Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Strategy class to ensure consistent tensor operations for data normalization #403

Closed

Conversation

yalsaffar
Copy link
Contributor

This PR addresses the second part of issue #365, focusing on the Strategy class and how data is added and normalized, transitioning the process to use tensors instead of NumPy operations.

The changes were made specifically within the normalize_inputs method of the Strategy class. Previously, this method had mismatched docstrings indicating np.array usage. Now, it consistently accepts and returns tensors, performing all operations within tensors.

The normalize_inputs method is called in add_data() (where the confusion arises), as the data passed can vary (either tensors or np.array). To resolve this, the method now acts as the first step, accepting both formats and then converting everything to tensors for consistent operations (model fitting later on). It’s also crucial to ensure the data type is float64, as gpytorch does not support other data types.

Additionally, a detailed docstring was added to clarify the method's expectations and ensure its proper use going forward.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@facebook-github-bot
Copy link
Contributor

@JasonKChow has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@JasonKChow JasonKChow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know a later change will be to add type hints to everything but if you're editing a function now, just add in those typehints right now. It'll help us catch possible errors from static typing.

aepsych/strategy.py Outdated Show resolved Hide resolved
aepsych/strategy.py Show resolved Hide resolved
aepsych/strategy.py Show resolved Hide resolved
aepsych/strategy.py Show resolved Hide resolved
aepsych/strategy.py Show resolved Hide resolved
facebook-github-bot pushed a commit that referenced this pull request Oct 16, 2024
Summary:
This PR resolves mypy linter issues in `plotting.py` by adding `None` checks for `strat.model`, providing a missing type hint for the `locs` variable, and updating `matplotlib.markers` to use a string digit instead of an integer.

This was done because adding more strict type hints in other parts of the code revealed these errors, which need to be fixed before proceeding with #403 .

Pull Request resolved: #405

Reviewed By: crasanders

Differential Revision: D64415474

Pulled By: JasonKChow

fbshipit-source-id: 1bbb3cc28d7c15b193404f2ca6407a6996d4d062
@yalsaffar yalsaffar force-pushed the strategy-normalize_inputs-fix branch from 01cfbb5 to aeadbf1 Compare October 16, 2024 20:17
@facebook-github-bot
Copy link
Contributor

@JasonKChow has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@JasonKChow JasonKChow self-requested a review October 16, 2024 23:21
@facebook-github-bot
Copy link
Contributor

@JasonKChow merged this pull request in 09b1d59.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants