|
- import os
- import torch
- import torchvision
- from PIL import Image
- from matplotlib import pyplot as plt
- from torch.utils.data import DataLoader
- import torch
- import numpy as np
- from torchvision import transforms
- from torchvision.transforms import ToTensor
- from torch.utils.data import Dataset, DataLoader
- from torch.nn.functional import one_hot
- import imageio
- import glob
- import os
-
-
- class MyDataset(Dataset):
- def __init__(self, images_path, labels_path, Transform=None):
- """"""
- # 在这里写,获得所有image路径,所有label路径的代码,并将路径放在分别放在images_path_list和labels_path_list中
- """"""
- self.images_path_list = glob.glob(os.path.join(images_path, '*.tif'))
- self.labels_path_list = glob.glob(os.path.join(labels_path, '*.tif'))
- self.transform = ToTensor()
-
- def __getitem__(self, index):
-
- image_path = self.images_path_list[index]
- label_path = self.labels_path_list[index]
-
- image = imageio.imread(image_path)
- label = imageio.imread(label_path)
-
- image = torch.from_numpy(image)
- label = torch.from_numpy(label)
- image = torch.permute(image, [2, 0, 1])
-
- # 4:tansform 参数一般为 transforms.ToTensor(),意思是上步image,label 转换为 tensor 类型
-
- # if self.transform is not None:
- # image = self.transform(image)
- # label = self.transform(label)
-
- # print(image.shape)
- # print(label.shape)
- label = torch.squeeze(label, 0)
-
- # label = torch.squeeze(label, 0)
- # print(label.shape)
- # label = one_hot(label.long(), num_classes=10)
- # label = torch.squeeze(label, 0)
- # label = np.transpose(label, ( 2, 0, 1))
-
- return image, label
-
- def __len__(self):
- return len(self.images_path_list)
-
- def plot_images(images):
- plt.figure(figsize=(32, 32))
- plt.imshow(torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- ], dim=-2).permute(1, 2, 0).cpu())
- plt.show()
-
-
- def save_images(images, path, **kwargs):
- grid = torchvision.utils.make_grid(images, **kwargs)
- ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
- im = Image.fromarray(ndarr)
- im.save(path)
-
-
- def get_data(args):
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.Resize(80), # args.image_size + 1/4 *args.image_size
- torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
- dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- return dataloader
-
-
- def setup_logging(run_name):
- os.makedirs("models", exist_ok=True)
- os.makedirs("results", exist_ok=True)
- os.makedirs(os.path.join("models", run_name), exist_ok=True)
- os.makedirs(os.path.join("results", run_name), exist_ok=True)
|