| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange, reduce |
| |
|
| | _EPS = 1e-8 |
| |
|
| |
|
| | class DifferentiableEntropyFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, zq, basis, K, eps): |
| | zb = (zq + 1) / 2 |
| | zi = ((zb * basis).sum(-1)).to(torch.int64) |
| | cnt = torch.scatter_reduce( |
| | torch.zeros(2**K, device=zq.device, dtype=zq.dtype), |
| | 0, |
| | zi.flatten(), |
| | torch.ones_like(zi.flatten()).to(zq.dtype), |
| | "sum", |
| | ) |
| | prob = (cnt + eps) / (cnt + eps).sum() |
| | H = torch.special.entr(prob).sum() |
| | ctx.save_for_backward(zq, zi, prob) |
| | ctx.K = K |
| | return H |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | zq, zi, prob = ctx.saved_tensors |
| | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K |
| | reord_grad = grad_array[zi.flatten()].reshape(zi.shape) |
| | grad_input = reord_grad.unsqueeze(-1) * zq |
| | return grad_input, None, None, None, None |
| |
|
| |
|
| | def codebook_entropy(zq, basis, K, eps=1e-8): |
| | return DifferentiableEntropyFunction.apply(zq, basis, K, eps) |
| |
|
| |
|
| | class BinarySphericalQuantizer(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int = 18, |
| | group_size: int = 9, |
| | soft_entropy: bool = True, |
| | beta: float = 0.0, |
| | gamma_0: float = 1.0, |
| | gamma_1: float = 1.0, |
| | input_format: str = "bchw", |
| | persample_entropy_compute: str = "group", |
| | l2_norm: bool = True, |
| | inv_temperature: float = 100.0, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.group_size = group_size |
| | assert embed_dim % group_size == 0, "embed_dim must be divisible by group_size" |
| | self.soft_entropy = soft_entropy |
| | self.beta = beta |
| | self.gamma_0 = gamma_0 |
| | self.gamma_1 = gamma_1 |
| | assert input_format in ["bchw", "blc"] |
| | self.input_format = input_format |
| | assert persample_entropy_compute in [ |
| | "group", |
| | "analytical", |
| | ], "persample_entropy_compute must be either 'group' or 'analytical'" |
| | self.persample_entropy_compute = persample_entropy_compute |
| | self.l2_norm = l2_norm |
| | self.inv_temperature = inv_temperature |
| |
|
| | self.register_buffer("basis", 2 ** torch.arange(embed_dim - 1, -1, -1), persistent=False) |
| | self.register_buffer( |
| | "group_basis", 2 ** torch.arange(group_size - 1, -1, -1), persistent=False |
| | ) |
| |
|
| | group_codes = torch.arange(2**self.group_size) |
| | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] |
| | self.register_buffer("group_codebook", group_codebook, persistent=False) |
| |
|
| | def quantize(self, z): |
| | assert ( |
| | z.shape[-1] == self.embed_dim |
| | ), f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" |
| | zhat = torch.where(z > 0, torch.ones_like(z), -torch.ones_like(z)) |
| | return z + (zhat - z).detach() |
| |
|
| | def forward(self, z): |
| | if self.input_format == "bchw": |
| | z = rearrange(z, "b c h w -> b h w c") |
| | zq = self.quantize(z) |
| |
|
| | indices = self.codes_to_indexes(zq.detach()) |
| | group_indices = self.codes_to_group_indexes(zq.detach()) |
| |
|
| | if not self.training: |
| | used_codes = torch.unique(indices, return_counts=False) |
| | else: |
| | used_codes = None |
| |
|
| | if self.soft_entropy: |
| | persample_entropy, cb_entropy = self.soft_entropy_loss(z) |
| | else: |
| | persample_entropy, cb_entropy = self.hard_entropy_loss(z) |
| | entropy_penalty = self.gamma_0 * persample_entropy - self.gamma_1 * cb_entropy |
| |
|
| | q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
| | zq = zq * q_scale |
| | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) |
| |
|
| | if self.input_format == "bchw": |
| | zq = rearrange(zq, "b h w c -> b c h w") |
| |
|
| | return ( |
| | zq, |
| | commit_loss + entropy_penalty / self.inv_temperature, |
| | { |
| | "H": cb_entropy, |
| | "used_codes": used_codes, |
| | "indices": indices, |
| | "group_indices": group_indices, |
| | }, |
| | ) |
| |
|
| | def soft_entropy_loss(self, z): |
| | group_codebook = self.group_codebook / (self.embed_dim**0.5 if self.l2_norm else 1) |
| | divided_z = rearrange(z, "... (g c) -> ... g c", c=self.group_size) |
| |
|
| | if self.persample_entropy_compute == "group": |
| | distance = -2 * torch.einsum("... g c, d c -> ... g d", divided_z, group_codebook) |
| | prob = (-distance * self.inv_temperature).softmax(dim=-1) |
| | persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
| | else: |
| | p = torch.sigmoid( |
| | -4 * z / (self.embed_dim**0.5 if self.l2_norm else 1) * self.inv_temperature |
| | ) |
| | prob = torch.stack([p, 1 - p], dim=-1) |
| | persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
| |
|
| | |
| | avg_prob = reduce(prob, "... g d -> g d", "mean") |
| | cb_entropy = torch.special.entr(avg_prob + _EPS).sum() |
| |
|
| | return persample_entropy, cb_entropy |
| |
|
| | def hard_entropy_loss(self, z): |
| | zb = ((z + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) |
| | prob_per_dim = zb.sum(1) / zb.shape[1] |
| | prob = torch.stack([prob_per_dim, 1 - prob_per_dim], dim=-1) |
| | persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
| | cb_entropy = codebook_entropy(z, self.basis, self.embed_dim) |
| |
|
| | return persample_entropy, cb_entropy |
| |
|
| | def codes_to_indexes(self, zhat): |
| | """Converts a `code` to an index in the codebook. |
| | Args: |
| | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} |
| | """ |
| | assert ( |
| | zhat.shape[-1] == self.embed_dim |
| | ), f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" |
| | return ((zhat.int() + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) |
| |
|
| | def codes_to_group_indexes(self, zhat): |
| | """Converts a `code` to a list of indexes (in groups) in the codebook. |
| | Args: |
| | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} |
| | """ |
| | zhat_in_group = rearrange(zhat, "b ... (g c) -> b ... g c", c=self.group_size) |
| | return ((zhat_in_group.int() + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) |
| |
|
| | def indexes_to_codes(self, indices): |
| | """Inverse of `codes_to_indexes`.""" |
| | indices = indices.unsqueeze(-1) |
| | codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2) |
| | return codes_non_centered * 2 - 1 |
| |
|
| | def group_indexes_to_codes(self, group_indices): |
| | """Inverse of `codes_to_group_indexes`.""" |
| | group_indices = group_indices.unsqueeze(-1) |
| | codes_non_centered = torch.remainder(torch.floor_divide(group_indices, self.group_basis), 2) |
| | codes_non_centered = rearrange(codes_non_centered, "b ... g c -> b ... (g c)") |
| | return codes_non_centered * 2 - 1 |
| |
|
| | def get_group_codebook_entry(self, group_indices, one_hot=False): |
| | """ |
| | Args: |
| | group_indices: A tensor of shape (B, L, G, C) containing the group indices. |
| | """ |
| | if one_hot: |
| | z_q = group_indices @ self.group_codebook |
| | else: |
| | z_q = self.group_indexes_to_codes(group_indices) |
| | q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
| | z_q = z_q * q_scale |
| | if self.input_format == "bchw": |
| | h, w = int(z_q.shape[1] ** 0.5) |
| | assert h * w == z_q.shape[1], "Invalid sequence length" |
| | z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h) |
| | return z_q |
| |
|
| | def get_codebook_entry(self, indices, one_hot=False): |
| | """ |
| | Args: |
| | group_indices: A tensor of shape (B, L, C) containing the indices. |
| | """ |
| | if one_hot: |
| | assert self.embed_dim == self.group_size, "one_hot is only supported for group_size == embed_dim" |
| | z_q = indices @ self.group_codebook |
| | else: |
| | z_q = self.indexes_to_codes(indices) |
| | q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
| | z_q = z_q * q_scale |
| | if self.input_format == "bchw": |
| | h, w = int(z_q.shape[1] ** 0.5) |
| | assert h * w == z_q.shape[1], "Invalid sequence length" |
| | z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h) |
| | return z_q |
| |
|