Skip to content

Commit 2de4ecd

Browse files
anjali411facebook-github-bot
authored andcommitted
Add serialization logic for complex numbers (pytorch#50885)
Summary: Pull Request resolved: pytorch#50885 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D26094906 Pulled By: anjali411 fbshipit-source-id: 7b2614f3ee4a30c4b4cf04aaa3432988b38a0721
1 parent 3b6f308 commit 2de4ecd

File tree

5 files changed

+35
-3
lines changed

5 files changed

+35
-3
lines changed

test/test_complex.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ def fn(a: complex):
1212

1313
self.checkScript(fn, (3 + 5j,))
1414

15+
def test_pickle(self):
16+
class ComplexModule(torch.jit.ScriptModule):
17+
def __init__(self):
18+
super().__init__()
19+
self.a = 3 + 5j
20+
21+
def forward(self, b: int):
22+
return b
23+
24+
loaded = self.getExportImportCopy(ComplexModule())
25+
self.assertEqual(loaded.a, 3 + 5j)
26+
1527
class TestComplexTensor(TestCase):
1628
@dtypes(*torch.testing.get_all_complex_dtypes())
1729
def test_to_list(self, device, dtype):

torch/csrc/jit/python/pybind_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ inline InferredType tryToInferType(py::handle input) {
293293
return InferredType(IntType::get());
294294
} else if (py::isinstance<py::float_>(input)) {
295295
return InferredType(FloatType::get());
296+
} else if (PyComplex_CheckExact(input.ptr())) {
297+
return InferredType(ComplexDoubleType::get());
296298
} else if (py::isinstance<py::str>(input)) {
297299
return InferredType(StringType::get());
298300
} else if (THPLayout_Check(input.ptr())) {

torch/csrc/jit/serialization/pickler.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
4949
pushTuple(ivalue);
5050
} else if (ivalue.isDouble()) {
5151
pushDouble(ivalue.toDouble());
52+
} else if (ivalue.isComplexDouble()) {
53+
pushComplexDouble(ivalue);
5254
} else if (ivalue.isInt()) {
5355
pushInt(ivalue.toInt());
5456
} else if (ivalue.isBool()) {
@@ -464,6 +466,14 @@ void Pickler::pushDouble(double value) {
464466
// Python pickle format is big endian, swap.
465467
push<double>(swapDouble(value));
466468
}
469+
void Pickler::pushComplexDouble(const IValue& value) {
470+
c10::complex<double> d = value.toComplexDouble();
471+
pushGlobal("builtins", "complex");
472+
pushIValue(d.real());
473+
pushIValue(d.imag());
474+
push<PickleOpCode>(PickleOpCode::TUPLE2);
475+
push<PickleOpCode>(PickleOpCode::REDUCE);
476+
}
467477

468478
void Pickler::pushLong(const std::string& data) {
469479
uint64_t size = data.size();

torch/csrc/jit/serialization/pickler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class TORCH_API Pickler {
160160
void endTypeTag(const IValue& value);
161161
void pushBool(bool value);
162162
void pushDouble(double value);
163+
void pushComplexDouble(const IValue& value);
163164
void pushGenericList(const IValue& ivalue);
164165
void pushIntList(const IValue& ivalue);
165166
void pushList(const IValue& ivalue);

torch/csrc/jit/serialization/unpickler.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
5757
case StorageType::Kind:
5858
case NumberType::Kind:
5959
case FloatType::Kind:
60+
case ComplexDoubleType::Kind:
6061
case IntType::Kind:
6162
case NoneType::Kind:
6263
case GeneratorType::Kind:
@@ -80,9 +81,6 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
8081
case AnyEnumType::Kind:
8182
// no op, there is nothing to tag
8283
break;
83-
// TODO(@anjali411): Implement serialization/deserialization for complex
84-
// numbers
85-
case ComplexDoubleType::Kind:
8684
case EnumType::Kind:
8785
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
8886
AT_ASSERT(false);
@@ -543,6 +541,15 @@ void Unpickler::readGlobal(
543541
// Unpickle a tensor
544542
bool quantized = class_name == "_rebuild_qtensor";
545543
rebuildTensor(quantized);
544+
} else if (module_name == "builtins" && class_name == "complex") {
545+
globals_.emplace_back([this] {
546+
auto elems = pop(stack_).toTuple()->elements();
547+
AT_ASSERT(elems.size() == 2);
548+
auto complex =
549+
c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
550+
stack_.emplace_back(complex);
551+
});
552+
546553
} else if (module_name == "collections" && class_name == "OrderedDict") {
547554
// collections.OrderedDict is used in tensor serialization for a tensor's
548555
// backward hooks (but they are not actually saved with this Pickler)

0 commit comments

Comments
 (0)