Skip to content

Commit

Permalink
Merge pull request #138 from staru09/main
Browse files Browse the repository at this point in the history
feat: UDFs support added
  • Loading branch information
shreyashankar authored Nov 1, 2024
2 parents 8fb2123 + 33c2436 commit 182246a
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 0 deletions.
146 changes: 146 additions & 0 deletions docetl/operations/code_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar

class CodeMapOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_map"
code: str
drop_keys: Optional[List[str]] = None

def syntax_check(self) -> None:
config = self.schema(**self.config)
try:
namespace = {}
exec(config.code, namespace)
if "transform" not in namespace:
raise ValueError("Code must define a 'transform' function")
if not callable(namespace["transform"]):
raise ValueError("'transform' must be a callable function")
except Exception as e:
raise ValueError(f"Invalid code configuration: {str(e)}")

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
namespace = {}
exec(self.config["code"], namespace)
transform_fn = namespace["transform"]

results = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(transform_fn, doc) for doc in input_data]
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (code_map)",
console=self.console,
)
for i in pbar:
result = futures[i].result()
if self.config.get("drop_keys"):
result = {
k: v for k, v in result.items()
if k not in self.config["drop_keys"]
}
doc = input_data[i]
merged_result = {**doc, **result}
results.append(merged_result)

return results, 0.0

class CodeReduceOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_reduce"
code: str

def syntax_check(self) -> None:
config = self.schema(**self.config)
try:
namespace = {}
exec(config.code, namespace)
if "transform" not in namespace:
raise ValueError("Code must define a 'transform' function")
if not callable(namespace["transform"]):
raise ValueError("'transform' must be a callable function")
except Exception as e:
raise ValueError(f"Invalid code configuration: {str(e)}")

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
namespace = {}
exec(self.config["code"], namespace)
reduce_fn = namespace["transform"]

reduce_keys = self.config.get("reduce_key", "_all")
if not isinstance(reduce_keys, list):
reduce_keys = [reduce_keys]

if reduce_keys == ["_all"] or reduce_keys == "_all":
grouped_data = [("_all", input_data)]
else:
def get_group_key(item):
return tuple(item[key] for key in reduce_keys)

grouped_data = {}
for item in input_data:
key = get_group_key(item)
if key not in grouped_data:
grouped_data[key] = []
grouped_data[key].append(item)

grouped_data = list(grouped_data.items())

results = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(reduce_fn, group) for _, group in grouped_data]
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (code_reduce)",
console=self.console,
)
for i, (key, group) in zip(pbar, grouped_data):
result = futures[i].result()

# Apply pass-through at the group level
if self.config.get("pass_through", False) and group:
for k, v in group[0].items():
if k not in result:
result[k] = v

results.append(result)

return results, 0.0

class CodeFilterOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_filter"
code: str

def syntax_check(self) -> None:
config = self.schema(**self.config)
try:
namespace = {}
exec(config.code, namespace)
if "transform" not in namespace:
raise ValueError("Code must define a 'transform' function")
if not callable(namespace["transform"]):
raise ValueError("'transform' must be a callable function")
except Exception as e:
raise ValueError(f"Invalid code configuration: {str(e)}")

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
namespace = {}
exec(self.config["code"], namespace)
filter_fn = namespace["transform"]

results = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(filter_fn, doc) for doc in input_data]
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (code_filter)",
console=self.console,
)
for i in pbar:
should_keep = futures[i].result()
if should_keep:
results.append(input_data[i])
return results, 0.0
92 changes: 92 additions & 0 deletions docs/operators/code.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Code Operations

Code operations in DocETL allow you to define transformations using Python code rather than LLM prompts. This is useful when you need deterministic processing, complex calculations, or want to leverage existing Python libraries.

## Motivation

While LLM-powered operations are powerful for natural language tasks, sometimes you need operations that are:

- Deterministic and reproducible
- Integrated with external Python libraries
- Focused on structured data transformations
- Math-based or computationally intensive (something an LLM is not good at)

Code operations provide a way to handle these cases efficiently without LLM overhead.

## Types of Code Operations

### Code Map Operation

The Code Map operation applies a Python function to each item in your input data independently.

??? example "Example Code Map Operation"

```yaml
- name: extract_keywords
type: code_map
code: |
def transform(doc) -> dict:
# Your transformation code here
keywords = doc['text'].lower().split()
return {
'keywords': keywords,
'keyword_count': len(keywords)
}
```

The code must define a `transform` function that takes a single document as input and returns a dictionary of transformed values.

### Code Reduce Operation

The Code Reduce operation aggregates multiple items into a single result using a Python function.

??? example "Example Code Reduce Operation"

```yaml
- name: aggregate_stats
type: code_reduce
reduce_key: category
code: |
def transform(items) -> dict:
total = sum(item['value'] for item in items)
avg = total / len(items)
return {
'total': total,
'average': avg,
'count': len(items)
}
```

The transform function for reduce operations takes a list of items as input and returns a single aggregated result.

### Code Filter Operation

The Code Filter operation allows you to filter items based on custom Python logic.

??? example "Example Code Filter Operation"

```yaml
- name: filter_valid_entries
type: code_filter
code: |
def transform(doc) -> bool:
# Return True to keep the document, False to filter it out
return doc['score'] >= 0.5 and len(doc['text']) > 100
```

The transform function should return True for items to keep and False for items to filter out.

## Configuration

### Required Parameters

- type: Must be "code_map", "code_reduce", or "code_filter"
- code: Python code containing the transform function. For map, the function must take a single document as input and return a document (a dictionary). For reduce, the function must take a list of documents as input and return a single aggregated document (a dictionary). For filter, the function must take a single document as input and return a boolean value indicating whether to keep the document.

### Optional Parameters

| Parameter | Description | Default |
|-----------|-------------|---------|
| drop_keys | List of keys to remove from output (code_map only) | None |
| reduce_key | Key(s) to group by (code_reduce only) | "_all" |
| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false |
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ nav:
- Gather: operators/gather.md
- Unnest: operators/unnest.md
- Sample: operators/sample.md
- Code: operators/code.md
- Optimization:
- Overview: optimization/overview.md
- Example: optimization/example.md
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ gather = "docetl.operations.gather:GatherOperation"
cluster = "docetl.operations.cluster:ClusterOperation"
sample = "docetl.operations.sample:SampleOperation"
link_resolve = "docetl.operations.link_resolve:LinkResolveOperation"
code_map = "docetl.operations.code_operations:CodeMapOperation"
code_reduce = "docetl.operations.code_operations:CodeReduceOperation"
code_filter = "docetl.operations.code_operations:CodeFilterOperation"

[tool.poetry.plugins."docetl.parser"]
llama_index_simple_directory_reader = "docetl.parsing_tools:llama_index_simple_directory_reader"
Expand Down
Loading

0 comments on commit 182246a

Please sign in to comment.