|
- import argparse
- import os.path
-
- def main(args):
- import json, time, os, sys, glob
- import shutil
- import warnings
- import numpy as np
- import mindspore as ms
- import mindspore.ops as ops
- import mindspore.dataset as ds
- import queue
- import copy
- import pickle
- import mindspore.nn as nn
- import random
- import os.path
- import subprocess
- from concurrent.futures import ProcessPoolExecutor
- from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, tied_featurize, \
- loss_smoothed, loss_nll, get_std_opt, featurize, lr_list
- from model_ascend_ import ProteinMPNN
- from datasets import StructureDataset, StructureLoader,Definebatch
- import torch
- torch.backends.cudnn.enabled = False
-
- device = ms.set_context(device_target='GPU', device_id=args.device_id, mode=ms.GRAPH_MODE)
-
- class CustomTrainOneStepCell(nn.Cell):
- """自定义训练网络"""
-
- def __init__(self, network_, optimizer_):
- """入参有两个:训练网络,优化器"""
- super(CustomTrainOneStepCell, self).__init__(auto_prefix=False)
- self.network = network_ # 定义前向网络
- self.network.set_grad() # 构建反向网络
- self.optimizer_ = optimizer_ # 定义优化器
- self.weights = self.optimizer_.parameters # 待更新参数
- self.grad = ops.GradOperation(get_by_list=True) # 反向传播获取梯度
-
- def construct(self, *inputs):
- X, S, mask, chain_M, residue_idx, chain_encoding_all, mask_for_loss =inputs
- input_ = X, S, mask, chain_M, residue_idx, chain_encoding_all
- loss_ = self.network(*inputs) # 计算当前输入的损失函数值
- grads = self.grad(self.network, self.weights)(*input_, mask_for_loss) # 进行反向传播,计算梯度
- # print(self.optimizer_.get_lr())
- # print(self.optimizer_.get_lr())
- self.optimizer_(grads) # 使用优化器更新权重参数
- # print(self.optimizer_.get_lr())
- return loss_
-
- class CustomWithLossCell(nn.Cell):
- def __init__(self, backbone, loss_fn):
- super(CustomWithLossCell, self).__init__(auto_prefix=False)
- self._backbone = backbone
- self._loss_fn = loss_fn
-
- def construct(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, mask_for_loss):
- output = self._backbone(X, S, mask, chain_M, residue_idx, chain_encoding_all)
- loss_av=self._loss_fn(S, output, mask_for_loss)
- return loss_av
-
- base_folder = time.strftime(args.path_for_outputs, time.localtime())
-
- if base_folder[-1] != '/':
- base_folder += '/'
- if not os.path.exists(base_folder):
- os.makedirs(base_folder)
- subfolders = ['model_weights']
- for subfolder in subfolders:
- if not os.path.exists(base_folder + subfolder):
- os.makedirs(base_folder + subfolder)
-
- PATH = args.previous_checkpoint
-
- logfile = base_folder + 'log.txt'
- if not PATH:
- with open(logfile, 'w') as f:
- f.write('Epoch\tTrain\tValidation\n')
-
- data_path = args.path_for_training_data
- params = {
- "LIST": f"{data_path}/list.csv",
- "VAL": f"{data_path}/valid_clusters.txt",
- "TEST": f"{data_path}/test_clusters.txt",
- "DIR": f"{data_path}",
- "DATCUT": "2030-Jan-01",
- "RESCUT": args.rescut, # resolution cutoff for PDBs
- "HOMO": 0.70 # min seq.id. to detect homo chains
- }
-
- LOAD_PARAM = {'batch_size': 1,
- 'shuffle': True,
- 'pin_memory': False,
- 'num_workers': 4}
-
- if args.debug:
- args.num_examples_per_epoch = 50
- args.max_protein_length = 1000
- args.batch_size = 1000
-
- # train, valid, test = build_training_clusters(params, args.debug)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/train_sample.pkl", 'wb') as f:
- # pickle.dump(train, f)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/valid_sample.pkl", 'wb') as f:
- # pickle.dump(valid, f)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/test_sample.pkl", 'wb') as f:
- # pickle.dump(test, f)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/train_sample.pkl", 'rb') as f_read:
- # train = pickle.load(f_read)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/valid_sample.pkl", 'rb') as f_read:
- # valid = pickle.load(f_read)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/test_sample.pkl", 'rb') as f_read:
- # test = pickle.load(f_read)
-
- # train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
- # train_set = ds.GeneratorDataset(train_set, ['seq', 'xyz', 'idx', 'masked', 'label'], shuffle=True)
- # train_set = train_set.batch(batch_size=1)
- # train_loader = train_set.create_dict_iterator()
- # for data in train_loader:
- # print("data: \n{}".format(data["seq"]))
- # valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params)
- # valid_set = ds.GeneratorDataset(valid_set, ['seq', 'xyz', 'idx', 'masked', 'label'], shuffle=True,
- # num_parallel_workers=4)
- # valid_set = valid_set.batch(batch_size=1)
- # valid_loader = valid_set.create_dict_iterator()
-
-
- model = ProteinMPNN(num_letters=21,
- node_features=args.hidden_dim,
- edge_features=args.hidden_dim,
- hidden_dim=args.hidden_dim,
- num_encoder_layers=args.num_encoder_layers,
- num_decoder_layers=args.num_encoder_layers,
- k_neighbors=args.num_neighbors,
- dropout=args.dropout,
- augment_eps=args.backbone_noise)
- loss = loss_smoothed()
- net_with_loss = CustomWithLossCell(model, loss)
- # ms.load_checkpoint('/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/model_weights/v_48_020.ckpt', model)
- if PATH:
- checkpoint = ms.load_checkpoint(PATH)
- total_step = checkpoint['step'] # write total_step from the checkpoint
- epoch = checkpoint['epoch'] # write epoch from the checkpoint
- model.load_state_dict(checkpoint['model_state_dict'])
- else:
- total_step = 0
- epoch = 0
- lr_list = lr_list(args.hidden_dim, 2, 4000)
- lr=lr_list.cal_lr(args.num_epochs)
- optimizer = nn.Adam(model.trainable_params(), learning_rate=ms.Tensor(lr), beta1=0.9, beta2=0.98, eps=1e-9)
- if PATH:
- optimizer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-
- train_net = CustomTrainOneStepCell(net_with_loss, optimizer)
- # for i in range(3):
- # q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
- # p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
- # pdb_dict_train = get_pdbs(train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)
- # pdb_dict_valid = get_pdbs(valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/pdb_dict_train.pkl",'wb') as f:
- # pickle.dump(pdb_dict_train, f)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/pdb_dict_valid.pkl",'wb') as f:
- # pickle.dump(pdb_dict_valid, f)
- with open("/home/huzhuping/mpnn_train1/mpnn_final/pdb_dict_train_sample_torch.pkl", 'rb') as f_read:
- pdb_dict_train = pickle.load(f_read)
- # with open("/home/zhaoyue/Huawei_BiologicalComputing/MindSpore_mpnn/ProteinMPNN_training/ProteinMPNN/pdb_2021aug02_sample/pdb_dict_valid.pkl", 'rb') as f_read:
- # pdb_dict_valid = pickle.load(f_read)
- # with ProcessPoolExecutor(max_workers=12) as executor:
- # q = queue.Queue(maxsize=3)
- # p = queue.Queue(maxsize=3)
- # for i in range(3):
- # q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
- # p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
- # pdb_dict_train = q.get().result()
- # pdb_dict_valid = p.get().result()
-
- dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
- # dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length)
-
- loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)
- # loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size)
-
- reload_c = 0
- for e in range(args.num_epochs):
- t0 = time.time()
- e = epoch + e
- train_net.set_train()
- train_sum, train_weights = 0., 0.
- train_acc = 0.
- # if e % args.reload_data_every_n_epochs == 0:
- # if reload_c != 0:
- # pdb_dict_train = get_pdbs(train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)
- # dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
- # loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)
- # # pdb_dict_valid = get_pdbs(valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)
- # # dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length)
- # # loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size)
- # reload_c += 1
- for _, batch in enumerate(loader_train):
-
- start_batch = time.time()
- X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
- elapsed_featurize = time.time() - start_batch
- mask_for_loss = mask * chain_M
- # print(optimizer.get_lr())
- # print(optimizer.get_lr())
- # X=ms.Tensor(np.load('X_torch.npy'))
- # S = ms.Tensor(np.load('S_torch.npy'))
- # mask = ms.Tensor(np.load('mask_torch.npy'))
- # chain_M = ms.Tensor(np.load('chain_M_torch.npy'))
- # residue_idx = ms.Tensor(np.load('residue_idx_torch.npy'))
- # chain_encoding_all = ms.Tensor(np.load('chain_encoding_all_torch.npy'))
- # mask_for_loss = ms.Tensor(np.load('mask_for_loss_torch.npy'))
- loss_av_smoothed = train_net(X, S, mask, chain_M, residue_idx, chain_encoding_all, mask_for_loss)
- log_probs=train_net.network._backbone(X, S, mask, chain_M, residue_idx, chain_encoding_all)
- # print(optimizer.get_lr())
- loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
-
- op = ms.ops.ReduceSum()
- train_sum += op(loss * mask_for_loss).asnumpy()
- train_acc += op(true_false * mask_for_loss).asnumpy()
- train_weights += op(mask_for_loss).asnumpy()
- total_step += 1
-
- train_net.set_train(False)
-
- # validation_sum, validation_weights = 0., 0.
- # validation_acc = 0.
- # for _, batch in enumerate(loader_valid):
- # X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
- # log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
- # mask_for_loss = mask * chain_M
- # loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
- # op = ms.ops.ReduceSum(keep_dims=True)
- # validation_sum += op(loss * mask_for_loss).asnumpy()
- # validation_acc += op(true_false * mask_for_loss).asnumpy()
- # validation_weights += op(mask_for_loss).asnumpy()
-
- train_loss = train_sum / train_weights
- train_accuracy = train_acc / train_weights
- train_perplexity = np.exp(train_loss)
- # validation_loss = validation_sum / validation_weights
- # validation_accuracy = validation_acc / validation_weights
- # validation_perplexity = np.exp(validation_loss)
- train_accuracy_ = np.format_float_positional(np.float32(train_accuracy), unique=False, precision=3)
- # validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3)
-
- t1 = time.time()
- dt = np.format_float_positional(np.float32(t1 - t0), unique=False, precision=1)
- # with open(logfile, 'a') as f:
- # f.write(
- # f'epoch: {e + 1}, step: {total_step}, time: {dt}, train: {train_perplexity}, valid: {validation_perplexity}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}\n')
- # print(
- # f'epoch: {e + 1}, step: {total_step}, time: {dt}, train: {train_perplexity}, valid: {validation_perplexity}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}')
- with open(logfile, 'a') as f:
- f.write(
- f'epoch: {e + 1}, step: {total_step}, time: {dt}, train: {train_perplexity}, train_acc: {train_accuracy_}\n')
- print(
- f'epoch: {e + 1}, step: {total_step}, time: {dt}, train: {train_perplexity}, train_acc: {train_accuracy_}')
- # print(optimizer.get_lr())
- checkpoint_filename_last = base_folder + 'model_weights/epoch_last.pt'.format(e + 1, total_step)
- # ms.save_checkpoint({
- # 'epoch': e + 1,
- # 'step': total_step,
- # 'num_edges': args.num_neighbors,
- # 'noise_level': args.backbone_noise,
- # 'model_state_dict': model.state_dict(),
- # 'optimizer_state_dict': optimizer.optimizer.state_dict(),
- # }, checkpoint_filename_last)
-
- # if (e + 1) % args.save_model_every_n_epochs == 0:
- # checkpoint_filename = base_folder + 'model_weights/epoch{}_step{}.pt'.format(e + 1, total_step)
- # ms.save_checkpoint({
- # 'epoch': e + 1,
- # 'step': total_step,
- # 'num_edges': args.num_neighbors,
- # 'noise_level': args.backbone_noise,
- # 'model_state_dict': model.state_dict(),
- # 'optimizer_state_dict': optimizer.optimizer.state_dict(),
- # }, checkpoint_filename)
-
-
- if __name__ == "__main__":
- argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
- argparser.add_argument("--path_for_training_data", type=str,
- default="/home/huzhuping/mpnn_train1/mpnn_final/pdb_2021aug02_sample",
- help="path for loading training data")
- argparser.add_argument("--path_for_outputs", type=str,
- default="/home/huzhuping/mpnn_train1/mpnn_final/exp_020/",
- help="path for logs and model weights")
- argparser.add_argument("--previous_checkpoint", type=str, default="",
- help="path for previous model weights, e.g. file.pt")
- argparser.add_argument("--num_epochs", type=int, default=200, help="number of epochs to train for")
- argparser.add_argument("--save_model_every_n_epochs", type=int, default=10,
- help="save model weights every n epochs")
- argparser.add_argument("--reload_data_every_n_epochs", type=int, default=2,
- help="reload training data every n epochs")
- argparser.add_argument("--num_examples_per_epoch", type=int, default=1000000,
- help="number of training example to load for one epoch")
- argparser.add_argument("--batch_size", type=int, default=10000, help="number of tokens for one batch")
- argparser.add_argument("--max_protein_length", type=int, default=52, #10000
- help="maximum length of the protein complext")
- argparser.add_argument("--hidden_dim", type=int, default=128, help="hidden model dimension")
- argparser.add_argument("--num_encoder_layers", type=int, default=3, help="number of encoder layers")
- argparser.add_argument("--num_decoder_layers", type=int, default=3, help="number of decoder layers")
- argparser.add_argument("--num_neighbors", type=int, default=48, help="number of neighbors for the sparse graph")
- argparser.add_argument("--dropout", type=float, default=0.1, help="dropout level; 0.0 means no dropout")
- argparser.add_argument("--backbone_noise", type=float, default=0.2,
- help="amount of noise added to backbone during training")
- argparser.add_argument("--rescut", type=float, default=3.5, help="PDB resolution cutoff")
- argparser.add_argument("--debug", type=bool, default=False, help="minimal data loading for debugging")
- argparser.add_argument("--gradient_norm", type=float, default=-1.0,
- help="clip gradient norm, set to negative to omit clipping")
- argparser.add_argument("--mixed_precision", type=bool, default=False, help="train with mixed precision")
- argparser.add_argument('--device_id', help='device id', type=int, default=3)
-
- args = argparser.parse_args()
- main(args)
|