|
- import os
-
- import torch
- from torch.utils.data import DataLoader
- import torch.optim as optim
- from ..models import build_model
- from . import BaseFlow, register_flow
- from ..sampler import mg2vec_sampler
- import numpy as np
-
-
- @register_flow('mg2vec_trainer')
- class Mg2vecTrainer(BaseFlow):
- def __init__(self, args):
- super(Mg2vecTrainer, self).__init__(args)
- self.mg2vec_sampler = None
- self.dataloader = None
- self.model = None
- self.embeddings_file_path = os.path.join(self.args.output_dir, self.args.dataset + '_mg2vec_embeddings.npy')
- self.embeddings_file_path2 = os.path.join(self.args.output_dir, self.args.dataset + '_mg2vec_embeddings.txt')
- self.load_trained_embeddings = False
-
- def preprocess(self):
- input_file = "./openhgnn/dataset/{}/meta.txt".format(self.args.dataset)
- block_size = self.args.batch_size * 100000
- self.mg2vec_sampler = mg2vec_sampler.Mg2vecSampler(input_file, block_size, self.args.alpha)
- self.dataloader = DataLoader(self.mg2vec_sampler, batch_size=self.args.batch_size, shuffle=True,
- num_workers=self.args.num_workers,
- )
- self.args.node_num = self.mg2vec_sampler.data.node_count
- self.args.mg_num = self.mg2vec_sampler.data.mg_count
- self.args.unigram = self.mg2vec_sampler.data.unigram
- self.model = build_model(self.model_name).build_model_from_args(self.args, self.hg).to(self.device)
-
- def train(self):
- emb = self.load_embeddings()
- emb_dict = dict()
- for nId, node in self.mg2vec_sampler.data.node_reverse_dict.items():
- emb_dict[int(node)] = emb[nId]
- # todo: only supports edge classification now
- metric = {
- 'test': self.task.downstream_evaluate(logits=self.get_edge_embed(emb=emb_dict), evaluation_metric='acc_f1')}
- self.logger.train_info(self.logger.metric2str(metric))
- # metric = {'test': self.task.evaluate(emb_dict)}
- # self.logger.train_info(self.logger.metric2str(metric))
-
- def load_embeddings(self):
- if not self.load_trained_embeddings or not os.path.exists(self.embeddings_file_path):
- self.train_embeddings()
- emb = np.load(self.embeddings_file_path)
- return emb
-
- def train_embeddings(self):
- self.preprocess()
- epoch_index = 1
- optimizer = optim.Adam(list(self.model.parameters()), lr=self.args.lr)
- average_loss = 0.0
- step = 0
- print("train start")
- while True:
- for i, sampled_batch in enumerate(self.dataloader):
- if len(sampled_batch) > 0:
- train_a = sampled_batch[0].to(self.device)
- train_b = sampled_batch[1].to(self.device)
- train_label = sampled_batch[2].to(self.device)
- train_freq = sampled_batch[3].reshape(-1, 1).to(self.device)
- train_weight = sampled_batch[4].reshape(-1, 1).to(self.device)
-
- optimizer.zero_grad()
- loss = self.model.forward(train_a, train_b, train_label, train_freq, train_weight, self.device)
- loss.backward()
- optimizer.step()
-
- average_loss += loss.item()
- step += 1
- if step > 0 and step % 10000 == 0:
- average_loss /= 10000
- print('Average loss at step ', step, ': ', average_loss)
- average_loss = 0.0
- if self.mg2vec_sampler.data.epoch_end:
- print("epoch %d end" % epoch_index)
- epoch_index += 1
- self.mg2vec_sampler.data.epoch_end = False
- if epoch_index > self.args.max_epoch:
- break
- self.mg2vec_sampler.data.read_block()
-
- print("total step: ", step)
- self.model.save_embedding_np(self.embeddings_file_path)
- self.model.save_embedding(self.mg2vec_sampler.data.node_reverse_dict, self.embeddings_file_path2)
-
- def get_edge_embed(self, emb):
- edge_embed = []
- g = self.hg
- u, v = g.edges()
- core1_dict = g.nodes['core1'].data['id2node'].cpu()
- core2_dict = g.nodes['core2'].data['id2node'].cpu()
- for i in range(len(u)):
- edge_embed.append(np.hstack([emb[int(core1_dict[u[i]])], emb[int(core2_dict[v[i]])]]))
- x = np.array(edge_embed)
- return x
|