0 联邦学习中的模型量化
KXiong edited this page 1 year ago

背景

利用模型量化技术对模型进行压缩有助于提升联邦学习训练效率,减缓通信压力。量化技术又具体分为训练后量化和感知量化训练,接下来具体地利用感知量化训练技术压缩联邦学习中的网络模型,在缩小网络模型大小,减少内存消耗的同时维持模型精度。

感知训练量化

接下来以简单的Lenet网络、Mnist数据集为例,利用SimQAT算法进行感知量化训练。Lenet网络如下:

class LeNet5(nn.Cell):
    def __init__(self, num_class=10, num_channel=1, include_top=True):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.include_top = include_top
        if self.include_top:
            self.flatten = nn.Flatten()
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        if not self.include_top:
            return x
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

插入伪量化节点后网络结构如下:

LeNet5Opt<
  (_handler): LeNet5<
    (conv1): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (conv2): Conv2d<input_channels=6, output_channels=16, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (relu): ReLU<>
    (max_pool2d): MaxPool2d<kernel_size=2, stride=2, pad_mode=VALID>
    (flatten): Flatten<>
    (fc1): Dense<input_channels=400, output_channels=120, has_bias=True>
    (fc2): Dense<input_channels=120, output_channels=84, has_bias=True>
    (fc3): Dense<input_channels=84, output_channels=10, has_bias=True>
    >
  (max_pool2d): MaxPool2d<kernel_size=2, stride=2, pad_mode=VALID>
  (flatten): Flatten<>
  (fc1): Dense<input_channels=400, output_channels=120, has_bias=True>
  (fc2): Dense<input_channels=120, output_channels=84, has_bias=True>
  (fc3): Dense<input_channels=84, output_channels=10, has_bias=True>
  (DenseQuant): QuantizeWrapperCell<
    (_handler): DenseQuant<
      in_channels=84, out_channels=10, weight=Parameter (name=DenseQuant._handler.weight, shape=(10, 84), dtype=Float32, requires_grad=True), has_bias=True, bias=Parameter (name=DenseQuant._handler.bias, shape=(10,), dtype=Float32, requires_grad=True)
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 10), quant_delay=900>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  (Conv2dBnFoldQuant): QuantizeWrapperCell<
    (_handler): Conv2dBnFoldQuant<
      in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, fake=True, freeze_bn=10000000, momentum=0.997
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 6), quant_delay=900>
      (batchnorm_fold): BatchNormFoldCell<>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  (Conv2dBnFoldQuant_1): QuantizeWrapperCell<
    (_handler): Conv2dBnFoldQuant<
      in_channels=6, out_channels=16, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, fake=True, freeze_bn=10000000, momentum=0.997
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 16), quant_delay=900>
      (batchnorm_fold): BatchNormFoldCell<>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  (DenseQuant_1): QuantizeWrapperCell<
    (_handler): DenseQuant<
      in_channels=400, out_channels=120, weight=Parameter (name=DenseQuant_1._handler.weight, shape=(120, 400), dtype=Float32, requires_grad=True), has_bias=True, bias=Parameter (name=DenseQuant_1._handler.bias, shape=(120,), dtype=Float32, requires_grad=True), activation=ReLU<>
      (activation): ReLU<>
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 120), quant_delay=900>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  (DenseQuant_2): QuantizeWrapperCell<
    (_handler): DenseQuant<
      in_channels=120, out_channels=84, weight=Parameter (name=DenseQuant_2._handler.weight, shape=(84, 120), dtype=Float32, requires_grad=True), has_bias=True, bias=Parameter (name=DenseQuant_2._handler.bias, shape=(84,), dtype=Float32, requires_grad=True), activation=ReLU<>
      (activation): ReLU<>
      (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 84), quant_delay=900>
      >
    (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
    >
  >

训练

将量化后网络在MNIST数据集上进行训练:

============== Starting Training ==============
epoch: 1 step: 1875, loss is 0.1184234544634819
Train epoch time: 18381.949 ms, per step time: 9.804 ms
epoch: 2 step: 1875, loss is 0.005062167067080736
Train epoch time: 9518.341 ms, per step time: 5.076 ms
epoch: 3 step: 1875, loss is 0.004801809787750244
Train epoch time: 9568.458 ms, per step time: 5.103 ms
epoch: 4 step: 1875, loss is 0.032152824103832245
Train epoch time: 8535.799 ms, per step time: 4.552 ms
epoch: 5 step: 1875, loss is 0.004819052759557962
Train epoch time: 8222.206 ms, per step time: 4.385 ms
epoch: 6 step: 1875, loss is 0.0224038977175951
Train epoch time: 8426.765 ms, per step time: 4.494 ms
epoch: 7 step: 1875, loss is 0.11416034400463104
Train epoch time: 8182.984 ms, per step time: 4.364 ms
epoch: 8 step: 1875, loss is 0.21064913272857666
Train epoch time: 8471.022 ms, per step time: 4.518 ms
epoch: 9 step: 1875, loss is 0.006558818276971579
Train epoch time: 10166.347 ms, per step time: 5.422 ms
epoch: 10 step: 1875, loss is 0.0007507787086069584
Train epoch time: 9955.181 ms, per step time: 5.309 ms

结果

对量化后权重进行评估,结果如下,可以看到量化后模型准确度与全精度模型准确度 0.983 维持一致。

========== LeNet GRAPH mode accuracy: 0.9906850961538461 ==========