|
- # Copyright 2022 Wh1isper
- #
- # Use of this source code is governed by an MIT-style
- # license that can be found in the LICENSE file or at
- # https://opensource.org/licenses/MIT.
-
- import errno
- import os
- import json
- import logging
- import pathlib
- import signal
- import socket
- import subprocess
- import time
- import asyncio
- from asyncio import Queue
- import threading
- from contextlib import contextmanager
-
- import sh
- from tornado import ioloop
-
- RANK_TABLE_FILE = 'RANK_TABLE_FILE'
- LOGO = 'datai'
-
-
- class Logger:
-
- @staticmethod
- def setup_modelarts_logger():
- name = LOGO
- formatter = logging.Formatter(fmt='[DataI]%(asctime)s - %(levelname)s - %(message)s')
-
- handler = logging.StreamHandler()
- handler.setFormatter(formatter)
-
- logger = logging.getLogger(name)
- logger.setLevel(logging.INFO)
- logger.addHandler(handler)
- logger.propagate = False
- return logger
-
- @staticmethod
- def get_modelarts_logger():
- return logging.getLogger(LOGO)
-
-
- log = Logger.setup_modelarts_logger()
-
-
- class Device:
- def __init__(self, device_id, device_ip, rank_id):
- self.device_id = device_id
- self.device_ip = device_ip
- # v1
- self.rank_id = rank_id
-
-
- class Instance:
- def __init__(self, pod_name, server_id, devices):
- self.pod_name = pod_name
- self.server_id = server_id
- self.devices = self.parse_devices(devices)
-
- @staticmethod
- def parse_devices(devices):
- if devices is None:
- return []
- device_object_list = []
- for device in devices:
- device_object_list.append(Device(device['device_id'], device['device_ip'], ''))
-
- return device_object_list
-
- def set_devices(self, devices):
- self.devices = devices
-
-
- class RankTable:
- STATUS_FIELD = 'status'
- COMPLETED_STATUS = 'completed'
-
- def __init__(self):
- self.rank_table_path = ""
- self.rank_table = {}
-
- @staticmethod
- def read_from_file(file_path):
- with open(file_path) as json_file:
- return json.load(json_file)
-
- @staticmethod
- def convert_server_to_instance(server):
- device_list = []
- for device in server['device']:
- device_list.append(
- Device(device_id=device['device_id'], device_ip=device['device_ip'], rank_id=device['rank_id']))
-
- ins = Instance(pod_name='', server_id=server['server_id'], devices=[])
- ins.set_devices(device_list)
- return ins
-
- def get_rank_table_path(self):
- return self.rank_table_path
-
- def get_server(self, server_id):
- for server in self.rank_table['server_list']:
- if server['server_id'] == server_id:
- log.info('Current server')
- log.info('\n' + json.dumps(server, indent=4))
- return server
-
- log.error('server [%s] is not found' % server_id)
- return None
-
-
- def get_current_host_ip():
- return os.environ.get('MA_CURRENT_HOST_IP')
-
-
- class RankTableV1(RankTable):
- def __init__(self, rank_table_path):
- super().__init__()
- self.rank_table_path = rank_table_path
- self.rank_table = self.read_from_file(file_path=rank_table_path)
-
- def get_current_instance(self):
- current_server = None
- server_list = self.rank_table['server_list']
- if len(server_list) == 1:
- current_server = server_list[0]
- elif len(server_list) > 1:
- host_ip = get_current_host_ip()
- if host_ip is not None:
- for server in server_list:
- if server['server_id'] == host_ip:
- current_server = server
- break
- else:
- current_server = server_list[0]
-
- if current_server is None:
- log.error('server is not found')
- return None
- return self.convert_server_to_instance(current_server)
-
- def get_device_num(self):
- server_list = self.rank_table['server_list']
- device_num = 0
- for server in server_list:
- device_num += len(server['device'])
- return device_num
-
-
- class DistributedRuntimeError(RuntimeError):
- ...
-
-
- _registered_child_ids = []
-
-
- class SigHandler:
-
- @staticmethod
- def register_sig_child_handler():
- signal.signal(signal.SIGCHLD, SigHandler.wait_child)
-
- @staticmethod
- def register_wait_child(child_pid):
- if child_pid not in _registered_child_ids:
- _registered_child_ids.append(child_pid)
-
- @staticmethod
- def wait_child(signum, frame):
- try:
- for child_pid in _registered_child_ids:
- _, status = os.waitpid(child_pid, os.WNOHANG)
- except OSError as e:
- if e.errno == errno.ECHILD:
- pass
- else:
- raise
-
-
- class AscendVersionManager:
- driver_version_file_path = '/usr/local/Ascend/driver/version.info'
-
- c75_tr5_driver_version = 'Version=20.1.0'
-
- @staticmethod
- def test_driver_version_file_exists():
- return os.path.isfile(AscendVersionManager.driver_version_file_path)
-
- @staticmethod
- def print_ascend_driver_version():
- if not AscendVersionManager.test_driver_version_file_exists():
- log.warn('there is no %s file' % AscendVersionManager.driver_version_file_path)
- log.info('Ascend Driver: Unknown')
- return
-
- with open(AscendVersionManager.driver_version_file_path) as version_file:
- for line in version_file:
- line = line.strip()
- log.info('Ascend Driver: %s' % line)
- # we only take the first line into account
- return
-
- return
-
- @staticmethod
- def is_atlas_c75_tr5():
- if not AscendVersionManager.test_driver_version_file_exists():
- return False
-
- with open(AscendVersionManager.driver_version_file_path) as version_file:
- for line in version_file:
- line = line.strip()
- if line == AscendVersionManager.c75_tr5_driver_version:
- return True
- return False
-
- return False
-
-
- def tail(filename, msg, pid):
- for line in sh.tail("-f", "--pid", pid, filename, _iter=True):
- print(f'{msg}: {line}', end='')
-
-
- class TailManager:
- threads = Queue()
- _started = False
-
- @classmethod
- def start_tail(cls, file_path, msg, pid):
- thread = threading.Thread(target=tail, args=(file_path, msg, pid))
- cls.start_thread(thread)
- cls.start_clean_inactivate()
-
- @classmethod
- def start_thread(cls, thread):
- thread.start()
- ioloop.IOLoop.current().add_callback(cls.threads.put, thread)
-
- @classmethod
- def start_clean_inactivate(cls):
- if cls._started:
- return
- ioloop.IOLoop.current().add_callback(cls.clean_inactivate)
- cls._started = True
-
- @classmethod
- async def clean_inactivate(cls, interval=2):
- while True:
- for _ in range(cls.threads.qsize()):
- t: threading.Thread = await cls.threads.get()
- if t.is_alive():
- await cls.threads.put(t)
- else:
- t.join()
- await asyncio.sleep(interval)
-
-
- class LogRecorder(object):
- pid_log_path = dict()
-
- @classmethod
- def record_pid_log_path(cls, pid, file_path):
- cls.pid_log_path[pid] = file_path
-
- @classmethod
- def get_log_from_pid(cls, pid):
- file_path = cls.pid_log_path.get(pid)
- if file_path:
- with open(file_path) as f:
- return f.read()
- else:
- return 'No log file found'
-
-
- def get_job_id():
- if 'BATCH_JOB_ID' in os.environ:
- return os.environ['BATCH_JOB_ID']
-
- if 'MA_VJ_NAME' in os.environ:
- ma_vj_name = os.environ['MA_VJ_NAME']
- return ma_vj_name.replace('ma-job', 'modelarts-job', 1)
-
- return socket.gethostname()
-
-
- def is_in_notebook():
- """To check is in notenook or not. """
-
- try:
- import ipykernel
- ipykernel.get_connection_info()
- # Temporary fix for #84
- # TODO: remove blanket Exception catching after fixing #84
- except Exception:
- return False
- return True
-
-
- class FMK:
-
- def __init__(self, c75_tr5, index, device):
- self.c75_tr5 = c75_tr5
-
- self.job_id = get_job_id()
- self.rank_id = device.rank_id
- if not c75_tr5:
- # logic device id after c75-tr5
- # specially, mindspore needs logic device id in c75-tr5 and after
- self.device_id = str(index)
- else:
- # physical device id in c75-tr5 (and before)
- self.device_id = device.device_id
-
- def gen_env_for_fmk(self, rank_size):
- current_envs = os.environ.copy()
- current_envs['JOB_ID'] = self.job_id
-
- if not self.c75_tr5:
- # import a new ASCEND_DEVICE_ID env as the logical device id after c75-tr5
- current_envs['ASCEND_DEVICE_ID'] = self.device_id
- # the DEVICE_ID env will be deprecated, keep it in order to be compatible with moxing and mindspore
- # physical device id in c75-tr5 (non mindspore)
- # logical device id after c75-tr5
- current_envs['DEVICE_ID'] = self.device_id
-
- current_envs['RANK_ID'] = self.rank_id
- current_envs['RANK_SIZE'] = str(rank_size)
-
- FMK.set_env_if_not_exist(current_envs, 'HCCL_CONNECT_TIMEOUT', str(1800)) # 30min
-
- self.gen_diag_mode_env(current_envs)
-
- return current_envs
-
- def gen_diag_mode_env(self, current_envs):
- log_dir = FMK.get_log_dir()
- process_log_path = os.path.join(log_dir, self.job_id, 'ascend', 'process_log', 'rank_' + self.rank_id)
- FMK.set_env_if_not_exist(current_envs, 'ASCEND_PROCESS_LOG_PATH', process_log_path)
- pathlib.Path(current_envs['ASCEND_PROCESS_LOG_PATH']).mkdir(parents=True, exist_ok=True)
- diag_mode = current_envs.get('MA_DIAG_MODE_ENV', '')
- run_mode = current_envs.get('MA_RUN_MODE_ENV', '')
- engine_version = current_envs.get("MA_ENGINE_VERSION", '')
- glog_dir = ms_rdr_path = ms_om_path = os.path.join(log_dir, self.job_id, 'mindspore', 'log')
- if diag_mode == 'faults':
- FMK.set_env_if_not_exist(current_envs, 'PRINT_MODEL', str(1))
- FMK.set_env_if_not_exist(current_envs, 'DUMP_GE_GRAPH', str(2))
- FMK.set_env_if_not_exist(current_envs, 'DUMP_GRAPH_LEVEL', str(2))
- FMK.set_env_if_not_exist(current_envs, 'ASCEND_GLOBAL_LOG_LEVEL', str(1))
- FMK.set_env_if_not_exist(current_envs, 'ASCEND_HOST_LOG_FILE_NUM', str(1000))
-
- npu_collect_path = os.path.join(log_dir, self.job_id, 'ascend', 'npu_collect', 'rank_' + self.rank_id)
- FMK.set_env_if_not_exist(current_envs, 'NPU_COLLECT_PATH', npu_collect_path)
- pathlib.Path(os.path.join(current_envs['NPU_COLLECT_PATH'], 'extra-info', 'graph')).mkdir(parents=True,
- exist_ok=True)
-
- framework_name_version = next(iter(engine_version.split('-')[0:1]), '')
- framework_version = next(iter(framework_name_version.split('_')[1:2]), '')
- if 'mindspore' in framework_name_version and '1.4' <= framework_version:
- FMK.set_env_if_not_exist(current_envs, 'GLOG_v', str(1))
- FMK.set_env_if_not_exist(current_envs, 'GLOG_log_dir', glog_dir)
- FMK.set_env_if_not_exist(current_envs, 'GLOG_logtostderr', str(0))
- FMK.set_env_if_not_exist(current_envs, 'MS_RDR_ENABLE', str(1))
- FMK.set_env_if_not_exist(current_envs, 'MS_RDR_PATH', ms_rdr_path)
- FMK.set_env_if_not_exist(current_envs, 'MS_OM_PATH', ms_om_path)
-
- elif diag_mode == 'accuracy' or diag_mode == 'profile':
- diag_data_path = os.path.join(log_dir, self.job_id, 'mindspore', 'diagnostic_data')
- FMK.set_env_if_not_exist(current_envs, 'MS_DIAGNOSTIC_DATA_PATH', diag_data_path)
-
- elif run_mode == 'performance':
- FMK.set_env_if_not_exist(current_envs, 'ASCEND_GLOBAL_LOG_LEVEL', str(3))
- FMK.set_env_if_not_exist(current_envs, 'ASCEND_GLOBAL_EVENT_LEVEL', str(0))
- FMK.set_env_if_not_exist(current_envs, 'GLOG_v', str(3))
- FMK.set_env_if_not_exist(current_envs, 'GLOG_log_dir', glog_dir)
- FMK.set_env_if_not_exist(current_envs, 'GLOG_logtostderr', str(0))
- FMK.set_env_if_not_exist(current_envs, 'MS_OM_PATH', ms_om_path)
-
- elif run_mode == 'normal':
- FMK.set_env_if_not_exist(current_envs, 'GLOG_v', str(1))
- FMK.set_env_if_not_exist(current_envs, 'GLOG_log_dir', glog_dir)
- FMK.set_env_if_not_exist(current_envs, 'GLOG_logtostderr', str(0))
- FMK.set_env_if_not_exist(current_envs, 'MS_OM_PATH', ms_om_path)
-
- @contextmanager
- def switch_directory(self, directory):
- owd = os.getcwd()
- try:
- os.chdir(directory)
- yield directory
- finally:
- os.chdir(owd)
-
- @staticmethod
- def get_log_dir():
- return '/tmp/logdir'
-
- @staticmethod
- def set_env_if_not_exist(envs, env_name, env_value):
- if env_name in os.environ:
- log.info('env already exists. env_name: %s, env_value: %s ' % (env_name, env_value))
- return
- envs[env_name] = env_value
-
- def run(self, rank_size, command, work_dir, user_log_dir, *, output_notebook):
- envs = self.gen_env_for_fmk(rank_size)
- log.info('bootstrap proc-rank-%s-device-%s' % (self.rank_id, self.device_id))
-
- working_dir = work_dir
- if not os.path.exists(working_dir):
- os.makedirs(working_dir)
-
- log_dir = FMK.get_log_dir()
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
-
- if not os.path.exists(user_log_dir):
- os.makedirs(user_log_dir)
-
- if self.c75_tr5:
- with self.switch_directory(working_dir):
- return subprocess.Popen(command, env=envs, preexec_fn=os.setsid)
-
- # we `tee` a proc log of each training processes after c75-tr5
-
- # AOM collect (*.trace | *.log | *.out) log file
- # let log_file end with .txt, avoid AOM collect it
- log_file = '%s-proc-rank-%s-device-%s.txt' % (self.job_id, self.rank_id, self.device_id)
- log_file_path = os.path.join(log_dir, log_file)
- user_log_file_path = os.path.join(user_log_dir, log_file)
-
- with self.switch_directory(working_dir):
- # os.setsid: change the process(forked) group id to itself
- training_proc = subprocess.Popen(command, env=envs, preexec_fn=os.setsid,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- )
-
- log.info('proc-rank-%s-device-%s (pid: %d)', self.rank_id, self.device_id, training_proc.pid)
-
- # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
- # modelarts_pipe_cmd should consume the stdout in time and avoid proc deadlock
- # and currently, we use `tee` instead of `modelarts-pipe`, as the modelarts-pipe requires singleton
- # TODO: limit the splitting log file size < 1GB
- subprocess.Popen(
- ['tee', log_file_path, user_log_file_path],
- stdin=training_proc.stdout,
- )
- LogRecorder.record_pid_log_path(training_proc.pid, user_log_file_path)
- if output_notebook and is_in_notebook():
- msg = f'proc-rank-{self.rank_id}-device-{self.device_id} (pid: {training_proc.pid})'
- TailManager.start_tail(user_log_file_path, msg, training_proc.pid)
-
- return training_proc
-
-
- class FMKManager:
- # max destroy time: ~20 (15 + 5)
- # ~ 15 (1 + 2 + 4 + 8)
- MAX_TEST_PROC_CNT = 4
- KILL_WAIT_TIME = 5
-
- _registered = False
-
- @classmethod
- def _register(cls):
- # Only need to register once
- if cls._registered:
- return
- SigHandler.register_sig_child_handler()
- cls._registered = True
-
- def __init__(self, instance):
- self.instance = instance
- self.fmk = []
- self.fmk_processes = []
- self.get_sigterm = False
- self._register()
-
- # break the monitor and destory processes when get terminate signal
- def term_handle(func):
- def receive_term(signum, stack):
- log.info('Received terminate signal %d, try to destroyed all processes' % signum)
- stack.f_locals['self'].get_sigterm = True
-
- def handle_func(self, *args, **kwargs):
- origin_handle = signal.getsignal(signal.SIGTERM)
- signal.signal(signal.SIGTERM, receive_term)
- res = func(self, *args, **kwargs)
- signal.signal(signal.SIGTERM, origin_handle)
- return res
-
- return handle_func
-
- def run(self, rank_size, command, work_dir, log_dir, *, output_notebook=False):
- c75_tr5_flag = AscendVersionManager.is_atlas_c75_tr5()
- for index, device in enumerate(self.instance.devices):
- fmk_instance = FMK(c75_tr5_flag, index, device)
- self.fmk.append(fmk_instance)
-
- self.fmk_processes.append(
- fmk_instance.run(rank_size, command, work_dir, log_dir, output_notebook=output_notebook))
-
- @term_handle
- def monitor(self, period=1, raise_exception=True):
- # busy waiting for all fmk processes exit by zero
- # or there is one process exit by non-zero
-
- fmk_cnt = len(self.fmk_processes)
- zero_ret_cnt = 0
- while zero_ret_cnt != fmk_cnt:
- zero_ret_cnt = 0
- for index in range(fmk_cnt):
- fmk = self.fmk[index]
- fmk_process = self.fmk_processes[index]
- if fmk_process.poll() is not None:
- if fmk_process.returncode != 0:
- log.error('proc-rank-%s-device-%s (pid: %d) has exited with non-zero code: %d'
- % (fmk.rank_id, fmk.device_id, fmk_process.pid, fmk_process.returncode))
- # only works when start by output_notebook=True
- err_log = LogRecorder.get_log_from_pid(fmk_process.pid)
- if raise_exception:
- raise DistributedRuntimeError('\n' + err_log)
- return fmk_process.returncode
-
- zero_ret_cnt += 1
- if self.get_sigterm:
- break
- time.sleep(period)
-
- return 0
-
- def destroy(self, base_period=1):
- log.info('Begin destroy training processes')
- self.send_sigterm_to_fmk_process()
- self.wait_fmk_process_end(base_period)
- log.info('End destroy training processes')
-
- def send_sigterm_to_fmk_process(self):
- # send SIGTERM to fmk processes (and process group)
- for r_index in range(len(self.fmk_processes) - 1, -1, -1):
- fmk = self.fmk[r_index]
- fmk_process = self.fmk_processes[r_index]
- if fmk_process.poll() is not None:
- log.info('proc-rank-%s-device-%s (pid: %d) has exited', fmk.rank_id, fmk.device_id, fmk_process.pid)
- del self.fmk_processes[r_index]
- del self.fmk[r_index]
-
- try:
- os.killpg(fmk_process.pid, signal.SIGTERM)
- except ProcessLookupError:
- pass
-
- def wait_fmk_process_end(self, base_period):
- test_cnt = 0
- period = base_period
- while len(self.fmk_processes) > 0 and test_cnt < self.MAX_TEST_PROC_CNT:
- for r_index in range(len(self.fmk_processes) - 1, -1, -1):
- fmk = self.fmk[r_index]
- fmk_process = self.fmk_processes[r_index]
- if fmk_process.poll() is not None:
- log.info('proc-rank-%s-device-%s (pid: %d) has exited',
- fmk.rank_id, fmk.device_id, fmk_process.pid)
- del self.fmk_processes[r_index]
- del self.fmk[r_index]
-
- time.sleep(period)
- period *= 2
- test_cnt += 1
-
- if len(self.fmk_processes) > 0:
- for r_index in range(len(self.fmk_processes) - 1, -1, -1):
- fmk = self.fmk[r_index]
- fmk_process = self.fmk_processes[r_index]
- if fmk_process.poll() is None:
- log.warn('proc-rank-%s-device-%s (pid: %d) has not exited within the max waiting time, '
- 'send kill signal',
- fmk.rank_id, fmk.device_id, fmk_process.pid)
- os.killpg(fmk_process.pid, signal.SIGKILL)
-
- def wait(self, destroy_when_finished=True, raise_exception=True):
- try:
- return_code = self.monitor(raise_exception=raise_exception)
- except DistributedRuntimeError:
- log.error('Running distributed work error, throw exception...')
- raise
- except Exception:
- log.error('Unknown Error occurred...')
- raise
- finally:
- if destroy_when_finished:
- self.destroy()
- return return_code
-
-
- def get_rank_table():
- try:
- rank_table_path = os.environ[RANK_TABLE_FILE]
- except KeyError:
- raise RuntimeError('No environment variable RANK_TABLE_FILE, try generate_rank_table() first.')
- return RankTableV1(rank_table_path)
-
-
- def start_distributed_train(command, work_dir='./', log_dir='./log', *, output_notebook=False):
- rank_table = get_rank_table()
- instance = rank_table.get_current_instance()
- server = rank_table.get_server(instance.server_id)
- current_instance = RankTable.convert_server_to_instance(server)
- fmk_manager = FMKManager(current_instance)
- fmk_manager.run(rank_table.get_device_num(), command, work_dir, log_dir, output_notebook=output_notebook)
- return fmk_manager
-
-
- def wait_distributed_train(fmk_manager, destroy_when_finished=True, raise_exception=True):
- return fmk_manager.wait(destroy_when_finished, raise_exception)
|