|
- import torch
- import cv2
- from models.yolo import Model
- from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
- smart_resume, torch_distributed_zero_first)
- from spikingjelly.clock_driven.functional import reset_net
- from utils.encoding import decodeFromDat
- from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_dataset, check_img_size, check_requirements,
- check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression,
- print_args, scale_boxes, xywh2xyxy, xyxy2xywh)
-
- weight_path = 'runs/train/exp32/weights/last.pt'
- spike_path = '/home/yanwq/Documents/github/matrix_vidar_host_subsystem/algorithm/example/detector/spike_sample/5.dat'
- model = torch.load(weight_path)['model'].cuda().float()
- # ema = ModelEMA(model)
-
- spike = decodeFromDat(spike_path, (32,250,400), 0).transpose(1,2,0)
- image = spike.mean(2)*255
- cv2.imwrite('xx.png', image)
- spike = cv2.resize(spike, (480,300), interpolation=cv2.INTER_NEAREST).transpose(2,0,1)
- spike = torch.FloatTensor(spike.copy()).cuda()[None,...]
- spike = torch.nn.functional.pad(spike, (0,0,0,180))
- image = spike.cpu().numpy()[0].mean(0)*255
- T = 8
- mean_pred = None
- for _ in range(T):
- pred, train_out = model(spike)
- if mean_pred is None:
- mean_pred = pred
- else:
- mean_pred += pred
- mean_pred /= T
- mean_pred = non_max_suppression(mean_pred,
- 0.1,
- 0.45)
-
- for pred in mean_pred[0]:
- image = cv2.rectangle(image,
- (int(pred[0].item()), int(pred[1].item())),
- (int(pred[2].item()), int(pred[3].item())), (255,255,255))
- cv2.imwrite('xx.png', image)
- print()
|