|
- # coding: utf-8
- from PIL import Image
- from torch.utils.data import Dataset
- import os.path as osp
- # import torch
- # import json
- # import prototype.spring.linklink as link
- # import os
-
-
- # 集成Dataset类
- class ImageDataset(Dataset):
- def __init__(self, root, txt_path, transform=None, target_transform=None):
- """
- tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
- transform:数据处理,对图像进行随机剪裁,以及转换成tensor
- """
- self.root = root
- self.transform = transform
- # self.evaluator = evaluator
- imgs = []
-
- with open(txt_path) as f:
- lines = f.readlines()
-
- self.num = len(lines)
- self.metas = []
- for line in lines:
- filename, label = line.rstrip().split()
- # self.metas.append({'filename': filename, 'label': label})
- imgs.append((filename, int(label)))
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
-
- def __getitem__(self, index):
-
- filename, label = self.imgs[index]
- filename = osp.join(self.root, filename)
- label = int(label)
-
- img = Image.open(filename).convert('RGB') # 把图像转成RGB
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- label = self.transform(label)
-
- return img, label# 这就返回一个样本
-
- def __len__(self):
- return self.num # 返回长度,index就会自动的指导读取多少
-
- # def dump(self, writer, output):
- # prediction = self.tensor2numpy(output['prediction'])
- # label = self.tensor2numpy(output['label'])
- # score = self.tensor2numpy(output['score'])
- #
- # if 'filename' in output:
- # # pytorch type: {'image', 'label', 'filename', 'image_id'}
- # filename = output['filename']
- # image_id = output['image_id']
- # for _idx in range(prediction.shape[0]):
- # res = {
- # 'filename': filename[_idx],
- # 'image_id': int(image_id[_idx]),
- # 'prediction': int(prediction[_idx]),
- # 'label': int(label[_idx]),
- # 'score': [float('%.8f' % s) for s in score[_idx]],
- # }
- # writer.write(json.dumps(res, ensure_ascii=False) + '\n')
- # else:
- # # dali type: {'image', 'label'}
- # for _idx in range(prediction.shape[0]):
- # res = {
- # 'prediction': int(prediction[_idx]),
- # 'label': int(label[_idx]),
- # 'score': [float('%.8f' % s) for s in score[_idx]],
- # }
- # writer.write(json.dumps(res, ensure_ascii=False) + '\n')
- # writer.flush()
- #
- # def tensor2numpy(self, x):
- # if x is None:
- # return x
- # if torch.is_tensor(x):
- # return x.cpu().numpy()
- # if isinstance(x, list):
- # x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
- # return x
- #
- # def merge(self, prefix):
- # """
- # Merge results into one file.
- #
- # Arguments:
- # - prefix (:obj:`str`): dir/results.rank
- # """
- # world_size = link.get_world_size()
- # merged_file = prefix.rsplit('.', 1)[0] + '.all'
- # merged_fd = open(merged_file, 'w')
- # for rank in range(world_size):
- # res_file = prefix + str(rank)
- # assert os.path.exists(res_file), f'No such file or directory: {res_file}'
- # with open(res_file, 'r') as fin:
- # for line_idx, line in enumerate(fin):
- # merged_fd.write(line)
- # merged_fd.close()
- # return merged_file
- #
- # def evaluate(self, res_file):
- # """
- # Arguments:
- # - res_file (:obj:`str`): filename of result
- # """
- # prefix = res_file.rstrip('0123456789')
- # merged_res_file = self.merge(prefix)
- #
- # metrics = self.evaluator.eval(merged_res_file) if self.evaluator else {}
- # return metrics #返回的metrics.metric报错:AttributeError: 'dict' object has no attribute 'metric'
|