@@ -1,6 +1,7 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<project version="4"> | |||
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:38670" remoteFilesAllowedToDisappearOnAutoupload="false"> | |||
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:38670" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false"> | |||
<option name="confirmBeforeUploading" value="false" /> | |||
<serverData> | |||
<paths name="root@region-3.autodl.com:38670"> | |||
<serverdata> | |||
@@ -9,6 +10,27 @@ | |||
</mappings> | |||
</serverdata> | |||
</paths> | |||
<paths name="root@region-3.autodl.com:40829"> | |||
<serverdata> | |||
<mappings> | |||
<mapping local="$PROJECT_DIR$" web="/" /> | |||
</mappings> | |||
</serverdata> | |||
</paths> | |||
<paths name="root@region-3.autodl.com:40829 (2)"> | |||
<serverdata> | |||
<mappings> | |||
<mapping local="$PROJECT_DIR$" web="/" /> | |||
</mappings> | |||
</serverdata> | |||
</paths> | |||
<paths name="root@region-3.autodl.com:51822"> | |||
<serverdata> | |||
<mappings> | |||
<mapping local="$PROJECT_DIR$" web="/" /> | |||
</mappings> | |||
</serverdata> | |||
</paths> | |||
</serverData> | |||
<option name="myAutoUpload" value="ALWAYS" /> | |||
</component> |
@@ -2,7 +2,7 @@ | |||
<module type="PYTHON_MODULE" version="4"> | |||
<component name="NewModuleRootManager"> | |||
<content url="file://$MODULE_DIR$" /> | |||
<orderEntry type="inheritedJdk" /> | |||
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:38670/root/miniconda3/bin/python)" jdkType="Python SDK" /> | |||
<orderEntry type="sourceFolder" forTests="false" /> | |||
</component> | |||
</module> |
@@ -1,6 +1,6 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<project version="4"> | |||
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:51822" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false" autoUploadExternalChanges="true"> | |||
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false" autoUploadExternalChanges="true"> | |||
<option name="confirmBeforeUploading" value="false" /> | |||
<serverData> | |||
<paths name="root@region-3.autodl.com:51822"> | |||
@@ -1,4 +1,4 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<project version="4"> | |||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:51822/root/miniconda3/bin/python)" project-jdk-type="Python SDK" /> | |||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:38670/root/miniconda3/bin/python)" project-jdk-type="Python SDK" /> | |||
</project> |
@@ -0,0 +1,6 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<project version="4"> | |||
<component name="VcsDirectoryMappings"> | |||
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" /> | |||
</component> | |||
</project> |
@@ -6,14 +6,11 @@ import warnings | |||
from tools.GetPipeline import get_dataset_cfg | |||
import mmcv | |||
import torch | |||
from mmcv import Config, DictAction | |||
from mmcv.cnn import fuse_conv_bn | |||
from mmcv import Config | |||
from mmcv.fileio.io import file_handlers | |||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint | |||
from mmcv.runner.fp16_utils import wrap_fp16_model | |||
os.chdir('mmaction2') | |||
from mmcv.runner import get_dist_info, init_dist | |||
from tools.Inference import turn_off_pretrained, inference_pytorch | |||
from mmaction.datasets import build_dataloader, build_dataset | |||
from mmaction.models import build_model | |||
from mmaction.utils import (build_ddp, build_dp, default_device, | |||
register_module_hooks, setup_multi_processes) | |||
@@ -25,8 +22,6 @@ except (ImportError, ModuleNotFoundError): | |||
'collect_results_cpu, collect_results_gpu from mmaction2 will be ' | |||
'deprecated. Please install mmcv through master branch.') | |||
from mmaction.apis import multi_gpu_test, single_gpu_test | |||
os.chdir('..') | |||
import yaml | |||
@@ -45,52 +40,6 @@ def parse_args(): | |||
return args | |||
def turn_off_pretrained(cfg): | |||
# recursively find all pretrained in the model config, | |||
# and set them None to avoid redundant pretrain steps for testing | |||
if 'pretrained' in cfg: | |||
cfg.pretrained = None | |||
# recursively turn off pretrained value | |||
for sub_cfg in cfg.values(): | |||
if isinstance(sub_cfg, dict): | |||
turn_off_pretrained(sub_cfg) | |||
def inference_pytorch(checkpoints, cfg, distributed, data_loader): | |||
"""Get predictions by pytorch models.""" | |||
# remove redundant pretrain steps for testing | |||
turn_off_pretrained(cfg.model) | |||
# build the model and load checkpoint | |||
model = build_model( | |||
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) | |||
if len(cfg.module_hooks) > 0: | |||
register_module_hooks(model, cfg.module_hooks) | |||
fp16_cfg = cfg.get('fp16', None) | |||
if fp16_cfg is not None: | |||
wrap_fp16_model(model) | |||
load_checkpoint(model, checkpoints, map_location='cpu') | |||
if not distributed: | |||
model = build_dp( | |||
model, default_device, default_args=dict(device_ids=cfg.gpu_ids)) | |||
outputs = single_gpu_test(model, data_loader) | |||
else: | |||
model = build_ddp( | |||
model, | |||
default_device, | |||
default_args=dict( | |||
device_ids=[int(os.environ['LOCAL_RANK'])], | |||
broadcast_buffers=False)) | |||
outputs = multi_gpu_test(model, data_loader, 'tmp', | |||
True) | |||
return outputs | |||
def main(): | |||
args = parse_args() | |||
default_cfg = 'mmaction2/configs/TASK/swin.py' | |||
@@ -116,10 +65,6 @@ def main(): | |||
assert checkpoints is not None and os.path.exists(checkpoints), \ | |||
'checkpoints not found in work dir %s, please check the work dir in config.yaml'%work_dir | |||
resume_path = os.path.join(cus_cfg['work_dir'], 'latest.pth') | |||
if os.path.exists(os.path.join(cus_cfg['work_dir'], 'latest.pth')): | |||
cfg.resume_from = resume_path | |||
# set multi-process settings | |||
setup_multi_processes(cfg) | |||
@@ -15,14 +15,14 @@ from mmcv import Config, DictAction | |||
from mmcv.runner import get_dist_info, init_dist, set_random_seed | |||
from mmcv.utils import get_git_hash | |||
os.chdir('mmaction2') | |||
from mmaction import __version__ | |||
from mmaction.apis import init_random_seed, train_model | |||
from mmaction.datasets import build_dataset | |||
from mmaction.models import build_model | |||
from mmaction.utils import (collect_env, get_root_logger, | |||
register_module_hooks, setup_multi_processes) | |||
os.chdir('..') | |||
def parse_args(): | |||
@@ -0,0 +1,308 @@ | |||
import copy | |||
import argparse | |||
from .tools.GetPipeline import get_dataset_cfg | |||
from .tools.Inference import inference_pytorch | |||
import os | |||
import os.path as osp | |||
import time | |||
import warnings | |||
import mmcv | |||
import torch | |||
import yaml | |||
import pickle | |||
import torch.distributed as dist | |||
from mmcv import Config, DictAction | |||
from mmcv.runner import get_dist_info, init_dist, set_random_seed | |||
from mmcv.utils import get_git_hash | |||
from mmcv.fileio.io import file_handlers | |||
from mmaction import __version__ | |||
from mmaction.apis import init_random_seed, train_model | |||
from mmaction.datasets import build_dataset, build_dataloader | |||
from mmaction.models import build_model | |||
from mmaction.utils import (collect_env, get_root_logger, | |||
register_module_hooks, setup_multi_processes) | |||
class AutoXVideo(): | |||
default_cfg = os.path.join(os.path.dirname(__file__), 'mmaction2/configs/TASK/swin.py') | |||
def __init__(self): | |||
self.cfg = Config.fromfile(self.default_cfg) | |||
if not os.path.exists(self.cfg.load_from): | |||
self.cfg.load_from = None | |||
# set cudnn_benchmark | |||
if self.cfg.get('cudnn_benchmark', False): | |||
torch.backends.cudnn.benchmark = True | |||
# The flag is used to determine whether it is omnisource training | |||
self.cfg.setdefault('omnisource', False) | |||
# The flag is used to register module's hooks | |||
self.cfg.setdefault('module_hooks', []) | |||
# default setting | |||
self.cfg.total_epochs = 50 | |||
self.cfg.work_dir = 'work_dirs/default_work_dir' | |||
self.cfg.num_class = -1 | |||
self.cfg.video_length = 2 | |||
def read_cfg(self, cfg_file): | |||
with open(cfg_file, 'r') as cus_cfg: | |||
cus_cfg = yaml.load(cus_cfg, Loader=yaml.FullLoader) | |||
if 'epoch' in cus_cfg.keys(): | |||
self.cfg.total_epochs = cus_cfg['epoch'] | |||
if 'work_dir' in cus_cfg.keys(): | |||
self.cfg.work_dir = cus_cfg['work_dir'] | |||
resume_path = os.path.join(cus_cfg['work_dir'], 'latest.pth') | |||
if os.path.exists(os.path.join(cus_cfg['work_dir'], 'latest.pth')): | |||
self.cfg.resume_from = resume_path | |||
if 'data_root' in cus_cfg.keys(): | |||
self.cfg.model.cls_head.num_classes = cus_cfg['num_class'] | |||
self.cfg.num_class = cus_cfg['num_class'] | |||
self.cfg.video_length=cus_cfg['video_length'] | |||
self.cfg.data = get_dataset_cfg( | |||
data_root=cus_cfg['data_root'], | |||
ann_file_train=cus_cfg.get('ann_file_train', None), | |||
ann_file_val=cus_cfg.get('ann_file_val', None), | |||
ann_file_test=cus_cfg.get('ann_file_test', None), | |||
videos_per_gpu=cus_cfg.get('videos_per_gpu', 8), | |||
video_length=cus_cfg.get('video_length', 2)) | |||
def fit(self, | |||
data_root=None, | |||
ann_file_train=None, | |||
ann_file_val=None, | |||
video_length=-1, | |||
num_class=-1, | |||
epoch=-1, | |||
videos_per_gpu=-1, | |||
evaluation=5, | |||
work_dir=None, | |||
distributed=False, | |||
gpus=1): | |||
if data_root is not None: | |||
# specify custom dataset | |||
assert ann_file_train is not None and video_length != -1 and num_class != -1,\ | |||
'Yon need to specify all dataset config by params or read_cfg().' | |||
self.cfg.model.cls_head.num_classes = num_class | |||
if ann_file_val != None: | |||
flag_val = True | |||
else: | |||
flag_val = False | |||
ann_file_val = ann_file_train | |||
if self.cfg.num_class == -1 and videos_per_gpu == -1: | |||
videos_per_gpu = 1 | |||
elif self.cfg.num_class != -1 and videos_per_gpu == -1: | |||
videos_per_gpu = self.cfg.data.videos_per_gpu | |||
self.cfg.data = get_dataset_cfg( | |||
data_root=data_root, | |||
ann_file_train=ann_file_train, | |||
ann_file_val=ann_file_val, | |||
ann_file_test=(self.cfg.get('data', {})).get('ann_file_test', None), | |||
videos_per_gpu=videos_per_gpu, | |||
video_length=video_length) | |||
if epoch != -1: | |||
self.cfg.total_epochs = epoch | |||
if work_dir is not None: | |||
self.cfg.work_dir = work_dir | |||
assert self.cfg.work_dir is not None, 'work dir not specified.' | |||
resume_path = os.path.join(self.cfg.work_dir, 'latest.pth') | |||
if os.path.exists(os.path.join(self.cfg.work_dir, 'latest.pth')): | |||
self.cfg.resume_from = resume_path | |||
if evaluation == 0: | |||
flag_val = False | |||
else: | |||
self.cfg.evaluation.interval = evaluation | |||
flag_val = True | |||
# set multi-process settings | |||
setup_multi_processes(self.cfg) | |||
# create work_dir | |||
mmcv.mkdir_or_exist(osp.abspath(self.cfg.work_dir)) | |||
# init logger before other steps | |||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |||
log_file = osp.join(self.cfg.work_dir, f'{timestamp}.log') | |||
logger = get_root_logger(log_file=log_file, log_level=self.cfg.log_level) | |||
# init the meta dict to record some important information such as | |||
# environment info and seed, which will be logged | |||
meta = dict() | |||
# log env info | |||
env_info_dict = collect_env() | |||
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) | |||
dash_line = '-' * 60 + '\n' | |||
logger.info('Environment info:\n' + dash_line + env_info + '\n' + | |||
dash_line) | |||
meta['env_info'] = env_info | |||
# log some basic info | |||
logger.info(f'Distributed training: {distributed}') | |||
logger.info(f'Config: {self.cfg.pretty_text}') | |||
# set random seeds | |||
seed = init_random_seed(42, distributed=distributed) | |||
logger.info(f'Set random seed to {seed}') | |||
set_random_seed(seed, deterministic=False) | |||
self.cfg.seed = seed | |||
meta['seed'] = seed | |||
meta['work_dir'] = osp.basename(self.cfg.work_dir.rstrip('/\\')) | |||
model = build_model( | |||
self.cfg.model, | |||
train_cfg=self.cfg.get('train_cfg'), | |||
test_cfg=self.cfg.get('test_cfg')) | |||
if len(self.cfg.module_hooks) > 0: | |||
register_module_hooks(model, self.cfg.module_hooks) | |||
if self.cfg.omnisource: | |||
# If omnisource flag is set, cfg.data.train should be a list | |||
assert isinstance(self.cfg.data.train, list) | |||
datasets = [build_dataset(dataset) for dataset in self.cfg.data.train] | |||
else: | |||
datasets = [build_dataset(self.cfg.data.train)] | |||
if len(self.cfg.workflow) == 2: | |||
# For simplicity, omnisource is not compatible with val workflow, | |||
# we recommend you to use `--validate` | |||
assert not self.cfg.omnisource | |||
val_dataset = copy.deepcopy(self.cfg.data.val) | |||
datasets.append(build_dataset(val_dataset)) | |||
if self.cfg.checkpoint_config is not None: | |||
# save mmaction version, config file content and class names in | |||
# checkpoints as meta data | |||
self.cfg.checkpoint_config.meta = dict( | |||
mmaction_version=__version__ + get_git_hash(digits=7), | |||
config=self.cfg.pretty_text) | |||
with open(os.path.join(self.cfg.work_dir, 'cfg.pkl'),'wb') as file_pkl: | |||
pickle.dump(self.cfg, file_pkl) | |||
test_option = dict(test_last=False, test_best=False) | |||
train_model( | |||
model, | |||
datasets, | |||
self.cfg, | |||
distributed=distributed, | |||
validate=flag_val, | |||
test=test_option, | |||
timestamp=timestamp, | |||
meta=meta) | |||
def transform(self, | |||
data_root=None, ann_file_test=None, video_length=-1, | |||
checkpoints=None, | |||
output_path='results.json', | |||
distributed=False, gpus=1): | |||
if checkpoints is not None: | |||
assert os.path.exists(checkpoints), '%s not found.' % checkpoints | |||
with open(os.path.join(os.path.dirname(checkpoints), 'cfg.pkl'), 'rb') as file_pkl: | |||
cfg = pickle.load(file_pkl) | |||
else: | |||
cfg = self.cfg | |||
assert os.path.exists(cfg.work_dir), 'checkpoints not specified.' | |||
for file in os.listdir(cfg.work_dir): | |||
if file.startswith('best') and file.endswith('pth'): | |||
checkpoints = os.path.join(cfg.work_dir, file) | |||
if checkpoints is None: | |||
checkpoints = os.path.join(cfg.work_dir, 'latest.pth') | |||
assert checkpoints is not None and os.path.exists(checkpoints), \ | |||
'checkpoints not found in work dir %s. ' % cfg.work_dir | |||
if data_root is not None: | |||
if ann_file_test is None: | |||
warnings.warn('No annotations provided, top_k_accuracy is invalid.') | |||
with open(os.path.join(data_root, 'temp_list.txt'), 'w') as file_temp: | |||
for video in os.listdir(data_root): | |||
if not os.path.isfile(os.path.join(data_root, video)) or video.endswith('txt'): | |||
continue | |||
file_temp.write('%s 0\n'%video) | |||
ann_file_test = os.path.join(data_root, 'temp_list.txt') | |||
if video_length == -1: | |||
video_length = self.cfg.video_length | |||
cfg.data = get_dataset_cfg( | |||
data_root=data_root, | |||
ann_file_test=ann_file_test, | |||
videos_per_gpu=1, | |||
video_length=video_length) | |||
else: | |||
cfg.data = self.cfg.data | |||
# set multi-process settings | |||
setup_multi_processes(cfg) | |||
# Load output_config from cfg | |||
output_config = cfg.get('output_config', {}) | |||
output_config = Config._merge_a_into_b( | |||
dict(out=output_path), output_config) | |||
# Load eval_config from cfg | |||
eval_config = cfg.get('eval_config', {}) | |||
eval_config = Config._merge_a_into_b( | |||
dict(metrics='top_k_accuracy'), eval_config) | |||
dataset_type = cfg.data.test.type | |||
if output_config.get('out', None): | |||
if 'output_format' in output_config: | |||
# ugly workround to make recognition and localization the same | |||
warnings.warn( | |||
'Skip checking `output_format` in localization task.') | |||
else: | |||
out = output_config['out'] | |||
# make sure the dirname of the output path exists | |||
mmcv.mkdir_or_exist(osp.dirname(out)) | |||
_, suffix = osp.splitext(out) | |||
if dataset_type == 'AVADataset': | |||
assert suffix[1:] == 'csv', ('For AVADataset, the format of ' | |||
'the output file should be csv') | |||
else: | |||
assert suffix[1:] in file_handlers, ( | |||
'The format of the output ' | |||
'file should be json, pickle or yaml') | |||
# set cudnn benchmark | |||
if cfg.get('cudnn_benchmark', False): | |||
torch.backends.cudnn.benchmark = True | |||
cfg.data.test.test_mode = True | |||
# The flag is used to register module's hooks | |||
cfg.setdefault('module_hooks', []) | |||
# build the dataloader | |||
dataset = build_dataset(cfg.data.test, dict(test_mode=True)) | |||
dataloader_setting = dict( | |||
videos_per_gpu=cfg.data.get('videos_per_gpu', 1), | |||
workers_per_gpu=cfg.data.get('workers_per_gpu', 1), | |||
dist=distributed, | |||
shuffle=False) | |||
dataloader_setting = dict(dataloader_setting, | |||
**cfg.data.get('test_dataloader', {})) | |||
data_loader = build_dataloader(dataset, **dataloader_setting) | |||
outputs = inference_pytorch(checkpoints, cfg, distributed, data_loader) | |||
rank, _ = get_dist_info() | |||
if rank == 0: | |||
if output_config.get('out', None): | |||
out = output_config['out'] | |||
print(f'\nwriting results to {out}') | |||
dataset.dump_results(outputs, **output_config) | |||
if eval_config: | |||
eval_res = dataset.evaluate(outputs, **eval_config) | |||
for name, val in eval_res.items(): | |||
print(f'{name}: {val:.04f}') | |||
if os.path.exists(os.path.join(data_root, 'temp_list.txt')): | |||
os.remove(os.path.join(data_root, 'temp_list.txt')) | |||
return outputs | |||
@@ -0,0 +1 @@ | |||
from .AutoXVideo import AutoXVideo |
@@ -34,9 +34,9 @@ def get_pipeline_cfg( | |||
def get_dataset_cfg( | |||
data_root, | |||
ann_file_train, | |||
ann_file_val, | |||
ann_file_test, | |||
ann_file_train=None, | |||
ann_file_val=None, | |||
ann_file_test=None, | |||
videos_per_gpu=8, | |||
video_length=2): | |||
train_pipeline = get_pipeline_cfg(video_length*4, 6, 1, False) | |||
@@ -51,20 +51,20 @@ def get_dataset_cfg( | |||
test_dataloader=dict( | |||
videos_per_gpu=1, | |||
workers_per_gpu=2 | |||
), | |||
train=dict( | |||
type='VideoDataset', | |||
ann_file=ann_file_train, | |||
data_prefix=data_root, | |||
pipeline=train_pipeline), | |||
val=dict( | |||
type='VideoDataset', | |||
ann_file=ann_file_val, | |||
data_prefix=data_root, | |||
pipeline=test_pipeline), | |||
test=dict( | |||
type='VideoDataset', | |||
ann_file=ann_file_test, | |||
data_prefix=data_root, | |||
pipeline=test_pipeline)) | |||
)) | |||
if ann_file_train is not None: | |||
data['train'] = dict(type='VideoDataset', | |||
ann_file=ann_file_train, | |||
data_prefix=data_root, | |||
pipeline=train_pipeline) | |||
if ann_file_val is not None: | |||
data['val'] = dict(type='VideoDataset', | |||
ann_file=ann_file_val, | |||
data_prefix=data_root, | |||
pipeline=test_pipeline) | |||
if ann_file_test is not None: | |||
data['test'] = dict(type='VideoDataset', | |||
ann_file=ann_file_test, | |||
data_prefix=data_root, | |||
pipeline=test_pipeline) | |||
return data |
@@ -0,0 +1,61 @@ | |||
import os | |||
from mmcv.runner import load_checkpoint | |||
from mmcv.runner.fp16_utils import wrap_fp16_model | |||
from mmaction.models import build_model | |||
from mmaction.utils import (build_ddp, build_dp, default_device, | |||
register_module_hooks, setup_multi_processes) | |||
import warnings | |||
try: | |||
from mmcv.engine import multi_gpu_test, single_gpu_test | |||
except (ImportError, ModuleNotFoundError): | |||
warnings.warn( | |||
'DeprecationWarning: single_gpu_test, multi_gpu_test, ' | |||
'collect_results_cpu, collect_results_gpu from mmaction2 will be ' | |||
'deprecated. Please install mmcv through master branch.') | |||
from mmaction.apis import multi_gpu_test, single_gpu_test | |||
def turn_off_pretrained(cfg): | |||
# recursively find all pretrained in the model config, | |||
# and set them None to avoid redundant pretrain steps for testing | |||
if 'pretrained' in cfg: | |||
cfg.pretrained = None | |||
# recursively turn off pretrained value | |||
for sub_cfg in cfg.values(): | |||
if isinstance(sub_cfg, dict): | |||
turn_off_pretrained(sub_cfg) | |||
def inference_pytorch(checkpoints, cfg, distributed, data_loader): | |||
"""Get predictions by pytorch models.""" | |||
# remove redundant pretrain steps for testing | |||
turn_off_pretrained(cfg.model) | |||
# build the model and load checkpoint | |||
model = build_model( | |||
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) | |||
if len(cfg.module_hooks) > 0: | |||
register_module_hooks(model, cfg.module_hooks) | |||
fp16_cfg = cfg.get('fp16', None) | |||
if fp16_cfg is not None: | |||
wrap_fp16_model(model) | |||
load_checkpoint(model, checkpoints, map_location='cpu') | |||
if not distributed: | |||
model = build_dp( | |||
model, default_device, default_args=dict(device_ids=cfg.gpu_ids)) | |||
outputs = single_gpu_test(model, data_loader) | |||
else: | |||
model = build_ddp( | |||
model, | |||
default_device, | |||
default_args=dict( | |||
device_ids=[int(os.environ['LOCAL_RANK'])], | |||
broadcast_buffers=False)) | |||
outputs = multi_gpu_test(model, data_loader, 'tmp', | |||
True) | |||
return outputs |
@@ -1,60 +1,30 @@ | |||
import os | |||
import shutil | |||
data_root = 'mmaction2/data/MMDS/MMDS-VIDEO' | |||
ann_file_train = 'mmaction2/data/MMDS/annotations/mmds_cls_train_video_list.txt' | |||
ann_file_test = 'mmaction2/data/MMDS/annotations/mmds_cls_val_video_list.txt' | |||
from autox.autox_video import AutoXVideo | |||
autox_video = AutoXVideo() | |||
output_ann_dir = 'data/demo/annotations' | |||
output_video_dir = 'data/demo/videos' | |||
# Load dataset from cfg | |||
autox_video.read_cfg('config.yaml') | |||
autox_video.fit() | |||
autox_video.transform() | |||
count = [0 for _ in range(240)] | |||
file_input = open(ann_file_train, 'r') | |||
file_output = open(os.path.join(output_ann_dir,'train_list.txt'), 'w') | |||
for line in file_input: | |||
content = line.strip().split() | |||
path = content[0] | |||
label = int(content[1]) | |||
if label > 24: | |||
break | |||
if count[label] < 6: | |||
count[label] += 1 | |||
name = os.path.basename(path) | |||
file_output.write(' '.join([name,str(label)])+'\n') | |||
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name)) | |||
file_input.close() | |||
file_output.close() | |||
# Manually specify datasets | |||
autox_video.fit( | |||
data_root='data/demo/videos', | |||
ann_file_train='data/demo/annotations/train.txt', | |||
ann_file_val='data/demo/annotations/val.txt', | |||
video_length=2, | |||
num_class=24, | |||
videos_per_gpu=8 | |||
) | |||
autox_video.transform( | |||
data_root='data/demo/videos', | |||
ann_file_test='data/demo/annotations/test.txt', | |||
) | |||
count = [0 for _ in range(240)] | |||
file_input = open(ann_file_test, 'r') | |||
file_output = open(os.path.join(output_ann_dir,'val_list.txt'), 'w') | |||
for line in file_input: | |||
content = line.strip().split() | |||
path = content[0] | |||
label = int(content[1]) | |||
if label > 24: | |||
break | |||
if count[label] < 2: | |||
count[label] += 1 | |||
name = os.path.basename(path) | |||
file_output.write(' '.join([name,str(label)])+'\n') | |||
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name)) | |||
file_input.close() | |||
file_output.close() | |||
# transform only | |||
autox_video.transform( | |||
data_root='data/demo/videos', | |||
ann_file_train='data/demo/annotations/test.txt', | |||
checkpoints='work_dirs/demo/latest.pth' | |||
) | |||
count = [0 for _ in range(240)] | |||
file_input = open(ann_file_test, 'r') | |||
file_output = open(os.path.join(output_ann_dir,'test_list.txt'), 'w') | |||
for line in file_input: | |||
content = line.strip().split() | |||
path = content[0] | |||
label = int(content[1]) | |||
if label > 24: | |||
break | |||
if count[label] < 2: | |||
count[label] += 1 | |||
name = os.path.basename(path) | |||
file_output.write(' '.join([name,str(label)])+'\n') | |||
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name)) | |||
file_input.close() | |||
file_output.close() |