-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #138 from staru09/main
feat: UDFs support added
- Loading branch information
Showing
5 changed files
with
418 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from typing import Any, Dict, List, Optional, Tuple | ||
from concurrent.futures import ThreadPoolExecutor | ||
from docetl.operations.base import BaseOperation | ||
from docetl.operations.utils import RichLoopBar | ||
|
||
class CodeMapOperation(BaseOperation): | ||
class schema(BaseOperation.schema): | ||
type: str = "code_map" | ||
code: str | ||
drop_keys: Optional[List[str]] = None | ||
|
||
def syntax_check(self) -> None: | ||
config = self.schema(**self.config) | ||
try: | ||
namespace = {} | ||
exec(config.code, namespace) | ||
if "transform" not in namespace: | ||
raise ValueError("Code must define a 'transform' function") | ||
if not callable(namespace["transform"]): | ||
raise ValueError("'transform' must be a callable function") | ||
except Exception as e: | ||
raise ValueError(f"Invalid code configuration: {str(e)}") | ||
|
||
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: | ||
namespace = {} | ||
exec(self.config["code"], namespace) | ||
transform_fn = namespace["transform"] | ||
|
||
results = [] | ||
with ThreadPoolExecutor() as executor: | ||
futures = [executor.submit(transform_fn, doc) for doc in input_data] | ||
pbar = RichLoopBar( | ||
range(len(futures)), | ||
desc=f"Processing {self.config['name']} (code_map)", | ||
console=self.console, | ||
) | ||
for i in pbar: | ||
result = futures[i].result() | ||
if self.config.get("drop_keys"): | ||
result = { | ||
k: v for k, v in result.items() | ||
if k not in self.config["drop_keys"] | ||
} | ||
doc = input_data[i] | ||
merged_result = {**doc, **result} | ||
results.append(merged_result) | ||
|
||
return results, 0.0 | ||
|
||
class CodeReduceOperation(BaseOperation): | ||
class schema(BaseOperation.schema): | ||
type: str = "code_reduce" | ||
code: str | ||
|
||
def syntax_check(self) -> None: | ||
config = self.schema(**self.config) | ||
try: | ||
namespace = {} | ||
exec(config.code, namespace) | ||
if "transform" not in namespace: | ||
raise ValueError("Code must define a 'transform' function") | ||
if not callable(namespace["transform"]): | ||
raise ValueError("'transform' must be a callable function") | ||
except Exception as e: | ||
raise ValueError(f"Invalid code configuration: {str(e)}") | ||
|
||
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: | ||
namespace = {} | ||
exec(self.config["code"], namespace) | ||
reduce_fn = namespace["transform"] | ||
|
||
reduce_keys = self.config.get("reduce_key", "_all") | ||
if not isinstance(reduce_keys, list): | ||
reduce_keys = [reduce_keys] | ||
|
||
if reduce_keys == ["_all"] or reduce_keys == "_all": | ||
grouped_data = [("_all", input_data)] | ||
else: | ||
def get_group_key(item): | ||
return tuple(item[key] for key in reduce_keys) | ||
|
||
grouped_data = {} | ||
for item in input_data: | ||
key = get_group_key(item) | ||
if key not in grouped_data: | ||
grouped_data[key] = [] | ||
grouped_data[key].append(item) | ||
|
||
grouped_data = list(grouped_data.items()) | ||
|
||
results = [] | ||
with ThreadPoolExecutor() as executor: | ||
futures = [executor.submit(reduce_fn, group) for _, group in grouped_data] | ||
pbar = RichLoopBar( | ||
range(len(futures)), | ||
desc=f"Processing {self.config['name']} (code_reduce)", | ||
console=self.console, | ||
) | ||
for i, (key, group) in zip(pbar, grouped_data): | ||
result = futures[i].result() | ||
|
||
# Apply pass-through at the group level | ||
if self.config.get("pass_through", False) and group: | ||
for k, v in group[0].items(): | ||
if k not in result: | ||
result[k] = v | ||
|
||
results.append(result) | ||
|
||
return results, 0.0 | ||
|
||
class CodeFilterOperation(BaseOperation): | ||
class schema(BaseOperation.schema): | ||
type: str = "code_filter" | ||
code: str | ||
|
||
def syntax_check(self) -> None: | ||
config = self.schema(**self.config) | ||
try: | ||
namespace = {} | ||
exec(config.code, namespace) | ||
if "transform" not in namespace: | ||
raise ValueError("Code must define a 'transform' function") | ||
if not callable(namespace["transform"]): | ||
raise ValueError("'transform' must be a callable function") | ||
except Exception as e: | ||
raise ValueError(f"Invalid code configuration: {str(e)}") | ||
|
||
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: | ||
namespace = {} | ||
exec(self.config["code"], namespace) | ||
filter_fn = namespace["transform"] | ||
|
||
results = [] | ||
with ThreadPoolExecutor() as executor: | ||
futures = [executor.submit(filter_fn, doc) for doc in input_data] | ||
pbar = RichLoopBar( | ||
range(len(futures)), | ||
desc=f"Processing {self.config['name']} (code_filter)", | ||
console=self.console, | ||
) | ||
for i in pbar: | ||
should_keep = futures[i].result() | ||
if should_keep: | ||
results.append(input_data[i]) | ||
return results, 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Code Operations | ||
|
||
Code operations in DocETL allow you to define transformations using Python code rather than LLM prompts. This is useful when you need deterministic processing, complex calculations, or want to leverage existing Python libraries. | ||
|
||
## Motivation | ||
|
||
While LLM-powered operations are powerful for natural language tasks, sometimes you need operations that are: | ||
|
||
- Deterministic and reproducible | ||
- Integrated with external Python libraries | ||
- Focused on structured data transformations | ||
- Math-based or computationally intensive (something an LLM is not good at) | ||
|
||
Code operations provide a way to handle these cases efficiently without LLM overhead. | ||
|
||
## Types of Code Operations | ||
|
||
### Code Map Operation | ||
|
||
The Code Map operation applies a Python function to each item in your input data independently. | ||
|
||
??? example "Example Code Map Operation" | ||
|
||
```yaml | ||
- name: extract_keywords | ||
type: code_map | ||
code: | | ||
def transform(doc) -> dict: | ||
# Your transformation code here | ||
keywords = doc['text'].lower().split() | ||
return { | ||
'keywords': keywords, | ||
'keyword_count': len(keywords) | ||
} | ||
``` | ||
|
||
The code must define a `transform` function that takes a single document as input and returns a dictionary of transformed values. | ||
|
||
### Code Reduce Operation | ||
|
||
The Code Reduce operation aggregates multiple items into a single result using a Python function. | ||
|
||
??? example "Example Code Reduce Operation" | ||
|
||
```yaml | ||
- name: aggregate_stats | ||
type: code_reduce | ||
reduce_key: category | ||
code: | | ||
def transform(items) -> dict: | ||
total = sum(item['value'] for item in items) | ||
avg = total / len(items) | ||
return { | ||
'total': total, | ||
'average': avg, | ||
'count': len(items) | ||
} | ||
``` | ||
|
||
The transform function for reduce operations takes a list of items as input and returns a single aggregated result. | ||
|
||
### Code Filter Operation | ||
|
||
The Code Filter operation allows you to filter items based on custom Python logic. | ||
|
||
??? example "Example Code Filter Operation" | ||
|
||
```yaml | ||
- name: filter_valid_entries | ||
type: code_filter | ||
code: | | ||
def transform(doc) -> bool: | ||
# Return True to keep the document, False to filter it out | ||
return doc['score'] >= 0.5 and len(doc['text']) > 100 | ||
``` | ||
|
||
The transform function should return True for items to keep and False for items to filter out. | ||
|
||
## Configuration | ||
|
||
### Required Parameters | ||
|
||
- type: Must be "code_map", "code_reduce", or "code_filter" | ||
- code: Python code containing the transform function. For map, the function must take a single document as input and return a document (a dictionary). For reduce, the function must take a list of documents as input and return a single aggregated document (a dictionary). For filter, the function must take a single document as input and return a boolean value indicating whether to keep the document. | ||
|
||
### Optional Parameters | ||
|
||
| Parameter | Description | Default | | ||
|-----------|-------------|---------| | ||
| drop_keys | List of keys to remove from output (code_map only) | None | | ||
| reduce_key | Key(s) to group by (code_reduce only) | "_all" | | ||
| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.