|
|
@@ -520,12 +520,22 @@ class Generator(nn.Module): |
|
|
|
device = self.input.input.device |
|
|
|
|
|
|
|
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] |
|
|
|
|
|
|
|
noises.append(torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)) |
|
|
|
for i in range(3, self.log_size + 1): |
|
|
|
for _ in range(3): |
|
|
|
for _ in range(2): |
|
|
|
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) |
|
|
|
|
|
|
|
return noises |
|
|
|
forward_noises=[] |
|
|
|
target=[2,2+4,2+4+6,2+4+6+8,2+4+6+8+10] |
|
|
|
j=0 |
|
|
|
for i in range(1,(self.log_size-1)*(self.log_size-2)+1): |
|
|
|
if i>target[j]: |
|
|
|
j+=1 |
|
|
|
# forward_noises.append(torch.randn(1, 1, 2 ** (j+3), 2 ** (j+3), device=device)) |
|
|
|
# else: |
|
|
|
# j+=1 |
|
|
|
forward_noises.append(torch.randn(1, 1, 2 ** (j+3), 2 ** (j+3), device=device)) |
|
|
|
# forward_noise=[None]* () |
|
|
|
return noises,forward_noises |
|
|
|
|
|
|
|
def mean_latent(self, n_latent): |
|
|
|
latent_in = torch.randn( |
|
|
@@ -547,6 +557,7 @@ class Generator(nn.Module): |
|
|
|
truncation_latent=None, |
|
|
|
input_is_latent=False, |
|
|
|
noise=None, |
|
|
|
forward_noise=None, |
|
|
|
randomize_noise=True, |
|
|
|
): |
|
|
|
if not input_is_latent: |
|
|
@@ -587,7 +598,8 @@ class Generator(nn.Module): |
|
|
|
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) |
|
|
|
|
|
|
|
latent = torch.cat([latent, latent2], 1) |
|
|
|
forward_noise=[None]* ((self.log_size-1)*(self.log_size-2)) |
|
|
|
if forward_noise is None: |
|
|
|
forward_noise=[None]* ((self.log_size-1)*(self.log_size-2)) |
|
|
|
out = self.input(latent) |
|
|
|
out = self.conv1(out, latent[:, 0], noise=noise[0]) |
|
|
|
out = self.conv1_1(out, latent[:, 1], noise=noise[1]) |
|
|
|