-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Strategy class to ensure consistent tensor operations for data normalization #403
Fix Strategy class to ensure consistent tensor operations for data normalization #403
Conversation
@JasonKChow has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know a later change will be to add type hints to everything but if you're editing a function now, just add in those typehints right now. It'll help us catch possible errors from static typing.
Summary: This PR resolves mypy linter issues in `plotting.py` by adding `None` checks for `strat.model`, providing a missing type hint for the `locs` variable, and updating `matplotlib.markers` to use a string digit instead of an integer. This was done because adding more strict type hints in other parts of the code revealed these errors, which need to be fixed before proceeding with #403 . Pull Request resolved: #405 Reviewed By: crasanders Differential Revision: D64415474 Pulled By: JasonKChow fbshipit-source-id: 1bbb3cc28d7c15b193404f2ca6407a6996d4d062
01cfbb5
to
aeadbf1
Compare
@JasonKChow has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@JasonKChow merged this pull request in 09b1d59. |
This PR addresses the second part of issue #365, focusing on the
Strategy
class and how data is added and normalized, transitioning the process to use tensors instead of NumPy operations.The changes were made specifically within the
normalize_inputs
method of theStrategy
class. Previously, this method had mismatched docstrings indicatingnp.array
usage. Now, it consistently accepts and returns tensors, performing all operations within tensors.The
normalize_inputs
method is called inadd_data()
(where the confusion arises), as the data passed can vary (either tensors ornp.array
). To resolve this, the method now acts as the first step, accepting both formats and then converting everything to tensors for consistent operations (model fitting later on). It’s also crucial to ensure the data type isfloat64
, asgpytorch
does not support other data types.Additionally, a detailed docstring was added to clarify the method's expectations and ensure its proper use going forward.