|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """run script"""
-
- import time
- import os
- import json
- import argparse
- import numpy as np
-
- import mindspore.context as context
- from mindspore.common.tensor import Tensor
- from mindspore import load_checkpoint
-
- from data.feature.feature_extraction import process_features
- from data.tools.data_process import data_process
- from commons.generate_pdb import to_pdb, from_prediction
- from commons.utils import compute_confidence
- from model import AlphaFold
- from config import config, global_config
-
- parser = argparse.ArgumentParser(description='Inputs for run.py')
- parser.add_argument('--seq_length', help='padding sequence length')
- parser.add_argument('--input_fasta_path', help='Path of FASTA files folder directory to be predicted.')
- parser.add_argument('--msa_result_path', help='Path to save msa result.')
- parser.add_argument('--database_dir', help='Path of data to generate msa.')
- parser.add_argument('--database_envdb_dir', help='Path of expandable data to generate msa.')
- parser.add_argument('--hhsearch_binary_path', help='Path of hhsearch executable.')
- parser.add_argument('--pdb70_database_path', help='Path to pdb70.')
- parser.add_argument('--template_mmcif_dir', help='Path of template mmcif.')
- parser.add_argument('--max_template_date', help='Maximum template release date.')
- parser.add_argument('--kalign_binary_path', help='Path to kalign executable.')
- parser.add_argument('--obsolete_pdbs_path', help='Path to obsolete pdbs path.')
- parser.add_argument('--checkpoint_path', help='Path of the checkpoint.')
- parser.add_argument('--device_id', default=0, type=int, help='Device id to be used.')
- args = parser.parse_args()
-
- if __name__ == "__main__":
- context.set_context(mode=context.GRAPH_MODE,
- device_target="Ascend",
- variable_memory_max_size="31GB",
- device_id=args.device_id,
- save_graphs=False)
- model_name = "model_1"
- model_config = config.model_config(model_name)
- num_recycle = model_config.model.num_recycle
- global_config = global_config.global_config(args.seq_length)
- extra_msa_length = global_config.extra_msa_length
- fold_net = AlphaFold(model_config, global_config)
-
- load_checkpoint(args.checkpoint_path, fold_net)
-
- seq_files = os.listdir(args.input_fasta_path)
-
- for seq_file in seq_files:
- t1 = time.time()
- seq_name = seq_file.split('.')[0]
- input_features = data_process(seq_name, args)
- tensors, aatype, residue_index, ori_res_length = process_features(
- raw_features=input_features, config=model_config, global_config=global_config)
- prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
- prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
- prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
- """
- :param::@sequence_length
- """
- t2 = time.time()
- for i in range(num_recycle+1):
- tensors_i = [tensor[i] for tensor in tensors]
- input_feats = [Tensor(tensor) for tensor in tensors_i]
- final_atom_positions, final_atom_mask, predicted_lddt_logits,\
- prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
- prev_pos,
- prev_msa_first_row,
- prev_pair)
-
- t3 = time.time()
-
- final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
- final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
- predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]
-
- confidence = compute_confidence(predicted_lddt_logits)
- unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
- pdb_file = to_pdb(unrelaxed_protein)
-
- seq_length = aatype.shape[-1]
- os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)
-
- with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
- f.write(pdb_file)
- t4 = time.time()
- timings = {"pre_process_time": round(t2 - t1, 2),
- "model_time": round(t3 - t2, 2),
- "pos_process_time": round(t4 - t3, 2),
- "all_time": round(t4 - t1, 2),
- "confidence": confidence}
- print(timings)
- with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
- f.write(json.dumps(timings))
|