|
11 | 11 | import yaml
|
12 | 12 | from torch.utils.data import Dataset
|
13 | 13 |
|
14 |
| -from datasets.utils import encode_categorical_variables |
15 |
| -from datasets.utils import encode_numerical_variables |
| 14 | +from datasets.utils import encode_conditioning_variables |
16 | 15 |
|
17 | 16 | warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)
|
18 | 17 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
@@ -90,18 +89,16 @@ def load_and_preprocess_data(
|
90 | 89 | raise FileNotFoundError(f"Metadata file not found at {metadata_csv_path}")
|
91 | 90 |
|
92 | 91 | metadata = pd.read_csv(metadata_csv_path, usecols=metadata_columns)
|
| 92 | + |
93 | 93 | if "solar" in metadata.columns:
|
94 | 94 | metadata.rename(columns={"solar": "has_solar"}, inplace=True)
|
95 | 95 |
|
96 | 96 | data = self._load_full_data(path, data_columns)
|
97 | 97 | user_flags = self._set_user_flags(metadata, data)
|
98 | 98 | data = self._preprocess_data(data)
|
99 | 99 | data = pd.merge(data, metadata, on="dataid", how="left")
|
100 |
| - data = encode_categorical_variables(data) |
101 |
| - data = encode_numerical_variables( |
102 |
| - data, |
103 |
| - ["total_square_footage", "house_construction_year", "total_amount_of_pv"], |
104 |
| - ) |
| 100 | + data = self._handle_missing_data(data) |
| 101 | + data = encode_conditioning_variables(data) |
105 | 102 | return data, metadata, user_flags
|
106 | 103 |
|
107 | 104 | def _load_full_data(self, path: str, columns: List[str]) -> pd.DataFrame:
|
@@ -193,31 +190,24 @@ def _calculate_and_store_statistics(self, data: pd.DataFrame, column: str) -> Di
|
193 | 190 | """
|
194 | 191 |
|
195 | 192 | def calculate_stats(group):
|
196 |
| - # Concatenate all time series arrays in the group |
197 | 193 | all_values = np.concatenate(group[column].values)
|
198 |
| - |
199 |
| - # Standardization statistics |
200 | 194 | mean = np.mean(all_values)
|
201 | 195 | std = np.std(all_values)
|
202 | 196 |
|
203 |
| - # Perform standardization on all_values |
204 | 197 | standardized = (all_values - mean) / (std + 1e-8)
|
205 | 198 |
|
206 |
| - # Min-Max scaling statistics on standardized data |
207 | 199 | z_min = np.min(standardized)
|
208 | 200 | z_max = np.max(standardized)
|
209 | 201 |
|
210 | 202 | return pd.Series({"mean": mean, "std": std, "z_min": z_min, "z_max": z_max})
|
211 | 203 |
|
212 | 204 | if self.normalization_method == "group":
|
213 |
| - # Group by dataid, month, and weekday |
214 | 205 | grouped_stats = data.groupby(["dataid", "month", "weekday"]).apply(
|
215 | 206 | calculate_stats
|
216 | 207 | )
|
217 | 208 | return grouped_stats.to_dict(orient="index")
|
218 | 209 |
|
219 | 210 | elif self.normalization_method == "date":
|
220 |
| - # Group by month and weekday |
221 | 211 | grouped_stats = data.groupby(["month", "weekday"]).apply(calculate_stats)
|
222 | 212 | return grouped_stats.to_dict(orient="index")
|
223 | 213 |
|
@@ -251,15 +241,12 @@ def normalize_and_scale_row(row):
|
251 | 241 | z_min = stats["z_min"]
|
252 | 242 | z_max = stats["z_max"]
|
253 | 243 |
|
254 |
| - # Standardization |
255 | 244 | values = np.array(row[column])
|
256 | 245 | standardized = (values - mean) / (std + 1e-8)
|
257 | 246 |
|
258 |
| - # Optional Clipping after Standardization |
259 | 247 | if self.threshold:
|
260 | 248 | standardized = np.clip(standardized, *self.threshold)
|
261 | 249 |
|
262 |
| - # Min-Max Scaling on standardized data |
263 | 250 | scaled = (standardized - z_min) / (z_max - z_min + 1e-8)
|
264 | 251 |
|
265 | 252 | return scaled
|
@@ -293,6 +280,13 @@ def _preprocess_solar(self, data: pd.DataFrame) -> pd.DataFrame:
|
293 | 280 |
|
294 | 281 | return solar_data
|
295 | 282 |
|
| 283 | + def _handle_missing_data(self, data: pd.DataFrame) -> pd.DataFrame: |
| 284 | + data["car1"] = data["car1"].fillna("no") |
| 285 | + data["has_solar"] = data["has_solar"].fillna("no") |
| 286 | + |
| 287 | + assert data.isna().sum().sum() == 0, "Missing data remaining!" |
| 288 | + return data |
| 289 | + |
296 | 290 | @staticmethod
|
297 | 291 | def _merge_columns_into_timeseries(df: pd.DataFrame) -> pd.DataFrame:
|
298 | 292 | """
|
@@ -632,7 +626,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
632 | 626 | "car1": torch.tensor(sample["car1"], dtype=torch.long),
|
633 | 627 | "city": torch.tensor(sample["city"], dtype=torch.long),
|
634 | 628 | "state": torch.tensor(sample["state"], dtype=torch.long),
|
635 |
| - "has_solar": torch.tensor(sample["has_solar"], dtype=torch.long), # Updated |
| 629 | + "has_solar": torch.tensor(sample["has_solar"], dtype=torch.long), |
| 630 | + "total_square_footage": torch.tensor( |
| 631 | + sample["total_square_footage"], dtype=torch.long |
| 632 | + ), |
| 633 | + "house_construction_year": torch.tensor( |
| 634 | + sample["house_construction_year"], dtype=torch.long |
| 635 | + ), |
636 | 636 | }
|
637 | 637 |
|
638 | 638 | return (torch.tensor(time_series, dtype=torch.float32), conditioning_vars)
|
0 commit comments