|
- import numpy as np
- import random
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- from torch.optim.lr_scheduler import CosineAnnealingLR
- from torch.utils.data.sampler import SubsetRandomSampler
- from torch.utils.data import DataLoader
- import sklearn.metrics as metrics
- import argparse
- import copy
- import utils.log
- from PointDA.data.dataloader import ScanNet, ModelNet, ShapeNet, label_to_idx
- from PointDA.Models import PointNet, DGCNN
- from utils import pc_utils
- from DefRec_and_PCM import DefRec, PCM
-
-
-
-
- # str转bool类型 这样的话,当我们只要输入自定义函数中的那些对应True or False的选项,就能能触发对应的布尔类型。
- def str2bool(v):
- """
- Input:
- v - string
- output:
- True/False
- """
- if isinstance(v, bool):
- return v
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
- return True
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
- return False
- else:
- raise argparse.ArgumentTypeError('Boolean value expected.')
-
-
- # ==================
- # Read Data
- # ==================
- def split_set(dataset, domain, set_type="source"):
- """
- Input:
- dataset -
- domain - modelnet/shapenet/scannet
- type_set - source/target
- output:
- train_sampler, valid_sampler
- """
- train_indices = dataset.train_ind
- val_indices = dataset.val_ind
- unique, counts = np.unique(dataset.label[train_indices], return_counts=True)
- io.cprint("Occurrences count of classes in " + set_type + " " + domain +
- " train part: " + str(dict(zip(unique, counts))))
- unique, counts = np.unique(dataset.label[val_indices], return_counts=True)
- io.cprint("Occurrences count of classes in " + set_type + " " + domain +
- " validation part: " + str(dict(zip(unique, counts))))
- # Creating PT data samplers and loaders:
- train_sampler = SubsetRandomSampler(train_indices)
- valid_sampler = SubsetRandomSampler(val_indices)
- return train_sampler, valid_sampler
-
-
- # ==================
- # Validation/test
- # ==================
-
- def final_test(test_loader, modelsd=None, modeltd=None, set_type="Target", partition="Val", epoch=0):
- # Run on cpu or gpu
- count = 0.0
- print_losses = {'cls': 0.0}
- batch_idx = 0
-
- with torch.no_grad():
- modelsd.eval()
- modeltd.eval()
- test_pred = []
- test_true = []
- for data, labels in test_loader:
- data, labels = data.to(device), labels.to(device).squeeze()
- data = data.permute(0, 2, 1)
- batch_size = data.size()[0]
-
- logits_sd = model_sd(data, activate_DefRec=False)
- logits_td = model_td(data, activate_DefRec=False)
- loss_sd = criterion(logits_sd["cls"], labels)
- loss_td = criterion(logits_td["cls"], labels)
-
- softmax_sum = logits_sd["cls"] + logits_td["cls"]
-
- loss = loss_sd + loss_td
- print_losses['cls'] += loss.item() * batch_size
-
- # evaluation metrics
- _, final_preds = torch.topk(softmax_sum, 1)
- test_true.append(labels.cpu().numpy())
- test_pred.append(final_preds.detach().cpu().numpy())
- count += batch_size
- batch_idx += 1
-
- test_true = np.concatenate(test_true)
- test_pred = np.concatenate(test_pred)
- print_losses = {k: v * 1.0 / count for (k, v) in print_losses.items()}
- test_acc = io.print_progress(set_type, partition, epoch, print_losses, test_true, test_pred)
- conf_mat = metrics.confusion_matrix(test_true, test_pred, labels=list(label_to_idx.values())).astype(int)
-
- return test_acc, print_losses['cls'], conf_mat
-
-
- def test(test_loader, model=None, set_type="Target", partition="Val", epoch=0):
- # Run on cpu or gpu
- count = 0.0
- print_losses = {'cls': 0.0}
- batch_idx = 0
-
- with torch.no_grad():
- model.eval()
- test_pred = []
- test_true = []
- for data, labels in test_loader:
- data, labels = data.to(device), labels.to(device).squeeze()
- data = data.permute(0, 2, 1)
- batch_size = data.size()[0]
-
- logits = model(data, activate_DefRec=False)
- loss = criterion(logits["cls"], labels)
- print_losses['cls'] += loss.item() * batch_size
-
- # evaluation metrics
- preds = logits["cls"].max(dim=1)[1]
- test_true.append(labels.cpu().numpy())
- test_pred.append(preds.detach().cpu().numpy())
- count += batch_size
- batch_idx += 1
-
- test_true = np.concatenate(test_true)
- test_pred = np.concatenate(test_pred)
- print_losses = {k: v * 1.0 / count for (k, v) in print_losses.items()}
- test_acc = io.print_progress(set_type, partition, epoch, print_losses, test_true, test_pred)
- conf_mat = metrics.confusion_matrix(test_true, test_pred, labels=list(label_to_idx.values())).astype(int)
-
- return test_acc, print_losses['cls'], conf_mat
-
-
- # ==================
- # Utils
- # ==================
- def get_target_preds(args, x):
- # pdb.set_trace()
- top_prob, top_label = torch.topk(F.softmax(x['cls'], dim=1), k=1)
- top_label = top_label.squeeze().t()
- top_prob = top_prob.squeeze().t()
- top_mean, top_std = top_prob.mean(), top_prob.std()
- threshold = top_mean - args.th * top_std
- return top_label, top_prob, threshold
-
-
- def generate_trgt_pseudo_label(trgt_data, logits, threshold):
- batch_size = trgt_data.size(0)
- pseudo_label = torch.zeros(batch_size, 10).long() # one-hot label
- sfm = nn.Softmax(dim=1)
- cls_conf = sfm(logits['cls'])
- mask = torch.max(cls_conf, 1) # 2 * b
- for i in range(batch_size):
- index = mask[1][i]
- if mask[0][i] > threshold:
- pseudo_label[i][index] = 1
-
- return pseudo_label
-
-
- def get_sp_loss(input, target, temp):
- criterion = nn.NLLLoss(reduction='none').cuda()
- loss = torch.mul(criterion(torch.log(1 - F.softmax(input / temp, dim=1)), target.detach()), 1).mean()
- return loss
-
-
- def fix_mix(args, src_data, trgt_data, src_label, trgt_pseudo, ratio):
- batch_size, _, num_points = src_data.size()
- device = torch.device("cuda:" + str(src_data.get_device()) if args.cuda else "cpu")
- # mixed_x = ratio * src_data + (1 - ratio) * trgt_data
-
- # loss = mixup_criterion_hard(pred, src_label.detach(), trgt_pseudo.detach(), ratio)
-
- # draw lambda from beta distribution
- # lam = np.random.beta(args.mixup_params, args.mixup_params) if args.mixup_params > 0 else 1.0
-
- num_pts_a = round(ratio * num_points)
- num_pts_b = num_points - round(ratio * num_points)
-
- pts_indices_a, pts_vals_a = pc_utils.farthest_point_sample(args, src_data, num_pts_a)
- pts_indices_b, pts_vals_b = pc_utils.farthest_point_sample(args, trgt_data, num_pts_b)
- mixed_X = torch.cat((pts_vals_a, pts_vals_b), 2) # convex combination
- points_perm = torch.randperm(num_points).to(device) # 要尝试加它对结果变化大不大,draw random permutation of points in the shape
- mixed_X = mixed_X[:, :, points_perm]
-
- Y_a = src_label.clone()
- Y_b = trgt_pseudo.clone()
-
- return mixed_X, (Y_a, Y_b, ratio)
-
-
- if __name__ == '__main__':
-
- NWORKERS = 8
- MAX_LOSS = 9 * (10 ** 9)
- # ==================
- # Argparse
- # ==================
-
- # 设置参数对象,简要设置为该程序要执行什么任务
- parser = argparse.ArgumentParser(description='DA on Point Clouds')
-
- # 调用add_argument()方法添加参数
- parser.add_argument('--exp_name', type=str, default='DefRec_PCM', help='Name of the experiment')
- parser.add_argument('--out_path', type=str, default='/tmp/output', help='log folder path')
- parser.add_argument('--dataroot', type=str, default='/tmp/dataset', metavar='N', help='data path')
- parser.add_argument('--src_dataset', type=str, default='shapenet', choices=['modelnet', 'shapenet', 'scannet'])
- parser.add_argument('--trgt_dataset', type=str, default='scannet', choices=['modelnet', 'shapenet', 'scannet'])
- parser.add_argument('--epochs', type=int, default=200, help='number of episode to train')
- parser.add_argument('--model', type=str, default='dgcnn', choices=['pointnet', 'dgcnn'], help='Model to use')
- parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
- parser.add_argument('--gpus', type=lambda s: [int(item.strip()) for item in s.split(',')], default='0',
- help='comma delimited of gpu ids to use. Use "-1" for cpu usage')
- parser.add_argument('--DefRec_dist', type=str, default='volume_based_voxels', metavar='N',
- choices=['volume_based_voxels', 'volume_based_radius'],
- help='distortion of points')
- parser.add_argument('--num_regions', type=int, default=3, help='number of regions to split shape by')
- parser.add_argument('--DefRec_on_src', type=str2bool, default=False, help='Using DefRec in source')
- parser.add_argument('--apply_PCM', type=str2bool, default=True, help='Using mixup in source')
- parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
- help='Size of train batch per domain')
- parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
- help='Size of test batch per domain')
- parser.add_argument('--optimizer', type=str, default='ADAM', choices=['ADAM', 'SGD'])
- parser.add_argument('--DefRec_weight', type=float, default=0.5, help='weight of the DefRec loss')
- parser.add_argument('--mixup_params', type=float, default=1.0, help='a,b in beta distribution')
- parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
- parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
- parser.add_argument('--wd', type=float, default=5e-5, help='weight decay')
- parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
- # 加
- parser.add_argument('--DefRec_on_trgt', type=str2bool, default=True, help='Using DefRec in target')
- parser.add_argument('--fix_ratio_mixup', type=str2bool, default=True, help='Using fixed ratio mixup')
- parser.add_argument('--apply_bim', type=str2bool, default=True, help='Using Bidirectional Matching')
- parser.add_argument('--apply_sp', type=str2bool, default=True, help='Using Self-Penalization')
- parser.add_argument('--apply_cr', type=str2bool, default=True, help='Using Consistency Regularization')
- parser.add_argument('-th', default=2.0, type=float, help='Threshold')
- parser.add_argument('-bim_start', default=150, type=int, help='Bidirectional Matching')
- parser.add_argument('-sp_start', default=150, type=int, help='Self-Penalization')
- parser.add_argument('-cr_start', default=150, type=int, help='Consistency Regularization')
- parser.add_argument('-lam_sd', default=0.7, type=float, help='Source Dominant Mixup ratio')
- parser.add_argument('-lam_td', default=0.3, type=float, help='Target Dominant Mixup ratio')
- # 使用 parse_args() 解析添加的参数
- args = parser.parse_args()
-
- # ==================
- # init
- # ==================
- io = utils.log.IOStream(args)
- io.cprint(str(args))
-
- random.seed(1)
- np.random.seed(1) # to get the same point choice in ModelNet and ScanNet leave it fixed
- torch.manual_seed(args.seed)
- args.cuda = (args.gpus[0] >= 0) and torch.cuda.is_available()
- device = torch.device("cuda:" + str(args.gpus[0]) if args.cuda else "cpu")
- if args.cuda:
- io.cprint('Using GPUs ' + str(args.gpus) + ',' + ' from ' +
- str(torch.cuda.device_count()) + ' devices available')
- torch.cuda.manual_seed_all(args.seed)
- # 主要用于实验过程的可复现 ?
- torch.backends.cudnn.enabled = False
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
- else:
- io.cprint('Using CPU')
- # string类型
- src_dataset = args.src_dataset
- trgt_dataset = args.trgt_dataset
- data_func = {'modelnet': ModelNet, 'scannet': ScanNet, 'shapenet': ShapeNet}
-
- src_trainset = data_func[src_dataset](io, args.dataroot, 'train') # <data.dataloader.ShapeNet object>
- trgt_trainset = data_func[trgt_dataset](io, args.dataroot, 'train')
- trgt_testset = data_func[trgt_dataset](io, args.dataroot, 'test')
-
- # Creating data indices for training and validation splits:
- src_train_sampler, src_valid_sampler = split_set(src_trainset, src_dataset, "source")
- trgt_train_sampler, trgt_valid_sampler = split_set(trgt_trainset, trgt_dataset, "target")
-
- # dataloaders for source and target 如果大小不合适,就删除最后一个未完成的批处理。
- src_train_loader = DataLoader(src_trainset, num_workers=NWORKERS, batch_size=args.batch_size,
- sampler=src_train_sampler, drop_last=True)
- src_val_loader = DataLoader(src_trainset, num_workers=NWORKERS, batch_size=args.test_batch_size,
- sampler=src_valid_sampler)
- trgt_train_loader = DataLoader(trgt_trainset, num_workers=NWORKERS, batch_size=args.batch_size,
- sampler=trgt_train_sampler, drop_last=True)
- trgt_val_loader = DataLoader(trgt_trainset, num_workers=NWORKERS, batch_size=args.test_batch_size,
- sampler=trgt_valid_sampler)
- trgt_test_loader = DataLoader(trgt_testset, num_workers=NWORKERS, batch_size=args.test_batch_size)
-
- # ==================
- # Init Model
- # ==================
- if args.model == 'pointnet':
- model_sd = PointNet(args)
- model_td = copy.deepcopy(model_sd)
- elif args.model == 'dgcnn':
- # model = DGCNN(args)
- # 加
- model_sd = DGCNN(args)
- model_td = copy.deepcopy(model_sd)
- else:
- raise Exception("Not implemented")
-
- # model = model.to(device)
- # 加
- model_sd = model_sd.to(device)
- model_td = model_td.to(device)
-
- # Handle multi-gpu
- if (device.type == 'cuda') and len(args.gpus) > 1:
- model_sd = nn.DataParallel(model_sd, args.gpus)
- model_td = nn.DataParallel(model_td, args.gpus)
- # best_model = copy.deepcopy(model)
-
- # ==================
- # Optimizer
- # ==================
- # opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
- # weight_decay=args.wd) if args.optimizer == "SGD" \
- # else optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
- # 加
- opt_sd = optim.SGD(model_sd.parameters(), lr=args.lr, momentum=args.momentum,
- weight_decay=args.wd) if args.optimizer == "SGD" \
- else optim.Adam(model_sd.parameters(), lr=args.lr, weight_decay=args.wd)
- opt_td = optim.SGD(model_td.parameters(), lr=args.lr, momentum=args.momentum,
- weight_decay=args.wd) if args.optimizer == "SGD" \
- else optim.Adam(model_td.parameters(), lr=args.lr, weight_decay=args.wd)
- # 加
- sp_param_sd = nn.Parameter(torch.tensor(5.0).to(device), requires_grad=True)
- sp_param_td = nn.Parameter(torch.tensor(5.0).to(device), requires_grad=True)
- opt_sd.add_param_group({"params": [sp_param_sd], "lr": args.lr})
- opt_td.add_param_group({"params": [sp_param_td], "lr": args.lr})
- #
- # scheduler = CosineAnnealingLR(opt, args.epochs)
- scheduler_sd = CosineAnnealingLR(opt_sd, args.epochs)
- scheduler_td = CosineAnnealingLR(opt_td, args.epochs)
- criterion = nn.CrossEntropyLoss() # return the mean of CE over the batch
- # 加
- ce = nn.CrossEntropyLoss()
- mse = nn.MSELoss()
-
- # lookup table of regions means 3*3*3,划分为27个空间(Tensor:(27,3)),
- lookup = torch.Tensor(pc_utils.region_mean(args.num_regions)).to(device)
-
- # 加
-
- # ==================
- # Train
- # ==================
- src_best_val_acc = trgt_best_val_acc = best_val_epoch = 0
- src_best_val_loss = trgt_best_val_loss = MAX_LOSS
- best_model_sd = io.save_model(model_sd, 'sd')
- best_model_td = io.save_model(model_td, 'td')
-
- # 加
- # best_model =
-
- for epoch in range(args.epochs):
- # model.train()
- # 加
- model_sd.train()
- model_td.train()
-
- # init data structures for saving epoch stats
- cls_type = 'mixup' if args.apply_PCM else 'cls'
- src_print_losses = {"total": 0.0, cls_type: 0.0}
- if args.fix_ratio_mixup:
- src_print_losses['FixMix'] = 0.0
- if args.DefRec_on_src:
- src_print_losses['DefRec'] = 0.0
- if args.apply_bim:
- src_print_losses['bim'] = 0.0
- if args.apply_sp:
- src_print_losses['sp'] = 0.0
- if args.apply_cr:
- src_print_losses['cr'] = 0.0
-
- trgt_print_losses = {'DefRec': 0.0}
- src_count = trgt_count = 0.0
-
- batch_idx = 1
- for data1, data2 in zip(src_train_loader, trgt_train_loader):
- # opt.zero_grad()
-
- # 加
- opt_sd.zero_grad()
- opt_td.zero_grad()
- total_loss = 0.0
- #### source data ####
- if data1 is not None:
- # src_data的大小是Tensor(32,1024,3),src_label的大小是Tensor(32,)
- src_data, src_label = data1[0].to(device), data1[1].to(device).squeeze()
- # change to [batch_size, num_coordinates, num_points]
- src_data = src_data.permute(0, 2, 1)
- # permute之后src_data的大小是Tensor(32,3,1024)
- batch_size = src_data.size()[0]
- src_data_orig = src_data.clone() # 深拷贝
- device = torch.device("cuda:" + str(src_data.get_device()) if args.cuda else "cpu")
-
- # if args.DefRec_on_src:
- # src_data, src_mask = DefRec.deform_input(src_data, lookup, args.DefRec_dist, device)
- # src_logits = model(src_data, activate_DefRec=True)
- # loss = DefRec.calc_loss(args, src_logits, src_data_orig, src_mask)
- # src_print_losses['DefRec'] += loss.item() * batch_size
- # src_print_losses['total'] += loss.item() * batch_size
- # loss.backward()
-
- # if args.apply_PCM:
- # src_data = src_data_orig.clone()
- # src_data, mixup_vals = PCM.mix_shapes(args, src_data, src_label)
- # src_cls_logits = model(src_data, activate_DefRec=False)
- # loss = PCM.calc_loss(args, src_cls_logits, mixup_vals, criterion)
- # src_print_losses['mixup'] += loss.item() * batch_size
- # src_print_losses['total'] += loss.item() * batch_size
- # loss.backward()
-
- # else:
- # src_data = src_data_orig.clone()
- # # predict with undistorted shape
- # src_cls_logits = model(src_data, activate_DefRec=False)
- # loss = (1 - args.DefRec_weight) * criterion(src_cls_logits["cls"], src_label)
- # src_print_losses['cls'] += loss.item() * batch_size
- # src_print_losses['total'] += loss.item() * batch_size
- # loss.backward()
-
- src_count += batch_size
-
- #### target data ####
- if data2 is not None:
- trgt_data, trgt_label = data2[0].to(device), data2[1].to(device).squeeze()
- trgt_data = trgt_data.permute(0, 2, 1)
- batch_size = trgt_data.size()[0]
- trgt_data_orig = trgt_data.clone()
- device = torch.device("cuda:" + str(trgt_data.get_device()) if args.cuda else "cpu")
- # 加
- if args.fix_ratio_mixup:
- # torch.Size([32, 10])
- x_sd, x_td = model_sd(trgt_data), model_td(trgt_data)
- # torch.Size([32])
- pseudo_sd, top_prob_sd, threshold_sd = get_target_preds(args, x_sd)
- data_sd, fix_values_sd = fix_mix(args, src_data, trgt_data, src_label, pseudo_sd, args.lam_sd)
- logits_sd = model_sd(data_sd)
- loss_sd = PCM.calc_loss(args, logits_sd, fix_values_sd, criterion)
-
- pseudo_td, top_prob_td, threshold_td = get_target_preds(args, x_td)
- data_td, fix_values_td = fix_mix(args, src_data, trgt_data, src_label, pseudo_td, args.lam_td)
- logits_td = model_td(data_td)
- loss_td = PCM.calc_loss(args, logits_td, fix_values_td, criterion)
- total_loss += loss_sd
- total_loss += loss_td
- loss = loss_sd + loss_td
- src_print_losses['FixMix'] += loss.item() * batch_size
- # loss.backward()
- if args.DefRec_on_trgt:
- trgt_data, trgt_mask = DefRec.deform_input(trgt_data, lookup, args.DefRec_dist, device)
- trgt_logits = model_sd(trgt_data, activate_DefRec=True)
- loss = DefRec.calc_loss(args, trgt_logits, trgt_data_orig, trgt_mask)
- trgt_print_losses['DefRec'] += loss.item() * batch_size
- total_loss += loss
- # Bidirectional Matching
- if args.apply_bim:
- if epoch > args.bim_start:
- #
- bim_mask_sd = torch.ge(top_prob_sd, threshold_sd)
- bim_mask_sd = torch.nonzero(bim_mask_sd).squeeze()
-
- bim_mask_td = torch.ge(top_prob_td, threshold_td)
- bim_mask_td = torch.nonzero(bim_mask_td).squeeze()
-
- if bim_mask_sd.dim() > 0 and bim_mask_td.dim() > 0:
- if bim_mask_sd.numel() > 0 and bim_mask_td.numel() > 0:
- bim_mask = min(bim_mask_sd.size(0), bim_mask_td.size(0))
- bim_sd_loss = ce(x_sd['cls'][bim_mask_td[:bim_mask]],
- pseudo_td[bim_mask_td[:bim_mask]].cuda().detach())
- bim_td_loss = ce(x_td['cls'][bim_mask_sd[:bim_mask]],
- pseudo_sd[bim_mask_sd[:bim_mask]].cuda().detach())
- total_loss += bim_sd_loss
- total_loss += bim_td_loss
- loss = bim_sd_loss + bim_td_loss
- src_print_losses['bim'] += loss.item() * batch_size
- # loss.backward()
- # Self Penalization
- if args.apply_sp:
- if epoch <= args.sp_start:
-
- sp_mask_sd = torch.lt(top_prob_sd, threshold_sd)
- sp_mask_sd = torch.nonzero(sp_mask_sd).squeeze()
-
- sp_mask_td = torch.lt(top_prob_td, threshold_td)
- sp_mask_td = torch.nonzero(sp_mask_td).squeeze()
-
- if sp_mask_sd.dim() > 0 and sp_mask_td.dim() > 0:
- if sp_mask_sd.numel() > 0 and sp_mask_td.numel() > 0:
- # pdb.set_trace()
- sp_mask = min(sp_mask_sd.size(0), sp_mask_td.size(0))
- sp_sd_loss = get_sp_loss(x_sd['cls'][sp_mask_sd[:sp_mask]], pseudo_sd[sp_mask_sd[:sp_mask]],
- sp_param_sd)
- sp_td_loss = get_sp_loss(x_td['cls'][sp_mask_td[:sp_mask]], pseudo_td[sp_mask_td[:sp_mask]],
- sp_param_td)
- total_loss += sp_sd_loss
- total_loss += sp_td_loss
- loss = sp_sd_loss + sp_td_loss
- src_print_losses['sp'] += loss.item() * batch_size
-
-
- # Consistency Regularization
- if args.apply_cr:
- if epoch > args.cr_start:
- mixed_cr, __ = fix_mix(args, src_data, trgt_data, src_label, pseudo_sd, 0.5)
- # mixed_cr = 0.5 * src_data + 0.5 * trgt_data
- out_sd, out_td = model_sd(mixed_cr), model_td(mixed_cr)
- cr_loss = mse(out_sd['cls'], out_td['cls'])
- total_loss += cr_loss
- src_print_losses['cr'] += cr_loss.item() * batch_size
-
- trgt_count += batch_size
- total_loss.backward()
- # opt.step()
- opt_sd.step()
- opt_td.step()
- batch_idx += 1
-
- scheduler_sd.step()
- scheduler_td.step()
-
- # print progress
- src_print_losses = {k: v * 1.0 / src_count for (k, v) in src_print_losses.items()}
- src_acc = io.print_progress("Source", "Trn", epoch, src_print_losses)
- trgt_print_losses = {k: v * 1.0 / trgt_count for (k, v) in trgt_print_losses.items()}
- trgt_acc = io.print_progress("Target", "Trn", epoch, trgt_print_losses)
-
- # ===================
- # Validation
- # ===================
- # src_val_acc, src_val_loss, src_conf_mat = test(src_val_loader, model, "Source", "Val", epoch)
- # trgt_val_acc, trgt_val_loss, trgt_conf_mat = test(trgt_val_loader, model, "Target", "Val", epoch)
- io.cprint("src:sd model " + str(epoch))
- src_val_acc_sd, src_val_loss_sd, src_conf_mat_sd = test(src_val_loader, model_sd, "Source", "Val", epoch)
- io.cprint("src:td model " + str(epoch))
- src_val_acc_td, src_val_loss_td, src_conf_mat_td = test(src_val_loader, model_td, "Source", "Val", epoch)
- io.cprint("src:final" + str(epoch))
- src_val_acc, src_val_loss, src_conf_mat = final_test(src_val_loader, model_sd, model_td, "Source", "Val", epoch)
- io.cprint("trgt:sd model " + str(epoch))
- trgt_val_acc_sd, trgt_val_loss_sd, trgt_conf_mat_sd = test(trgt_val_loader, model_sd, "Target", "Val", epoch)
- io.cprint("trgt:td model " + str(epoch))
- trgt_val_acc_td, trgt_val_loss_td, trgt_conf_mat_td = test(trgt_val_loader, model_td, "Target", "Val", epoch)
- io.cprint("trgt:final" + str(epoch))
- trgt_val_acc, trgt_val_loss, trgt_conf_mat = final_test(trgt_val_loader, model_sd, model_td, "Target", "Val",
- epoch)
- # save model according to best source model (since we don't have target labels)
-
- if src_val_acc > src_best_val_acc:
- src_best_val_acc = src_val_acc
- src_best_val_loss = src_val_loss
- trgt_best_val_acc = trgt_val_acc
- trgt_best_val_loss = trgt_val_loss
- best_val_epoch = epoch
- best_epoch_conf_mat = trgt_conf_mat
- best_model_sd = io.save_model(model_sd, "sd")
- best_model_td = io.save_model(model_td, "td")
- if epoch > 100:
- trgt_test_acc, trgt_test_loss, trgt_conf_mat = final_test(trgt_test_loader, model_sd, model_td, "Target", "Test", 0)
-
- io.cprint("target test accuracy: %.4f, target test loss: %.4f" % (trgt_test_acc, trgt_best_val_loss))
- io.save_model(model_sd, str(epoch)+"sd")
- io.save_model(model_td, str(epoch)+"td")
- io.cprint("Best model was found at epoch %d, source validation accuracy: %.4f, source validation loss: %.4f,"
- "target validation accuracy: %.4f, target validation loss: %.4f"
- % (best_val_epoch, src_best_val_acc, src_best_val_loss, trgt_best_val_acc, trgt_best_val_loss))
- io.cprint("Best validtion model confusion matrix:")
- io.cprint('\n' + str(best_epoch_conf_mat))
-
- # ===================
- # Test
- # ===================
- model_sd = best_model_sd
- model_td = best_model_td
- trgt_test_acc, trgt_test_loss, trgt_conf_mat = final_test(trgt_test_loader, model_sd, model_td, "Target", "Test", 0)
- io.cprint("target test accuracy: %.4f, target test loss: %.4f" % (trgt_test_acc, trgt_best_val_loss))
- io.cprint("Test confusion matrix:")
- io.cprint('\n' + str(trgt_conf_mat))
|