Skip to content

Commit

Permalink
saving notebooks on this branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryaaa committed Sep 3, 2024
1 parent e4a734e commit c6350b7
Show file tree
Hide file tree
Showing 8 changed files with 1,992 additions and 24 deletions.
8 changes: 7 additions & 1 deletion CPC/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ def __init__(self, n_dims, n_classes, n_hidden_layers, n_hidden_dims):
last = nn.Linear(n_hidden_dims, n_classes)
self.layers.append(last)

self.activation = nn.ReLU()
self.final_activation = nn.Softmax(1)


def forward(self, x):
for layer in self.layers[:-1]:
x = layer(x)
return self.layers[-1](x)
x = self.activation(x)
x = self.layers[-1](x)
return self.final_activation(x)
2 changes: 1 addition & 1 deletion CPC/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"seed = 1\n",
"crop_size = 280\n",
"batch_size = 32\n",
"version = 10\n",
"version = 10 # first training this will be zero\n",
"checkpoint_dir = f\"checkpoints/version_{str(version)}\"\n",
"latent_dims = 2\n",
"\n",
Expand Down
346 changes: 327 additions & 19 deletions CPC/train.ipynb

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions CPC/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,17 @@ def train(
device,
checkpoint_dir,
patience,
metadata_training = None,
):
patience_counter = 0
best_loss = torch.inf
checkpoint_dir = get_next_version(checkpoint_dir)
writer = SummaryWriter(checkpoint_dir)

if metadata_training is not None:
torch.save(
metadata_training,
os.path.join(checkpoint_dir, "metadata_training.pt"),
)
encoder = encoder.to(device)
ar_model = ar_model.to(device)
query_weights = query_weights.to(device)
Expand Down Expand Up @@ -188,7 +193,7 @@ def train(
break
except KeyboardInterrupt:
print(f"Keyboard interrupt")
print(f"{epoch} epochs. Version {checkpoint_dir.split("_")[-1]}")
print(f"{epoch} epochs. Version {checkpoint_dir.split('_')[-1]}")

if __name__ == "__main__":
transform = torchvision.transforms.RandomCrop(crop_size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# %%
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score, StratifiedKFold, GridSearchCV, train_test_split
from sklearn.ensemble import RandomForestClassifier

from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
tabular_data = "/mnt/efs/dlmbl/G-et/tabular_data"

latent_spaces = {
"UNet_20z_Old_Normalisation":pd.read_csv(
Path(tabular_data)/"UNet_VAE_01_old_normalisation.csv"
),
"UNet_20z_New_Normalisation":pd.read_csv(
Path(tabular_data)/"UNet_VAE_02_new_normalisation.csv"
),
"Resnet18_26000z_Old_Normalisation":pd.read_csv(
Path(tabular_data)/"LinearVAE_01_bicubic_latents_w_annot.csv"
),
"Resnet18_26000z_New_Normalisation":pd.read_csv(
Path(tabular_data)/"LinearVAE_02_bicubic_latents_w_annot.csv"
),
}

df = latent_spaces["UNet_20z_Old_Normalisation"]
grouped_by_well = df.groupby(["Run","Plate","ID"])
n_samples = len(grouped_by_well)

# %%
group_dict = grouped_by_well.groups
group_keys = list(group_dict.keys())
group_keys[0]
# %
labels = [lab[0] for lab in grouped_by_well["Label"].unique().to_numpy()]
len(labels)
# %%
from sklearn.metrics import balanced_accuracy_score
gt_keys = ["Label","Time","Axes","Run","Plate","ID"]
results = {}

model_name =["SVC","RF","LDA"]

models = [
SVC(C=40, kernel='rbf'),
RandomForestClassifier(random_state=1,n_jobs=10,n_estimators=500,max_features=300),
LDA(solver='svd'),
]

for name, df in latent_spaces.items():
results[name] = {}
y = df[gt_keys]
X = df.drop(gt_keys,axis=1)
grouped_by_well = df.groupby(["Run","Plate","ID"])
n_samples = len(grouped_by_well)
group_dict = grouped_by_well.groups
group_keys = list(group_dict.keys())
labels = [lab[0] for lab in grouped_by_well["Label"].unique().to_numpy()]

for mod_name, model in zip(model_name,models):
print(f"{name}, {mod_name}")

# configure the cross-validation procedure
cv_outer = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)
# execute the nested cross-validation
scores = []
for i, (idx_keys_train,idx_keys_test) in enumerate(cv_outer.split(range(n_samples),labels)):
train_keys = []
for j in idx_keys_train:
train_keys.append(group_keys[j])
train_indices_df = np.concat(
[group_dict[key] for key in train_keys]
)
y_train = y.iloc[train_indices_df]["Label"]=="good"
test_keys = []
for j in idx_keys_test:
test_keys.append(group_keys[j])
test_indices_df = np.concat(
[group_dict[key] for key in test_keys]
)
y_test=y.iloc[test_indices_df]["Label"] == "good"

model.fit(X.iloc[train_indices_df],y_train)
predictions = model.predict(X.iloc[test_indices_df])
score = balanced_accuracy_score(
y_true=y_test,
y_pred=predictions
)
print(score)
scores.append(score)

# report performance
print('Accuracy: %.3f (%.3f)' % (np.mean(scores), np.std(scores)))
results[name][mod_name] = scores
# %%
all_data = []
for name, diction in results.items():
for k, v in diction.items():
scores = v
all_data.append(pd.DataFrame({"Accuracy":scores,"Classifier":np.full(len(scores), k),"Feature Set":np.full(len(scores), name)}))

all_classifier_results_df = pd.concat(all_data,axis=0,ignore_index=True)
all_classifier_results_df.to_csv(Path(tabular_data)/"classification results first good latent spaces.csv")
all_classifier_results_df
# %%
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

#out_fig_2 = home_directory + "/" + r"Plots\Fig 2"

sns.set()
fig, ax = plt.subplots(figsize=(5,3))

ax = sns.barplot(all_classifier_results_df,y="Accuracy",x="Feature Set",hue="Classifier", ax = ax, width = 0.8,saturation=1,errorbar=("sd",1),capsize=0.1,errwidth=1)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax.set_ylim([0.4,1])
plt.xticks(rotation=90)

#plt.savefig(f"{out_fig_2}/Morph Prediction.pdf", format="pdf", bbox_inches="tight")
plt.show()
# %%
Loading

0 comments on commit c6350b7

Please sign in to comment.