Skip to content

Commit

Permalink
fixing time elapse for plot utils
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-yuan committed Nov 12, 2024
1 parent 95eae15 commit d75a08f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
16 changes: 10 additions & 6 deletions src/utils/plot_combined_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def combine_and_plot(
xlabels: List[str],
ylabels: List[str],
output_dir: str,
include_logs: bool = True
) -> None:
"""
Combine and plot metrics from multiple experiments with 95% confidence intervals.
Expand All @@ -28,6 +29,9 @@ def combine_and_plot(
# Ensure the output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)

if not os.path.exists(os.path.join(output_dir, 'plots')):
os.makedirs(os.path.join(output_dir, 'plots'))

for metric_index, metric_name in enumerate(metrics_list):
plt.figure(figsize=(12, 8), dpi=300)
Expand All @@ -36,21 +40,21 @@ def combine_and_plot(
# Load the aggregated metric DataFrame for the experiment
metric_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_avg.csv"))
# TODO: modify this to be CI
std_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_std.csv"))
# ci_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_ci95.csv"))
# std_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_std.csv"))
ci_df = pd.read_csv(os.path.join(experiment_path, f"{metric_name}_ci95.csv"))

# Assuming rounds or time steps are the index
rounds = np.arange(len(metric_df))

# Plot each experiment's mean with 95% CI
mean_metric = metric_df.values.flatten()
std_metric = std_df.values.flatten()
# ci_95 = ci_df.values.flatten()
# std_metric = std_df.values.flatten()
ci_95 = ci_df.values.flatten()

# Plot the mean with confidence interval as a shaded area
plt.plot(rounds, mean_metric, label=f'{experiment_key}', linestyle='--', linewidth=1.5)
# plt.fill_between(rounds, mean_metric - ci_95, mean_metric + ci_95, alpha=0.2)
plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2)
plt.fill_between(rounds, mean_metric - ci_95, mean_metric + ci_95, alpha=0.2)
# plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2)

# Plot customization
plt.xlabel(xlabels[metric_index], fontsize=14)
Expand Down
12 changes: 9 additions & 3 deletions src/utils/post_hoc_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,17 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru
rounds=all_users_data['rounds'],
metric_name=key,
ylabel=display_name,
output_dir=f'{logs_dir}plots/',
output_dir=f'{logs_dir}/plots/',
plot_avg_only=plot_avg_only,
**kwargs
)

if per_time:
time_data = os.path.join(logs_dir, 'node_1/csv/time_elapsed.csv')
# check if time elapsed data exists
if not os.path.exists(time_data):
print("Time elapsed data not found. Skipping per-time plotting.")
return
all_users_data = aggregate_per_realtime_data(logs_dir, **kwargs)

for key, display_name in metrics_map.items():
Expand All @@ -439,7 +445,7 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru
time_ticks=all_users_data[key].index.values,
metric_name=key,
ylabel=display_name,
output_dir=f'{logs_dir}plots/',
output_dir=f'{logs_dir}/plots/',
plot_avg_only=plot_avg_only,
**kwargs
)
Expand Down Expand Up @@ -468,7 +474,7 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru
try:
print(f"Processing logs in: {logs_dir}")
avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir)
plot_all_metrics(logs_dir)
plot_all_metrics(logs_dir, per_round=True, per_time=True, plot_avg_only=True)
except Exception as e:
print(f"Error processing {logs_dir}: {e}")
continue

0 comments on commit d75a08f

Please sign in to comment.