Skip to content

add maximum mean discrepancy metric #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

samuelstanton
Copy link
Contributor

estimates maximum mean discrepancy from samples

@@ -0,0 +1,89 @@
import numpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better to use torch instead of numpy for consistency unless numpy is needed for a particular reason

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice if this worked with arrays of strings as well since that's a common data structure. We allowed numpy arrays here for the same reason: https://github.com/Genentech/beignet/blob/main/src/beignet/_farthest_first_traversal.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that being said I could see wanting this both differentiable and GPU-enabled... any thoughts on that? implement two versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting to tensor opens the tokenization and padding can of worms and usually you just want something simple

Y,
distance_fn=None,
kernel_width: float | None = None,
eps: float = 1e-16,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for lower bounding the MMD at sqrt(eps), rather than 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I was trying to forestall division by 0 errors but upon reflection that's a decision downstream users should make, will revert lower bound to 0

@samuelstanton
Copy link
Contributor Author

ok I have attempted to make this compliant with the python array standard API and NEP 56 to support both numpy and pytorch arrays

@@ -9,6 +9,7 @@ requires = [
[project]
authors = [{ email = "[email protected]", name = "Allen Goodman" }]
dependencies = [
"numpy>=2.0.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let torch manage the numpy dependency

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think torch doesn't depend on numpy now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that what I was seeing

@samuelstanton
Copy link
Contributor Author

@kleinhenz finally fixed the broken test, are we good to merge?

@samuelstanton
Copy link
Contributor Author

samuelstanton commented Apr 3, 2025

I do think there is a valid question around whether we want to force beignet users to upgrade to NumPy 2.0 (to support the array API standard). @0x00b1 thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants