| | import torch |
| | import numpy as np |
| | from torch_utils.ops import bias_act |
| | from torch_utils import misc |
| |
|
| |
|
| |
|
| | def normalize_2nd_moment(x, dim=1, eps=1e-8): |
| | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() |
| |
|
| |
|
| | class FullyConnectedLayer_normal(torch.nn.Module): |
| | def __init__(self, |
| | in_features, |
| | out_features, |
| | bias = True, |
| | bias_init = 0, |
| | ): |
| | super().__init__() |
| | self.fc = torch.nn.Linear(in_features, out_features, bias=bias) |
| | if bias: |
| | with torch.no_grad(): |
| | self.fc.bias.fill_(bias_init) |
| |
|
| | def forward(self, x): |
| | output = self.fc(x) |
| | return output |
| |
|
| |
|
| | class MappingNetwork_normal(torch.nn.Module): |
| | def __init__(self, |
| | in_features, |
| | int_dim, |
| | num_layers = 8, |
| | mapping_normalization = False |
| | ): |
| | super().__init__() |
| | layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)] |
| | for i in range(1, num_layers): |
| | layers.append(torch.nn.Linear(int_dim, int_dim)) |
| | layers.append(torch.nn.LeakyReLU(0.2)) |
| |
|
| | self.net = torch.nn.Sequential(*layers) |
| | self.normalization = mapping_normalization |
| |
|
| | def forward(self, x): |
| | if self.normalization: |
| | x = normalize_2nd_moment(x) |
| | output = self.net(x) |
| | return output |
| |
|
| |
|
| | class DecodingNetwork(torch.nn.Module): |
| | def __init__(self, |
| | in_features, |
| | out_dim, |
| | num_layers = 8, |
| | ): |
| | super().__init__() |
| | layers = [] |
| | for i in range(num_layers-1): |
| | layers.append(torch.nn.Linear(in_features, in_features)) |
| | layers.append(torch.nn.ReLU()) |
| |
|
| | layers.append(torch.nn.Linear(in_features, out_dim)) |
| |
|
| | self.net = torch.nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | x = torch.nn.functional.normalize(x, dim=1) |
| | output = self.net(x) |
| | return output |
| |
|
| |
|
| | class FullyConnectedLayer(torch.nn.Module): |
| | def __init__(self, |
| | in_features, |
| | out_features, |
| | bias = True, |
| | activation = 'linear', |
| | lr_multiplier = 1, |
| | bias_init = 0, |
| | ): |
| | super().__init__() |
| | self.activation = activation |
| | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) |
| | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None |
| | self.weight_gain = lr_multiplier / np.sqrt(in_features) |
| | self.bias_gain = lr_multiplier |
| |
|
| | def forward(self, x): |
| | w = self.weight.to(x.dtype) * self.weight_gain |
| | b = self.bias |
| | if b is not None: |
| | b = b.to(x.dtype) |
| | if self.bias_gain != 1: |
| | b = b * self.bias_gain |
| |
|
| | if self.activation == 'linear' and b is not None: |
| | x = torch.addmm(b.unsqueeze(0), x, w.t()) |
| | else: |
| | x = x.matmul(w.t()) |
| | x = bias_act.bias_act(x, b, act=self.activation) |
| | return x |
| |
|
| |
|
| | class MappingNetwork(torch.nn.Module): |
| | def __init__(self, |
| | z_dim, |
| | c_dim, |
| | w_dim, |
| | num_ws, |
| | num_layers = 8, |
| | embed_features = None, |
| | layer_features = None, |
| | activation = 'lrelu', |
| | lr_multiplier = 0.01, |
| | w_avg_beta = 0.995, |
| | normalization = None |
| | ): |
| | super().__init__() |
| | self.z_dim = z_dim |
| | self.c_dim = c_dim |
| | self.w_dim = w_dim |
| | self.num_ws = num_ws |
| | self.num_layers = num_layers |
| | self.w_avg_beta = w_avg_beta |
| | self.normalization = normalization |
| |
|
| | if embed_features is None: |
| | embed_features = w_dim |
| | if c_dim == 0: |
| | embed_features = 0 |
| | if layer_features is None: |
| | layer_features = w_dim |
| | features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] |
| |
|
| | if c_dim > 0: |
| | self.embed = FullyConnectedLayer(c_dim, embed_features) |
| | for idx in range(num_layers): |
| | in_features = features_list[idx] |
| | out_features = features_list[idx + 1] |
| | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) |
| | setattr(self, f'fc{idx}', layer) |
| |
|
| | if num_ws is not None and w_avg_beta is not None: |
| | self.register_buffer('w_avg', torch.zeros([w_dim])) |
| |
|
| | def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): |
| | |
| | x = None |
| | with torch.autograd.profiler.record_function('input'): |
| | if self.z_dim > 0: |
| | misc.assert_shape(z, [None, self.z_dim]) |
| | if self.normalization: |
| | x = normalize_2nd_moment(z.to(torch.float32)) |
| | else: |
| | x = z |
| | x = z.to(torch.float32) |
| | if self.c_dim > 0: |
| | raise ValueError("This implementation does not need class index") |
| | misc.assert_shape(c, [None, self.c_dim]) |
| | y = normalize_2nd_moment(self.embed(c.to(torch.float32))) |
| | y = self.embed(c.to(torch.float32)) |
| | x = torch.cat([x, y], dim=1) if x is not None else y |
| |
|
| | |
| | for idx in range(self.num_layers): |
| | layer = getattr(self, f'fc{idx}') |
| | x = layer(x) |
| |
|
| | |
| | if self.w_avg_beta is not None and self.training and not skip_w_avg_update: |
| | with torch.autograd.profiler.record_function('update_w_avg'): |
| | self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) |
| |
|
| | |
| | if self.num_ws is not None: |
| | with torch.autograd.profiler.record_function('broadcast'): |
| | x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) |
| |
|
| | |
| | if truncation_psi != 1: |
| | with torch.autograd.profiler.record_function('truncate'): |
| | assert self.w_avg_beta is not None |
| | if self.num_ws is None or truncation_cutoff is None: |
| | x = self.w_avg.lerp(x, truncation_psi) |
| | else: |
| | x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) |
| | return x |
| |
|