-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resurrect JIT functionality and loading of models #146
Comments
@NiklasGustafsson What's your thinking on this for the roadmap for TorchSharp? It feels like it should be there? |
I think this is one area that requires very careful consideration. I'm not convinced that TorchSharp needs to follow the PyTorch recipe for externalizing models exactly. While we have spent a lot of effort making the eager evaluation of models as close as possible to what PyTorch developers are doing, we do not have to make the externalized model experience the same. For example, there may be a role for Roslyn analyzers for the C# solution, rather than a clone of the Python solution. Likewise for F# -- there may be a F#-specific solution. I think this should come post-v1.0, and we also need to consider how we will support ONNX export from TorchSharp. To me, they are two sides of the same coin. |
Currently, I could be wrong, it seems the continuity from pytorch saved model to Torchsharp is through ONNX. There is a need for a proof of concept that the state dictionary exported (using TorshSharp exportsd.py script) from pyTorch can be loaded back to TorchSharp by having the requirement that both the TorchSharp and PyTorch models are compatible. As Nikas has implied in a few occasions, one possible option is the TorchSharp support for ONNX import of PyTorch exported ONNX.
It seems TorchSharp ONNX import is equally and perhaps more important than TorchSharp ONNX export. Again, I hope this issue is evaluated post-v1.0 Congratulation of rapid progress towards v1.0! |
@NiklasGustafsson I see for loading (though I'd still think we need to support loading of the pytorch format, if Java or Rust or Scala bindings do, we should check) What about JIT? That's really what I was asking about, I forgot this issue referred to loading too. |
Here are relevant links from Rust PyTorch bindings https://github.com/LaurentMazare/tch-rs#using-some-pre-trained-model https://github.com/LaurentMazare/tch-rs/tree/main/examples/jit I haven't checked how the ".ot" files are extracted. At least, we should be matching this |
Looks like they have
|
Based on new intiative from @dsyme, I recall the challenges of Netron dealing with PyTorch support.
Rust binding approachThen I look into the Rust binding approach as suggested by @dsyme Rust seems to work only with TorchScript instead of PyTorch model Example (tracing an existing module): import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)
# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
Given a "PyTorch" trained and save model, (e.g. pth or pt) the extension itself does not reveal if it is a PyTorch model or a TorchScript serialized model. (lots of discussion on that) To address that in Rust.
The rust CModule is described here However, it seems in practice, Rust goes from a save Jit model (e.g. "pt"), through an intermediate format ".nz", before ending up with the "ot" format. This is what I could gather and I am not 100% sure if my understanding is correct. |
@dsyme, @GeorgeS2019 -- we can and should look at JIT post-v1.0. In my opinion, we should decide on what is right in the context of .NET and the tooling options we have at our disposal. Tracing is simple, it's the scripting option of TorchScript that is going to take some thought and consideration. I think ONNX is a good solution for exchange between languages (that's what it's for), but the reason for JIT support is primarily about performance and being able to deploy models for inference on a wide variety of platforms, so the JIT question is distinct from loading, IMO. I'm working on an example (for the examples repo) with a simple ONNX -> C# converter that relies on the .NET model binary format. It should be ready in a week or so. |
In some way, I am thinking along what Niklas has said. We may need to learn from the PyTorch experience faced by both Netron and Rust. To avoid not knowing the origin of PyTorch "model" and confusion caused by that, Netron recommends ONNX and not officially supporting PyTorch. For Rust, it requires a Controled workflow of migrating a TorchScript to a proprietary Rust "ot" format that is non standard and only used within Rust, or perhaps a subset of Rust projects dealing with PyTorch. Recommendation
If there is a need to deal with TorchScript, provides a tool to detect if a given file is a PyTorch Model or TorchScript. Provide a brief explanation why only deals with TorchScript. Instead of following the path of Rust, TorchScript will be converted to ONNX within TorchSharp. There is no need to deal with a proprietary non standard TorchSharp format which is currently the case. |
The core design philosophy for TorchSharp seems to be "if it's in LibTorch, and stable, then TorchSharp should expose it". That is from the perspective of the coherence of the TorchSharp design: TorchSharp is the one and only wrapper for LibTorch in the .NET ecosystem. As a .NET Foundation project I think that's the right approach. So for any LibTorch functionality (incl JIT and loading) I'd say "if it's there, if it's stable, if it's feasible, if other languages bind to loading, then TorchSharp should in theory do to". That means I believe we should say we would accept community PRs for these. It's not saying we need to put it into v1.0 ourselves, but it's hard to argue against accepting these things if someone contributes them. I guess there's chance that @gbaydin and I may like to use the JIT functionality from DiffSharp to allow quotation-based compilation. Or I might try to build a different toolchain that compiles models written in F# statically using FSharp.Compiler.Service and JIT. |
What I learn from this recent use case suggest the following challenges when using save pytorch model ( more like saving dict_states) in torchsharp. if __name__ == '__main__':
# Create model
model = BasicConv1d(1, 32)
#Export model to .dat file for ingestion into TorchSharp
f = open("bug510.dat", "wb")
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
f.close() One has to examine if all the save dict states have been implemented internally in torchsharp for the torchsharp I discuss here so more users check to cover (if any) more missing internal implementation (if possible submitting PRs) to make mode.load more reliable when loading saved pytorch dict_states. Once we have a reliable model.load for almost all scenarios, then it will be less challenging implement torchsharp codes to load ONNX in torchsharp. Any comment?? How can we coordinate with the community to write unit tests to increase code coverage for achieving reliable model.load when exporting using the provided torchsharp python code. |
Model.load() has nothing to do with loading ONNX models in TorchSharp, it relates only to loading model state_dicts() generated from either TorchSharp or PyTorch. |
I got the part on ONNX and state_dict() and nothing to do with Model.Load(). Model.Load() ONLY works if
|
We have discussed the limitation of pytorch pth model Here @fwaris addresses that by creating a parser going through TensorFlow saved model.
Since @michaelgsharp ML.NET can integrate Tensorflow save model. I wonder if the necessary parser needed to extract the graph structure similar to that of TfCheckpoint already available? |
@GeorgeS2019 at this time, TfCheckpoint only reads the tensor data (variables) - not the model structure. TfCheckpoint has minimal dependencies (really just IronSnappy and the required protobuf definitions). However, Tensorflow.NET seems to have the required functionality but the documentation is sparse. The "Restoring the Model" part is not written yet. |
This has now been implemented -- TorchSharp can load and saved TorchScript files, but not create them from scratch. See PR #644 |
The PyTorch C++ guide shows that loading of modules means using the JIT loading: https://pytorch.org/tutorials/advanced/cpp_export.html
The JIT functionality was removed after the extensive churn in Pytorch C++ API in v1.01 --> 1.50.
It needs to be resurrected.
The more direct Load/Save that was present on modules should probably not be resurrected directly.
Also need to add ONNX load/save support
The text was updated successfully, but these errors were encountered: