|
- import os
- import glob
- import argparse
- import traceback
- import subprocess
- import numpy as np
-
-
- parser = argparse.ArgumentParser()
-
- parser.add_argument('-n', '--num', type=int, default=-1,
- help="Number of samples to use, `-1` means all")
- parser.add_argument('-s', '--start_idx', type=int, default=0,
- help="Index of samples to start")
-
- args = parser.parse_args()
-
- if os.path.exists('/tmp/code'):
- proj_name = os.listdir("/tmp/code/")[0]
- subprocess.run(["ln", "-sf", "/tmp/code/" + proj_name, "/code"])
- if os.path.exists('/tmp/dataset'):
- subprocess.run(["ln", "-sf", "/tmp/dataset", "/dataset"])
- if os.path.exists('/tmp/output'):
- subprocess.run(["ln", "-sf", "/tmp/output", "/model/task_out"])
- subprocess.run(["mkdir", "-p", "/model/task_out"])
-
- os.chdir('/code')
-
- from generate_eyes_tex import generate_datasets as generate_eyes_tex
- generate_eyes_tex()
-
-
- import os, sys
- import torch
- import torchvision
- import torch.nn.functional as F
- import torch.nn as nn
- import numpy as np
- import datetime
-
- from PIL import Image
-
- sys.path.append('./models/')
- from FLAME import FLAME, FLAMETex
- from renderer import Renderer
- import util
- torch.backends.cudnn.benchmark = True
-
-
- class LandmarksFitting(object):
- def __init__(self, config, device='cuda'):
- self.batch_size = config.batch_size
- self.image_size = config.image_size
- self.config = config
- self.device = device
- #
- self.flame = FLAME(self.config).to(self.device)
- self.flametex = FLAMETex(self.config).to(self.device)
-
- self._setup_renderer()
-
- def _setup_renderer(self):
- mesh_file = './data/head_template_mesh.obj'
- self.render = Renderer(self.image_size, obj_filename=mesh_file).to(self.device)
-
- def optimize(self, images, landmarks, image_masks, preshape=None, savefolder=None):
- torch.cuda.manual_seed(42)
-
- bz = images.shape[0]
- shape = nn.Parameter(torch.zeros(bz, self.config.shape_params).float().to(self.device))
- if preshape is not None:
- shape = torch.from_numpy(preshape)[None].float().to(self.device)
- tex = nn.Parameter(torch.zeros(bz, self.config.tex_params).float().to(self.device))
- exp = nn.Parameter(torch.zeros(bz, self.config.expression_params).float().to(self.device))
- pose = nn.Parameter(torch.zeros(bz, self.config.pose_params).float().to(self.device))
- eye_tex = nn.Parameter(torch.zeros(bz, self.flametex.eyes_basis.shape[-1]).float().to(self.device))
- cam = torch.zeros(bz, self.config.camera_params); cam[:, 0] = 5.
- cam = nn.Parameter(cam.float().to(self.device))
- lights = nn.Parameter(torch.ones(bz, 9, 3).float().to(self.device))*0.3
- e_opt = torch.optim.Adam(
- ([shape] if preshape is None else []) + \
- [exp, pose, cam],
- lr=self.config.e_lr,
- weight_decay=self.config.e_wd
- )
- e_opt_rigid = torch.optim.Adam(
- [pose, cam],
- lr=self.config.e_lr,
- weight_decay=self.config.e_wd
- )
-
- gt_landmark = landmarks
-
- # rigid fitting of pose and camera with 51 static face landmarks,
- # this is due to the non-differentiable attribute of contour landmarks trajectory
- for k in range(200):
- losses = {}
- vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
- trans_vertices = util.batch_orth_proj(vertices, cam);
- trans_vertices[..., 1:] = - trans_vertices[..., 1:]
- landmarks2d = util.batch_orth_proj(landmarks2d, cam);
- landmarks2d[..., 1:] = - landmarks2d[..., 1:]
- landmarks3d = util.batch_orth_proj(landmarks3d, cam);
- landmarks3d[..., 1:] = - landmarks3d[..., 1:]
-
- losses['landmark'] = util.l2_distance(landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * config.w_lmks
-
- all_loss = 0.
- for key in losses.keys():
- all_loss = all_loss + losses[key]
- losses['all_loss'] = all_loss
- e_opt_rigid.zero_grad()
- all_loss.backward()
- e_opt_rigid.step()
-
- loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
- for key in losses.keys():
- loss_info = loss_info + '{}: {}, '.format(key, float(losses[key]))
- if k % 10 == 0:
- print(loss_info)
-
- # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms.
- for k in range(200, 1000):
- losses = {}
- vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
- trans_vertices = util.batch_orth_proj(vertices, cam);
- trans_vertices[..., 1:] = - trans_vertices[..., 1:]
- landmarks2d = util.batch_orth_proj(landmarks2d, cam);
- landmarks2d[..., 1:] = - landmarks2d[..., 1:]
- landmarks3d = util.batch_orth_proj(landmarks3d, cam);
- landmarks3d[..., 1:] = - landmarks3d[..., 1:]
-
- losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * config.w_lmks
- losses['shape_reg'] = (torch.sum(shape ** 2) / 2) * config.w_shape_reg # *1e-4
- losses['expression_reg'] = (torch.sum(exp ** 2) / 2) * config.w_expr_reg # *1e-4
- losses['pose_reg'] = (torch.sum(pose ** 2) / 2) * config.w_pose_reg
-
- all_loss = 0.
- for key in losses.keys():
- all_loss = all_loss + losses[key]
- losses['all_loss'] = all_loss
- e_opt.zero_grad()
- all_loss.backward()
- e_opt.step()
-
- loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
- for key in losses.keys():
- loss_info = loss_info + '{}: {}, '.format(key, float(losses[key]))
-
- if k % 10 == 0:
- print(loss_info)
-
- print('A fitting task finished.')
- try:
- with torch.no_grad():
- vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
- trans_vertices = util.batch_orth_proj(vertices, cam);
- trans_vertices[..., 1:] = - trans_vertices[..., 1:]
- tex_full = torch.cat([tex, eye_tex], -1)
- albedos = self.flametex(tex_full) / 255.
- ops = self.render(vertices, trans_vertices, albedos, lights)
- uv_pverts = self.render.world2uv(trans_vertices)
- normal_images = ops['normal_images']
- normal_images = F.grid_sample(normal_images, uv_pverts.permute(0,2,3,1)[:,:,:,:2], mode='bilinear', align_corners=False)
- albedos_gt = F.grid_sample(images, uv_pverts.permute(0,2,3,1)[:,:,:,:2], mode='bilinear', align_corners=False)
- # uv_shading = self.render.add_SHlight(normal_images, lights)
- # albedos_gt = albedos_gt / uv_shading
- normal_image = (normal_images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
-
- texture_image = (albedos[0,:3].cpu().numpy().transpose(1, 2, 0).copy() * 255)
- texture_gt_image = (albedos_gt[0,:3].cpu().numpy().transpose(1, 2, 0).copy() * 255)
-
- grid = ops['grid']
- grid = grid.permute(0,3,1,2).clone()
- grid_r = (grid[:,:1] + 1) / 2 * 255
- grid_g = (grid[:,1:2] + 1) / 2 * 255
- grid_b = ((grid_r - grid_r.floor()) * 15).floor() * 16 + ((grid_g - grid_g.floor()) * 15).floor()
- grid_images = torch.cat([grid_r, grid_g, grid_b], 1).floor() / 255.0
- grid_image = (grid_images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
-
- rendered_images = ops['images']
- alpha_images = ops['alpha_images'] * ops['pos_mask']
- rendered_image = (rendered_images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
- alpha_image = (alpha_images[0,0].cpu().numpy().copy() * 255)
- image = (images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
- save_dict = dict(
- texture_gt=texture_gt_image,
- rendered=rendered_image, alpha=alpha_image
- )
- save_dict['original'] = image
- for name, value in save_dict.items():
- value = np.minimum(np.maximum(value, 0), 255).astype(np.uint8)
- Image.fromarray(value).save(
- '{}/{}.png'.format(
- savefolder,
- name
- )
- )
- except Exception:
- print(traceback.format_exc())
- exit(1)
-
- single_params = {
- 'shape': shape.detach().cpu().numpy(),
- 'exp': exp.detach().cpu().numpy(),
- 'pose': pose.detach().cpu().numpy(),
- 'cam': cam.detach().cpu().numpy(),
- }
- return single_params
-
- def run(self, imagepath, landmarkpath, image_mask_path, image_preshape_path):
- # The implementation is potentially able to optimize with images(batch_size>1),
- # here we show the example with a single image fitting
- images = []
- landmarks = []
- image_masks = []
-
- image_name = os.path.basename(imagepath)[:-4]
- savefile = os.path.sep.join([self.config.savefolder, image_name + '.npy'])
-
- # photometric optimization is sensitive to the hair or glass occlusions,
- # therefore we use a face segmentation network to mask the skin region out.
- # image_mask_folder = './FFHQ_seg/'
- # image_mask_path = os.path.sep.join([image_mask_folder, image_name + '.npy'])
-
- # image = cv2.resize(cv2.imread(imagepath), (config.cropped_size, config.cropped_size)).astype(np.float32) / 255.
- # image = image[:, :, [2, 1, 0]].transpose(2, 0, 1)
- image = Image.open(imagepath).convert('RGB').resize((config.cropped_size, config.cropped_size), Image.BILINEAR)
- image = np.asarray(image).transpose(2, 0, 1) / 255.
- images.append(torch.from_numpy(image[None, :, :, :]).float().to(self.device))
-
- image_mask = np.asarray(Image.open(image_mask_path).convert('L'))
- _image_mask = 0
- mask_ids = [1,2,3,4,5,6,7,8,9,11,12,17]
- for mask_id in mask_ids:
- _image_mask = _image_mask + (image_mask == mask_id).astype('float32')
- image_mask = _image_mask > 0
- image_mask = image_mask[..., None].astype('float32')
- image_mask = image_mask.transpose(2, 0, 1)
- image_mask_bn = np.zeros_like(image_mask)
- image_mask_bn[np.where(image_mask != 0)] = 1.
- image_masks.append(torch.from_numpy(image_mask_bn[None, :, :, :]).to(self.device))
-
- landmark = np.load(landmarkpath).astype(np.float32)
- landmark[:, 0] = landmark[:, 0] / float(image.shape[2]) * 2 - 1
- landmark[:, 1] = landmark[:, 1] / float(image.shape[1]) * 2 - 1
- landmarks.append(torch.from_numpy(landmark)[None, :, :].float().to(self.device))
-
- images = torch.cat(images, dim=0)
- images = F.interpolate(images, [self.image_size, self.image_size])
- image_masks = torch.cat(image_masks, dim=0)
- image_masks = F.interpolate(image_masks, [self.image_size, self.image_size])
-
- landmarks = torch.cat(landmarks, dim=0)
- # savefolder = os.path.sep.join([self.config.savefolder, image_name])
- savefolder_intermediate = os.path.sep.join([self.config.savefolder + "_intermediate", image_name])
-
- preshape = np.load(image_preshape_path)
-
- # util.check_mkdir(savefolder)
- util.check_mkdir(savefolder_intermediate)
- # optimize
- single_params = self.optimize(images, landmarks, image_masks, preshape, savefolder_intermediate)
- # self.render.save_obj(filename=savefile[:-4]+'.obj',
- # vertices=torch.from_numpy(single_params['verts'][0]).to(self.device),
- # textures=torch.from_numpy(single_params['albedos'][0]).to(self.device)
- # )
- np.save(savefile, single_params)
-
-
- if __name__ == '__main__':
- device_name = 'cuda'
- use_preshape = True
- config = {
- # FLAME
- 'flame_model_path': '/dataset/FLAME2020/generic_model.pkl', # acquire it from FLAME project page
- 'flame_lmk_embedding_path': './data/landmark_embedding.npy',
- 'flame_albedo_tex_space_path': '/dataset/albedoModel2020_FLAME_albedoPart/albedoModel2020_FLAME_albedoPart.npz', # download it from AlbedoMM release page
- 'flame_eyes_tex': './data/head_eyes.png',
- 'head_eyes_path': './data/head_eyes.npy',
- 'camera_params': 3,
- 'shape_params': 300 if use_preshape else 100,
- 'expression_params': 100,
- 'pose_params': 6,
- 'eye_pose_params': 6,
- 'tex_params': 145,
- 'use_face_contour': True,
-
- 'cropped_size': 256,
- 'batch_size': 1,
- 'image_size': 224,
- 'e_lr': 0.005,
- 'e_wd': 0.0001,
- 'savefolder': '/model/task_out/ffhq_flame_pose',
- # weights of losses and reg terms
- 'w_pho': 8,
- 'w_lmks': 1,
- 'w_shape_reg': 1e-4,
- 'w_expr_reg': 1e-4,
- 'w_pose_reg': 0,
- }
-
- config = util.dict2obj(config)
-
- config.batch_size = 1
- fitting = LandmarksFitting(config, device=device_name)
-
-
- ffhq_paths = sorted(glob.glob('/dataset/ffhq_aging256x256/**/*.png', recursive=True))
- ffhq_mask_paths = [path for path in ffhq_paths if 'parsings' in path]
- ffhq_paths = [path for path in ffhq_paths if 'parsings' not in path]
- ffhq_lmks_paths = sorted(glob.glob('/dataset/FFHQ_lmks/**/*.npy', recursive=True))
- ffhq_shape_paths = sorted(glob.glob('/dataset/ffhq_identity/**/*.npy', recursive=True))
- if args.num == -1:
- args.num = len(ffhq_paths)
- ffhq_paths = ffhq_paths[args.start_idx:args.num]
- ffhq_mask_paths = ffhq_mask_paths[args.start_idx:args.num]
- ffhq_lmks_paths = ffhq_lmks_paths[args.start_idx:args.num]
- ffhq_shape_paths = ffhq_shape_paths[args.start_idx:args.num]
- print('Datas size:', len(ffhq_paths))
-
- config.savefolder += "_" + \
- os.path.basename(ffhq_paths[0])[:-4] + "_to_" + \
- os.path.basename(ffhq_paths[-1])[:-4]
- save_folder_base = os.path.basename(config.savefolder)
-
- util.check_mkdir(config.savefolder)
- util.check_mkdir(config.savefolder + "_intermediate")
-
- for imagepath, maskpath, landmarkpath, shapepath in zip(
- ffhq_paths, ffhq_mask_paths, ffhq_lmks_paths, ffhq_shape_paths):
-
- fitting.run(imagepath, landmarkpath, maskpath, shapepath)
-
- os.chdir('/model/task_out')
- os.system(f'tar cfz {save_folder_base}.tar.gz {save_folder_base}/*')
|