|
- import os.path
- import numpy as np
- import torch
- import torch.nn
- from torch.nn import CrossEntropyLoss
- from torch.utils.data import DataLoader
- from utils.dataset.TT_Dataset_new import MyDataset # 读取数据所用函数
- from Model.HrNet.hrnet_vit import HighResolutionNet
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- #超参数
- num_classes = 8
- batch_size = 1
- epoches = 200
- modelname = "UNet"
- #读取数据
- imagePath = r"G:\数据集\连云港\NewDataSet_5.18\cut\image"
- labelPath = r"G:\数据集\连云港\NewDataSet_5.18\cut\label"
-
- val_image = r"G:\数据集\连云港\NewDataSet_5.18\cut\val_image - 副本"
- val_label = r"G:\数据集\连云港\NewDataSet_5.18\cut\val_label - 副本"
-
- trainDataset = MyDataset(imagePath, labelPath)
- trainDatasetloader = DataLoader(trainDataset, batch_size, shuffle=False)
- trainLen = len(trainDatasetloader)
-
- valDataset = MyDataset(val_image,val_label)
- valDatasetloader = DataLoader(valDataset, 1, shuffle=False)
- valLen = len(valDatasetloader)
-
-
- #构建模型、优化器、损失
- model = HighResolutionNet(num_class=num_classes)
- total = sum([param.nelement() for param in model.parameters()])
- print("Number of parameter: %.2fM" % (total / 1e6))
- loss = CrossEntropyLoss()
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01, betas=(0.9, 0.95))
- #开始训练
- for epoch in range(epoches):
- print(
- f"\n----------------------------------------------epoch: {epoch}----------------------------------------------")
- loss_total = 0
- acc_all = 0
-
- for i,data in enumerate(trainDatasetloader):
- image, label, image_down, id = data
- image = image.to(device)
- label = label.to(device)
- image_down = image_down.to(device)
- #梯度清零
- optimizer.zero_grad()
- output = model(image,image_down,id)
- loss_step = loss(output, label.long())
- print("\r train: epoch: {}, step: {}/{}, loss: {} ".format(epoch, i, trainLen,loss_step ,end=''))
- loss_step.backward()
- optimizer.step()
- loss_total += loss_step
-
-
- for j ,data in enumerate(valDatasetloader):
- image, label, image_down, id = data
- image = image.to(device)
- label = label.to(device)
- image_down = image_down.to(device)
- output = torch.argmax(model(image,image_down,id),axis=1)
- acc = torch.mean((output.eq_(label)).to(torch.float32))
- acc_all += acc
- print("acc", acc)
-
- acc_all = acc_all/len(valDatasetloader)
- print("acc_all:" + str(acc_all))
-
-
-
- loss_epoch = loss_total/trainLen
- print("\r epoch: {}, epoch_loss: {}".format(epoch, round(float(loss_epoch), 8)), end='')
|