|
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from matplotlib import pyplot as plt
- from sklearn.preprocessing import StandardScaler, MinMaxScaler
- from sklearn.metrics import classification_report, confusion_matrix
- from torch.utils.data import Dataset, TensorDataset
- from tqdm import tqdm
- import seaborn as sns
-
- from dataset import TimeSeriesDataset, data2classes
- from models.resnet import resnet101, resnet50, resnet18
-
-
- class Config():
- root_dir = r"D:\1\2024phm\Preliminary stage\Data_Pre Stage\Training data"
- used_data = ["data_motor.csv","data_gearbox.csv",
- "data_leftaxlebox.csv","data_rightaxlebox.csv"]
- seq_length = 4096 # 时间步长,就是利用多少时间窗口 1024
- img_size = 112
- batch_size = 32 # 批次大小 1024
- feature_size = 6 # 每个步长对应的特征数量 21 =6 + 9 + 3 +
- # hidden_size = 384 # 隐层大小 128
- _, num_classes = data2classes(used_data)
- epochs = 100 # 迭代轮数
- best_accuracy = 0.00 # 记录准确率
- learning_rate = 0.0001 # 学习率
- weight_decay = 0.00001 # 学习率
- model_name = 'resnet' # 模型名称
- save_path = './checkpoint/{}.pth'.format(model_name) # 最优模型保存路径
-
-
- if __name__ == '__main__':
- config = Config()
-
- # Create dataset
- test_data = TimeSeriesDataset(config.root_dir, config.seq_length,
- img_size=config.img_size, used_data=config.used_data, mode="valid")
- # 将数据加载成迭代器
- test_loader = torch.utils.data.DataLoader(test_data,
- config.batch_size,
- True)
-
- model = resnet18(in_channels=config.feature_size*2, num_classes=config.num_classes).cuda()
- # 加载模型参数
- checkpoint = torch.load(config.save_path)
- model.load_state_dict(checkpoint)
- # 在验证集上评估模型性能
- model.eval()
- correct = 0
- total = 0
- val_preds = np.array([])
- val_labels = np.array([])
-
- with torch.no_grad():
- for inputs, labels in test_loader:
- outputs = model(inputs.cuda())
- _, predicted = torch.max(outputs, 1)
- _, labels = torch.max(labels, 1)
- total += labels.size(0)
- # correct += (predicted == torch.max(labels, 1)).sum().item()
- correct += (predicted.cpu().numpy() == labels.cpu().numpy()).sum() # Convert to numpy arrays
- val_labels = np.concatenate((val_labels, labels.cpu().numpy()), axis=0)
- val_preds = np.concatenate((val_preds, predicted.cpu().numpy()), axis=0)
- accuracy = 100 * correct / total
- print("Validation Accuracy: {:.2f}%".format(accuracy))
- # 混淆矩阵
- cm = confusion_matrix(val_labels, val_preds)
- # 可视化混淆矩阵
- plt.figure(figsize=(8, 6))
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
- plt.xlabel('Valid Predicted labels')
- plt.ylabel('Valid True labels')
- plt.title('Valid Confusion Matrix')
- plt.show()
|