|
- # -*- coding:utf-8 -*-
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torchvision
- from torchvision import transforms, datasets
- import matplotlib.pyplot as plt
-
- # 定义数据转换
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- # 加载训练和验证数据集
- train_data = datasets.ImageFolder('trainN', transform=transform)
- train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
- valid_data = datasets.ImageFolder('testN', transform=transform)
- valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=64)
-
- # 加载ResNet-18模型
- model = torchvision.models.resnet18(pretrained=True)
- num_features = model.fc.in_features
- model.fc = nn.Linear(num_features, 21) # 30是分类的数量
-
- # 定义损失函数和优化器
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=0.001)
-
- # 训练模型
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model.to(device)
-
- num_epochs = 10
- train_losses = []
- valid_losses = []
- valid_accuracies = []
-
- for epoch in range(num_epochs):
- model.train()
- running_loss = 0.0
- print("epoch:",end=" ")
- for inputs, labels in train_loader:
- inputs, labels = inputs.to(device), labels.to(device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
- train_losses.append(running_loss / len(train_loader))
- print(epoch)
- # 计算验证集精度
- model.eval()
- correct = 0
- total = 0
- with torch.no_grad():
- for inputs, labels in valid_loader:
- inputs, labels = inputs.to(device), labels.to(device)
- outputs = model(inputs)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- accuracy = 100 * correct / total
- valid_accuracies.append(accuracy)
- valid_loss = criterion(outputs, labels).item()
- valid_losses.append(valid_loss)
-
- print(
- f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_losses[-1]}, Valid Loss: {valid_losses[-1]}, Valid Accuracy: {accuracy}")
-
- # 绘制精度图
- plt.plot(train_losses, label='Train Loss')
- plt.plot(valid_losses, label='Valid Loss')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend()
- plt.show()
-
- plt.plot(valid_accuracies, label='Valid Accuracy')
- plt.xlabel('Epoch')
- plt.ylabel('Accuracy (%)')
- plt.legend()
- plt.show()
-
- # 保存模型权重
- best_model_weights = model.state_dict()
-
- torch.save(best_model_weights, 'best_model.pth')
|