Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoleculeACE linear probe evaluation callback #31

Merged
merged 2 commits into from
Feb 19, 2025

Conversation

karinazad
Copy link
Collaborator

No description provided.

from lobster.datasets import MoleculeACEDataset


class LinearProbeCallback(Callback):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karinazad wdyt about separating the base class LinearProbeCallback (everything but on_validation_epoch_end) into its own file? Then have MoleculeAceLinearProbe or any others inherit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, I can refactor

embeddings = embeddings.numpy()
targets = targets.numpy()

probe = LinearRegression()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a TODO to extend this to logistic regression as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added logistic regression

x, y = batch

# TODO: Handle multiple modalities in ModernBERT
batch_embeddings = module.sequences_to_latents(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add a TODO to use the pooling API as well to handle the embeddings

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added mean pooling for now

@taylormjs
Copy link
Collaborator

It's looking great! Just confirming, does this have the most recent commit? (separating LinearProbe, logreg, mean pooling) @karinazad

@karinazad karinazad changed the title Draft: MoleculeACE linear probe evaluation callback MoleculeACE linear probe evaluation callback Feb 19, 2025
@karinazad karinazad merged commit 4fc1ac5 into main Feb 19, 2025
5 checks passed
@karinazad karinazad deleted the ume-eval-linear-probe-moleculeace branch February 19, 2025 13:38
@@ -136,6 +136,30 @@ def configure_optimizers(self):

return {"optimizer": optimizer, "lr_scheduler": scheduler}

def tokens_to_latents(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be a commonly used method - can you add a short docstring with an amino acid and smiles example?

)

return hidden_states

def sequences_to_latents(self, sequences: list[str]) -> list[torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be refactored to use tokens_to_latents?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I meant to do that but forgot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants