diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 09144e77..033e3bf4 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -1,3 +1,4 @@ +import numpy as np from jinja2 import Environment, Template from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple @@ -99,6 +100,9 @@ def execute( ) tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings) + + if "collapse" in self.config: + tree = self.collapse_tree(tree, collapse = self.config["collapse"]) self.prompt_template = Template(self.config["summary_prompt"]) cost += self.annotate_clustering_tree(tree) @@ -122,7 +126,7 @@ def build_tree(i): # res["embedding"] = list(embeddings[i]) return res return { - "children": [ + "children": [ build_tree(cl.children_[i - nsamples, 0]), build_tree(cl.children_[i - nsamples, 1]), ], @@ -131,6 +135,40 @@ def build_tree(i): return build_tree(nsamples + len(cl.children_) - 1) + def get_tree_distances(self, t): + res = set() + if "distance" in t: + res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child])) + if "children" in t: + for child in t["children"]: + res.update(self.get_tree_distances(child)) + return res + + def _collapse_tree(self, t, parent_dist = None, collapse = None): + if "children" in t: + if ( "distance" in t + and parent_dist is not None + and collapse is not None + and parent_dist - t["distance"] < collapse): + return [grandchild + for child in t["children"] + for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)] + else: + res = dict(t) + res["children"] = [grandchild + for idx, child in enumerate(t["children"]) + for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)] + return [res] + else: + return [t] + + def collapse_tree(self, tree, collapse = None): + if collapse is not None: + tree_distances = np.array(sorted(self.get_tree_distances(tree))) + collapse = tree_distances[int(len(tree_distances) * collapse)] + return self._collapse_tree(tree, collapse=collapse)[0] + + def annotate_clustering_tree(self, t): if "children" in t: with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: @@ -149,12 +187,8 @@ def annotate_clustering_tree(self, t): total_cost += futures[i].result() pbar.update(i) - assert len(t["children"]) == 2, ( - "Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s" - % len(t["children"]) - ) prompt = self.prompt_template.render( - left=t["children"][0], right=t["children"][1] + inputs=t["children"] ) def validation_fn(response: Dict[str, Any]): @@ -167,31 +201,33 @@ def validation_fn(response: Dict[str, Any]): return output, True return output, False - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], + response = self.runner.api.call_llm( model=self.config.get("model", self.default_model), - operation_type="cluster", - schema=self.config["summary_schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "cluster", - messages, - self.config["summary_schema"], - tools=self.config.get("tools", None), - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), + op_type="cluster", + messages=[{"role": "user", "content": prompt}], + output_schema=self.config["summary_schema"], + timeout_seconds=self.config.get("timeout", 120), + bypass_cache=self.config.get("bypass_cache", False), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": validation_fn, + } + if self.config.get("validate", None) + else None ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, + verbose=self.config.get("verbose", False), ) - total_cost += cost - - t.update(output) + total_cost += response.total_cost + if response.validated: + output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["summary_schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + t.update(output) return total_cost return 0 diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index e56a7b39..d6cf6469 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -82,9 +82,12 @@ def compare_pair( {"is_match": "bool"}, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, + bypass_cache=self.config.get("bypass_cache", False), ) - output = self.runner.api.parse_llm_response(response, {"is_match": "bool"})[0] - return output["is_match"], completion_cost(response) + output = self.runner.api.parse_llm_response( + response.response, {"is_match": "bool"} + )[0] + return output["is_match"], response.total_cost def syntax_check(self) -> None: """ diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index d67042e9..48037eef 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -5,13 +5,13 @@ from jinja2 import Template -from docetl.operations.base import BaseOperation +from docetl.operations.map import MapOperation from docetl.operations.utils import ( RichLoopBar, ) -class FilterOperation(BaseOperation): +class FilterOperation(MapOperation): def syntax_check(self) -> None: """ Checks the configuration of the FilterOperation for required keys and valid structure. @@ -110,77 +110,9 @@ def execute( ) ) - if self.status: - self.status.start() - - def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]: - prompt_template = Template(self.config["prompt"]) - prompt = prompt_template.render(input=item) - - def validation_fn(response: Dict[str, Any]): - output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - for key, value in item.items(): - if key not in self.config["output"]["schema"]: - output[key] = value - if self.runner.api.validate_output(self.config, output, self.console): - return output, True - return output, False - - output, cost, is_valid = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="filter", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "filter", - messages, - self.config["output"]["schema"], - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) + results, total_cost = super().execute(input_data) - if is_valid: - return output, cost - - return None, cost - - with ThreadPoolExecutor(max_workers=self.max_threads) as executor: - futures = [ - executor.submit(_process_filter_item, item) for item in input_data - ] - results = [] - total_cost = 0 - pbar = RichLoopBar( - range(len(futures)), - desc=f"Processing {self.config['name']} (filter) on all documents", - console=self.console, - ) - for i in pbar: - future = futures[i] - result, item_cost = future.result() - total_cost += item_cost - if result is not None: - if is_build: - results.append(result) - else: - if result.get(filter_key, False): - results.append(result) - pbar.update(1) - - if self.status: - self.status.start() + # Drop records with filter_key values that are False + results = [result for result in results if result[filter_key]] return results, total_cost diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 3f077713..300e8419 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -153,59 +153,42 @@ def validation_fn(response: Dict[str, Any]): return output, False self.runner.rate_limiter.try_acquire("call", weight=1) - if "gleaning" in self.config: - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="map", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm_with_gleaning( - self.config.get("model", self.default_model), - "map", - messages, - self.config["output"]["schema"], - self.config["gleaning"]["validation_prompt"], - self.config["gleaning"]["num_rounds"], - self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - verbose=self.config.get("verbose", False), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) - else: - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="map", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "map", - messages, - self.config["output"]["schema"], - tools=self.config.get("tools", None), - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) + llm_result = self.runner.api.call_llm( + self.config.get("model", self.default_model), + "map", + [{"role": "user", "content": prompt}], + self.config["output"]["schema"], + tools=self.config.get("tools", None), + scratchpad=None, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": validation_fn, + } + if self.config.get("validate", None) + else None + ), + gleaning_config=self.config.get("gleaning", None), + verbose=self.config.get("verbose", False), + bypass_cache=self.config.get("bypass_cache", False), + ) - if success: - return output, cost + if llm_result.validated: + # Parse the response + output = self.runner.api.parse_llm_response( + llm_result.response, + schema=self.config["output"]["schema"], + tools=self.config.get("tools", None), + manually_fix_errors=self.manually_fix_errors, + )[0] + # Augment the output with the original item + output = {**item, **output} + return output, llm_result.total_cost - return None, cost + return None, llm_result.total_cost with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: futures = [executor.submit(_process_map_item, item) for item in input_data] @@ -375,17 +358,17 @@ def process_prompt(item, prompt_config): [{"role": "user", "content": prompt}], local_output_schema, tools=prompt_config.get("tools", None), - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), ) output = self.runner.api.parse_llm_response( - response, + response.response, schema=local_output_schema, tools=prompt_config.get("tools", None), manually_fix_errors=self.manually_fix_errors, )[0] - return output, completion_cost(response) + return output, response.total_cost with ThreadPoolExecutor(max_workers=self.max_threads) as executor: if "prompts" in self.config: diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index b4091865..682b5d70 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -12,7 +12,7 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import jinja2 import numpy as np @@ -404,7 +404,7 @@ def _cluster_based_sampling( return group_list, 0 clusters, cost = cluster_documents( - group_list, value_sampling, sample_size, self.api + group_list, value_sampling, sample_size, self.runner.api ) sampled_items = [] @@ -444,7 +444,7 @@ def _semantic_similarity_sampling( ) embeddings, cost = get_embeddings_for_clustering( - group_list, value_sampling, self.api + group_list, value_sampling, self.runner.api ) query_response = self.runner.api.gen_embedding(embedding_model, [query_text]) @@ -684,6 +684,15 @@ def _incremental_reduce( return current_output, total_cost + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + def _increment_fold( self, key: Tuple, @@ -715,29 +724,43 @@ def _increment_fold( output=current_output, reduce_key=dict(zip(self.config["reduce_key"], key)), ) + response = self.runner.api.call_llm( self.config.get("model", self.default_model), "reduce", [{"role": "user", "content": fold_prompt}], self.config["output"]["schema"], scratchpad=scratchpad, - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + bypass_cache=self.config.get("bypass_cache", False), + verbose=self.config.get("verbose", False), ) - folded_output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - folded_output.update(dict(zip(self.config["reduce_key"], key))) - fold_cost = completion_cost(response) end_time = time.time() self._update_fold_time(end_time - start_time) - if self.runner.api.validate_output(self.config, folded_output, self.console): + if response.validated: + folded_output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + + folded_output.update(dict(zip(self.config["reduce_key"], key))) + fold_cost = response.total_cost + return folded_output, fold_cost + return None, fold_cost def _merge_results( @@ -766,20 +789,34 @@ def _merge_results( "merge", [{"role": "user", "content": merge_prompt}], self.config["output"]["schema"], - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + bypass_cache=self.config.get("bypass_cache", False), + verbose=self.config.get("verbose", False), ) - merged_output = self.runner.api.parse_llm_response( - response, self.config["output"]["schema"] - )[0] - merged_output.update(dict(zip(self.config["reduce_key"], key))) - merge_cost = completion_cost(response) + end_time = time.time() self._update_merge_time(end_time - start_time) - if self.runner.api.validate_output(self.config, merged_output, self.console): + if response.validated: + merged_output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + merged_output.update(dict(zip(self.config["reduce_key"], key))) + merge_cost = response.total_cost return merged_output, merge_cost + return None, merge_cost def get_fold_time(self) -> Tuple[float, bool]: @@ -854,41 +891,37 @@ def _batch_reduce( ) item_cost = 0 - if "gleaning" in self.config: - response, gleaning_cost = self.runner.api.call_llm_with_gleaning( - self.config.get("model", self.default_model), - "reduce", - [{"role": "user", "content": prompt}], - self.config["output"]["schema"], - self.config["gleaning"]["validation_prompt"], - self.config["gleaning"]["num_rounds"], - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), - verbose=self.config.get("verbose", False), - ) - item_cost += gleaning_cost - else: - response = self.runner.api.call_llm( - self.config.get("model", self.default_model), - "reduce", - [{"role": "user", "content": prompt}], - self.config["output"]["schema"], - console=self.console, - scratchpad=scratchpad, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), - ) + response = self.runner.api.call_llm( + self.config.get("model", self.default_model), + "reduce", + [{"role": "user", "content": prompt}], + self.config["output"]["schema"], + scratchpad=scratchpad, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + gleaning_config=self.config.get("gleaning", None), + verbose=self.config.get("verbose", False), + ) - item_cost += completion_cost(response) + item_cost += response.total_cost - output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - output.update(dict(zip(self.config["reduce_key"], key))) + if response.validated: + output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + output.update(dict(zip(self.config["reduce_key"], key))) - if self.runner.api.validate_output(self.config, output, self.console): return output, item_cost return None, item_cost diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index e5f58121..0184f8b2 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -56,12 +56,13 @@ def compare_pair( {"is_match": "bool"}, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, + bypass_cache=self.config.get("bypass_cache", False), ) output = self.runner.api.parse_llm_response( - response, + response.response, {"is_match": "bool"}, )[0] - return output["is_match"], completion_cost(response) + return output["is_match"], response.total_cost def syntax_check(self) -> None: """ @@ -169,6 +170,15 @@ def syntax_check(self) -> None: if self.config["limit_comparisons"] <= 0: raise ValueError("'limit_comparisons' must be a positive integer") + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ Executes the resolve operation on the provided dataset. @@ -401,22 +411,28 @@ def process_cluster(cluster): "reduce", [{"role": "user", "content": resolution_prompt}], self.config["output"]["schema"], - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get( "max_retries_per_timeout", 2 ), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), ) - reduction_output = self.runner.api.parse_llm_response( - reduction_response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - reduction_cost = completion_cost(reduction_response) - - if self.runner.api.validate_output( - self.config, reduction_output, self.console - ): + reduction_cost = reduction_response.total_cost + + if reduction_response.validated: + reduction_output = self.runner.api.parse_llm_response( + reduction_response.response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] return ( [ { diff --git a/docetl/operations/sample.py b/docetl/operations/sample.py new file mode 100644 index 00000000..083870ff --- /dev/null +++ b/docetl/operations/sample.py @@ -0,0 +1,53 @@ +import sklearn.model_selection +from typing import Any, Dict, List, Optional, Tuple +from .base import BaseOperation + + +class SampleOperation(BaseOperation): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def syntax_check(self) -> None: + """ + Checks the configuration of the SampleOperation for required keys and valid structure. + + Raises: + ValueError: If required keys are missing or invalid in the configuration. + TypeError: If configuration values have incorrect types. + """ + pass + + def execute( + self, input_data: List[Dict], is_build: bool = False + ) -> Tuple[List[Dict], float]: + """ + Executes the sample operation on the input data. + + Args: + input_data (List[Dict]): A list of dictionaries to process. + is_build (bool): Whether the operation is being executed + in the build phase. Defaults to False. + + Returns: + Tuple[List[Dict], float]: A tuple containing the filtered + list of dictionaries and the total cost of the operation. + """ + + samples = self.config["samples"] + if isinstance(samples, list): + output_data = [input_data[sample] + for sample in samples] + else: + stratify=None + if "stratify" in self.config: + stratify = [data[self.config["stratify"]] for data in input_data] + output_data, dummy = sklearn.model_selection.train_test_split( + input_data, + train_size = samples, + random_state = self.config.get("random_state", None), + stratify = stratify) + return output_data, 0 diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index 163015ed..f7313690 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -20,6 +20,7 @@ from rich.console import Console from rich.prompt import Prompt from tqdm import tqdm +from pydantic import BaseModel from docetl.utils import completion_cost, count_tokens import time @@ -36,6 +37,12 @@ cache.close() +class LLMResult(BaseModel): + response: Any + total_cost: float + validated: bool + + def freezeargs(func): """ Decorator to convert mutable dictionary arguments into immutable. @@ -412,9 +419,7 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]: return result - # TODO: optimize this - @freezeargs - def cached_call_llm( + def _cached_call_llm( self, cache_key: str, model: str, @@ -423,11 +428,15 @@ def cached_call_llm( output_schema: Dict[str, str], tools: Optional[str] = None, scratchpad: Optional[str] = None, - ) -> str: + validation_config: Optional[Dict[str, Any]] = None, + gleaning_config: Optional[Dict[str, Any]] = None, + verbose: bool = False, + bypass_cache: bool = False, + ) -> LLMResult: """ Cached version of the call_llm function. - This function serves as a cached wrapper around call_llm_with_cache. It uses + This function serves as a cached wrapper around _call_llm_with_cache. It uses the @freezeargs decorator to ensure immutable arguments and @functools.lru_cache for caching results. @@ -439,80 +448,170 @@ def cached_call_llm( output_schema (Dict[str, str]): The output schema dictionary. tools (Optional[str]): The tools to pass to the LLM. scratchpad (Optional[str]): The scratchpad to use for the operation. + validation_config (Optional[Dict[str, Any]]): The validation configuration. + gleaning_config (Optional[Dict[str, Any]]): The gleaning configuration. + verbose (bool): Whether to print verbose output. + bypass_cache (bool): Whether to bypass the cache. Returns: - str: The result from call_llm_with_cache. + LLMResult: The response from _call_llm_with_cache. """ + total_cost = 0.0 + validated = False with cache as c: - result = c.get(cache_key) - if result is None: - result = self.call_llm_with_cache( + response = c.get(cache_key) + if response is not None and not bypass_cache: + validated = True + else: + response = self._call_llm_with_cache( model, op_type, messages, output_schema, tools, scratchpad ) - # Only set the cache if the result tool calls or output is not empty - if ( - result - and "tool_calls" in dir(result.choices[0].message) - and result.choices[0].message.tool_calls - ): - c.set(cache_key, result) + total_cost += completion_cost(response) + + if gleaning_config: + # Retry gleaning prompt + regular LLM + num_gleaning_rounds = gleaning_config.get("num_rounds", 2) + validator_prompt_template = Template(gleaning_config["prompt"]) + + parsed_output = self.parse_llm_response( + response, output_schema, tools + )[0] + + validator_messages = ( + [ + { + "role": "system", + "content": f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation.", + } + ] + + messages + + [ + {"role": "assistant", "content": json.dumps(parsed_output)}, + ] + ) - return result + for rnd in range(num_gleaning_rounds): + # Prepare validator prompt + validator_prompt = validator_prompt_template.render( + output=parsed_output + ) + self.runner.rate_limiter.try_acquire("llm_call", weight=1) + + validator_response = completion( + model=gleaning_config.get("model", model), + messages=truncate_messages( + validator_messages + + [{"role": "user", "content": validator_prompt}], + model, + ), + response_format={ + "type": "json_schema", + "json_schema": { + "name": "response", + "strict": True, + "schema": { + "type": "object", + "properties": { + "should_refine": {"type": "boolean"}, + "improvements": {"type": "string"}, + }, + "required": ["should_refine", "improvements"], + "additionalProperties": False, + }, + }, + }, + ) + total_cost += completion_cost(validator_response) - def call_llm_with_validation( - self, - messages: List[str], - model: str, - operation_type: str, - schema: Dict[str, str], - llm_call_fn: Callable, - validation_fn: Callable, - val_rule: str, - num_retries: int, - console: Console, - scratchpad: Optional[str] = None, - ) -> Tuple[Any, float, bool]: - num_tries = num_retries + 1 - cost = 0.0 + # Parse the validator response + suggestion = json.loads( + validator_response.choices[0].message.content + ) + if not suggestion["should_refine"]: + break - key = cache_key(model, operation_type, messages, schema, scratchpad) + if verbose: + self.runner.console.log( + f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}" + ) - for i in range(num_tries): - response = llm_call_fn(messages) - if isinstance(response, tuple): - response, curr_cost = response - cost += curr_cost + # Prompt for improvement + improvement_prompt = f"""Based on the validation feedback: - cost += completion_cost(response) + ``` + {suggestion['improvements']} + ``` - parsed_output, result = validation_fn(response) + Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation.""" + messages.append({"role": "user", "content": improvement_prompt}) - if result: - return parsed_output, cost, True + # Call LLM again + response = self._call_llm_with_cache( + model, op_type, messages, output_schema, tools, scratchpad + ) + parsed_output = self.parse_llm_response( + response, output_schema, tools + )[0] + validator_messages[-1] = [ + {"role": "assistant", "content": json.dumps(parsed_output)}, + ] + + total_cost += completion_cost(response) + + validated = True + + # If there's validation, handle it here + elif validation_config: + num_tries = validation_config.get("num_retries", 2) + validation_fn = validation_config.get("validation_fn") + val_rule = validation_config.get("val_rule") + + # Try validation + i = 0 + validation_result = False + while not validation_result and i < num_tries: + parsed_output, validation_result = validation_fn(response) + if validation_result: + validated = True + break + + # Append the validation result to messages + messages.append( + { + "role": "assistant", + "content": json.dumps(parsed_output), + } + ) + messages.append( + { + "role": "user", + "content": f"Your output {parsed_output} failed my validation rule: {str(val_rule)}\n\nPlease try again.", + } + ) + self.runner.console.log( + f"[bold red]Validation failed:[/bold red] {val_rule}\n" + f"\t[yellow]Output:[/yellow] {parsed_output}\n" + f"\t({i + 1}/{num_tries})" + ) + i += 1 - # Remove from cache - with cache as c: - c.delete(key) + response = self._call_llm_with_cache( + model, op_type, messages, output_schema, tools, scratchpad + ) + total_cost += completion_cost(response) - # Append the validation result to messages - messages.append( - { - "role": "assistant", - "content": json.dumps(parsed_output), - } - ) - messages.append( - { - "role": "user", - "content": f"Your output {parsed_output} failed my validation rule: {str(val_rule)}\n\nPlease try again.", - } - ) - console.log( - f"[bold red]Validation failed:[/bold red] {val_rule}\n" - f"\t[yellow]Output:[/yellow] {parsed_output}\n" - f"\tTrying again... ({i + 1}/{num_tries})" - ) + else: + # No validation, so we assume the result is valid + validated = True - return parsed_output, cost, False + # Only set the cache if the result tool calls or output is not empty + if ( + response + and "tool_calls" in dir(response.choices[0].message) + and response.choices[0].message.tool_calls + ): + c.set(cache_key, response) + + return LLMResult(response=response, total_cost=total_cost, validated=validated) def call_llm( self, @@ -522,10 +621,13 @@ def call_llm( output_schema: Dict[str, str], tools: Optional[List[Dict[str, str]]] = None, scratchpad: Optional[str] = None, - console: Console = Console(), timeout_seconds: int = 120, max_retries_per_timeout: int = 2, - ) -> Any: + validation_config: Optional[Dict[str, Any]] = None, + gleaning_config: Optional[Dict[str, Any]] = None, + verbose: bool = False, + bypass_cache: bool = False, + ) -> LLMResult: """ Wrapper function that uses caching for LLM calls. @@ -541,8 +643,9 @@ def call_llm( scratchpad (Optional[str]): The scratchpad to use for the operation. timeout_seconds (int): The timeout for the LLM call. max_retries_per_timeout (int): The maximum number of retries per timeout. + bypass_cache (bool): Whether to bypass the cache. Returns: - str: The result from the cached LLM call. + LLMResult: The result from the cached LLM call. Raises: TimeoutError: If the call times out after retrying. @@ -554,7 +657,7 @@ def call_llm( rate_limited_attempt = 0 while attempt <= max_retries: try: - return timeout(timeout_seconds)(self.cached_call_llm)( + return timeout(timeout_seconds)(self._cached_call_llm)( key, model, op_type, @@ -562,6 +665,10 @@ def call_llm( output_schema, json.dumps(tools) if tools else None, scratchpad, + validation_config=validation_config, + gleaning_config=gleaning_config, + verbose=verbose, + bypass_cache=bypass_cache, ) except RateLimitError: # TODO: this is a really hacky way to handle rate limits @@ -569,21 +676,21 @@ def call_llm( backoff_time = 4 * (2**rate_limited_attempt) # Exponential backoff max_backoff = 120 # Maximum backoff time of 60 seconds sleep_time = min(backoff_time, max_backoff) - console.log( + self.runner.console.log( f"[yellow]Rate limit hit. Retrying in {sleep_time:.2f} seconds...[/yellow]" ) time.sleep(sleep_time) rate_limited_attempt += 1 except TimeoutError: if attempt == max_retries: - console.log( + self.runner.console.log( f"[bold red]LLM call timed out after {max_retries + 1} attempts[/bold red]" ) # TODO: HITL - return {} + return LLMResult(response=None, total_cost=0.0, validated=False) attempt += 1 - def call_llm_with_cache( + def _call_llm_with_cache( self, model: str, op_type: str, @@ -591,7 +698,7 @@ def call_llm_with_cache( output_schema: Dict[str, str], tools: Optional[str] = None, scratchpad: Optional[str] = None, - ) -> str: + ) -> Any: """ Make an LLM call with caching. @@ -680,7 +787,6 @@ def call_llm_with_cache( Update the 'updated_scratchpad' field in your output with the new scratchpad content. Remember: The scratchpad should contain information necessary for processing future batches, not the final result.""" - messages = json.loads(messages) # Truncate messages if they exceed the model's context length messages = truncate_messages(messages, model) @@ -713,173 +819,6 @@ def call_llm_with_cache( return response - def call_llm_with_gleaning( - self, - model: str, - op_type: str, - messages: List[Dict[str, str]], - output_schema: Dict[str, str], - validator_prompt_template: str, - num_gleaning_rounds: int, - console: Console = Console(), - timeout_seconds: int = 120, - max_retries_per_timeout: int = 2, - verbose: bool = False, - ) -> Tuple[str, float]: - """ - Call LLM with a gleaning process, including validation and improvement rounds. - - This function performs an initial LLM call, followed by multiple rounds of - validation and improvement based on the validator prompt template. - - Args: - model (str): The model name. - op_type (str): The operation type. - messages (List[Dict[str, str]]): The messages to send to the LLM. - output_schema (Dict[str, str]): The output schema dictionary. - validator_prompt_template (str): Template for the validator prompt. - num_gleaning_rounds (int): Number of gleaning rounds to perform. - timeout_seconds (int): The timeout for the LLM call. - Returns: - Tuple[str, float]: A tuple containing the final LLM response and the total cost. - """ - if not litellm.supports_function_calling(model): - raise ValueError( - f"Model {model} does not support function calling (which we use for structured outputs). Please use a different model." - ) - - props = {key: convert_val(value) for key, value in output_schema.items()} - - parameters = {"type": "object", "properties": props} - parameters["required"] = list(props.keys()) - parameters["additionalProperties"] = False - - # Initial LLM call - response = self.call_llm( - model, - op_type, - messages, - output_schema, - console=console, - timeout_seconds=timeout_seconds, - max_retries_per_timeout=max_retries_per_timeout, - ) - - cost = 0.0 - - # Parse the response - parsed_response = self.parse_llm_response(response, output_schema) - output = parsed_response[0] - - messages = ( - [ - { - "role": "system", - "content": f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation.", - } - ] - + messages - + [ - {"role": "assistant", "content": json.dumps(output)}, - ] - ) - - for rnd in range(num_gleaning_rounds): - cost += completion_cost(response) - - # Prepare validator prompt - validator_template = Template(validator_prompt_template) - validator_prompt = validator_template.render(output=output) - - # Call LLM for validation - self.runner.rate_limiter.try_acquire("llm_call", weight=1) - validator_response = completion( - model=model, - messages=truncate_messages( - messages + [{"role": "user", "content": validator_prompt}], model - ), - response_format={ - "type": "json_schema", - "json_schema": { - "name": "response", - "strict": True, - "schema": { - "type": "object", - "properties": { - "should_refine": {"type": "boolean"}, - "improvements": {"type": "string"}, - }, - "required": ["should_refine", "improvements"], - "additionalProperties": False, - }, - }, - }, - ) - cost += completion_cost(validator_response) - - # Parse the validator response - suggestion = json.loads(validator_response.choices[0].message.content) - if not suggestion["should_refine"]: - break - - if verbose: - console.log( - f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}" - ) - - # Prompt for improvement - improvement_prompt = f"""Based on the validation feedback: - - ``` - {suggestion['improvements']} - ``` - - Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation.""" - messages.append({"role": "user", "content": improvement_prompt}) - - # Call LLM for improvement - # TODO: support gleaning and tools - self.runner.rate_limiter.try_acquire("llm_call", weight=1) - response = completion( - model=model, - messages=truncate_messages(messages, model), - # response_format={ - # "type": "json_schema", - # "json_schema": { - # "name": "write_output", - # "description": "Write processing output to a database", - # "strict": True, - # "schema": parameters, - # # "additionalProperties": False, - # }, - # }, - tools=[ - { - "type": "function", - "function": { - "name": "send_output", - "description": "Send output back to the user", - "strict": True, - "parameters": parameters, - "additionalProperties": False, - }, - } - ], - tool_choice={"type": "function", "function": {"name": "send_output"}}, - ) - - # Update messages with the new response - messages.append( - { - "role": "assistant", - "content": json.dumps( - self.parse_llm_response(response, output_schema)[0] - ), - } - ) - - return response, cost - def parse_llm_response( self, response: Any, @@ -892,7 +831,7 @@ def parse_llm_response( This function extracts the tool calls from the LLM response and returns the arguments """ try: - return self.parse_llm_response_helper(response, schema, tools) + return self._parse_llm_response_helper(response, schema, tools) except InvalidOutputError as e: if manually_fix_errors: rprint( @@ -909,7 +848,7 @@ def parse_llm_response( else: raise e - def parse_llm_response_helper( + def _parse_llm_response_helper( self, response: Any, schema: Dict[str, Any] = {}, diff --git a/docs/operators/sample.md b/docs/operators/sample.md new file mode 100644 index 00000000..3d5c0a91 --- /dev/null +++ b/docs/operators/sample.md @@ -0,0 +1,43 @@ +# Sample operation + +The Sample operation in DocETL samples items from the input. It is +meant mostly as a debugging tool: + +Insert it before the last operation, the one you're currently trying +to tack on to the end of a working pipeline, to limit the amount of +data it will be fed, so that the run time is small enough to +comfortably debug its prompt. Once it seems to be working, you can +remove the sample operation. You can then repeat this for each +operation you add while developing your pipeline! + + + +## 🚀 Example: + +```yaml +- name: cluster_concepts + type: sample + samples: 0.1 + random_state: 42 + stratify: category +``` + +This sample operation will return a pseudo-randomly selected 10% of +the samples (`samples: 0.1`). The random selection will be seeded with +a constant (42), meaning the same selection will be returned if you +rerun the pipeline (If no random state is given, a different sample +will be returned every time). Additionally, the random sampling will +sample each value of the `category` key equally. + +## Required Parameters + +- `name`: A unique name for the operation. +- `type`: Must be set to "sample". +- `samples`: Either a list of sample indices to just return those samples, an integer count of samples, or a float fraction of samples. + +## Optional Parameters + +| Parameter | Description | Default | +| ------------------------- | -------------------------------------------------------------------------------- | ----------------------------- | +| `random_state | An integer to seed the random generator with | Use the (numpy) global random state +| `stratify` | The key to stratify by | | diff --git a/pyproject.toml b/pyproject.toml index f0d65afd..09d81d85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ resolve = "docetl.operations.resolve:ResolveOperation" gather = "docetl.operations.gather:GatherOperation" cluster = "docetl.operations.cluster:ClusterOperation" outliers = "docetl.operations.outliers:OutliersOperation" +sample = "docetl.operations.sample:SampleOperation" [tool.poetry.plugins."docetl.parser"] llama_index_simple_directory_reader = "docetl.parsing_tools:llama_index_simple_directory_reader" diff --git a/tests/basic/test_basic_filter_split_gather.py b/tests/basic/test_basic_filter_split_gather.py index dcb65694..1c4a3ca0 100644 --- a/tests/basic/test_basic_filter_split_gather.py +++ b/tests/basic/test_basic_filter_split_gather.py @@ -42,7 +42,6 @@ def test_filter_operation( assert len(results) < len(filter_sample_data) assert all(len(result["text"].split()) > 3 for result in results) - assert cost > 0 def test_filter_operation_empty_input( @@ -192,7 +191,6 @@ def test_equijoin_operation( assert len(results) == 2 # Only 2 matches assert all("name" in result and "email" in result for result in results) - assert cost > 0 def test_equijoin_operation_empty_input( diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index e8c4ce9f..21a21d95 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -23,14 +23,12 @@ def test_map_operation( map_sample_data, ): results, cost = test_map_operation_instance.execute(map_sample_data) - print(results) assert len(results) == len(map_sample_data) assert all("sentiment" in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results ) - assert cost > 0 def test_map_operation_empty_input(map_config, default_model, max_threads, api_wrapper): @@ -48,6 +46,7 @@ def test_map_operation_with_drop_keys( map_sample_data_with_extra_keys, api_wrapper, ): + map_config_with_drop_keys["bypass_cache"] = True operation = MapOperation( api_wrapper, map_config_with_drop_keys, default_model, max_threads ) @@ -55,11 +54,12 @@ def test_map_operation_with_drop_keys( assert len(results) == len(map_sample_data_with_extra_keys) assert all("sentiment" in result for result in results) - assert all("original_sentiment" not in result for result in results) - assert all("to_be_dropped" in result for result in results) + assert all("original_sentiment" in result for result in results) + assert all("to_be_dropped" not in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results ) + assert cost > 0 @@ -95,7 +95,6 @@ def test_map_operation_with_batching( results, cost = operation.execute(map_sample_data) assert len(results) == len(map_sample_data) - assert cost > 0 assert all("sentiment" in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results @@ -128,7 +127,6 @@ def test_map_operation_with_large_max_batch_size( results, cost = operation.execute(map_sample_data) assert len(results) == len(map_sample_data) - assert cost > 0 def test_map_operation_with_word_count_tool( @@ -140,7 +138,6 @@ def test_map_operation_with_word_count_tool( assert len(results) == len(synthetic_data) assert all("word_count" in result for result in results) assert [result["word_count"] for result in results] == [5, 6, 5, 1] - assert cost > 0 # Ensure there was some cost associated with the operation @pytest.fixture @@ -185,8 +182,8 @@ def test_map_operation_with_timeout(simple_map_config, simple_sample_data, api_w operation = MapOperation(api_wrapper, map_config_with_timeout, "gpt-4o-mini", 4) # Execute the operation and expect empty results - with pytest.raises(docetl.operations.utils.InvalidOutputError): - operation.execute(simple_sample_data) + results, cost = operation.execute(simple_sample_data) + assert len(results) == 0 def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wrapper): @@ -215,6 +212,3 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra assert all( any(vs in result["sentiment"] for vs in valid_sentiments) for result in results ) - - # Assert that the operation had a cost - assert cost > 0 diff --git a/tests/basic/test_basic_parallel_map.py b/tests/basic/test_basic_parallel_map.py index 926fe20e..5edcaa2c 100644 --- a/tests/basic/test_basic_parallel_map.py +++ b/tests/basic/test_basic_parallel_map.py @@ -22,6 +22,7 @@ def test_parallel_map_operation( parallel_map_sample_data, api_wrapper, ): + parallel_map_config["bypass_cache"] = True operation = ParallelMapOperation( api_wrapper, parallel_map_config, default_model, max_threads ) diff --git a/tests/basic/test_basic_reduce_resolve.py b/tests/basic/test_basic_reduce_resolve.py index a2f8ab59..a0a93c2f 100644 --- a/tests/basic/test_basic_reduce_resolve.py +++ b/tests/basic/test_basic_reduce_resolve.py @@ -42,6 +42,7 @@ def reduce_sample_data_with_list_key(): def test_reduce_operation( reduce_config, default_model, max_threads, reduce_sample_data, api_wrapper ): + reduce_config["bypass_cache"] = True operation = ReduceOperation(api_wrapper, reduce_config, default_model, max_threads) results, cost = operation.execute(reduce_sample_data) @@ -61,7 +62,6 @@ def test_reduce_operation_with_all_key( results, cost = operation.execute(reduce_sample_data) assert len(results) == 1 - assert cost > 0 def test_reduce_operation_with_list_key( @@ -84,7 +84,6 @@ def test_reduce_operation_with_list_key( and "avg" in result for result in results ) - assert cost > 0 def test_reduce_operation_empty_input( @@ -134,7 +133,6 @@ def test_resolve_operation( distinct_names = set(result["name"] for result in results) assert len(distinct_names) < len(results) - assert cost > 0 def test_resolve_operation_empty_input(resolve_config, max_threads, api_wrapper): diff --git a/tests/basic/test_cluster.py b/tests/basic/test_cluster.py index 3db8424b..0bcaa717 100644 --- a/tests/basic/test_cluster.py +++ b/tests/basic/test_cluster.py @@ -17,11 +17,10 @@ def cluster_config(): these two concepts already encompasses the other; in that case, you should just use that concept. - {{left.concept}}: - {{left.description}} - - {{right.concept}}: - {{right.description}} + {% for input in inputs %} + {{input.concept}}: + {{input.description}} + {% endfor %} Provide the title of the super-concept, and a description. """, @@ -70,6 +69,7 @@ def sample_data(): def test_cluster_operation( cluster_config, sample_data, api_wrapper, default_model, max_threads ): + cluster_config["bypass_cache"] = True operation = ClusterOperation( api_wrapper, cluster_config, default_model, max_threads ) diff --git a/tests/conftest.py b/tests/conftest.py index 231775c7..f92c65dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def map_config_with_drop_keys(): "prompt": "Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", "output": {"schema": {"sentiment": "string"}}, "model": "gpt-4o-mini", - "drop_keys": ["original_sentiment"], + "drop_keys": ["to_be_dropped"], } diff --git a/tests/test_api.py b/tests/test_api.py index 921c2de9..64cc0243 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -257,7 +257,6 @@ def test_pipeline_execution( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_parallel_map_pipeline( @@ -283,7 +282,6 @@ def test_parallel_map_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_filter_pipeline( @@ -309,7 +307,6 @@ def test_filter_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_reduce_pipeline( @@ -333,7 +330,6 @@ def test_reduce_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_resolve_pipeline( @@ -359,7 +355,6 @@ def test_resolve_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_equijoin_pipeline( @@ -404,4 +399,3 @@ def test_equijoin_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 diff --git a/tests/test_eugene.py b/tests/test_eugene.py index dd08f6dc..b9928c17 100644 --- a/tests/test_eugene.py +++ b/tests/test_eugene.py @@ -185,5 +185,3 @@ def test_database_survey_pipeline( assert all("summary" in result for result in summarized_results) total_cost = extract_cost + unnest_cost + resolve_cost + summarize_cost - assert total_cost > 0 - print(total_cost) diff --git a/tests/test_reduce_scale.py b/tests/test_reduce_scale.py index 2d054388..412511a7 100644 --- a/tests/test_reduce_scale.py +++ b/tests/test_reduce_scale.py @@ -90,7 +90,6 @@ def test_reduce_operation( results, cost = operation.execute(reduce_sample_data) assert len(results) == 3, "Should have results for 3 unique categories" - assert cost > 0, "Cost should be greater than 0" for result in results: assert "category" in result, "Each result should have a 'category' key" @@ -112,7 +111,6 @@ def test_reduce_operation_pass_through( results, cost = operation.execute(reduce_sample_data) assert len(results) == 3, "Should have results for 3 unique categories" - assert cost > 0, "Cost should be greater than 0" for result in results: assert "category" in result, "Each result should have a 'category' key" @@ -176,7 +174,6 @@ def test_reduce_operation_non_associative(api_wrapper, default_model, max_thread results, cost = operation.execute(sample_data) assert len(results) == 1, "Should have one result for the 'story' sequence" - assert cost > 0, "Cost should be greater than 0" result = results[0] assert "combined_result" in result, "Result should have a 'combined_result' key" @@ -231,7 +228,6 @@ def test_reduce_operation_persist_intermediates( results, cost = operation.execute(sample_data) assert len(results) == 1, "Should have one result for the 'numbers' group" - assert cost > 0, "Cost should be greater than 0" result = results[0] assert "summary" in result, "Result should have a 'summary' key" diff --git a/tests/test_validation.py b/tests/test_validation.py index ed64af4c..fc6d924b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -31,6 +31,7 @@ def sample_data(): def test_map_operation_with_validation( map_config_with_validation, sample_data, api_wrapper, default_model, max_threads ): + map_config_with_validation["bypass_cache"] = True operation = MapOperation( api_wrapper, map_config_with_validation, default_model, max_threads )