Source code for world_engine.world_engine

from typing import Dict, Optional, Set, Tuple
import torch
from torch import Tensor
from dataclasses import dataclass, field

from .model import WorldModel, StaticKVCache, PromptEncoder
from .ae import InferenceAE
from .patch_model import apply_inference_patches
from .quantize import quantize_model


# Global torch optimizations
torch._dynamo.config.recompile_limit = 64
torch.set_float32_matmul_precision("medium")  # low: bf16, medium: tf32, high: fp32

# fix graph break:
torch._dynamo.config.capture_scalar_outputs = True


[docs] @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity scroll_wheel: int = 0 # bwd, stationary, or fwd -> (-1, 0, 1)
[docs] class WorldEngine:
[docs] def __init__( self, model_uri: str, quant: Optional[str] = None, model_config_overrides: Optional[Dict] = None, device=None, dtype=torch.bfloat16, ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml quant: None | w8a8 | nvfp4 """ self.device, self.dtype = device, dtype self.model_cfg = WorldModel.load_config(model_uri) if model_config_overrides: self.model_cfg.merge_with(model_config_overrides) # Model self.vae = InferenceAE.from_pretrained(self.model_cfg.ae_uri, device=device, dtype=dtype) self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: self.prompt_encoder = PromptEncoder("google/umt5-xl", dtype=dtype).to(device).eval() # TODO: dont hardcode self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg).to(device=device, dtype=dtype).eval() apply_inference_patches(self.model) if quant is not None: quantize_model(self.model, quant) # Inference Scheduler self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, device=device, dtype=dtype) pH, pW = getattr(self.model_cfg, "patch", [1, 1]) self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW # State self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device) self.frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) # Static input context tensors self._ctx = { "button": torch.zeros((1, 1, self.model_cfg.n_buttons), device=device, dtype=dtype), "mouse": torch.zeros((1, 1, 2), device=device, dtype=dtype), "scroll": torch.zeros((1, 1, 1), device=device, dtype=dtype), "frame_timestamp": torch.empty((1, 1), device=device, dtype=torch.long), } self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None}
@torch.inference_mode() def reset(self): """Reset state for new generation""" self.kv_cache.reset() self.frame_ts.zero_() for v in self._ctx.values(): v.zero_()
[docs] def set_prompt(self, prompt: str): """Apply text conditioning for T2V""" if self.prompt_encoder is None: raise RuntimeError("prompt_conditioning enabled but prompt_encoder is not initialized") self._prompt_ctx["prompt_emb"], self._prompt_ctx["prompt_pad_mask"] = self.prompt_encoder([prompt])
@torch.inference_mode() def append_frame(self, img: Tensor, ctrl: CtrlInput = None): assert img.dtype == torch.uint8, img.dtype x0 = self.vae.encode(img).unsqueeze(1) inputs = self.prep_inputs(x=x0, ctrl=ctrl) self._cache_pass(x0, inputs, self.kv_cache) return img @torch.inference_mode() def gen_frame(self, ctrl: CtrlInput = None, return_img: bool = True): x = torch.randn(self.frm_shape, device=self.device, dtype=self.dtype) inputs = self.prep_inputs(x=x, ctrl=ctrl) x0 = self._denoise_pass(x, inputs, self.kv_cache).clone() self._cache_pass(x0, inputs, self.kv_cache) return (self.vae.decode(x0.squeeze(1)) if return_img else x0.squeeze(1)) @torch.compile def _prep_inputs(self, x, ctrl=None): self._ctx["mouse"][0, 0, 0] = ctrl.mouse[0] self._ctx["mouse"][0, 0, 1] = ctrl.mouse[1] self._ctx["scroll"][0, 0, 0] = ctrl.scroll_wheel self._ctx["frame_timestamp"].copy_(self.frame_ts) self.frame_ts.add_(1) return self._ctx
[docs] def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() self._ctx["button"].zero_() if ctrl.button: self._ctx["button"][..., list(ctrl.button)] = 1.0 ctrl.mouse = torch.tensor(ctrl.mouse, device=x.device, dtype=self._ctx["mouse"].dtype) ctrl.scroll_wheel = torch.sign(torch.tensor(ctrl.scroll_wheel, device=x.device, dtype=self._ctx["scroll"].dtype)) ctx = self._prep_inputs(x, ctrl) # prepare prompt conditioning if self.model_cfg.prompt_conditioning is None: return ctx if self._prompt_ctx["prompt_emb"] is None: self.set_prompt("An explorable world") return {**ctx, **self._prompt_ctx}
@torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) def _denoise_pass(self, x, ctx: Dict[str, Tensor], kv_cache): kv_cache.set_frozen(True) sigma = x.new_empty((x.size(0), x.size(1))) for step_sig, step_dsig in zip(self.scheduler_sigmas, self.scheduler_sigmas.diff()): v = self.model(x, sigma.fill_(step_sig), **ctx, kv_cache=kv_cache) x = x + step_dsig * v return x @torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) def _cache_pass(self, x, ctx: Dict[str, Tensor], kv_cache): """Side effect: updates kv cache""" kv_cache.set_frozen(False) self.model(x, x.new_zeros((x.size(0), x.size(1))), **ctx, kv_cache=kv_cache)