Skip to content

Commit

Permalink
eval cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 4, 2024
1 parent 1630e7e commit 35c367b
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
4 changes: 2 additions & 2 deletions config/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ diffusion_ts:

acgan:
batch_size: 32
n_epochs: 10
n_epochs: 200
lr_gen: 3e-4
lr_discr: 1e-4
warm_up_epochs: 5
warm_up_epochs: 50
80 changes: 80 additions & 0 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from eval.metrics import calculate_mmd
from eval.metrics import calculate_period_bound_mse
from eval.metrics import dynamic_time_warping_dist
from eval.metrics import plot_range_with_syn_values
from eval.metrics import plot_syn_with_closest_real_ts
from eval.metrics import visualization
from eval.predictive_metric import predictive_score_metrics
from generator.diffcharge.diffusion import DDPM
from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS
Expand Down Expand Up @@ -197,6 +200,9 @@ def evaluate_subset(
# Compute metrics
self.compute_metrics(real_data_array, syn_data_array, real_data_subset, writer)

# Generate plots
self.create_visualizations(real_data_inv, syn_data_inv, dataset, model, writer)

# Close the writer
writer.flush()
writer.close()
Expand Down Expand Up @@ -250,6 +256,80 @@ def compute_metrics(
writer.add_scalar("Predictive/score", pred_score)
self.metrics["predictive"].append(pred_score)

def create_visualizations(
self,
real_data_df: pd.DataFrame,
syn_data_df: pd.DataFrame,
dataset: Any,
model: Any,
writer: SummaryWriter,
num_samples: int = 100,
num_runs: int = 3,
):
"""
Create various visualizations for the evaluation results.
Args:
real_data_df (pd.DataFrame): Inverse-transformed real data.
syn_data_df (pd.DataFrame): Inverse-transformed synthetic data.
dataset (Any): The dataset object.
model (Any): The trained model.
writer (SummaryWriter): TensorBoard writer for logging visualizations.
num_samples (int): Number of samples to generate for visualization.
num_runs (int): Number of visualization runs.
"""
for i in range(num_runs):
# Sample a conditioning variable combination from real data
sample_row = real_data_df.sample(n=1).iloc[0]
conditioning_vars_sample = {
var_name: torch.tensor(
[sample_row[var_name]] * num_samples,
dtype=torch.long,
device=device,
)
for var_name in model.categorical_dims.keys()
}

generated_samples = model.generate(conditioning_vars_sample).cpu().numpy()
if generated_samples.ndim == 2:
generated_samples = generated_samples.reshape(
generated_samples.shape[0], -1, generated_samples.shape[1]
)

generated_samples_df = pd.DataFrame(
{
var_name: [sample_row[var_name]] * num_samples
for var_name in model.categorical_dims.keys()
}
)
generated_samples_df["timeseries"] = list(generated_samples)
generated_samples_df["dataid"] = sample_row[
"dataid"
] # required for inverse transform
generated_samples_df = dataset.inverse_transform(generated_samples_df)

# Extract month and weekday for plotting
month = sample_row.get("month", None)
weekday = sample_row.get("weekday", None)

# Visualization 1: Plot range with synthetic values
range_plot = plot_range_with_syn_values(
real_data_df, generated_samples_df, month, weekday
)
writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot)

# Visualization 2: Plot closest real signals with synthetic values
closest_plot = plot_syn_with_closest_real_ts(
real_data_df, generated_samples_df, month, weekday
)
writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot)

# Visualization 3: KDE plots for real and synthetic data
real_data_array = np.stack(real_data_df["timeseries"])
syn_data_array = np.stack(syn_data_df["timeseries"])
kde_plot = visualization(real_data_array, syn_data_array, "kernel")
writer.add_figure(f"Visualizations/KDE", kde_plot)

def get_trained_model(self, dataset: Any) -> Any:
"""
Get a trained model for the dataset.
Expand Down
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def evaluate_single_dataset_model(
# evaluator.evaluate_all_users()
# evaluator.evaluate_all_non_pv_users()
non_pv_user_evaluator.evaluate_model(
None, distinguish_rare=True, data_label="non_pv_users"
None, distinguish_rare=False, data_label="non_pv_users"
)
pv_user_evaluator.evaluate_model(
None, distinguish_rare=False, data_label="pv_users"
)
pv_user_evaluator.evaluate_model(None, distinguish_rare=True, data_label="pv_users")


def main():
# evaluate_individual_user_models("gpt", include_generation=False)
# evaluate_individual_user_models("acgan", include_generation=True)
# evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date")
evaluate_single_dataset_model(
"diffusion_ts",
"acgan",
geography="california",
include_generation=False,
normalization_method="group",
Expand Down

0 comments on commit 35c367b

Please sign in to comment.