Skip to content

Commit

Permalink
fix f-string issues for testing earlier python versions; resolve some…
Browse files Browse the repository at this point in the history
… mypy errors for stage/base.py
  • Loading branch information
blsmxiu47 committed Oct 28, 2024
1 parent 18d8cad commit 0108d3f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
24 changes: 16 additions & 8 deletions src/onemod/stage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def dataif(self) -> DataInterface:
return self._dataif

def set_dataif(self, config_path: Path | str) -> None:
if self.input is None:
return

directory = Path(config_path).parent
self._dataif = DataInterface(
directory=directory,
Expand All @@ -88,14 +91,14 @@ def set_dataif(self, config_path: Path | str) -> None:
if not (directory / self.name).exists():
(directory / self.name).mkdir()

@computed_field
@property
def module(self) -> str | None:
@computed_field
def module(self) -> Path | None:
if self._module is None and not hasattr(
onemod_stages, self.type
): # custom stage
try:
return getfile(self.__class__)
return Path(getfile(self.__class__))
except TypeError:
raise TypeError(f"Could not find module for {self.name} stage")
return self._module
Expand All @@ -104,8 +107,8 @@ def module(self) -> str | None:
def skip(self) -> set[str]:
return self._skip

@computed_field
@property
@computed_field
def input(self) -> Input | None:
if self._input is None:
self._input = Input(
Expand All @@ -129,8 +132,8 @@ def dependencies(self) -> set[str]:
return set()
return self.input.dependencies

@computed_field
@property
@computed_field
def type(self) -> str:
return type(self).__name__

Expand All @@ -157,7 +160,7 @@ def from_json(cls, config_path: Path | str, stage_name: str) -> Stage:
stage_config = pipeline_config["stages"][stage_name]
except KeyError:
raise AttributeError(
f"{pipeline_config["name"]} does not contain a stage named '{stage_name}'"
f"{pipeline_config['name']} does not contain a stage named '{stage_name}'"
)
stage = cls(**stage_config)
stage.config.inherit(pipeline_config["config"])
Expand Down Expand Up @@ -345,9 +348,9 @@ class ModelStage(Stage, ABC):
_required_input: set[str] = set() # data required for groupby
_collect_after: set[str] = set() # defined by class

@computed_field
@property
def crossby(self) -> set[str]:
@computed_field
def crossby(self) -> set[str] | None:
return self._crossby

@property
Expand Down Expand Up @@ -380,6 +383,11 @@ def collect_after(self) -> set[str]:

def create_stage_subsets(self, data: Path | str) -> None:
"""Create stage data subsets from groupby."""
if self.groupby is None:
raise AttributeError(
f"{self.name} does not have a groupby attribute"
)

subsets = create_subsets(
self.groupby, self.dataif.load(data, columns=self.groupby)
)
Expand Down
4 changes: 2 additions & 2 deletions src/onemod/stage/model_stages/spxmod_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _fit(
# Create and fit submodel
logger.info(f"Fitting {self.name} submodel {subset_id}")
train = data.query(
f"({self.config["test_column"]} == 0) & {self.config["observation_column"]}.notnull()"
f"({self.config['test_column']} == 0) & {self.config['observation_column']}.notnull()"
)
model = XModel.from_config(xmodel_args)
model.fit(data=train, data_span=data, **self.config.xmodel_fit)
Expand Down Expand Up @@ -214,7 +214,7 @@ def _get_submodel_data(
# Add offset to data
offset = False
if "offset" in self.input:
logger.info(f"Adding offset from {self.input["offset"].stage}")
logger.info(f"Adding offset from {self.input['offset'].stage}")
data = data.merge(
right=self.dataif.load_offset(
columns=list(self.config["id_columns"])
Expand Down

0 comments on commit 0108d3f

Please sign in to comment.