Skip to content

Commit

Permalink
Implemented wrapper classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 9, 2024
1 parent b96323b commit 910fef2
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 70 deletions.
2 changes: 1 addition & 1 deletion datasets/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __len__(self):
def __getitem__(self, idx):
sample = self.data.iloc[idx]

time_series = sample[self.time_series_column]
time_series = sample[self.time_series_column_name]
time_series = torch.tensor(time_series, dtype=torch.float32)

conditioning_vars_dict = {}
Expand Down
74 changes: 37 additions & 37 deletions generator/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any
from typing import Dict

import pandas as pd
import torch

from datasets.timeseries_dataset import TimeSeriesDataset
from generator.diffcharge import DDPM
from generator.diffusion_ts import Diffusion_TS
from generator.diffcharge.diffusion import DDPM
from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS
from generator.gan.acgan import ACGAN
from generator.options import Options

Expand Down Expand Up @@ -49,41 +50,40 @@ def _initialize_model(self):
else:
raise ValueError(f"Model {self.model_name} not recognized.")

def fit(self, X: Any, timeseries_colname: Any):
"""
Train the model on the given dataset.
def fit(self, X):
"""
Train the model on the given dataset.
Args:
X: Input data. Should be a compatible dataset object or pandas DataFrame.
"""
if isinstance(X, pd.DataFrame):
dataset = self._prepare_dataset(X)
else:
dataset = X

sample_timeseries, sample_cond_vars = dataset[0]
expected_seq_len = self.model.opt.seq_len
assert (
sample_timeseries.shape[0] == expected_seq_len
), f"Expected timeseries length {expected_seq_len}, but got {sample_timeseries.shape[0]}"

if (
hasattr(self.model_params, "conditioning_vars")
and self.model_params.conditioning_vars
):
for var in self.model_params.conditioning_vars:
assert (
var in sample_cond_vars.keys()
), f"Conditioning variable '{var}' specified in model_params.conditioning_vars not found in dataset"

expected_input_dim = self.model.opt.input_dim
assert sample_timeseries.shape == (
expected_seq_len,
expected_input_dim,
), f"Expected timeseries shape ({expected_seq_len}, {expected_input_dim}), but got {sample_timeseries.shape}"

self.model.train_model(dataset)
Args:
df (Any): Input data. Should be a compatible dataset object or pandas DataFrame.
"""
if isinstance(X, pd.DataFrame):
dataset = self._prepare_dataset(X, timeseries_colname)
else:
dataset = X

sample_timeseries, sample_cond_vars = dataset[0]
expected_seq_len = self.model.opt.seq_len
assert (
sample_timeseries.shape[0] == expected_seq_len
), f"Expected timeseries length {expected_seq_len}, but got {sample_timeseries.shape[0]}"

if (
hasattr(self.model_params, "conditioning_vars")
and self.model_params.conditioning_vars
):
for var in self.model_params.conditioning_vars:
assert (
var in sample_cond_vars.keys()
), f"Conditioning variable '{var}' specified in model_params.conditioning_vars not found in dataset"

expected_input_dim = self.model.opt.input_dim
assert sample_timeseries.shape == (
expected_seq_len,
expected_input_dim,
), f"Expected timeseries shape ({expected_seq_len}, {expected_input_dim}), but got {sample_timeseries.shape}"

self.model.train_model(dataset)

