Source code for world_engine.patch_model

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from .model.nn import rms_norm
from .model.attn import Attn
from .model.world_model import MLPFusion
from torch.nn.attention.flex_attention import flex_attention


def _bf16_u16(x: Tensor) -> Tensor:
    # reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32
    return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF


[docs] class CachedDenoiseStepEmb(nn.Module): """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong)."""
[docs] def __init__(self, base: nn.Module, sigmas: list[float]): super().__init__() device = next(base.parameters()).device levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) # [S] bits = _bf16_u16(levels) # [S] if torch.unique(bits).numel() != bits.numel(): raise ValueError("scheduler_sigmas collide in bf16; caching would be ambiguous") with torch.no_grad(): table = base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() # [S,D] lut = torch.full((65536,), -1, device=device, dtype=torch.int32) lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) self.register_buffer("table", table, persistent=False) # [S,D] bf16 self.register_buffer("lut", lut, persistent=False) # [65536] int32 self.register_buffer("oob", torch.tensor(bits.numel(), device=device, dtype=torch.int32), persistent=False)
[docs] def forward(self, sigma: Tensor) -> Tensor: if sigma.dtype is not torch.bfloat16: raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") idx = self.lut[_bf16_u16(sigma)] idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB) return self.table[idx.to(torch.int64)] # [...,D] bf16
[docs] class CachedCondHead(nn.Module): """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong)."""
[docs] def __init__(self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8): super().__init__() table = cached_denoise_step_emb.table # [S,D] bf16 S, D = table.shape with torch.no_grad(): emb = table[:, None, :] # [S,1,D] cache = torch.stack([t.squeeze(1) for t in base(emb)], 0).to(torch.bfloat16).contiguous() # [6,S,D] # pick a single embedding dimension whose bf16 bits uniquely identify sigma key_dim = None for d in range(min(D, max_key_dims)): b = _bf16_u16(table[:, d]) if torch.unique(b).numel() == S: key_dim = d key_bits = b break if key_dim is None: raise ValueError("Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims") lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) self.key_dim = int(key_dim) self.register_buffer("cache", cache, persistent=False) # [6,S,D] bf16 self.register_buffer("lut", lut, persistent=False) # [65536] int32 self.register_buffer("oob", torch.tensor(S, device=table.device, dtype=torch.int32), persistent=False)
[docs] def forward(self, cond: Tensor): if cond.dtype is not torch.bfloat16: raise RuntimeError("CachedCondHead expects cond bf16") idx = self.lut[_bf16_u16(cond[..., self.key_dim])] idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB) g = self.cache[:, idx.to(torch.int64)] # [6,...,D] bf16 (or errors) return tuple(g.unbind(0)) # (s0,b0,g0,s1,b1,g1)
[docs] def patch_cached_noise_conditioning(model) -> None: # Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval() cached_denoise_step_emb = CachedDenoiseStepEmb(model.denoise_step_emb, model.config.scheduler_sigmas) model.denoise_step_emb = cached_denoise_step_emb for blk in model.transformer.blocks: blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb)
[docs] class MergedQKVAttn(Attn):
[docs] def __init__(self, src: Attn, config): super().__init__(config, src.layer_idx) # makes fresh q/k/v/out/etc self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype) self.load_state_dict(src.state_dict(), strict=False) # copies trained weights/buffers self.train(src.training) # preserve train/eval mode self.q_out = self.n_heads * self.d_head self.kv_out = self.n_kv_heads * self.d_head self.qkv_proj = nn.Linear( self.q_proj.in_features, self.q_out + 2 * self.kv_out, bias=False, device=self.q_proj.weight.device, dtype=self.q_proj.weight.dtype, ) with torch.no_grad(): self.qkv_proj.weight.copy_(torch.cat( [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 )) del self.q_proj, self.k_proj, self.v_proj
[docs] def forward(self, x, pos_ids, v1, kv_cache): q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) B, T = x.shape[:2] q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2) k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) if self.value_residual: v1 = v if v1 is None else v1 v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) if self.gated_attn: gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) y = y * gates.permute(0, 2, 1).unsqueeze(-1) y = y.transpose(1, 2).reshape(B, T, -1) y = self.out_proj(y) return y, v1
[docs] def patch_Attn_merge_qkv(model) -> None: for name, mod in list(model.named_modules()): if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): model.set_submodule(name, MergedQKVAttn(mod, model.config))
[docs] class SplitMLPFusion(nn.Module): """Packed MLPFusion -> split linears (no cat, quant-friendly)."""
[docs] def __init__(self, src: MLPFusion): super().__init__() D = src.mlp.fc2.in_features dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) with torch.no_grad(): Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) self.fc1_x.weight.copy_(Wx) self.fc1_c.weight.copy_(Wc) self.fc2.weight.copy_(src.mlp.fc2.weight) self.train(src.training)
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, _, D = x.shape L = cond.shape[1] x = x.reshape(B, L, -1, D) return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(1, 2)
[docs] def patch_MLPFusion_split(model) -> None: for name, mod in list(model.named_modules()): if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): model.set_submodule(name, SplitMLPFusion(mod))
[docs] def apply_inference_patches(model) -> None: patch_cached_noise_conditioning(model) patch_Attn_merge_qkv(model) patch_MLPFusion_split(model)