|
- #!/usr/bin/env python
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- """
- A script to benchmark builtin models.
-
- Note: this script has an extra dependency of psutil.
- """
-
- import itertools
- import logging
- import psutil
- import torch
- import tqdm
- from fvcore.common.timer import Timer
- from torch.nn.parallel import DistributedDataParallel
-
- from detectron2.checkpoint import DetectionCheckpointer
- from detectron2.config import get_cfg
- from detectron2.data import (
- DatasetFromList,
- build_detection_test_loader,
- build_detection_train_loader,
- )
- from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch
- from detectron2.modeling import build_model
- from detectron2.solver import build_optimizer
- from detectron2.utils import comm
- from detectron2.utils.events import CommonMetricPrinter
- from detectron2.utils.logger import setup_logger
-
- logger = logging.getLogger("detectron2")
-
-
- def setup(args):
- cfg = get_cfg()
- cfg.merge_from_file(args.config_file)
- cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway.
- cfg.merge_from_list(args.opts)
- cfg.freeze()
- setup_logger(distributed_rank=comm.get_rank())
- return cfg
-
-
- def benchmark_data(args):
- cfg = setup(args)
-
- timer = Timer()
- dataloader = build_detection_train_loader(cfg)
- logger.info("Initialize loader using {} seconds.".format(timer.seconds()))
-
- timer.reset()
- itr = iter(dataloader)
- for i in range(10): # warmup
- next(itr)
- if i == 0:
- startup_time = timer.seconds()
- timer = Timer()
- max_iter = 1000
- for _ in tqdm.trange(max_iter):
- next(itr)
- logger.info(
- "{} iters ({} images) in {} seconds.".format(
- max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds()
- )
- )
- logger.info("Startup time: {} seconds".format(startup_time))
- vram = psutil.virtual_memory()
- logger.info(
- "RAM Usage: {:.2f}/{:.2f} GB".format(
- (vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3
- )
- )
-
- # test for a few more rounds
- for _ in range(10):
- timer = Timer()
- max_iter = 1000
- for _ in tqdm.trange(max_iter):
- next(itr)
- logger.info(
- "{} iters ({} images) in {} seconds.".format(
- max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds()
- )
- )
-
-
- def benchmark_train(args):
- cfg = setup(args)
- model = build_model(cfg)
- logger.info("Model:\n{}".format(model))
- if comm.get_world_size() > 1:
- model = DistributedDataParallel(
- <<<<<<< HEAD
- model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
- find_unused_parameters=True
- =======
- model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
- >>>>>>> 7b936afd5b423c3188687d8b529a984bed528a87
- )
- optimizer = build_optimizer(cfg, model)
- checkpointer = DetectionCheckpointer(model, optimizer=optimizer)
- checkpointer.load(cfg.MODEL.WEIGHTS)
-
- cfg.defrost()
- cfg.DATALOADER.NUM_WORKERS = 0
- data_loader = build_detection_train_loader(cfg)
- dummy_data = list(itertools.islice(data_loader, 100))
-
- def f():
- data = DatasetFromList(dummy_data, copy=False)
- while True:
- yield from data
-
- max_iter = 400
- trainer = SimpleTrainer(model, f(), optimizer)
- trainer.register_hooks(
- [hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])]
- )
- trainer.train(1, max_iter)
-
-
- @torch.no_grad()
- def benchmark_eval(args):
- cfg = setup(args)
- model = build_model(cfg)
- model.eval()
- logger.info("Model:\n{}".format(model))
- DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
-
- cfg.defrost()
- cfg.DATALOADER.NUM_WORKERS = 0
- data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
- dummy_data = list(itertools.islice(data_loader, 100))
-
- def f():
- while True:
- yield from DatasetFromList(dummy_data, copy=False)
-
- for _ in range(5): # warmup
- model(dummy_data[0])
-
- max_iter = 400
- timer = Timer()
- with tqdm.tqdm(total=max_iter) as pbar:
- for idx, d in enumerate(f()):
- if idx == max_iter:
- break
- model(d)
- pbar.update()
- logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds()))
-
-
- if __name__ == "__main__":
- parser = default_argument_parser()
- parser.add_argument("--task", choices=["train", "eval", "data"], required=True)
- args = parser.parse_args()
- assert not args.eval_only
-
- if args.task == "data":
- f = benchmark_data
- elif args.task == "train":
- """
- Note: training speed may not be representative.
- The training cost of a R-CNN model varies with the content of the data
- and the quality of the model.
- """
- f = benchmark_train
- elif args.task == "eval":
- f = benchmark_eval
- # only benchmark single-GPU inference.
- assert args.num_gpus == 1 and args.num_machines == 1
- launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,))
|