-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MoleculeACE linear probe evaluation callback #31
Conversation
from lobster.datasets import MoleculeACEDataset | ||
|
||
|
||
class LinearProbeCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karinazad wdyt about separating the base class LinearProbeCallback (everything but on_validation_epoch_end) into its own file? Then have MoleculeAceLinearProbe or any others inherit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, I can refactor
embeddings = embeddings.numpy() | ||
targets = targets.numpy() | ||
|
||
probe = LinearRegression() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a TODO to extend this to logistic regression as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added logistic regression
x, y = batch | ||
|
||
# TODO: Handle multiple modalities in ModernBERT | ||
batch_embeddings = module.sequences_to_latents(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can add a TODO to use the pooling API as well to handle the embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added mean pooling for now
It's looking great! Just confirming, does this have the most recent commit? (separating LinearProbe, logreg, mean pooling) @karinazad |
@@ -136,6 +136,30 @@ def configure_optimizers(self): | |||
|
|||
return {"optimizer": optimizer, "lr_scheduler": scheduler} | |||
|
|||
def tokens_to_latents(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be a commonly used method - can you add a short docstring with an amino acid and smiles example?
) | ||
|
||
return hidden_states | ||
|
||
def sequences_to_latents(self, sequences: list[str]) -> list[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be refactored to use tokens_to_latents
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I meant to do that but forgot
No description provided.