|
- import os
- import argparse
- import collections
- import torch
- from mindspore import Tensor
- from mindspore.train.serialization import save_checkpoint
-
-
- def build_params_map(attention_num=12):
- """
- build params map from pytorch model to mindspore model
- """
- weight_map = collections.OrderedDict({
- 'bert.embeddings.word_embeddings.weight': "base_encoder.bert.bert_embedding_lookup.embedding_table",
- 'bert.embeddings.position_embeddings.weight': "base_encoder.bert.bert_embedding_postprocessor.full_position_embedding.embedding_table",
- 'bert.embeddings.token_type_embeddings.weight': "base_encoder.bert.bert_embedding_postprocessor.token_type_embedding.embedding_table",
- 'bert.embeddings.LayerNorm.weight': 'base_encoder.bert.bert_embedding_postprocessor.layernorm.gamma',
- 'bert.embeddings.LayerNorm.bias': 'base_encoder.bert.bert_embedding_postprocessor.layernorm.beta',
- })
- # add attention layers
- for i in range(attention_num):
- weight_map[f'bert.encoder.layer.{i}.attention.self.query.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.query_layer.weight'
- weight_map[f'bert.encoder.layer.{i}.attention.self.query.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.query_layer.bias'
- weight_map[f'bert.encoder.layer.{i}.attention.self.key.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.key_layer.weight'
- weight_map[f'bert.encoder.layer.{i}.attention.self.key.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.key_layer.bias'
- weight_map[f'bert.encoder.layer.{i}.attention.self.value.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.value_layer.weight'
- weight_map[f'bert.encoder.layer.{i}.attention.self.value.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.attention.value_layer.bias'
- weight_map[f'bert.encoder.layer.{i}.attention.output.dense.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.output.dense.weight'
- weight_map[f'bert.encoder.layer.{i}.attention.output.dense.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.output.dense.bias'
- weight_map[f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.output.layernorm.gamma'
- weight_map[f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.attention.output.layernorm.beta'
- weight_map[f'bert.encoder.layer.{i}.intermediate.dense.weight'] = f'base_encoder.bert.bert_encoder.layers.{i}.intermediate.weight'
- weight_map[f'bert.encoder.layer.{i}.intermediate.dense.bias'] = f'base_encoder.bert.bert_encoder.layers.{i}.intermediate.bias'
- weight_map[f'bert.encoder.layer.{i}.output.dense.weight'] = f'base_encoder.bert.bert_encoder.layers.{i}.output.dense.weight'
- weight_map[f'bert.encoder.layer.{i}.output.dense.bias'] = f'base_encoder.bert.bert_encoder.layers.{i}.output.dense.bias'
- weight_map[f'bert.encoder.layer.{i}.output.LayerNorm.weight'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.output.layernorm.gamma'
- weight_map[f'bert.encoder.layer.{i}.output.LayerNorm.bias'] = \
- f'base_encoder.bert.bert_encoder.layers.{i}.output.layernorm.beta'
-
- weight_map.update(
- {
- 'bert.pooler.dense.weight': 'base_encoder.bert.dense.weight',
- 'bert.pooler.dense.bias': 'base_encoder.bert.dense.bias',
- }
- )
-
- # trans_model
- weight_map.update(
- {
- "position_embedding.weight": "trans_model.position_embedding.embedding_table",
- "action_embedding.weight": "trans_model.action_embedding.embedding_table",
- "stack_cell.weight_ih_l0": "trans_model.stack_cell.weight_ih_l0",
- "stack_cell.weight_hh_l0": "trans_model.stack_cell.weight_hh_l0",
- "stack_cell.bias_ih_l0": "trans_model.stack_cell.bias_ih_l0",
- "stack_cell.bias_hh_l0": "trans_model.stack_cell.bias_hh_l0",
- "stack_cell.weight_ih_l0_reverse": "trans_model.stack_cell.weight_ih_l0_reverse",
- "stack_cell.weight_hh_l0_reverse": "trans_model.stack_cell.weight_hh_l0_reverse",
- "stack_cell.bias_ih_l0_reverse": "trans_model.stack_cell.bias_ih_l0_reverse",
- "stack_cell.bias_hh_l0_reverse": "trans_model.stack_cell.bias_hh_l0_reverse",
- "buffer_cell.weight_ih_l0": "trans_model.buffer_cell.weight_ih_l0",
- "buffer_cell.weight_hh_l0": "trans_model.buffer_cell.weight_hh_l0",
- "buffer_cell.bias_ih_l0": "trans_model.buffer_cell.bias_ih_l0",
- "buffer_cell.bias_hh_l0": "trans_model.buffer_cell.bias_hh_l0",
- "buffer_cell.weight_ih_l0_reverse": "trans_model.buffer_cell.weight_ih_l0_reverse",
- "buffer_cell.weight_hh_l0_reverse": "trans_model.buffer_cell.weight_hh_l0_reverse",
- "buffer_cell.bias_ih_l0_reverse": "trans_model.buffer_cell.bias_ih_l0_reverse",
- "buffer_cell.bias_hh_l0_reverse": "trans_model.buffer_cell.bias_hh_l0_reverse",
- "action_cell.weight_ih_l0": "trans_model.action_cell.weight_ih_l0",
- "action_cell.weight_hh_l0": "trans_model.action_cell.weight_hh_l0",
- "action_cell.bias_ih_l0": "trans_model.action_cell.bias_ih_l0",
- "action_cell.bias_hh_l0": "trans_model.action_cell.bias_hh_l0",
- "single_MLP.0.weight": "trans_model.single_MLP.0.weight",
- "single_MLP.0.bias": "trans_model.single_MLP.0.bias",
- "single_MLP.1.weight": "trans_model.single_MLP.1.gamma",
- "single_MLP.1.bias": "trans_model.single_MLP.1.beta",
- "single_MLP.1.running_mean": "trans_model.single_MLP.1.moving_mean",
- "single_MLP.1.running_var": "trans_model.single_MLP.1.moving_variance",
- "single_MLP.4.weight": "trans_model.single_MLP.4.weight",
- "single_MLP.4.bias": "trans_model.single_MLP.4.bias",
- "single_MLP.5.weight": "trans_model.single_MLP.5.gamma",
- "single_MLP.5.bias": "trans_model.single_MLP.5.beta",
- "single_MLP.5.running_mean": "trans_model.single_MLP.5.moving_mean",
- "single_MLP.5.running_var": "trans_model.single_MLP.5.moving_variance",
- "single_MLP.8.weight": "trans_model.single_MLP.8.weight",
- "single_MLP.8.bias": "trans_model.single_MLP.8.bias",
- "tuple_MLP.0.weight": "trans_model.tuple_MLP.0.weight",
- "tuple_MLP.0.bias": "trans_model.tuple_MLP.0.bias",
- "tuple_MLP.1.weight": "trans_model.tuple_MLP.1.gamma",
- "tuple_MLP.1.bias": "trans_model.tuple_MLP.1.beta",
- "tuple_MLP.1.running_mean": "trans_model.tuple_MLP.1.moving_mean",
- "tuple_MLP.1.running_var": "trans_model.tuple_MLP.1.moving_variance",
- "tuple_MLP.4.weight": "trans_model.tuple_MLP.4.weight",
- "tuple_MLP.4.bias": "trans_model.tuple_MLP.4.bias",
- "tuple_MLP.5.weight": "trans_model.tuple_MLP.5.gamma",
- "tuple_MLP.5.bias": "trans_model.tuple_MLP.5.beta",
- "tuple_MLP.5.running_mean": "trans_model.tuple_MLP.5.moving_mean",
- "tuple_MLP.5.running_var": "trans_model.tuple_MLP.5.moving_variance",
- "tuple_MLP.8.weight": "trans_model.tuple_MLP.8.weight",
- "tuple_MLP.8.bias": "trans_model.tuple_MLP.8.bias",
- }
- )
- return weight_map
-
-
- def extract_and_convert(input_dir, output_dir):
- """extract ckpt and convert"""
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- print('=' * 20 + 'extract weights' + '=' * 20)
- state_dict = []
- weight_map = build_params_map()
- pytorch_bert_params = torch.load(os.path.join(input_dir, 'with_reversal_bertmodel_pre_0.7438423645320197_rec_0.7023255813953488_f1_0.7224880382775121.mdl'), map_location=torch.device('cpu'))
- pytorch_trans_params = torch.load(os.path.join(input_dir, 'with_reversal_transmodel_pre_0.7438423645320197_rec_0.7023255813953488_f1_0.7224880382775121.mdl'), map_location=torch.device('cpu'))
- for weight_name, weight_value in pytorch_bert_params.items():
- if weight_name not in weight_map.keys():
- continue
- state_dict.append({'name': weight_map[weight_name],
- 'data': mindspore.Tensor(weight_value.numpy())})
- print(weight_name, '->', weight_map[weight_name], weight_value.shape)
- for weight_name, weight_value in pytorch_trans_params.items():
- if weight_name not in weight_map.keys():
- continue
- state_dict.append({'name': weight_map[weight_name], 'data': Tensor(weight_value.numpy())})
- print(weight_name, '->', weight_map[weight_name], weight_value.shape)
- save_checkpoint(state_dict, os.path.join(output_dir, "mindspore_net.ckpt"))
-
-
- def run_convert():
- """run convert"""
- parser = argparse.ArgumentParser(description="run convert")
- parser.add_argument("--input_dir", type=str, default="./Data/best_pytorch_ckpt", help="pytorch ckpt dir")
- parser.add_argument("--output_dir", type=str, default="./Data/best_mindspore_ckpt", help="Converted model dir")
- args_opt = parser.parse_args()
- extract_and_convert(args_opt.input_dir, args_opt.output_dir)
-
-
- if __name__ == '__main__':
- run_convert()
|