Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General model experiments #3

Open
EchteRobert opened this issue Feb 18, 2022 · 17 comments
Open

General model experiments #3

EchteRobert opened this issue Feb 18, 2022 · 17 comments
Assignees
Labels

Comments

@EchteRobert
Copy link
Collaborator

This issue is used to test more general aspects of model development not directly related to, but likely still influenced by, the dataset or model hyperparameters that are used at that time.

@EchteRobert
Copy link
Collaborator Author

EchteRobert commented Feb 18, 2022

Experiment

During training I sample cells from each well to create uniform tensors, this is a requirement for creating batches of data larger than 1. It is generally known in contrastive learning/metric learning that a large batch size is essential for good model performance. This is also something that I showed in the first issue: #1.
To test if the model's performance is affected by this sampling I evaluate the model's performance using a range of number of sampled cells and also test it without sampling, i.e. collapsing all existing cell features into a profile like with the mean.

Hypothesis

Model performance will be slightly improved by using all existing cells without extra sampling because the sampling may pick out cells that are not as representative of the perturbation.

Main takeaways

  • Up to a certain threshold, sampling more or less cells appears to have little effect on the PR metric
  • As the number of sampled cells decreases the correlation between replicates starts to decrease. However, this is only represented in the PR score once the number of sampled cells decreases to a very low number. For this dataset this is somewhere between 200 and 300 cells. This last number may be tied to the histograms that show the number of cells per well per plate, shown below.

From here on evaluation will be done without sampling, simply by collapsing all cells into a feature representation.

Exciting stuff here!

no sampling
MLP_BS64_noSampling_PR

600 cells sampled
MLP_BS64_600cells_PR

400 cells sampled
MLP_BS64_400cells_PR

200 cells sampled
MLP_BS64_200cells_PR

100 cells sampled
MLP_BS64_100cells_PR

50 cells sampled
MLP_BS64_50cells_PR

cell distribution histogram
NrCellsHist

@EchteRobert
Copy link
Collaborator Author

EchteRobert commented Apr 13, 2022

Model capacity experiment 1

Goal: Test if the current training method (model architecture, optimizer, and loss function) is capable of learning to distinguish point sets from each other while their means are the same.

Background:
This experiment is the first one in a series of experiments that aims to figure out what the feature aggregation model is learning. We expect that it is learning to both select cells and features when aggregating single-cell level feature data, however this is hard to verify. So instead, this experiment aims to test the model’s capacity to learn the covariance matrix of the input data.

The general setup is as follows:

  • Generate n classes (let’s say n=10), where each sample is a low dimensional (e.g. d=2, for visualization purposes) matrix generated from a multivariate gaussian distribution.
  • All of the classes will have the same mean feature vector, but will have different covariance matrices.
  • The goal of the model is now to learn to cluster these classes based on their covariance matrices.

To create the data:

  • I generate a low dimensional (in this exp. d=2) covariance matrix [[1, 0.5], [0.5, 1]] and then create all subsequent 9 classes (n=10 classes total) by rotating this matrix.
  • I then generate m random points (in this exp. m=800) and rotate those by computing the dot product between the points and the covariance matrix, for each class.
  • For each class I generate these m points k times (in this exp. k=4).
  • The mean of all generated points is the same, in this case set to [0, 0].
  • In summary: all n classes each have k samples of m points in d dimensional space.

Training process:

  • The training setup is similar to the feature aggregation method I am developing.
  • 6 of the n classes are used to train the model and 4 of the n classes are used as a validation set. These 4 validation classes can be found in the figure shown below: top left, top right, bottom left, bottom middle.
  • The model is the same, although with far fewer parameters:
    model = MLPsumV2(input_dim=2, latent_dim=64, output_dim=2, k=4, dropout=0, cell_layers=1, proj_layers=2, reduction='sum')
  • The optimizer and loss function are identical. The model is trained for 20 epochs with a batch size of nr_train_samples//4 (in this exp. bs=6).

If it is able to complete this task, we can verify that this architecture is able to learn covariance matrices, and that most likely the feature aggregation model I am proposing is also able to do (or doing) this for the single-cell level feature data.

Main takeaways

  • The model is able to distinguish the 4 classes in the validation set from each other accurately, meaning that is able to learn (something related to) the covariance matrices.
  • Decreasing the number of points used per sample from 800 to 80 significantly decreases the model's performance, although it is still better than the baseline.

Results

Cool ellipses!

