-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DDRNet for Semantic Segmentation (#70)
Co-authored-by: Martin Kozlovsky <martin.kozlovsky@luxonis.com> Co-authored-by: GitHub Actions <actions@github.com>
- Loading branch information
1 parent
3a2d684
commit db8a3a9
Showing
14 changed files
with
1,166 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# DDRNet-23-slim model for segmentation | ||
# Refer to here for optimal hyperparameters for this model: https://github.com/Deci-AI/super-gradients/blob/4797c974c7c445d12e2575c468848d9c3e04becd/src/super_gradients/recipes/cityscapes_ddrnet.yaml#L4 | ||
|
||
model: | ||
name: ddrnet_segmentation | ||
predefined_model: | ||
name: DDRNetSegmentationModel | ||
params: | ||
task: binary | ||
backbone_params: | ||
use_aux_heads: True # set to False to disable auxiliary heads (for export) | ||
variant: '23-slim' | ||
|
||
loader: | ||
params: | ||
dataset_name: coco_test | ||
|
||
trainer: | ||
preprocessing: | ||
train_image_size: [&height 256, &width 320] | ||
keep_aspect_ratio: False | ||
normalize: | ||
active: True | ||
|
||
batch_size: 4 | ||
epochs: &epochs 500 | ||
num_workers: 4 | ||
validation_interval: 10 | ||
num_log_images: 8 | ||
|
||
callbacks: | ||
- name: TestOnTrainEnd | ||
- name: ExportOnTrainEnd | ||
|
||
optimizer: | ||
name: SGD | ||
params: | ||
lr: 0.01 | ||
momentum: 0.9 | ||
weight_decay: 0.0005 | ||
|
||
scheduler: | ||
name: CosineAnnealingLR | ||
params: | ||
T_max: *epochs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
luxonis_train/models/predefined_models/ddrnet_segmentation_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from luxonis_train.utils.config import ( | ||
LossModuleConfig, | ||
ModelNodeConfig, | ||
) | ||
from luxonis_train.utils.types import Kwargs | ||
|
||
from .segmentation_model import SegmentationModel | ||
|
||
|
||
@dataclass | ||
class DDRNetSegmentationModel(SegmentationModel): | ||
backbone: str = "DDRNet" | ||
aux_head_params: Kwargs = field(default_factory=dict) | ||
|
||
@property | ||
def nodes(self) -> list[ModelNodeConfig]: | ||
self.head_params.update({"attach_index": -1}) | ||
|
||
self.aux_head_params.update({"attach_index": -2}) | ||
|
||
node_list = [ | ||
ModelNodeConfig( | ||
name=self.backbone, | ||
alias="ddrnet_backbone", | ||
freezing=self.backbone_params.pop("freezing", {}), | ||
params=self.backbone_params, | ||
), | ||
ModelNodeConfig( | ||
name="DDRNetSegmentationHead", | ||
alias="segmentation_head", | ||
inputs=["ddrnet_backbone"], | ||
freezing=self.head_params.pop("freezing", {}), | ||
params=self.head_params, | ||
task=self.task_name, | ||
), | ||
] | ||
if self.backbone_params.get("use_aux_heads", False): | ||
node_list.append( | ||
ModelNodeConfig( | ||
name="DDRNetSegmentationHead", | ||
alias="aux_segmentation_head", | ||
inputs=["ddrnet_backbone"], | ||
freezing=self.aux_head_params.pop("freezing", {}), | ||
params=self.aux_head_params, | ||
task=self.task_name, | ||
) | ||
) | ||
return node_list | ||
|
||
@property | ||
def losses(self) -> list[LossModuleConfig]: | ||
loss_list = [ | ||
LossModuleConfig( | ||
name="BCEWithLogitsLoss" | ||
if self.task == "binary" | ||
else "CrossEntropyLoss", | ||
alias="segmentation_loss", | ||
attached_to="segmentation_head", | ||
params=self.loss_params, | ||
weight=1.0, | ||
), | ||
] | ||
if self.backbone_params.get("use_aux_heads", False): | ||
loss_list.append( | ||
LossModuleConfig( | ||
name="BCEWithLogitsLoss" | ||
if self.task == "binary" | ||
else "CrossEntropyLoss", | ||
alias="aux_segmentation_loss", | ||
attached_to="aux_segmentation_head", | ||
params=self.loss_params, | ||
weight=0.4, | ||
) | ||
) | ||
return loss_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .ddrnet import DDRNet | ||
|
||
__all__ = ["DDRNet"] |
Oops, something went wrong.