|
- # Copyright 2023 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import os
-
- os.system("pip install OmegaConf -i https://pypi.tuna.tsinghua.edu.cn/simple/")
- os.system("pip install imagesize -i https://pypi.tuna.tsinghua.edu.cn/simple/")
- os.system("pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple/")
- os.system("pip install ftfy -i https://pypi.tuna.tsinghua.edu.cn/simple/")
- os.system("pip install regex -i https://pypi.tuna.tsinghua.edu.cn/simple/")
-
-
- # try:
- # """配置环境的,其中 requirements.txt 文件里面主要写了一些所需要的环境包"""
-
- # os.system(f"pip install -r /cache/code/draw_pic/requirements.txt")
-
- # print("环境安装成功")
- # except BaseException:
- # print("环境安装失败")
- from openi import openi_multidataset_to_env as DatasetToEnv
-
- import time
- import argparse
- import importlib
-
- import mindspore as ms
- from omegaconf import OmegaConf
- from mindspore import Model, context
- from mindspore import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore.nn import TrainOneStepWithLossScaleCell
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
-
- from ldm.data.dataset_db import load_data
- from ldm.models.clip_zh.simple_tokenizer import WordpieceTokenizer
- from ldm.modules.train.optim import build_optimizer
- from ldm.modules.train.callback import OverflowMonitor
- from ldm.modules.train.learningrate import LearningRate
- from ldm.modules.train.parallel_config import ParallelConfig
- from ldm.modules.train.tools import parse_with_config, set_random_seed
- from ldm.modules.train.cell_wrapper import ParallelTrainOneStepWithLossScaleCell
-
- import moxing as mox
- import json
-
- os.environ['HCCL_CONNECT_TIMEOUT'] = '6000'
-
-
- def C2netModelToEnv(model_url, model_dir):
- # --ckpt_url is json data, need to do json parsing for ckpt_url_json
- model_url_json = json.loads(model_url)
- print("model_url_json:", model_url_json)
- for i in range(len(model_url_json)):
- modelfile_path = model_dir + "/" + "checkpoint.ckpt"
- try:
- mox.file.copy(model_url_json[i]["model_url"], modelfile_path)
- print(
- "Successfully Download {} to {}".format(
- model_url_json[i]["model_url"],
- modelfile_path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- model_url_json[i]["model_url"], modelfile_path) + str(e))
- return
-
-
- def C2netMultiObsToEnv(multi_data_url, data_dir):
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- zipfile_path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- try:
- mox.file.copy(multi_data_json[i]["dataset_url"], zipfile_path)
- print(
- "Successfully Download {} to {}".format(
- multi_data_json[i]["dataset_url"],
- zipfile_path))
- filename = os.path.splitext(multi_data_json[i]["dataset_name"])[0]
- filePath = data_dir + "/" + filename
- if not os.path.exists(filePath):
- os.makedirs(filePath)
- os.system("unzip {} -d {}".format(zipfile_path, filePath))
-
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], zipfile_path) + str(e))
-
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
-
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
-
-
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- C2netMultiObsToEnv(multi_data_url, data_dir)
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
- device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True, parameter_broadcast=True)
- init()
- local_rank = int(os.getenv('RANK_ID'))
- if local_rank % 8 == 0:
- C2netMultiObsToEnv(multi_data_url, data_dir)
-
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
-
- def DownloadModelFromQizhi(model_url, model_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- C2netModelToEnv(model_url, model_dir)
- if device_num > 1:
- # Copying obs data does not need to be executed multiple times, just
- # let the 0th card copy the data
- local_rank = int(os.getenv('RANK_ID'))
- if local_rank % 8 == 0:
- C2netModelToEnv(model_url, model_dir)
- return
-
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank = int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank % 8 == 0:
- EnvToObs(train_dir, obs_train_url)
- return
-
-
- def init_env(opts):
- """ init_env """
- data_dir = '/cache/data'
-
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
-
- DatasetToEnv(args.multi_data_url, data_dir)
- set_random_seed(opts.seed)
- if opts.use_parallel:
- init()
- device_id = int(os.getenv('DEVICE_ID'))
- device_num = get_group_size()
- ParallelConfig.dp = device_num
- rank_id = get_rank()
- opts.rank = rank_id
- print("device_id is {}, rank_id is {}, device_num is {}".format(
- device_id, rank_id, device_num))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=context.ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- device_num=device_num)
- else:
- device_num = 1
- device_id = int(os.getenv('DEVICE_ID', 0))
- rank_id = 0
- opts.rank = rank_id
-
- context.set_context(mode=context.GRAPH_MODE,
- device_target="Ascend",
- device_id=device_id,
- max_device_memory="30GB",
- )
-
- """ create dataset"""
- tokenizer = WordpieceTokenizer()
- print("data : ", os.listdir(data_dir))
- dataset = load_data(
- train_data_path=os.path.join(data_dir, os.path.join('db_data', 'train')),
- reg_data_path=os.path.join(data_dir, os.path.join('db_data', 'regular')),
- train_data_repeats=opts.train_data_repeats,
- class_word=opts.class_word,
- token=opts.token,
- batch_size=opts.train_batch_size,
- tokenizer=tokenizer,
- image_size=opts.image_size,
- image_filter_size=opts.image_filter_size,
- device_num=device_num,
- random_crop=opts.random_crop,
- rank_id=rank_id,
- sample_num=-1
- )
- print(f"rank id {rank_id}, sample num is {dataset.get_dataset_size()}")
-
- return dataset, rank_id, device_id, device_num
-
-
- def instantiate_from_config(config):
- config = OmegaConf.load(config).model
- if not "target" in config:
- if config == '__is_first_stage__':
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
- def str2bool(b):
- if b.lower() not in ["false", "true"]:
- raise Exception("Invalid Bool Value")
- if b.lower() in ["false"]:
- return False
- return True
-
-
- def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
- def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- model = instantiate_from_config(config.model)
- if os.path.exists(ckpt):
- param_dict = ms.load_checkpoint(ckpt)
- if param_dict:
- param_not_load = ms.load_param_into_net(model, param_dict)
- print("param not load:", param_not_load)
- else:
- print(f"{ckpt} not exist:")
-
- return model
-
-
- def load_pretrained_model(pretrained_ckpt, net):
- print(f"start loading pretrained_ckpt {pretrained_ckpt}")
- if os.path.exists(pretrained_ckpt):
- param_dict = load_checkpoint(pretrained_ckpt)
- param_not_load = load_param_into_net(net, param_dict)
- print("param not load:", param_not_load)
- else:
- print("ckpt file not exist!")
-
- print("end loading ckpt")
-
-
- def load_pretrained_model_clip_and_vae(pretrained_ckpt, net):
- new_param_dict = {}
- print(f"start loading pretrained_ckpt {pretrained_ckpt}")
- if os.path.exists(pretrained_ckpt):
- param_dict = load_checkpoint(pretrained_ckpt)
- for key in param_dict:
- if key.startswith("first") or key.startswith("cond"):
- new_param_dict[key] = param_dict[key]
- param_not_load = load_param_into_net(net, new_param_dict)
- print("param not load:")
- for param in param_not_load:
- print(param)
- else:
- print("ckpt file not exist!")
-
- print("end loading ckpt")
-
-
- def main(opts):
- model_dir = '/cache/pretrain'
- train_dir = '/cache/output'
- try:
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
- except Exception as e:
- print("path already exists")
- DownloadModelFromQizhi(opts.pretrain_url, model_dir)
- print("List /cache/pretrain: ", os.listdir(model_dir))
- dataset, rank_id, device_id, device_num = init_env(opts)
- LatentDiffusionWithLoss = instantiate_from_config(opts.model_config)
- pretrained_ckpt = os.path.join(model_dir, 'checkpoint.ckpt')
- load_pretrained_model(pretrained_ckpt, LatentDiffusionWithLoss)
-
- if not opts.decay_steps:
- dataset_size = dataset.get_dataset_size()
- opts.decay_steps = opts.epochs * dataset_size
- lr = LearningRate(
- opts.start_learning_rate,
- opts.end_learning_rate,
- opts.warmup_steps,
- opts.decay_steps)
- optimizer = build_optimizer(LatentDiffusionWithLoss, opts, lr)
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=opts.init_loss_scale,
- scale_factor=opts.loss_scale_factor,
- scale_window=opts.scale_window)
-
- if opts.use_parallel:
- net_with_grads = ParallelTrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer,
- scale_sense=update_cell, parallel_config=ParallelConfig)
- else:
- net_with_grads = TrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer,
- scale_sense=update_cell)
- model = Model(net_with_grads)
- callback = [TimeMonitor(opts.callback_size), LossMonitor(opts.callback_size)]
-
- ofm_cb = OverflowMonitor()
- callback.append(ofm_cb)
-
- if rank_id == 0:
- dataset_size = dataset.get_dataset_size()
- if not opts.save_checkpoint_steps:
- opts.save_checkpoint_steps = dataset_size
- ckpt_dir = os.path.join(train_dir, "ckpt", f"rank_{str(rank_id)}")
- if not os.path.exists(ckpt_dir):
- os.makedirs(ckpt_dir)
- config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps,
- keep_checkpoint_max=1,
- integrated_save=False)
- ckpoint_cb = ModelCheckpoint(prefix="wkhh_txt2img",
- directory=ckpt_dir,
- config=config_ck)
- callback.append(ckpoint_cb)
-
- print("start_training...")
- model.train(opts.epochs,dataset,callbacks=callback,dataset_sink_mode=False)
- end = time.time()
- print("training time: ", end - start)
- UploadToQizhi(train_dir, opts.output_path)
-
-
- if __name__ == "__main__":
- print('process id:', os.getpid())
- parser = argparse.ArgumentParser()
- parser.add_argument('--multi_data_url',default=" ",type=str,help='qizhi data_path')
- parser.add_argument('--use_parallel',default=False,type=str2bool,help='use parallel')
- parser.add_argument('--data_path',default="dataset",type=str,help='data path')
- parser.add_argument('--output_path',default="./results",type=str,help='use audio out')
- parser.add_argument('--train_config',default="configs/train_db_config.json",type=str,help='train config path')
- parser.add_argument('--model_config', default="configs/v1-train-db-chinese.yaml", type=str,help='model config path')
- parser.add_argument('--pretrain_url',default="",type=str,help='pretrained model directory')
- parser.add_argument('--train_data_path', default="/cache/code/wukong_huahua/db_data/train",type=str,help='train data path')
- parser.add_argument('--reg_data_path', default="/cache/code/wukong_huahua/db_data/regular",type=str,help='regularization data path')
- parser.add_argument('--train_data_repeats',default=100,type=int,help='repetition times of training data')
- parser.add_argument('--class_word', default="", type=str,help='Match class_word to the category of images you want to train')
- parser.add_argument('--token',default="α",type=str,help='unique token you want to represent your trained model')
- parser.add_argument('--optim', default="adamw", type=str, help='optimizer')
- parser.add_argument('--seed', default=3407, type=int, help='data path')
- parser.add_argument('--warmup_steps',default=1000,type=int,help='warmup steps')
- parser.add_argument('--train_batch_size',default=10,type=int,help='batch size')
- parser.add_argument('--callback_size',default=1,type=int,help='callback size.')
- parser.add_argument("--start_learning_rate",default=1e-5,type=float,help="The initial learning rate for Adam.")
- parser.add_argument("--end_learning_rate",default=1e-7,type=float,help="The end learning rate for Adam.")
- parser.add_argument("--decay_steps",default=0,type=int,help="lr decay steps.")
- parser.add_argument("--epochs", default=10, type=int, help="epochs")
- parser.add_argument("--init_loss_scale",default=65536,type=float,help="loss scale")
- parser.add_argument("--loss_scale_factor",default=2,type=float,help="loss scale factor")
- parser.add_argument("--scale_window",default=1000,type=float,help="scale window")
- parser.add_argument("--save_checkpoint_steps",default=0,type=int,help="save checkpoint steps")
- parser.add_argument('--random_crop',default=False,type=str2bool,help='random crop')
- parser.add_argument('--filter_small_size',default=True,type=str2bool,help='filter small images')
- parser.add_argument('--image_size',default=512,type=int,help='images size')
- parser.add_argument('--image_filter_size',default=256,type=int,help='image filter size')
-
- args = parser.parse_args()
- args = parse_with_config(args)
- abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),""))
- args.model_config = os.path.join(abs_path, args.model_config)
- print(args)
- start = time.time()
- main(args)
|