Skip to content

Commit

Permalink
Update hubconf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thias15 committed Jan 2, 2023
1 parent ed20cc6 commit 1645b7e
Showing 1 changed file with 147 additions and 4 deletions.
151 changes: 147 additions & 4 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from midas.midas_net import MidasNet
from midas.midas_net_custom import MidasNet_small

def DPT_BEit_L_512(pretrained=True, **kwargs):
def DPT_BEiT_L_512(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_BEit_L_512 model for monocular depth estimation
MiDaS DPT_BEiT_L_512 model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

Expand All @@ -29,9 +29,9 @@ def DPT_BEit_L_512(pretrained=True, **kwargs):

return model

def DPT_BEit_L_384(pretrained=True, **kwargs):
def DPT_BEiT_L_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_BEit_L_384 model for monocular depth estimation
MiDaS DPT_BEiT_L_384 model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

Expand All @@ -52,6 +52,29 @@ def DPT_BEit_L_384(pretrained=True, **kwargs):

return model

def DPT_BEiT_B_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_BEiT_B_384 model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

model = DPTDepthModel(
path=None,
backbone="beitb16_384",
non_negative=True,
)

if pretrained:
checkpoint = (
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
)
model.load_state_dict(state_dict)

return model

def DPT_SwinV2_L_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_SwinV2_L_384 model for monocular depth estimation
Expand All @@ -75,6 +98,29 @@ def DPT_SwinV2_L_384(pretrained=True, **kwargs):

return model

def DPT_SwinV2_B_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_SwinV2_B_384 model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

model = DPTDepthModel(
path=None,
backbone="swin2b24_384",
non_negative=True,
)

if pretrained:
checkpoint = (
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
)
model.load_state_dict(state_dict)

return model

def DPT_SwinV2_T_256(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_SwinV2_T_256 model for monocular depth estimation
Expand All @@ -98,6 +144,29 @@ def DPT_SwinV2_T_256(pretrained=True, **kwargs):

return model

def DPT_Swin_L_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_Swin_L_384 model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

model = DPTDepthModel(
path=None,
backbone="swinl12_384",
non_negative=True,
)

if pretrained:
checkpoint = (
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
)
model.load_state_dict(state_dict)

return model

def DPT_Next_ViT_L_384(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation
Expand Down Expand Up @@ -131,6 +200,8 @@ def DPT_LeViT_224(pretrained=True, **kwargs):
path=None,
backbone="levit_384",
non_negative=True,
head_features_1=64,
head_features_2=8,
)

if pretrained:
Expand Down Expand Up @@ -289,4 +360,76 @@ def transforms():
]
)

transforms.beit512_transform = Compose(
[
lambda img: {"image": img / 255.0},
Resize(
512,
512,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
]
)

transforms.swin384_transform = Compose(
[
lambda img: {"image": img / 255.0},
Resize(
384,
384,
resize_target=None,
keep_aspect_ratio=False,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
]
)

transforms.swin256_transform = Compose(
[
lambda img: {"image": img / 255.0},
Resize(
256,
256,
resize_target=None,
keep_aspect_ratio=False,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
]
)

transforms.levit_transform = Compose(
[
lambda img: {"image": img / 255.0},
Resize(
224,
224,
resize_target=None,
keep_aspect_ratio=False,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
]
)

return transforms

0 comments on commit 1645b7e

Please sign in to comment.