From 48b4cd0dbe2b32632593b5133b1ce40a979cdc94 Mon Sep 17 00:00:00 2001 From: Guru Prathosh <86939870+iamprathosh@users.noreply.github.com> Date: Sat, 11 Jan 2025 21:35:17 +0530 Subject: [PATCH] Fix `ot` handling in lattice `+` composition Fixes #146 Implement new `ot` handling in lattice `+` composition. * Add addition lookup table to `ot_base` in `lib/gpt/core/object_type/base.py`. * Implement automatic embedding for `complex` to `ot_singlet` lattice in `lib/gpt/core/object_type/base.py`. * Add explicit casting functions for explicit embedding/projection in `lib/gpt/core/object_type/base.py`. * Add unit tests for new `ot` handling in lattice `+` composition in `tests/core/expr.py`. --- lib/gpt/core/object_type/base.py | 11 +++++++++++ tests/core/expr.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 tests/core/expr.py diff --git a/lib/gpt/core/object_type/base.py b/lib/gpt/core/object_type/base.py index 6eeee02a..82c3411e 100644 --- a/lib/gpt/core/object_type/base.py +++ b/lib/gpt/core/object_type/base.py @@ -30,6 +30,7 @@ class ot_base: data_alias = None # ot can be cast as fundamental type data_alias (such as SU(3) -> 3x3 matrix) mtab = {} # x's multiplication table for x * y rmtab = {} # y's multiplication table for x * y + atab = {} # addition lookup table # only vectors shall define otab/itab otab = None # x's outer product multiplication table for x * adj(y) @@ -46,3 +47,13 @@ def data_otype(self): def is_self_dual(self): return False + + def automatic_embedding(self, other): + if isinstance(other, complex): + return ot_singlet() + return None + + def explicit_cast(self, other): + if isinstance(other, ot_base): + return other + return None diff --git a/tests/core/expr.py b/tests/core/expr.py new file mode 100644 index 00000000..7d192b7d --- /dev/null +++ b/tests/core/expr.py @@ -0,0 +1,34 @@ +import gpt as g +import numpy as np + +def test_factor_unary(): + grid = g.grid([8, 8, 8, 16], g.double) + v = g.vspincolor(grid) + adj_v = g.adj(v) + result = g(adj_v + v) + assert result.otype.__name__ == "ot_vector_spin_color(4,3)" + +def test_addition_with_complex(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + result = g(singlet + 2.0j) + assert result.otype.__name__ == "ot_singlet" + +def test_automatic_embedding(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + result = g(singlet + 2.0j) + assert result.otype.__name__ == "ot_singlet" + +def test_explicit_casting(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + casted = g.convert(singlet, g.ot_singlet()) + assert casted.otype.__name__ == "ot_singlet" + +if __name__ == "__main__": + test_factor_unary() + test_addition_with_complex() + test_automatic_embedding() + test_explicit_casting() + print("All tests passed.")