Skip to content

Commit

Permalink
akila's classifier and evaluation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
akila-14 committed Sep 6, 2024
1 parent 3cc0eae commit 47e5ea6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion notebooks/classifier_tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"from matplotlib.colors import ListedColormap\n",
"import umap\n",
"from embed_time.model_VAE_resnet18 import VAEResNet18\n",
"from datasets.neuromast import NeuromastDatasetTest, NeuromastDatasetTrain, NeuromastDatasetTrain_T10\n",
"from embed_time.neuromast import NeuromastDatasetTest, NeuromastDatasetTrain, NeuromastDatasetTrain_T10\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import confusion_matrix\n",
"from sklearn.preprocessing import LabelEncoder\n",
Expand Down
13 changes: 7 additions & 6 deletions scripts/20240902_ab_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,21 @@ def plot_image(model, dataloader, device):
#%%
def create_pca_plots(train_latents, val_latents, train_df, val_df):
# Step 1: Scale the features
scaler = StandardScaler()
train_latents_scaled = scaler.fit_transform(train_latents)
val_latents_scaled = scaler.transform(val_latents)
# scaler = StandardScaler()
# train_latents_scaled = scaler.fit_transform(train_latents)
# val_latents_scaled = scaler.transform(val_latents)

# Step 2: Perform PCA
pca = PCA(n_components=2)
train_latents_pca = pca.fit_transform(train_latents_scaled)
val_latents_pca = pca.transform(val_latents_scaled)
train_latents_pca = pca.fit_transform(train_latents)
val_latents_pca = pca.transform(val_latents)

# Step 3: Prepare the plot
fig, axes = plt.subplots(1,2, figsize=(25, 10))

# Helper function to create a color map
def create_color_map(n):
return ListedColormap(plt.cm.viridis(np.linspace(0, 1, n)))
return ListedColormap(plt.cm.Greens(np.linspace(0.25, 1, n)))
# Assuming you have 3 unique labels

# Step 3: Plot PCA for the training set
Expand Down Expand Up @@ -302,6 +302,7 @@ def create_color_map(n):
# plot_reconstructions(model, dataloader_train, device)

#%%
plot_reconstructions(model, dataloader_train, device)
plot_reconstructions(model, dataloader_val, device)


Expand Down

0 comments on commit 47e5ea6

Please sign in to comment.