| """
|
| model.py β Liquid Chess Model (LCM) architecture.
|
|
|
| Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks,
|
| distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives.
|
|
|
| Architecture highlights:
|
| - GQA (Grouped Query Attention) with RoPE positional embeddings
|
| - LIV (Local Input-dependent Value) causal convolution blocks
|
| - LRM (Learnable Rate Multipliers) on every block
|
| - Weight tying between embedding and NTP head
|
| - PyTorch SDPA for efficient attention
|
| """
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from config import ChessModelConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| """Root Mean Square Layer Normalization."""
|
|
|
| def __init__(self, d_model: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.ones(d_model))
|
| self.eps = eps
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
|
| return (x / rms) * self.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
| class LIVBlock(nn.Module):
|
| """
|
| Local Input-dependent Value convolution block.
|
|
|
| Each token attends to itself and its nearest neighbors (kernel_size=4)
|
| using double gating. Efficient for capturing local sequential patterns.
|
|
|
| Structure:
|
| input β RMSNorm β project to 3Γ β split (B, C, x)
|
| β B gates x β causal conv β C gates result β project back
|
| β LRM scale β residual add
|
| """
|
|
|
| def __init__(self, config: ChessModelConfig):
|
| super().__init__()
|
| d = config.d_model
|
| k = config.conv_kernel_size
|
|
|
| self.norm = RMSNorm(d)
|
| self.input_proj = nn.Linear(d, 3 * d, bias=False)
|
| self.conv = nn.Conv1d(
|
| in_channels=d, out_channels=d, kernel_size=k,
|
| padding=k - 1, groups=d, bias=False,
|
| )
|
| self.output_proj = nn.Linear(d, d, bias=False)
|
| self.dropout = nn.Dropout(config.dropout)
|
| self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| residual = x
|
| x = self.norm(x)
|
| B, C, x = self.input_proj(x).chunk(3, dim=-1)
|
| x = B * x
|
| x = self.conv(x.transpose(1, 2))
|
| x = x[:, :, :residual.shape[1]]
|
| x = C * x.transpose(1, 2)
|
| x = self.dropout(self.output_proj(x))
|
| if self.lrm is not None:
|
| x = x * self.lrm
|
| return residual + x
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_rope_cache(
|
| seq_len: int, head_dim: int, device: torch.device
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| """Precompute RoPE cosine and sine tables."""
|
| theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
| positions = torch.arange(seq_len, device=device).float()
|
| freqs = torch.outer(positions, theta)
|
| return torch.cos(freqs), torch.sin(freqs)
|
|
|
|
|
| def apply_rope(
|
| x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| ) -> torch.Tensor:
|
| """Apply rotary position embeddings to a query or key tensor."""
|
| x1, x2 = x[..., ::2], x[..., 1::2]
|
| cos = cos.unsqueeze(0).unsqueeze(0)
|
| sin = sin.unsqueeze(0).unsqueeze(0)
|
| return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """SwiGLU feed-forward network."""
|
|
|
| def __init__(self, config: ChessModelConfig):
|
| super().__init__()
|
| d, h = config.d_model, config.ffn_hidden_size
|
| self.gate_proj = nn.Linear(d, h, bias=False)
|
| self.up_proj = nn.Linear(d, h, bias=False)
|
| self.down_proj = nn.Linear(h, d, bias=False)
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
|
|
|
|
|
| class GQABlock(nn.Module):
|
| """
|
| Grouped Query Attention block with SwiGLU FFN and RoPE.
|
| Uses PyTorch's scaled_dot_product_attention for efficiency.
|
| """
|
|
|
| def __init__(self, config: ChessModelConfig):
|
| super().__init__()
|
| d = config.d_model
|
| self.n_heads = config.n_heads
|
| self.n_kv_heads = config.n_kv_heads
|
| self.head_dim = config.head_dim
|
| self.repeats = config.n_heads // config.n_kv_heads
|
|
|
| self.attn_norm = RMSNorm(d)
|
| self.ffn_norm = RMSNorm(d)
|
|
|
| self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
|
| self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
|
| self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
|
| self.o_proj = nn.Linear(d, d, bias=False)
|
|
|
| self.ffn = SwiGLU(config)
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
| self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
|
|
| def forward(
|
| self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| ) -> torch.Tensor:
|
| B, T, _ = x.shape
|
|
|
|
|
| residual = x
|
| x_norm = self.attn_norm(x)
|
|
|
| q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim)
|
| k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
|
| v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
|
|
|
| q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
|
| k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
|
|
|
|
|
| k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
|
| v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
|
| q = q.transpose(1, 2)
|
|
|
| attn_out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| dropout_p=self.dropout.p if self.training else 0.0,
|
| is_causal=True,
|
| ).transpose(1, 2).reshape(B, T, -1)
|
|
|
| attn_out = self.o_proj(attn_out)
|
| if self.attn_lrm is not None:
|
| attn_out = attn_out * self.attn_lrm
|
| x = residual + attn_out
|
|
|
|
|
| residual = x
|
| ffn_out = self.ffn(self.ffn_norm(x))
|
| if self.ffn_lrm is not None:
|
| ffn_out = ffn_out * self.ffn_lrm
|
| return residual + ffn_out
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_layer_types(n_layers: int, n_gqa: int) -> list[str]:
|
| """
|
| Distribute GQA layers evenly through the network using a Bresenham-style
|
| integer accumulator. Avoids floating-point rounding collisions.
|
| Always places a GQA block first.
|
|
|
| Example (16 layers, 6 GQA):
|
| GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA
|
| """
|
| if n_gqa == 0:
|
| return ["liv"] * n_layers
|
| if n_gqa >= n_layers:
|
| return ["gqa"] * n_layers
|
|
|
| layer_types = ["liv"] * n_layers
|
| layer_types[0] = "gqa"
|
| gqa_placed = 1
|
| remaining = n_gqa - 1
|
| slots = n_layers - 1
|
| accumulator = 0
|
|
|
| for i in range(1, n_layers):
|
| accumulator += remaining
|
| if accumulator >= slots:
|
| layer_types[i] = "gqa"
|
| accumulator -= slots
|
| gqa_placed += 1
|
| if gqa_placed == n_gqa:
|
| break
|
|
|
| return layer_types
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ChessModel(nn.Module):
|
| """
|
| Liquid Chess Model (LCM).
|
|
|
| Input: token IDs (batch_size, seq_len)
|
| Output: ntp_logits (batch_size, seq_len, vocab_size) β move generation
|
| top_logits (batch_size, seq_len, vocab_size) β auxiliary training only
|
| """
|
|
|
| def __init__(self, config: ChessModelConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
| self.embedding = nn.Embedding(
|
| config.vocab_size, config.d_model, padding_idx=config.pad_id
|
| )
|
|
|
| layer_types = get_layer_types(config.n_layers, config.n_gqa_layers)
|
| self.blocks = nn.ModuleList([
|
| GQABlock(config) if lt == "gqa" else LIVBlock(config)
|
| for lt in layer_types
|
| ])
|
| self.layer_types = layer_types
|
|
|
| self.norm = RMSNorm(config.d_model)
|
| self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
| self.ntp_head.weight = self.embedding.weight
|
|
|
| freqs_cos, freqs_sin = build_rope_cache(
|
| config.max_seq_len, config.head_dim, device=torch.device("cpu")
|
| )
|
| self.register_buffer("freqs_cos", freqs_cos)
|
| self.register_buffer("freqs_sin", freqs_sin)
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| for module in self.modules():
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
| for name, param in self.named_parameters():
|
| if "o_proj" in name or "down_proj" in name:
|
| nn.init.normal_(param, mean=0.0,
|
| std=0.02 / math.sqrt(2 * self.config.n_layers))
|
|
|
| def forward(
|
| self, token_ids: torch.Tensor
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| B, T = token_ids.shape
|
| assert T <= self.config.max_seq_len, \
|
| f"Sequence length {T} exceeds maximum {self.config.max_seq_len}"
|
|
|
| x = self.embedding(token_ids)
|
| freqs_cos = self.freqs_cos[:T]
|
| freqs_sin = self.freqs_sin[:T]
|
|
|
| for block, lt in zip(self.blocks, self.layer_types):
|
| x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x)
|
|
|
| x = self.norm(x)
|
| return self.ntp_head(x), self.top_head(x)
|
|
|
| def count_parameters(self) -> int:
|
| return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
| if __name__ == "__main__":
|
| from model.config import ChessModelConfig
|
|
|
| config = ChessModelConfig()
|
| model = ChessModel(config)
|
| params = model.count_parameters()
|
| print(f"Parameters: {params:,} ({params/1e6:.1f}M)")
|
|
|
| x = torch.randint(0, config.vocab_size, (2, 255))
|
| ntp, top = model(x)
|
| assert ntp.shape == (2, 255, config.vocab_size)
|
| assert top.shape == (2, 255, config.vocab_size)
|
| print(f"Forward pass: {ntp.shape} β") |