Skip to content

Commit

Permalink
Added mixed precision training and removed torch.cuda.empty_cache() d…
Browse files Browse the repository at this point in the history
…ue to performance drop.
  • Loading branch information
JSabadin committed Oct 10, 2024
1 parent dc7ae5f commit d839190
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@ def forward(
self.alpha * pred_score.pow(self.gamma) * (1 - label)
+ target_score * label
)
ce_loss = F.binary_cross_entropy(
pred_score.float(), target_score.float(), reduction="none"
)
with torch.amp.autocast(

Check failure on line 273 in luxonis_train/attached_modules/losses/adaptive_detection_loss.py

View workflow job for this annotation

GitHub Actions / type-check

"amp" is not a known attribute of module "torch" (reportAttributeAccessIssue)
device_type=pred_score.device.type, enabled=False
):
ce_loss = F.binary_cross_entropy(
pred_score.float(), target_score.float(), reduction="none"
)
loss = (ce_loss * weight).sum()
return loss
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()
use_rich_progress_bar: bool = True

precision: Literal["16-mixed", "32"] = "32"
accelerator: Literal["auto", "cpu", "gpu", "tpu"] = "auto"
devices: int | list[int] | str = "auto"
strategy: Literal["auto", "ddp"] = "auto"
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
callbacks=LuxonisRichProgressBar()
if self.cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar(),
precision=self.cfg.trainer.precision,
)

self.train_augmentations = Augmentations(
Expand Down
6 changes: 1 addition & 5 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from lightning.pytorch.utilities import rank_zero_only # type: ignore
from luxonis_ml.data import LuxonisDataset
from torch import Size, Tensor, nn
from torch.amp import autocast

import luxonis_train
from luxonis_train.attached_modules import (
Expand Down Expand Up @@ -394,8 +393,7 @@ def forward(
else:
node_inputs.append({"features": [inputs[pred]]})

with autocast(device_type=self.device.type):
outputs = node.run(node_inputs)
outputs = node.run(node_inputs)

computed[node_name] = outputs

Expand Down Expand Up @@ -445,8 +443,6 @@ def forward(
if node_name in self.outputs
}

torch.cuda.empty_cache()

return LuxonisOutput(
outputs=outputs_dict, losses=losses, visualizations=visualizations
)
Expand Down

0 comments on commit d839190

Please sign in to comment.