8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

如何解决VAE训练中的梯度爆炸问题?

Ganesh Gebhard 2月前

21 0

受 MNIST 的 Tensorflow 实现启发,我尝试在 CelebA 数据集上实现 VAE。我尝试过改变批处理大小,但似乎没有效果。形成的图像是

受 MNIST 的 Tensorflow 实现启发,我尝试在 CelebA 数据集上实现 VAE。我尝试过改变批处理大小,但似乎没有效果。形成的图像大部分都是灰色的。理想情况下,我们希望 KL 散度和重建损失都接近于零,但在我的例子中,两者都呈指数增长。
这是我得到的损失曲线。 以下是损失函数定义块:

optimizer = tf.keras.optimizers.Adam(1e-4)
def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)
  
def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  #logpx_z = tf.reduce_mean(tf.square(x - x_logit), axis=[1, 2, 3])
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x), logpx_z, logqz_x-logpz

我的潜在维度是 16,批量大小是 500。此外,我的输入只有 500 张图像。

我已经尝试改变输入的大小,但似乎没有影响。这是模型定义:

class CVAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(64, 64, 3)),
                tf.keras.layers.Conv2D(
                    filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=128, kernel_size=3, strides=(2, 2), activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=256, kernel_size=3, strides=(2, 2), activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=512, kernel_size=3, strides=(2, 2), activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dense(latent_dim + latent_dim),
            ]
        )

        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
                tf.keras.layers.Dense(units=4*4*256, activation=tf.nn.relu),
                tf.keras.layers.Reshape(target_shape=(4, 4, 256)),
                tf.keras.layers.Conv2DTranspose(
                    filters=128, kernel_size=3, strides=2, padding='same',
                    activation='relu'),
                tf.keras.layers.Conv2DTranspose(
                    filters=64, kernel_size=3, strides=2, padding='same',
                    activation='relu'),
                tf.keras.layers.Conv2DTranspose(
                    filters=32, kernel_size=3, strides=2, padding='same',
                    activation='relu'),
                tf.keras.layers.Conv2DTranspose(
                    filters=3, kernel_size=3, strides=2, padding='same'),
            ]
        )
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

这是 colab 笔记本的链接

帖子版权声明 1、本帖标题:如何解决VAE训练中的梯度爆炸问题?
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由Ganesh Gebhard在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
返回
作者最近主题: