-
Notifications
You must be signed in to change notification settings - Fork 59
/
hubconf.py
64 lines (47 loc) · 2.89 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os.path
from typing import Literal
from transformers import CLIPVisionModelWithProjection, CLIPTextModel, AutoConfig
from src.models.ConvNet_TPS import ConvNet_TPS
from src.models.UNet import UNetVanilla
from src.models.emasc import EMASC
dependencies = ['torch', 'diffusers', 'transformers']
import torch
from diffusers import UNet2DConditionModel
from src.models.inversion_adapter import InversionAdapter
def inversion_adapter(dataset: Literal['dresscode', 'vitonhd']):
config = AutoConfig.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
text_encoder_config = UNet2DConditionModel.load_config("stabilityai/stable-diffusion-2-inpainting", subfolder="text_encoder")
inversion_adapter = InversionAdapter(input_dim=config.vision_config.hidden_size,
hidden_dim=config.vision_config.hidden_size * 4,
output_dim=text_encoder_config['hidden_size'] * 16,
num_encoder_layers=1,
config=config.vision_config)
checkpoint_url = f"https://github.com/miccunifi/ladi-vton/releases/download/weights/inversion_adapter_{dataset}.pth"
inversion_adapter.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu'))
return inversion_adapter
def extended_unet(dataset: Literal['dresscode', 'vitonhd']):
config = UNet2DConditionModel.load_config("stabilityai/stable-diffusion-2-inpainting", subfolder="unet")
config['in_channels'] = 31
unet = UNet2DConditionModel.from_config(config)
checkpoint_url = f"https://github.com/miccunifi/ladi-vton/releases/download/weights/unet_{dataset}.pth"
unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu'))
return unet
def emasc(dataset: Literal['dresscode', 'vitonhd']):
in_feature_channels = [128, 128, 128, 256, 512]
out_feature_channels = [128, 256, 512, 512, 512]
emasc = EMASC(in_feature_channels,
out_feature_channels,
kernel_size=3,
padding=1,
stride=1,
type='nonlinear')
checkpoint_url = f"https://github.com/miccunifi/ladi-vton/releases/download/weights/emasc_{dataset}.pth"
emasc.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu'))
return emasc
def warping_module(dataset: Literal['dresscode', 'vitonhd']):
tps = ConvNet_TPS(256, 192, 21, 3)
refinement = UNetVanilla(n_channels=24, n_classes=3, bilinear=True)
checkpoint_url = f"https://github.com/miccunifi/ladi-vton/releases/download/weights/warping_{dataset}.pth"
tps.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu')['tps'])
refinement.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu')['refinement'])
return tps, refinement