From cf6d06f71395cdedda60c9d7fdf243a06b381b41 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 12 Oct 2024 12:54:23 -0700 Subject: [PATCH] change samples interface --- docetl/operations/sample.py | 25 ++++++++++++++++--------- docs/operators/sample.md | 14 ++++++-------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/docetl/operations/sample.py b/docetl/operations/sample.py index 083870ff..91ebc344 100644 --- a/docetl/operations/sample.py +++ b/docetl/operations/sample.py @@ -1,6 +1,5 @@ -import sklearn.model_selection from typing import Any, Dict, List, Optional, Tuple -from .base import BaseOperation +from docetl.operations.base import BaseOperation class SampleOperation(BaseOperation): @@ -20,7 +19,7 @@ def syntax_check(self) -> None: TypeError: If configuration values have incorrect types. """ pass - + def execute( self, input_data: List[Dict], is_build: bool = False ) -> Tuple[List[Dict], float]: @@ -39,15 +38,23 @@ def execute( samples = self.config["samples"] if isinstance(samples, list): - output_data = [input_data[sample] - for sample in samples] + keys = list(samples[0].keys()) + key_to_doc = {tuple([doc[key] for key in keys]): doc for doc in input_data} + + output_data = [ + key_to_doc[tuple([sample[key] for key in keys])] for sample in samples + ] else: - stratify=None + stratify = None if "stratify" in self.config: stratify = [data[self.config["stratify"]] for data in input_data] + + import sklearn.model_selection + output_data, dummy = sklearn.model_selection.train_test_split( input_data, - train_size = samples, - random_state = self.config.get("random_state", None), - stratify = stratify) + train_size=samples, + random_state=self.config.get("random_state", None), + stratify=stratify, + ) return output_data, 0 diff --git a/docs/operators/sample.md b/docs/operators/sample.md index 3d5c0a91..62d852a2 100644 --- a/docs/operators/sample.md +++ b/docs/operators/sample.md @@ -10,9 +10,7 @@ comfortably debug its prompt. Once it seems to be working, you can remove the sample operation. You can then repeat this for each operation you add while developing your pipeline! - - -## 🚀 Example: +## 🚀 Example: ```yaml - name: cluster_concepts @@ -33,11 +31,11 @@ sample each value of the `category` key equally. - `name`: A unique name for the operation. - `type`: Must be set to "sample". -- `samples`: Either a list of sample indices to just return those samples, an integer count of samples, or a float fraction of samples. +- `samples`: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples. ## Optional Parameters -| Parameter | Description | Default | -| ------------------------- | -------------------------------------------------------------------------------- | ----------------------------- | -| `random_state | An integer to seed the random generator with | Use the (numpy) global random state -| `stratify` | The key to stratify by | | +| Parameter | Description | Default | +| ------------- | -------------------------------------------- | ----------------------------------- | +| `random_state | An integer to seed the random generator with | Use the (numpy) global random state | +| `stratify` | The key to stratify by | |