Browse Source

update high-level apis

pull/68/head
caixiaochen 1 year ago
parent
commit
b4ef358cb8
16 changed files with 451 additions and 138 deletions
  1. +23
    -1
      .idea/deployment.xml
  2. BIN
      autox/.DS_Store
  3. BIN
      autox/autox_video/.DS_Store
  4. +1
    -1
      autox/autox_video/.idea/autox.iml
  5. +1
    -1
      autox/autox_video/.idea/deployment.xml
  6. +1
    -1
      autox/autox_video/.idea/misc.xml
  7. +6
    -0
      autox/autox_video/.idea/vcs.xml
  8. +3
    -58
      autox/autox_video/AutoTest.py
  9. +2
    -2
      autox/autox_video/AutoTrain.py
  10. +308
    -0
      autox/autox_video/AutoXVideo.py
  11. +1
    -0
      autox/autox_video/__init__.py
  12. +0
    -0
      autox/autox_video/mmaction2/__init__.py
  13. BIN
      autox/autox_video/resources/.DS_Store
  14. +19
    -19
      autox/autox_video/tools/GetPipeline.py
  15. +61
    -0
      autox/autox_video/tools/Inference.py
  16. +25
    -55
      autox/autox_video/tools/temp.py

+ 23
- 1
.idea/deployment.xml View File

@@ -1,6 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:38670" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:38670" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false">
<option name="confirmBeforeUploading" value="false" />
<serverData>
<paths name="root@region-3.autodl.com:38670">
<serverdata>
@@ -9,6 +10,27 @@
</mappings>
</serverdata>
</paths>
<paths name="root@region-3.autodl.com:40829">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-3.autodl.com:40829 (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-3.autodl.com:51822">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>

BIN
autox/.DS_Store View File


BIN
autox/autox_video/.DS_Store View File


+ 1
- 1
autox/autox_video/.idea/autox.iml View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:38670/root/miniconda3/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

+ 1
- 1
autox/autox_video/.idea/deployment.xml View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="root@region-3.autodl.com:51822" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false" autoUploadExternalChanges="true">
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false" autoUploadExternalChanges="true">
<option name="confirmBeforeUploading" value="false" />
<serverData>
<paths name="root@region-3.autodl.com:51822">


+ 1
- 1
autox/autox_video/.idea/misc.xml View File

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:51822/root/miniconda3/bin/python)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://root@region-3.autodl.com:38670/root/miniconda3/bin/python)" project-jdk-type="Python SDK" />
</project>

+ 6
- 0
autox/autox_video/.idea/vcs.xml View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
</component>
</project>

+ 3
- 58
autox/autox_video/AutoTest.py View File

@@ -6,14 +6,11 @@ import warnings
from tools.GetPipeline import get_dataset_cfg
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv import Config
from mmcv.fileio.io import file_handlers
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.runner.fp16_utils import wrap_fp16_model
os.chdir('mmaction2')
from mmcv.runner import get_dist_info, init_dist
from tools.Inference import turn_off_pretrained, inference_pytorch
from mmaction.datasets import build_dataloader, build_dataset
from mmaction.models import build_model
from mmaction.utils import (build_ddp, build_dp, default_device,
register_module_hooks, setup_multi_processes)

@@ -25,8 +22,6 @@ except (ImportError, ModuleNotFoundError):
'collect_results_cpu, collect_results_gpu from mmaction2 will be '
'deprecated. Please install mmcv through master branch.')
from mmaction.apis import multi_gpu_test, single_gpu_test

os.chdir('..')
import yaml


@@ -45,52 +40,6 @@ def parse_args():
return args


def turn_off_pretrained(cfg):
# recursively find all pretrained in the model config,
# and set them None to avoid redundant pretrain steps for testing
if 'pretrained' in cfg:
cfg.pretrained = None

# recursively turn off pretrained value
for sub_cfg in cfg.values():
if isinstance(sub_cfg, dict):
turn_off_pretrained(sub_cfg)


def inference_pytorch(checkpoints, cfg, distributed, data_loader):
"""Get predictions by pytorch models."""
# remove redundant pretrain steps for testing
turn_off_pretrained(cfg.model)

# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))

if len(cfg.module_hooks) > 0:
register_module_hooks(model, cfg.module_hooks)

fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, checkpoints, map_location='cpu')


if not distributed:
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
outputs = single_gpu_test(model, data_loader)
else:
model = build_ddp(
model,
default_device,
default_args=dict(
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False))
outputs = multi_gpu_test(model, data_loader, 'tmp',
True)

return outputs

def main():
args = parse_args()
default_cfg = 'mmaction2/configs/TASK/swin.py'
@@ -116,10 +65,6 @@ def main():
assert checkpoints is not None and os.path.exists(checkpoints), \
'checkpoints not found in work dir %s, please check the work dir in config.yaml'%work_dir

resume_path = os.path.join(cus_cfg['work_dir'], 'latest.pth')
if os.path.exists(os.path.join(cus_cfg['work_dir'], 'latest.pth')):
cfg.resume_from = resume_path

# set multi-process settings
setup_multi_processes(cfg)



+ 2
- 2
autox/autox_video/AutoTrain.py View File

@@ -15,14 +15,14 @@ from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist, set_random_seed
from mmcv.utils import get_git_hash

os.chdir('mmaction2')
from mmaction import __version__
from mmaction.apis import init_random_seed, train_model
from mmaction.datasets import build_dataset
from mmaction.models import build_model
from mmaction.utils import (collect_env, get_root_logger,
register_module_hooks, setup_multi_processes)
os.chdir('..')


def parse_args():


+ 308
- 0
autox/autox_video/AutoXVideo.py View File

@@ -0,0 +1,308 @@
import copy
import argparse
from .tools.GetPipeline import get_dataset_cfg
from .tools.Inference import inference_pytorch
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
import yaml
import pickle
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist, set_random_seed
from mmcv.utils import get_git_hash
from mmcv.fileio.io import file_handlers

from mmaction import __version__
from mmaction.apis import init_random_seed, train_model
from mmaction.datasets import build_dataset, build_dataloader
from mmaction.models import build_model
from mmaction.utils import (collect_env, get_root_logger,
register_module_hooks, setup_multi_processes)


class AutoXVideo():
default_cfg = os.path.join(os.path.dirname(__file__), 'mmaction2/configs/TASK/swin.py')
def __init__(self):
self.cfg = Config.fromfile(self.default_cfg)
if not os.path.exists(self.cfg.load_from):
self.cfg.load_from = None

# set cudnn_benchmark
if self.cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

# The flag is used to determine whether it is omnisource training
self.cfg.setdefault('omnisource', False)

# The flag is used to register module's hooks
self.cfg.setdefault('module_hooks', [])

# default setting
self.cfg.total_epochs = 50
self.cfg.work_dir = 'work_dirs/default_work_dir'

self.cfg.num_class = -1
self.cfg.video_length = 2

def read_cfg(self, cfg_file):
with open(cfg_file, 'r') as cus_cfg:
cus_cfg = yaml.load(cus_cfg, Loader=yaml.FullLoader)
if 'epoch' in cus_cfg.keys():
self.cfg.total_epochs = cus_cfg['epoch']
if 'work_dir' in cus_cfg.keys():
self.cfg.work_dir = cus_cfg['work_dir']
resume_path = os.path.join(cus_cfg['work_dir'], 'latest.pth')
if os.path.exists(os.path.join(cus_cfg['work_dir'], 'latest.pth')):
self.cfg.resume_from = resume_path
if 'data_root' in cus_cfg.keys():
self.cfg.model.cls_head.num_classes = cus_cfg['num_class']
self.cfg.num_class = cus_cfg['num_class']
self.cfg.video_length=cus_cfg['video_length']
self.cfg.data = get_dataset_cfg(
data_root=cus_cfg['data_root'],
ann_file_train=cus_cfg.get('ann_file_train', None),
ann_file_val=cus_cfg.get('ann_file_val', None),
ann_file_test=cus_cfg.get('ann_file_test', None),
videos_per_gpu=cus_cfg.get('videos_per_gpu', 8),
video_length=cus_cfg.get('video_length', 2))

