Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial values for the hidden/cell state for LSTM and GRU models in Pytorch #1120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JanFSchulte
Copy link
Contributor

@JanFSchulte JanFSchulte commented Nov 11, 2024

This PR addresses #1074 and implements the passing of initial values for the hidden and cell states in GRU and LSTM models, which is supported in pytorch. This first version implements this only for the pytorch parser, but it should be able to be extended it for keras and other parsers.

I have tested this for Vivado, Vitis, and Quartus. I don't have access to Catapult or oneAPI at the moment, so I haven't implemented this fort those backends.

Note that this currently only works in io_parallel. In io_stream I was having some conceptual issues and was unsure if I should treat these initial states are streamed inputs or not. Might be good enough for now and I can revisit io_stream if there are any suggestions how to tackle that.

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change which adds functionality)

Tests

Tested in both standalone scripts and also the pytests to ensure that model parsing and evaluation work with and without passing these optional tensors.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant