diff --git a/README.md b/README.md index a29d13d..8e09058 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,11 @@ sure that the PyTorch model handles a **two-dimensional** input matrix! Accordin many optimization problems. However, you can explicitly request the generation of the Hessian by passing `generate_jac_jac=True`. +L4CasADi v2 can use the new **torch compile** functionality starting from PyTorch 2.4. By passing `scripting=False`. This +will lead to a longer compile time on first L4CasADi function call but will lead to a overall faster +execution. However, currently this functionality is experimental and not fully stable across all models. In the long +term there is a good chance this will become the default over scripting once the functionality is stabilized by the +Torch developers. ## Table of Content - [Projects using L4CasADi](#projects-using-l4casadi) diff --git a/examples/acados.py b/examples/acados.py index d58ad03..25ba6f8 100644 --- a/examples/acados.py +++ b/examples/acados.py @@ -128,7 +128,7 @@ def ocp(self): ocp.cost.W = np.array([[1.]]) # Trivial PyTorch index 0 - l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr', model_expects_batch_dim=False) + l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr') ocp.model.cost_y_expr = l4c_y_expr(x) ocp.model.cost_y_expr_e = x[0] diff --git a/examples/cpp_usage/generate.py b/examples/cpp_usage/generate.py index fe4f41c..9748edb 100644 --- a/examples/cpp_usage/generate.py +++ b/examples/cpp_usage/generate.py @@ -10,7 +10,7 @@ def forward(self, x): def generate(): - l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c') + l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c') sym_in = cs.MX.sym('x', 1, 1) diff --git a/examples/fish_turbulent_flow/utils.py b/examples/fish_turbulent_flow/utils.py index b69f1c4..5f9f4dd 100644 --- a/examples/fish_turbulent_flow/utils.py +++ b/examples/fish_turbulent_flow/utils.py @@ -266,7 +266,7 @@ def import_l4casadi_model(device): x = cs.MX.sym("x", 3) xn = (x - meanX) / stdX - y = l4c.L4CasADi(model, name="turbulent_model", model_expects_batch_dim=True)(xn) + y = l4c.L4CasADi(model, name="turbulent_model", generate_adj1=False, generate_jac_jac=True)(xn.T).T y = y * stdY + meanY fU = cs.Function("fU", [x], [y[0]]) fV = cs.Function("fV", [x], [y[1]]) diff --git a/examples/matlab/export.py b/examples/matlab/export.py index 66549f8..6350d9d 100644 --- a/examples/matlab/export.py +++ b/examples/matlab/export.py @@ -10,7 +10,7 @@ def forward(self, x): def generate(): - l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c') + l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c') sym_in = cs.MX.sym('x', 1, 1) l4casadi_model.build(sym_in) return diff --git a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py index f1172fb..d4cd455 100644 --- a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py +++ b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py @@ -10,6 +10,7 @@ CASE = 1 +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' def polynomial(n, n_eval): """Generates a symbolic function for a polynomial of degree n-1""" @@ -175,7 +176,7 @@ def main(): strict=False, ) # -------------------------- Create L4CasADi Module -------------------------- # - l4c_nerf = l4c.L4CasADi(model) + l4c_nerf = l4c.L4CasADi(model, scripting=False) # ---------------------------------------------------------------------------- # # NLP warmup # diff --git a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py index a9f8b10..c77a5c4 100644 --- a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py +++ b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py @@ -8,6 +8,7 @@ import l4casadi as l4c from density_nerf import DensityNeRF +import os CASE = 1 diff --git a/examples/readme.py b/examples/readme.py index 7bf012b..f68af36 100644 --- a/examples/readme.py +++ b/examples/readme.py @@ -25,9 +25,9 @@ def forward(self, x): pyTorch_model = MultiLayerPerceptron() -l4c_model = l4c.L4CasADi(pyTorch_model, model_expects_batch_dim=True, device='cpu') # device='cuda' for GPU +l4c_model = l4c.L4CasADi(pyTorch_model, device='cpu') # device='cuda' for GPU -x_sym = cs.MX.sym('x', 2, 1) +x_sym = cs.MX.sym('x', 1, 2) y_sym = l4c_model(x_sym) f = cs.Function('y', [x_sym], [y_sym]) df = cs.Function('dy', [x_sym], [cs.jacobian(y_sym, x_sym)]) diff --git a/examples/simple_nlp.py b/examples/simple_nlp.py index d8923be..f74eaf2 100644 --- a/examples/simple_nlp.py +++ b/examples/simple_nlp.py @@ -14,7 +14,7 @@ def forward(self, input): f = PyTorchObjectiveModel() # objective -f = l4c.L4CasADi(f, name='f', model_expects_batch_dim=False)(x) +f = l4c.L4CasADi(f, name='f')(x) class PyTorchConstraintModel(torch.nn.Module): @@ -23,7 +23,7 @@ def forward(self, input): g = PyTorchConstraintModel() # constraint -g = l4c.L4CasADi(g, name='g', model_expects_batch_dim=False)(x) +g = l4c.L4CasADi(g, name='g')(x) nlp = {'x': x, 'f': f, 'g': g} diff --git a/l4casadi/l4casadi.py b/l4casadi/l4casadi.py index 146de4b..e6cb9ed 100644 --- a/l4casadi/l4casadi.py +++ b/l4casadi/l4casadi.py @@ -48,6 +48,7 @@ def __init__(self, generate_adj1: bool = True, generate_jac_adj1: bool = True, generate_jac_jac: bool = False, + scripting: bool = True, mutable: bool = False): """ :param model: PyTorch model. @@ -65,11 +66,21 @@ def __init__(self, :param generate_adj1: If True, the Adjoint of the model is tried to be generated. :param generate_jac_adj1: If True, the Jacobain of the Adjoint of the model is tried to be generated. :param generate_jac_jac: If True, the Hessian of the model is tried to be generated. + :param scripting: If True, the model is traced using TorchScript. If False, the model is compiled. :param mutable: If True, enables updating the model online via the update method. """ if platform.system() == "Windows": warnings.warn("L4CasADi is currently not supported for Windows.") + if not scripting: + warnings.warn("L4CasADi with Torch AOT compilation is experimental at this point and might not work as " + "expected.") + if torch.__version__ < torch.torch_version.TorchVersion('2.4.0'): + raise RuntimeError("For PyTorch versions < 2.4.0 L4CasADi only supports jit scripting. Please pass " + "scripting=True.") + import torch._inductor.config as config + config.freezing = True + self.model = model self.naive = False if isinstance(self.model, NaiveL4CasADiModule): @@ -94,6 +105,8 @@ def __init__(self, self._generate_jac_adj1 = generate_jac_adj1 self._generate_jac_jac = generate_jac_jac + self._scripting = scripting + self._mutable = mutable self._input_shape: Tuple[int, int] = (-1, -1) @@ -284,6 +297,7 @@ def _generate_cpp_function_template(self, has_jac: bool, has_adj1: bool, has_jac 'has_adj1': 'true' if has_adj1 else 'false', 'has_jac_adj1': 'true' if has_jac_adj1 else 'false', 'has_jac_jac': 'true' if has_jac_jac else 'false', + 'scripting': 'true' if self._scripting else 'false', 'model_is_mutable': 'true' if self._mutable else 'false', 'batched': 'true' if self.batched else 'false', 'jac_ccs_len': len(jac_ccs) if self.batched else 0, @@ -327,8 +341,8 @@ def _trace_jac_model(self, inp): def with_batch_dim(x): return torch.func.vmap(jacrev(self.model))(x[:, None])[:, 0].permute(1, 0, 2, 3) - return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) - return make_fx(functionalize(jacrev(self.model), remove='mutations_and_views'))(inp) + return make_fx(functionalize(with_batch_dim, remove='mutations'))(inp) + return make_fx(functionalize(jacrev(self.model), remove='mutations'))(inp) def _trace_adj1_model(self): p_d = torch.zeros(self._input_shape).to(self.device) @@ -337,7 +351,7 @@ def _trace_adj1_model(self): def _vjp(p, x): return vjp(self.model, p)[1](x)[0] - return make_fx(functionalize(_vjp, remove='mutations_and_views'))(p_d, t_d) + return make_fx(functionalize(_vjp, remove='mutations'))(p_d, t_d) def _trace_jac_adj1_model(self): p_d = torch.zeros(self._input_shape).to(self.device) @@ -351,8 +365,8 @@ def _vjp(p, x): def with_batch_dim(p, x): return torch.func.vmap(jacfwd(_vjp))(p[:, None], x[:, None])[:, 0].permute(3, 2, 0, 1) - return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(p_d, t_d) - return make_fx(functionalize(jacfwd(_vjp), remove='mutations_and_views'))(p_d, t_d) + return make_fx(functionalize(with_batch_dim, remove='mutations'))(p_d, t_d) + return make_fx(functionalize(jacfwd(_vjp), remove='mutations'))(p_d, t_d) def _trace_hess_model(self, inp): if self.batched: @@ -360,8 +374,8 @@ def with_batch_dim(x): # Permutation is trial and error return torch.func.vmap(jacrev(jacrev(self.model)))(x[:, None])[:, 0].permute(1, 3, 2, 0, 4, 5) - return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) - return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations_and_views'))(inp) + return make_fx(functionalize(with_batch_dim, remove='mutations'))(inp) + return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations'))(inp) def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: d_inp = torch.zeros(self._input_shape) @@ -372,7 +386,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: out_folder = self.build_dir - self._jit_compile_and_save(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp), + self.model_compile( make_fx(functionalize(self.model, remove='mutations'))(d_inp), (out_folder / f'{self.name}.pt').as_posix(), (d_inp,)) @@ -380,7 +394,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: if self._generate_jac: jac_model = self._trace_jac_model(d_inp) - exported_jac = self._jit_compile_and_save( + exported_jac = self.model_compile( jac_model, (out_folder / f'jac_{self.name}.pt').as_posix(), (d_inp,) @@ -389,7 +403,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: exported_adj1 = False if self._generate_adj1: adj1_model = self._trace_adj1_model() - exported_adj1 = self._jit_compile_and_save( + exported_adj1 = self.model_compile( adj1_model, (out_folder / f'adj1_{self.name}.pt').as_posix(), (d_inp, d_out) @@ -398,7 +412,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: exported_jac_adj1 = False if self._generate_jac_adj1: jac_adj1_model = self._trace_jac_adj1_model() - exported_jac_adj1 = self._jit_compile_and_save( + exported_jac_adj1 = self.model_compile( jac_adj1_model, (out_folder / f'jac_adj1_{self.name}.pt').as_posix(), (d_inp, d_out) @@ -413,7 +427,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: pass if hess_model is not None: - exported_hess = self._jit_compile_and_save( + exported_hess = self.model_compile( hess_model, (out_folder / f'jac_jac_{self.name}.pt').as_posix(), (d_inp,) @@ -421,9 +435,27 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: return exported_jac, exported_adj1, exported_jac_adj1, exported_hess + def model_compile(self, model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): + if self._scripting: + return self._jit_compile_and_save(model, file_path, dummy_inp) + else: + return self._aot_compile_and_save(model, file_path, dummy_inp) + + @staticmethod + def _aot_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): + try: + with torch.no_grad(): + torch._export.aot_compile( + model, + dummy_inp, + options={"aot_inductor.output_path": file_path[:-2] + 'so'}, + ) + return True + except: # noqa + return False + @staticmethod - def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor): - # TODO: Could switch to torch export https://pytorch.org/docs/stable/export.html + def _jit_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): try: # Try scripting ts_compile(model).save(file_path) diff --git a/l4casadi/template_generation/templates/casadi_function.in.cpp b/l4casadi/template_generation/templates/casadi_function.in.cpp index f841bfa..e4d5f87 100644 --- a/l4casadi/template_generation/templates/casadi_function.in.cpp +++ b/l4casadi/template_generation/templates/casadi_function.in.cpp @@ -1,6 +1,6 @@ #include -L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ model_is_mutable }}); +L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ scripting }}, {{ model_is_mutable }}); #ifdef __cplusplus extern "C" { diff --git a/libl4casadi/CMakeLists.txt b/libl4casadi/CMakeLists.txt index 44545c6..9229efa 100644 --- a/libl4casadi/CMakeLists.txt +++ b/libl4casadi/CMakeLists.txt @@ -1,13 +1,7 @@ cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(L4CasADi) -# Load CUDA if it is installed -find_package(CUDAToolkit) -find_package(CUDA) - -if (USE_CUDA) - add_definitions(-DUSE_CUDA) -endif () +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) if (WIN32) set (CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS TRUE) @@ -18,6 +12,22 @@ endif () set(CMAKE_PREFIX_PATH ${CMAKE_TORCH_PATH}) find_package(Torch REQUIRED) + +# Load CUDA if it is installed +find_package(CUDAToolkit) +find_package(CUDA) + +add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) +add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) +add_definitions(-DTORCH_VERSION_PATCH=${Torch_VERSION_PATCH}) + +if (Torch_VERSION_MAJOR GREATER_EQUAL 1 AND Torch_VERSION_MINOR GREATER_EQUAL 4) + #add_definitions(-DENABLE_TORCH_COMPILE) +endif () +if (USE_CUDA) + add_definitions(-DUSE_CUDA) +endif () + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") add_library(l4casadi SHARED src/l4casadi.cpp include/l4casadi.hpp) diff --git a/libl4casadi/include/l4casadi.hpp b/libl4casadi/include/l4casadi.hpp index a0d277d..b0a92ec 100644 --- a/libl4casadi/include/l4casadi.hpp +++ b/libl4casadi/include/l4casadi.hpp @@ -14,7 +14,7 @@ class L4CasADi int cols_out; public: L4CasADi(std::string, std::string, int, int, int, int, std::string = "cpu", bool = false, bool = false, bool = false, bool = false, - bool = false); + bool = false,bool = false); ~L4CasADi(); void forward(const double*, double*); void jac(const double*, double*); @@ -26,6 +26,8 @@ class L4CasADi // PImpl Idiom class L4CasADiImpl; + class L4CasADiScriptedImpl; + class L4CasADiCompiledImpl; std::unique_ptr pImpl; }; diff --git a/libl4casadi/src/l4casadi.cpp b/libl4casadi/src/l4casadi.cpp index e7f7ed2..3dc586c 100644 --- a/libl4casadi/src/l4casadi.cpp +++ b/libl4casadi/src/l4casadi.cpp @@ -6,10 +6,12 @@ #include //#include +#if ENABLE_TORCH_COMPILE +#include + #if USE_CUDA #include -#else -#include +#endif #endif #include "l4casadi.hpp" @@ -18,32 +20,24 @@ torch::Device cpu(torch::kCPU); class L4CasADi::L4CasADiImpl { - std::string model_path; - std::string model_prefix; +protected: + std::string path; + std::string function_name; bool has_jac; bool has_adj1; bool has_jac_adj1; bool has_hess; - torch::jit::script::Module adj1_model; - torch::jit::script::Module forward_model; - torch::jit::script::Module jac_model; - torch::jit::script::Module jac_adj1_model; - torch::jit::script::Module hess_model; + bool is_mutable; torch::Device device; - std::thread online_model_reloader_thread; - std::mutex model_update_mutex; - std::atomic reload_model_loop_running = false; - public: - L4CasADiImpl(std::string model_path, std::string model_prefix, std::string device, bool has_jac, bool has_adj1, - bool has_jac_adj1, bool has_hess, bool model_is_mutable): device{torch::kCPU}, model_path{model_path}, - model_prefix{model_prefix}, has_jac{has_jac}, has_adj1{has_adj1}, has_jac_adj1{has_jac_adj1}, - has_hess{has_hess} { - + L4CasADiImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): device{torch::kCPU}, path{path}, + function_name{function_name}, has_jac{has_jac}, has_adj1{has_adj1}, has_jac_adj1{has_jac_adj1}, + has_hess{has_hess}, is_mutable(is_mutable) { if (torch::cuda::is_available() && device.compare("cpu")) { std::cout << "CUDA is available! Using GPU " << device << "." << std::endl; this->device = torch::Device(device); @@ -56,16 +50,183 @@ class L4CasADi::L4CasADiImpl } else { this->device = torch::Device(device); } + } + virtual torch::Tensor forward(torch::Tensor) = 0; + virtual torch::Tensor jac(torch::Tensor) = 0; + virtual torch::Tensor adj1(torch::Tensor, torch::Tensor) = 0; + virtual torch::Tensor jac_adj1(torch::Tensor, torch::Tensor) = 0; + virtual torch::Tensor hess(torch::Tensor) = 0; + + virtual ~L4CasADiImpl() = default; +}; + +#if ENABLE_TORCH_COMPILE +class L4CasADi::L4CasADiCompiledImpl : public L4CasADi::L4CasADiImpl +{ + std::unique_ptr forward_model; + std::unique_ptr jac_model; + std::unique_ptr adj1_model; + std::unique_ptr jac_adj1_model; + std::unique_ptr hess_model; + + std::mutex model_update_mutex; + +public: + L4CasADiCompiledImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): L4CasADiImpl(path, function_name, device, has_jac, + has_adj1, has_jac_adj1, has_hess, is_mutable) { this->load_model_from_disk(); - if (model_is_mutable) { + if (is_mutable) { + throw std::invalid_argument("Mutable functions are not yet supported for compiled models."); + } + } + + ~L4CasADiCompiledImpl() = default; + + void load_model_from_disk() { + std::filesystem::path dir (this->path); + std::filesystem::path forward_model_file (this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); + } + else { + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); + } +#else + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); +#endif + if (this->has_adj1) { + std::filesystem::path adj1_model_file ("adj1_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); + } + else { + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); + } +#else + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); +#endif + } + + if (this->has_jac_adj1) { + std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); + } + else { + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); + } +#else + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); +#endif + } + + if (this->has_jac) { + std::filesystem::path jac_model_file ("jac_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); + } + else { + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); + } +#else + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); +#endif + } + + if (this->has_hess) { + std::filesystem::path hess_model_file ("jac_jac_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); + } + else { + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); + } +#else + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); +#endif + } + } + + torch::Tensor forward(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x); + auto out = this->forward_model->run(inputs)[0].to(cpu); + return out; + } + + torch::Tensor jac(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x.to(this->device)); + return this->jac_model->run(inputs)[0].to(cpu); + } + + torch::Tensor adj1(torch::Tensor primal, torch::Tensor tangent) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + return this->adj1_model->run(inputs)[0].to(cpu); + } + + torch::Tensor jac_adj1(torch::Tensor primal, torch::Tensor tangent){ + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + return this->jac_adj1_model->run(inputs)[0].to(cpu); + } + + torch::Tensor hess(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x.to(this->device)); + return this->hess_model->run(inputs)[0].to(cpu); + } + +}; +#endif + +class L4CasADi::L4CasADiScriptedImpl : public L4CasADi::L4CasADiImpl +{ + torch::jit::script::Module adj1_model; + torch::jit::script::Module forward_model; + torch::jit::script::Module jac_model; + torch::jit::script::Module jac_adj1_model; + torch::jit::script::Module hess_model; + + std::thread online_model_reloader_thread; + std::mutex model_update_mutex; + std::atomic reload_model_loop_running = false; + +public: + L4CasADiScriptedImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): L4CasADiImpl(path, function_name, device, has_jac, + has_adj1, has_jac_adj1, has_hess, is_mutable) { + + this->load_model_from_disk(); + + if (is_mutable) { this->reload_model_loop_running = true; - this->online_model_reloader_thread = std::thread(&L4CasADiImpl::reload_runner, this); + this->online_model_reloader_thread = std::thread(&L4CasADiScriptedImpl::reload_runner, this); } } - ~ L4CasADiImpl() { + ~ L4CasADiScriptedImpl() { if (this->reload_model_loop_running == true) { this->reload_model_loop_running = false; this->online_model_reloader_thread.join(); @@ -73,8 +234,8 @@ class L4CasADi::L4CasADiImpl } void reload_runner() { - std::filesystem::path dir (this->model_path); - std::filesystem::path reload_file (this->model_prefix + ".reload"); + std::filesystem::path dir (this->path); + std::filesystem::path reload_file (this->function_name + ".reload"); while(this->reload_model_loop_running) { std::this_thread::sleep_for(std::chrono::milliseconds(200)); @@ -87,15 +248,15 @@ class L4CasADi::L4CasADiImpl } void load_model_from_disk() { - std::filesystem::path dir (this->model_path); - std::filesystem::path forward_model_file (this->model_prefix + ".pt"); + std::filesystem::path dir (this->path); + std::filesystem::path forward_model_file (this->function_name + ".pt"); this->forward_model = torch::jit::load((dir / forward_model_file).generic_string()); this->forward_model.to(this->device); this->forward_model.eval(); this->forward_model = torch::jit::optimize_for_inference(this->forward_model); if (this->has_adj1) { - std::filesystem::path adj1_model_file ("adj1_" + this->model_prefix + ".pt"); + std::filesystem::path adj1_model_file ("adj1_" + this->function_name + ".pt"); this->adj1_model = torch::jit::load((dir / adj1_model_file).generic_string()); this->adj1_model.to(this->device); this->adj1_model.eval(); @@ -103,7 +264,7 @@ class L4CasADi::L4CasADiImpl } if (this->has_jac_adj1) { - std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->model_prefix + ".pt"); + std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->function_name + ".pt"); this->jac_adj1_model = torch::jit::load((dir / jac_adj1_model_file).generic_string()); this->jac_adj1_model.to(this->device); this->jac_adj1_model.eval(); @@ -111,7 +272,7 @@ class L4CasADi::L4CasADiImpl } if (this->has_jac) { - std::filesystem::path jac_model_file ("jac_" + this->model_prefix + ".pt"); + std::filesystem::path jac_model_file ("jac_" + this->function_name + ".pt"); this->jac_model = torch::jit::load((dir / jac_model_file).generic_string()); this->jac_model.to(this->device); this->jac_model.eval(); @@ -119,7 +280,7 @@ class L4CasADi::L4CasADiImpl } if (this->has_hess) { - std::filesystem::path hess_model_file ("jac_jac_" + this->model_prefix + ".pt"); + std::filesystem::path hess_model_file ("jac_jac_" + this->function_name + ".pt"); this->hess_model = torch::jit::load((dir / hess_model_file).generic_string()); this->hess_model.to(this->device); this->hess_model.eval(); @@ -172,15 +333,23 @@ class L4CasADi::L4CasADiImpl } }; -L4CasADi::L4CasADi(std::string model_path, std::string model_prefix, int rows_in, int cols_in, int rows_out, int cols_out, - std::string device, bool has_jac, bool has_adj1, bool has_jac_adj1, bool has_hess, bool model_is_mutable): - pImpl{std::make_unique(model_path, model_prefix, device, has_jac, has_adj1, has_jac_adj1, has_hess, - model_is_mutable)}, rows_in{rows_in}, cols_in{cols_in}, rows_out{rows_out}, cols_out{cols_out} {} +L4CasADi::L4CasADi(std::string path, std::string function_name, int rows_in, int cols_in, int rows_out, int cols_out, + std::string device, bool has_jac, bool has_adj1, bool has_jac_adj1, bool has_hess, bool scripting, bool is_mutable): + rows_in{rows_in}, cols_in{cols_in}, rows_out{rows_out}, cols_out{cols_out} { +#if ENABLE_TORCH_COMPILE + if (scripting == true) { + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); + } else { + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); + } +#else + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); +#endif +} void L4CasADi::forward(const double* x, double* out) { torch::Tensor x_tensor; x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); - torch::Tensor out_tensor = this->pImpl->forward(x_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } @@ -188,7 +357,6 @@ void L4CasADi::forward(const double* x, double* out) { void L4CasADi::jac(const double* x, double* out) { torch::Tensor x_tensor; x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); - // CasADi expects the return in Fortran order -> Transpose last two dimensions torch::Tensor out_tensor = this->pImpl->jac(x_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); diff --git a/setup.py b/setup.py index 6f1938a..6e61f82 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def compile_hook(manifest): setup( cmake_process_manifest_hook=compile_hook, cmake_source_dir='libl4casadi', - cmake_args=[f'-DCMAKE_TORCH_PATH={os.path.dirname(os.path.abspath(torch.__file__))}'], + cmake_args=['-DCMAKE_BUILD_TYPE=Release', f'-DCMAKE_TORCH_PATH={os.path.dirname(os.path.abspath(torch.__file__))}'], include_package_data=True, package_data={'': [ 'lib/**.dylib',