Source code for world_engine.ae_nn

import torch
from torch import nn
import torch.nn.functional as F

from torch.nn.utils.parametrizations import weight_norm
import math

# === General Blocks ===

[docs] def WeightNormConv2d(*args, **kwargs): return weight_norm(nn.Conv2d(*args, **kwargs))
[docs] class ResBlock(nn.Module):
[docs] def __init__(self, ch): super().__init__() hidden = 2 * ch # 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64) n_grps = max(1, hidden // 16) self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0) self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps) self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False) self.act1 = nn.LeakyReLU(inplace=False) self.act2 = nn.LeakyReLU(inplace=False)
[docs] def forward(self, x): h = self.conv1(x) h = self.act1(h) h = self.conv2(h) h = self.act2(h) h = self.conv3(h) return x + h
# === Encoder ===
[docs] class LandscapeToSquare(nn.Module): # Strict assumption of 360p
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
[docs] def forward(self, x): x = F.interpolate(x, (512, 512), mode = 'bicubic') x = self.proj(x) return x
[docs] class Downsample(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias = False)
[docs] def forward(self, x): x = F.interpolate(x, scale_factor = 0.5, mode = 'bicubic') x = self.proj(x) return x
[docs] class DownBlock(nn.Module):
[docs] def __init__(self, ch_in, ch_out, num_res=1): super().__init__() self.down = Downsample(ch_in, ch_out) blocks = [] for _ in range(num_res): blocks.append(ResBlock(ch_in)) self.blocks = nn.ModuleList(blocks)
[docs] def forward(self, x): for block in self.blocks: x = block(x) x = self.down(x) return x
[docs] class SpaceToChannel(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
[docs] def forward(self, x): x = self.proj(x) x = F.pixel_unshuffle(x, 2).contiguous() return x
[docs] class ChannelAverage(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) self.grps = ch_in // ch_out self.scale = (self.grps) ** 0.5
[docs] def forward(self, x): res = x x = self.proj(x.contiguous()) # [b, ch_out, h, w] # Residual goes through channel avg res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous() res = res.mean(dim=1) * self.scale # [b, ch_out, h, w] return res + x
# === Decoder ===
[docs] class SquareToLandscape(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
[docs] def forward(self, x): x = self.proj(x) # TODO This ordering is wrong for both x = F.interpolate(x, (360, 640), mode = 'bicubic') return x
[docs] class Upsample(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d( ch_in, ch_out, 1, 1, 0, bias=False )
[docs] def forward(self, x): x = self.proj(x) x = F.interpolate(x, scale_factor = 2.0, mode = 'bicubic') return x
[docs] class UpBlock(nn.Module):
[docs] def __init__(self, ch_in, ch_out, num_res=1): super().__init__() self.up = Upsample(ch_in, ch_out) blocks = [] for _ in range(num_res): blocks.append(ResBlock(ch_out)) self.blocks = nn.ModuleList(blocks)
[docs] def forward(self, x): x = self.up(x) for block in self.blocks: x = block(x) return x
[docs] class ChannelToSpace(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
[docs] def forward(self, x): x = self.proj(x) x = F.pixel_shuffle(x, 2).contiguous() return x
[docs] class ChannelDuplication(nn.Module):
[docs] def __init__(self, ch_in, ch_out): super().__init__() self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) self.reps = ch_out // ch_in self.scale = (self.reps) ** -0.5
[docs] def forward(self, x): res = x x = self.proj(x.contiguous()) b, c, h, w = res.shape res = res.unsqueeze(2) # [b, c, 1, h, w] res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w] res = res.reshape(b, c * self.reps, h, w).contiguous() res = res * self.scale return res + x
# === Main AE ===
[docs] class Encoder(nn.Module):
[docs] def __init__(self, config): super().__init__() self.conv_in = LandscapeToSquare(config.channels, config.ch_0) blocks = [] residuals = [] ch = config.ch_0 for block_count in config.encoder_blocks_per_stage: next_ch = min(ch*2, config.ch_max) blocks.append(DownBlock(ch, next_ch, block_count)) residuals.append(SpaceToChannel(ch, next_ch)) ch = next_ch self.blocks = nn.ModuleList(blocks) self.residuals = nn.ModuleList(residuals) self.conv_out = ChannelAverage(ch, config.latent_channels) self.skip_logvar = bool(getattr(config, "skip_logvar", False)) if not self.skip_logvar: # Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3] self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
[docs] def forward(self, x): x = self.conv_in(x) for block, residual in zip(self.blocks, self.residuals): x = block(x) + residual(x) return self.conv_out(x)
[docs] class Decoder(nn.Module):
[docs] def __init__(self, config): super().__init__() self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max) blocks = [] residuals = [] ch = config.ch_0 for block_count in reversed(config.decoder_blocks_per_stage): next_ch = min(ch*2, config.ch_max) blocks.append(UpBlock(next_ch, ch, block_count)) residuals.append(ChannelToSpace(next_ch, ch)) ch = next_ch self.blocks = nn.ModuleList(reversed(blocks)) self.residuals = nn.ModuleList(reversed(residuals)) self.act_out = nn.SiLU() self.conv_out = SquareToLandscape(config.ch_0, config.channels)
[docs] def forward(self, x): x = self.conv_in(x) for block, residual in zip(self.blocks, self.residuals): x = block(x) + residual(x) x = self.act_out(x) return self.conv_out(x)
[docs] class AutoEncoder(nn.Module):
[docs] def __init__(self, encoder_config, decoder_config=None): super().__init__() if decoder_config is None: decoder_config = encoder_config self.encoder = Encoder(encoder_config) self.decoder = Decoder(decoder_config)
[docs] def forward(self, x): return self.decoder(self.encoder(x))