-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexplth.py
131 lines (101 loc) · 4.45 KB
/
explth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from sympy import Function, Matrix, diff, Expr
class ExpLTH(Function):
"""
The exponent of a Lie transform of a Hamiltonian which is only a function of position and momenta.
This way, the operator is exactly given by a first order series expansion:
e^(t:H:)*F=F+t:H:F
"""
def __init__(self, variable: Expr, hamiltonian: Expr, position_vars, momentum_vars):
h_vars = hamiltonian.free_symbols.intersection(
[*position_vars, *momentum_vars])
self.h_q_dependence = h_vars.issubset(position_vars)
self.h_p_dependence = h_vars.issubset(momentum_vars)
# Check notes for detailed math. In short, if H depends on momentum we want to check F dependence on position
# To check if the operator is equal to identity, and visa versa with H dependence on position.
if self.h_p_dependence:
self.F_dependence = set(position for idx, position in enumerate(
position_vars) if momentum_vars[idx] in h_vars)
self.d_hamiltonian = - \
Matrix([diff(hamiltonian, p) for p in momentum_vars])
elif self.h_q_dependence:
self.F_dependence = set(momentum for idx, momentum in enumerate(
momentum_vars) if position_vars[idx] in h_vars)
self.d_hamiltonian = Matrix(
[diff(hamiltonian, q) for q in position_vars])
else:
raise ValueError(
"Hamiltonian is not a function of only position or momenta.")
self.hamiltonian = hamiltonian
self.variable = variable
self.position_vars = position_vars
self.momentum_vars = momentum_vars
def copy_replace_variable(self, new_variable):
return ExpLTH(new_variable, *self.args[1:])
def __mul__(self, F: Expr):
# check if iterable
try:
result = []
for i in F:
result.append(self * i)
return Matrix(result)
except TypeError: # It is not iterable. Continue as normal
pass
if isinstance(F, ExpLTH):
return ProductExpLTH(self, F)
non_zero_vars = F.free_symbols.intersection(self.F_dependence)
if not non_zero_vars:
# Identity operator
return F
vars_to_diff = self.momentum_vars if self.h_q_dependence else self.position_vars
return F + self.variable * self.d_hamiltonian.dot(
Matrix(
[diff(F, var) if var in non_zero_vars else 0 for var in vars_to_diff])
)
def __pow__(self, other):
if not isinstance(other, int):
raise TypeError("ExpLTH raised to the power of a non-integer")
elif other < 0:
raise ValueError(
"ExpLTH raised to the power of a non-positive integer")
if other == 0:
return 1 # Identity
else:
return ProductExpLTH(self * other)
def __str__(self):
return f"exp({self.variable}:{self.hamiltonian}:)"
def _latex(self, printer):
var = printer._print(self.variable)
H = printer._print(self.hamiltonian)
return r"\operatorname{exp}{\left( %s:%s: \right)}" % (var, H)
class ProductExpLTH(Function):
"""
Class to handle multiple ExpLTHs multiplied in a row.
In this case, we keep storing them until we execute it on something.
"""
def __init__(self, *args: ExpLTH):
self.terms = list(args)
def __str__(self):
return ''.join(i.__str__() + ' * ' for i in self.terms).rstrip(' * ')
def _latex(self, printer):
return ''.join(i._latex(printer) for i in self.terms)
def copy_replace_variable(self, new_variable):
return ProductExpLTH(*[term.copy_replace_variable(new_variable) for term in self.terms])
def __mul__(self, other):
if isinstance(other, ProductExpLTH):
return ProductExpLTH(*self.terms, *other.terms)
if isinstance(other, ExpLTH):
return ProductExpLTH(*self.terms, other)
result = other
for i in self.terms[::-1]:
result = i * result
return result
def __pow__(self, other):
if not isinstance(other, int):
raise TypeError("ExpLTH raised to the power of a non-integer")
elif other < 0:
raise ValueError(
"ExpLTH raised to the power of a non-positive integer")
if other == 0:
return 1 # Identity
else:
return ProductExpLTH(*(self.terms * other))