Skip to content

Commit

Permalink
Merge pull request #64 from redhog/throttle
Browse files Browse the repository at this point in the history
Throttle
  • Loading branch information
shreyashankar authored Oct 7, 2024
2 parents efc4291 + 69b491e commit 9524e2e
Show file tree
Hide file tree
Showing 29 changed files with 1,235 additions and 940 deletions.
69 changes: 69 additions & 0 deletions .github/workflows/stage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: Create or Update PR from staging to main

on:
push:
branches:
- staging
pull_request:
types:
- closed
branches:
- staging

jobs:
create-or-update-pr:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0

- name: Check for existing PR
id: check_pr
uses: actions/github-script@v6
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const { data: pullRequests } = await github.rest.pulls.list({
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
head: 'staging',
base: 'main'
});
return pullRequests.length > 0 ? 'true' : 'false';
- name: Create Pull Request
if: steps.check_pr.outputs.result == 'false'
uses: repo-sync/pull-request@v2
with:
source_branch: "staging"
destination_branch: "main"
pr_title: "Merge staging into main"
pr_body: "This PR was automatically created to merge changes from staging into main."
github_token: ${{ secrets.GITHUB_TOKEN }}

- name: Update Pull Request
if: steps.check_pr.outputs.result == 'true'
uses: actions/github-script@v6
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const { data: pullRequests } = await github.rest.pulls.list({
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
head: 'staging',
base: 'main'
});
if (pullRequests.length > 0) {
const prNumber = pullRequests[0].number;
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
body: 'This PR has been automatically updated with the latest changes from staging.'
});
console.log(`Updated PR #${prNumber}`);
}
21 changes: 13 additions & 8 deletions docetl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from docetl.optimizers.map_optimizer import MapOptimizer
from docetl.optimizers.reduce_optimizer import ReduceOptimizer
from docetl.optimizers.utils import LLMClient
from docetl.utils import load_config
from docetl.config_wrapper import ConfigWrapper

install(show_locals=True)

Expand Down Expand Up @@ -77,7 +77,7 @@ def items(self):
return [(key, self[key]) for key in self.keys()]


class Optimizer:
class Optimizer(ConfigWrapper):
@classmethod
def from_yaml(cls, yaml_file: str, **kwargs):
# check that file ends with .yaml or .yml
Expand All @@ -88,8 +88,9 @@ def from_yaml(cls, yaml_file: str, **kwargs):

base_name = yaml_file.rsplit(".", 1)[0]
suffix = yaml_file.split("/")[-1].split(".")[0]
config = load_config(yaml_file)
return cls(config, base_name, suffix, **kwargs)
return super(Optimizer, cls).from_yaml(
yaml_file, base_name=base_name, yaml_file_suffix=suffix, **kwargs
)

