|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import logging
-
- from fairseq.modules.quantization import pq, quantization_options, scalar
- from omegaconf import DictConfig
-
-
- logger = logging.getLogger(__name__)
-
-
- def quantize_model_scalar(model, model_cfg: DictConfig):
- quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
- if quant_noise_scalar > 0:
- # quantize_model edits the model in place
- scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
- return model
-
-
- class Quantizer(object):
- def __init__(self, config_path, max_epoch, max_update):
- try:
- import yaml
- except ImportError:
- raise ImportError("Please install yaml with: pip install yaml")
-
- # parse config
- if config_path:
- with open(config_path) as config_file:
- config = quantization_options.parse_config_yaml(
- yaml.safe_load(config_file)
- )
- else:
- config = quantization_options.parse_config_yaml({})
-
- self.n_centroids_config = config["n_centroids"]
- self.block_sizes_config = config["block_sizes"]
- self.layers_to_quantize = config["layers_to_quantize"]
-
- # We assume that training will run for a fixed number of epochs
- # (or updates) and that we should train for equal durations
- # between iterations of PQ.
- num_iterations = len(self.layers_to_quantize)
- if max_epoch > 0:
- assert max_epoch % num_iterations == 0, (
- "for iterative PQ, --max-epoch (={}) must be evenly divisible by "
- "len(layers_to_quantize) (={})".format(max_epoch, num_iterations)
- )
- self.epoch_schedule = max_epoch // num_iterations
- else:
- self.epoch_schedule = None
- if max_update > 0:
- assert max_update % num_iterations == 0, (
- "for iterative PQ, --max-update (={}) must be evenly divisible by "
- "len(layers_to_quantize) (={})".format(max_update, num_iterations)
- )
- self.update_schedule = max_update // num_iterations
- else:
- self.update_schedule = None
- assert (self.epoch_schedule is not None) ^ (
- self.update_schedule is not None
- ), "for iterative PQ, cannot specify both --max-update and --max-epoch"
-
- # 0 is a special value for quantization step, which will force
- # the first call to begin_epoch() to call step()
- self.quantization_step = 0
-
- def set_trainer(self, trainer):
- self.trainer = trainer
- self.size_tracker = pq.SizeTracker(self.trainer.get_model())
-
- def step(self):
- """Move to the next stage of quantization."""
- if self.quantization_step >= len(self.layers_to_quantize):
- # Maybe we just finished the last training step or we loaded
- # a checkpoint for an iterative PQ model which previously
- # finished training. Either way, don't quantize again.
- return
-
- logger.info(
- "quantizing model (step={}; layers_to_quantize[step]={})".format(
- self.quantization_step, self.layers_to_quantize[self.quantization_step]
- )
- )
- quantized_layers = pq.quantize_model_(
- self.trainer.get_model(),
- self.size_tracker,
- self.layers_to_quantize,
- self.block_sizes_config,
- self.n_centroids_config,
- step=self.quantization_step,
- )
- logger.info("quantized layers: {}".format(quantized_layers))
- logger.info(self.size_tracker)
-
- self.quantization_step += 1
-
- # reintialize the Trainer since model parameters have changed
- self.trainer.reinitialize()
-
- def begin_epoch(self, epoch):
- """Called at the beginning of each epoch (epochs start at 1)."""
- if (
- (
- self.epoch_schedule is not None
- and epoch > 0
- and (epoch - 1) % self.epoch_schedule == 0
- )
- # we always step once in the beginning, even if using
- # update-based quantization
- or self.quantization_step == 0
- ):
- self.step()
-
- def step_update(self, num_updates):
- """Called at the end of each step."""
- if (
- self.update_schedule is not None
- and num_updates > 0
- and num_updates % self.update_schedule == 0
- ):
- self.step()
-
- def state_dict(self):
- return {
- "n_centroids_config": self.n_centroids_config,
- "block_sizes_config": self.block_sizes_config,
- "layers_to_quantize": self.layers_to_quantize,
- "epoch_schedule": self.epoch_schedule,
- "update_schedule": self.update_schedule,
- "quantization_step": self.quantization_step,
- }
-
- def load_state_dict(self, state_dict):
- self.n_centroids_config = state_dict["n_centroids_config"]
- self.block_sizes_config = state_dict["block_sizes_config"]
- self.layers_to_quantize = state_dict["layers_to_quantize"]
- self.epoch_schedule = state_dict["epoch_schedule"]
- self.update_schedule = state_dict["update_schedule"]
- self.quantization_step = state_dict["quantization_step"]
|