diff --git a/my_dataset.py b/my_dataset.py index 9a3f3c7..d545892 100644 --- a/my_dataset.py +++ b/my_dataset.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- - +""" +各数据集的Dataset定义 +""" +import numpy as np +import torch import os import random from PIL import Image @@ -7,98 +11,79 @@ from torch.utils.data import Dataset random.seed(1) rmb_label = {"1": 0, "100": 1} -ants_label={'ants':0, 'bees':1} - -class RMBDataset(Dataset): - def __init__(self, data_dir, transform=None): - """ - rmb面额分类任务的Dataset - :param data_dir: str, 数据集所在路径 - :param transform: torch.transform,数据预处理 - """ - # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本 - self.data_info = self.get_img_info(data_dir) - self.transform = transform - - def __getitem__(self, index): - # 通过 index 读取样本 - path_img, label = self.data_info[index] - # 注意这里需要 convert('RGB') - img = Image.open(path_img).convert('RGB') # 0~255 - if self.transform is not None: - img = self.transform(img) # 在这里做transform,转为tensor等等 - # 返回是样本和标签 - return img, label - # 返回所有样本的数量 - def __len__(self): - return len(self.data_info) - - @staticmethod - def get_img_info(data_dir): - data_info = list() - # data_dir 是训练集、验证集或者测试集的路径 - for root, dirs, _ in os.walk(data_dir): - # 遍历类别 - # dirs ['1', '100'] - for sub_dir in dirs: - # 文件列表 - img_names = os.listdir(os.path.join(root, sub_dir)) - # 取出 jpg 结尾的文件 - img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) - # 遍历图片 - for i in range(len(img_names)): - img_name = img_names[i] - # 图片的绝对路径 - path_img = os.path.join(root, sub_dir, img_name) - # 标签,这里需要映射为 0、1 两个类别 - label = rmb_label[sub_dir] - # 保存在 data_info 变量中 - data_info.append((path_img, int(label))) - return data_info - -class AntsDataset(Dataset): - def __init__(self,data_dir, transform=None): - self.label_name={'ants':0, 'bees':1} - self.data_info=self.get_item_info(data_dir) - self.transform=transform - - def __getitem__(self, index): - path,label=self.data_info[index] - img = Image.open(path).convert('RGB') - if self.transform is not None: - img=self.transform(img) - return img, label - @staticmethod - def get_item_info(data_dir): - data_info=list() - for root,dirs,_ in os.walk(data_dir): - for sub_dir in dirs: - img_names=os.listdir(os.path.join(root,sub_dir)) - img_names=list(filter(lambda x:x.endswith('.jpg'), img_names)) +class PennFudanDataset(object): + def __init__(self, data_dir, transforms): - for i in range(len(img_names)): - path_img=os.path.join(root,sub_dir,img_names[i]) - label=ants_label[sub_dir] - data_info.append((path_img, int(label))) + self.data_dir = data_dir + self.transforms = transforms + self.img_dir = os.path.join(data_dir, "PNGImages") + self.txt_dir = os.path.join(data_dir, "Annotation") + # 保存所有图片的文件名,后面用于查找对应的 txt 标签文件 + self.names = [name[:-4] for name in list(filter(lambda x: x.endswith(".png"), os.listdir(self.img_dir)))] - if len(data_info)==0: - raise Exception('\ndata_dir:{} is a empty dir! please check your image paths!'.format(data_dir)) + def __getitem__(self, index): + """ + 返回img和target + :param idx: + :return: + """ - return data_info + name = self.names[index] + path_img = os.path.join(self.img_dir, name + ".png") + path_txt = os.path.join(self.txt_dir, name + ".txt") + + # load img + img = Image.open(path_img).convert("RGB") + + # load boxes and label + f = open(path_txt, "r") + import re + # 查找每一行是否有数字,有数字的则是带有标签的行 + points = [re.findall(r"\d+", line) for line in f.readlines() if "Xmin" in line] + boxes_list = list() + for point in points: + box = [int(p) for p in point] + boxes_list.append(box[-4:]) + boxes = torch.tensor(boxes_list, dtype=torch.float) + labels = torch.ones((boxes.shape[0],), dtype=torch.long) + + # iscrowd = torch.zeros((num_objs,), dtype=torch.int64) + target = {} + # 组成 label,是一个 dict,包括 boxes 和 labels + target["boxes"] = boxes + target["labels"] = labels + # target["iscrowd"] = iscrowd + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target def __len__(self): - return len(self.data_info) - - - - - + if len(self.names) == 0: + raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir)) + return len(self.names) +class CelebADataset(object): + def __init__(self, data_dir, transforms): + self.data_dir = data_dir + self.transform = transforms + self.img_names = [name for name in list(filter(lambda x: x.endswith(".jpg"), os.listdir(self.data_dir)))] + def __getitem__(self, index): + path_img = os.path.join(self.data_dir, self.img_names[index]) + img = Image.open(path_img).convert('RGB') + if self.transform is not None: + img = self.transform(img) + return img + def __len__(self): + if len(self.img_names) == 0: + raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir)) + return len(self.img_names) \ No newline at end of file diff --git a/resnet_inference.py b/resnet_inference.py new file mode 100644 index 0000000..a421bc9 --- /dev/null +++ b/resnet_inference.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +import os +import time +import torch.nn as nn +import torch +import torchvision.transforms as transforms +from PIL import Image +from matplotlib import pyplot as plt +import torchvision.models as models +import enviroments +import torchsummary +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") + +# config +vis = True +# vis = False +vis_row = 4 + +norm_mean = [0.485, 0.456, 0.406] +norm_std = [0.229, 0.224, 0.225] + +inference_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(norm_mean, norm_std), +]) + +classes = ["ants", "bees"] + + +def img_transform(img_rgb, transform=None): + """ + 将数据转换为模型读取的形式 + :param img_rgb: PIL Image + :param transform: torchvision.transform + :return: tensor + """ + + if transform is None: + raise ValueError("找不到transform!必须有transform对img进行处理") + + img_t = transform(img_rgb) + return img_t + + +def get_img_name(img_dir, format="jpg"): + """ + 获取文件夹下format格式的文件名 + :param img_dir: str + :param format: str + :return: list + """ + file_names = os.listdir(img_dir) + # 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件 + img_names = list(filter(lambda x: x.endswith(format), file_names)) + + if len(img_names) < 1: + raise ValueError("{}下找不到{}格式数据".format(img_dir, format)) + return img_names + + +def get_model(m_path, vis_model=False): + + resnet18 = models.resnet18() + # torchsummary.summary(resnet18, (3,224,224)) + # 修改全连接层的输出 + num_ftrs = resnet18.fc.in_features + resnet18.fc = nn.Linear(num_ftrs, 2) + + # 加载模型参数 + checkpoint = torch.load(m_path) + resnet18.load_state_dict(checkpoint['model_state_dict']) + + + if vis_model: + from torchsummary import summary + summary(resnet18, input_size=(3, 224, 224), device="cpu") + + return resnet18 + + +if __name__ == "__main__": + + img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees") + model_path = "./checkpoint_14_epoch.pkl" + time_total = 0 + img_list, img_pred = list(), list() + + # 1. data + img_names = get_img_name(img_dir) + num_img = len(img_names) + + # 2. model + resnet18 = get_model(model_path, True) + resnet18.to(device) + resnet18.eval() + + with torch.no_grad(): + for idx, img_name in enumerate(img_names): + + path_img = os.path.join(img_dir, img_name) + + # step 1/4 : path --> img + img_rgb = Image.open(path_img).convert('RGB') + + # step 2/4 : img --> tensor + img_tensor = img_transform(img_rgb, inference_transform) + img_tensor.unsqueeze_(0) + img_tensor = img_tensor.to(device) + + # step 3/4 : tensor --> vector + time_tic = time.time() + outputs = resnet18(img_tensor) + time_toc = time.time() + + # step 4/4 : visualization + _, pred_int = torch.max(outputs.data, 1) + pred_str = classes[int(pred_int)] + + if vis: + img_list.append(img_rgb) + img_pred.append(pred_str) + + if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1: + for i in range(len(img_list)): + plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i]) + plt.title("predict:{}".format(img_pred[i])) + plt.show() + plt.close() + img_list, img_pred = list(), list() + + time_s = time_toc-time_tic + time_total += time_s + + print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s)) + + print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s". + format(device, time_total, time_total/num_img)) + if torch.cuda.is_available(): + print("GPU name:{}".format(torch.cuda.get_device_name())) + diff --git a/rnn_demo.py b/rnn_demo.py new file mode 100644 index 0000000..0fecf6e --- /dev/null +++ b/rnn_demo.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +""" +RNN 实现人名分类 +""" +from io import open +import glob +import unicodedata +import string +import math +import os +import time +import torch.nn as nn +import torch +import random +import matplotlib.pyplot as plt +import torch.utils.data +from common_tools import set_seed +import enviroments + +set_seed(1) # 设置随机种子 +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") + + +# Read a file and split into lines +def readLines(filename): + lines = open(filename, encoding='utf-8').read().strip().split('\n') + return [unicodeToAscii(line) for line in lines] + + +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + and c in all_letters) + + +# Find letter index from all_letters, e.g. "a" = 0 +def letterToIndex(letter): + return all_letters.find(letter) + + +# Just for demonstration, turn a letter into a <1 x n_letters> Tensor +def letterToTensor(letter): + tensor = torch.zeros(1, n_letters) + tensor[0][letterToIndex(letter)] = 1 + return tensor + + +# Turn a line into a , +# or an array of one-hot letter vectors +def lineToTensor(line): + tensor = torch.zeros(len(line), 1, n_letters) + for li, letter in enumerate(line): + tensor[li][0][letterToIndex(letter)] = 1 + return tensor + + +def categoryFromOutput(output): + top_n, top_i = output.topk(1) + category_i = top_i[0].item() + return all_categories[category_i], category_i + + +def randomChoice(l): + return l[random.randint(0, len(l) - 1)] + + +def randomTrainingExample(): + category = randomChoice(all_categories) # 选类别 + line = randomChoice(category_lines[category]) # 选一个样本 + category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) + line_tensor = lineToTensor(line) # str to one-hot + return category, line, category_tensor, line_tensor + + +def timeSince(since): + now = time.time() + s = now - since + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + +# Just return an output given a line +def evaluate(line_tensor): + hidden = rnn.initHidden() + + for i in range(line_tensor.size()[0]): + output, hidden = rnn(line_tensor[i], hidden) + + return output + + +def predict(input_line, n_predictions=3): + print('\n> %s' % input_line) + with torch.no_grad(): + output = evaluate(lineToTensor(input_line)) + + # Get top N categories + topv, topi = output.topk(n_predictions, 1, True) + + for i in range(n_predictions): + value = topv[0][i].item() + category_index = topi[0][i].item() + print('(%.2f) %s' % (value, all_categories[category_index])) + + +def get_lr(iter, learning_rate): + lr_iter = learning_rate if iter < n_iters else learning_rate*0.1 + return lr_iter + +class RNN(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(RNN, self).__init__() + + self.hidden_size = hidden_size + + self.u = nn.Linear(input_size, hidden_size) + self.w = nn.Linear(hidden_size, hidden_size) + self.v = nn.Linear(hidden_size, output_size) + + self.tanh = nn.Tanh() + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, hidden): + + u_x = self.u(inputs) + + hidden = self.w(hidden) + hidden = self.tanh(hidden + u_x) + + output = self.softmax(self.v(hidden)) + + return output, hidden + + def initHidden(self): + return torch.zeros(1, self.hidden_size) + + +def train(category_tensor, line_tensor): + hidden = rnn.initHidden() + + rnn.zero_grad() + + line_tensor = line_tensor.to(device) + hidden = hidden.to(device) + category_tensor = category_tensor.to(device) + + for i in range(line_tensor.size()[0]): + output, hidden = rnn(line_tensor[i], hidden) + + loss = criterion(output, category_tensor) + loss.backward() + + # Add parameters' gradients to their values, multiplied by learning rate + for p in rnn.parameters(): + p.data.add_(-learning_rate, p.grad.data) + + return output, loss.item() + + +if __name__ == "__main__": + # config + path_txt = os.path.join(enviroments.names,"*.txt") + all_letters = string.ascii_letters + " .,;'" + n_letters = len(all_letters) # 52 + 5 字符总数 + print_every = 5000 + plot_every = 5000 + learning_rate = 0.005 + n_iters = 200000 + + # step 1 data + # Build the category_lines dictionary, a list of names per language + category_lines = {} + all_categories = [] + for filename in glob.glob(path_txt): + category = os.path.splitext(os.path.basename(filename))[0] + all_categories.append(category) + lines = readLines(filename) + category_lines[category] = lines + + n_categories = len(all_categories) + + # step 2 model + n_hidden = 128 + # rnn = RNN(n_letters, n_hidden, n_categories) + rnn = RNN(n_letters, n_hidden, n_categories) + + rnn.to(device) + + # step 3 loss + criterion = nn.NLLLoss() + + # step 4 optimize by hand + + # step 5 iteration + current_loss = 0 + all_losses = [] + start = time.time() + for iter in range(1, n_iters + 1): + # sample + category, line, category_tensor, line_tensor = randomTrainingExample() + + # training + output, loss = train(category_tensor, line_tensor) + + current_loss += loss + + # Print iter number, loss, name and guess + if iter % print_every == 0: + guess, guess_i = categoryFromOutput(output) + correct = '✓' if guess == category else '✗ (%s)' % category + print('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s} pred: {:>8s} label: {:>8s}'.format( + iter, timeSince(start), loss, line, guess, correct)) + + # Add current loss avg to list of losses + if iter % plot_every == 0: + all_losses.append(current_loss / plot_every) + current_loss = 0 +path_model = os.path.join(BASE_DIR, "rnn_state_dict.pkl") +torch.save(rnn.state_dict(), path_model) +plt.plot(all_losses) +plt.show() + +predict('Yue Tingsong') +predict('Yue tingsong') +predict('yutingsong') + +predict('test your name') + + +