|
- import argparse
- import os
- import random
- import torch
- import torch.optim as optim
- from torch.utils.data import DataLoader
-
- import sys
- sys.path.append("..")
- #from data.dataset import ShapeNetDataset
- from data.dataset_shapenet import ShapeNetDataset
- from model import PointNetSeg, feature_transform_regularizer
- from tqdm import tqdm
- import numpy as np
-
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
-
- # def cal_accuracy_iou(preds, labels, seg_classes, pt=True):
- # '''
- # :param pred: shape=(B, N)
- # :param labels: shape=(B, N)
- # :param seg_classes: dict: cat->labels
- # :return:
- # '''
- # nclasses, n = len(seg_classes), len(preds)
- # shape_ious = {cat: [] for cat in seg_classes}
- # shape_count = {cat: 0.0 for cat in seg_classes}
- # shape_points_seen = {cat: 0.0 for cat in seg_classes}
- # shape_points_correct = {cat: 0.0 for cat in seg_classes}
- # seg2cat = {}
- # for k, vs in seg_classes.items():
- # for v in vs:
- # seg2cat[v] = k
- # for i in range(n):
- # pred, label = preds[i], labels[i]
- # npoints = len(pred)
- # cat = seg2cat[label[0]]
- # shape_count[cat] += 1
- # shape_points_seen[cat] += npoints
- # shape_points_correct[cat] += np.sum(pred == label)
- # part_ious = []
- # for l in seg_classes[cat]:
- # intersection = np.sum(np.all([pred == l, label == l], axis=0))
- # union = np.sum(np.any([pred == l, label == l], axis=0))
- # if union < 1:
- # part_ious.append(1.0)
- # continue
- # part_ious.append(intersection / union)
- # shape_ious[cat].append(np.mean(part_ious))
-
- # if pt:
- # print('='*40)
- # weighted_acc = 0.0
- # weighted_average_iou = 0.0
- # accs, ious = [], []
- # for cat in sorted(seg_classes.keys()):
- # acc = shape_points_correct[cat] / float(shape_points_seen[cat])
- # iou = np.mean(shape_ious[cat])
- # if pt:
- # print('{} | acc: {:.4f}, iou: {:.4f}'.format(cat, acc, iou))
- # accs.append(round(acc * 100, 1))
- # ious.append(round(iou * 100, 1))
- # weighted_acc += shape_count[cat] * acc
- # weighted_average_iou += shape_count[cat] * iou
- # #print('accs: ', accs)
- # #print('ious: ', ious)
- # weighted_acc = weighted_acc / np.sum(list(shape_count.values())).astype(np.float32)
- # weighted_average_iou = weighted_average_iou / np.sum(list(shape_count.values())).astype(np.float32)
- # return weighted_average_iou, weighted_acc
-
-
- def train(seg_model, opt, dataset, dataloader):
- optimizer = optim.Adam(seg_model.parameters(), lr=0.001, betas=(0.9, 0.999))
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
- seg_model.cuda()
-
- num_batch = len(dataset) / opt.batchsize
-
- for epoch in tqdm(range(opt.epochs)):
- scheduler.step()
- for i, data in enumerate(dataloader, 0):
- points, target = data
- points = points.transpose(2, 1)
- points, target = points.cuda(), target.cuda()
- optimizer.zero_grad()
- seg_model = seg_model.train()
- pred, trans, trans_feat = seg_model(points)
- #print('----------------->', pred.size())
- pred = pred.view(-1, opt.nclasses)
- target = target.view(-1, 1)[:, 0] - 1
- #print(pred.size(), target.size())
- loss = torch.nn.functional.nll_loss(pred, target)
- loss += feature_transform_regularizer(trans_feat) * 0.001
- loss.backward()
- optimizer.step()
- pred_choice = pred.data.max(1)[1]
- correct = pred_choice.eq(target.data).cpu().sum()
- print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchsize * 2500)))
-
- if epoch % 10 == 0:
- torch.save(seg_model.state_dict(), '%s/seg_model_%d.pth' % (opt.ckp, epoch))
-
-
- ## Testing...
- ## benchmark mIOU
- def test(seg_model, opt, test_dataloader, checkpoint_path):
- logfile.flush()
- seg_model.cuda()
- seg_model.load_state_dict(torch.load(os.path.join(opt.ckp, checkpoint_path)))
- eg_model = seg_model.eval()
- shape_ious = []
- for i,data in tqdm(enumerate(test_dataloader, 0)):
- points, target = data
- points = points.transpose(2, 1)
- points, target = points.cuda(), target.cuda()
-
- pred, _, _ = seg_model(points)
- pred_choice = pred.data.max(2)[1]
-
- pred_np = pred_choice.cpu().data.numpy()
- target_np = target.cpu().data.numpy() - 1
- #average_iou, acc = cal_accuracy_iou(pred_np, target_np, seg_classes=num_classes)
- for shape_idx in range(target_np.shape[0]):
- parts = range(num_classes)#np.unique(target_np[shape_idx])
- part_ious = []
- for part in parts:
- I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
- U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
- if U == 0:
- iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1
- else:
- iou = I / float(U)
- part_ious.append(iou)
- shape_ious.append(np.mean(part_ious))
- print("mIOU for class {}: {}\n".format(opt.cat_choice, np.mean(shape_ious)))
- logfile.write("mIOU for class {}: {}\n".format(opt.cat_choice, np.mean(shape_ious)))
- logfile.close()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--batchsize', type=int, default=32, help='input batch size')
- parser.add_argument('--epochs', type=int, default=101, help='number of epochs to train for')
- parser.add_argument('--ckp', type=str, default='checkpoints', help='output folder')
- parser.add_argument('--model', type=str, default='', help='model path')
- parser.add_argument('--dataset', type=str, default='../data/', required=False, help="dataset path")
- parser.add_argument('--cat_choice', type=str, default='Chair', help="class_choice")
- parser.add_argument('--nclasses', type=int, default=50, help='Number of classes')
- parser.add_argument('--mode', type=str, default='train', help='train or test')
-
- logfile = open('./log.txt','a')
- opt = parser.parse_args()
- print(opt)
-
- if not os.path.exists(opt.ckp):
- os.makedirs(opt.ckp)
-
- opt.manualSeed = random.randint(1, 10000) # fix seed
- print("Random Seed: ", opt.manualSeed)
- random.seed(opt.manualSeed)
- torch.manual_seed(opt.manualSeed)
-
- dataset = ShapeNetDataset(data_root=opt.dataset, split='train')
- dataloader = DataLoader(dataset, batch_size=opt.batchsize, shuffle=True, num_workers=4)
-
- test_dataset = ShapeNetDataset(data_root=opt.dataset, cat_choice=opt.cat_choice, split='test')
- test_dataloader = DataLoader(test_dataset, batch_size=opt.batchsize, shuffle=False, num_workers=4)
-
-
- print('-----------------------------')
- print(len(dataset), len(test_dataset))
- print('train on all catagery data')
- print('test on a specified categery: %s' % opt.cat_choice)
-
- if opt.cat_choice is not None:
- num_classes = test_dataset.seg_classes[opt.cat_choice]
-
- print('test_classes', num_classes)
- print('-----------------------------')
-
- seg_model = PointNetSeg(k=opt.nclasses)
-
-
- if opt.mode == 'train':
- train(seg_model, opt, dataset, dataloader)
-
- if opt.mode == 'test':
- checkpoint_path = 'seg_model_90.pth'
- test(seg_model, opt, test_dataloader, checkpoint_path)
-
|