-
-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'ds_refactor' of github.com:microsoft/RD-Agent into ds_r…
…efactor
- Loading branch information
Showing
17 changed files
with
151 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,133 +1,166 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import platform | ||
import pandas as pd | ||
import fire | ||
import shutil | ||
|
||
from rdagent.app.kaggle.conf import KAGGLE_IMPLEMENT_SETTING | ||
|
||
|
||
class DataHandler: | ||
"""Base DataHandler interface.""" | ||
|
||
def load(self, path) -> pd.DataFrame: | ||
... | ||
raise NotImplementedError | ||
|
||
def dump(self, df: pd.DataFrame, path): | ||
... | ||
raise NotImplementedError | ||
|
||
|
||
class CSVDataHandler(DataHandler): | ||
class GenericDataHandler(DataHandler): | ||
""" | ||
A generic data handler that automatically detects file type based on suffix | ||
and uses the correct pandas method for load/dump. | ||
""" | ||
|
||
def load(self, path) -> pd.DataFrame: | ||
return pd.read_csv(path) | ||
path = Path(path) | ||
suffix = path.suffix.lower() | ||
|
||
if suffix == ".csv": | ||
return pd.read_csv(path) | ||
elif suffix == ".pkl": | ||
return pd.read_pickle(path) | ||
elif suffix == ".parquet": | ||
return pd.read_parquet(path) | ||
elif suffix in [".h5", ".hdf", ".hdf5"]: | ||
# Note: for HDF, you need a 'key' in read_hdf. If you expect a single key, | ||
# you might do: pd.read_hdf(path, key='df') or something similar. | ||
# Adjust as needed based on your HDF structure. | ||
return pd.read_hdf(path, key='data') | ||
else: | ||
raise ValueError(f"Unsupported file type: {suffix}") | ||
|
||
def dump(self, df: pd.DataFrame, path): | ||
df.to_csv(path, index=False) | ||
path = Path(path) | ||
suffix = path.suffix.lower() | ||
|
||
if suffix == ".csv": | ||
df.to_csv(path, index=False) | ||
elif suffix == ".pkl": | ||
df.to_pickle(path) | ||
elif suffix == ".parquet": | ||
df.to_parquet(path, index=True) | ||
elif suffix in [".h5", ".hdf", ".hdf5"]: | ||
# Similarly, you need a key for HDF. | ||
df.to_hdf(path, key="data", mode="w") | ||
else: | ||
raise ValueError(f"Unsupported file type: {suffix}") | ||
|
||
|
||
class DataReducer: | ||
"""Base DataReducer interface.""" | ||
|
||
def reduce(self, df) -> pd.DataFrame: | ||
... | ||
def reduce(self, df: pd.DataFrame) -> pd.DataFrame: | ||
raise NotImplementedError | ||
|
||
|
||
class RandDataReducer(DataReducer): | ||
""" | ||
Example random sampler: ensures at least `min_num` rows | ||
or at least `min_frac` fraction of the data (whichever is larger). | ||
""" | ||
|
||
def __init__(self, min_frac=0.05, min_num=100): | ||
self.min_frac = min_frac | ||
self.min_num = min_num | ||
|
||
def reduce(self, df) -> pd.DataFrame: | ||
# Calculate the fraction to sample | ||
def reduce(self, df: pd.DataFrame) -> pd.DataFrame: | ||
frac = max(self.min_frac, self.min_num / len(df)) | ||
# Sample the data | ||
if frac >= 1: | ||
return df | ||
return df.sample(frac=frac, random_state=1) | ||
|
||
|
||
def create_debug_data( | ||
competition, | ||
original_file_name, | ||
dh_cls: type[DataHandler], | ||
dr_cls: type[DataReducer], | ||
dr_cls_kwargs={}, | ||
dataset_path=KAGGLE_IMPLEMENT_SETTING.local_data_path, | ||
): | ||
# Define the path to the original data file | ||
data_path = Path(dataset_path) / competition / original_file_name | ||
|
||
# Automatically generate full and sampled file names based on the original file name | ||
original_suffix = Path(original_file_name).suffix | ||
full_file_name = original_file_name.replace(original_suffix, f'.full{original_suffix}') | ||
sampled_file_name = original_file_name.replace(original_suffix, f'.sampled{original_suffix}') | ||
|
||
# Define the path to the .full data file | ||
full_data_path = data_path.with_name(full_file_name) | ||
|
||
# Check if the .full file exists | ||
if not full_data_path.exists(): | ||
# Initialize handlers | ||
data_handler = dh_cls() | ||
data_reducer = dr_cls(**dr_cls_kwargs) | ||
|
||
# Load the data file | ||
df = data_handler.load(data_path) | ||
|
||
# Reduce the data | ||
df_sampled = data_reducer.reduce(df) | ||
|
||
# Save the sampled data to a new data file | ||
sampled_data_path = data_path.with_name(sampled_file_name) | ||
data_handler.dump(df_sampled, sampled_data_path) | ||
|
||
# Rename the original file with .full | ||
data_path.rename(full_data_path) | ||
|
||
# Move the sampled data to replace the original one | ||
sampled_data_path.rename(data_path) | ||
|
||
|
||
class PickleDataHandler(DataHandler): | ||
|
||
def load(self, path) -> pd.DataFrame: | ||
return pd.read_pickle(path) | ||
|
||
def dump(self, df: pd.DataFrame, path): | ||
df.to_pickle(path) | ||
|
||
|
||
class ColumnReducer(DataReducer): | ||
""" | ||
Example column reducer: keep only the first 5 columns. | ||
""" | ||
|
||
def reduce(self, df) -> pd.DataFrame: | ||
def reduce(self, df: pd.DataFrame) -> pd.DataFrame: | ||
return df.iloc[:, :5] | ||
|
||
|
||
def new_york_city_taxi_fare_prediction_creator(): | ||
create_debug_data(competition="new-york-city-taxi-fare-prediction", | ||
original_file_name="train.csv", | ||
dh_cls=CSVDataHandler, | ||
dr_cls=RandDataReducer, | ||
dr_cls_kwargs=dict(min_frac=0.05, min_num=100)) | ||
class RowReducer(DataReducer): | ||
""" | ||
Example row reducer: keep only the first 10% rows. | ||
""" | ||
|
||
def reduce(self, df: pd.DataFrame) -> pd.DataFrame: | ||
ten_percent = int(max(len(df) * 0.1, 100)) | ||
return df.iloc[:ten_percent] | ||
|
||
def amc_debug_data_creator(): | ||
create_debug_data( | ||
competition="amc", | ||
original_file_name="train_feature_with_label.pkl", | ||
dh_cls=PickleDataHandler, | ||
dr_cls=ColumnReducer, | ||
) | ||
|
||
create_debug_data( | ||
competition="amc", | ||
original_file_name="test_feature.pkl", | ||
dh_cls=PickleDataHandler, | ||
dr_cls=ColumnReducer, | ||
) | ||
def create_debug_data( | ||
competition: str, | ||
dr_cls: type[DataReducer] = RandDataReducer, | ||
dr_cls_kwargs=None, | ||
dataset_path=None, | ||
sample_path=None, | ||
): | ||
""" | ||
Reads the original data file, creates a reduced sample, | ||
and renames/moves files for easier debugging. | ||
Automatically detects file type (csv, pkl, parquet, hdf, etc.). | ||
""" | ||
if dr_cls_kwargs is None: | ||
dr_cls_kwargs = {} | ||
|
||
if dataset_path is None: | ||
dataset_path = KAGGLE_IMPLEMENT_SETTING.local_data_path | ||
|
||
if sample_path is None: | ||
sample_path = Path(dataset_path) / "sample" | ||
|
||
data_folder = Path(dataset_path) / competition | ||
sample_folder = Path(sample_path) / competition | ||
|
||
# Traverse the folder and exclude specific file types | ||
included_extensions = {".csv", ".pkl", ".parquet", ".h5", ".hdf", ".hdf5"} | ||
files_to_process = [ | ||
file for file in data_folder.rglob("*") | ||
if file.is_file() | ||
] | ||
|
||
for file_path in files_to_process: | ||
sampled_file_path = sample_folder / file_path.relative_to(data_folder) | ||
if sampled_file_path.exists(): | ||
continue | ||
|
||
sampled_file_path.parent.mkdir(parents=True, exist_ok=True) | ||
if file_path.suffix not in included_extensions: | ||
if platform.system() == "Linux": | ||
os.symlink(file_path, sampled_file_path) | ||
if platform.system() == "Windows": | ||
os.link(file_path, sampled_file_path) | ||
continue | ||
|
||
# Initialize the generic data handler | ||
data_handler = GenericDataHandler() | ||
|
||
# Initialize the data reducer (e.g., RandDataReducer or ColumnReducer) | ||
data_reducer = dr_cls(**dr_cls_kwargs) | ||
|
||
# Load the original data | ||
df = data_handler.load(file_path) | ||
|
||
# Create a sampled subset | ||
df_sampled = data_reducer.reduce(df) | ||
|
||
# Dump the sampled data | ||
data_handler.dump(df_sampled, sampled_file_path) | ||
|
||
# competition to data handler & Reducer mapping | ||
# find a place to store reduced data. | ||
# - <local_data_path>, <local_data_path>.debug | ||
|
||
import fire | ||
if __name__ == "__main__": | ||
# fire.Fire(create_debug_data) | ||
fire.Fire(amc_debug_data_creator) | ||
fire.Fire(create_debug_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .scen import DataScienceScen | ||
from .kaggle import KaggleScen | ||
from .scen import DataScienceScen | ||
|
||
__all__ = ["DataScienceScen", "KaggleScen"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.