|
- #!/usr/bin/env python36
- # -*- coding: utf-8 -*-
- """
- Created on July, 2018
-
- @author: Tangrizzly
- """
-
- import argparse
- import pickle
- import time
- from utils import Data
- from model import *
-
-
- def init_seed(seed=None):
- if seed is None:
- seed = int(time.time() * 1000 // 1000)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset', default='Caixin_Neg', help='dataset name: Caixin')
- parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
- parser.add_argument('--textHiddenSize', type=int, default=300, help='hidden state size')
- parser.add_argument('--epoch', type=int, default=30, help='the number of epochs to train for')
- parser.add_argument('--lr', type=float, default=0.001, help='learning rate') # [0.001, 0.0005, 0.0001]
- parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate')
- parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay')
- parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001]
- parser.add_argument('--topk', type=list, default=[5], help='topk recommendation') # [5, 10, 20]
- parser.add_argument('--patience', type=int, default=3, help='the number of epoch to wait before early stop ')
- parser.add_argument('--MAX_TITLE', type=int, default=15)
- parser.add_argument('--MAX_CONTENT', type=int, default=50)
- opt = parser.parse_args()
- print(opt)
-
-
- def main():
- init_seed(2023)
-
- train = pickle.load(open('/dataset/' + opt.dataset + '/train.txt', 'rb'))
- test = pickle.load(open('/dataset/' + opt.dataset + '/test.txt', 'rb'))
-
- train_hist, train_neg = train[0], train[3]
- test_hist, test_neg = test[0], test[3]
-
- train_data = Data(train_hist, train_neg, shuffle=False)
- test_data = Data(test_hist, test_neg, shuffle=False)
-
- # 以news index为索引,[['政协', '开幕会', '昆明', '暴力', '恐怖事件', '遇难者', '默哀'], ...]
- content_cut = pickle.load(open('/dataset/' + opt.dataset + '/content_cut.pkl', 'rb'))
-
- # {word: vector-300d}, vocabulary_size: 82700
- word_vectors_300d = pickle.load(open('/dataset/' + opt.dataset + '/word_vectors_300d.pkl', 'rb'))
-
- model = trans_to_cuda(proposed_model(opt))
-
- start = time.time()
- best_result = [0, 0, 0, 0]
- best_epoch = [0, 0, 0, 0]
- bad_counter = 0
- for epoch in range(opt.epoch):
- print('-------------------------------------------------------')
- print('epoch: ', epoch)
- auc_tes, mrr_tes, ndcg5_tes, ndcg10_tes = train_test(model, train_data, test_data, content_cut, word_vectors_300d)
- flag = 0
- if auc_tes >= best_result[0]:
- best_result[0] = auc_tes
- best_epoch[0] = epoch
- flag = 1
- if mrr_tes >= best_result[1]:
- best_result[1] = mrr_tes
- best_epoch[1] = epoch
- flag = 1
- if ndcg5_tes >= best_result[2]:
- best_result[2] = ndcg5_tes
- best_epoch[2] = epoch
- flag = 1
- if ndcg10_tes >= best_result[3]:
- best_result[3] = ndcg10_tes
- best_epoch[3] = epoch
- flag = 1
- print('Best Result:')
- print('\tAuc:\t%.4f\tMMR:\t%.4f\tNDCG@5:\t%.4f\tNDCG@10:\t%.4f\tEpoch:\t%d,\t%d,\t%d,\t%d'
- % (best_result[0], best_result[1], best_result[2],best_result[3], best_epoch[0], best_epoch[1], best_epoch[2], best_epoch[3]))
- bad_counter += 1 - flag
- if bad_counter >= opt.patience:
- break
- print('-------------------------------------------------------')
- end = time.time()
- print("Run time: %f s" % (end - start))
-
-
- if __name__ == '__main__':
- main()
|