| | from dataclasses import dataclass |
| |
|
| | import numpy as np |
| | import timm |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from segmentation_models_pytorch.base import SegmentationHead |
| | from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder |
| | from timm.layers.create_act import create_act_layer |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import SemanticSegmenterOutput |
| |
|
| | from .convlstm import ConvLSTM |
| |
|
| |
|
| | class ACTUConfig(PretrainedConfig): |
| | model_type = "actu" |
| |
|
| | def __init__( |
| | self, |
| | |
| | in_channels: int = 3, |
| | kernel_size: tuple[int, int] = (3, 3), |
| | padding="same", |
| | stride=(1, 1), |
| | backbone="resnet34", |
| | bias=True, |
| | batch_first=True, |
| | bidirectional=False, |
| | original_resolution=(256, 256), |
| | act_layer="sigmoid", |
| | n_classes=1, |
| | |
| | use_dem_input: bool = False, |
| | use_climate_branch: bool = False, |
| | |
| | climate_seq_len=5, |
| | climate_input_dim=6, |
| | lstm_hidden_dim=128, |
| | num_lstm_layers=1, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.in_channels = in_channels |
| | self.kernel_size = kernel_size |
| | self.padding = padding |
| | self.stride = stride |
| | self.backbone = backbone |
| | self.bias = bias |
| | self.batch_first = batch_first |
| | self.bidirectional = bidirectional |
| | self.original_resolution = original_resolution |
| | self.act_layer = act_layer |
| | self.n_classes = n_classes |
| |
|
| | |
| | self.use_dem_input = use_dem_input |
| | self.use_climate_branch = use_climate_branch |
| | self.climate_seq_len = climate_seq_len |
| | self.climate_input_dim = climate_input_dim |
| | self.lstm_hidden_dim = lstm_hidden_dim |
| | self.num_lstm_layers = num_lstm_layers |
| |
|
| | |
| | if self.use_dem_input: |
| | self.in_channels += 1 |
| |
|
| |
|
| | class ACTUForImageSegmentation(PreTrainedModel): |
| | config_class = ACTUConfig |
| |
|
| | def __init__(self, config: ACTUConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.encoder: nn.Module = timm.create_model( |
| | config.backbone, features_only=True, in_chans=config.in_channels |
| | ) |
| |
|
| | with torch.no_grad(): |
| | dummy_input_channels = config.in_channels |
| | dummy_input = torch.randn( |
| | 1, dummy_input_channels, *config.original_resolution, device=self.device |
| | ) |
| | embs = self.encoder(dummy_input) |
| | self.embs_shape = [e.shape for e in embs] |
| | self.encoder_channels = [e[1] for e in self.embs_shape] |
| |
|
| | self.convlstm = nn.ModuleList( |
| | [ |
| | ConvLSTM( |
| | in_channels=shape[1], |
| | hidden_channels=shape[1], |
| | kernel_size=config.kernel_size, |
| | padding=config.padding, |
| | stride=config.stride, |
| | bias=config.bias, |
| | batch_first=config.batch_first, |
| | bidirectional=config.bidirectional, |
| | ) |
| | for shape in self.embs_shape |
| | ] |
| | ) |
| |
|
| | if self.config.use_climate_branch: |
| | self.climate_branch = ClimateBranchLSTM( |
| | output_shapes=[e[1:] for e in self.embs_shape], |
| | lstm_hidden_dim=config.lstm_hidden_dim, |
| | climate_seq_len=config.climate_seq_len, |
| | climate_input_dim=config.climate_input_dim, |
| | num_lstm_layers=config.num_lstm_layers, |
| | ) |
| | self.fusers = nn.ModuleList( |
| | GatedFusion(enc, enc) for enc in self.encoder_channels |
| | ) |
| |
|
| | self.decoder = UnetDecoder( |
| | encoder_channels=[1] + self.encoder_channels, |
| | decoder_channels=self.encoder_channels[::-1], |
| | n_blocks=len(self.encoder_channels), |
| | ) |
| |
|
| | self.seg_head = nn.Sequential( |
| | SegmentationHead( |
| | in_channels=self.encoder_channels[0], |
| | out_channels=config.n_classes, |
| | ), |
| | create_act_layer(config.act_layer, inplace=True), |
| | ) |
| |
|
| | def forward( |
| | self, |
| | pixel_values: torch.Tensor, |
| | climate: torch.Tensor = None, |
| | dem: torch.Tensor = None, |
| | labels: torch.Tensor = None, |
| | **kwargs, |
| | ) -> SemanticSegmenterOutput: |
| | b, t = pixel_values.shape[:2] |
| | original_size = pixel_values.shape[-2:] |
| |
|
| | |
| | if self.config.use_dem_input: |
| | if dem is None: |
| | raise ValueError( |
| | "DEM tensor must be provided when use_dem_input is True." |
| | ) |
| | dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t) |
| | pixel_values = torch.cat([pixel_values, dem_repeated], dim=2) |
| |
|
| | |
| | encoded_sequence = self._encode_images(pixel_values) |
| |
|
| | |
| | if self.config.use_climate_branch: |
| | if climate is None: |
| | raise ValueError( |
| | "Climate tensor must be provided when use_climate_branch is True." |
| | ) |
| |
|
| | climate_features = self.climate_branch(climate) |
| |
|
| | |
| | encoded_sequence_reshaped = [ |
| | rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence |
| | ] |
| | climate_features_reshaped = [ |
| | rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features |
| | ] |
| |
|
| | |
| | fused_features = [ |
| | fuser(img, clim) |
| | for fuser, img, clim in zip( |
| | self.fusers, encoded_sequence_reshaped, climate_features_reshaped |
| | ) |
| | ] |
| |
|
| | |
| | encoded_sequence = [ |
| | rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features |
| | ] |
| |
|
| | |
| | temporal_features = self._encode_timeseries(encoded_sequence) |
| |
|
| | |
| | logits = self._decode(temporal_features, size=original_size) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits, labels.float().unsqueeze(1)) |
| |
|
| | return SemanticSegmenterOutput( |
| | loss=loss, |
| | logits=logits, |
| | ) |
| |
|
| | def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]: |
| | B = x.size(0) |
| | encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w")) |
| | return [ |
| | rearrange(frames, "(b t) c h w -> b t c h w", b=B) |
| | for frames in encoded_frames |
| | ] |
| |
|
| | def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]: |
| | outs = [] |
| | for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))): |
| | lstm_out, (_, _) = convlstm(encoded) |
| | outs.append(lstm_out[:, -1, :, :, :]) |
| | return outs |
| |
|
| | def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: |
| | trend_map = self.decoder(*[None] + x[::-1]) |
| | trend_map = self.seg_head(trend_map) |
| | trend_map = F.interpolate( |
| | trend_map, size=size, mode="bilinear", align_corners=False |
| | ) |
| | return trend_map |
| |
|
| |
|
| | class ClimateBranchLSTM(nn.Module): |
| | """ |
| | Processes climate time series data using an LSTM. |
| | Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5) |
| | Output shape: (B, T, output_dim) -> e.g., (B, 5, 128) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | output_shapes: list[tuple[int, int, int]], |
| | climate_input_dim=5, |
| | climate_seq_len=6, |
| | lstm_hidden_dim=64, |
| | num_lstm_layers=1, |
| | ): |
| | super().__init__() |
| | self.climate_seq_len = climate_seq_len |
| | self.climate_input_dim = climate_input_dim |
| | self.lstm_hidden_dim = lstm_hidden_dim |
| | self.num_lstm_layers = num_lstm_layers |
| | self.proj_dim = 128 |
| | self.output_shapes = output_shapes |
| |
|
| | self.lstm = nn.LSTM( |
| | input_size=climate_input_dim, |
| | hidden_size=lstm_hidden_dim, |
| | num_layers=num_lstm_layers, |
| | batch_first=True, |
| | dropout=0.3 if num_lstm_layers > 1 else 0, |
| | bidirectional=False, |
| | ) |
| |
|
| | |
| | self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim) |
| |
|
| | self.upsamples = nn.ModuleList( |
| | _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes |
| | ) |
| |
|
| | def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]: |
| | |
| | B_img, B_cli, T, C = climate_data.shape |
| |
|
| | |
| | lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C") |
| |
|
| | |
| | _, (hidden, _) = self.lstm.forward(lstm_input) |
| | |
| | last_hidden = ( |
| | hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1] |
| | ) |
| | if last_hidden.ndim == 3: |
| | last_hidden = hidden.mean(dim=0) |
| |
|
| | |
| | climate_features = self.fc(last_hidden) |
| | climate_features = rearrange(climate_features, "b c -> b c 1 1") |
| | climate_features = [ |
| | rearrange( |
| | u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli |
| | ) |
| | for u in self.upsamples |
| | ] |
| |
|
| | return climate_features |
| |
|
| |
|
| | class GatedFusion(nn.Module): |
| | def __init__(self, img_channels, clim_channels): |
| | super().__init__() |
| | self.gate = nn.Sequential( |
| | nn.Sequential( |
| | nn.Conv2d( |
| | img_channels + clim_channels, img_channels, kernel_size=3, padding=1 |
| | ), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(img_channels, img_channels, kernel_size=1), |
| | nn.Sigmoid(), |
| | ) |
| | ) |
| |
|
| | def forward(self, img_feat, clim_feat): |
| | gate = self.gate(torch.cat([img_feat, clim_feat], dim=1)) |
| | return gate * img_feat + (1 - gate) * clim_feat |
| |
|
| |
|
| | def _build_upsampler( |
| | in_channels: int, target_channels: int, target_h: int |
| | ) -> nn.Sequential: |
| | layers = [] |
| | current_h = 1 |
| |
|
| | |
| | layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()] |
| |
|
| | |
| | while current_h < target_h: |
| | next_h = min(current_h * 2, target_h) |
| | layers += [ |
| | nn.Upsample(scale_factor=2, mode="nearest"), |
| | nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1), |
| | nn.GELU(), |
| | ] |
| | current_h = next_h |
| |
|
| | return nn.Sequential(*layers) |
| |
|