Browse Source

更新 'model_forward.py'

main
mohenghui 1 month ago
parent
commit
be9b1efe9e
1 changed files with 17 additions and 5 deletions
  1. +17
    -5
      model_forward.py

+ 17
- 5
model_forward.py View File

@@ -520,12 +520,22 @@ class Generator(nn.Module):
device = self.input.input.device

noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
noises.append(torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device))
for i in range(3, self.log_size + 1):
for _ in range(3):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

return noises
forward_noises=[]
target=[2,2+4,2+4+6,2+4+6+8,2+4+6+8+10]
j=0
for i in range(1,(self.log_size-1)*(self.log_size-2)+1):
if i>target[j]:
j+=1
# forward_noises.append(torch.randn(1, 1, 2 ** (j+3), 2 ** (j+3), device=device))
# else:
# j+=1
forward_noises.append(torch.randn(1, 1, 2 ** (j+3), 2 ** (j+3), device=device))
# forward_noise=[None]* ()
return noises,forward_noises

def mean_latent(self, n_latent):
latent_in = torch.randn(
@@ -547,6 +557,7 @@ class Generator(nn.Module):
truncation_latent=None,
input_is_latent=False,
noise=None,
forward_noise=None,
randomize_noise=True,
):
if not input_is_latent:
@@ -587,7 +598,8 @@ class Generator(nn.Module):
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

latent = torch.cat([latent, latent2], 1)
forward_noise=[None]* ((self.log_size-1)*(self.log_size-2))
if forward_noise is None:
forward_noise=[None]* ((self.log_size-1)*(self.log_size-2))
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
out = self.conv1_1(out, latent[:, 1], noise=noise[1])


Loading…
Cancel
Save