Skip to content

Commit

Permalink
Merge branch 'staging' into outliers
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar authored Oct 12, 2024
2 parents b2ee5a2 + 70b7ddc commit 40c853e
Show file tree
Hide file tree
Showing 20 changed files with 519 additions and 500 deletions.
92 changes: 64 additions & 28 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from jinja2 import Environment, Template
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -99,6 +100,9 @@ def execute(
)

tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings)

if "collapse" in self.config:
tree = self.collapse_tree(tree, collapse = self.config["collapse"])

self.prompt_template = Template(self.config["summary_prompt"])
cost += self.annotate_clustering_tree(tree)
Expand All @@ -122,7 +126,7 @@ def build_tree(i):
# res["embedding"] = list(embeddings[i])
return res
return {
"children": [
"children": [
build_tree(cl.children_[i - nsamples, 0]),
build_tree(cl.children_[i - nsamples, 1]),
],
Expand All @@ -131,6 +135,40 @@ def build_tree(i):

return build_tree(nsamples + len(cl.children_) - 1)

def get_tree_distances(self, t):
res = set()
if "distance" in t:
res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child]))
if "children" in t:
for child in t["children"]:
res.update(self.get_tree_distances(child))
return res

def _collapse_tree(self, t, parent_dist = None, collapse = None):
if "children" in t:
if ( "distance" in t
and parent_dist is not None
and collapse is not None
and parent_dist - t["distance"] < collapse):
return [grandchild
for child in t["children"]
for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)]
else:
res = dict(t)
res["children"] = [grandchild
for idx, child in enumerate(t["children"])
for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)]
return [res]
else:
return [t]

def collapse_tree(self, tree, collapse = None):
if collapse is not None:
tree_distances = np.array(sorted(self.get_tree_distances(tree)))
collapse = tree_distances[int(len(tree_distances) * collapse)]
return self._collapse_tree(tree, collapse=collapse)[0]


def annotate_clustering_tree(self, t):
if "children" in t:
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
Expand All @@ -149,12 +187,8 @@ def annotate_clustering_tree(self, t):
total_cost += futures[i].result()
pbar.update(i)

assert len(t["children"]) == 2, (
"Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s"
% len(t["children"])
)
prompt = self.prompt_template.render(
left=t["children"][0], right=t["children"][1]
inputs=t["children"]
)

def validation_fn(response: Dict[str, Any]):
Expand All @@ -167,31 +201,33 @@ def validation_fn(response: Dict[str, Any]):
return output, True
return output, False

output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
response = self.runner.api.call_llm(
model=self.config.get("model", self.default_model),
operation_type="cluster",
schema=self.config["summary_schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"cluster",
messages,
self.config["summary_schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
op_type="cluster",
messages=[{"role": "user", "content": prompt}],
output_schema=self.config["summary_schema"],
timeout_seconds=self.config.get("timeout", 120),
bypass_cache=self.config.get("bypass_cache", False),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
validation_config=(
{
"num_retries": self.num_retries_on_validate_failure,
"val_rule": self.config.get("validate", []),
"validation_fn": validation_fn,
}
if self.config.get("validate", None)
else None
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
verbose=self.config.get("verbose", False),
)
total_cost += cost

t.update(output)
total_cost += response.total_cost
if response.validated:
output = self.runner.api.parse_llm_response(
response.response,
schema=self.config["summary_schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
t.update(output)

return total_cost
return 0
Expand Down
7 changes: 5 additions & 2 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def compare_pair(
{"is_match": "bool"},
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
bypass_cache=self.config.get("bypass_cache", False),
)
output = self.runner.api.parse_llm_response(response, {"is_match": "bool"})[0]
return output["is_match"], completion_cost(response)
output = self.runner.api.parse_llm_response(
response.response, {"is_match": "bool"}
)[0]
return output["is_match"], response.total_cost

def syntax_check(self) -> None:
"""
Expand Down
78 changes: 5 additions & 73 deletions docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from jinja2 import Template

from docetl.operations.base import BaseOperation
from docetl.operations.map import MapOperation
from docetl.operations.utils import (
RichLoopBar,
)


class FilterOperation(BaseOperation):
class FilterOperation(MapOperation):
def syntax_check(self) -> None:
"""
Checks the configuration of the FilterOperation for required keys and valid structure.
Expand Down Expand Up @@ -110,77 +110,9 @@ def execute(
)
)

if self.status:
self.status.start()

def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
for key, value in item.items():
if key not in self.config["output"]["schema"]:
output[key] = value
if self.runner.api.validate_output(self.config, output, self.console):
return output, True
return output, False

output, cost, is_valid = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="filter",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"filter",
messages,
self.config["output"]["schema"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
results, total_cost = super().execute(input_data)

if is_valid:
return output, cost

return None, cost

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = [
executor.submit(_process_filter_item, item) for item in input_data
]
results = []
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (filter) on all documents",
console=self.console,
)
for i in pbar:
future = futures[i]
result, item_cost = future.result()
total_cost += item_cost
if result is not None:
if is_build:
results.append(result)
else:
if result.get(filter_key, False):
results.append(result)
pbar.update(1)

if self.status:
self.status.start()
# Drop records with filter_key values that are False
results = [result for result in results if result[filter_key]]

return results, total_cost
91 changes: 37 additions & 54 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,59 +153,42 @@ def validation_fn(response: Dict[str, Any]):
return output, False

self.runner.rate_limiter.try_acquire("call", weight=1)
if "gleaning" in self.config:
output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="map",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm_with_gleaning(
self.config.get("model", self.default_model),
"map",
messages,
self.config["output"]["schema"],
self.config["gleaning"]["validation_prompt"],
self.config["gleaning"]["num_rounds"],
self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
verbose=self.config.get("verbose", False),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
else:
output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="map",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
messages,
self.config["output"]["schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
llm_result = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
[{"role": "user", "content": prompt}],
self.config["output"]["schema"],
tools=self.config.get("tools", None),
scratchpad=None,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
validation_config=(
{
"num_retries": self.num_retries_on_validate_failure,
"val_rule": self.config.get("validate", []),
"validation_fn": validation_fn,
}
if self.config.get("validate", None)
else None
),
gleaning_config=self.config.get("gleaning", None),
verbose=self.config.get("verbose", False),
bypass_cache=self.config.get("bypass_cache", False),
)

if success:
return output, cost
if llm_result.validated:
# Parse the response
output = self.runner.api.parse_llm_response(
llm_result.response,
schema=self.config["output"]["schema"],
tools=self.config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
# Augment the output with the original item
output = {**item, **output}
return output, llm_result.total_cost

return None, cost
return None, llm_result.total_cost

with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
futures = [executor.submit(_process_map_item, item) for item in input_data]
Expand Down Expand Up @@ -375,17 +358,17 @@ def process_prompt(item, prompt_config):
[{"role": "user", "content": prompt}],
local_output_schema,
tools=prompt_config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
bypass_cache=self.config.get("bypass_cache", False),
)
output = self.runner.api.parse_llm_response(
response,
response.response,
schema=local_output_schema,
tools=prompt_config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
return output, completion_cost(response)
return output, response.total_cost

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
if "prompts" in self.config:
Expand Down
Loading

0 comments on commit 40c853e

Please sign in to comment.