8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

PowerTransform 双射中的 lambda 参数的 MLE

mxmissile 2月前

24 0

我正在尝试找到具有最大似然估计的 PowerTransform 双射器的最佳 lambda 参数。为了做到这一点,我必须修改双射器的构造函数,以便允许

我正在尝试用最大似然估计来找到双射器的最优 lambda 参数。 PowerTransform 为了做到这一点,我必须修改双射器的构造函数以使其 power 可训练 tf.Variable 。然后代码如下:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
from scipy.stats import boxcox

np.random.seed(42)
tf.random.set_seed(42)

# Data exponential generation
data_dist = tfd.Sample(tfd.Exponential(1.), sample_shape=1000)
x_train = data_dist.sample()
plt.hist(x_train.numpy(), bins=120, density=True, alpha=0.6, color='blue', label='Samples')
plt.show()

nf = tfd.TransformedDistribution(
    tfd.Normal(loc=0, scale=1),
    bijector=tfb.PowerTransform(power=tf.Variable(initial_value=1., name='power'))
)

# Training loop
num_steps = 2000
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
for step in range(num_steps):
    with tf.GradientTape() as tape:
        loss = -tf.reduce_sum(nf.log_prob(x_train))
        grads = tape.gradient(loss, nf.trainable_variables)
    optimizer.apply_gradients(zip(grads, nf.trainable_variables))

    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.numpy()}, Grads: {grads}, Power: {nf.trainable_variables[0].numpy()}")

_, llm_lmbda = boxcox(x_train, lmbda=None)
print(f"Scipy MLE power is: {llm_lmbda}") # 0.24647694003577084
print(f"My MLE power is: {nf.trainable_variables[0].numpy()}") # 0.6407918334007263

z_samples = nf.sample(1000)
plt.hist(z_samples.numpy(), bins=120, density=True, alpha=0.6, color='green', label='Samples')
plt.show()

使用 scipy 库中的 MLE 进行估计给出, lambda = 0.25 但我给出 lambda = 0.64 。如果我使用静态值为 0.25 的双射器,我可以恢复更接近原始指数的分布,因此我相信训练程序或双射器中前向雅可比矩阵的计算可能有问题, PowerTransform 但我找不到它。

有人可以帮忙吗?

帖子版权声明 1、本帖标题:PowerTransform 双射中的 lambda 参数的 MLE
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由mxmissile在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
返回
作者最近主题: