如何可视化用于测试使用混淆矩阵创建的模型的样本?例如如下所示。您可以访问 GitHub,整个过程都很相似,只是架构不同……
如何可视化用于测试使用混淆矩阵创建的模型的样本?例如如下所示。
你可以访问 GitHub,整个过程是相似的,只是架构和数据集不同
https://github.com/cendekialnazalia/CaisimPestDetection/blob/main/Percobaan%20E%20-%20CNN%20add%20Models%20Xception.ipynb
这是我的代码
训练模型
epochs = 10
mc = ModelCheckpoint('sequential', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history=model.fit(x=train_gen, epochs=epochs, validation_data=valid_gen)
def print_info( test_gen, preds, print_code, save_dir, subject ):
class_dict=test_gen.class_indices
labels= test_gen.labels
file_names= test_gen.filenames
error_list=[]
true_class=[]
pred_class=[]
prob_list=[]
new_dict={}
error_indices=[]
y_pred=[]
for key,value in class_dict.items():
new_dict[value]=key # dictionary {integer of class number: string of class name}
# store new_dict as a text fine in the save_dir
classes=list(new_dict.values()) # list of string of class names
dict_as_text=str(new_dict)
dict_name= subject + '-' +str(len(classes)) +'.txt'
dict_path=os.path.join(save_dir,dict_name)
with open(dict_path, 'w') as x_file:
x_file.write(dict_as_text)
errors=0
for i, p in enumerate(preds):
pred_index=np.argmax(p)
true_index=labels[i] # labels are integer values
if pred_index != true_index: # a misclassification has occurred
error_list.append(file_names[i])
true_class.append(new_dict[true_index])
pred_class.append(new_dict[pred_index])
prob_list.append(p[pred_index])
error_indices.append(true_index)
errors=errors + 1
y_pred.append(pred_index)
if print_code !=0:
if errors>0:
if print_code>errors:
r=errors
else:
r=print_code
msg='{0:^28s}{1:^28s}{2:^28s}{3:^16s}'.format('Filename', 'Predicted Class' , 'True Class', 'Probability')
print_in_color(msg, (0,255,0),(55,65,80))
for i in range(r):
split1=os.path.split(error_list[i])
split2=os.path.split(split1[0])
fname=split2[1] + '/' + split1[1]
msg='{0:^28s}{1:^28s}{2:^28s}{3:4s}{4:^6.4f}'.format(fname, pred_class[i],true_class[i], ' ', prob_list[i])
print_in_color(msg, (255,255,255), (55,65,60))
#print(error_list[i] , pred_class[i], true_class[i], prob_list[i])
else:
msg='With accuracy of 100 % there are no errors to print'
print_in_color(msg, (0,255,0),(55,65,80))
if errors>0:
plot_bar=[]
plot_class=[]
for key, value in new_dict.items():
count=error_indices.count(key)
if count!=0:
plot_bar.append(count) # list containg how many times a class c had an error
plot_class.append(value) # stores the class
fig=plt.figure()
fig.set_figheight(len(plot_class)/3)
fig.set_figwidth(10)
plt.style.use('fivethirtyeight')
for i in range(0, len(plot_class)):
c=plot_class[i]
x=plot_bar[i]
plt.barh(c, x, )
plt.title( ' Errors by Class on Test Set')
y_true= np.array(labels)
y_pred=np.array(y_pred)
if len(classes)<= 30:
# create a confusion matrix
cm = confusion_matrix(y_true, y_pred )
length=len(classes)
if length<8:
fig_width=8
fig_height=8
else:
fig_width= int(length * .5)
fig_height= int(length * .5)
plt.figure(figsize=(fig_width, fig_height))
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)
plt.xticks(np.arange(length)+.5, classes, rotation= 90)
plt.yticks(np.arange(length)+.5, classes, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
clr = classification_report(y_true, y_pred, target_names=classes)
print("Classification Report:\n----------------------\n", clr)
混淆矩阵
print_code=0
preds=model.predict(test_gen)
print_info( test_gen, preds, print_code, save_dir, subject )
输出
我不仅想显示表格输出和召回率、精确度和 f1 分数值,还想显示 CM 预测的每幅图像的可视化效果,例如像上面的图片,或者可能更好。
例如,在 \'Daun Sehat\' 表中,有一个数据样本被预测为 \'Karat Merah\',但如果没有该信息的可视化,我不知道 \'Daun Sehat\' 的哪个图像样本被检测为 \'Karat Merah\'
在最后一行代码“对测试集进行预测并生成混淆矩阵和分类报告”之后,我添加了如下代码来找出每个测试数据的预测
test_gen.class_indices
print(preds,preds.shape)
result_index = np.argmax(preds[])
print(result_index)
for i in range(len(preds)):
if(np.argmax(preds[i]) == 0):
print("Bercak Daun")
elif(np.argmax(preds[i]) == 1):
print("Daun Sehat")
elif(np.argmax(preds[i]) == 2):
print("Karat Merah")
else:
print("Lainya")
输出
Bercak Daun
Daun Sehat
.
.
.
up to 177
Daun Sehat
我不只是想显示带有字符串的预测,还想将其与图像一起显示。也许有更好、更高效的代码可以解决我的问题?