|
- import os
- import time
- import random
- import logging
- import numpy as np
- from mindspore import save_checkpoint
- from models.MultiTaskNet import MultiTaskUNet
- from models.utils import mmseg_acc, accuracy_pixel_level
- from mindspore.train.serialization import load_param_into_net, load_checkpoint
- from data_folder import DataFolder
- from my_transforms import get_transforms
- from tensorboardX import SummaryWriter
- # from sklearn.metrics import accuracy_score
- import mindspore.nn as nn
- import shutil
- import argparse
- from data_folder import create_dataset
- from models.loss import MultiLoss
- from models.crossentropy import CrossEntropy
- from mindspore import context
-
- context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
-
- writer = SummaryWriter()
-
-
- def mysave_checkpoint(net, epoch, is_best, save_dir):
- cp_dir = 'checkpoints'
- if not os.path.exists(cp_dir):
- os.mkdir(cp_dir)
- filename = "checkpoints/epoch_" + str(epoch) + ".ckpt"
- save_checkpoint(net, "checkpoints/epoch_" + str(epoch) + ".ckpt")
- if is_best:
- shutil.copyfile(filename, 'checkpoints/final_best.pth')
-
-
- def main(args, logger):
- best_acc = 0
- dsets = {}
- data_transforms = {
- 'train': get_transforms({
- 'scale': 240,
- 'horizontal_flip': True,
- 'random_rotation': 90,
- 'random_crop': 240,
- 'to_tensor': 1
- }),
- 'validation': get_transforms({
- 'scale': 240,
- 'to_tensor': 1
- })
- }
-
- for x in ['train', 'validation']:
- img_dir = os.path.join(args.train_img_dir, x)
- target_dir = os.path.join(args.train_label_dir, x)
- dir_list = [img_dir, target_dir]
- dsets[x] = dir_list
- val_loader = create_dataset(dir_list=dsets["validation"], post_fix=['.png'], num_channels=[3, 1],
- data_transforms=data_transforms["validation"],
- column_names=["input", "target", "category"],
- batch_size=args.batch_size, shuffle=False)
-
- net = MultiTaskUNet(n_channels=3, n_classes=args.num_classes)
- param_dict = load_checkpoint("checkpoints/final_epoch.ckpt") # final_best.ckpt
- load_param_into_net(net, param_dict)
- eval_results = np.zeros((2,), np.float32)
- max_batch = 0
- t = 0
- p = 0
- for index, sample in enumerate(val_loader):
- input, target, category = sample
- target = target.squeeze()
- segoutput, clsoutput = net(input)
- pred = np.argmax(segoutput.asnumpy(), axis=1)
- all_acc, acc, iou = mmseg_acc(pred, target.asnumpy(), num_classes=args.num_classes, ignore_index=0)
- metrics = accuracy_pixel_level(pred, target.asnumpy())
- eval_results += np.array([all_acc, metrics[1]])
- max_batch = max(index, max_batch)
- pred_cls = np.argmax(clsoutput.asnumpy(), axis=1)
- # p += accuracy_score(pred_cls, category, normalize=False)
- t += input.shape[0]
- eval_results = [value / max_batch for value in eval_results.tolist()]
- logger.info("Eval Results : mAcc = {:.2f}, mIoU = {:.2f}".format(eval_results[0],eval_results[1]))
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Process some integers.')
- # dataloader
- parser.add_argument('--batch_size', default=4, type=int)
- parser.add_argument('--num_classes', default=8, type=int)
-
- # loss
- parser.add_argument('--seg_loss', default="CE", type=str, choices=["CE", "MSE"])
-
- # optimizer
- parser.add_argument('--optimizer', default="SGD", type=str, choices=["SGD", "Adam"])
- parser.add_argument('--lr', default=1e-4, type=float)
- parser.add_argument('--weight_decay', default=1e-4, type=float)
-
- # training
- parser.add_argument('--num_epoches', default=1, type=int)
- parser.add_argument('--train_img_dir', default="endoscope400/ade20k/images", type=str)
- parser.add_argument('--train_label_dir', default="endoscope400/ade20k/annotations", type=str)
-
- # evaluation
- parser.add_argument('--eval_per_epoch', default=5, type=int)
-
- parser.add_argument('--model_save_dir', default="runs", type=str)
-
- args = parser.parse_args()
-
- logging.basicConfig(filemode='a',
- format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
- datefmt='%H:%M:%S',
- level=logging.INFO)
- sh = logging.StreamHandler() # 往屏幕上输出
- fh = logging.FileHandler('runs/logs_{}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
-
- logging.info("Training and Validation Record.")
- logger = logging.getLogger("MultiTask for Endoscope")
- logger.addHandler(sh)
- logger.addHandler(fh)
- main(args, logger)
|