|
- # -*- coding: utf-8 -*-
- # File : comm.py
- # Author : Jiayuan Mao
- # Email : maojiayuan@gmail.com
- # Date : 27/01/2018
- #
- # This file is part of Synchronized-BatchNorm-PyTorch.
- # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
- # Distributed under MIT License.
-
- import queue
- import collections
- import threading
-
- __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
-
-
- class FutureResult(object):
- """A thread-safe future implementation. Used only as one-to-one pipe."""
-
- def __init__(self):
- self._result = None
- self._lock = threading.Lock()
- self._cond = threading.Condition(self._lock)
-
- def put(self, result):
- with self._lock:
- assert self._result is None, 'Previous result has\'t been fetched.'
- self._result = result
- self._cond.notify()
-
- def get(self):
- with self._lock:
- if self._result is None:
- self._cond.wait()
-
- res = self._result
- self._result = None
- return res
-
-
- _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
- _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
-
-
- class SlavePipe(_SlavePipeBase):
- """Pipe for master-slave communication."""
-
- def run_slave(self, msg):
- self.queue.put((self.identifier, msg))
- ret = self.result.get()
- self.queue.put(True)
- return ret
-
-
- class SyncMaster(object):
- """An abstract `SyncMaster` object.
-
- - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
- call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
- and passed to a registered callback.
- - After receiving the messages, the master device should gather the information and determine to message passed
- back to each slave devices.
- """
-
- def __init__(self, master_callback):
- """
-
- Args:
- master_callback: a callback to be invoked after having collected messages from slave devices.
- """
- self._master_callback = master_callback
- self._queue = queue.Queue()
- self._registry = collections.OrderedDict()
- self._activated = False
-
- def __getstate__(self):
- return {'master_callback': self._master_callback}
-
- def __setstate__(self, state):
- self.__init__(state['master_callback'])
-
- def register_slave(self, identifier):
- """
- Register an slave device.
-
- Args:
- identifier: an identifier, usually is the device id.
-
- Returns: a `SlavePipe` object which can be used to communicate with the master device.
-
- """
- if self._activated:
- assert self._queue.empty(), 'Queue is not clean before next initialization.'
- self._activated = False
- self._registry.clear()
- future = FutureResult()
- self._registry[identifier] = _MasterRegistry(future)
- return SlavePipe(identifier, self._queue, future)
-
- def run_master(self, master_msg):
- """
- Main entry for the master device in each forward pass.
- The messages were first collected from each devices (including the master device), and then
- an callback will be invoked to compute the message to be sent back to each devices
- (including the master device).
-
- Args:
- master_msg: the message that the master want to send to itself. This will be placed as the first
- message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
-
- Returns: the message to be sent back to the master device.
-
- """
- self._activated = True
-
- intermediates = [(0, master_msg)]
- for i in range(self.nr_slaves):
- intermediates.append(self._queue.get())
-
- results = self._master_callback(intermediates)
- assert results[0][0] == 0, 'The first result should belongs to the master.'
-
- for i, res in results:
- if i == 0:
- continue
- self._registry[i].result.put(res)
-
- for i in range(self.nr_slaves):
- assert self._queue.get() is True
-
- return results[0][1]
-
- @property
- def nr_slaves(self):
- return len(self._registry)
|