Skip to content

Commit

Permalink
DDRNet for Semantic Segmentation (#70)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Kozlovsky <martin.kozlovsky@luxonis.com>
Co-authored-by: GitHub Actions <actions@github.com>
  • Loading branch information
3 people committed Oct 9, 2024
1 parent 3a2d684 commit db8a3a9
Show file tree
Hide file tree
Showing 14 changed files with 1,166 additions and 1 deletion.
45 changes: 45 additions & 0 deletions configs/ddrnet_segmentation_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# DDRNet-23-slim model for segmentation
# Refer to here for optimal hyperparameters for this model: https://github.com/Deci-AI/super-gradients/blob/4797c974c7c445d12e2575c468848d9c3e04becd/src/super_gradients/recipes/cityscapes_ddrnet.yaml#L4

model:
name: ddrnet_segmentation
predefined_model:
name: DDRNetSegmentationModel
params:
task: binary
backbone_params:
use_aux_heads: True # set to False to disable auxiliary heads (for export)
variant: '23-slim'

loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
train_image_size: [&height 256, &width 320]
keep_aspect_ratio: False
normalize:
active: True

batch_size: 4
epochs: &epochs 500
num_workers: 4
validation_interval: 10
num_log_images: 8

callbacks:
- name: TestOnTrainEnd
- name: ExportOnTrainEnd

optimizer:
name: SGD
params:
lr: 0.01
momentum: 0.9
weight_decay: 0.0005

scheduler:
name: CosineAnnealingLR
params:
T_max: *epochs
2 changes: 2 additions & 0 deletions luxonis_train/models/predefined_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_predefined_model import BasePredefinedModel
from .classification_model import ClassificationModel
from .ddrnet_segmentation_model import DDRNetSegmentationModel
from .detection_model import DetectionModel
from .keypoint_detection_model import KeypointDetectionModel
from .segmentation_model import SegmentationModel
Expand All @@ -10,4 +11,5 @@
"DetectionModel",
"KeypointDetectionModel",
"ClassificationModel",
"DDRNetSegmentationModel",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from dataclasses import dataclass, field

from luxonis_train.utils.config import (
LossModuleConfig,
ModelNodeConfig,
)
from luxonis_train.utils.types import Kwargs

from .segmentation_model import SegmentationModel


@dataclass
class DDRNetSegmentationModel(SegmentationModel):
backbone: str = "DDRNet"
aux_head_params: Kwargs = field(default_factory=dict)

@property
def nodes(self) -> list[ModelNodeConfig]:
self.head_params.update({"attach_index": -1})

self.aux_head_params.update({"attach_index": -2})

node_list = [
ModelNodeConfig(
name=self.backbone,
alias="ddrnet_backbone",
freezing=self.backbone_params.pop("freezing", {}),
params=self.backbone_params,
),
ModelNodeConfig(
name="DDRNetSegmentationHead",
alias="segmentation_head",
inputs=["ddrnet_backbone"],
freezing=self.head_params.pop("freezing", {}),
params=self.head_params,
task=self.task_name,
),
]
if self.backbone_params.get("use_aux_heads", False):
node_list.append(
ModelNodeConfig(
name="DDRNetSegmentationHead",
alias="aux_segmentation_head",
inputs=["ddrnet_backbone"],
freezing=self.aux_head_params.pop("freezing", {}),
params=self.aux_head_params,
task=self.task_name,
)
)
return node_list

@property
def losses(self) -> list[LossModuleConfig]:
loss_list = [
LossModuleConfig(
name="BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss",
alias="segmentation_loss",
attached_to="segmentation_head",
params=self.loss_params,
weight=1.0,
),
]
if self.backbone_params.get("use_aux_heads", False):
loss_list.append(
LossModuleConfig(
name="BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss",
alias="aux_segmentation_loss",
attached_to="aux_segmentation_head",
params=self.loss_params,
weight=0.4,
)
)
return loss_list
2 changes: 2 additions & 0 deletions luxonis_train/nodes/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .contextspatial import ContextSpatial
from .ddrnet import DDRNet
from .efficientnet import EfficientNet
from .efficientrep import EfficientRep
from .micronet import MicroNet
Expand All @@ -18,4 +19,5 @@
"ReXNetV1_lite",
"RepVGG",
"ResNet",
"DDRNet",
]
3 changes: 3 additions & 0 deletions luxonis_train/nodes/backbones/ddrnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ddrnet import DDRNet

__all__ = ["DDRNet"]
Loading

0 comments on commit db8a3a9

Please sign in to comment.