|
- import random
- import torch
-
-
- class PromptSpell(torch.nn.Module):
- def __init__(self, spell_length, hidden_size, spell_func):
- super(PromptSpell, self).__init__()
- self.spell_length = spell_length
- self.hidden_size = hidden_size
- self.spell_embeddings = torch.nn.Embedding(self.spell_length, self.hidden_size)
- self.spell_func = spell_func
- if self.spell_func == "lstm":
- self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,
- hidden_size=self.hidden_size,
- num_layers=2,
- # dropout=self.lstm_dropout,
- bidirectional=True,
- batch_first=True) # .to(torch.device("cuda"))
- self.mlp_head = torch.nn.Sequential(torch.nn.Linear(2 * self.hidden_size, self.hidden_size),
- torch.nn.ReLU(),
- torch.nn.Linear(self.hidden_size, self.hidden_size))
- elif self.spell_func == "mlp":
- self.mlp_head = torch.nn.Sequential(torch.nn.Linear(self.hidden_size, self.hidden_size),
- torch.nn.ReLU(),
- torch.nn.Linear(self.hidden_size, self.hidden_size))
- elif self.spell_func != "none":
- raise NotImplementedError("Prompt function " + self.spell_func)
-
- def init_embedding(self, word_embeddings=None, task_tokens=None):
- num_words = 5000
- with torch.no_grad():
- for i in range(self.spell_length):
- rand_token = random.randrange(num_words)
- if task_tokens is None:
- target_embedding = word_embeddings[rand_token]
- else:
- word_embedding = word_embeddings[rand_token]
- task_token = random.choice(task_tokens)
- task_embedding = word_embeddings[task_token]
- ratio = random.random()
- target_embedding = word_embedding * ratio + task_embedding * (1 - ratio)
- self.spell_embeddings.weight.data[i] = target_embedding
-
- def forward(self):
- prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
- if self.spell_func == "lstm":
- prompt_embeds = self.lstm_head(prompt_embeds)[0]
- if self.spell_func == "lstm" or self.spell_func == "mlp":
- prompt_embeds = self.mlp_head(prompt_embeds)
- return prompt_embeds
|