|
- # Copyright 2023 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- import sys
- import time
- import argparse
- import importlib
-
- import albumentations
- import mindspore as ms
- from omegaconf import OmegaConf
- from mindspore import Model, context
- from mindspore import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore.nn import DynamicLossScaleUpdateCell
- from mindspore.nn import TrainOneStepWithLossScaleCell
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
-
- from ldm.data.dataset_db import load_data
- from ldm.models.clip_zh.simple_tokenizer import WordpieceTokenizer
- from ldm.modules.train.optim import build_optimizer
- from ldm.modules.train.callback import OverflowMonitor
- from ldm.modules.train.learningrate import LearningRate
- from ldm.modules.train.parallel_config import ParallelConfig
- from ldm.modules.train.tools import parse_with_config, set_random_seed
- from ldm.modules.train.cell_wrapper import ParallelTrainOneStepWithLossScaleCell
-
-
- os.environ['HCCL_CONNECT_TIMEOUT'] = '6000'
-
-
- def init_env(opts):
- """ init_env """
- set_random_seed(opts.seed)
- if opts.use_parallel:
- init()
- device_id = int(os.getenv('DEVICE_ID'))
- device_num = get_group_size()
- ParallelConfig.dp = device_num
- rank_id = get_rank()
- opts.rank = rank_id
- print("device_id is {}, rank_id is {}, device_num is {}".format(
- device_id, rank_id, device_num))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=context.ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- device_num=device_num)
- else:
- device_num = 1
- device_id = int(os.getenv('DEVICE_ID', 0))
- rank_id = 0
- opts.rank = rank_id
-
- context.set_context(mode=context.GRAPH_MODE,
- device_target="Ascend",
- device_id=device_id,
- max_device_memory="30GB",
- )
-
- """ create dataset"""
- tokenizer = WordpieceTokenizer()
- dataset = load_data(
- train_data_path = opts.train_data_path,
- reg_data_path = opts.reg_data_path,
- train_data_repeats = opts.train_data_repeats,
- class_word = opts.class_word,
- token = opts.token,
- batch_size = opts.train_batch_size,
- tokenizer = tokenizer,
- image_size=opts.image_size,
- image_filter_size=opts.image_filter_size,
- device_num=device_num,
- random_crop=opts.random_crop,
- rank_id=rank_id,
- sample_num=-1
- )
- print(f"rank id {rank_id}, sample num is {dataset.get_dataset_size()}")
-
- return dataset, rank_id, device_id, device_num
-
-
- def instantiate_from_config(config):
- config = OmegaConf.load(config).model
- if not "target" in config:
- if config == '__is_first_stage__':
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
- def str2bool(b):
- if b.lower() not in ["false", "true"]:
- raise Exception("Invalid Bool Value")
- if b.lower() in ["false"]:
- return False
- return True
-
-
- def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
- def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- model = instantiate_from_config(config.model)
- if os.path.exists(ckpt):
- param_dict = ms.load_checkpoint(ckpt)
- if param_dict:
- param_not_load = ms.load_param_into_net(model, param_dict)
- print("param not load:", param_not_load)
- else:
- print(f"{ckpt} not exist:")
-
- return model
-
-
- def load_pretrained_model(pretrained_ckpt, net):
- print(f"start loading pretrained_ckpt {pretrained_ckpt}")
- if os.path.exists(pretrained_ckpt):
- param_dict = load_checkpoint(pretrained_ckpt)
- param_not_load = load_param_into_net(net, param_dict)
- print("param not load:", param_not_load)
- else:
- print("ckpt file not exist!")
-
- print("end loading ckpt")
-
-
- def load_pretrained_model_clip_and_vae(pretrained_ckpt, net):
- new_param_dict = {}
- print(f"start loading pretrained_ckpt {pretrained_ckpt}")
- if os.path.exists(pretrained_ckpt):
- param_dict = load_checkpoint(pretrained_ckpt)
- for key in param_dict:
- if key.startswith("first") or key.startswith("cond"):
- new_param_dict[key] = param_dict[key]
- param_not_load = load_param_into_net(net, new_param_dict)
- print("param not load:")
- for param in param_not_load:
- print(param)
- else:
- print("ckpt file not exist!")
-
- print("end loading ckpt")
-
-
- def main(opts):
- dataset, rank_id, device_id, device_num = init_env(opts)
- LatentDiffusionWithLoss = instantiate_from_config(opts.model_config)
- pretrained_ckpt = os.path.join(opts.pretrained_model_path, opts.pretrained_model_file)
- load_pretrained_model(pretrained_ckpt, LatentDiffusionWithLoss)
-
- if not opts.decay_steps:
- dataset_size = dataset.get_dataset_size()
- opts.decay_steps = opts.epochs * dataset_size
- lr = LearningRate(opts.start_learning_rate, opts.end_learning_rate, opts.warmup_steps, opts.decay_steps)
- optimizer = build_optimizer(LatentDiffusionWithLoss, opts, lr)
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=opts.init_loss_scale,
- scale_factor=opts.loss_scale_factor,
- scale_window=opts.scale_window)
-
- if opts.use_parallel:
- net_with_grads = ParallelTrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer,
- scale_sense=update_cell, parallel_config=ParallelConfig)
- else:
- net_with_grads = TrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer,
- scale_sense=update_cell)
- model = Model(net_with_grads)
- callback = [TimeMonitor(opts.callback_size), LossMonitor(opts.callback_size)]
-
- ofm_cb = OverflowMonitor()
- callback.append(ofm_cb)
-
- if rank_id == 0:
- dataset_size = dataset.get_dataset_size()
- if not opts.save_checkpoint_steps:
- opts.save_checkpoint_steps = dataset_size
- ckpt_dir = os.path.join(opts.output_path, "ckpt", f"rank_{str(rank_id)}")
- if not os.path.exists(ckpt_dir):
- os.makedirs(ckpt_dir)
- config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps,
- keep_checkpoint_max=10,
- integrated_save=False)
- ckpoint_cb = ModelCheckpoint(prefix="wkhh_txt2img",
- directory=ckpt_dir,
- config=config_ck)
- callback.append(ckpoint_cb)
-
- print("start_training...")
- model.train(opts.epochs, dataset, callbacks=callback, dataset_sink_mode=False)
-
-
- if __name__ == "__main__":
- print('process id:', os.getpid())
- parser = argparse.ArgumentParser()
- parser.add_argument('--use_parallel', default=False, type=str2bool, help='use parallel')
- parser.add_argument('--data_path', default="dataset", type=str, help='data path')
- parser.add_argument('--output_path', default="output/", type=str, help='use audio out')
- parser.add_argument('--train_config', default="configs/train_db_config.json", type=str, help='train config path')
- parser.add_argument('--model_config', default="configs/v1-train-db-chinese.yaml", type=str, help='model config path')
- parser.add_argument('--pretrained_model_path', default="", type=str, help='pretrained model directory')
- parser.add_argument('--pretrained_model_file', default="", type=str, help='pretrained model file name')
- parser.add_argument('--train_data_path', default="", type=str, help='train data path')
- parser.add_argument('--reg_data_path', default="", type=str, help='regularization data path')
-
- parser.add_argument('--train_data_repeats', default=100, type=int, help='repetition times of training data')
- parser.add_argument('--class_word', default="", type=str, help='Match class_word to the category of images you want to train')
- parser.add_argument('--token', default="α", type=str, help='unique token you want to represent your trained model')
- parser.add_argument('--optim', default="adamw", type=str, help='optimizer')
- parser.add_argument('--seed', default=3407, type=int, help='data path')
- parser.add_argument('--warmup_steps', default=1000, type=int, help='warmup steps')
- parser.add_argument('--train_batch_size', default=10, type=int, help='batch size')
- parser.add_argument('--callback_size', default=1, type=int, help='callback size.')
- parser.add_argument("--start_learning_rate", default=1e-5, type=float,help="The initial learning rate for Adam.")
- parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.")
- parser.add_argument("--decay_steps", default=0, type=int,help="lr decay steps.")
- parser.add_argument("--epochs", default=10, type=int, help="epochs")
- parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale")
- parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor")
- parser.add_argument("--scale_window", default=1000, type=float, help="scale window")
- parser.add_argument("--save_checkpoint_steps", default=0, type=int, help="save checkpoint steps")
- parser.add_argument('--random_crop', default=False, type=str2bool, help='random crop')
- parser.add_argument('--filter_small_size', default=True, type=str2bool, help='filter small images')
- parser.add_argument('--image_size', default=512, type=int, help='images size')
- parser.add_argument('--image_filter_size', default=256, type=int, help='image filter size')
-
- args = parser.parse_args()
- args = parse_with_config(args)
- abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ""))
- args.model_config = os.path.join(abs_path, args.model_config)
- print(args)
-
- start = time.time()
- main(args)
- end = time.time()
- print("training time: ", end-start)
|