Skip to content

Commit

Permalink
TN equalize_norms add check_zero option and turn on for simplifying
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 18, 2024
1 parent 1dc14d5 commit 69f28e5
Showing 1 changed file with 72 additions and 17 deletions.
89 changes: 72 additions & 17 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9011,7 +9011,8 @@ def insert_compressor_between_regions(

# then form the 'oblique' projectors
Pl, Pr = decomp.compute_oblique_projectors(
Rl, Rr,
Rl,
Rr,
max_bond=max_bond,
cutoff=cutoff,
**compress_opts,
Expand Down Expand Up @@ -9389,7 +9390,7 @@ def randomize(self, dtype=None, seed=None, inplace=False, **randn_opts):

randomize_ = functools.partialmethod(randomize, inplace=True)

def strip_exponent(self, tid_or_tensor, value=None):
def strip_exponent(self, tid_or_tensor, value=None, check_zero=False):
"""Scale the elements of tensor corresponding to ``tid`` so that the
norm of the array is some value, which defaults to ``1``. The log of
the scaling factor, base 10, is then accumulated in the ``exponent``
Expand All @@ -9401,6 +9402,11 @@ def strip_exponent(self, tid_or_tensor, value=None):
The tensor identifier or actual tensor.
value : None or float, optional
The value to scale the norm of the tensor to.
check_zero : bool, optional
Whether to check if the tensor has zero norm and in that case do
nothing, since the `exponent` would be -inf. Off by default to
avoid data dependent computational graphs when tracing and
computing gradients etc.
"""
if (value is None) or (value is True):
value = 1.0
Expand All @@ -9411,6 +9417,10 @@ def strip_exponent(self, tid_or_tensor, value=None):
t = self.tensor_map[tid_or_tensor]

stripped_factor = t.norm() / value

if check_zero and (stripped_factor == 0.0):
return

Check warning on line 9422 in quimb/tensor/tensor_core.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_core.py#L9422

Added line #L9422 was not covered by tests

t.modify(apply=lambda data: data / stripped_factor)
self.exponent = self.exponent + do("log10", stripped_factor)

Expand All @@ -9425,7 +9435,7 @@ def distribute_exponent(self):
# reset the exponent to zero
self.exponent = 0.0

def equalize_norms(self, value=None, inplace=False):
def equalize_norms(self, value=None, check_zero=False, inplace=False):
"""Make the Frobenius norm of every tensor in this TN equal without
changing the overall value if ``value=None``, or set the norm of every
tensor to ``value`` by scalar multiplication only.
Expand All @@ -9436,6 +9446,11 @@ def equalize_norms(self, value=None, inplace=False):
Set the norm of each tensor to this value specifically. If supplied
the change in overall scaling will be accumulated in
``tn.exponent`` in the form of a base 10 power.
check_zero : bool, optional
Whether, if and when equalizing norms, to check if tensors have
zero norm and in that case do nothing, since the `exponent` would
be -inf. Off by default to avoid data dependent computational
graphs when tracing and computing gradients etc.
inplace : bool, optional
Whether to perform the norm equalization inplace or not.
Expand All @@ -9446,7 +9461,7 @@ def equalize_norms(self, value=None, inplace=False):
tn = self if inplace else self.copy()

for tid in tn.tensor_map:
tn.strip_exponent(tid, value=value)
tn.strip_exponent(tid, value=value, check_zero=check_zero)

if value is None:
tn.distribute_exponent()
Expand Down Expand Up @@ -9591,6 +9606,7 @@ def rank_simplify(
equalize_norms=False,
cache=None,
max_combinations=500,
check_zero=False,
inplace=False,
):
"""Simplify this tensor network by performing contractions that don't
Expand All @@ -9607,6 +9623,11 @@ def rank_simplify(
exponent in ``tn.exponent``.
cache : None or set
Persistent cache used to mark already checked tensors.
check_zero : bool, optional
Whether, if and when equalizing norms, to check if tensors have
zero norm and in that case do nothing, since the `exponent` would
be -inf. Off by default to avoid data dependent computational
graphs when tracing and computing gradients etc.
inplace : bool, optional
Whether to perform the rand reduction inplace.
Expand Down Expand Up @@ -9752,18 +9773,24 @@ def rank_weight(ind):
tn |= tab

if equalize_norms:
tn.strip_exponent(tab, equalize_norms)
tn.strip_exponent(tab, equalize_norms, check_zero=check_zero)

for ix in out_ab:
# now we need to check outputs indices again
queue.add(ix)

if scalars:
if equalize_norms:
# move overall scaling factor into exponent, absorb phase
signs = []
for s in scalars:
signs.append(s / do("abs", s))
tn.exponent += do("log10", do("abs", s))
sa = do("abs", s)
if check_zero and (sa == 0.0):
# whole contraction is zero
signs = [0.0]
break

Check warning on line 9791 in quimb/tensor/tensor_core.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_core.py#L9790-L9791

Added lines #L9790 - L9791 were not covered by tests
signs.append(s / sa)
tn.exponent += do("log10", sa)
scalars = signs

if tn.num_tensors:
Expand Down Expand Up @@ -10023,6 +10050,7 @@ def split_simplify(
atol=1e-12,
equalize_norms=False,
cache=None,
check_zero=False,
inplace=False,
**split_opts,
):
Expand All @@ -10039,6 +10067,11 @@ def split_simplify(
exponent in ``tn.exponent``.
cache : None or set
Persistent cache used to mark already checked tensors.
check_zero : bool, optional
Whether, if and when equalizing norms, to check if tensors have
zero norm and in that case do nothing, since the `exponent` would
be -inf. Off by default to avoid data dependent computational
graphs when tracing and computing gradients etc.
inplace, bool, optional
Whether to perform the split simplification inplace.
"""
Expand Down Expand Up @@ -10075,8 +10108,12 @@ def split_simplify(
tn |= tr

if equalize_norms:
tn.strip_exponent(tl, equalize_norms)
tn.strip_exponent(tr, equalize_norms)
tn.strip_exponent(
tl, equalize_norms, check_zero=check_zero
)
tn.strip_exponent(
tr, equalize_norms, check_zero=check_zero
)

else:
cache.add(cache_key)
Expand All @@ -10093,6 +10130,7 @@ def pair_simplify(
cache=None,
equalize_norms=False,
max_combinations=500,
check_zero=False,
inplace=False,
**split_opts,
):
Expand Down Expand Up @@ -10180,8 +10218,8 @@ def gen_pairs():

tensor_fuse_squeeze(tl, tr)
if equalize_norms:
tn.strip_exponent(tl, equalize_norms)
tn.strip_exponent(tr, equalize_norms)
tn.strip_exponent(tl, equalize_norms, check_zero=check_zero)
tn.strip_exponent(tr, equalize_norms, check_zero=check_zero)

Check warning on line 10222 in quimb/tensor/tensor_core.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_core.py#L10221-L10222

Added lines #L10221 - L10222 were not covered by tests

queue.extend(tl.inds)
queue.extend(tr.inds)
Expand All @@ -10199,6 +10237,7 @@ def loop_simplify(
loops=None,
cache=None,
equalize_norms=False,
check_zero=False,
inplace=False,
**split_opts,
):
Expand All @@ -10218,6 +10257,11 @@ def loop_simplify(
cache : set, optional
For performance reasons can supply a cache for already checked
loops.
check_zero : bool, optional
Whether, if and when equalizing norms, to check if tensors have
zero norm and in that case do nothing, since the `exponent` would
be -inf. Off by default to avoid data dependent computational
graphs when tracing and computing gradients etc.
inplace : bool, optional
Whether to replace the loops inplace.
split_opts
Expand Down Expand Up @@ -10298,8 +10342,8 @@ def loop_simplify(

tensor_fuse_squeeze(tl, tr)
if equalize_norms:
tn.strip_exponent(tl, equalize_norms)
tn.strip_exponent(tr, equalize_norms)
tn.strip_exponent(tl, equalize_norms, check_zero=check_zero)
tn.strip_exponent(tr, equalize_norms, check_zero=check_zero)

Check warning on line 10346 in quimb/tensor/tensor_core.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_core.py#L10345-L10346

Added lines #L10345 - L10346 were not covered by tests

return tn

Expand All @@ -10312,13 +10356,14 @@ def full_simplify(
atol=1e-12,
equalize_norms=False,
cache=None,
inplace=False,
progbar=False,
rank_simplify_opts=None,
loop_simplify_opts=None,
split_simplify_opts=None,
custom_methods=(),
split_method="svd",
check_zero=True,
inplace=False,
progbar=False,
):
"""Perform a series of tensor network 'simplifications' in a loop until
there is no more reduction in the number of tensors or indices. Note
Expand Down Expand Up @@ -10357,6 +10402,9 @@ def full_simplify(
cache : None or set
A persistent cache for each simplification process to mark
already processed tensors.
check_zero : bool, optional
Whether to check if tensors have zero norm and in that case do
nothing if and when equalizing norms, rather than generating a NaN.
progbar : bool, optional
Show a live progress bar of the simplification process.
inplace : bool, optional
Expand Down Expand Up @@ -10422,6 +10470,7 @@ def full_simplify(
output_inds=ix_o,
cache=cache,
equalize_norms=equalize_norms,
check_zero=check_zero,
**rank_simplify_opts,
)
elif meth == "A":
Expand All @@ -10435,6 +10484,7 @@ def full_simplify(
atol=atol,
cache=cache,
equalize_norms=equalize_norms,
check_zero=check_zero,
**split_simplify_opts,
)
elif meth == "L":
Expand All @@ -10443,6 +10493,7 @@ def full_simplify(
cutoff=atol,
cache=cache,
equalize_norms=equalize_norms,
check_zero=check_zero,
**loop_simplify_opts,
)
elif meth == "P":
Expand All @@ -10451,6 +10502,7 @@ def full_simplify(
cutoff=atol,
cache=cache,
equalize_norms=equalize_norms,
check_zero=check_zero,
**loop_simplify_opts,
)
else:
Expand All @@ -10462,9 +10514,10 @@ def full_simplify(
if equalize_norms:
if equalize_norms is True:
# this also redistributes the collected exponents
tn.equalize_norms_()
value = None
else:
tn.equalize_norms_(value=equalize_norms)
value = equalize_norms

Check warning on line 10519 in quimb/tensor/tensor_core.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_core.py#L10519

Added line #L10519 was not covered by tests
tn.equalize_norms_(value=value, check_zero=check_zero)

if progbar:
pbar.close()
Expand Down Expand Up @@ -10594,6 +10647,7 @@ def compress_simplify(
max_simplification_iterations=100,
converged_tol=0.01,
equalize_norms=True,
check_zero=True,
progbar=False,
inplace=False,
**full_simplify_opts,
Expand All @@ -10606,6 +10660,7 @@ def compress_simplify(
simplify_opts = {
"atol": atol,
"equalize_norms": equalize_norms,
"check_zero": check_zero,
"progbar": progbar,
"output_inds": output_inds,
"cache": set(),
Expand Down

0 comments on commit 69f28e5

Please sign in to comment.