def fit(self,
data_root=None,
ann_file_train=None,
ann_file_val=None,
video_length=-1,
num_class=-1,
epoch=-1,
videos_per_gpu=-1,
evaluation=5,
work_dir=None,
distributed=False,
gpus=1):
if data_root is not None:
# specify custom dataset
assert ann_file_train is not None and video_length != -1 and num_class != -1,\
'Yon need to specify all dataset config by params or read_cfg().'
self.cfg.model.cls_head.num_classes = num_class
if ann_file_val != None:
flag_val = True
else:
flag_val = False
ann_file_val = ann_file_train
if self.cfg.num_class == -1 and videos_per_gpu == -1:
videos_per_gpu = 1
elif self.cfg.num_class != -1 and videos_per_gpu == -1:
videos_per_gpu = self.cfg.data.videos_per_gpu
self.cfg.data = get_dataset_cfg(
data_root=data_root,
ann_file_train=ann_file_train,
ann_file_val=ann_file_val,
ann_file_test=(self.cfg.get('data', {})).get('ann_file_test', None),
videos_per_gpu=videos_per_gpu,
video_length=video_length)

if epoch != -1:
self.cfg.total_epochs = epoch

if work_dir is not None:
self.cfg.work_dir = work_dir

assert self.cfg.work_dir is not None, 'work dir not specified.'

resume_path = os.path.join(self.cfg.work_dir, 'latest.pth')
if os.path.exists(os.path.join(self.cfg.work_dir, 'latest.pth')):
self.cfg.resume_from = resume_path

if evaluation == 0:
flag_val = False
else:
self.cfg.evaluation.interval = evaluation
flag_val = True

# set multi-process settings
setup_multi_processes(self.cfg)

# create work_dir
mmcv.mkdir_or_exist(osp.abspath(self.cfg.work_dir))
# init logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(self.cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=self.cfg.log_level)

# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info

# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config: {self.cfg.pretty_text}')

# set random seeds
seed = init_random_seed(42, distributed=distributed)
logger.info(f'Set random seed to {seed}')
set_random_seed(seed, deterministic=False)

self.cfg.seed = seed
meta['seed'] = seed
meta['work_dir'] = osp.basename(self.cfg.work_dir.rstrip('/\\'))

model = build_model(
self.cfg.model,
train_cfg=self.cfg.get('train_cfg'),
test_cfg=self.cfg.get('test_cfg'))

if len(self.cfg.module_hooks) > 0:
register_module_hooks(model, self.cfg.module_hooks)

if self.cfg.omnisource:
# If omnisource flag is set, cfg.data.train should be a list
assert isinstance(self.cfg.data.train, list)
datasets = [build_dataset(dataset) for dataset in self.cfg.data.train]
else:
datasets = [build_dataset(self.cfg.data.train)]

if len(self.cfg.workflow) == 2:
# For simplicity, omnisource is not compatible with val workflow,
# we recommend you to use `--validate`
assert not self.cfg.omnisource
val_dataset = copy.deepcopy(self.cfg.data.val)
datasets.append(build_dataset(val_dataset))
if self.cfg.checkpoint_config is not None:
# save mmaction version, config file content and class names in
# checkpoints as meta data
self.cfg.checkpoint_config.meta = dict(
mmaction_version=__version__ + get_git_hash(digits=7),
config=self.cfg.pretty_text)

with open(os.path.join(self.cfg.work_dir, 'cfg.pkl'),'wb') as file_pkl:
pickle.dump(self.cfg, file_pkl)

test_option = dict(test_last=False, test_best=False)
train_model(
model,
datasets,
self.cfg,
distributed=distributed,
validate=flag_val,
test=test_option,
timestamp=timestamp,
meta=meta)

def transform(self,
data_root=None, ann_file_test=None, video_length=-1,
checkpoints=None,
output_path='results.json',
distributed=False, gpus=1):
if checkpoints is not None:
assert os.path.exists(checkpoints), '%s not found.' % checkpoints
with open(os.path.join(os.path.dirname(checkpoints), 'cfg.pkl'), 'rb') as file_pkl:
cfg = pickle.load(file_pkl)
else:
cfg = self.cfg
assert os.path.exists(cfg.work_dir), 'checkpoints not specified.'
for file in os.listdir(cfg.work_dir):
if file.startswith('best') and file.endswith('pth'):
checkpoints = os.path.join(cfg.work_dir, file)
if checkpoints is None:
checkpoints = os.path.join(cfg.work_dir, 'latest.pth')
assert checkpoints is not None and os.path.exists(checkpoints), \
'checkpoints not found in work dir %s. ' % cfg.work_dir

