Skip to content

Commit

Permalink
Fix ot handling in lattice + composition
Browse files Browse the repository at this point in the history
Fixes lehner#146

Implement new `ot` handling in lattice `+` composition.

* Add addition lookup table to `ot_base` in `lib/gpt/core/object_type/base.py`.
* Implement automatic embedding for `complex` to `ot_singlet` lattice in `lib/gpt/core/object_type/base.py`.
* Add explicit casting functions for explicit embedding/projection in `lib/gpt/core/object_type/base.py`.
* Add unit tests for new `ot` handling in lattice `+` composition in `tests/core/expr.py`.
  • Loading branch information
iamprathosh committed Jan 11, 2025
1 parent 5003e00 commit 48b4cd0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
11 changes: 11 additions & 0 deletions lib/gpt/core/object_type/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ot_base:
data_alias = None # ot can be cast as fundamental type data_alias (such as SU(3) -> 3x3 matrix)
mtab = {} # x's multiplication table for x * y
rmtab = {} # y's multiplication table for x * y
atab = {} # addition lookup table

# only vectors shall define otab/itab
otab = None # x's outer product multiplication table for x * adj(y)
Expand All @@ -46,3 +47,13 @@ def data_otype(self):

def is_self_dual(self):
return False

def automatic_embedding(self, other):
if isinstance(other, complex):
return ot_singlet()
return None

def explicit_cast(self, other):
if isinstance(other, ot_base):
return other
return None
34 changes: 34 additions & 0 deletions tests/core/expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import gpt as g
import numpy as np

def test_factor_unary():
grid = g.grid([8, 8, 8, 16], g.double)
v = g.vspincolor(grid)
adj_v = g.adj(v)
result = g(adj_v + v)
assert result.otype.__name__ == "ot_vector_spin_color(4,3)"

def test_addition_with_complex():
grid = g.grid([8, 8, 8, 16], g.double)
singlet = g.singlet(grid)
result = g(singlet + 2.0j)
assert result.otype.__name__ == "ot_singlet"

def test_automatic_embedding():
grid = g.grid([8, 8, 8, 16], g.double)
singlet = g.singlet(grid)
result = g(singlet + 2.0j)
assert result.otype.__name__ == "ot_singlet"

def test_explicit_casting():
grid = g.grid([8, 8, 8, 16], g.double)
singlet = g.singlet(grid)
casted = g.convert(singlet, g.ot_singlet())
assert casted.otype.__name__ == "ot_singlet"

if __name__ == "__main__":
test_factor_unary()
test_addition_with_complex()
test_automatic_embedding()
test_explicit_casting()
print("All tests passed.")

0 comments on commit 48b4cd0

Please sign in to comment.