|
- import torch
-
- from .layers import QuantizedColumnParallelLinear
- from .layers import QuantizedRowParallelLinear
-
-
- def quantize(model, weight_bit_width):
- """Replace fp16 linear with quantized linear"""
-
- if torch.distributed.get_rank() == 0:
- print(f"> Quantizing model weight to {weight_bit_width} bits")
-
- for layer in model.transformer.layers:
- layer.attention.query_key_value = QuantizedColumnParallelLinear(
- weight_bit_width=weight_bit_width,
- weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
- input_size=layer.attention.query_key_value.input_size,
- output_size=layer.attention.query_key_value.output_size,
- bias=True,
- gather_output=False,
- params_dtype=torch.half,
- name="query_key_value",
- skip_init=True,
- device=layer.attention.query_key_value.weight.device,
- )
- layer.attention.dense = QuantizedRowParallelLinear(
- weight_bit_width=weight_bit_width,
- weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
- input_size=layer.attention.dense.input_size,
- output_size=layer.attention.dense.output_size,
- bias=True,
- input_is_parallel=True,
- params_dtype=torch.half,
- name="dense",
- skip_init=True,
- device=layer.attention.dense.weight.device,
- )
- layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
- weight_bit_width=weight_bit_width,
- weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
- input_size=layer.mlp.dense_h_to_4h.input_size,
- output_size=layer.mlp.dense_h_to_4h.output_size,
- bias=True,
- gather_output=False,
- params_dtype=torch.half,
- name="dense_h_to_4h",
- skip_init=True,
- device=layer.mlp.dense_h_to_4h.weight.device,
- )
- layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
- weight_bit_width=weight_bit_width,
- weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
- input_size=layer.mlp.dense_4h_to_h.input_size,
- output_size=layer.mlp.dense_4h_to_h.output_size,
- bias=True,
- input_is_parallel=True,
- params_dtype=torch.half,
- name="dense_h_to_4h",
- skip_init=True,
- device=layer.mlp.dense_4h_to_h.weight.device,
- )
-
- return model
|