CovarianceClasses

Model performance

Total model mAP: 0.9589285714285715
Total model precision at R: 0.9166666666666666

class AP precision at R
0 0 1 1
1 0 1 1
2 0 1 1
3 0 1 1
7 4 1 1
8 5 1 1
9 5 1 1
10 5 1 1
11 5 1 1
12 7 1 1
13 7 1 1
14 7 1 1
15 7 1 1
4 4 0.866667 0.666667
5 4 0.833333 0.666667
6 4 0.642857 0.333333
Baseline (mean) performance

Total baseline (mean) mAP: 0.3086709586709586
Total baseline (mean) precision at R: 0.20833333333333331

class AP precision at R
11 5 0.471429 0.333333
8 5 0.465368 0.333333
12 7 0.460317 0.333333
2 0 0.447619 0.333333
3 0 0.447619 0.333333
14 7 0.316667 0.333333
13 7 0.305556 0.333333
0 0 0.280952 0.333333
7 4 0.256614 0.333333
4 4 0.25 0
9 5 0.240909 0
15 7 0.233333 0.333333
1 0 0.206044 0
5 4 0.205026 0
10 5 0.189377 0
6 4 0.161905 0

@EchteRobert
Copy link
Collaborator Author

EchteRobert commented Apr 14, 2022

Model capacity experiment 2

Continuing on the previous idea of learning higher order (i.e. covariance matrix and higher) statistics with the current model setup, this experiment aims to learn to repeat the previous one but on real data. More specifically, the Stain3 data is used as described here #6. I train using plates BR00115134_FS, BR00115125_FS, and BR00115133highexp_FS and validate on the rest.

Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level. I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless. I used this data because it is faster than preprocessing all of the data again.

I am not too sure what this means for the standard deviation (or covariance matrix) on the well level. What do you think @shntnu ?

Results

Table Time!
plate Training mAP model Training mAP BM Validation mAP model Validation mAP BM PR model PR BM
Training plates
BR00115134 0.44 0.03 0.24 0.04 90 12.2
BR00115125 0.37 0.03 0.25 0.04 94.4 6.7
BR00115133highexp 0.38 0.02 0.17 0.02 92.2 2.2
Validation plates
BR00115128highexp 0.33 0.03 0.25 0.04 81.1 11.1
BR00115125highexp 0.29 0.03 0.22 0.02 78.9 5.6
BR00115131 0.32 0.03 0.22 0.03 85.6 11.1
BR00115133 0.32 0.03 0.16 0.04 83.3 5.6
BR00115127 0.29 0.03 0.22 0.05 84.4 3.3
BR00115128 0.33 0.03 0.29 0.04 86.7 2.2
BR00115129 0.3 0.03 0.26 0.04 82.2 5.6
BR00115126 0.2 0.03 0.22 0.04 50 7.8

@shntnu
Copy link
Collaborator

shntnu commented Apr 20, 2022

I am not too sure what this means for the standard deviation (or covariance matrix) on the well level.

My notes are below

Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level.

^^^ This is fine

I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless.

Perfect, because you've not scaled to unit variance (otherwise you'd be looking for structure in the correlation matrix, instead of the covariance matrix)

I used this data because it is faster than preprocessing all of the data again.

That worked out well!

I've not peeked into the results but I am eagerly looking forward to your talk tomorrow where you might discuss more 🎉

@shntnu
Copy link
Collaborator

shntnu commented Apr 20, 2022

Ok, I couldn't contain my excitement so I looked at the results :D

I just picked one at random, and focused only on

  • Validation mAP model: 0.25
  • Validation mAP BM : 0.04
plate Training mAP model Training mAP BM Validation mAP model Validation mAP BM PR model PR BM
Validation plates
BR00115128highexp 0.33 0.03 0.25 0.04 81.1 11.1

This is fantastic, right?!

@EchteRobert
Copy link
Collaborator Author

Yes I think so too! I think this is proof that we can learn more than the mean. The mAP of the benchmark looks random (I believe it should be ~1/30, but the math around mAP is not as intuitive to me as the precision at K :) ).
Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?

@shntnu
Copy link
Collaborator

shntnu commented Apr 20, 2022

Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?

I didn't understand – you're already learning the covariance because you only mean subtract wells, right?

@EchteRobert
Copy link
Collaborator Author

Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow

@shntnu
Copy link
Collaborator

shntnu commented Apr 20, 2022

Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow

Ah by fixing you mean factoring out – got it

