|
- import ddpm as ddpm
- import argparse
- import os
- from mindspore import context
- from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- import time
- from upload import UploadOutput
- import moxing as mox
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description="train ddpm",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
- parser.add_argument('--pretrain_path',
- type=str,
- default=None,
- help='the pretrain model path')
-
- parser.add_argument('--data_url',
- type=str,
- default="C:\\Users\\Administrator\\PycharmProjects\\DDPM\\datasets\\test",
- help='training data file path')
-
- parser.add_argument('--train_url',
- default='./results',
- type=str,
- help='the path model and fig save path')
-
- parser.add_argument('--steps',
- default=20000,
- type=int,
- help='training steps')
-
- parser.add_argument('--save_every',
- default=5000,
- type=int,
- help='save_every')
-
- parser.add_argument('--num_samples',
- default=4,
- type=int,
- help='num_samples must have a square root, like 4, 9, 16 ...')
-
- parser.add_argument('--device_target',
- default="Ascend",
- type=str,
- help='device target')
- parser.add_argument('--image_size',
- default=200,
- type=int,
- help='image size')
- args, _ = parser.parse_known_args()
- return args
-
-
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
-
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
-
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
-
-
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
-
- def DownloadFromQizhi(obs_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- ObsToEnv(obs_data_url, data_dir)
- # context.set_context(device_target=args_opt.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE,
- device_target=args_opt.device_target,
- device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num=device_num,
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- parameter_broadcast=True)
- init()
-
- local_rank = int(os.getenv('RANK_ID'))
- if local_rank % 8 == 0:
- ObsToEnv(obs_data_url, data_dir)
-
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank = int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank % 8 == 0:
- EnvToObs(train_dir, obs_train_url)
- return
-
-
- def train_ddpm():
- steps = args_opt.steps
- image_size = args_opt.image_size
-
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- ckpt_url = '/cache/checkpoint.ckpt'
- try:
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- except Exception as e:
- print("path already exists")
-
- # ObsUrlToEnv(args_opt.ckpt_url, ckpt_url)
- DownloadFromQizhi(args_opt.data_url, data_dir)
-
- print("List /cache/data: ", os.listdir(data_dir))
- model = ddpm.Unet(
- dim=image_size,
- out_dim=3,
- dim_mults=(1, 2, 4, 8)
- )
-
- diffusion = ddpm.GaussianDiffusion(
- model,
- image_size=image_size,
- timesteps=20, # number of time steps
- sampling_timesteps=10,
- loss_type='l2' # L1 or L2
- )
-
- trainer = ddpm.Trainer(
- diffusion,
- os.path.join(data_dir, 'test'),
- train_batch_size=1,
- train_lr=8e-5,
- train_num_steps=steps, # total training steps
- gradient_accumulate_every=1, # gradient accumulation steps
- ema_decay=0.995, # exponential moving average decay
- save_and_sample_every=args_opt.save_every, # image sampling and step
- num_samples=4,
- results_folder=train_dir,
- distributed=False
- )
- if args_opt.pretrain_path:
- trainer.load(args_opt.pretrain_path)
- trainer.train()
- UploadToQizhi(train_dir, args_opt.train_url)
-
-
- if __name__ == '__main__':
- args_opt = parse_args()
- train_ddpm()
|