Skip to content

Commit

Permalink
Archive without Export (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent ff06fc6 commit 03a8bda
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 15 deletions.
4 changes: 2 additions & 2 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ def inspect(
@app.command()
def archive(
executable: Annotated[
str,
str | None,
typer.Option(
help="Path to the model file.", show_default=False, metavar="FILE"
),
],
] = None,
config: ConfigType = None,
opts: OptsType = None,
):
Expand Down
17 changes: 10 additions & 7 deletions luxonis_train/callbacks/archive_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ def on_train_end(
@param pl_module: Pytorch Lightning module.
"""

path = self.get_checkpoint(pl_module)
if path is None: # pragma: no cover
logger.warning("Skipping model archiving.")
return

onnx_path = pl_module.core._exported_models.get("onnx")
if onnx_path is None: # pragma: no cover
checkpoint = self.get_checkpoint(pl_module)
if checkpoint is None:
logger.warning("Skipping model archiving.")
return
logger.info("Exported model not found. Exporting to ONNX...")
pl_module.core.export(weights=checkpoint)
onnx_path = pl_module.core._exported_models.get("onnx")

if onnx_path is None: # pragma: no cover
logger.error(
"Model executable not found. "
"Make sure to run exporter callback before archiver callback. "
"Model executable not found and couldn't be created. "
"Skipping model archiving."
)
return
Expand Down
6 changes: 3 additions & 3 deletions luxonis_train/callbacks/export_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def on_train_end(
@type pl_module: L{pl.LightningModule}
@param pl_module: Pytorch Lightning module.
"""
path = self.get_checkpoint(pl_module)
if path is None: # pragma: no cover
checkpoint = self.get_checkpoint(pl_module)
if checkpoint is None: # pragma: no cover
logger.warning("Skipping model export.")
return

pl_module.core.export(weights=self.get_checkpoint(pl_module))
pl_module.core.export(weights=checkpoint)
2 changes: 1 addition & 1 deletion luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def get_config(
cls,
cfg: str | dict[str, Any] | None = None,
overrides: dict[str, Any] | list[str] | tuple[str, ...] | None = None,
):
) -> "Config":
instance = super().get_config(cfg, overrides)
if not isinstance(cfg, str):
return instance
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def archive(self, path: str | Path | None = None) -> Path:
outputs = []

if path is None:
logger.warning("No model executable specified for archiving.")
if "onnx" not in self._exported_models:
logger.info("Exporting model to ONNX...")
self.export()
Expand Down
3 changes: 2 additions & 1 deletion luxonis_train/core/utils/archive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def _get_head_outputs(
# TODO: Fix this, will require refactoring custom ONNX output names
logger.error(
"ONNX model uses custom output names, trying to determine outputs based on the head type. "
"This will likely result in incorrect archive for multi-head models."
"This will likely result in incorrect archive for multi-head models. "
"You can ignore this error if your model has only one head."
)

if head_type == "ClassificationHead":
Expand Down
11 changes: 10 additions & 1 deletion luxonis_train/nodes/heads/ddrnet_segmentation_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Literal

import torch
import torch.nn as nn
Expand All @@ -22,7 +23,15 @@ class DDRNetSegmentationHead(BaseNode[Tensor, Tensor]):
def __init__(
self,
inter_channels: int = 64,
inter_mode: str = "bilinear",
inter_mode: Literal[
"nearest",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
"pixel_shuffle",
] = "bilinear",
**kwargs,
):
"""DDRNet segmentation head.
Expand Down

0 comments on commit 03a8bda

Please sign in to comment.