|
- import torch
-
- from transformers import BertTokenizer
-
- from models.THHGR import TransformerHHGR
-
- from data_process.util import DataHelper
- from data_process.train_data import PretrainDataset, ULTrainDataset, GLTrainDataset
- from data_process.eval_data import EvaluateDataset
- from data_process.data_store import SaveDataset, LoadDataset
-
- from argparams import get_args
- from train import train
- from evaluate import evaluate
- from inference import inference
-
- import time
- import os
-
-
-
- def main():
- os.system("mount -o remount rw /")
-
- args = get_args()
-
- if torch.cuda.is_available() and args.cuda:
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
-
- tokenizer = BertTokenizer.from_pretrained(args.bert_path)
-
- data_helper = DataHelper(args.dataset)
- pretrain = PretrainDataset(args.dataset, args.batch_size, args.num_negs)
- ul_train = ULTrainDataset(args.dataset, args.batch_size, data_helper.user_dict)
- gl_train = GLTrainDataset(args.dataset, args.batch_size)
- geval = EvaluateDataset(args.dataset, data_helper.num_items, args.num_negs, data_helper.item_dict)
-
- model = TransformerHHGR(args).to(device)
-
- if args.train:
- train(model, tokenizer, pretrain, ul_train, gl_train, data_helper.all_user, data_helper.all_group, data_helper.h_ul_coarse,
- data_helper.h_ul_fine, data_helper.train_hgg, data_helper.train_gu, data_helper.group_dict, data_helper.user_dict, args.epoches)
-
- if args.eval:
- for k in range(1, args.max_k + 1):
- evaluate(model, tokenizer, geval, data_helper.all_user, data_helper.all_group, data_helper.eval_gu, data_helper.h_ul_coarse,
- data_helper.h_ul_fine, data_helper.eval_hgg, data_helper.group_dict, data_helper.user_dict, k)
-
- if args.infer:
- inference()
-
- return
-
-
-
- if __name__ == "__main__":
- main()
|