def generate(self, conditioning_vars):
"""
Expand Down Expand Up @@ -147,8 +147,8 @@ def _prepare_dataset(
elif isinstance(df, pd.DataFrame):
dataset = TimeSeriesDataset(
dataframe=df,
time_series_column_name=timeseries_colname,
conditioning_vars=conditioning_vars,
time_series_column=timeseries_colname,
)
return dataset
else:
Expand Down
75 changes: 43 additions & 32 deletions playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,54 +101,65 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from datasets.pecanstreet import PecanStreetDataManager\n",
"\n",
"pv = PecanStreetDataManager(geography=\"austin\", include_generation=False).create_all_pv_user_dataset()\n",
"non_pv = PecanStreetDataManager(geography=\"austin\",include_generation=False).create_non_pv_user_dataset()"
"pv = PecanStreetDataManager(geography=\"austin\", include_generation=True).create_all_pv_user_dataset()\n",
"non_pv = PecanStreetDataManager(geography=\"austin\",include_generation=True).create_non_pv_user_dataset()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pv.data.dataid.nunique()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 62%|██████▏ | 133/213 [00:06<00:03, 20.05it/s]\n"
]
},
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 25\u001b[0m\n\u001b[1;32m 9\u001b[0m conditioning_dict \u001b[39m=\u001b[39m {\n\u001b[1;32m 10\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmonth\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m12\u001b[39m,\n\u001b[1;32m 11\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mweekday\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m7\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mhouse_construction_year\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m5\u001b[39m\n\u001b[1;32m 19\u001b[0m }\n\u001b[1;32m 24\u001b[0m dataset \u001b[39m=\u001b[39m TimeSeriesDataset(df, \u001b[39m\"\u001b[39m\u001b[39mtimeseries\u001b[39m\u001b[39m\"\u001b[39m, conditioning_dict)\n\u001b[0;32m---> 25\u001b[0m generator\u001b[39m.\u001b[39;49mfit(dataset, \u001b[39m\"\u001b[39;49m\u001b[39mtimeseries\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
"File \u001b[0;32m~/EnData/generator/data_generator.py:86\u001b[0m, in \u001b[0;36mDataGenerator.fit\u001b[0;34m(self, X, timeseries_colname)\u001b[0m\n\u001b[1;32m 80\u001b[0m expected_input_dim \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mopt\u001b[39m.\u001b[39minput_dim\n\u001b[1;32m 81\u001b[0m \u001b[39massert\u001b[39;00m sample_timeseries\u001b[39m.\u001b[39mshape \u001b[39m==\u001b[39m (\n\u001b[1;32m 82\u001b[0m expected_seq_len,\n\u001b[1;32m 83\u001b[0m expected_input_dim,\n\u001b[1;32m 84\u001b[0m ), \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mExpected timeseries shape (\u001b[39m\u001b[39m{\u001b[39;00mexpected_seq_len\u001b[39m}\u001b[39;00m\u001b[39m, \u001b[39m\u001b[39m{\u001b[39;00mexpected_input_dim\u001b[39m}\u001b[39;00m\u001b[39m), but got \u001b[39m\u001b[39m{\u001b[39;00msample_timeseries\u001b[39m.\u001b[39mshape\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m---> 86\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel\u001b[39m.\u001b[39;49mtrain_model(dataset)\n",
"File \u001b[0;32m~/EnData/generator/gan/acgan.py:301\u001b[0m, in \u001b[0;36mACGAN.train_model\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 299\u001b[0m _lambda \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msparse_conditioning_loss_weight\n\u001b[1;32m 300\u001b[0m N_r \u001b[39m=\u001b[39m rare_mask\u001b[39m.\u001b[39msum()\u001b[39m.\u001b[39mitem()\n\u001b[0;32m--> 301\u001b[0m N_nr \u001b[39m=\u001b[39m (torch\u001b[39m.\u001b[39;49mlogical_not(rare_mask))\u001b[39m.\u001b[39;49msum()\u001b[39m.\u001b[39;49mitem()\n\u001b[1;32m 302\u001b[0m N \u001b[39m=\u001b[39m current_batch_size\n\u001b[1;32m 303\u001b[0m g_loss \u001b[39m=\u001b[39m (\n\u001b[1;32m 304\u001b[0m _lambda \u001b[39m*\u001b[39m (N_r \u001b[39m/\u001b[39m N) \u001b[39m*\u001b[39m g_loss_rare\n\u001b[1;32m 305\u001b[0m \u001b[39m+\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m-\u001b[39m _lambda) \u001b[39m*\u001b[39m (N_nr \u001b[39m/\u001b[39m N) \u001b[39m*\u001b[39m g_loss_non_rare\n\u001b[1;32m 306\u001b[0m )\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"non_pv.data.dataid.nunique()"
"from datasets.timeseries_dataset import TimeSeriesDataset\n",
"from generator.data_generator import DataGenerator\n",
"\n",
"df = pv.data\n",
"df\n",
"\n",
"generator = DataGenerator(\"acgan\", model_params=None)\n",
"\n",
"conditioning_dict = {\n",
" \"month\": 12,\n",
" \"weekday\": 7,\n",
" \"city\": 6,\n",
" \"building_type\": 3,\n",
" \"has_solar\": 2,\n",
" \"car1\": 2,\n",
" \"state\": 3,\n",
" \"total_square_footage\": 5,\n",
" \"house_construction_year\": 5\n",
"}\n",
"\n",
"dataset = TimeSeriesDataset(df, \"timeseries\", conditioning_dict)\n",
"generator.fit(dataset, \"timeseries\")"
]
},
{
Expand Down

0 comments on commit 910fef2

Please sign in to comment.