8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

Tensorflow/Keras Transformer 很难预测序列中的最后一个位置,但在其他所有位置上都表现良好

dbrrt 2月前

12 0

我正在使用 Transformer 进行下一帧预测。每个帧都已使用 VAE 编码为 1D 潜在向量(VAE 的编码解码非常好)。我将

我正在使用 Transformer 进行下一帧预测。每个帧都已使用 VAE 预先编码为 1D 潜在向量(VAE 的编码解码效果非常好)。

我将视频分成 24 帧的块,并使用 VAE 创建了 24,100 这样的序列(其中 24 是序列长度,100 是来自 VAE 的潜在向量)。这些序列被输入到一个小型 Transformer(类似 GPT,只有带有随意注意的解码器)。由于它是批处理的,因此实际输入为 64,24,100。我们将其中一个称为“源”。我向 Transformer 输入“源 [:,:-1]”,目标是“源 [:,1:]”。损失函数是 MSE。当我检查 Transformer 的输出时,它在所有序列位置(0,23)中都做出了出色的预测,除了最后一个位置(24)。最后一个还算可以。不是随机的,但也不惊人。例如,所有其他位置的 MSE 损失约为 0.006,而最后一个位置的 MSE 损失为 0.15。

这些都表明我没有使用随意掩蔽——因为它在所有可以看到未来的序列条目中表现良好,但在无法看到未来的最后一个位置上却表现不佳。但我使用的是随意注意,所以没有任何条目应该看到未来:

class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()

class CausalSelfAttention(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x,
        use_causal_mask = True)
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

变压器块使用如下代码:

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.3):
        super(TransformerBlock, self).__init__()
        self.att = CausalSelfAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"),
             layers.Dense(embed_dim), ]
        )
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout2 = layers.Dropout(rate)


    def call(self, inputs):
        attention_output = self.att(inputs)
        ffn_output = self.ffn(attention_output)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(attention_output + ffn_output)

在此之前,我自己计算掩码并在创建 MHA 时发送它们,而不是使用“use_casual_mask=True”。但结果是一样的。

在这里你还可以看到我的 train_step,你可以看到我快速输入和预测的窗口:

    def train_step(self, batch):
        """Processes one batch inside model.fit()."""
        source = batch[0]/4.  # [0] are the sequences of latent vectors
        source_input = source[:, :-1]
        source_target = source[:, 1:]
        with tf.GradientTape() as tape:
            preds = self(source_input)
            loss = self.compiled_loss(source_target, preds)*4.
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.loss_metric.update_state(loss)

        return {"loss": self.loss_metric.result()}

在这里您可以看到我的嵌入和位置嵌入:

class PositionEmbedding(layers.Layer):
    def __init__(self, maxlen, embed_dim):
        super(PositionEmbedding, self).__init__()
        # self.latent_emb = layers.Dense(embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
        self.latent_emb = keras.Sequential(
            [layers.Conv2D(1, (5,5), padding="same", activation="relu"),
             layers.Reshape([maxlen, 100]),
             layers.Conv1D(embed_dim, 3, padding="same"), ]
        )

    def call(self, x):

        maxlen = tf.shape(x)[1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)

        x = self.latent_emb(tf.expand_dims(x,-1))
        # x = self.latent_emb(x)
        return x + positions

如您所见,我尝试了两件事(如果您查看注释)。首先,我尝试使用 Dense 层作为嵌入,仅此而已。然后我尝试使用 Conv2d 和 Conv1d(现在取消注释的那个)。结果是一样的 - 使用 Convs 会好一点,但好不了多少。[0,-1] 和 [-1] 之间的预测差异仍然存在。

我很困惑为什么会有这么大的差异,因为使用 Casual Masks 不应该这样。我还检查了问题是否出在进行 Conv 的填充上,但第一个预测(行)似乎没有受到影响。第一列和最后一列似乎也没有问题。唯一的问题是最后一行(新的预测)。我尝试了不同的填充技术,但问题仍然存在。如果 Conv 填充是问题所在,那么你会期望它出现在数据周围,而不仅仅是最后一行。

总的来说,我很困惑,因为这一切似乎都表明它忽略了随意的面具,但据我所知,它正在使用它。

帖子版权声明 1、本帖标题:Tensorflow/Keras Transformer 很难预测序列中的最后一个位置,但在其他所有位置上都表现良好
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由dbrrt在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
返回
作者最近主题: