|
- from glob import glob
- import time
- import datetime
- from tensorboardX import SummaryWriter
- from torch.utils.data import DataLoader
- from utils import *
- from losses import *
- from models.unet import UNet
- from models.pix2pix_networks import PixelDiscriminator
- from models.liteFlownet import lite_flownet as lite_flow
- import random
- import torch
- import numpy as np
- import cv2
- import glob
- from torch.utils.data import Dataset
-
-
-
- class Train_dataset(Dataset):
- """
- No data augmentation.
- Normalized from [0, 255] to [-1, 1], the channels are BGR due to cv2 and liteFlownet.
- """
-
- def __init__(self, train_data):
- self.train_data = train_data
- self.img_h = 256
- self.img_w = 256
- self.clip_length = 5
- self.videos = []
- self.all_seqs = []
- for folder in sorted(glob.glob(f'{self.train_data}/*')):
- all_imgs = glob.glob(f'{folder}/*.jpg')
- all_imgs.sort()
- self.videos.append(all_imgs)
- random_seq = list(range(len(all_imgs) - 4))
- random.shuffle(random_seq)
- self.all_seqs.append(random_seq)
-
- def __len__(self): # This decide the indice range of the PyTorch Dataloader.
- return len(self.videos)
-
- def __getitem__(self, indice): # Indice decide which video folder to be loaded.
- one_folder = self.videos[indice]
-
- video_clip = []
- start = self.all_seqs[indice][-1] # Always use the last index in self.all_seqs.
- for i in range(start, start + self.clip_length):
- video_clip.append(np_load_frame(one_folder[i], self.img_h, self.img_w))
-
- video_clip = np.array(video_clip).reshape((-1, self.img_h, self.img_w))
- video_clip = torch.from_numpy(video_clip)
-
- flow_str = f'{indice}_{start + 3}-{start + 4}'
- return indice, video_clip, flow_str
-
- def np_load_frame(filename, resize_h, resize_w):
- img = cv2.imread(filename)
- image_resized = cv2.resize(img, (resize_w, resize_h)).astype('float32')
- image_resized = (image_resized / 127.5) - 1.0 # to -1 ~ 1
- image_resized = np.transpose(image_resized, [2, 0, 1]) # to (C, W, H)
- return image_resized
-
-
-
- class Train():
- def __init__(self, dataset, train_data):
- self.dataset = dataset #'dataset', 'ped2', type=str, help='The name of the dataset to train.'
- self.train_data = train_data
- self.batch_size = 8 #'--batch_size', default=8, type=int
- self.iters = 40000 #'--iters', default=40000, type=int, help='The total iteration number.'
- self.save_intervel = 1000 #'--save_interval', default=1000, type=int, help='Save the model every [save_interval] iterations.'
- self.g_lr = 0.0002 # learning rate of generator
- self.d_lr = 0.00002 # learning rate of discriminator
- self.generator = UNet(input_channels=12, output_channel=3).cuda() # generator
- self.discriminator = PixelDiscriminator(input_nc=3).cuda() #discriminator
- self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.g_lr)
- self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.d_lr)
- self.generator.apply(weights_init_normal)
- self.discriminator.apply(weights_init_normal)
- self.flow_net = lite_flow.Network() # '--flownet', default='lite', type=str, help='lite: LiteFlownet
- self.flow_net.load_state_dict(torch.load('models/liteFlownet/network-default.pytorch'))
- self.flow_net.cuda().eval() # Use flow_net to generate optic flows, so set to eval mode.
- self.adversarial_loss = Adversarial_Loss().cuda()
- self.discriminate_loss = Discriminate_Loss().cuda()
- self.gradient_loss = Gradient_Loss(3).cuda()
- self.flow_loss = Flow_Loss().cuda()
- self.intensity_loss = Intensity_Loss().cuda()
-
- def __call__(self, *args, **kwargs):
- train_dataset = Train_dataset(self.train_data)
- # Remember to set drop_last=True, because we need to use 4 frames to predict one frame.
- train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size,
- shuffle=True, num_workers=4, drop_last=True)
- writer = SummaryWriter(f'tensorboard_log/{self.dataset}_bs{self.batch_size}')
- start_iter = 0
- training = True
- generator = self.generator.train()
- discriminator = self.discriminator.train()
- step = start_iter
- while training:
- for indice, clips, flow_strs in train_dataloader:
- input_frames = clips[:, 0:12, :, :].cuda() # (n, 12, 256, 256)
- target_frame = clips[:, 12:15, :, :].cuda() # (n, 3, 256, 256)
- input_last = input_frames[:, 9:12, :, :].cuda() # use for flow_loss
-
- # pop() the used frame index, this can't work in train_dataset.__getitem__ because of multiprocessing.
- for index in indice:
- train_dataset.all_seqs[index].pop()
- if len(train_dataset.all_seqs[index]) == 0:
- train_dataset.all_seqs[index] = list(range(len(train_dataset.videos[index]) - 4))
- random.shuffle(train_dataset.all_seqs[index])
-
- G_frame = generator(input_frames)
- gt_flow_input = torch.cat([input_last, target_frame], 1)
- pred_flow_input = torch.cat([input_last, G_frame], 1)
- # No need to train flow_net, use .detach() to cut off gradients.
- flow_gt = self.flow_net.batch_estimate(gt_flow_input, self.flow_net).detach()
- flow_pred = self.flow_net.batch_estimate(pred_flow_input, self.flow_net).detach()
-
- inte_l = self.intensity_loss(G_frame, target_frame)
- grad_l = self.gradient_loss(G_frame, target_frame)
- fl_l = self.flow_loss(flow_pred, flow_gt)
- g_l = self.adversarial_loss(discriminator(G_frame))
- G_l_t = 1. * inte_l + 1. * grad_l + 2. * fl_l + 0.05 * g_l
-
- # Train discriminator
- # When training discriminator, don't train generator, so use .detach() to cut off gradients.
- D_l = self.discriminate_loss(discriminator(target_frame), discriminator(G_frame.detach()))
- self.optimizer_D.zero_grad()
- D_l.backward()
- self.optimizer_D.step()
-
- # Train generator
- self.optimizer_G.zero_grad()
- G_l_t.backward()
- self.optimizer_G.step()
-
- torch.cuda.synchronize()
- time_end = time.time()
- if step > start_iter: # This doesn't include the testing time during training.
- iter_t = time_end - temp
- temp = time_end
-
- if step != start_iter:
- if step % 20 == 0:
- time_remain = (self.iters - step) * iter_t
- eta = str(datetime.timedelta(seconds=time_remain)).split('.')[0]
- psnr = psnr_error(G_frame, target_frame)
- lr_g = self.optimizer_G.param_groups[0]['lr']
- lr_d = self.optimizer_D.param_groups[0]['lr']
-
- print(f"[{step}] inte_l: {inte_l:.3f} | grad_l: {grad_l:.3f} | fl_l: {fl_l:.3f} | "
- f"g_l: {g_l:.3f} | G_l_total: {G_l_t:.3f} | D_l: {D_l:.3f} | psnr: {psnr:.3f} | "
- f"iter: {iter_t:.3f}s | ETA: {eta} | lr: {lr_g} {lr_d}")
-
- save_G_frame = ((G_frame[0] + 1) / 2)
- save_G_frame = save_G_frame.cpu().detach()[(2, 1, 0), ...]
- save_target = ((target_frame[0] + 1) / 2)
- save_target = save_target.cpu().detach()[(2, 1, 0), ...]
-
- writer.add_scalar('psnr/train_psnr', psnr, global_step=step)
- writer.add_scalar('total_loss/g_loss_total', G_l_t, global_step=step)
- writer.add_scalar('total_loss/d_loss', D_l, global_step=step)
- writer.add_scalar('G_loss_total/g_loss', g_l, global_step=step)
- writer.add_scalar('G_loss_total/fl_loss', fl_l, global_step=step)
- writer.add_scalar('G_loss_total/inte_loss', inte_l, global_step=step)
- writer.add_scalar('G_loss_total/grad_loss', grad_l, global_step=step)
- writer.add_scalar('psnr/train_psnr', psnr, global_step=step)
-
- if step % int(self.iters / 100) == 0:
- writer.add_image('image/G_frame', save_G_frame, global_step=step)
- writer.add_image('image/target', save_target, global_step=step)
-
- if step % self.save_intervel == 0:
- model_dict = {'net_g': generator.state_dict(), 'optimizer_g': self.optimizer_G.state_dict(),
- 'net_d': discriminator.state_dict(), 'optimizer_d': self.optimizer_D.state_dict()}
- torch.save(model_dict, f'weights/{self.dataset}_{step}.pth')
- print(f'\nAlready saved: \'{self.dataset}_{step}.pth\'.')
-
- step += 1
- if step > self.iters:
- training = False
- model_dict = {'net_g': generator.state_dict(), 'optimizer_g': self.optimizer_G.state_dict(),
- 'net_d': discriminator.state_dict(), 'optimizer_d': self.optimizer_D.state_dict()}
- torch.save(model_dict, f'weights/latest_{self.dataset}_{step}.pth')
- break
-
-
- if __name__ == '__main__':
- dataset = 'ped2'
- train_data = '/home/huangchao/data/ped2/training/'
- c = Train(dataset, train_data)
- c.__call__()
- print('ok')
|