|
- #!/usr/bin/env python36
- # -*- coding: utf-8 -*-
- """
- Created on July, 2018
-
- @author: Tangrizzly
- """
- import numpy as np
-
-
- def data_masks(all_usr_pois, item_tail):
- us_lens = [len(upois) for upois in all_usr_pois]
- len_max = max(us_lens)
- us_pois = [upois + item_tail * (len_max - le) for upois, le in zip(all_usr_pois, us_lens)]
- us_msks = [[1] * le + [0] * (len_max - le) for le in us_lens]
- return us_pois, us_msks, len_max
-
-
- class Data():
- def __init__(self, hist, neg_sample, shuffle=False, graph=None):
- inputs = []
- candidates = [] # 候选新闻,bs * (K + 1),truth在第0维
- for sess, neg in zip(hist, neg_sample):
- inputs.append(sess[0])
- candidates.append([sess[1]] + neg)
- targets = np.zeros((len(inputs), ), dtype=int)
- self.inputs = np.asarray(inputs)
- self.candidates = np.asarray(candidates)
- self.targets = np.asarray(targets)
- self.length = len(inputs)
-
- def generate_batch(self, batch_size):
- n_batch = int(self.length / batch_size)
- if self.length % batch_size != 0:
- n_batch += 1
- slices = np.split(np.arange(n_batch * batch_size), n_batch)
- slices[-1] = slices[-1][:(self.length - batch_size * (n_batch - 1))]
- return slices
-
- def get_slice(self, i):
- inputs, candidates, targets = self.inputs[i], self.candidates[i], self.targets[i]
- inputs, mask, len_max = data_masks(inputs, [0])
- inputs = np.asarray(inputs)
- mask = np.asarray(mask)
- candidates = np.asarray(candidates)
- return inputs, mask, candidates, targets
-
-
- def get_text(inputs, w2v_model, text_cut, MAX_LEN):
- batch_size = len(inputs) # bs
- max_sess_len = len(inputs[0]) # N
- text_length = [] # bs * N
- max_text_len = 0 # M
-
- for sess in inputs:
- lens = []
- for news_id in sess:
- cut = text_cut[news_id]
- if len(cut) > MAX_LEN:
- text_cut[news_id] = text_cut[news_id][:MAX_LEN]
- lens.append(len(text_cut[news_id]))
- max_text_len = max(len(text_cut[news_id]), max_text_len)
- text_length.append(lens)
-
- text = np.zeros((batch_size, max_sess_len, max_text_len, 300), dtype=np.float32)
- for i in range(len(inputs)):
- for j, news_id in enumerate(inputs[i], 0):
- for k, word in enumerate(text_cut[news_id], 0):
- if k >= max_text_len:
- break
- text[i][j][k] = w2v_model[word]
-
- return text, text_length
|