diff --git a/benchmark/dbally_benchmark/context_benchmark.py b/benchmark/dbally_benchmark/context_benchmark.py deleted file mode 100644 index e9108759..00000000 --- a/benchmark/dbally_benchmark/context_benchmark.py +++ /dev/null @@ -1,252 +0,0 @@ -# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, missing-class-docstring, broad-exception-caught -import asyncio -import json -import os -import traceback -import typing -from copy import deepcopy -from dataclasses import dataclass, field - -import polars as pl -import sqlalchemy -import tqdm.asyncio -from sqlalchemy import create_engine -from sqlalchemy.ext.automap import AutomapBase, automap_base -from typing_extensions import TypeAlias - -import dbally -from dbally import SqlAlchemyBaseView, decorators -from dbally.collection import Collection -from dbally.context import BaseCallerContext -from dbally.iql import IQLError -from dbally.llms.litellm import LiteLLM - -SQLITE_DB_FILE_REL_PATH = "../../examples/recruiting/data/candidates.db" -engine = create_engine(f"sqlite:///{os.path.abspath(SQLITE_DB_FILE_REL_PATH)}") - -Base: AutomapBase = automap_base() -Base.prepare(autoload_with=engine) - -Candidate = Base.classes.candidates - - -@dataclass -class MyData(BaseCallerContext): - first_name: str - surname: str - position: str - years_of_experience: int - university: str - skills: typing.List[str] - country: str - - -@dataclass -class OpenPosition(BaseCallerContext): - position: str - min_years_of_experience: int - graduated_from_university: str - required_skills: typing.List[str] - - -class CandidateView(SqlAlchemyBaseView): - """ - A view for retrieving candidates from the database. - """ - - def get_select(self) -> sqlalchemy.Select: - """ - Creates the initial SqlAlchemy select object, which will be used to build the query. - """ - return sqlalchemy.select(Candidate) - - @decorators.view_filter() - def at_least_experience(self, years: typing.Union[int, OpenPosition]) -> sqlalchemy.ColumnElement: - """ - Filters candidates with at least `years` of experience. - """ - if isinstance(years, OpenPosition): - years = years.min_years_of_experience - - return Candidate.years_of_experience >= years - - @decorators.view_filter() - def at_most_experience(self, years: typing.Union[int, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(years, MyData): - years = years.years_of_experience - - return Candidate.years_of_experience <= years - - @decorators.view_filter() - def has_position(self, position: typing.Union[str, OpenPosition]) -> sqlalchemy.ColumnElement: - if isinstance(position, OpenPosition): - position = position.position - - return Candidate.position == position - - @decorators.view_filter() - def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: - """ - Filters candidates that can be considered for a senior data scientist position. - """ - return sqlalchemy.and_( - Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), - Candidate.years_of_experience >= 3, - ) - - @decorators.view_filter() - def from_country(self, country: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - """ - Filters candidates from a specific country. - """ - if isinstance(country, MyData): - return Candidate.country == country.country - - return Candidate.country == country - - @decorators.view_filter() - def graduated_from_university(self, university: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(university, MyData): - university = university.university - - return Candidate.university == university - - @decorators.view_filter() - def has_skill(self, skill: str) -> sqlalchemy.ColumnElement: - return Candidate.skills.like(f"%{skill}%") - - @decorators.view_filter() - def knows_data_analysis(self) -> sqlalchemy.ColumnElement: - return Candidate.tags.like("%Data Analysis%") - - @decorators.view_filter() - def knows_python(self) -> sqlalchemy.ColumnElement: - return Candidate.skills.like("%Python%") - - @decorators.view_filter() - def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(first_name, MyData): - first_name = first_name.first_name - - return Candidate.name.startswith(first_name) - - -OpenAILLMName: TypeAlias = typing.Literal["gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gpt-4-turbo", "gpt-4o"] - - -def setup_collection(model_name: OpenAILLMName) -> Collection: - llm = LiteLLM(model_name=model_name) - - collection = dbally.create_collection("recruitment", llm) - collection.add(CandidateView, lambda: CandidateView(engine)) - - return collection - - -async def generate_iql_from_question( - collection: Collection, - model_name: OpenAILLMName, - question: str, - contexts: typing.Optional[typing.List[BaseCallerContext]], -) -> typing.Tuple[str, OpenAILLMName, typing.Optional[str]]: - try: - result = await collection.ask(question, contexts=contexts, dry_run=True) - except IQLError as e: - exc_pretty = traceback.format_exception_only(e.__class__, e)[0] - return question, model_name, f"FAILED: {exc_pretty}({e.source})" - except Exception as e: - exc_pretty = traceback.format_exception_only(e.__class__, e)[0] - return question, model_name, f"FAILED: {exc_pretty}" - - out = result.metadata.get("iql") - if out is None: - return question, model_name, None - - return question, model_name, out.replace('"', "'") - - -@dataclass -class BenchmarkConfig: - dataset_path: str - out_path: str - n_repeats: int = 5 - llms: typing.List[OpenAILLMName] = field(default_factory=lambda: ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"]) - - -async def main(config_: BenchmarkConfig): - test_set = None - with open(config_.dataset_path, encoding="utf-8") as file: - test_set = json.load(file) - - contexts = [ - MyData( - first_name="John", - surname="Smith", - years_of_experience=4, - position="Data Engineer", - university="University of Toronto", - skills=["Python"], - country="United Kingdom", - ), - OpenPosition( - position="Machine Learning Engineer", - graduated_from_university="Stanford Univeristy", - min_years_of_experience=1, - required_skills=["Python", "SQL"], - ), - ] - - tasks: typing.List[asyncio.Task] = [] - for model_name in config.llms: - collection = setup_collection(model_name) - for test_case in test_set: - for _ in range(config.n_repeats): - task = asyncio.create_task( - generate_iql_from_question(collection, model_name, test_case["question"], contexts=contexts) - ) - tasks.append(task) - - output_data = {test_case["question"]: test_case for test_case in test_set} - empty_answers = {str(llm_name): [] for llm_name in config.llms} - - total_iter = len(config.llms) * len(test_set) * config.n_repeats - for task in tqdm.asyncio.tqdm.as_completed(tasks, total=total_iter): - question, llm_name, answer = await task - if "answers" not in output_data[question]: - output_data[question]["answers"] = deepcopy(empty_answers) - - output_data[question]["answers"][llm_name].append(answer) - - df_out_raw = pl.DataFrame(list(output_data.values())) - - df_out = ( - df_out_raw.unnest("answers") - .unpivot( - on=pl.selectors.starts_with("gpt"), - index=["question", "correct_answer", "context"], - variable_name="model", - value_name="answer", - ) - .explode("answer") - .group_by(["context", "model"]) - .agg( - [ - (pl.col("correct_answer") == pl.col("answer")).mean().alias("frac_hits"), - (pl.col("correct_answer") == pl.col("answer")).sum().alias("n_hits"), - ] - ) - .sort(["model", "context"]) - ) - - print(df_out) - - with open(config.out_path, "w", encoding="utf-8") as file: - file.write(json.dumps(df_out_raw.to_dicts(), indent=2)) - - -if __name__ == "__main__": - config = BenchmarkConfig( - dataset_path="dataset/context_dataset.json", out_path="../../context_benchmark_output.json" - ) - - asyncio.run(main(config)) diff --git a/benchmark/dbally_benchmark/dataset/context_dataset.json b/benchmark/dbally_benchmark/dataset/context_dataset.json deleted file mode 100644 index c37f38d1..00000000 --- a/benchmark/dbally_benchmark/dataset/context_dataset.json +++ /dev/null @@ -1,62 +0,0 @@ -[ - { - "question": "Find me French candidates suitable for my position with at least 1 year of experience.", - "correct_answer": "from_country('France') AND has_position(AskerContext()) AND at_least_experience(1)", - "context": false - }, - { - "question": "Please find me candidates from my country who have at most 4 years of experience.", - "correct_answer": "from_country(AskerContext()) AND at_most_experience(4)", - "context": true - }, - { - "question": "Find me candidates who graduated from Stanford University and work as Software Engineers.", - "correct_answer": "graduated_from_university('Stanford University') AND has_position('Software Engineer')", - "context": false - }, - { - "question": "Find me candidates who graduated from my university", - "correct_answer": "graduated_from_university(AskerContext())", - "context": true - }, - { - "question": "Could you find me candidates with at most as experience as I have who also know Python?", - "correct_answer": "at_most_experience(AskerContext()) AND know_python()", - "context": true - }, - { - "question": "Please find me candidates who know Data Analysis and Python", - "correct_answer": "know_python() AND know_data_analysis()", - "context": false - }, - { - "question": "Find me candidates with at least minimal required experience for the currently open position.", - "correct_answer": "at_least_experience(AskerContext())", - "context": true - }, - { - "question": "List candidates with between 2 and 6 years of experience.", - "correct_answer": "at_least_experience(2) AND at_most_experience(6)", - "context": false - }, - { - "question": "Find me candidates who currently have the same position as we look for in our company?", - "correct_answer": "has_position(AskerContext())", - "context": true - }, - { - "question": "Please find me senior data scientist candidates who know Data Analysis and come from my country", - "correct_answer": "senior_data_scientist_position() AND has_skill('Data Analysis') AND from_country(AskerContext())", - "context": true - }, - { - "question": "Find me candidates that have the same first name as me", - "correct_answer": "first_name_is(AskerContext())", - "context": true - }, - { - "question": "List candidates named Mohammed from India", - "correct_answer": "first_name_is('Mohammed') AND from_country('India')", - "context": false - } -] diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py deleted file mode 100644 index f2d86b58..00000000 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ /dev/null @@ -1,153 +0,0 @@ -import asyncio -import json -import os -from functools import partial -from pathlib import Path -from typing import Any, List - -import hydra -import neptune -from dbally_benchmark.config import BenchmarkConfig -from dbally_benchmark.constants import VIEW_REGISTRY, EvaluationType, ViewName -from dbally_benchmark.dataset.bird_dataset import BIRDDataset, BIRDExample -from dbally_benchmark.paths import PATH_EXPERIMENTS -from dbally_benchmark.text2sql.metrics import calculate_dataset_metrics -from dbally_benchmark.text2sql.text2sql_result import Text2SQLResult -from dbally_benchmark.utils import batch, get_datetime_str, set_up_gitlab_metadata -from hydra.utils import instantiate -from loguru import logger -from neptune.utils import stringify_unsupported -from omegaconf import DictConfig -from sqlalchemy import create_engine - -import dbally -from dbally.collection import Collection -from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError -from dbally.llms.litellm import LiteLLM -from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE - - -async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: - try: - result = await collection.ask(example.question, dry_run=True) - sql = result.metadata["sql"] - except UnsupportedQueryError: - sql = "UnsupportedQueryError" - except NoViewFoundError: - sql = "NoViewFoundError" - except Exception: # pylint: disable=broad-exception-caught - sql = "Error" - - return Text2SQLResult( - db_id=example.db_id, question=example.question, ground_truth_sql=example.SQL, predicted_sql=sql - ) - - -async def run_dbally_for_dataset(dataset: BIRDDataset, collection: Collection) -> List[Text2SQLResult]: - """ - Transforms questions into SQL queries using a IQL approach. - - Args: - dataset: The dataset containing questions to be transformed into SQL queries. - collection: Container for a set of views used by db-ally. - - Returns: - A list of Text2SQLResult objects representing the predictions. - """ - - results: List[Text2SQLResult] = [] - - for group in batch(dataset, 5): - current_results = await asyncio.gather( - *[_run_dbally_for_single_example(example, collection) for example in group] - ) - results = [*current_results, *results] - - return results - - -async def evaluate(cfg: DictConfig) -> Any: - """ - Runs db-ally evaluation for a single dataset defined in hydra config. - - Args: - cfg: hydra config, loads automatically from path passed on to the decorator - """ - - output_dir = PATH_EXPERIMENTS / cfg.output_path / get_datetime_str() - output_dir.mkdir(exist_ok=True, parents=True) - cfg = instantiate(cfg) - benchmark_cfg = BenchmarkConfig() - - engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - - llm = LiteLLM( - model_name="gpt-4", - api_key=benchmark_cfg.openai_api_key, - ) - - db = dbally.create_collection(cfg.db_name, llm) - - for view_name in cfg.view_names: - view = VIEW_REGISTRY[ViewName(view_name)] - db.add(view, partial(view, engine)) - - run = None - if cfg.neptune.log: - run = neptune.init_run( - project=benchmark_cfg.neptune_project, - api_token=benchmark_cfg.neptune_api_token, - ) - run["config"] = stringify_unsupported(cfg) - tags = list(cfg.neptune.get("tags", [])) + [EvaluationType.END2END.value, cfg.model_name, cfg.db_name] - run["sys/tags"].add(tags) - - if "CI_MERGE_REQUEST_IID" in os.environ: - run = set_up_gitlab_metadata(run) - - metrics_file_name, results_file_name = "metrics.json", "eval_results.json" - - logger.info(f"Running db-ally predictions for dataset {cfg.dataset_path}") - evaluation_dataset = BIRDDataset.from_json_file( - Path(cfg.dataset_path), difficulty_levels=cfg.get("difficulty_levels") - ) - dbally_results = await run_dbally_for_dataset(dataset=evaluation_dataset, collection=db) - - with open(output_dir / results_file_name, "w", encoding="utf-8") as outfile: - json.dump([result.model_dump() for result in dbally_results], outfile, indent=4) - - logger.info("Calculating metrics") - metrics = calculate_dataset_metrics(dbally_results, engine) - - with open(output_dir / metrics_file_name, "w", encoding="utf-8") as outfile: - json.dump(metrics, outfile, indent=4) - - logger.info(f"db-ally predictions saved under directory: {output_dir}") - - if run: - run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) - run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat) - run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE) - run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) - run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) - run["evaluation/metrics"] = stringify_unsupported(metrics) - logger.info(f"Evaluation results logged to neptune at {run.get_url()}") - - -@hydra.main(version_base=None, config_path="experiment_config", config_name="evaluate_e2e_config") -def main(cfg: DictConfig): - """ - Runs db-ally evaluation for a single dataset defined in hydra config. - The following metrics are calculated during evaluation: exact match, valid SQL, - execution accuracy and valid efficiency score. - - Args: - cfg: hydra config, loads automatically from path passed on to the decorator. - """ - - asyncio.run(evaluate(cfg)) - - -if __name__ == "__main__": - main() # pylint: disable=E1120