from typing import Optional
import torch
import torch.nn as nn
QUANTS = [None] # TODO: enable specific quant based on model config, which should specify compatible quants [None, "w8a8", "fp8"]
try:
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
QUANTS.append("nvfp4")
except ImportError:
pass
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
def fp4_linear(
a_bf16: torch.Tensor,
b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor,
b_sf_T: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
a_fp4, a_sf = nvfp4_quantize(
a_bf16,
a_global_sf,
sfLayout=SfLayout.layout_128x4,
do_shuffle=False,
)
return mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass")
@fp4_linear.register_fake
def _fp4_linear_fake(
a_bf16: torch.Tensor,
b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor,
b_sf_T: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return torch.empty((a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16)
[docs]
class FP4Linear(nn.Module):
"""FP4 Linear layer using FlashInfer's NVFP4 quantization."""
[docs]
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features = lin.in_features
self.out_features = lin.out_features
# Check alignment requirements for NVFP4 TMA
assert self.in_features % 32 == 0 and self.out_features % 32 == 0, "features % 32 != 0, nvfp4 disallowed"
# Store weight from original linear layer
self.weight = nn.Parameter(lin.weight.detach().clone())
# Cached FP4 weight and scales (populated on first forward)
self._weight_fp4_T: Optional[torch.Tensor] = None
self._weight_scales_T: Optional[torch.Tensor] = None
self._alpha: Optional[torch.Tensor] = None
self._dummy_scale: Optional[torch.Tensor] = None
self._weight_global_sf = None
with torch.no_grad():
# Quantize weights eagerly (no lazy path)
self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32)
weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
weight_amax = weight_bf16.float().abs().nan_to_num().max()
self._weight_global_sf = (1.0) / weight_amax
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
w_fp4, w_sf = nvfp4_quantize(
weight_bf16,
self._weight_global_sf,
sfLayout=SfLayout.layout_128x4,
do_shuffle=False,
)
self._weight_fp4_T = w_fp4.t()
self._weight_scales_T = w_sf.t()
# Warmup flashinfer fp4 graphs
assert self.weight.is_cuda, "Weights need to be on GPU before quantization"
# TODO: test actual shape warmup, might perform better
lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16)
fp4_linear(
lazy_x,
self._weight_fp4_T,
self._dummy_scale,
self._weight_scales_T,
self._alpha,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass using FP4 quantization and FlashInfer GEMM."""
x_flat = x.reshape(-1, x.shape[-1])
y = fp4_linear(
x_flat.to(torch.bfloat16).contiguous(),
self._weight_fp4_T,
self._dummy_scale,
self._weight_scales_T,
self._alpha,
)
return y.reshape(x.shape[:-1] + (-1,))
[docs]
class FP8W8A8Linear(nn.Module):
__constants__ = ("in_features", "out_features")
[docs]
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
f8 = torch.float8_e4m3fn
inv = 1.0 / float(torch.finfo(f8).max)
self._inv = inv
w = lin.weight.detach()
ws = (w.abs().amax() * inv).clamp_min(1e-8).float() # 0-d
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() # row-major
self.register_buffer("wT", wf8.t()) # col-major view (no contiguous)
self.register_buffer("ws", ws)
if lin.bias is None:
self.bias = None
else:
self.register_buffer("bias", lin.bias.detach().to(torch.float16))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
s = x.shape
x2 = x.reshape(-1, s[-1])
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
y = torch._scaled_mm(
xf8, self.wT, xs, self.ws,
bias=self.bias, out_dtype=torch.float16, use_fast_accum=True
)
return y.reshape(*s[:-1], self.out_features).to(x.dtype)
[docs]
class FP8Linear(nn.Module):
[docs]
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
self.bias = (
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
if lin.bias is not None
else None
)
w_amax = lin.weight.data.clone().amax().float().squeeze()
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
self.register_buffer("w_amax", w_amax)
self.register_buffer("weightT", w.t())
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass using FP8 matmul.
Args:
x: Input tensor of shape [..., in_features] (flattens if > 2D)
Returns:
Output tensor of shape [..., out_features] in BF16 format, unflattened if input is > 2D
"""
# Convert input to FP8 e4m3
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
result = torch._scaled_mm(
x_fp8,
self.weightT,
bias=self.bias,
scale_a=self.dummy_scale,
scale_b=self.w_amax,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
return result.reshape(x.shape[:-1] + (-1,))
[docs]
def quantize_model(model: nn.Module, quant: str):
if quant is None:
return model
def eligible(m: nn.Module) -> bool:
w = getattr(m, "weight", None)
if not isinstance(m, nn.Linear):
return False
if getattr(w, "dtype", None) != torch.bfloat16:
return False
o, k = w.shape
return (o % 32 == 0) and (k % 32 == 0)
new_linear = {
"w8a8": FP8W8A8Linear,
"nvfp4": FP4Linear,
"fp8": FP8Linear,
}[quant]
for name, child in model.named_children():
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(
child, quant
)
return model