Skip to content

Commit

Permalink
Forbid Extra Fields (#13)
Browse files Browse the repository at this point in the history
* forbid extra fields in config

* fixed configs
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent 31d7c42 commit d6081d4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
1 change: 0 additions & 1 deletion configs/coco_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ trainer:
validation_interval: 10
num_log_images: 8
skip_last_batch: True
main_head_index: 0
log_sub_losses: True
save_top_k: 3

Expand Down
5 changes: 3 additions & 2 deletions configs/resnet_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ model:
name: resnet50_classification
nodes:
- name: ResNet
variant: "50"
download_weights: True
params:
variant: "50"
download_weights: True

- name: ClassificationHead
inputs:
Expand Down
44 changes: 24 additions & 20 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

from luxonis_ml.data import BucketStorage, BucketType
from luxonis_ml.utils import Environ, LuxonisConfig, LuxonisFileSystem, setup_logging
from pydantic import BaseModel, Field, field_serializer, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator

from luxonis_train.utils.general import is_acyclic
from luxonis_train.utils.registry import MODELS

logger = logging.getLogger(__name__)


class AttachedModuleConfig(BaseModel):
class CustomBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class AttachedModuleConfig(CustomBaseModel):
name: str
attached_to: str
alias: str | None = None
Expand All @@ -28,20 +32,20 @@ class MetricModuleConfig(AttachedModuleConfig):
is_main_metric: bool = False


class FreezingConfig(BaseModel):
class FreezingConfig(CustomBaseModel):
active: bool = False
unfreeze_after: int | float | None = None


class ModelNodeConfig(BaseModel):
class ModelNodeConfig(CustomBaseModel):
name: str
alias: str | None = None
inputs: list[str] = []
params: dict[str, Any] = {}
freezing: FreezingConfig = FreezingConfig()


class PredefinedModelConfig(BaseModel):
class PredefinedModelConfig(CustomBaseModel):
name: str
params: dict[str, Any] = {}
include_nodes: bool = True
Expand All @@ -50,7 +54,7 @@ class PredefinedModelConfig(BaseModel):
include_visualizers: bool = True


class ModelConfig(BaseModel):
class ModelConfig(CustomBaseModel):
name: str
predefined_model: PredefinedModelConfig | None = None
weights: str | None = None
Expand Down Expand Up @@ -114,7 +118,7 @@ def check_unique_names(self):
return self


class TrackerConfig(BaseModel):
class TrackerConfig(CustomBaseModel):
project_name: str | None = None
project_id: str | None = None
run_name: str | None = None
Expand All @@ -126,7 +130,7 @@ class TrackerConfig(BaseModel):
is_mlflow: bool = False


class DatasetConfig(BaseModel):
class DatasetConfig(CustomBaseModel):
name: str | None = None
id: str | None = None
team_name: str | None = None
Expand All @@ -143,20 +147,20 @@ def get_enum_value(self, v: Enum, _) -> str:
return str(v.value)


class NormalizeAugmentationConfig(BaseModel):
class NormalizeAugmentationConfig(CustomBaseModel):
active: bool = True
params: dict[str, Any] = {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
}


class AugmentationConfig(BaseModel):
class AugmentationConfig(CustomBaseModel):
name: str
params: dict[str, Any] = {}


class PreprocessingConfig(BaseModel):
class PreprocessingConfig(CustomBaseModel):
train_image_size: Annotated[
list[int], Field(default=[256, 256], min_length=2, max_length=2)
] = [256, 256]
Expand All @@ -174,23 +178,23 @@ def check_normalize(self):
return self


class CallbackConfig(BaseModel):
class CallbackConfig(CustomBaseModel):
name: str
active: bool = True
params: dict[str, Any] = {}


class OptimizerConfig(BaseModel):
class OptimizerConfig(CustomBaseModel):
name: str = "Adam"
params: dict[str, Any] = {}


class SchedulerConfig(BaseModel):
class SchedulerConfig(CustomBaseModel):
name: str = "ConstantLR"
params: dict[str, Any] = {}


class TrainerConfig(BaseModel):
class TrainerConfig(CustomBaseModel):
preprocessing: PreprocessingConfig = PreprocessingConfig()

accelerator: Literal["auto", "cpu", "gpu"] = "auto"
Expand Down Expand Up @@ -229,17 +233,17 @@ def check_num_workes_platform(self):
return self


class OnnxExportConfig(BaseModel):
class OnnxExportConfig(CustomBaseModel):
opset_version: int = 12
dynamic_axes: dict[str, Any] | None = None


class BlobconverterExportConfig(BaseModel):
class BlobconverterExportConfig(CustomBaseModel):
active: bool = False
shaves: int = 6


class ExportConfig(BaseModel):
class ExportConfig(CustomBaseModel):
export_save_directory: str = "output_export"
input_shape: list[int] | None = None
export_model_name: str = "model"
Expand All @@ -265,12 +269,12 @@ def pad_values(values: float | list[float] | None):
return self


class StorageConfig(BaseModel):
class StorageConfig(CustomBaseModel):
active: bool = True
storage_type: Literal["local", "remote"] = "local"


class TunerConfig(BaseModel):
class TunerConfig(CustomBaseModel):
study_name: str = "test-study"
use_pruner: bool = True
n_trials: int | None = 15
Expand Down

0 comments on commit d6081d4

Please sign in to comment.