|
- # -*- coding: utf-8 -*-
- # @Author : William
- # @Project : TextGAN-william
- # @FileName : SentiGAN_G.py
- # @Time : Created at 2019-07-26
- # @Blog : http://zhiweil.ml/
- # @Description :
- # Copyrights (C) 2018. All Rights Reserved.
-
-
- import torch
- import torch.nn.functional as F
-
- from models.generator import LSTMGenerator
-
-
- class SentiGAN_G(LSTMGenerator):
- def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False):
- super(SentiGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu)
- self.name = 'sentigan'
-
- def forward(self, inp, hidden, need_hidden=False, use_log=True):
- """
- Embeds input and applies LSTM
- :param inp: batch_size * seq_len
- :param hidden: (h, c)
- :param need_hidden: if return hidden, use for sampling
- """
- emb = self.embeddings(inp) # batch_size * len * embedding_dim
- if len(inp.size()) == 1:
- emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim
-
- out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim
- out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim
- out = self.lstm2out(out) # batch_size * seq_len * vocab_size
- # out = self.temperature * out # temperature
- if use_log:
- pred = F.log_softmax(out, dim=-1)
- else:
- pred = F.softmax(out, dim=-1)
-
- if need_hidden:
- return pred, hidden
- else:
- return pred
-
- def batchPGLoss(self, inp, target, reward):
- """
- Returns a policy gradient loss
-
- :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended
- :param target: batch_size x seq_len
- :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence)
- :return loss: policy loss
- """
-
- batch_size, seq_len = inp.size()
- hidden = self.init_hidden(batch_size)
-
- out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size)
- target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size
- pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len
- loss = -torch.sum(pred * (1 - reward))
-
- return loss
|