|
- import argparse
-
- import numpy as np
- import os
-
- import tensorflow as tf
- from AnimeGANv2.net import generator as tf_generator
-
- import torch
- from model import Generator
-
-
- def load_tf_weights(tf_path):
- test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
- with tf.variable_scope("generator", reuse=False):
- test_generated = tf_generator.G_net(test_real).fake
-
- saver = tf.train.Saver()
-
- with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
- ckpt = tf.train.get_checkpoint_state(tf_path)
-
- assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {tf_path}"
-
- saver.restore(sess, ckpt.model_checkpoint_path)
- print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
-
- tf_weights = {}
- for v in tf.trainable_variables():
- tf_weights[v.name] = v.eval()
-
- return tf_weights
-
-
- def convert_keys(k):
-
- # 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
- # 2. handle each part & merge into a pytorch model keys
-
- k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
- keys = k.split("/")[2:]
-
- is_dconv = False
-
- # handle C block..
- if keys[0] == "C":
- if keys[1] in ["Conv_1", "LayerNorm_1"]:
- keys[1] = keys[1].replace("1", "5")
-
- if len(keys) == 4:
- assert "r" in keys[1]
-
- if keys[1] == keys[2]:
- is_dconv = True
- keys[2] = "1.1"
-
- block_c_maps = {
- "1": "1.2",
- "Conv_1": "2",
- "2": "3",
- }
- if keys[2] in block_c_maps:
- keys[2] = block_c_maps[keys[2]]
-
- keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
- keys[2] = keys[3]
- keys.pop(-1)
- assert len(keys) == 3
-
- # handle output block
- if "out" in keys[0]:
- keys[1] = "0"
-
- # first part
- if keys[0] in ["A", "B", "C", "D", "E"]:
- keys[0] = "block_" + keys[0].lower()
-
- # second part
- if "LayerNorm_" in keys[1]:
- keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
- if "Conv_" in keys[1]:
- keys[1] = keys[1].replace("Conv_", "") + ".1"
-
- # third part
- keys[2] = {
- "weights:0": "weight",
- "w:0": "weight",
- "bias:0": "bias",
- "gamma:0": "weight",
- "beta:0": "bias",
- }[keys[2]]
-
- return ".".join(keys), is_dconv
-
-
- def convert_and_save(tf_checkpoint_path, save_name):
-
- tf_weights = load_tf_weights(tf_checkpoint_path)
-
- torch_net = Generator()
- torch_weights = torch_net.state_dict()
-
- torch_converted_weights = {}
- for k, v in tf_weights.items():
- torch_k, is_dconv = convert_keys(k)
- assert torch_k in torch_weights, f"weight name mismatch: {k}"
-
- converted_weight = torch.from_numpy(v)
- if len(converted_weight.shape) == 4:
- if is_dconv:
- converted_weight = converted_weight.permute(2, 3, 0, 1)
- else:
- converted_weight = converted_weight.permute(3, 2, 0, 1)
-
- assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
-
- torch_converted_weights[torch_k] = converted_weight
-
- assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
- torch_net.load_state_dict(torch_converted_weights)
- torch.save(torch_net.state_dict(), save_name)
- print(f"PyTorch model saved at {save_name}")
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--tf_checkpoint_path',
- type=str,
- default='AnimeGANv2/checkpoint/generator_Paprika_weight',
- )
- parser.add_argument(
- '--save_name',
- type=str,
- default='pytorch_generator_Paprika.pt',
- )
- args = parser.parse_args()
-
- convert_and_save(args.tf_checkpoint_path, args.save_name)
|