if data_root is not None:
if ann_file_test is None:
warnings.warn('No annotations provided, top_k_accuracy is invalid.')
with open(os.path.join(data_root, 'temp_list.txt'), 'w') as file_temp:
for video in os.listdir(data_root):
if not os.path.isfile(os.path.join(data_root, video)) or video.endswith('txt'):
continue
file_temp.write('%s 0\n'%video)
ann_file_test = os.path.join(data_root, 'temp_list.txt')
if video_length == -1:
video_length = self.cfg.video_length
cfg.data = get_dataset_cfg(
data_root=data_root,
ann_file_test=ann_file_test,
videos_per_gpu=1,
video_length=video_length)
else:
cfg.data = self.cfg.data

# set multi-process settings
setup_multi_processes(cfg)

# Load output_config from cfg
output_config = cfg.get('output_config', {})
output_config = Config._merge_a_into_b(
dict(out=output_path), output_config)

# Load eval_config from cfg
eval_config = cfg.get('eval_config', {})
eval_config = Config._merge_a_into_b(
dict(metrics='top_k_accuracy'), eval_config)

dataset_type = cfg.data.test.type
if output_config.get('out', None):
if 'output_format' in output_config:
# ugly workround to make recognition and localization the same
warnings.warn(
'Skip checking `output_format` in localization task.')
else:
out = output_config['out']
# make sure the dirname of the output path exists
mmcv.mkdir_or_exist(osp.dirname(out))
_, suffix = osp.splitext(out)
if dataset_type == 'AVADataset':
assert suffix[1:] == 'csv', ('For AVADataset, the format of '
'the output file should be csv')
else:
assert suffix[1:] in file_handlers, (
'The format of the output '
'file should be json, pickle or yaml')

# set cudnn benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True

# The flag is used to register module's hooks
cfg.setdefault('module_hooks', [])

# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
dist=distributed,
shuffle=False)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)

outputs = inference_pytorch(checkpoints, cfg, distributed, data_loader)

rank, _ = get_dist_info()
if rank == 0:
if output_config.get('out', None):
out = output_config['out']
print(f'\nwriting results to {out}')
dataset.dump_results(outputs, **output_config)
if eval_config:
eval_res = dataset.evaluate(outputs, **eval_config)
for name, val in eval_res.items():
print(f'{name}: {val:.04f}')
if os.path.exists(os.path.join(data_root, 'temp_list.txt')):
os.remove(os.path.join(data_root, 'temp_list.txt'))

return outputs


+ 1
- 0
autox/autox_video/__init__.py View File

@@ -0,0 +1 @@
from .AutoXVideo import AutoXVideo

+ 0
- 0
autox/autox_video/mmaction2/__init__.py View File


BIN
autox/autox_video/resources/.DS_Store View File


+ 19
- 19
autox/autox_video/tools/GetPipeline.py View File

@@ -34,9 +34,9 @@ def get_pipeline_cfg(

def get_dataset_cfg(
data_root,
ann_file_train,
ann_file_val,
ann_file_test,
ann_file_train=None,
ann_file_val=None,
ann_file_test=None,
videos_per_gpu=8,
video_length=2):
train_pipeline = get_pipeline_cfg(video_length*4, 6, 1, False)
@@ -51,20 +51,20 @@ def get_dataset_cfg(
test_dataloader=dict(
videos_per_gpu=1,
workers_per_gpu=2
),
train=dict(
type='VideoDataset',
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type='VideoDataset',
ann_file=ann_file_val,
data_prefix=data_root,
pipeline=test_pipeline),
test=dict(
type='VideoDataset',
ann_file=ann_file_test,
data_prefix=data_root,
pipeline=test_pipeline))
))
if ann_file_train is not None:
data['train'] = dict(type='VideoDataset',
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline)
if ann_file_val is not None:
data['val'] = dict(type='VideoDataset',
ann_file=ann_file_val,
data_prefix=data_root,
pipeline=test_pipeline)
if ann_file_test is not None:
data['test'] = dict(type='VideoDataset',
ann_file=ann_file_test,
data_prefix=data_root,
pipeline=test_pipeline)
return data

+ 61
- 0
autox/autox_video/tools/Inference.py View File

