|
- def get_config(mode: str = "xxs") -> dict:
- if mode == "xx_small":
- mv2_exp_mult = 2
- config = {
- "layer1": {
- "out_channels": 16,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 1,
- "stride": 1,
- "block_type": "mv2",
- },
- "layer2": {
- "out_channels": 24,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 3,
- "stride": 2,
- "block_type": "mv2",
- },
- "layer3": { # 28x28
- "out_channels": 48,
- "transformer_channels": 64,
- "ffn_dim": 128,
- "transformer_blocks": 2,
- "patch_h": 2, # 8,
- "patch_w": 2, # 8,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer4": { # 14x14
- "out_channels": 64,
- "transformer_channels": 80,
- "ffn_dim": 160,
- "transformer_blocks": 4,
- "patch_h": 2, # 4,
- "patch_w": 2, # 4,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer5": { # 7x7
- "out_channels": 80,
- "transformer_channels": 96,
- "ffn_dim": 192,
- "transformer_blocks": 3,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "last_layer_exp_factor": 4,
- "cls_dropout": 0.1
- }
- elif mode == "x_small":
- mv2_exp_mult = 4
- config = {
- "layer1": {
- "out_channels": 32,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 1,
- "stride": 1,
- "block_type": "mv2",
- },
- "layer2": {
- "out_channels": 48,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 3,
- "stride": 2,
- "block_type": "mv2",
- },
- "layer3": { # 28x28
- "out_channels": 64,
- "transformer_channels": 96,
- "ffn_dim": 192,
- "transformer_blocks": 2,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer4": { # 14x14
- "out_channels": 80,
- "transformer_channels": 120,
- "ffn_dim": 240,
- "transformer_blocks": 4,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer5": { # 7x7
- "out_channels": 96,
- "transformer_channels": 144,
- "ffn_dim": 288,
- "transformer_blocks": 3,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "last_layer_exp_factor": 4,
- "cls_dropout": 0.1
- }
- elif mode == "small":
- mv2_exp_mult = 4
- config = {
- "layer1": {
- "out_channels": 32,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 1,
- "stride": 1,
- "block_type": "mv2",
- },
- "layer2": {
- "out_channels": 64,
- "expand_ratio": mv2_exp_mult,
- "num_blocks": 3,
- "stride": 2,
- "block_type": "mv2",
- },
- "layer3": { # 28x28
- "out_channels": 96,
- "transformer_channels": 144,
- "ffn_dim": 288,
- "transformer_blocks": 2,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer4": { # 14x14
- "out_channels": 128,
- "transformer_channels": 192,
- "ffn_dim": 384,
- "transformer_blocks": 4,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "layer5": { # 7x7
- "out_channels": 160,
- "transformer_channels": 240,
- "ffn_dim": 480,
- "transformer_blocks": 3,
- "patch_h": 2,
- "patch_w": 2,
- "stride": 2,
- "mv_expand_ratio": mv2_exp_mult,
- "num_heads": 4,
- "block_type": "mobilevit",
- },
- "last_layer_exp_factor": 4,
- "cls_dropout": 0.1
- }
- else:
- raise NotImplementedError
-
- for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
- config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
-
- return config
|