我正在尝试训练LSTM模型,然后进行完全连接的层,以对具有22个通道(序列长度为1000)的一组EEG时间序列数据进行分类。
我正在使用Pytorch ...
我正在尝试训练LSTM模型,然后是完全连接的层,以对一组EEG时间序列数据进行分类,该数据具有22个通道(序列长度为1000)。
我正在使用pytorch进行层,但是在训练期间,验证准确性并不能提高。
如果有的话,验证变得更糟。
我假设这里有很多过度适应的事情,但是我很困惑为什么它仍然与随机选择班级相同的水平(我有4个类)。
有人知道为什么这样吗?
我正在附加用于创建模型,训练它的代码以及培训期间的一些输出。
代码
num_layers = 1
input_size = 22
hidden_size = 32
num_classes = 4
batch_size = 50
num_epochs = 20
class LSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=0.5, num_layers=num_layers)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Set initial hidden and cell states
h0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
# Instantiate the model
model = LSTMClassifier(input_size, hidden_size, num_classes)
model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss, correct, total = 0.0, 0, 0
for i in range(0, len(X_train), batch_size):
optimizer.zero_grad()
# Get batch data
batch_input = X_train[i:i+batch_size, :, :input_size] # Ensure input size matches
batch_target = y_train[i:i+batch_size]
# Forward pass
outputs = model(batch_input)
# Calculate loss
loss = criterion(outputs, batch_target)
# Backward pass and optimize
loss.backward()
optimizer.step()
total_loss += loss.item()
# Calculate accuracy
_, predicted = torch.max(outputs, 1)
total += batch_target.size(0)
correct += (predicted == batch_target).sum().item()
accuracy = 100 * correct / total
# Validation
model.eval()
total_loss_valid, correct_valid, total_valid = 0.0, 0, 0
with torch.no_grad():
outputs = model(X_valid)
loss_valid = criterion(outputs, y_valid)
total_loss_valid += loss_valid.item()
_, predicted_valid = torch.max(outputs, 1)
total_valid += y_valid.size(0)
correct_valid += (predicted_valid == y_valid).sum().item()
accuracy_valid = 100 * correct_valid / total_valid
print("Epoch:", epoch+1, "\t\tTraining Loss:", total_loss, "\tTraining Accuracy:", accuracy)
print("\t\t\tValidation Loss:", total_loss_valid, "\tValidation Accuracy:", accuracy_valid)
print("Model training complete.")
训练输出
Epoch: 1 Training Loss: 47.797505140304565 Training Accuracy: 24.231678486997637
Validation Loss: 1.3961728811264038 Validation Accuracy: 23.87706855791962
Epoch: 2 Training Loss: 46.88989615440369 Training Accuracy: 27.77777777777778
Validation Loss: 1.3933675289154053 Validation Accuracy: 24.34988179669031
Epoch: 3 Training Loss: 46.60832667350769 Training Accuracy: 30.61465721040189
Validation Loss: 1.3946889638900757 Validation Accuracy: 28.368794326241133
Epoch: 4 Training Loss: 46.370394468307495 Training Accuracy: 31.85579196217494
Validation Loss: 1.3951939344406128 Validation Accuracy: 27.89598108747045
Epoch: 5 Training Loss: 46.11502695083618 Training Accuracy: 34.1016548463357
Validation Loss: 1.3980598449707031 Validation Accuracy: 27.659574468085108
Epoch: 6 Training Loss: 45.86429834365845 Training Accuracy: 35.579196217494086
Validation Loss: 1.399003267288208 Validation Accuracy: 26.24113475177305
Epoch: 7 Training Loss: 45.610363483428955 Training Accuracy: 37.35224586288416
Validation Loss: 1.4000375270843506 Validation Accuracy: 24.34988179669031
Epoch: 8 Training Loss: 45.35917508602142 Training Accuracy: 37.70685579196218
Validation Loss: 1.4008349180221558 Validation Accuracy: 26.24113475177305
Epoch: 9 Training Loss: 45.06815278530121 Training Accuracy: 39.657210401891255
Validation Loss: 1.4022436141967773 Validation Accuracy: 25.29550827423168
Epoch: 10 Training Loss: 44.72212290763855 Training Accuracy: 41.016548463356976
Validation Loss: 1.4047679901123047 Validation Accuracy: 23.87706855791962
Epoch: 11 Training Loss: 44.39665925502777 Training Accuracy: 41.78486997635934
Validation Loss: 1.4072729349136353 Validation Accuracy: 25.768321513002363
Epoch: 12 Training Loss: 44.04106819629669 Training Accuracy: 42.434988179669034
Validation Loss: 1.409785509109497 Validation Accuracy: 24.34988179669031
Epoch: 13 Training Loss: 43.70037293434143 Training Accuracy: 43.73522458628842
Validation Loss: 1.4153059720993042 Validation Accuracy: 25.059101654846337
Epoch: 14 Training Loss: 43.31694221496582 Training Accuracy: 45.62647754137116
Validation Loss: 1.4185421466827393 Validation Accuracy: 23.87706855791962
Epoch: 15 Training Loss: 42.95249891281128 Training Accuracy: 46.74940898345154
Validation Loss: 1.425550937652588 Validation Accuracy: 24.58628841607565
Epoch: 16 Training Loss: 42.523557305336 Training Accuracy: 47.75413711583924
Validation Loss: 1.4291025400161743 Validation Accuracy: 24.113475177304963
Epoch: 17 Training Loss: 42.142807960510254 Training Accuracy: 48.758865248226954
Validation Loss: 1.4345715045928955 Validation Accuracy: 24.822695035460992
Epoch: 18 Training Loss: 41.72265028953552 Training Accuracy: 49.645390070921984
Validation Loss: 1.4420568943023682 Validation Accuracy: 24.34988179669031
Epoch: 19 Training Loss: 41.276421666145325 Training Accuracy: 50.354609929078016
Validation Loss: 1.449327826499939 Validation Accuracy: 25.53191489361702
Epoch: 20 Training Loss: 40.81519865989685 Training Accuracy: 51.59574468085106
Validation Loss: 1.4553029537200928 Validation Accuracy: 25.29550827423168
Model training complete.