Skip to content

Commit

Permalink
move feature name assignment to right spot
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthiasSchmidtblaicherQC committed Dec 8, 2023
1 parent b0b2d3e commit 16bd925
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ def _check_offset(
return offset


def _name_categorical_variables(
categories: tuple[str], column_name: str, drop_first: bool
):
new_names = [
f"{column_name}__{category}" for category in categories[int(drop_first) :]
]
if len(new_names) == 0:
raise ValueError(
f"Categorical column: {column_name}, contains only one category. "
+ "This should be dropped from the feature matrix."
)
return new_names


def _parse_formula(
formula: FormulaSpec, include_intercept: bool = True
) -> tuple[Optional[Formula], Formula]:
Expand Down Expand Up @@ -2696,16 +2710,6 @@ def _set_up_and_check_fit_args(
self.term_names_ = list(
chain.from_iterable(
[term] * len(cols) for term, _, cols in X.model_spec.structure

if any(X.dtypes == "category"):
self.feature_names_ = list(
chain.from_iterable(
_name_categorical_variables(
dtype.categories, column, getattr(self, "drop_first", False)
)
if isinstance(dtype, pd.CategoricalDtype)
else [column]
for column, dtype in zip(X.columns, X.dtypes)
)
)

Expand All @@ -2715,6 +2719,17 @@ def _set_up_and_check_fit_args(
self.feature_dtypes_ = X.dtypes.to_dict()

if any(X.dtypes == "category"):

self.feature_names_ = list(
chain.from_iterable(
_name_categorical_variables(
dtype.categories, column, getattr(self, "drop_first", False)
)
if isinstance(dtype, pd.CategoricalDtype)
else [column]
for column, dtype in zip(X.columns, X.dtypes)
)
)

def _expand_categorical_penalties(penalty, X, drop_first):
"""
Expand Down

0 comments on commit 16bd925

Please sign in to comment.