Skip to content

Commit

Permalink
Make the SpecifiedLayout class opaque.
Browse files Browse the repository at this point in the history
Also need to enabling pickling to xc.Layout so that AOT serialization continues to work.

PiperOrigin-RevId: 583684299
  • Loading branch information
yashk2810 authored and copybara-github committed Nov 18, 2023
1 parent c4a813e commit a01af1a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 216
_version = 217

# Version number for MLIR:Python components.
mlir_api_version = 54
Expand Down
20 changes: 19 additions & 1 deletion xla/python/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/python/exceptions.h"
#include "xla/python/py_client.h"
Expand Down Expand Up @@ -299,7 +300,24 @@ void BuildXlaCompilerSubmodule(py::module& m) {
const Layout& other) { return layout != other; })
.def("__hash__",
[](const Layout& layout) { return absl::HashOf(layout); })
.def("to_string", &Layout::ToString);
.def("to_string", &Layout::ToString)
.def(py::pickle(
[](const Layout& self) -> py::tuple {
auto proto = self.ToProto();
std::string result;
if (!tsl::SerializeToStringDeterministic(proto, &result)) {
// throw converted by PyBind to a Python RuntimeError.
throw XlaRuntimeError(
absl::StrCat("Layout.py_pickle: ",
"SerializeToStringDeterministic failed"));
}
return py::make_tuple(py::bytes(result));
},
[](py::tuple t) {
LayoutProto result;
result.ParseFromString(t[0].cast<std::string>());
return Layout::CreateFromProto(result);
}));

py::class_<Shape> shape_class(m, "Shape");
shape_class
Expand Down

0 comments on commit a01af1a

Please sign in to comment.