|
- import ddpm as ddpm
- import argparse
- import os
- from mindspore import context
- from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- import time
- from upload import UploadOutput
- import moxing as mox
- import json
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description="train ddpm",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
- parser.add_argument('--pretrain_url',
- type=str,
- default=None,
- help='the pretrain model path')
-
- parser.add_argument('--multi_data_url',
- type=str,
- default="C:\\Users\\Administrator\\PycharmProjects\\DDPM\\datasets\\test",
- help='training data file path')
-
- parser.add_argument('--train_url',
- default='./results',
- type=str,
- help='the path model and fig save path')
-
- parser.add_argument('--steps',
- default=20000,
- type=int,
- help='training steps')
-
- parser.add_argument('--save_every',
- default=20000,
- type=int,
- help='save every')
-
- parser.add_argument('--num_samples',
- default=4,
- type=int,
- help='num_samples must have a square root, like 4, 9, 16 ...')
-
- parser.add_argument('--device_target',
- default="Ascend",
- type=str,
- help='device target')
-
- parser.add_argument('--image_size',
- default=64,
- type=int,
- help='image size')
- args, _ = parser.parse_known_args()
- return args
-
-
- def C2netModelToEnv(model_url, model_dir):
- #--ckpt_url is json data, need to do json parsing for ckpt_url_json
- model_url_json = json.loads(model_url)
- print("model_url_json:",model_url_json)
- for i in range(len(model_url_json)):
- modelfile_path = model_dir + "/" + "checkpoint.ckpt"
- try:
- mox.file.copy(model_url_json[i]["model_url"], modelfile_path)
- print("Successfully Download {} to {}".format(model_url_json[i]["model_url"],modelfile_path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- model_url_json[i]["model_url"], modelfile_path) + str(e))
- return
-
-
- def C2netMultiObsToEnv(multi_data_url, data_dir):
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- zipfile_path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- try:
- mox.file.copy(multi_data_json[i]["dataset_url"], zipfile_path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],zipfile_path))
- filename = os.path.splitext(multi_data_json[i]["dataset_name"])[0]
- filePath = data_dir + "/" + filename
- if not os.path.exists(filePath):
- os.makedirs(filePath)
- os.system("unzip {} -d {}".format(zipfile_path, filePath))
-
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], zipfile_path) + str(e))
-
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
-
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
-
-
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- C2netMultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args_opt.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- local_rank = int(os.getenv('RANK_ID'))
- if local_rank % 8 == 0:
- C2netMultiObsToEnv(multi_data_url,data_dir)
-
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
- def DownloadModelFromQizhi(model_url, model_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- C2netModelToEnv(model_url,model_dir)
- if device_num > 1:
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- C2netModelToEnv(model_url,model_dir)
- return
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank = int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank % 8 == 0:
- EnvToObs(train_dir, obs_train_url)
- return
-
-
- def train_ddpm():
- steps = args_opt.steps
- image_size = args_opt.image_size
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- model_dir = '/cache/pretrain'
- try:
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- except Exception as e:
- print("path already exists")
-
- DownloadModelFromQizhi(args_opt.pretrain_url, model_dir)
- DownloadFromQizhi(args_opt.multi_data_url, data_dir)
-
- print("List /cache/data: ", os.listdir(data_dir))
- model = ddpm.Unet(
- dim=image_size,
- out_dim=3,
- dim_mults=(1, 2, 4, 8)
- )
-
- diffusion = ddpm.GaussianDiffusion(
- model,
- image_size=image_size,
- timesteps=1000, # number of time steps
- sampling_timesteps=50,
- loss_type='l2' # L1 or L2
- )
-
- trainer = ddpm.Trainer(
- diffusion,
- os.path.join(data_dir, 'zitirefer/zitirefer'),
- train_batch_size=2,
- train_lr=8e-7,
- train_num_steps=steps, # total training steps
- gradient_accumulate_every=1, # gradient accumulation steps
- ema_decay=0.995, # exponential moving average decay
- save_and_sample_every=args_opt.save_every, # image sampling and step
- num_samples=4,
- results_folder=train_dir,
- amp_level="O1",
- distributed=False
- )
- #if args_opt.pretrain_url:
- # trainer.load(os.path.join(model_dir,'checkpoint.ckpt'))
- trainer.train()
- #UploadToQizhi(train_dir, args_opt.train_url)
-
-
- if __name__ == '__main__':
- args_opt = parse_args()
- train_ddpm()
|