Skip to content

Commit

Permalink
always returning log
Browse files Browse the repository at this point in the history
  • Loading branch information
anushka255 committed Oct 24, 2024
1 parent 2cf75c9 commit 1bdf271
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 22 deletions.
33 changes: 13 additions & 20 deletions src/paste3/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def pairwise_align(
m=s,
G0=G_init,
loss_fun="square_loss",
log=True,
numItermax=maxIter if s else numItermax,
use_gpu=use_gpu,
)
Expand Down Expand Up @@ -446,7 +445,6 @@ def my_fused_gromov_wasserstein(
G0=None,
loss_fun="square_loss",
armijo=False,
log=False,
numItermax=200,
tol_rel=1e-9,
tol_abs=1e-9,
Expand Down Expand Up @@ -474,8 +472,7 @@ def my_fused_gromov_wasserstein(
" equal to min(|p|_1, |q|_1)."
)

if log:
_log = {"err": []}
_log = {"err": []}
count = 0
dummy = 1
_p = torch.cat([p, torch.Tensor([(q.sum() - m) / dummy] * dummy).to(p.device)])
Expand Down Expand Up @@ -512,10 +509,9 @@ def df(G):
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
if m:
nonlocal count
if log:
# keep track of error only on every 10th iteration
if count % 10 == 0:
_log["err"].append(torch.norm(deltaG))
# keep track of error only on every 10th iteration
if count % 10 == 0:
_log["err"].append(torch.norm(deltaG))
count += 1

if armijo:
Expand Down Expand Up @@ -557,25 +553,22 @@ def lp_solver(a, b, M, **kwargs):
lp_solver,
line_search,
G0,
log=log,
log=True,
numItermax=numItermax,
stopThr=tol_rel,
stopThr2=tol_abs,
**kwargs,
)

if log:
res, log = return_val
if m:
log["partial_fgw_cost"] = log["loss"][-1]
log["err"] = _log["err"]
else:
log["fgw_dist"] = log["loss"][-1]
log["u"] = log["u"]
log["v"] = log["v"]
return res, log
res, log = return_val
if m:
log["partial_fgw_cost"] = log["loss"][-1]
log["err"] = _log["err"]
else:
return return_val
log["fgw_dist"] = log["loss"][-1]
log["u"] = log["u"]
log["v"] = log["v"]
return res, log


def solve_gromov_linesearch(
Expand Down
1 change: 0 additions & 1 deletion tests/test_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def test_fused_gromov_wasserstein(slices, spot_distance_matrix):
alpha=0.1,
G0=None,
loss_fun="square_loss",
log=True,
numItermax=10,
)
pd.DataFrame(pairwise_info).to_csv(
Expand Down
1 change: 0 additions & 1 deletion tests/test_paste2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def test_partial_fused_gromov_wasserstein(slices, armijo, expected_log, filename
G0=None,
loss_fun="square_loss",
armijo=armijo,
log=True,
)

assert np.allclose(
Expand Down

0 comments on commit 1bdf271

Please sign in to comment.