Skip to content

Commit

Permalink
Update parameters of the ResNet18Classifier class.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed Mar 1, 2024
1 parent 9e80711 commit 62ec7cf
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions experiments/datascope/experiments/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def compute_metrics(eval_pred):


class ResNet18Classifier(BaseEstimator, ClassifierMixin):
def __init__(self, n_epochs: int = 10, eval_split: float = 0.2, logger: Optional[Logger] = None) -> None:
def __init__(self, n_epochs: int = 10, eval_split: float = 0.1, logger: Optional[Logger] = None) -> None:
self.n_epochs = n_epochs
self.eval_split = eval_split
self.logger = logger
Expand Down Expand Up @@ -206,8 +206,8 @@ def fit(self, X: NDArray, y: NDArray) -> None:
output_dir=self.tempdir,
num_train_epochs=self.n_epochs,
evaluation_strategy=IntervalStrategy.STEPS,
eval_steps=50, # Evaluation and Save happens every 50 steps
save_total_limit=8,
eval_steps=150, # Evaluation and Save happens every 150 steps
save_total_limit=15,
load_best_model_at_end=True,
metric_for_best_model="roc_auc",
)
Expand All @@ -218,7 +218,7 @@ def fit(self, X: NDArray, y: NDArray) -> None:
eval_dataset=self.eval_dataset,
compute_metrics=compute_metrics,
callbacks=[
EarlyStoppingCallback(early_stopping_patience=5),
EarlyStoppingCallback(early_stopping_patience=10),
EvalLoggerCallback(self.logger, prefix=self.__class__.__name__),
],
)
Expand Down

0 comments on commit 62ec7cf

Please sign in to comment.