Skip to content

Commit

Permalink
(WIP) Remove extra keyword from IC data loader get func
Browse files Browse the repository at this point in the history
  • Loading branch information
john-b-yang committed Aug 8, 2023
1 parent c02afdc commit 485c962
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
2 changes: 1 addition & 1 deletion experiments/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def handicap_bash(record: Dict) -> str:
def handicap_sql(record: Dict) -> str:
# Custom handicap for spider dev dataset
handicap = "MySQL tables, with their properties\n"
tables = record["extra"]["db_tables"]
tables = record["db_tables"]
for name, columns in tables.items():
handicap += f'- {name}: {str(columns)}\n'
return handicap
Expand Down
10 changes: 5 additions & 5 deletions intercode/envs/python/python_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,16 @@ def get_reward_mbpp(self):

# Run tests against submitted function
results_pred = {}
self.conn.root.execute(self.record["extra"]["test_setup_code"])
for test in self.record["extra"]["tests"]:
self.conn.root.execute(self.record["test_setup_code"])
for test in self.record["tests"]:
results_pred[test] = self.conn.root.execute(test)

# Load gold + run tests
results_gold = {}
self.conn.root.execute(RESET_KEYWORD)
self.conn.root.execute(self.record["extra"]["test_setup_code"])
self.conn.root.execute(self.record["test_setup_code"])
self.conn.root.execute(self.gold)
for test in self.record["extra"]["tests"]:
for test in self.record["tests"]:
results_gold[test] = self.conn.root.execute(test)

self.info["submitted_function"] = func_name
Expand All @@ -125,4 +125,4 @@ def get_reward_mbpp(self):

self.logger.info(f"Info: {self.info}")
self.logger.info(f"Reward: {self.reward}")
return 0.0, self.info
return self.reward, self.info
14 changes: 4 additions & 10 deletions intercode/utils/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import math, os
import numpy as np
import pandas as pd

Expand All @@ -15,17 +15,11 @@ def get(self, index: int = None) -> Dict:
"""Get query, gold pair (+ extra data) at index (or random index if None)"""
if index is None:
index = np.random.randint(0, len(self.data))
record = self.data.iloc[index].to_dict()
record = {
"query": self.data.iloc[index]["query"],
"gold": self.data.iloc[index]["gold"],
key: value for key, value in record.items()
if not (isinstance(value, float) and math.isnan(value))
}
if len(self.data.iloc[index]) > 2:
columns = self.data.columns.tolist()
extras = {}
for i in range(len(columns)):
if columns[i] not in ["query", "gold"]:
extras[columns[i]] = self.data.iloc[index,i]
record["extra"] = extras
return record

def _load_data(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_get():
record = data_loader.get(0)
assert(record["query"] == "Find the first name of students who have both cat and dog pets .")
assert(record["gold"].startswith("select t1.fname from Student as t1 join Has_Pet as t2 on t1.stuid"))
assert(record["extra"]["db"] == "pets_1")
assert(record["db"] == "pets_1")

data_path = "./data/test/bash_queries.json"
data_loader = IntercodeDataLoader(data_path)
Expand Down

0 comments on commit 485c962

Please sign in to comment.