|
- import torch
-
- from transformers import BertTokenizer
-
-
- from models.THHGR import TransformerHHGR
- from argparams import get_args
- from data_process.util import DataHelper
- from data_process.train_data import PretrainDataset, ULTrainDataset, GLTrainDataset
-
- import time
- import tqdm
- import numpy
-
-
-
-
- def train(
- model: TransformerHHGR,
- tokenizer,
- pretrain,
- ul_train,
- gl_train,
- all_user,
- all_group,
- h_ul_coarse,
- h_ul_fine,
- h_gl,
- train_gu,
- group_dict,
- user_dict,
- epoches
- ):
- model.train()
-
- print ("Pre-training model on user-item interactions")
- with tqdm.tqdm(range(epoches), desc="pretrain") as pre_tqdm:
- for epoch in pre_tqdm:
- start = time.time()
-
- uloss = []
- for idx, (user_in, pos_item, neg_item) in enumerate(pretrain):
- pos_pred = model("user", **{"user_inputs": user_in, "item_inputs": pos_item, "tokenizer": tokenizer})
- neg_pred = model("user", **{"user_inputs": user_in, "item_inputs": neg_item, "tokenizer": tokenizer})
-
- model.zero_grad()
- loss = torch.mean((pos_pred - neg_pred) ** 2)
- uloss.append(float(loss))
- loss.backward(retain_graph=True)
- model.optimizer.step()
-
- end = time.time()
- pre_tqdm.write("Pretrain | epoch {:^3d} | loss {:^7.5f} | time {:^4.2f}".format(epoch, numpy.mean(uloss), end - start))
-
-
- print ("loading the user-level hypergraph learning")
- all_user_embedding = model.to_embedding(all_user, tokenizer).detach()
- fine_user_embedding = all_user_embedding
- user_embed_coarse = model.hgcn_coarse(all_user_embedding, h_ul_coarse)
- user_embed_fine = model.hgcn_fine(fine_user_embedding, h_ul_fine)
-
- with tqdm.tqdm(range(epoches)) as ul_tqdm:
- for epoch in ul_tqdm:
- start = time.time()
-
- ul_loss = []
-
- for idx, (ui_inst, pos_user, neg_user) in enumerate(ul_train):
- model.optimizer.zero_grad()
- model.zero_grad()
-
- uembed_coarse = user_embed_coarse[pos_user]
- uembed_fine = user_embed_fine[pos_user]
- neg_uembed = user_embed_coarse[neg_user]
-
- pos_score = model.discriminator(uembed_coarse, uembed_fine)
- neg_score = model.discriminator(uembed_coarse, neg_uembed)
- mi_loss = model.discriminator.mi_loss(pos_score, neg_score)
- with torch.autograd.detect_anomaly():
- mi_loss.backward(retain_graph=True)
- ul_loss.append(float(mi_loss))
-
- model.optimizer.step()
-
- end = time.time()
- ul_tqdm.write("UL train | epoch {:^3d} | loss {:^7.5f} | time {:^4.2f}".format(epoch, numpy.mean(ul_loss), end - start))
-
- print ("loading the group-level hypergraph learning")
- all_user_embedding = model.to_embedding(all_user, tokenizer).detach()
- fine_user_embedding = all_user_embedding
- user_embed_coarse = model.hgcn_coarse(all_user_embedding, h_ul_coarse).detach()
- user_embed_fine = model.hgcn_fine(fine_user_embedding, h_ul_fine).detach()
- user_embedding = user_embed_coarse + user_embed_fine
-
- with tqdm.tqdm(range(epoches)) as gl_tqdm:
- for epoch in gl_tqdm:
- start = time.time()
- gl_loss = []
-
- all_group_embedding = model.to_embedding(all_group, tokenizer)
- group_embeds = model.hgcn_gl(all_group_embedding, h_gl)
-
- for idx, (group_in, pos_item, neg_item) in enumerate(gl_train):
- model.optimizer.zero_grad()
- model.zero_grad()
-
- pos_pred = model(
- "group",
- **{"group_inputs": group_in, "item_inputs": pos_item, "user_embedding": user_embedding, "group_embedding": group_embeds,
- "group_users": train_gu, "group_dict": group_dict, "user_dict": user_dict, "tokenizer": tokenizer}
- )
- neg_pred = model(
- "group",
- **{"group_inputs": group_in, "item_inputs": neg_item, "user_embedding": user_embedding, "group_embedding": group_embeds,
- "group_users": train_gu, "group_dict": group_dict, "user_dict": user_dict, "tokenizer": tokenizer}
- )
-
- loss = torch.mean((pos_pred - neg_pred) ** 2)
- gl_loss.append(float(loss))
- with torch.autograd.detect_anomaly():
- loss.backward(retain_graph=True)
- model.optimizer.step()
-
- end = time.time()
- gl_tqdm.write("GL train | epoch {:^3d} | loss {:^7.5f} | time {:^4.2f}".format(epoch, numpy.mean(gl_loss), end - start))
-
- return
-
-
-
- def CloudBrainTrain():
- args = get_args()
-
- if torch.cuda.is_available() and args.cuda:
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
-
- tokenizer = BertTokenizer.from_pretrained(args.bert_path)
-
- data_helper = DataHelper(args.dataset)
- pretrain = PretrainDataset(args.dataset, args.batch_size, args.num_negs)
- ul_train = ULTrainDataset(args.dataset, args.batch_size, data_helper.user_dict)
- gl_train = GLTrainDataset(args.dataset, args.batch_size)
-
- model = TransformerHHGR(args).to(device)
-
- train(model, tokenizer, pretrain, ul_train, gl_train, data_helper.all_user, data_helper.all_group, data_helper.h_ul_coarse,
- data_helper.h_ul_fine, data_helper.train_hgg, data_helper.train_gu, data_helper.group_dict, data_helper.user_dict, args.epoches)
-
- torch.save(model, "/model/thhgr.pth")
-
- return
-
-
-
- if __name__ == "__main__":
- CloudBrainTrain()
|