Skip to content

Commit

Permalink
Refactor plot formatting in MetaModelGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
leonvanbokhorst committed Nov 8, 2024
1 parent 0ada788 commit b2fc926
Showing 1 changed file with 22 additions and 29 deletions.
51 changes: 22 additions & 29 deletions src/maml_model_agnostic_meta_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ def compute_metrics(

return {"mse": mse, "mae": mae, "r2": r2, "rmse": np.sqrt(mse)}

def _setup_plot_formatting(self, xlabel: str, ylabel: str, title: str, add_legend: bool = True):
"""Helper method to set up common plot formatting"""
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
if add_legend:
plt.legend()

def visualize_adaptation(
self,
support_x: torch.Tensor,
Expand All @@ -206,7 +214,7 @@ def visualize_adaptation(
try:
plt.figure(figsize=(20, 5))

# Plot 1: Predictions (now including support points)
# Plot 1: Predictions
plt.subplot(1, 4, 1)
with torch.no_grad():
initial_pred = self.forward(query_x)
Expand All @@ -222,9 +230,10 @@ def visualize_adaptation(
plt.plot([query_y.min().item(), query_y.max().item()],
[query_y.min().item(), query_y.max().item()],
'r--', label='Perfect')
self._extracted_from_visualize_adaptation_30(
'True Values', 'Predicted Values', "Predictions vs True Values", 2
)
self._setup_plot_formatting('True Values', 'Predicted Values', 'Predictions vs True Values')

# Plot 2: Feature Importance
plt.subplot(1, 4, 2)
with torch.no_grad():
feature_importance = torch.zeros(query_x.shape[1])
for i in range(query_x.shape[1]):
Expand All @@ -233,20 +242,19 @@ def visualize_adaptation(
perturbed_pred = self.forward_with_fast_weights(perturbed_x, fast_weights)
feature_importance[i] = F.mse_loss(perturbed_pred, adapted_pred)

plt.bar(range(len(feature_importance)),
feature_importance.cpu().numpy())
self._extracted_from_visualize_adaptation_30(
'Feature Index', 'Importance (MSE Impact)', 'Feature Importance'
)
plt.bar(range(len(feature_importance)), feature_importance.cpu().numpy())
self._setup_plot_formatting('Feature Index', 'Importance (MSE Impact)', 'Feature Importance', False)

# Plot 3: Error Distribution
plt.subplot(1, 4, 3)
initial_errors = (initial_pred - query_y).cpu().numpy()
adapted_errors = (adapted_pred - query_y).cpu().numpy()
plt.hist(initial_errors, alpha=0.5, label='Pre-Adaptation', bins=20)
plt.hist(adapted_errors, alpha=0.5, label='Post-Adaptation', bins=20)
self._extracted_from_visualize_adaptation_30(
'Prediction Error', 'Count', 'Error Distribution', 4
)
self._setup_plot_formatting('Prediction Error', 'Count', 'Error Distribution')

# Plot 4: Adaptation Progress
plt.subplot(1, 4, 4)
progress_x = query_x[:5] # Track few points for visualization
progress_y = query_y[:5]
adaptation_steps = []
Expand All @@ -266,9 +274,8 @@ def visualize_adaptation(
temp_weights[name] = weight - self.inner_lr * grad

plt.plot(adaptation_steps, marker='o')
self._extracted_from_visualize_adaptation_30(
'Adaptation Step', 'MSE Loss', 'Adaptation Progress'
)
self._setup_plot_formatting('Adaptation Step', 'MSE Loss', 'Adaptation Progress')

# Add overall metrics
plt.suptitle(f"{task_name}\n"
f"MSE Before: {F.mse_loss(initial_pred, query_y):.4f}, "
Expand All @@ -286,20 +293,6 @@ def visualize_adaptation(
logger.error(f"Shapes - query_x: {query_x.shape}, query_y: {query_y.shape}")
return None

# TODO Rename this here and in `visualize_adaptation`
def _extracted_from_visualize_adaptation_30(self, arg0, arg1, arg2, arg3):
self._extracted_from_visualize_adaptation_30(arg0, arg1, arg2)
plt.legend()

# Plot 2: Feature Importance
plt.subplot(1, 4, arg3)

# TODO Rename this here and in `visualize_adaptation`
def _extracted_from_visualize_adaptation_30(self, arg0, arg1, arg2):
plt.xlabel(arg0)
plt.ylabel(arg1)
plt.title(arg2)


def create_synthetic_tasks(
num_tasks: int = 100,
Expand Down

0 comments on commit b2fc926

Please sign in to comment.