Skip to content

Commit d56861e

Browse files
authored
Merge pull request #888 from sfu-db/load_db
feat(eda): add get_db_names
2 parents ec857e7 + a7bf820 commit d56861e

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

dataprep/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This module implements load dataset related functions"""
22

33
from ._base import load_dataset, _load_dataset_as_dask, load_db
4-
from ._base import get_dataset_names
4+
from ._base import get_dataset_names, get_db_names
55

6-
__all__ = ["load_dataset", "get_dataset_names", "_load_dataset_as_dask", "load_db"]
6+
__all__ = ["load_dataset", "get_dataset_names", "_load_dataset_as_dask", "load_db", "get_db_names"]

dataprep/datasets/_base.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def get_dataset_names() -> List[str]:
3030
return datasets
3131

3232

33+
def get_db_names() -> List[str]:
34+
"""
35+
Get all available database names. It is all csv file names in 'database' folder.
36+
37+
Returns
38+
-------
39+
datasets: list
40+
A list of all available dataset names.
41+
42+
"""
43+
module_path = dirname(__file__)
44+
files = os.listdir(f"{module_path}/database")
45+
db_files = list(filter(lambda x: x.endswith(".db"), files))
46+
47+
# remove suffix csv and get dataset names
48+
db_names = list(map(lambda f: os.path.splitext(f)[0], db_files))
49+
50+
return db_names
51+
52+
3353
def _get_dataset_path(name: str) -> str:
3454
"""
3555
Given a dataset name, output the file path.
@@ -80,20 +100,24 @@ def load_dataset(name: str) -> pd.DataFrame:
80100
return df
81101

82102

83-
def load_db(file_name: str) -> Engine:
103+
def load_db(name: str) -> Engine:
84104
"""
85105
Load a database file
86106
87107
Parameters
88108
----------
89-
file_name: str
109+
name: str
90110
Name of the database file
91111
92112
Returns
93113
-------
94114
db_url : str
95115
SQLite url
96116
"""
117+
file_name = name.lower()
118+
if not file_name.endswith(".db"):
119+
file_name += ".db"
120+
97121
db_file_path = str(
98122
os.path.join(os.path.dirname(os.path.abspath(__file__)), "database", file_name)
99123
)
Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
"""
22
module for testing the functions inside datasets
33
"""
4-
from ...datasets import get_dataset_names, load_dataset
4+
from ...datasets import get_dataset_names, get_db_names, load_dataset, load_db
55

66

77
def test_get_dataset_names() -> None:
88
names = get_dataset_names()
99
assert len(names) > 0
1010

1111

12+
def test_get_db_names() -> None:
13+
names = get_db_names()
14+
assert len(names) > 0
15+
16+
1217
def test_load_dataset() -> None:
1318
dataset_names = get_dataset_names()
1419
for name in dataset_names:
1520
df = load_dataset(name)
1621
assert len(df) > 0
22+
23+
24+
def test_load_db() -> None:
25+
dataset_names = get_db_names()
26+
for name in dataset_names:
27+
db = load_db(name)

0 commit comments

Comments
 (0)