From d75a08f2673890fcd0ce72995a9da38f8ba07db4 Mon Sep 17 00:00:00 2001 From: joyce-yuan Date: Tue, 12 Nov 2024 18:27:53 +0000 Subject: [PATCH] fixing time elapse for plot utils --- src/utils/plot_combined_exp.py | 16 ++++++++++------ src/utils/post_hoc_plot_utils.py | 12 +++++++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/utils/plot_combined_exp.py b/src/utils/plot_combined_exp.py index 182e26a..919d5ba 100644 --- a/src/utils/plot_combined_exp.py +++ b/src/utils/plot_combined_exp.py @@ -12,6 +12,7 @@ def combine_and_plot( xlabels: List[str], ylabels: List[str], output_dir: str, + include_logs: bool = True ) -> None: """ Combine and plot metrics from multiple experiments with 95% confidence intervals. @@ -28,6 +29,9 @@ def combine_and_plot( # Ensure the output directory exists if not os.path.exists(output_dir): os.makedirs(output_dir) + + if not os.path.exists(os.path.join(output_dir, 'plots')): + os.makedirs(os.path.join(output_dir, 'plots')) for metric_index, metric_name in enumerate(metrics_list): plt.figure(figsize=(12, 8), dpi=300) @@ -36,21 +40,21 @@ def combine_and_plot( # Load the aggregated metric DataFrame for the experiment metric_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_avg.csv")) # TODO: modify this to be CI - std_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_std.csv")) - # ci_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_ci95.csv")) + # std_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_std.csv")) + ci_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_ci95.csv")) # Assuming rounds or time steps are the index rounds = np.arange(len(metric_df)) # Plot each experiment's mean with 95% CI mean_metric = metric_df.values.flatten() - std_metric = std_df.values.flatten() - # ci_95 = ci_df.values.flatten() + # std_metric = std_df.values.flatten() + ci_95 = ci_df.values.flatten() # Plot the mean with confidence interval as a shaded area plt.plot(rounds, mean_metric, label=f'{experiment_key}', linestyle='--', linewidth=1.5) - # plt.fill_between(rounds, mean_metric - ci_95, mean_metric + ci_95, alpha=0.2) - plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2) + plt.fill_between(rounds, mean_metric - ci_95, mean_metric + ci_95, alpha=0.2) + # plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2) # Plot customization plt.xlabel(xlabels[metric_index], fontsize=14) diff --git a/src/utils/post_hoc_plot_utils.py b/src/utils/post_hoc_plot_utils.py index de57122..176414d 100644 --- a/src/utils/post_hoc_plot_utils.py +++ b/src/utils/post_hoc_plot_utils.py @@ -426,11 +426,17 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru rounds=all_users_data['rounds'], metric_name=key, ylabel=display_name, - output_dir=f'{logs_dir}plots/', + output_dir=f'{logs_dir}/plots/', plot_avg_only=plot_avg_only, **kwargs ) + if per_time: + time_data = os.path.join(logs_dir, 'node_1/csv/time_elapsed.csv') + # check if time elapsed data exists + if not os.path.exists(time_data): + print("Time elapsed data not found. Skipping per-time plotting.") + return all_users_data = aggregate_per_realtime_data(logs_dir, **kwargs) for key, display_name in metrics_map.items(): @@ -439,7 +445,7 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru time_ticks=all_users_data[key].index.values, metric_name=key, ylabel=display_name, - output_dir=f'{logs_dir}plots/', + output_dir=f'{logs_dir}/plots/', plot_avg_only=plot_avg_only, **kwargs ) @@ -468,7 +474,7 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru try: print(f"Processing logs in: {logs_dir}") avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir) - plot_all_metrics(logs_dir) + plot_all_metrics(logs_dir, per_round=True, per_time=True, plot_avg_only=True) except Exception as e: print(f"Error processing {logs_dir}: {e}") continue \ No newline at end of file