-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234
[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good start. The main thing to fix is the handling of the input and output tensors.
The
|
Thanks for the review, Andrew. I've marked smaller Nits as resolved, and I've addressed other comments as well, but kept them 'unresolved' as of now until you take a look. |
&input_tensor.data, | ||
&dimensions, | ||
kind, | ||
)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is almost there: we're still ignoring the index, though. What might need to happen is that we set up inputs as a Vec<Option<TchTensor>>
filled with None
and then set the right item based on the index. We are able to retrieve the number of inputs, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index is being ignored because jit-compiled pytorch models expect tensors in order, similar to what is achieved by vector of tensors in set_input()
. Do we need to use index here to keep things consistent with wasi-nn? I don't think I saw a direct way to retrieve the number of inputs to a model, I can look into that further. However, someone using pytorch would probably not intend on using index.
This was my previous comment on ignoring the index:
This is one of the differences for this backend. The module's forward method should handle multiple inputs appropriately if it does support multiple inputs. The vector of input tensors being passed to forward should be sufficient, no index or name is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean use the index as the position/index for the vector of tensors? That sounds good. I'll see if there's a way to determine the max number of inputs to the given model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, so here's what I've changed -
set_input()
now uses theid
passed to it. If it isu32
, it assigns the tensor to the vectorinputs[id]
, assigningNone
if there are any empty spots below the givenid
. This will give the user flexibility to set input tensors non sequentially.compute()
will check the vector for anyNone
values, and give an error if present, before callingforward_ts()
.- There is no reliable way at this time to get max inputs for a model. However, assigning more inputs than available returns an error message detailing the max inputs the model expects at run time, so this is helpful for the user. The expectation, similar to other backends, is for the user to be aware of input size/shape etc.
- If
id
is a string, I've currently permitted only single input to be set, effectively ignoring the index. If more inputs are assigned, or ifu32
andString
indexes are used together, we give out an error.
Let me know if this looks fine. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, looking at 727a4aa, I think that makes more sense. I was kind of hoping that CModule::named_parameters
would return the list of inputs by name, but, if that's not the case, then let's just return an error for the Id::Name
side of things. In any case, we probably don't need id_type
: if we get an Id::Index
we know were to put it in the vector and if we get an Id::Name
then we either (a) look up its index in named_parameters
or (b) if that's not possible, return an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had added the Id:Name case because the wit test cases (nn_wit_image_classification_pytorch
in this case) need a Name instead of Index to pass - nn.rs and wasi-nn.wit. Although it looks like set-input
might be going away anyway.
786b073
to
7f84ed4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we find a smaller model or download this instead? Not all Wasmtime users probably want to download this file... twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made the following changes -
- Created a repository to generate jit-compiled libtorch models, and download models from Release for tests and examples.
https://github.com/rahulchaphalkar/libtorch-models
https://github.com/rahulchaphalkar/libtorch-models/releases/tag/v0.1 - Switched from Resnet18(45MB) to Squeezenet1.1(4.5MB) as the file to download from above repo.
- This
libtorch-models
repo can issue new releases if we update the pytorch/libtorch version in this backend from current 2.4.0 to ensure model compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the identical 44.7MB file checked in again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downloading from external repo, same as tests
As described in the [contribution guidelines], Wasmtime will exempt dependencies from vetting that receive at least 10,000 downloads a day. This substantially reduces the burden for vetting this PR, so I've tallied up daily downloads (across all versions) for the crates in this PR, listed below. This change then exempts the new dependencies that meet the 10K+ criteria. [contribution guidelines]: https://docs.wasmtime.dev/contributing-coding-guidelines.html#policy-for-adding-cargo-vet-entries ``` > aes 2024-10-02 111734 2024-10-03 107324 2024-10-04 104299 2024-10-05 32397 2024-10-06 29507 2024-10-07 123368 2024-10-08 125732 > base64ct 2024-10-02 179848 2024-10-03 157938 2024-10-04 149495 2024-10-05 48118 2024-10-06 43389 2024-10-07 183254 2024-10-08 175378 > bzip2 2024-10-02 89309 2024-10-03 85112 2024-10-04 76573 2024-10-05 27152 2024-10-06 24124 2024-10-07 90228 2024-10-08 93314 > bzip2-sys 2024-10-02 109664 2024-10-03 102677 2024-10-04 94485 2024-10-05 33196 2024-10-06 28417 2024-10-07 110195 2024-10-08 110951 > cipher 2024-10-02 1119 2024-10-03 377 2024-10-04 270 2024-10-05 178 2024-10-06 271 2024-10-07 2105 2024-10-08 1777 > constant_time_eq 2024-10-02 137462 2024-10-03 126300 2024-10-04 121927 2024-10-05 169156 2024-10-06 139559 2024-10-07 304529 2024-10-08 246533 > crunchy 2024-10-02 197832 2024-10-03 176586 2024-10-04 172053 2024-10-05 187875 2024-10-06 153647 2024-10-07 359240 2024-10-08 304777 > deranged 2024-10-02 319691 2024-10-03 285298 2024-10-04 267760 2024-10-05 104537 2024-10-06 92306 2024-10-07 309831 2024-10-08 308869 > digest 2024-10-02 2128 2024-10-03 1335 2024-10-04 1474 2024-10-05 594 2024-10-06 726 2024-10-07 3079 2024-10-08 2855 > half 2024-10-02 161525 2024-10-03 144013 2024-10-04 137296 2024-10-05 49246 2024-10-06 42437 2024-10-07 157366 2024-10-08 165013 > hmac 2024-10-02 1254 2024-10-03 394 2024-10-04 322 2024-10-05 230 2024-10-06 424 2024-10-07 2068 2024-10-08 1907 > inout 2024-10-02 1114 2024-10-03 366 2024-10-04 281 2024-10-05 184 2024-10-06 285 2024-10-07 2000 2024-10-08 1782 > matrixmultiply 2024-10-02 52273 2024-10-03 49931 2024-10-04 48408 2024-10-05 17219 2024-10-06 13950 2024-10-07 53916 2024-10-08 52644 > ndarray 2024-10-02 28922 2024-10-03 29354 2024-10-04 27397 2024-10-05 10480 2024-10-06 9074 2024-10-07 30988 2024-10-08 32344 > num-complex 2024-10-02 178444 2024-10-03 159144 2024-10-04 146722 2024-10-05 48522 2024-10-06 39138 2024-10-07 171363 2024-10-08 172915 > num-conv 2024-10-02 298495 2024-10-03 267134 2024-10-04 250350 2024-10-05 97809 2024-10-06 87399 2024-10-07 293150 2024-10-08 290661 > num-integer 2024-10-02 333731 2024-10-03 300418 2024-10-04 287516 2024-10-05 227416 2024-10-06 190413 2024-10-07 487348 2024-10-08 433744 > password-hash 2024-10-02 22429 2024-10-03 20702 2024-10-04 21550 2024-10-05 9061 2024-10-06 8660 2024-10-07 25743 2024-10-08 22404 > pbkdf2 2024-10-02 77885 2024-10-03 76192 2024-10-04 72278 2024-10-05 148944 2024-10-06 119322 2024-10-07 248354 2024-10-08 190649 > powerfmt 2024-10-02 310293 2024-10-03 277178 2024-10-04 259885 2024-10-05 101195 2024-10-06 89789 2024-10-07 302058 2024-10-08 300192 > rawpointer 2024-10-02 53917 2024-10-03 50649 2024-10-04 48439 2024-10-05 17375 2024-10-06 14761 2024-10-07 56228 2024-10-08 55013 > safetensors 2024-10-02 2253 2024-10-03 1737 2024-10-04 1798 2024-10-05 1085 2024-10-06 1544 2024-10-07 1742 2024-10-08 2024 > sha1 2024-10-02 1410 2024-10-03 673 2024-10-04 772 2024-10-05 230 2024-10-06 416 2024-10-07 2125 2024-10-08 2204 > tch 2024-10-02 1930 2024-10-03 2295 2024-10-04 2834 2024-10-05 1274 2024-10-06 455 2024-10-07 2290 2024-10-08 2181 > time 2024-10-02 303042 2024-10-03 271434 2024-10-04 255795 2024-10-05 100194 2024-10-06 88810 2024-10-07 297807 2024-10-08 295315 > time-core 2024-10-02 334979 2024-10-03 302165 2024-10-04 282918 2024-10-05 109319 2024-10-06 96522 2024-10-07 324779 2024-10-08 322102 > torch-sys 2024-10-02 1911 2024-10-03 2300 2024-10-04 2843 2024-10-05 1271 2024-10-06 452 2024-10-07 2292 2024-10-08 2177 > zip 2024-10-02 22520 2024-10-03 23201 2024-10-04 20946 2024-10-05 9067 2024-10-06 8470 2024-10-07 24674 2024-10-08 24870 > zstd 2024-10-02 175155 2024-10-03 167766 2024-10-04 157489 2024-10-05 52753 2024-10-06 44844 2024-10-07 177411 2024-10-08 173785 > zstd-safe 2024-10-02 179288 2024-10-03 170379 2024-10-04 159352 2024-10-05 52820 2024-10-06 45835 2024-10-07 180535 2024-10-08 177703 ```
For dependencies that did not have clear 10k+ daily downloads, this change audits them for `safe-to-deploy`.
This adds external audits pulled in automatically by `cargo vet` for the remainder of the dependencies not covered by previous commits.
aaf6597
to
da1f467
Compare
da1f467
to
39f9271
Compare
I rolled back 2 commits, but there's an issue with the lock file. I had previously attempted to fix this by deleting my lockfile, rebasing off of latest main, and then doing a |
The failing tests fail due to
I'm hoping a rerun of CI would help. |
@abrown can you take a look |
This change adds a PyTorch backend for wasi-nn.
tch crate is used for Libtorch bindings. I have added an image classification example to demonstrate its usage, which uses a torchscript model.
This backend is currently gated behind a wasi-nn feature flag
--features pytorch
as due to dynamic linking, a Libtorch v2.4.0 installation on the system (specified byLIBTORCH=/path/to/libtorch
) is needed for building.