Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query and parsing fixes #130

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/instruct_advanced_postgres.csv
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ ewallet,instructions_cte_window,What is the LUB for each user.,"LUB = Latest Use
LUB = Latest User Balance, which is the most recent balance for each user
To determine user notification preferences, use a join between the users and user_setting_snapshot tables in a CTE, focusing on selecting the most recent snapshot for each user.
For analyzing coupon usage, start with a join between the coupons and wallet_transactions_daily tables in a CTE, apply filtering as needed, and then perform aggregation for the total discount amount"
ewallet,instructions_cte_window,"What is the MRR for each merchant? Return the merchant name, category, revenue amount, and revenue rank.","MRR = Merchant Revenue Rank, which ranks merchants based on their total successful received transaction amounts. Filter receiver_type=1 in consumer_div.wallet_transactions_daily for merchants. Merchant with rank 1 has the highest revenue.","WITH merchant_revenue AS (SELECT {m.mid, m.name}, m.category AS merchant_category, SUM(w.amount) AS total_revenue FROM consumer_div.merchants m INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1 WHERE w.status = 'success' GROUP BY {m.mid, m.name}, m.category) SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue","To get user notification preferences, join the users and user_setting_snapshot tables in a CTE, then select the latest snapshot for each user
ewallet,instructions_cte_window,"What is the MRR for each merchant? Return the merchant name, category, revenue amount, and revenue rank.","MRR = Merchant Revenue Rank, which ranks merchants based on their total successful received transaction amounts. Filter receiver_type=1 in consumer_div.wallet_transactions_daily for merchants. Merchant with rank 1 has the highest revenue.","WITH merchant_revenue AS (SELECT {m.mid, m.name}, m.category AS merchant_category, SUM(w.amount) AS total_revenue FROM consumer_div.merchants m INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1 WHERE w.status = 'success' GROUP BY {}, m.category) SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue","To get user notification preferences, join the users and user_setting_snapshot tables in a CTE, then select the latest snapshot for each user
Merchant category should be matched case-insensitively with wildcards, e.g., using LOWER(merchants.category) LIKE '%...%'.
MRR = Merchant Revenue Rank, which ranks merchants based on their total successful received transaction amounts. Filter receiver_type=1 in consumer_div.wallet_transactions_daily for merchants. Merchant with rank 1 has the highest revenue.
To analyze user engagement, join the users and user_sessions tables in a CTE, then aggregate to calculate total session duration per user"
Expand Down
5 changes: 2 additions & 3 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def mk_vllm_json(prompt, num_beams):
"use_beam_search": num_beams > 1,
"best_of": num_beams,
# "temperature": 0,
# "stop": [";", "```"],
"stop": [";", "```"],
"max_tokens": 1024,
}

Expand Down Expand Up @@ -62,8 +62,7 @@ def process_row(row, api_url: str, api_type: str, num_beams: int, decimal_points
generated_query = ""
elif "[SQL]" not in row["prompt"]:
generated_query = (
r.json()["text"][0].split("```")[-1].split("```")[0].split(";")[0].strip()
+ ";"
r.json()["text"][0].split("```", 1)[0].split(";")[0].strip() + ";"
)
else:
generated_query = r.json()["text"][0]
Expand Down
29 changes: 29 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,35 @@ def test_get_all_minimal_queries():
option5,
option6,
]
query6 = """WITH merchant_revenue AS (SELECT {m.mid,m.name}, m.category AS merchant_category, SUM(w.amount) AS total_revenue
FROM consumer_div.merchants m
INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1
WHERE w.status = 'success'
GROUP BY {}, m.category)
SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue"""
option7 = """WITH merchant_revenue AS (SELECT m.mid, m.category AS merchant_category, SUM(w.amount) AS total_revenue
FROM consumer_div.merchants m
INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1
WHERE w.status = 'success'
GROUP BY m.mid, m.category)
SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue"""
option8 = """WITH merchant_revenue AS (SELECT m.name, m.category AS merchant_category, SUM(w.amount) AS total_revenue
FROM consumer_div.merchants m
INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1
WHERE w.status = 'success'
GROUP BY m.name, m.category)
SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue"""
option9 = """WITH merchant_revenue AS (SELECT m.mid, m.name, m.category AS merchant_category, SUM(w.amount) AS total_revenue
FROM consumer_div.merchants m
INNER JOIN consumer_div.wallet_transactions_daily w ON m.mid = w.receiver_id AND w.receiver_type = 1
WHERE w.status = 'success'
GROUP BY m.mid, m.name, m.category)
SELECT *, RANK() OVER (ORDER BY total_revenue DESC) AS mrr FROM merchant_revenue"""
for expected, result in zip(
get_all_minimal_queries(query6), [option7, option8, option9]
):
assert expected == result
assert get_all_minimal_queries(query6) == [option7, option8, option9]


@mock.patch("pandas.read_sql_query")
Expand Down
Loading