@@ -0,0 +1,61 @@
import os
from mmcv.runner import load_checkpoint
from mmcv.runner.fp16_utils import wrap_fp16_model
from mmaction.models import build_model
from mmaction.utils import (build_ddp, build_dp, default_device,
register_module_hooks, setup_multi_processes)
import warnings
try:
from mmcv.engine import multi_gpu_test, single_gpu_test
except (ImportError, ModuleNotFoundError):
warnings.warn(
'DeprecationWarning: single_gpu_test, multi_gpu_test, '
'collect_results_cpu, collect_results_gpu from mmaction2 will be '
'deprecated. Please install mmcv through master branch.')
from mmaction.apis import multi_gpu_test, single_gpu_test

def turn_off_pretrained(cfg):
# recursively find all pretrained in the model config,
# and set them None to avoid redundant pretrain steps for testing
if 'pretrained' in cfg:
cfg.pretrained = None

# recursively turn off pretrained value
for sub_cfg in cfg.values():
if isinstance(sub_cfg, dict):
turn_off_pretrained(sub_cfg)


def inference_pytorch(checkpoints, cfg, distributed, data_loader):
"""Get predictions by pytorch models."""
# remove redundant pretrain steps for testing
turn_off_pretrained(cfg.model)

# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))

if len(cfg.module_hooks) > 0:
register_module_hooks(model, cfg.module_hooks)

fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, checkpoints, map_location='cpu')


if not distributed:
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
outputs = single_gpu_test(model, data_loader)
else:
model = build_ddp(
model,
default_device,
default_args=dict(
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False))
outputs = multi_gpu_test(model, data_loader, 'tmp',
True)

return outputs

+ 25
- 55
autox/autox_video/tools/temp.py View File

@@ -1,60 +1,30 @@
import os
import shutil
data_root = 'mmaction2/data/MMDS/MMDS-VIDEO'
ann_file_train = 'mmaction2/data/MMDS/annotations/mmds_cls_train_video_list.txt'
ann_file_test = 'mmaction2/data/MMDS/annotations/mmds_cls_val_video_list.txt'
from autox.autox_video import AutoXVideo

autox_video = AutoXVideo()

output_ann_dir = 'data/demo/annotations'
output_video_dir = 'data/demo/videos'
# Load dataset from cfg
autox_video.read_cfg('config.yaml')
autox_video.fit()
autox_video.transform()

count = [0 for _ in range(240)]
file_input = open(ann_file_train, 'r')
file_output = open(os.path.join(output_ann_dir,'train_list.txt'), 'w')
for line in file_input:
content = line.strip().split()
path = content[0]
label = int(content[1])
if label > 24:
break
if count[label] < 6:
count[label] += 1
name = os.path.basename(path)
file_output.write(' '.join([name,str(label)])+'\n')
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name))
file_input.close()
file_output.close()
# Manually specify datasets
autox_video.fit(
data_root='data/demo/videos',
ann_file_train='data/demo/annotations/train.txt',
ann_file_val='data/demo/annotations/val.txt',
video_length=2,
num_class=24,
videos_per_gpu=8
)
autox_video.transform(
data_root='data/demo/videos',
ann_file_test='data/demo/annotations/test.txt',
)

count = [0 for _ in range(240)]
file_input = open(ann_file_test, 'r')
file_output = open(os.path.join(output_ann_dir,'val_list.txt'), 'w')
for line in file_input:
content = line.strip().split()
path = content[0]
label = int(content[1])
if label > 24:
break
if count[label] < 2:
count[label] += 1
name = os.path.basename(path)
file_output.write(' '.join([name,str(label)])+'\n')
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name))
file_input.close()
file_output.close()
# transform only
autox_video.transform(
data_root='data/demo/videos',
ann_file_train='data/demo/annotations/test.txt',
checkpoints='work_dirs/demo/latest.pth'
)

count = [0 for _ in range(240)]
file_input = open(ann_file_test, 'r')
file_output = open(os.path.join(output_ann_dir,'test_list.txt'), 'w')
for line in file_input:
content = line.strip().split()
path = content[0]
label = int(content[1])
if label > 24:
break
if count[label] < 2:
count[label] += 1
name = os.path.basename(path)
file_output.write(' '.join([name,str(label)])+'\n')
shutil.copy(os.path.join(data_root, content[0]), os.path.join(output_video_dir,name))
file_input.close()
file_output.close()

Loading…
Cancel
Save