For that, you'd spherize instead of mean subtracting

That will be totally shocking it if works!!

Even second order is pretty awesome (IMO, unless there's something trivial happening here https://broadinstitute.slack.com/archives/C025JFCBQAK/p1650466733918839?thread_ts=1649774854.681729&cid=C025JFCBQAK)

@shntnu
Copy link
Collaborator

shntnu commented Apr 20, 2022

PS – unless something super trivial is happening here that we haven't caught, I think you might be pretty close to having something you can write up. Let's get together with @johnarevalo for his advice, maybe next week

@EchteRobert
Copy link
Collaborator Author

Worth to give it a shot ;) Sounds good!

@shntnu
Copy link
Collaborator

shntnu commented Apr 24, 2022

For the toy data, also standardize after rotating and see what happens then. The idea is that we don’t yet know if it is learning covariance or just standard deviation

@EchteRobert
Copy link
Collaborator Author

Repeat of same experiment with standardized feature dimensions

Based on @shntnu's previous comment.

Main takeaways

Total model mAP: 0.93
Total model precision at R: 0.92

Total baseline (mean) mAP: 0.33
Total baseline (mean) precision at R: 0.21

The model is still beating the baseline (random) performance.

Ellipsoid classes after standardizing points

CovarianceClassesStandardized

mean Average Precision scores

Total model mAP: 0.9282512626262627
Total model precision at R: 0.9166666666666666

compound AP precision at R
0 0 1 1
1 0 1 1
2 0 1 1
3 0 1 1
4 4 1 1
5 4 1 1
6 4 1 1
7 4 1 1
12 7 1 1
13 7 1 1
15 7 1 1
9 5 0.916667 1
10 5 0.866667 0.666667
11 5 0.866667 1
14 7 0.757576 0.666667
8 5 0.444444 0.333333

Total baseline (mean) mAP: 0.32838989713989714
Total baseline (mean) precision at R: 0.20833333333333331

compound AP precision at R
8 5 0.638889 0.666667
14 7 0.54359 0.333333
1 0 0.474074 0.333333
12 7 0.4 0.333333
13 7 0.383333 0.333333
9 5 0.361111 0.333333
11 5 0.354701 0.333333
4 4 0.347222 0.333333
6 4 0.289683 0.333333
10 5 0.277778 0
0 0 0.244444 0
7 4 0.214286 0
3 0 0.197619 0
2 0 0.188889 0
5 4 0.185606 0
15 7 0.153014 0

@shntnu
Copy link
Collaborator

shntnu commented Apr 25, 2022

The model is still beating the baseline (random) performance.

Great!

And I verified, as sanity check, that the baseline hasn't changed (much) from before #3 (comment)

The model results don't change much either (correlation vs covariance; details below)

BTW, you show 10 ellipses but you have 16 rows in your results. Why is that?

Correlation

From the most recent results #3 (comment)

Total model mAP: 0.9282512626262627
Total model precision at R: 0.9166666666666666

Covariance

From the results 12 days ago #3 (comment)

Total model mAP: 0.9589285714285715
Total model precision at R: 0.9166666666666666

@EchteRobert
Copy link
Collaborator Author

EchteRobert commented Apr 25, 2022

Great, thanks for checking!

There's 10 classes, but 4 samples (of 800 points each) of each class. The validation set consists of 4 classes, so 16 samples total. I report all samples individually here, normally I take the mean per class (compound).

Interesting to note (perhaps for myself in the future): I had to reduce the learning rate by a factor of 100 (lr: 1e-5) to learn the correlation with the model adequately compared to learning the covariance (lr: 1e-3).

@EchteRobert
Copy link
Collaborator Author

EchteRobert commented Apr 26, 2022

Experiment 3 - Sphering the toy data

In this experiment I sample 800*4 points using each covariance matrix class for the distribution, then I sphere this sample and subsequently subsample it to create pairs for training. I increase the number of epochs from 40 to 1000 as the model is not able to fit the data otherwise.

Main takeaways

  • After heavy sphering of the data the model is no longer able to learn how to discern the different classes.
  • I tried a few hyperparameters, like varying model width, learning rate, and number of epochs but with no (consistent) success.
  • After medium or low sphering the model is still able to beat the baseline, although it requires many more training steps with a smaller learning rate.
Regularization 0.01 - heavy sphering

Spherize0_01

Total model mAP: 0.2943837412587413
Total model precision at R: 0.125

