Skip to content

Commit

Permalink
Make EvalLoggerCallback more configurable by allowing it to drop some…
Browse files Browse the repository at this point in the history
… specified metrics and to trim some specified prefix of metric names.
  • Loading branch information
bojan-karlas committed Jul 4, 2024
1 parent fe3e053 commit ac37c03
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion experiments/datascope/experiments/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,38 @@ def predict_proba(self, X: Union[NDArray, DataFrame]) -> NDArray:


class EvalLoggerCallback(TrainerCallback):
def __init__(self, logger: Optional[Logger] = None, prefix: str = "") -> None:
def __init__(
self,
logger: Optional[Logger] = None,
prefix: str = "",
trim_metric_prefix: Optional[str] = None,
drop_metrics_with_suffix: Optional[List[str]] = None,
) -> None:
self.logger = logger
self.prefix = prefix
self.trim_metric_prefix = trim_metric_prefix
self.drop_metrics_with_suffix = drop_metrics_with_suffix

def on_evaluate(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: Dict[str, float], **kwargs
):
if self.logger is not None:

# Add prefix to the message.
prefix = "[%s] " % self.prefix if self.prefix else ""

# Drop metrics with specified suffix is provided.
if self.drop_metrics_with_suffix is not None:
metrics = {
k: v
for k, v in metrics.items()
if not any(k.endswith(suffix) for suffix in self.drop_metrics_with_suffix)
}
# Trim metric prefix if specified.
if self.trim_metric_prefix is not None:
metrics = {k.removeprefix(self.trim_metric_prefix): v for k, v in metrics.items()}

# Construct the message and log it.
message = prefix + ", ".join(["%s=%.3f" % (k, v) for k, v in metrics.items()])
self.logger.debug(message)

Expand Down

0 comments on commit ac37c03

Please sign in to comment.