-
-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a variable skip_unrolling in class Metric #3258
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @simeetnayan81
Left a comment about the implementation
ignite/metrics/metric.py
Outdated
@@ -300,7 +300,7 @@ def compute(self): | |||
_required_output_keys = required_output_keys | |||
|
|||
def __init__( | |||
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") | |||
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also update above docstring by adding the new argument. Please also check CONTRIBUTING guideline the part about how to add .. versionadded::
tag in the bottom of the docstring. Version to put should be 0.5.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I add .. versionchanged::
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct, it should be versionchanged tag, https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#writing-documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On it. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have made the changes. Kindly review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! I added few minor comments. Let's add new feature tests and run the CI to see if any failures
Co-authored-by: vfdev <[email protected]>
Co-authored-by: vfdev <[email protected]>
Tests should be added to end of the test_metric.py file? |
Yes, you can add it in the end of the file |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @simeetnayan81 , thanks, let's just add an example of usage of the flag in the docstring and it will be good to go, once CI is green (except unrelated failures)
@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this.
Change to:
|
@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR. |
… skip_unrolling arg
Things to do in a follow-up PR.
|
Thanks for the updates and the TODO. Can we do this point here ?
|
Alright @vfdev-5 |
Have made the changes, the new test works locally. |
The test is failing because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @simeetnayan81 , lgtm
Fixes #2940
Description:
Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977
Check list: