|
- import torch.nn as nn
- from typing import List, Tuple
-
-
- class SharedMLP(nn.Sequential):
-
- def __init__(
- self,
- args: List[int],
- *,
- bn: bool = False,
- activation=nn.ReLU(inplace=True),
- preact: bool = False,
- first: bool = False,
- name: str = "",
- instance_norm: bool = False,
- ):
- super().__init__()
-
- for i in range(len(args) - 1):
- self.add_module(
- name + 'layer{}'.format(i),
- Conv2d(
- args[i],
- args[i + 1],
- bn=(not first or not preact or (i != 0)) and bn,
- activation=activation
- if (not first or not preact or (i != 0)) else None,
- preact=preact,
- instance_norm=instance_norm
- )
- )
-
-
- class _ConvBase(nn.Sequential):
-
- def __init__(
- self,
- in_size,
- out_size,
- kernel_size,
- stride,
- padding,
- activation,
- bn,
- init,
- conv=None,
- batch_norm=None,
- bias=True,
- preact=False,
- name="",
- instance_norm=False,
- instance_norm_func=None
- ):
- super().__init__()
-
- bias = bias and (not bn)
- conv_unit = conv(
- in_size,
- out_size,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- bias=bias
- )
- init(conv_unit.weight)
- if bias:
- nn.init.constant_(conv_unit.bias, 0)
-
- if bn:
- if not preact:
- bn_unit = batch_norm(out_size)
- else:
- bn_unit = batch_norm(in_size)
- if instance_norm:
- if not preact:
- in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
- else:
- in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
-
- if preact:
- if bn:
- self.add_module(name + 'bn', bn_unit)
-
- if activation is not None:
- self.add_module(name + 'activation', activation)
-
- if not bn and instance_norm:
- self.add_module(name + 'in', in_unit)
-
- self.add_module(name + 'conv', conv_unit)
-
- if not preact:
- if bn:
- self.add_module(name + 'bn', bn_unit)
-
- if activation is not None:
- self.add_module(name + 'activation', activation)
-
- if not bn and instance_norm:
- self.add_module(name + 'in', in_unit)
-
-
- class _BNBase(nn.Sequential):
-
- def __init__(self, in_size, batch_norm=None, name=""):
- super().__init__()
- self.add_module(name + "bn", batch_norm(in_size))
-
- nn.init.constant_(self[0].weight, 1.0)
- nn.init.constant_(self[0].bias, 0)
-
-
- class BatchNorm1d(_BNBase):
-
- def __init__(self, in_size: int, *, name: str = ""):
- super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
-
-
- class BatchNorm2d(_BNBase):
-
- def __init__(self, in_size: int, name: str = ""):
- super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
-
-
- class Conv1d(_ConvBase):
-
- def __init__(
- self,
- in_size: int,
- out_size: int,
- *,
- kernel_size: int = 1,
- stride: int = 1,
- padding: int = 0,
- activation=nn.ReLU(inplace=True),
- bn: bool = False,
- init=nn.init.kaiming_normal_,
- bias: bool = True,
- preact: bool = False,
- name: str = "",
- instance_norm=False
- ):
- super().__init__(
- in_size,
- out_size,
- kernel_size,
- stride,
- padding,
- activation,
- bn,
- init,
- conv=nn.Conv1d,
- batch_norm=BatchNorm1d,
- bias=bias,
- preact=preact,
- name=name,
- instance_norm=instance_norm,
- instance_norm_func=nn.InstanceNorm1d
- )
-
-
- class Conv2d(_ConvBase):
-
- def __init__(
- self,
- in_size: int,
- out_size: int,
- *,
- kernel_size: Tuple[int, int] = (1, 1),
- stride: Tuple[int, int] = (1, 1),
- padding: Tuple[int, int] = (0, 0),
- activation=nn.ReLU(inplace=True),
- bn: bool = False,
- init=nn.init.kaiming_normal_,
- bias: bool = True,
- preact: bool = False,
- name: str = "",
- instance_norm=False
- ):
- super().__init__(
- in_size,
- out_size,
- kernel_size,
- stride,
- padding,
- activation,
- bn,
- init,
- conv=nn.Conv2d,
- batch_norm=BatchNorm2d,
- bias=bias,
- preact=preact,
- name=name,
- instance_norm=instance_norm,
- instance_norm_func=nn.InstanceNorm2d
- )
-
-
- class FC(nn.Sequential):
-
- def __init__(
- self,
- in_size: int,
- out_size: int,
- *,
- activation=nn.ReLU(inplace=True),
- bn: bool = False,
- init=None,
- preact: bool = False,
- name: str = ""
- ):
- super().__init__()
-
- fc = nn.Linear(in_size, out_size, bias=not bn)
- if init is not None:
- init(fc.weight)
- if not bn:
- nn.init.constant(fc.bias, 0)
-
- if preact:
- if bn:
- self.add_module(name + 'bn', BatchNorm1d(in_size))
-
- if activation is not None:
- self.add_module(name + 'activation', activation)
-
- self.add_module(name + 'fc', fc)
-
- if not preact:
- if bn:
- self.add_module(name + 'bn', BatchNorm1d(out_size))
-
- if activation is not None:
- self.add_module(name + 'activation', activation)
|