def __init__(
self,
Expand Down Expand Up @@ -135,11 +136,9 @@ def __init__(
The method also calls print_optimizer_config() to display the initial configuration.
"""
self.config = config
self.console = Console()
ConfigWrapper.__init__(self, config, max_threads)
self.optimized_config = copy.deepcopy(self.config)
self.llm_client = LLMClient(model)
self.max_threads = max_threads or (os.cpu_count() or 1) * 4
self.operations_cost = 0
self.timeout = timeout
self.selectivities = defaultdict(dict)
Expand All @@ -163,7 +162,6 @@ def __init__(
if self.config.get("optimizer_config", {}).get("sample_sizes", {}):
self.sample_size_map.update(self.config["optimizer_config"]["sample_sizes"])

self.status = None
self.step_op_to_optimized_ops = {}

self.print_optimizer_config()
Expand Down Expand Up @@ -193,6 +191,7 @@ def syntax_check(self):
try:
operation_class = get_operation(operation_type)
operation_class(
self,
operation_config,
self.config.get("default_model", "gpt-4o-mini"),
self.max_threads,
Expand Down Expand Up @@ -973,6 +972,7 @@ def _get_sample_data(
f"Dataset '{dataset_name}' not found in config or previous steps."
)
dataset = Dataset(
runner=self,
type=dataset_config["type"],
path_or_data=dataset_config["path"],
parsing=dataset_config.get("parsing", []),
Expand Down Expand Up @@ -1115,6 +1115,7 @@ def _optimize_reduce(
List[Dict[str, Any]]: The optimized operation configuration.
"""
reduce_optimizer = ReduceOptimizer(
self,
self.config,
self.console,
self.llm_client,
Expand Down Expand Up @@ -1157,6 +1158,7 @@ def _optimize_equijoin(
new_right_name = right_name
for _ in range(max_iterations):
join_optimizer = JoinOptimizer(
self,
self.config,
op_config,
self.console,
Expand Down Expand Up @@ -1278,6 +1280,7 @@ def _optimize_map(
List[Dict[str, Any]]: The optimized operation configuration.
"""
map_optimizer = MapOptimizer(
self,
self.config,
self.console,
self.llm_client,
Expand Down Expand Up @@ -1307,6 +1310,7 @@ def _optimize_resolve(
List[Dict[str, Any]]: The optimized operation configuration.
"""
optimized_config, cost = JoinOptimizer(
self,
self.config,
op_config,
self.console,
Expand Down Expand Up @@ -1354,6 +1358,7 @@ def _run_operation(
operation_class = get_operation(op_config["type"])

oc_kwargs = {
"runner": self,
"config": op_config,
"default_model": self.config["default_model"],
"max_threads": self.max_threads,
Expand Down
68 changes: 68 additions & 0 deletions docetl/config_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
from docetl.utils import load_config
from typing import Any, Dict, List, Optional, Tuple, Union
from docetl.operations.utils import APIWrapper
from rich.console import Console
import pyrate_limiter
from inspect import isawaitable
import math


class BucketCollection(pyrate_limiter.BucketFactory):
def __init__(self, **buckets):
self.clock = pyrate_limiter.TimeClock()
self.buckets = buckets

def wrap_item(self, name: str, weight: int = 1) -> pyrate_limiter.RateItem:
now = self.clock.now()

async def wrap_async():
return pyrate_limiter.RateItem(name, await now, weight=weight)

def wrap_sync():
return pyrate_limiter.RateItem(name, now, weight=weight)

return wrap_async() if isawaitable(now) else wrap_sync()

def get(self, item: pyrate_limiter.RateItem) -> pyrate_limiter.AbstractBucket:
if item.name not in self.buckets:
return self.buckets["unknown"]
return self.buckets[item.name]


class ConfigWrapper(object):
@classmethod
def from_yaml(cls, yaml_file: str, **kwargs):
config = load_config(yaml_file)
return cls(config, **kwargs)

def __init__(self, config: Dict, max_threads: int = None):
self.config = config
self.default_model = self.config.get("default_model", "gpt-4o-mini")
self.console = Console()
self.max_threads = max_threads or (os.cpu_count() or 1) * 4
self.status = None

buckets = {
param: pyrate_limiter.InMemoryBucket(
[
pyrate_limiter.Rate(
param_limit["count"],
param_limit["per"]
* getattr(
pyrate_limiter.Duration,
param_limit.get("unit", "SECOND").upper(),
),
)
for param_limit in param_limits
]
)
for param, param_limits in self.config.get("rate_limits", {}).items()
}
buckets["unknown"] = pyrate_limiter.InMemoryBucket(
[pyrate_limiter.Rate(math.inf, 1)]
)
bucket_factory = BucketCollection(**buckets)
self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf)

self.api = APIWrapper(self)
36 changes: 19 additions & 17 deletions docetl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,6 @@
from docetl.schemas import ParsingTool


def process_item(
item: Dict[str, Any],
input_key: str,
output_key: str,
func: Callable,
**function_kwargs: Dict[str, Any],
):
if input_key not in item:
raise ValueError(f"Input key {input_key} not found in item: {item}")
result = func(item[input_key], **function_kwargs)
if isinstance(result, list):
return [item.copy() | {output_key: res} for res in result]
else:
return [item | {output_key: result}]


def create_parsing_tool_map(
parsing_tools: Optional[List[ParsingTool]],
) -> Dict[str, ParsingTool]:
Expand Down Expand Up @@ -57,6 +41,7 @@ class Dataset:

def __init__(
self,
runner,
type: str,
path_or_data: Union[str, List[Dict]],
source: str = "local",
Expand All @@ -73,6 +58,7 @@ def __init__(
parsing (List[Dict[str, str]], optional): A list of parsing tools to apply to the data.
user_defined_parsing_tool_map (Dict[str, ParsingTool], optional): A map of user-defined parsing tools.
"""
self.runner = runner
self.type = self._validate_type(type)
self.source = self._validate_source(source)
self.path_or_data = self._validate_path_or_data(path_or_data)
Expand Down Expand Up @@ -224,6 +210,22 @@ def load(self) -> List[Dict]:

return self._apply_parsing_tools(data)

def _process_item(
self,
item: Dict[str, Any],
input_key: str,
output_key: str,
func: Callable,
**function_kwargs: Dict[str, Any],
):
if input_key not in item:
raise ValueError(f"Input key {input_key} not found in item: {item}")
result = func(item[input_key], **function_kwargs)
if isinstance(result, list):
return [item.copy() | {output_key: res} for res in result]
else:
return [item | {output_key: result}]

def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
"""
Apply parsing tools to the data.
Expand Down Expand Up @@ -266,7 +268,7 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
process_item,
self._process_item,
item,
input_key,
output_key,
Expand Down
3 changes: 3 additions & 0 deletions docetl/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple

from docetl.operations.utils import APIWrapper
from rich.console import Console
from rich.status import Status


class BaseOperation(ABC):
def __init__(
self,
runner: "ConfigWrapper",
config: Dict,
default_model: str,
max_threads: int,
Expand All @@ -31,6 +33,7 @@ def __init__(
"""
assert "name" in config, "Operation must have a name"
assert "type" in config, "Operation must have a type"
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
Expand Down
17 changes: 11 additions & 6 deletions docetl/operations/clustering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from typing import Dict, List, Tuple

from docetl.operations.utils import gen_embedding
from docetl.operations.utils import APIWrapper
from docetl.utils import completion_cost


def get_embeddings_for_clustering(
items: List[Dict], sampling_config: Dict
items: List[Dict], sampling_config: Dict, api_wrapper: APIWrapper
) -> Tuple[List[List[float]], float]:
embedding_model = sampling_config.get("embedding_model", "text-embedding-3-small")
embedding_keys = sampling_config.get("embedding_keys")
Expand All @@ -31,7 +31,7 @@ def get_embeddings_for_clustering(
" ".join(str(item[key]) for key in embedding_keys if key in item)[:10000]
for item in batch
]
response = gen_embedding(embedding_model, texts)
response = api_wrapper.gen_embedding(embedding_model, texts)
embeddings.extend([data["embedding"] for data in response["data"]])
cost += completion_cost(response)

Expand Down Expand Up @@ -61,7 +61,10 @@ def get_embeddings_for_clustering_with_st(


def cluster_documents(
documents: List[Dict], sampling_config: Dict, sample_size: int
documents: List[Dict],
sampling_config: Dict,
sample_size: int,
api_wrapper: APIWrapper,
) -> Tuple[Dict[int, List[Dict]], float]:
"""
Cluster documents using KMeans clustering algorithm.
Expand All @@ -70,11 +73,13 @@ def cluster_documents(
documents (List[Dict]): The list of documents to cluster.
sampling_config (Dict): The sampling configuration. Must contain embedding_model. If embedding_keys is not specified, it will use all keys in the document. If embedding_model is not specified, it will use text-embedding-3-small. If embedding_model is sentence-transformer, it will use all-MiniLM-L6-v2.
sample_size (int): The number of clusters to create.
api_wrapper (APIWrapper): The API wrapper to use for embedding.
Returns:
Dict[int, List[Dict]]: A dictionary of clusters, where each cluster is a list of documents.
"""
embeddings, cost = get_embeddings_for_clustering(documents, sampling_config)
embeddings, cost = get_embeddings_for_clustering(
documents, sampling_config, api_wrapper
)

from sklearn.cluster import KMeans

Expand Down
Loading

0 comments on commit 9524e2e

Please sign in to comment.