Total baseline (mean) mAP: 0.25055043336293337
Total baseline (mean) precision at R: 0.125

full tables

Model

compound AP precision at R
10 5 0.535354 0.333333
11 5 0.535354 0.333333
13 7 0.516667 0.333333
2 0 0.293651 0.333333
3 0 0.289683 0
4 4 0.288889 0.333333
8 5 0.268687 0
12 7 0.266667 0
9 5 0.25 0
0 0 0.238095 0
1 0 0.233333 0
5 4 0.22906 0.333333
14 7 0.227273 0
15 7 0.227273 0
6 4 0.165568 0
7 4 0.144589 0

Baseline

compound AP precision at R
2 0 0.455556 0.333333
3 0 0.451282 0.333333
5 4 0.273016 0.333333
11 5 0.255495 0
8 5 0.254701 0.333333
13 7 0.251852 0.333333
4 4 0.240741 0
6 4 0.240741 0
9 5 0.229798 0
0 0 0.225397 0.333333
1 0 0.215812 0
14 7 0.212169 0
15 7 0.199074 0
12 7 0.177778 0
10 5 0.169841 0
7 4 0.155556 0
Regularization 0.1 - medium sphering

Spherize0_1

Total model mAP: 0.6088789682539683
Total model precision at R: 0.5

Total baseline (mean) mAP: 0.2500837125837126
Total baseline (mean) precision at R: 0.125

full tables

Model

compound AP precision at R
0 0 1 1
2 0 1 1
3 0 1 1
12 7 0.791667 0.666667
9 5 0.722222 0.666667
1 0 0.638889 0.666667
11 5 0.638889 0.666667
5 4 0.591667 0.333333
8 5 0.588889 0.666667
15 7 0.569444 0.333333
10 5 0.555556 0.666667
14 7 0.533333 0.333333
6 4 0.341667 0
4 4 0.319444 0
13 7 0.302778 0
7 4 0.147619 0

Baseline

compound AP precision at R
2 0 0.455556 0.333333
3 0 0.455556 0.333333
5 4 0.273016 0.333333
8 5 0.254701 0.333333
13 7 0.251852 0.333333
11 5 0.24359 0
6 4 0.240741 0
9 5 0.229798 0
14 7 0.228836 0
0 0 0.225397 0.333333
1 0 0.220862 0
4 4 0.220539 0
15 7 0.205026 0
12 7 0.177778 0
10 5 0.165568 0
7 4 0.152525 0
Regularization 0.3 - low sphering

Spherize0_03

Total model mAP: 0.722172619047619
Total model precision at R: 0.625

Total baseline (mean) mAP: 0.25495106745106744
Total baseline (mean) precision at R: 0.10416666666666666

full tables

Model

compound AP precision at R
0 0 1 1
1 0 1 1
2 0 1 1
3 0 1 1
11 5 0.916667 0.666667
13 7 0.916667 0.666667
14 7 0.916667 0.666667
12 7 0.755556 0.666667
5 4 0.7 0.333333
9 5 0.626984 0.666667
8 5 0.622222 0.666667
6 4 0.588889 0.666667
7 4 0.5 0.333333
4 4 0.47619 0.333333
15 7 0.31746 0.333333
10 5 0.21746 0

Baseline

compound AP precision at R
2 0 0.455556 0.333333
3 0 0.455556 0.333333
5 4 0.333333 0.333333
13 7 0.288889 0.333333
11 5 0.24359 0
6 4 0.240741 0
1 0 0.233333 0
9 5 0.229798 0
0 0 0.22906 0.333333
8 5 0.226923 0
4 4 0.212602 0
14 7 0.210606 0
15 7 0.205026 0
12 7 0.184615 0
10 5 0.17033 0
7 4 0.159259 0

@shntnu
Copy link
Collaborator

shntnu commented Apr 26, 2022

  • After heavy sphering of the data the model is no longer able to learn how to discern the different classes.

Ah, this is expected (and thus, good!) because your data is almost surely fully explained by its second-order moments - because that's how you generated it – and sphering factors that out.

The story is different with your real data – there, it will almost sure not be fully explained by second-order moments (although that doesn't mean you can do better)

After medium or low sphering the model is still able to beat the baseline, although it requires many more training steps with a smaller learning rate.

Perfect! as expected, and it's great that you quantified it in terms of how much more complicated it is

Note that medium / low sphering show be roughly equivalent to medium / low value for major/minor axis

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants