Skip to content

Commit

Permalink
change samples interface
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 12, 2024
1 parent 4a12935 commit cf6d06f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
25 changes: 16 additions & 9 deletions docetl/operations/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sklearn.model_selection
from typing import Any, Dict, List, Optional, Tuple
from .base import BaseOperation
from docetl.operations.base import BaseOperation


class SampleOperation(BaseOperation):
Expand All @@ -20,7 +19,7 @@ def syntax_check(self) -> None:
TypeError: If configuration values have incorrect types.
"""
pass

def execute(
self, input_data: List[Dict], is_build: bool = False
) -> Tuple[List[Dict], float]:
Expand All @@ -39,15 +38,23 @@ def execute(

samples = self.config["samples"]
if isinstance(samples, list):
output_data = [input_data[sample]
for sample in samples]
keys = list(samples[0].keys())
key_to_doc = {tuple([doc[key] for key in keys]): doc for doc in input_data}

output_data = [
key_to_doc[tuple([sample[key] for key in keys])] for sample in samples
]
else:
stratify=None
stratify = None
if "stratify" in self.config:
stratify = [data[self.config["stratify"]] for data in input_data]

import sklearn.model_selection

output_data, dummy = sklearn.model_selection.train_test_split(
input_data,
train_size = samples,
random_state = self.config.get("random_state", None),
stratify = stratify)
train_size=samples,
random_state=self.config.get("random_state", None),
stratify=stratify,
)
return output_data, 0
14 changes: 6 additions & 8 deletions docs/operators/sample.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ comfortably debug its prompt. Once it seems to be working, you can
remove the sample operation. You can then repeat this for each
operation you add while developing your pipeline!



## 🚀 Example:
## 🚀 Example:

```yaml
- name: cluster_concepts
Expand All @@ -33,11 +31,11 @@ sample each value of the `category` key equally.

- `name`: A unique name for the operation.
- `type`: Must be set to "sample".
- `samples`: Either a list of sample indices to just return those samples, an integer count of samples, or a float fraction of samples.
- `samples`: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples.

## Optional Parameters

| Parameter | Description | Default |
| ------------------------- | -------------------------------------------------------------------------------- | ----------------------------- |
| `random_state | An integer to seed the random generator with | Use the (numpy) global random state
| `stratify` | The key to stratify by | |
| Parameter | Description | Default |
| ------------- | -------------------------------------------- | ----------------------------------- |
| `random_state | An integer to seed the random generator with | Use the (numpy) global random state |
| `stratify` | The key to stratify by | |

0 comments on commit cf6d06f

Please sign in to comment.