|
- #!/usr/bin/env python3
-
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- """Execute various operations (train, test, time, etc.) on a classification model."""
-
- import argparse
- import sys
-
- import xcom.core.builders as builders
- import xcom.core.config as config
- import xcom.core.distributed as dist
- import xcom.core.net as net
- import xcom.core.trainer as trainer
- import xcom.models.scaler as scaler
- from xcom.core.config import cfg
- import tools.compress as s_prune
- import xcom.core.checkpoint as cp
- import xcom.datasets.loader as data_loader
- import xcom.core.meters as meters
- import xcom.core.optimizer as optim
- def parse_args():
- """Parse command line options (mode and config)."""
- parser = argparse.ArgumentParser(description="Run a model.")
- help_s, choices = "Run mode", ["info", "train", "test", "time", "scale"]
- parser.add_argument("--mode", help=help_s, choices=choices, required=False, type=str,default='train')
- parser.add_argument("--method",required=False, type=str,default='l1')
- parser.add_argument("--stayed_channels", required=False, type=str,default=[32, 32, 32, 128, 32, 32, 32, 32,64, 64, 256, 64,64, 64, 64, 64, 64, 128, 128, 512, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 256, 256, 1024, 256, 256, 256, 256])
- parser.add_argument("--tr_scratch", required=False, default=False)
- help_s = "Config file location"
- parser.add_argument("--cfg", help=help_s, required=False, type=str,default='configs/dds_baselines/imagenet/R-50-imagenet.yaml')
- help_s = "See pycls/core/config.py for all options"
- parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
- # if len(sys.argv) == 1:
- # parser.print_help()
- # sys.exit(1)
- return parser.parse_args()
-
-
- def main():
- """Execute operation (train, test, time, etc.)."""
- args = parse_args()
- mode = args.mode
- config.load_cfg(args.cfg)
- cfg.merge_from_list(args.opts)
- config.assert_cfg()
- cfg.freeze()
- model_class=builders.get_model()
- if mode == "info":
- print(builders.get_model()())
- print("complexity:", net.complexity(builders.get_model()))
- elif mode == "train":
- print("complexity:", net.complexity(model_class().cuda()))
- #1
- dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model)
- #load一下再剪枝
- model,_=trainer.setup_model()
- file = cp.get_last_checkpoint()
- cp.load_checkpoint(file, model, None, None)
- # test_loader = data_loader.construct_test_loader()
- # test_meter = meters.TestMeter(len(test_loader))
- # trainer.test_epoch(test_loader,model,test_meter,-1)
- model_pruned=s_prune.compress(model,args.method,args.stayed_channels,args.tr_scratch)
- model_pruned=model_pruned.cuda()
- print("complexity:", net.complexity(model_pruned))
- elif mode == "test":
- dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model)
- elif mode == "time":
- dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model)
- elif mode == "scale":
- cfg.defrost()
- cx_orig = net.complexity(builders.get_model())
- scaler.scale_model()
- cx_scaled = net.complexity(builders.get_model())
- cfg_file = config.dump_cfg()
- print("Scaled config dumped to:", cfg_file)
- print("Original model complexity:", cx_orig)
- print("Scaled model complexity:", cx_scaled)
-
-
- if __name__ == "__main__":
- main()
|