|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # @File : text_classification_tl.py
- # @Date : 2021/12/9
- # @Desc : 基于TensorLayer的文本分类算法
- # @Author : Hou
-
- # The same set of code can switch the backend with one line
- import os
- import sys
- import argparse
- import numpy as np
- import pickle
-
- parser = argparse.ArgumentParser(__doc__)
- parser.add_argument('--tl_backend', choices=['tensorflow', 'mindspore', 'paddle'], default="tensorflow",
- help="Select which backend TensorLayer to use, defaults to tensorflow.")
- parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu', 'npu'], default="cpu",
- help="Select which device to train model, defaults to gpu.")
-
- parser.add_argument('--network', choices=['rnn', 'lstm', 'gru', 'transformerEncoder'],
- default="lstm", help="Select which network to train, defaults to lstm.")
-
- parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate used to train.")
- parser.add_argument("--epochs", type=int, default=1, help="Number of epoches for training.")
- parser.add_argument("--seq_len", type=int, default=200, help="Input sequce length.")
- parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
-
- parser.add_argument("--print_freq", type=int, default=1, help="log print freq")
- parser.add_argument("--model_path", type=str, default='model.npz', help="path to save model checkpoint")
-
- parser.add_argument("--run_mode", choices=['train', 'eval'], default='train',
- help="select run mode, default to train")
-
- args = parser.parse_args()
- print(args)
-
- # os.environ['TL_BACKEND'] = 'tensorflow'
- # os.environ['TL_BACKEND'] = 'mindspore'
- # os.environ['TL_BACKEND'] = 'paddle'
- os.environ['TL_BACKEND'] = args.tl_backend
-
- import tensorlayer as tl
- from tensorlayer.dataflow import Dataset
- from model import initial_network
-
-
- def get_word_set(contents):
- word_dict = set()
- for seq in contents:
- word_dict.update(seq)
- return word_dict
-
-
- def get_word_freq(contents):
- word_dict = dict()
- for seq in contents:
- for word in seq:
- if word in word_dict:
- word_dict[word] += 1
- else:
- word_dict[word] = 1
- return word_dict
-
-
- class imdbdataset(Dataset):
-
- def __init__(self, X, y, seq_len):
- self.X = X
- self.y = y
- self.seq_len = seq_len
-
- def __getitem__(self, index):
-
- data = self.X[index]
- data = np.concatenate([data[:self.seq_len], [0] * (self.seq_len - len(data))]).astype('int64') # set
- label = self.y[index].astype('int64')
- return data, label
-
- def __len__(self):
-
- return len(self.y)
-
-
- def create_data_loader(x, y, seq_len, batch_size):
- """
-
- :param x:
- :param y:
- :param seq_len:
- :param batch_size:
- :return:
- """
- train_dataset = imdbdataset(x, y, seq_len)
- train_dataset = tl.dataflow.FromGenerator(
- train_dataset, output_types=[tl.int64, tl.int64], column_names=['data', 'label']
- )
- train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=batch_size, shuffle=True)
- return train_loader
-
-
- def imdb_data_load(path='data', nb_words=20000, test_split=0.2):
- filename = "imdb_split.pkl"
- filepath = os.path.join(path, 'imdb', filename)
- if os.path.exists(filepath):
- fin = open(filepath, 'rb')
- data = pickle.load(fin)
- fin.close()
- return data
- else:
- data = tl.files.load_imdb_dataset(path, nb_words=nb_words, test_split=test_split)
- fout = open(filepath, 'wb')
- pickle.dump(data, fout)
- fout.close()
- return data
-
-
- def model_initial(net_name, vocab_size, label_num):
- # net = ImdbNet(vocab_size, label_num)
- net = initial_network(net_name, vocab_size, label_num)
-
- optimizer = tl.optimizers.Adam(1e-3)
- metric = tl.metric.Accuracy()
- loss_fn = tl.cost.softmax_cross_entropy_with_logits
- model = tl.models.Model(network=net, loss_fn=loss_fn, optimizer=optimizer, metrics=metric)
- return model
-
-
- def model_train():
- X_train, y_train, X_test, y_test = imdb_data_load('data', nb_words=20000, test_split=0.2)
- train_loader = create_data_loader(X_train, y_train, args.seq_len, args.batch_size)
- test_loader = create_data_loader(X_test, y_test, args.seq_len, args.batch_size)
-
- label_list = list(set(y_train.tolist()))
- label_num = len(label_list)
- w_dict = get_word_freq(X_train)
- # vocab_size = len(w_dict) + 1
- # vocab_size = len(X_train) + 1 # todo: 词表大小为什么这么设置?该为实际此表不能运行
- print(len(X_train), len(w_dict))
- vocab_size = 20000 + 1
-
- model = model_initial(args.network, vocab_size, label_num)
-
- print("Model train ...")
- model.train(n_epoch=args.epochs, train_dataset=train_loader, test_dataset=test_loader,
- print_freq=args.print_freq, print_train_batch=False)
-
- print("Model eval ...")
- model.eval(test_loader)
-
- print("Model save ...")
- model.save_weights(args.model_path, format='npz')
-
-
- def model_eval():
- X_train, y_train, X_test, y_test = imdb_data_load('data', nb_words=20000, test_split=0.2)
- test_loader = create_data_loader(X_test, y_test, args.seq_len, args.batch_size)
- label_list = list(set(y_train.tolist()))
- label_num = len(label_list)
- print("Train Instance", len(y_train))
- print("Test Instance", len(y_test))
- print("Train Label 1", sum(y_train))
- print("Test Label 1", sum(y_test))
-
- vocab_size = 20000 + 1
-
- model = model_initial(args.network, vocab_size, label_num)
-
- print("Model load ...")
- model.load_weights(args.model_path, format='npz')
-
- model.eval(test_loader)
-
-
- if __name__ == "__main__":
- if args.run_mode == "train":
- model_train()
- elif args.run_mode == "eval":
- model_eval()
- else:
- print(args)
- print("Done.")
|