Skip to content

Commit

Permalink
update nn
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Mar 4, 2024
1 parent 92c7212 commit e02c97c
Show file tree
Hide file tree
Showing 8 changed files with 1,332 additions and 44 deletions.
1 change: 1 addition & 0 deletions kiui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
module_path = os.path.dirname(os.path.abspath(__file__))
submodules = [m.strip('.py') for m in os.listdir(module_path) if not m.startswith('__')]
submodules.append('gridencoder')
submodules.append('nn')

# find out all function names without importing the module
utils_path = os.path.join(module_path, 'utils.py')
Expand Down
20 changes: 19 additions & 1 deletion kiui/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from kiui.nn.utils import *

class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
Expand All @@ -21,4 +23,20 @@ def forward(self, x):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
return x


class _TruncExp(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(torch.clamp(x, max=15))

trunc_exp = _TruncExp.apply
16 changes: 0 additions & 16 deletions kiui/nn/activation.py

This file was deleted.

317 changes: 317 additions & 0 deletions kiui/nn/unet_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from typing import Tuple, Literal
from functools import partial

from kiui.nn.attention import MemEffAttention

class ImageAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
groups: int = 32,
eps: float = 1e-5,
residual: bool = True,
skip_scale: float = 1,
):
super().__init__()

self.residual = residual
self.skip_scale = skip_scale

self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)

def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape

res = x
x = self.norm(x)

x = x.permute(0, 2, 3, 1).reshape(B, -1, C)
x = self.attn(x)
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).reshape(B, C, H, W)

if self.residual:
x = (x + res) * self.skip_scale

return x

class ResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resample: Literal['default', 'up', 'down'] = 'default',
groups: int = 32,
eps: float = 1e-5,
skip_scale: float = 1, # multiplied to output
):
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.skip_scale = skip_scale

self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

self.act = F.silu

self.resample = None
if resample == 'up':
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif resample == 'down':
self.resample = nn.AvgPool2d(kernel_size=2, stride=2)

self.shortcut = nn.Identity()
if self.in_channels != self.out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)


def forward(self, x):
res = x

x = self.norm1(x)
x = self.act(x)

if self.resample:
res = self.resample(res)
x = self.resample(x)

x = self.conv1(x)
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x)

x = (x + self.shortcut(res)) * self.skip_scale

return x

class DownBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample: bool = True,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()

nets = []
attns = []
for i in range(num_layers):
cin = in_channels if i == 0 else out_channels
nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale))
if attention:
attns.append(ImageAttention(out_channels, attention_heads, skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)

self.downsample = None
if downsample:
self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
xs = []

for attn, net in zip(self.attns, self.nets):
x = net(x)
if attn:
x = attn(x)
xs.append(x)

if self.downsample:
x = self.downsample(x)
xs.append(x)

return x, xs


class MidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()

nets = []
attns = []
# first layer
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
# more layers
for i in range(num_layers):
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
if attention:
attns.append(ImageAttention(in_channels, attention_heads, skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)

def forward(self, x):
x = self.nets[0](x)
for attn, net in zip(self.attns, self.nets[1:]):
if attn:
x = attn(x)
x = net(x)
return x


class UpBlock(nn.Module):
def __init__(
self,
in_channels: int,
prev_out_channels: int,
out_channels: int,
num_layers: int = 1,
upsample: bool = True,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()

nets = []
attns = []
for i in range(num_layers):
cin = in_channels if i == 0 else out_channels
cskip = prev_out_channels if (i == num_layers - 1) else out_channels

nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
if attention:
attns.append(ImageAttention(out_channels, attention_heads, skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)

self.upsample = None
if upsample:
self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

def forward(self, x, xs):

for attn, net in zip(self.attns, self.nets):
res_x = xs[-1]
xs = xs[:-1]
x = torch.cat([x, res_x], dim=1)
x = net(x)
if attn:
x = attn(x)

if self.upsample:
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.upsample(x)

return x


# it could be asymmetric!
class UNet(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
down_attention: Tuple[bool, ...] = (False, False, False, True, True),
mid_attention: bool = True,
up_channels: Tuple[int, ...] = (1024, 512, 256),
up_attention: Tuple[bool, ...] = (True, True, False),
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
):
super().__init__()

# first
self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)

# down
down_blocks = []
cout = down_channels[0]
for i in range(len(down_channels)):
cin = cout
cout = down_channels[i]

down_blocks.append(DownBlock(
cin, cout,
num_layers=layers_per_block,
downsample=(i != len(down_channels) - 1), # not final layer
attention=down_attention[i],
skip_scale=skip_scale,
))
self.down_blocks = nn.ModuleList(down_blocks)

# mid
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)

# up
up_blocks = []
cout = up_channels[0]
for i in range(len(up_channels)):
cin = cout
cout = up_channels[i]
cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric

up_blocks.append(UpBlock(
cin, cskip, cout,
num_layers=layers_per_block + 1, # one more layer for up
upsample=(i != len(up_channels) - 1), # not final layer
attention=up_attention[i],
skip_scale=skip_scale,
))
self.up_blocks = nn.ModuleList(up_blocks)

# last
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)


def forward(self, x):
# x: [B, Cin, H, W]

# first
x = self.conv_in(x)

# down
xss = [x]
for block in self.down_blocks:
x, xs = block(x)
xss.extend(xs)

# mid
x = self.mid_block(x)

# up
for block in self.up_blocks:
xs = xss[-len(block.nets):]
xss = xss[:-len(block.nets)]
x = block(x, xs)

# last
x = self.norm_out(x)
x = F.silu(x)
x = self.conv_out(x) # [B, Cout, H', W']

return x
Loading

0 comments on commit e02c97c

Please sign in to comment.