| |
|
| |
|
| | |
| | from torch import nn |
| | import torch |
| |
|
| | from vit.vision_transformer import Mlp, DropPath |
| |
|
| |
|
| | |
| | class ResnetBlockFC(nn.Module): |
| | """ |
| | Fully connected ResNet Block class. |
| | Taken from DVR code. |
| | :param size_in (int): input dimension |
| | :param size_out (int): output dimension |
| | :param size_h (int): hidden dimension |
| | """ |
| | def __init__(self, size_in, size_out=None, size_h=None, beta=0.0, init_as_zero=False): |
| | super().__init__() |
| | |
| | if size_out is None: |
| | size_out = size_in |
| |
|
| | if size_h is None: |
| | size_h = min(size_in, size_out) |
| |
|
| | self.size_in = size_in |
| | self.size_h = size_h |
| | self.size_out = size_out |
| | |
| | self.fc_0 = nn.Linear(size_in, size_h) |
| | self.fc_1 = nn.Linear(size_h, size_out) |
| |
|
| | |
| | nn.init.constant_(self.fc_0.bias, 0.0) |
| | if init_as_zero: |
| | nn.init.zeros_(self.fc_0.weight) |
| | else: |
| | nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in") |
| | nn.init.constant_(self.fc_1.bias, 0.0) |
| | nn.init.zeros_(self.fc_1.weight) |
| |
|
| | if beta > 0: |
| | self.activation = nn.Softplus(beta=beta) |
| | else: |
| | self.activation = nn.ReLU() |
| |
|
| | if size_in == size_out: |
| | self.shortcut = None |
| | else: |
| | self.shortcut = nn.Linear(size_in, size_out, bias=False) |
| | |
| | nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in") |
| |
|
| | def forward(self, x): |
| | |
| | net = self.fc_0(self.activation(x)) |
| | dx = self.fc_1(self.activation(net)) |
| |
|
| | if self.shortcut is not None: |
| | x_s = self.shortcut(x) |
| | else: |
| | x_s = x |
| | return x_s + dx |
| |
|
| |
|
| |
|
| |
|
| | |
| | class ResnetBlockFCViT(nn.Module): |
| | """ |
| | Fully connected ResNet Block class. |
| | Taken from DVR code. |
| | :param size_in (int): input dimension |
| | :param size_out (int): output dimension |
| | :param size_h (int): hidden dimension |
| | """ |
| | def __init__(self, size_in, size_out=None, size_h=None, beta=0.0, init_as_zero=False): |
| | super().__init__() |
| | |
| | if size_out is None: |
| | size_out = size_in |
| |
|
| | if size_h is None: |
| | size_h = min(size_in, size_out) |
| |
|
| | self.size_in = size_in |
| | self.size_h = size_h |
| | self.size_out = size_out |
| | |
| | self.fc_0 = nn.Linear(size_in, size_h) |
| | self.fc_1 = nn.Linear(size_h, size_out) |
| |
|
| | |
| | nn.init.constant_(self.fc_0.bias, 0.0) |
| | if init_as_zero: |
| | nn.init.zeros_(self.fc_0.weight) |
| | else: |
| | nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in") |
| | nn.init.constant_(self.fc_1.bias, 0.0) |
| | nn.init.zeros_(self.fc_1.weight) |
| |
|
| | if beta > 0: |
| | self.activation = nn.Softplus(beta=beta) |
| | else: |
| | self.activation = nn.ReLU() |
| |
|
| | if size_in == size_out: |
| | self.shortcut = None |
| | else: |
| | self.shortcut = nn.Linear(size_in, size_out, bias=False) |
| | |
| | nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in") |
| |
|
| | def forward(self, x): |
| | |
| | net = self.fc_0(self.activation(x)) |
| | dx = self.fc_1(self.activation(net)) |
| |
|
| | if self.shortcut is not None: |
| | x_s = self.shortcut(x) |
| | else: |
| | x_s = x |
| | return x_s + dx |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| | class ResMlp(nn.Module): |
| | def __init__(self, |
| | |
| | size_in, |
| | size_out=None, |
| | size_h=None, |
| | drop=0., |
| | drop_path=0., |
| | act_layer=nn.GELU, |
| | norm_layer=nn.LayerNorm, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | if size_out is None: |
| | size_out = size_in |
| | if size_h is None: |
| | size_h = min(size_in, size_out) |
| | self.size_in = size_in |
| | self.size_h = size_h |
| | self.size_out = size_out |
| |
|
| | |
| | self.norm1 = norm_layer(size_in) |
| |
|
| | self.mlp = Mlp(in_features=size_in, |
| | out_features=size_out, |
| | act_layer=act_layer, |
| | drop=drop) |
| |
|
| | |
| | if size_in == size_out: |
| | self.shortcut = None |
| | else: |
| | self.shortcut = nn.Linear(size_in, size_out, bias=False) |
| | self.norm2 = norm_layer(size_in) |
| |
|
| | self.drop_path = DropPath( |
| | drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | def forward(self, x): |
| | dx = self.mlp(self.norm1(x)) |
| |
|
| | if self.shortcut is not None: |
| | x_s = self.shortcut(self.norm2(x)) |
| | else: |
| | x_s = x |
| |
|
| | return x_s + self.drop_path(dx) |