Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix relative path on hubconfig to get cfg_file. #92

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
dependencies = ['torch', 'torchvision']

import torch
import os
try:
from mmcv.utils import Config, DictAction
except:
Expand Down Expand Up @@ -41,14 +42,15 @@ def metric3d_convnext_large(pretrain=False, **kwargs):
Returns:
model (nn.Module): a Metric3D model.
'''
cfg_file = MODEL_TYPE['ConvNeXt-Large']['cfg_file']
dirname = os.path.dirname(__file__)
cfg_file = os.path.join(dirname, MODEL_TYPE['ConvNeXt-Large']['cfg_file'])
ckpt_file = MODEL_TYPE['ConvNeXt-Large']['ckpt_file']

cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model
Expand All @@ -62,14 +64,15 @@ def metric3d_vit_small(pretrain=False, **kwargs):
Returns:
model (nn.Module): a Metric3D model.
'''
cfg_file = MODEL_TYPE['ViT-Small']['cfg_file']
dirname = os.path.dirname(__file__)
cfg_file = os.path.join(dirname, MODEL_TYPE['ViT-Small']['cfg_file'])
ckpt_file = MODEL_TYPE['ViT-Small']['ckpt_file']

cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model
Expand All @@ -83,14 +86,15 @@ def metric3d_vit_large(pretrain=False, **kwargs):
Returns:
model (nn.Module): a Metric3D model.
'''
cfg_file = MODEL_TYPE['ViT-Large']['cfg_file']
dirname = os.path.dirname(__file__)
cfg_file = os.path.join(dirname, MODEL_TYPE['ViT-Large']['cfg_file'])
ckpt_file = MODEL_TYPE['ViT-Large']['ckpt_file']

cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model
Expand All @@ -104,14 +108,15 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):
Returns:
model (nn.Module): a Metric3D model.
'''
cfg_file = MODEL_TYPE['ViT-giant2']['cfg_file']
dirname = os.path.dirname(__file__)
cfg_file = os.path.join(dirname, MODEL_TYPE['ViT-giant2']['cfg_file'])
ckpt_file = MODEL_TYPE['ViT-giant2']['ckpt_file']

cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model
Expand Down Expand Up @@ -163,7 +168,7 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):
# un pad
pred_depth = pred_depth.squeeze()
pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]

# upsample to original size
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], rgb_origin.shape[:2], mode='bilinear').squeeze()
###################### canonical camera space ######################
Expand All @@ -173,14 +178,14 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):
pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
pred_depth = torch.clamp(pred_depth, 0, 300)

#### you can now do anything with the metric depth
#### you can now do anything with the metric depth
# such as evaluate predicted depth
if depth_file is not None:
gt_depth = cv2.imread(depth_file, -1)
gt_depth = gt_depth / gt_depth_scale
gt_depth = torch.from_numpy(gt_depth).float().cuda()
assert gt_depth.shape == pred_depth.shape

mask = (gt_depth > 1e-8)
abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean()
print('abs_rel_err:', abs_rel_err.item())
2 changes: 1 addition & 1 deletion mono/configs/HourglassDecoder/convlarge.0.3_150.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
focal_length=1000.0,
),
depth_range=(0, 1),
depth_normalize=(0.3, 150),
depth_normalize=(0.3, 500),
crop_size = (544, 1216),
)

Expand Down
4 changes: 2 additions & 2 deletions mono/configs/HourglassDecoder/vit.raft5.giant2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


max_value = 200
max_value = 500
# configs of the canonical space
data_basic=dict(
canonical_space = dict(
Expand All @@ -25,7 +25,7 @@
depth_range=(0, 1),
depth_normalize=(0.1, max_value),
crop_size = (616, 1064), # %28 = 0
clip_depth_range=(0.1, 200),
clip_depth_range=(0.1, 500),
vit_size=(616,1064)
)

Expand Down
4 changes: 2 additions & 2 deletions mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def compute_depth_expectation(prob, depth_values):
return depth

def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
with torch.autocast(device_type='cuda', dtype=torch.float, enabled=False):
return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)

# def upflow8(flow, mode='bilinear'):
Expand All @@ -225,7 +225,7 @@ def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_c

def upflow4(flow, mode='bilinear'):
new_size = (4 * flow.shape[2], 4 * flow.shape[3])
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
with torch.autocast(device_type='cuda', dtype=torch.float, enabled=False):
return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

def coords_grid(batch, ht, wd):
Expand Down