Source code for world_engine.model.world_model

from typing import Optional, List
from torch import Tensor

import einops as eo
from tensordict import TensorDict
import math

import torch
from torch import nn
import torch.nn.functional as F

from .attn import Attn, CrossAttention
from .nn import AdaLN, ada_gate, ada_rmsnorm, NoiseConditioner
from .base_model import BaseModel


[docs] class PromptEncoder(nn.Module): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" """Callable for text -> UMT5 embedding"""
[docs] def __init__(self, model_id="google/umt5-xl", dtype=torch.bfloat16): from transformers import AutoTokenizer, UMT5EncoderModel super().__init__() self.dtype = dtype self.tok = AutoTokenizer.from_pretrained(model_id) self.encoder = UMT5EncoderModel.from_pretrained(model_id, torch_dtype=dtype).eval()
@torch.compile def encode(self, inputs): return self.encoder(**inputs).last_hidden_state @torch.inference_mode() def forward(self, texts: List[str]): import ftfy texts = [ftfy.fix_text(t) for t in texts] inputs = self.tok( texts, return_tensors="pt", padding="max_length", truncation=True, max_length=512 ).to(self.encoder.device) attn = inputs["attention_mask"] # [B,S] 1=token 0=pad emb = self.encode(inputs).to(self.dtype) emb = emb * attn.unsqueeze(-1).type_as(emb) # preserve old behavior pad_mask = attn.eq(0) return emb, pad_mask
[docs] def rms_norm(x: torch.Tensor) -> torch.Tensor: return F.rms_norm(x, (x.size(-1),))
[docs] class CFG(nn.Module):
[docs] def __init__(self, d_model: int, dropout: float): super().__init__() self.dropout = dropout self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
[docs] def forward(self, x: torch.Tensor, is_conditioned: Optional[bool] = None) -> torch.Tensor: """ x: [B, L, D] is_conditioned: - None: training-style random dropout - bool: whole batch conditioned / unconditioned at sampling """ B, L, _ = x.shape null = self.null_emb.expand(B, L, -1) # training-style dropout OR unspecified if self.training or is_conditioned is None: if self.dropout == 0.0: return x drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1] return torch.where(drop, null, x) # sampling-time switch return x if is_conditioned else null
[docs] class MLP(nn.Module):
[docs] def __init__(self, dim_in, dim_middle, dim_out): super().__init__() self.fc1 = nn.Linear(dim_in, dim_middle, bias=False) self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
[docs] def forward(self, x): return self.fc2(F.silu(self.fc1(x)))
[docs] class ControllerInputEmbedding(nn.Module):
[docs] def __init__(self, config): super().__init__() self.mlp = MLP( config.n_buttons + 3, # mouse velocity (x,y) + scroll sign config.d_model * config.mlp_ratio, config.d_model )
[docs] def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): x = torch.cat((mouse, button, scroll), dim=-1) return self.mlp(x)
[docs] class MLPFusion(nn.Module): """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond])"""
[docs] def __init__(self, config): super().__init__() self.mlp = MLP(2 * config.d_model, config.d_model, config.d_model)
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, _, D = x.shape L = cond.shape[1] Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D] x = x.view(B, L, -1, D) h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze(2) # broadcast, no repeat/cat h = F.silu(h) y = F.linear(h, self.mlp.fc2.weight) return y.flatten(1, 2)
[docs] class CondHead(nn.Module): """Per-layer conditioning head: bias_in → SiLU → Linear → chunk(n_cond).""" n_cond = 6
[docs] def __init__(self, config): super().__init__() self.bias_in = nn.Parameter(torch.zeros(config.d_model)) if config.noise_conditioning == "wan" else None self.cond_proj = nn.ModuleList( [nn.Linear(config.d_model, config.d_model, bias=False) for _ in range(self.n_cond)] )
[docs] def forward(self, cond): cond = cond + self.bias_in if self.bias_in is not None else cond h = F.silu(cond) return tuple(p(h) for p in self.cond_proj)
[docs] class WorldDiTBlock(nn.Module):
[docs] def __init__(self, config, layer_idx): super().__init__() self.config = config self.attn = Attn(config, layer_idx) self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) self.cond_head = CondHead(config) do_prompt_cond = config.prompt_conditioning is not None and layer_idx % config.prompt_conditioning_period == 0 self.prompt_cross_attn = CrossAttention(config, config.prompt_embedding_dim) if do_prompt_cond else None do_ctrl_cond = layer_idx % config.ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(config) if do_ctrl_cond else None
[docs] def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): """ 0) Causal Frame Attention 1) Frame->CTX Cross Attention 2) MLP """ s0, b0, g0, s1, b1, g1 = self.cond_head(cond) # Self / Causal Attention residual = x x = ada_rmsnorm(x, s0, b0) x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) x = ada_gate(x, g0) + residual # Cross Attention Prompt Conditioning if self.prompt_cross_attn is not None: x = self.prompt_cross_attn( rms_norm(x), context=rms_norm(ctx["prompt_emb"]), context_pad_mask=ctx["prompt_pad_mask"], ) + x # MLPFusion Controller Conditioning if self.ctrl_mlpfusion is not None: x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x # MLP x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x return x, v
[docs] class WorldDiT(nn.Module):
[docs] def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList([WorldDiTBlock(config, idx) for idx in range(config.n_layers)]) if self.config.noise_conditioning in ("dit_air", "wan"): ref_proj = self.blocks[0].cond_head.cond_proj for blk in self.blocks[1:]: for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): blk_mod.weight = ref_mod.weight # Shared RoPE buffers ref_rope = self.blocks[0].attn.rope for blk in self.blocks[1:]: blk.attn.rope = ref_rope
[docs] def forward(self, x, pos_ids, cond, ctx, kv_cache=None): v = None for i, block in enumerate(self.blocks): x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) return x
[docs] class WorldModel(BaseModel): """ WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser Denoise a frame given - All previous frames - The prompt embedding - The controller input embedding - The current noise level """
[docs] def __init__(self, config): super().__init__() self.config = config assert config.tokens_per_frame == config.height * config.width self.denoise_step_emb = NoiseConditioner(config.d_model) self.ctrl_emb = ControllerInputEmbedding(config) if config.ctrl_conditioning is not None: self.ctrl_cfg = CFG(config.d_model, config.ctrl_cond_dropout) if config.prompt_conditioning is not None: self.prompt_cfg = CFG(config.prompt_embedding_dim, config.prompt_cond_dropout) self.transformer = WorldDiT(config) self.patch = tuple(getattr(config, "patch", (1, 1))) C, D = config.channels, config.d_model self.patchify = nn.Conv2d(C, D, kernel_size=self.patch, stride=self.patch, bias=False) self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) self.out_norm = AdaLN(config.d_model) # Cached 1-frame pos_ids (buffers + cached TensorDict view) T = config.tokens_per_frame idx = torch.arange(T, dtype=torch.long) self.register_buffer("_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False) self.register_buffer("_y_pos_1f", idx.div(config.width, rounding_mode="floor"), persistent=False) self.register_buffer("_x_pos_1f", idx.remainder(config.width), persistent=False)
[docs] def forward( self, x: Tensor, sigma: Tensor, frame_timestamp: Tensor, prompt_emb: Optional[Tensor] = None, prompt_pad_mask: Optional[Tensor] = None, mouse: Optional[Tensor] = None, button: Optional[Tensor] = None, scroll: Optional[Tensor] = None, kv_cache=None, ctrl_cond: Optional[bool] = None, prompt_cond: Optional[bool] = None, ): """ x: [B, N, C, H, W], sigma: [B, N] frame_timestamp: [B, N] prompt_emb: [B, P, D] controller_inputs: [B, N, I] ctrl_cond: Inference only, whether to apply controller conditioning """ B, N, C, H, W = x.shape ph, pw = self.patch assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" Hp, Wp = H // ph, W // pw torch._assert(Hp * Wp == self.config.tokens_per_frame, f"{Hp} * {Wp} != {self.config.tokens_per_frame}") torch._assert(B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1") self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) pos_ids = TensorDict( {"t_pos": self._t_pos_1f[None], "y_pos": self._y_pos_1f[None], "x_pos": self._x_pos_1f[None]}, batch_size=[1, self._t_pos_1f.numel()], ) cond = self.denoise_step_emb(sigma) # [B, N, d] assert button is not None ctx = { "ctrl_emb": self.ctrl_emb(mouse, button, scroll), "prompt_emb": prompt_emb, "prompt_pad_mask": prompt_pad_mask, } D = self.unpatchify.in_features x = self.patchify(x.reshape(B * N, C, H, W)) x = eo.rearrange(x.view(B, N, D, Hp, Wp), 'b n d hp wp -> b (n hp wp) d') x = self.transformer(x, pos_ids, cond, ctx, kv_cache) x = F.silu(self.out_norm(x, cond)) x = eo.rearrange( self.unpatchify(x), 'b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)', n=N, hp=Hp, wp=Wp, ph=ph, pw=pw ) return x