Skip to content

Commit

Permalink
Fix Removed Tensor Metadata (#12)
Browse files Browse the repository at this point in the history
* option to source custom code in CLI

* removed empty dicts

* fixed issue with removed tensor metadata in match case statements
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent 92fd295 commit 31d7c42
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(
) -> tuple[Tensor, Tensor]:
for visualizer in self.visualizers:
match visualizer.run(label_canvas, prediction_canvas, outputs, labels):
case Tensor(data=prediction_viz):
case Tensor() as prediction_viz:
prediction_canvas = prediction_viz
case (Tensor(data=label_viz), Tensor(data=prediction_viz)):
label_canvas = label_viz
Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/attached_modules/visualizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def resize_to_match(
return fst_resized, snd_resized

match visualization:
case Tensor(data=viz):
case Tensor() as viz:
return viz
case (Tensor(data=viz_labels), Tensor(data=viz_predictions)):
viz_labels, viz_predictions = resize_to_match(viz_labels, viz_predictions)
Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/core/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def export(self, onnx_path: str | None = None):
model_onnx = onnx.load(onnx_path)
onnx_model, check = onnxsim.simplify(model_onnx)
if not check:
raise RuntimeError("Onnx simplify failed.")
raise RuntimeError("ONNX simplify failed.")
onnx.save(onnx_model, onnx_path)
logger.info(f"ONNX model saved to {onnx_path}")

Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/models/luxonis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def compute_metrics(self) -> dict[str, dict[str, Tensor]]:
computed_submetrics = {
metric_name: metric_value,
} | submetrics
case Tensor(data=metric_value):
case Tensor() as metric_value:
computed_submetrics = {metric_name: metric_value}
case dict(submetrics):
computed_submetrics = submetrics
Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def wrap(self, output: ForwardOutputT) -> Packet[Tensor]:
"""

match output:
case Tensor(data=out):
case Tensor() as out:
outputs = [out]
case list(tensors) if all(isinstance(t, Tensor) for t in tensors):
outputs = tensors
Expand Down

0 comments on commit 31d7c42

Please sign in to comment.