Skip to content

Commit

Permalink
allowing custom dataset urls
Browse files Browse the repository at this point in the history
Signed-off-by: Giridhar Ganapavarapu <[email protected]>
  • Loading branch information
gganapavarapu committed Jul 31, 2023
1 parent 24548f1 commit 34a2906
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
11 changes: 9 additions & 2 deletions aix360/datasets/climate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class ClimateDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "climate_data"
Expand All @@ -32,7 +35,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "jena_climate_2009_2016.csv")
)
climate_data_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
climate_data_url = (
url
if url is not None
else "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
)

self.input_length = 500
# download data
Expand Down
11 changes: 9 additions & 2 deletions aix360/datasets/diabetes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ class DiabetesDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "diabetes_data"
Expand All @@ -27,7 +30,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "diabetes.csv")
)
diabetes_url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"
diabetes_url = (
url
if url is not None
else "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"
)

if not os.path.exists(self.data_file):
response = requests.get(diabetes_url)
Expand Down
10 changes: 7 additions & 3 deletions aix360/datasets/ford_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FordDataset:
"""

def __init__(self, category_a: bool = True):
def __init__(self, url: str = None, category_a: bool = True):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "ford_data"
Expand All @@ -41,8 +41,12 @@ def __init__(self, category_a: bool = True):
)

self.category = "A" if category_a else "B"
ford_data_url = "http://timeseriesclassification.com/ClassificationDownloads/Ford{}.zip".format(
self.category
ford_data_url = (
url
if url is not None
else "https://timeseriesclassification.com/aeon-toolkit/Ford{}.zip".format(
self.category
)
)

self.input_length = 500
Expand Down
11 changes: 9 additions & 2 deletions aix360/datasets/sunspots_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class SunspotDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "sunspots_data"
Expand All @@ -32,7 +35,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "sunspots.csv")
)
sunspots_url = "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv"
sunspots_url = (
url
if url is not None
else "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv"
)

if not os.path.exists(self.data_file):
response = requests.get(sunspots_url)
Expand Down

0 comments on commit 34a2906

Please sign in to comment.