Skip to content

Commit

Permalink
Feat/cross validation (#121)
Browse files Browse the repository at this point in the history
* define interface

* basic ho iterator

* move obtaining data for train from node optimizer to modules themselves

* stage progress

* implement cv iterator

* minor bug fix

* implement cv iterator for decision node

* move cv iteration to base module definition

* implement cv iterator for embedding node

* add training to `score_ho` of each node

* properly define base module

* fix codestyle

* remove regexp node

* remove regexp validator

* fix typing problems (except `DataHandler._split_cv`)

* add ingore oos decorator

* fix codestyle

* fix typing

* add oos handling to cv iterator

* remove `DataHandler.dump()`

* minor bug fix

* implement splitting to cv folds

* fix codestyle

* remove regex tests

* bug fix

* bug fix

* update tests

* fix typing

* big fix

* basic test on cv folding

* add tests for metrics to ignore oos samples

* add tests for cv iterator

* fix codestyle

* minor bug fix

* fix codestyle

* add test for cv

* bug fix

* implement cv iterator for description scorer

* refactor cv iterator for description node

* fix typing

* add cache cleaning before refitting

* bug fix

* implement refitting the whole pipeline with all train data

* fix typing

* bug fix

* fix typing

* respond to samoed

* create `ValidationType` in `autointent.custom_types`

* fix docstring

* properly expose `n_folds` argument

* `ValidationType` -> `ValidationScheme`

* `make schema`
  • Loading branch information
voorhs authored Feb 10, 2025
1 parent 6a478cd commit 0979234
Show file tree
Hide file tree
Showing 42 changed files with 545 additions and 629 deletions.
32 changes: 29 additions & 3 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from autointent import Context, Dataset
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import ListOfGenericLabels, NodeType
from autointent.custom_types import ListOfGenericLabels, NodeType, ValidationScheme
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
Expand Down Expand Up @@ -122,7 +122,9 @@ def _is_inference(self) -> bool:
"""
return isinstance(self.nodes[NodeType.scoring], InferenceNode)

def fit(self, dataset: Dataset) -> Context:
def fit(
self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3, refit_after: bool = False
) -> Context:
"""
Optimize the pipeline from dataset.
Expand All @@ -134,7 +136,7 @@ def fit(self, dataset: Dataset) -> Context:
raise RuntimeError(msg)

context = Context()
context.set_dataset(dataset)
context.set_dataset(dataset, scheme, n_folds)
context.configure_logging(self.logging_config)
context.configure_vector_index(self.vector_index_config, self.embedder_config)
context.configure_cross_encoder(self.cross_encoder_config)
Expand All @@ -150,6 +152,9 @@ def fit(self, dataset: Dataset) -> Context:

self.nodes = {node.node_type: node for node in nodes_list}

if refit_after:
self._refit(context)

predictions = self.predict(context.data_handler.test_utterances())
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
context.optimization_info.pipeline_metrics[metric_name] = metric(
Expand Down Expand Up @@ -220,6 +225,27 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels:
scores = scoring_module.predict(utterances)
return decision_module.predict(scores)

def _refit(self, context: Context) -> None:
"""
Fit pipeline of already selected modules with all train data.
:param context: context object to take data from
:return: list of predicted labels
"""
if not self._is_inference():
msg = "Pipeline in optimization mode cannot perform inference"
raise RuntimeError(msg)

scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]

context.data_handler.prepare_for_refit()

scoring_module.fit(*scoring_module.get_train_data(context))
scores = scoring_module.predict(context.data_handler.train_utterances(1))

decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags)

def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
"""
Predict the labels for the utterances with metadata.
Expand Down
7 changes: 7 additions & 0 deletions autointent/_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Can be used to rank retrieved sentences by meaning closeness to provided utterance.
"""

import gc
import itertools as it
import json
import logging
Expand Down Expand Up @@ -272,3 +273,9 @@ def load(cls, path: Path) -> "Ranker":
metadata: CrossEncoderMetadata = json.load(file)

return cls(**metadata, classifier_head=clf)

def clear_ram(self) -> None:
self.cross_encoder.model.cpu()
del self.cross_encoder
gc.collect()
torch.cuda.empty_cache()
6 changes: 6 additions & 0 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, Field

from autointent.custom_types import ValidationScheme

from ._name import get_run_name


Expand All @@ -12,6 +14,10 @@ class DataConfig(BaseModel):

train_path: str | Path
"""Path to the training data. Can be local path or HF repo."""
scheme: ValidationScheme
"""Hold-out or cross-validation."""
n_folds: int = 3
"""Number of folds in cross-validation."""


class TaskConfig(BaseModel):
Expand Down
10 changes: 6 additions & 4 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LoggingConfig,
VectorIndexConfig,
)
from autointent.custom_types import ValidationScheme

from ._utils import NumpyEncoder, load_dataset
from .data_handler import DataHandler
Expand Down Expand Up @@ -81,11 +82,10 @@ def configure_data(self, config: DataConfig) -> None:
:param config: Configuration for the data handling process.
"""
self.data_handler = DataHandler(
dataset=load_dataset(config.train_path),
random_seed=self.seed,
dataset=load_dataset(config.train_path), random_seed=self.seed, scheme=config.scheme
)

def set_dataset(self, dataset: Dataset) -> None:
def set_dataset(self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3) -> None:
"""
Set the datasets for training, validation and testing.
Expand All @@ -94,6 +94,8 @@ def set_dataset(self, dataset: Dataset) -> None:
self.data_handler = DataHandler(
dataset=dataset,
random_seed=self.seed,
scheme=scheme,
n_folds=n_folds,
)

def get_inference_config(self) -> dict[str, Any]:
Expand Down Expand Up @@ -137,7 +139,7 @@ def dump(self) -> None:
# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
self.data_handler.dump(logs_dir / "dataset.json")
self.data_handler.dataset.to_json(logs_dir / "dataset.json")

self._logger.info("logs and other assets are saved to %s", logs_dir)

Expand Down
120 changes: 92 additions & 28 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Data Handler file."""

import logging
from pathlib import Path
from collections.abc import Generator
from typing import TypedDict, cast

from datasets import concatenate_datasets
from transformers import set_seed

from autointent import Dataset
from autointent.custom_types import ListOfGenericLabels, Split
from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split, ValidationScheme

from ._stratification import split_dataset

Expand All @@ -26,10 +26,17 @@ class RegexPatterns(TypedDict):
"""Partial match regex patterns."""


class DataHandler:
class DataHandler: # TODO rename to Validator
"""Data handler class."""

def __init__(self, dataset: Dataset, random_seed: int = 0, split_train: bool = True) -> None:
def __init__(
self,
dataset: Dataset,
scheme: ValidationScheme = "ho",
split_train: bool = True,
random_seed: int = 0,
n_folds: int = 3,
) -> None:
"""
Initialize the data handler.
Expand All @@ -39,12 +46,18 @@ def __init__(self, dataset: Dataset, random_seed: int = 0, split_train: bool = T
threshold search).
"""
set_seed(random_seed)
self.random_seed = random_seed

self.dataset = dataset

self.n_classes = self.dataset.n_classes
self.scheme = scheme
self.n_folds = n_folds

self._split(random_seed, split_train)
if scheme == "ho":
self._split_ho(split_train)
elif scheme == "cv":
self._split_cv()

self.regexp_patterns = [
RegexPatterns(
Expand Down Expand Up @@ -97,6 +110,9 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])

def train_labels_folded(self) -> list[ListOfGenericLabels]:
return [self.train_labels(j) for j in range(self.n_folds)]

def validation_utterances(self, idx: int | None = None) -> list[str]:
"""
Retrieve validation utterances from the dataset.
Expand Down Expand Up @@ -153,28 +169,37 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])

def dump(self, filepath: str | Path) -> None:
"""
Save the dataset splits and intents to a JSON file.
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
if self.scheme == "ho":
msg = "Cannot call cross-validation on hold-out DataHandler"
raise RuntimeError(msg)

:param filepath: The path to the file where the JSON data will be saved.
"""
self.dataset.to_json(filepath)
for j in range(self.n_folds):
val_utterances = self.train_utterances(j)
val_labels = self.train_labels(j)
train_folds = [i for i in range(self.n_folds) if i != j]
train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)]
train_labels = [lab for i_fold in train_folds for lab in self.train_labels(i_fold)]

def _split(self, random_seed: int, split_train: bool) -> None:
# filter out all OOS samples from train
train_utterances = [ut for ut, lab in zip(train_utterances, train_labels, strict=True) if lab is not None]
train_labels = [lab for lab in train_labels if lab is not None]
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]

def _split_ho(self, split_train: bool) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)

if split_train and Split.TRAIN in self.dataset:
self._split_train(random_seed)
self._split_train()

if Split.TEST not in self.dataset:
test_size = 0.1 if has_validation_split else 0.2
self._split_test(test_size, random_seed)
self._split_test(test_size)

if not has_validation_split:
self._split_validation_from_train(random_seed)
self._split_validation_from_train()
elif Split.VALIDATION in self.dataset:
self._split_validation(random_seed)
self._split_validation()

for split in self.dataset:
n_classes_split = self.dataset.get_n_classes(split)
Expand All @@ -185,7 +210,7 @@ def _split(self, random_seed: int, split_train: bool) -> None:
)
raise ValueError(message)

def _split_train(self, random_seed: int) -> None:
def _split_train(self) -> None:
"""
Split on two sets.
Expand All @@ -195,12 +220,12 @@ def _split_train(self, random_seed: int) -> None:
self.dataset,
split=Split.TRAIN,
test_size=0.5,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=False, # only train data for decision node should contain OOS
)
self.dataset.pop(Split.TRAIN)

def _split_validation(self, random_seed: int) -> None:
def _split_validation(self) -> None:
"""
Split on two sets.
Expand All @@ -210,27 +235,49 @@ def _split_validation(self, random_seed: int) -> None:
self.dataset,
split=Split.VALIDATION,
test_size=0.5,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=False, # only val data for decision node should contain OOS
)
self.dataset.pop(Split.VALIDATION)

def _split_validation_from_test(self, random_seed: int) -> None:
def _split_validation_from_test(self) -> None:
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=Split.TEST,
test_size=0.5,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=True, # both test and validation splits can contain OOS
)

def _split_validation_from_train(self, random_seed: int) -> None:
def _split_cv(self) -> None:
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
if extra_splits:
self.dataset[Split.TRAIN] = concatenate_datasets(
[self.dataset.pop(split_name) for split_name in extra_splits]
)

if Split.TEST not in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True
)

for j in range(self.n_folds - 1):
self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=1 / (self.n_folds - j),
random_seed=self.random_seed,
allow_oos_in_train=True,
)
self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN)

def _split_validation_from_train(self) -> None:
if Split.TRAIN in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=0.2,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=True,
)
else:
Expand All @@ -239,27 +286,44 @@ def _split_validation_from_train(self, random_seed: int) -> None:
self.dataset,
split=f"{Split.TRAIN}_{idx}",
test_size=0.2,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
)

def _split_test(self, test_size: float, random_seed: int) -> None:
def _split_test(self, test_size: float) -> None:
"""Obtain test set from train."""
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_0",
test_size=test_size,
random_seed=random_seed,
random_seed=self.random_seed,
)
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_1",
test_size=test_size,
random_seed=random_seed,
random_seed=self.random_seed,
allow_oos_in_train=True,
)
self.dataset[Split.TEST] = concatenate_datasets(
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
)
self.dataset.pop(f"{Split.TEST}_0")
self.dataset.pop(f"{Split.TEST}_1")

def prepare_for_refit(self) -> None:
if self.scheme == "ho":
return

train_folds = [split_name for split_name in self.dataset if split_name.startswith(Split.TRAIN)]
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(name) for name in train_folds])

self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=0.5,
random_seed=self.random_seed,
allow_oos_in_train=False,
)

self.dataset.pop(Split.TRAIN)
3 changes: 3 additions & 0 deletions autointent/context/optimization_info/_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class ScorerArtifact(Artifact):
train_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for train utterances")
validation_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for validation utterances")
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
folded_scores: list[NDArray[np.float64]] | None = Field(
None, description="Scores for each fold from cross-validation"
)


class DecisionArtifact(Artifact):
Expand Down
Loading

0 comments on commit 0979234

Please sign in to comment.