import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm
import math
# === General Blocks ===
[docs]
def WeightNormConv2d(*args, **kwargs):
return weight_norm(nn.Conv2d(*args, **kwargs))
[docs]
class ResBlock(nn.Module):
[docs]
def __init__(self, ch):
super().__init__()
hidden = 2 * ch
# 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64)
n_grps = max(1, hidden // 16)
self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0)
self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps)
self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False)
self.act1 = nn.LeakyReLU(inplace=False)
self.act2 = nn.LeakyReLU(inplace=False)
[docs]
def forward(self, x):
h = self.conv1(x)
h = self.act1(h)
h = self.conv2(h)
h = self.act2(h)
h = self.conv3(h)
return x + h
# === Encoder ===
[docs]
class LandscapeToSquare(nn.Module):
# Strict assumption of 360p
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
[docs]
def forward(self, x):
x = F.interpolate(x, (512, 512), mode = 'bicubic')
x = self.proj(x)
return x
[docs]
class Downsample(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias = False)
[docs]
def forward(self, x):
x = F.interpolate(x, scale_factor = 0.5, mode = 'bicubic')
x = self.proj(x)
return x
[docs]
class DownBlock(nn.Module):
[docs]
def __init__(self, ch_in, ch_out, num_res=1):
super().__init__()
self.down = Downsample(ch_in, ch_out)
blocks = []
for _ in range(num_res):
blocks.append(ResBlock(ch_in))
self.blocks = nn.ModuleList(blocks)
[docs]
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.down(x)
return x
[docs]
class SpaceToChannel(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
[docs]
def forward(self, x):
x = self.proj(x)
x = F.pixel_unshuffle(x, 2).contiguous()
return x
[docs]
class ChannelAverage(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
self.grps = ch_in // ch_out
self.scale = (self.grps) ** 0.5
[docs]
def forward(self, x):
res = x
x = self.proj(x.contiguous()) # [b, ch_out, h, w]
# Residual goes through channel avg
res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous()
res = res.mean(dim=1) * self.scale # [b, ch_out, h, w]
return res + x
# === Decoder ===
[docs]
class SquareToLandscape(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
[docs]
def forward(self, x):
x = self.proj(x) # TODO This ordering is wrong for both
x = F.interpolate(x, (360, 640), mode = 'bicubic')
return x
[docs]
class Upsample(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d(
ch_in, ch_out, 1, 1, 0, bias=False
)
[docs]
def forward(self, x):
x = self.proj(x)
x = F.interpolate(x, scale_factor = 2.0, mode = 'bicubic')
return x
[docs]
class UpBlock(nn.Module):
[docs]
def __init__(self, ch_in, ch_out, num_res=1):
super().__init__()
self.up = Upsample(ch_in, ch_out)
blocks = []
for _ in range(num_res):
blocks.append(ResBlock(ch_out))
self.blocks = nn.ModuleList(blocks)
[docs]
def forward(self, x):
x = self.up(x)
for block in self.blocks:
x = block(x)
return x
[docs]
class ChannelToSpace(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
[docs]
def forward(self, x):
x = self.proj(x)
x = F.pixel_shuffle(x, 2).contiguous()
return x
[docs]
class ChannelDuplication(nn.Module):
[docs]
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
self.reps = ch_out // ch_in
self.scale = (self.reps) ** -0.5
[docs]
def forward(self, x):
res = x
x = self.proj(x.contiguous())
b, c, h, w = res.shape
res = res.unsqueeze(2) # [b, c, 1, h, w]
res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w]
res = res.reshape(b, c * self.reps, h, w).contiguous()
res = res * self.scale
return res + x
# === Main AE ===
[docs]
class Encoder(nn.Module):
[docs]
def __init__(self, config):
super().__init__()
self.conv_in = LandscapeToSquare(config.channels, config.ch_0)
blocks = []
residuals = []
ch = config.ch_0
for block_count in config.encoder_blocks_per_stage:
next_ch = min(ch*2, config.ch_max)
blocks.append(DownBlock(ch, next_ch, block_count))
residuals.append(SpaceToChannel(ch, next_ch))
ch = next_ch
self.blocks = nn.ModuleList(blocks)
self.residuals = nn.ModuleList(residuals)
self.conv_out = ChannelAverage(ch, config.latent_channels)
self.skip_logvar = bool(getattr(config, "skip_logvar", False))
if not self.skip_logvar:
# Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3]
self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
[docs]
def forward(self, x):
x = self.conv_in(x)
for block, residual in zip(self.blocks, self.residuals):
x = block(x) + residual(x)
return self.conv_out(x)
[docs]
class Decoder(nn.Module):
[docs]
def __init__(self, config):
super().__init__()
self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max)
blocks = []
residuals = []
ch = config.ch_0
for block_count in reversed(config.decoder_blocks_per_stage):
next_ch = min(ch*2, config.ch_max)
blocks.append(UpBlock(next_ch, ch, block_count))
residuals.append(ChannelToSpace(next_ch, ch))
ch = next_ch
self.blocks = nn.ModuleList(reversed(blocks))
self.residuals = nn.ModuleList(reversed(residuals))
self.act_out = nn.SiLU()
self.conv_out = SquareToLandscape(config.ch_0, config.channels)
[docs]
def forward(self, x):
x = self.conv_in(x)
for block, residual in zip(self.blocks, self.residuals):
x = block(x) + residual(x)
x = self.act_out(x)
return self.conv_out(x)
[docs]
class AutoEncoder(nn.Module):
[docs]
def __init__(self, encoder_config, decoder_config=None):
super().__init__()
if decoder_config is None:
decoder_config = encoder_config
self.encoder = Encoder(encoder_config)
self.decoder = Decoder(decoder_config)
[docs]
def forward(self, x):
return self.decoder(self.encoder(x))