|
- import os
- from src.mol_tree import Vocab
- from src.jtnn_vae import JTNNVAE
- from src.model_utils.config import config
-
- import random
- import rdkit
- import rdkit.Chem as Chem
- from tqdm import tqdm
- from mindspore import context
- from mindspore.common import set_seed
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
-
- ### Copy single dataset from obs to inference image ###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- return
-
- ### Copy ckpt file from obs to inference image###
- ### To operate on folders, use mox.file.copy_parallel. If copying a file.
- ### Please use mox.file.copy to operate the file, this operation is to operate the file
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
-
- ### Copy the output result to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
- def eval():
- vocab = [x.strip("\r\n ") for x in open(os.path.join(config.raw_data_dir, "zinc/vocab.txt"))]
- vocab = Vocab(vocab)
-
- model = JTNNVAE(vocab, config.hidden_size, config.latent_size, config.depth, stereo=True)
- param_dict = load_checkpoint(config.ckpt_path)
- load_param_into_net(model, param_dict)
- model.set_train(False)
-
- data = []
- with open(os.path.join(config.raw_data_dir, "zinc/test.txt")) as f:
- for line in f:
- s = line.strip("\r\n ").split()[0]
- data.append(s)
-
-
- acc = 0.0
- for smiles in tqdm(data, total=len(data)):
- mol = Chem.MolFromSmiles(smiles)
- smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
-
- dec_smiles = model.reconstruct(smiles3D)
- if dec_smiles == smiles3D:
- acc += 1
-
- acc /= len(data)
- with open(config.acclog_path, "w+") as acc_file:
- acc_file.write(f'reconstruction accuracy: {acc}')
- acc_file.close()
- print(f'reconstruction accuracy: {acc}')
-
-
-
- if __name__ == "__main__":
- lg = rdkit.RDLogger.logger()
- lg.setLevel(rdkit.RDLogger.CRITICAL)
-
- set_seed(1)
- random.seed(1)
- context.set_context(mode=context.PYNATIVE_MODE)
- device_id = int(os.getenv('DEVICE_ID', 0))
- context.set_context(device_target=config.device_target, device_id=device_id)
-
- if config.enable_modelarts:
- import moxing as mox
- ###Initialize the data and result directories in the inference image###
- if not os.path.exists(config.data_dir):
- os.makedirs(config.data_dir)
- if not os.path.exists(config.result_dir):
- os.makedirs(config.result_dir)
-
- ###Copy dataset from obs to inference image
- ObsToEnv(config.data_url, config.data_dir)
- ###Copy ckpt file from obs to inference image
- ObsUrlToEnv(config.ckpt_url, config.ckpt_dir)
-
- config.raw_data_dir = config.data_dir
- config.ckpt_path = config.ckpt_dir
- config.acclog_path = os.path.join(config.result_dir, config.acclog_path)
-
- eval()
-
- ###Copy result data from the local running environment back to obs,
- ###and download it in the inference task corresponding to the Qizhi platform
- if config.enable_modelarts:
- EnvToObs(config.result_dir, config.result_url)
|