|
- import os
- import glob
- import argparse
- import traceback
- import subprocess
- import numpy as np
-
-
- parser = argparse.ArgumentParser()
-
- parser.add_argument('-n', '--num', type=int, default=-1,
- help="Number of samples to use, `-1` means all")
- parser.add_argument('-s', '--start_idx', type=int, default=0,
- help="Index of samples to start")
-
- args = parser.parse_args()
-
- if os.path.exists('/tmp/code'):
- proj_name = os.listdir("/tmp/code/")[0]
- subprocess.run(["ln", "-sf", "/tmp/code/" + proj_name, "/code"])
- if os.path.exists('/tmp/dataset'):
- subprocess.run(["ln", "-sf", "/tmp/dataset", "/dataset"])
- if os.path.exists('/tmp/output'):
- subprocess.run(["ln", "-sf", "/tmp/output", "/model/task_out"])
- subprocess.run(["mkdir", "-p", "/model/task_out"])
-
- os.chdir('/code')
-
- from generate_eyes_tex import generate_datasets as generate_eyes_tex
- generate_eyes_tex()
-
-
- import os, sys
- import torch
- import torchvision
- import torch.nn.functional as F
- import torch.nn as nn
- import numpy as np
- import datetime
-
- from PIL import Image
-
- sys.path.append('./models/')
- from FLAME import FLAME, FLAMETex
- from renderer import Renderer
- import util
- torch.backends.cudnn.benchmark = True
-
-
- class ExtractNormalAndGrid(object):
- def __init__(self, config, device='cuda'):
- self.batch_size = config.batch_size
- self.image_size = config.image_size
- self.config = config
- self.device = device
- #
- self.flame = FLAME(self.config).to(self.device)
- self.flametex = FLAMETex(self.config).to(self.device)
-
- self._setup_renderer()
-
- def _setup_renderer(self):
- mesh_file = './data/head_template_mesh.obj'
- self.render = Renderer(self.image_size, obj_filename=mesh_file).to(self.device)
-
- def extract(self, params, savefolder=None):
- bz = 1
- shape = nn.Parameter(torch.from_numpy(params['shape']).float().to(self.device))
- tex = nn.Parameter(torch.zeros(bz, self.config.tex_params).float().to(self.device))
- exp = nn.Parameter(torch.from_numpy(params['exp']).float().to(self.device))
- pose = nn.Parameter(torch.from_numpy(params['pose']).float().to(self.device))
- eye_tex = nn.Parameter(torch.zeros(bz, self.flametex.eyes_basis.shape[-1]).float().to(self.device))
- eye_pose = nn.Parameter(torch.zeros(bz, self.config.eye_pose_params).float().to(self.device))
- cam = nn.Parameter(torch.from_numpy(params['cam']).float().to(self.device))
- lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device))
-
- try:
- with torch.no_grad():
- eye_pose_limit = torch.tanh(eye_pose)*0.5
- vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose, eye_pose_params=eye_pose_limit)
- trans_vertices = util.batch_orth_proj(vertices, cam);
- trans_vertices[..., 1:] = - trans_vertices[..., 1:]
- tex_full = torch.cat([tex, eye_tex], -1)
- albedos = self.flametex(tex_full) / 255.
- ops = self.render(vertices, trans_vertices, albedos, lights)
- uv_pverts = self.render.world2uv(trans_vertices)
- normal_images = ops['normal_images']
- normal_images = F.grid_sample(normal_images, uv_pverts.permute(0,2,3,1)[:,:,:,:2], mode='bilinear', align_corners=False)
- normal_image = (normal_images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
-
- grid = ops['grid']
- grid = grid.permute(0,3,1,2).clone()
- grid_r = (grid[:,:1] + 1) / 2 * 255
- grid_g = (grid[:,1:2] + 1) / 2 * 255
- grid_b = ((grid_r - grid_r.floor()) * 15).floor() * 16 + ((grid_g - grid_g.floor()) * 15).floor()
- grid_images = torch.cat([grid_r, grid_g, grid_b], 1).floor() / 255.0
- grid_image = (grid_images[0].cpu().numpy().transpose(1, 2, 0).copy() * 255)
-
- save_dict = dict(
- normal=normal_image, grid=grid_image
- )
- for name, value in save_dict.items():
- value = np.minimum(np.maximum(value, 0), 255).astype(np.uint8)
- Image.fromarray(value).save(
- '{}_{}.png'.format(
- savefolder,
- name
- )
- )
- except Exception:
- print(traceback.format_exc())
- exit(1)
-
- def run(self, posepath):
- params = np.load(posepath, allow_pickle=True).item()
- # util.check_mkdir(self.config.savefolder + '/' + os.path.basename(posepath)[:-4])
- # extract
- self.extract(params, self.config.savefolder + '/' + os.path.basename(posepath)[:-4])
-
- if __name__ == '__main__':
- device_name = 'cuda'
- use_preshape = True
- config = {
- # FLAME
- 'flame_model_path': '/dataset/FLAME2020/generic_model.pkl', # acquire it from FLAME project page
- 'flame_lmk_embedding_path': './data/landmark_embedding.npy',
- 'flame_albedo_tex_space_path': '/dataset/albedoModel2020_FLAME_albedoPart/albedoModel2020_FLAME_albedoPart.npz', # download it from AlbedoMM release page
- 'flame_eyes_tex': './data/head_eyes.png',
- 'head_eyes_path': './data/head_eyes.npy',
- 'camera_params': 3,
- 'shape_params': 300 if use_preshape else 100,
- 'expression_params': 100,
- 'pose_params': 6,
- 'eye_pose_params': 6,
- 'tex_params': 145,
- 'use_face_contour': True,
-
- 'cropped_size': 256,
- 'batch_size': 1,
- 'image_size': 224,
- 'e_lr': 0.005,
- 'e_wd': 0.0001,
- 'savefolder': '/model/task_out/ffhq_normal_and_grid',
- # weights of losses and reg terms
- 'w_pho': 8,
- 'w_lmks': 1,
- 'w_shape_reg': 1e-4,
- 'w_expr_reg': 1e-4,
- 'w_pose_reg': 0,
- }
-
- config = util.dict2obj(config)
-
- config.batch_size = 1
- fitting = ExtractNormalAndGrid(config, device=device_name)
-
-
- ffhq_paths = sorted(glob.glob('/dataset/images256x256/**/*.png', recursive=True))
- ffhq_pose_paths = sorted(glob.glob('/dataset/ffhq_flame_pose_00000_to_69999/**/*.npy', recursive=True))
- if args.num == -1:
- args.num = len(ffhq_paths)
- ffhq_paths = ffhq_paths[args.start_idx:args.num]
- ffhq_pose_paths = ffhq_pose_paths[args.start_idx:args.num]
- print('Datas size:', len(ffhq_paths))
-
- config.savefolder += "_" + \
- os.path.basename(ffhq_paths[0])[:-4] + "_to_" + \
- os.path.basename(ffhq_paths[-1])[:-4]
- save_folder_base = os.path.basename(config.savefolder)
-
- util.check_mkdir(config.savefolder)
-
- for posepath in ffhq_pose_paths:
-
- fitting.run(posepath)
-
- os.chdir('/model/task_out')
- print(f"Run `tar cfz {save_folder_base}.tar.gz {save_folder_base}/*'`")
- os.system(f'tar cfz {save_folder_base}.tar.gz {save_folder_base}/*')
|