diff --git a/cities/utils/data_grabber.py b/cities/utils/data_grabber.py index d36c01cd..740275a9 100644 --- a/cities/utils/data_grabber.py +++ b/cities/utils/data_grabber.py @@ -76,7 +76,6 @@ def __init__(self): sys.path.insert(0, self.data_path) - def list_available_features(level="county"): root = find_repo_root() @@ -85,9 +84,11 @@ def list_available_features(level="county"): elif level == "msa": folder_path = f"{root}/data/MSA_level" elif level == "census_tract": - folder_path = f"{root}/data/Census_tract_level" + folder_path = f"{root}/data/Census_tract_level" else: - raise ValueError("Invalid level. Please choose 'county', 'census_tract' or 'msa'.") + raise ValueError( + "Invalid level. Please choose 'county', 'census_tract' or 'msa'." + ) file_names = [f for f in os.listdir(folder_path) if f != ".gitkeep"] processed_file_names = [] @@ -113,11 +114,13 @@ def list_tensed_features(level="county"): all_features = list_available_features(level="msa") elif level == "census_tract": - data = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future - all_features = list_available_features(level="Census_tract_level") + data = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future + all_features = list_available_features(level="Census_tract_level") else: - raise ValueError("Invalid level. Please choose 'county', 'census_tract' or 'msa'.") + raise ValueError( + "Invalid level. Please choose 'county', 'census_tract' or 'msa'." + ) data.get_features_wide(all_features) diff --git a/tests/test_data_grabber.py b/tests/test_data_grabber.py index 5a8959d8..305765fb 100644 --- a/tests/test_data_grabber.py +++ b/tests/test_data_grabber.py @@ -5,7 +5,7 @@ from cities.utils.data_grabber import ( DataGrabber, MSADataGrabber, - CTDataGrabberCSV, # TODO: Change to CTDataGrabber() in the future + CTDataGrabberCSV, # TODO: Change to CTDataGrabber() in the future list_available_features, list_interventions, list_outcomes, @@ -16,6 +16,7 @@ features_msa = list_available_features("msa") features_ct = list_available_features("census_tract") + def test_non_emptiness_DataGrabber(): assert features is not None @@ -41,9 +42,9 @@ def test_non_emptiness_DataGrabber(): ) -def test_non_emptiness_MSADataGrabber(): +def test_non_emptiness_MSADataGrabber(): os.chdir(os.path.dirname(os.getcwd())) - data_msa = MSADataGrabber() + data_msa = MSADataGrabber() data_msa.get_features_wide(features_msa) data_msa.get_features_std_wide(features_msa) @@ -57,10 +58,9 @@ def test_non_emptiness_MSADataGrabber(): assert data_msa.std_long[feature].shape[1] == 4 - def test_non_emptiness_CTDataGrabber(): os.chdir(os.path.dirname(os.getcwd())) - data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future + data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future data_ct.get_features_wide(features_ct) data_ct.get_features_std_wide(features_ct) @@ -74,7 +74,7 @@ def test_non_emptiness_CTDataGrabber(): assert data_ct.std_long[feature].shape[1] == 4 -def general_data_format_testing(data, features, level = "county_msa"): +def general_data_format_testing(data, features, level="county_msa"): assert features is not None data.get_features_wide(features) @@ -95,7 +95,6 @@ def general_data_format_testing(data, features, level = "county_msa"): for feature in features: if level == "county_msa": - namesFipsError = "FIPS codes and GeoNames don't match!" assert ( data.wide[feature]["GeoFIPS"].nunique() @@ -115,9 +114,7 @@ def general_data_format_testing(data, features, level = "county_msa"): ), namesFipsError elif level == "census_tract": - - pass # TODO: check whether the county number is correct as indicated by the CT number - + pass # TODO: check whether the county number is correct as indicated by the CT number for feature in features: for column in data.wide[feature].columns[2:]: @@ -173,11 +170,10 @@ def test_MSADataGrabber_data_types(): general_data_format_testing(data_msa, features_msa) - def test_CTDataGrabber_data_types(): - data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future + data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future - general_data_format_testing(data_ct, features_ct, level= "census_tract") + general_data_format_testing(data_ct, features_ct, level="census_tract") def test_feature_listing_runtime(): @@ -214,8 +210,7 @@ def test_GeoFIPS_ma_column_values(): assert all(value > 9999 and str(value)[-1] == "0" for value in column_values) - -data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future +data_ct = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future data_ct.get_features_wide(features_ct) @@ -225,5 +220,3 @@ def test_GeoFIPS_ct_column_values(): column_values = data_ct.wide[feature]["GeoFIPS"] assert all(value > 999999999 for value in column_values) - -