-
Notifications
You must be signed in to change notification settings - Fork 5
add maximum mean discrepancy metric #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -0,0 +1,89 @@ | |||
import numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better to use torch instead of numpy for consistency unless numpy is needed for a particular reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice if this worked with arrays of strings as well since that's a common data structure. We allowed numpy arrays here for the same reason: https://github.com/Genentech/beignet/blob/main/src/beignet/_farthest_first_traversal.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that being said I could see wanting this both differentiable and GPU-enabled... any thoughts on that? implement two versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casting to tensor opens the tokenization and padding can of worms and usually you just want something simple
Y, | ||
distance_fn=None, | ||
kernel_width: float | None = None, | ||
eps: float = 1e-16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation for lower bounding the MMD at sqrt(eps), rather than 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose I was trying to forestall division by 0 errors but upon reflection that's a decision downstream users should make, will revert lower bound to 0
ok I have attempted to make this compliant with the python array standard API and NEP 56 to support both numpy and pytorch arrays |
@@ -9,6 +9,7 @@ requires = [ | |||
[project] | |||
authors = [{ email = "[email protected]", name = "Allen Goodman" }] | |||
dependencies = [ | |||
"numpy>=2.0.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let torch manage the numpy dependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think torch doesn't depend on numpy now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that what I was seeing
@kleinhenz finally fixed the broken test, are we good to merge? |
I do think there is a valid question around whether we want to force beignet users to upgrade to NumPy 2.0 (to support the array API standard). @0x00b1 thoughts? |
estimates maximum mean discrepancy from samples