Skip to content

Commit

Permalink
refactoring data grabber
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklewa committed Jun 28, 2024
1 parent beddecf commit af986a9
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 27 deletions.
80 changes: 65 additions & 15 deletions cities/utils/data_grabber.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def find_repo_root() -> Path:


def check_if_tensed(df):
years_to_check = ["2015", "2018", "2019", "2020"]
check = df.columns[2:].isin(years_to_check).any().any()
years_to_check = ["2015", "2018", "2019", "2020"]
check = any(year in df.columns for year in years_to_check)
return check


Expand Down Expand Up @@ -67,16 +67,55 @@ def __init__(self):
self.data_path = os.path.join(self.repo_root, "data/MSA_level")
sys.path.insert(0, self.data_path)


class CTDataGrabberCSV(DataGrabberCSV):
def __init__(self):
def __init__(self,
level_DG: str = "pre_2020"): # new argument pre_2020 and post_2020
super().__init__()
self.repo_root = find_repo_root()
self.data_path = os.path.join(self.repo_root, "data/Census_tract_level")
self.level_DG = level_DG
sys.path.insert(0, self.data_path)

def _get_features(self, features: List[str], table_suffix: str) -> None: # redefining data grabbing to depend on `level_DG` argument
for feature in features:
if self.level_DG == "pre_2020":
file_path = os.path.join(self.data_path, f"{feature}_pre2020_CT_{table_suffix}.csv")
elif self.level_DG == "post_2020":
file_path = os.path.join(self.data_path, f"{feature}_post2020_CT_{table_suffix}.csv")
else:
raise ValueError("Invalid level_DG. Please choose 'pre_2020' or 'post_2020'.")

if os.path.exists(file_path):
df = pd.read_csv(file_path)
if table_suffix == "wide":
self.wide[feature] = df
elif table_suffix == "std_wide":
self.std_wide[feature] = df
elif table_suffix == "long":
self.long[feature] = df
elif table_suffix == "std_long":
self.std_long[feature] = df
else:
raise ValueError(
"Invalid table suffix. Please choose 'wide', 'std_wide', 'long', or 'std_long'."
)
else:
print(f"File not found: {file_path}")

def get_features_wide(self, features: List[str]) -> None:
self._get_features(features, "wide")

def list_available_features(level="county"):
def get_features_std_wide(self, features: List[str]) -> None:
self._get_features(features, "std_wide")

def get_features_long(self, features: List[str]) -> None:
self._get_features(features, "long")

def get_features_std_long(self, features: List[str]) -> None:
self._get_features(features, "std_long")



def list_available_features(level="county", level_DG="pre_2020"):
root = find_repo_root()

if level == "county":
Expand All @@ -86,25 +125,36 @@ def list_available_features(level="county"):
elif level == "census_tract":
folder_path = f"{root}/data/Census_tract_level"
else:
raise ValueError(
"Invalid level. Please choose 'county', 'census_tract' or 'msa'."
)
raise ValueError("Invalid level. Please choose 'county', 'census_tract' or 'msa'.")

file_names = [f for f in os.listdir(folder_path) if f != ".gitkeep"]
processed_file_names = []

for file_name in file_names:
# Use regular expressions to find the patterns and split accordingly
matches = re.split(r"_wide|_long|_std", file_name)
if level == "census_tract":
if level_DG == "pre_2020" and "pre2020" in file_name:
matches = re.split(r"_wide|_long|_std|_pre2020", file_name)
elif level_DG == "post_2020" and "pre2020" not in file_name:
matches = re.split(r"_wide|_long|_std|_post2020", file_name)
else:
continue
else:
matches = re.split(r"_wide|_long|_std", file_name)

if matches:
processed_file_names.append(matches[0])
base_name = matches[0]
processed_file_names.append(base_name)



# Remove any remaining suffixes from the base names
feature_names = list(set(processed_file_names))
feature_names = [re.sub(r'(_pre2020|_post2020)$', '', name) for name in feature_names]

return sorted(feature_names)


def list_tensed_features(level="county"):
def list_tensed_features(level="county", level_DG="pre_2020"):
if level == "county":
data = DataGrabber()
all_features = list_available_features(level="county")
Expand All @@ -114,8 +164,8 @@ def list_tensed_features(level="county"):
all_features = list_available_features(level="msa")

elif level == "census_tract":
data = CTDataGrabberCSV() # TODO: Change to CTDataGrabber() in the future
all_features = list_available_features(level="Census_tract_level")
data = CTDataGrabberCSV(level_DG=level_DG) # TODO: Change to CTDataGrabber() in the future
all_features = list_available_features(level="census_tract", level_DG=level_DG)

else:
raise ValueError(
Expand Down
Loading

0 comments on commit af986a9

Please sign in to comment.