|
- import os
- import traceback
- from multiprocessing import Queue, Process
-
-
- def chunked_worker(worker_id, map_func, args, results_queue=None, init_ctx_func=None):
- ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
- for job_idx, arg in args:
- try:
- if ctx is not None:
- res = map_func(*arg, ctx=ctx)
- else:
- res = map_func(*arg)
- results_queue.put((job_idx, res))
- except:
- traceback.print_exc()
- results_queue.put((job_idx, None))
-
- def chunked_multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, q_max_size=1000):
- args = zip(range(len(args)), args)
- args = list(args)
- n_jobs = len(args)
- if num_workers is None:
- num_workers = int(os.getenv('N_PROC', os.cpu_count()))
- results_queues = []
- if ordered:
- for i in range(num_workers):
- results_queues.append(Queue(maxsize=q_max_size // num_workers))
- else:
- results_queue = Queue(maxsize=q_max_size)
- for i in range(num_workers):
- results_queues.append(results_queue)
- workers = []
- for i in range(num_workers):
- args_worker = args[i::num_workers]
- p = Process(target=chunked_worker, args=(
- i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True)
- workers.append(p)
- p.start()
- for n_finished in range(n_jobs):
- results_queue = results_queues[n_finished % num_workers]
- job_idx, res = results_queue.get()
- assert job_idx == n_finished or not ordered, (job_idx, n_finished)
- yield res
- for w in workers:
- w.join()
- w.close()
|