diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 9e13fe5e4b0d6..4a476977f1811 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 216 +_version = 217 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 2342ace4ca1c7..ac0b0933eae6a 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/python/exceptions.h" #include "xla/python/py_client.h" @@ -299,7 +300,24 @@ void BuildXlaCompilerSubmodule(py::module& m) { const Layout& other) { return layout != other; }) .def("__hash__", [](const Layout& layout) { return absl::HashOf(layout); }) - .def("to_string", &Layout::ToString); + .def("to_string", &Layout::ToString) + .def(py::pickle( + [](const Layout& self) -> py::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return py::make_tuple(py::bytes(result)); + }, + [](py::tuple t) { + LayoutProto result; + result.ParseFromString(t[0].cast()); + return Layout::CreateFromProto(result); + })); py::class_ shape_class(m, "Shape"); shape_class