Skip to content

Commit

Permalink
Merge branch 'main' into refactor/arg_check
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored Jul 3, 2024
2 parents add9ca6 + b864111 commit 085eae9
Show file tree
Hide file tree
Showing 9 changed files with 8 additions and 67 deletions.
4 changes: 0 additions & 4 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@
"pnorm_p",
"sq_pnorm",
"cosine",
"elastic_l1",
"elastic_l2",
"elastic_stvs",
"elastic_sqk_overlap",
"geodesic",
]
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
Expand Down
4 changes: 0 additions & 4 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,3 @@
register_cost("cosine", backend="ott")(costs.Cosine)
register_cost("pnorm_p", backend="ott")(costs.PNormP)
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
4 changes: 1 addition & 3 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str):
np.testing.assert_allclose(solver._problem.geom.cost_matrix, problem.geom.cost_matrix, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)

@pytest.mark.parametrize(
("rank", "cost_fn"), [(2, costs.Euclidean()), (3, costs.SqPNorm(p=1.5)), (5, costs.ElasticL1(0.1))]
)
@pytest.mark.parametrize(("rank", "cost_fn"), [(2, costs.Euclidean()), (3, costs.SqPNorm(p=1.5))])
def test_geometry_rank(self, x: Geom_t, rank: int, cost_fn: costs.CostFn):
eps = 0.05
geom = PointCloud(x, epsilon=eps, cost_fn=cost_fn).to_LRCGeometry(rank=rank)
Expand Down
4 changes: 0 additions & 4 deletions tests/costs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ class TestCostUtils:
"cosine",
"pnorm_p",
"sq_pnorm",
"elastic_l1",
"elastic_l2",
"elastic_stvs",
"elastic_sqk_overlap",
),
}

Expand Down
17 changes: 1 addition & 16 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
ElasticL2,
ElasticSTVS,
Euclidean,
PNormP,
SqEuclidean,
SqPNorm,
)
from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm
from ott.solvers.linear import acceleration

from anndata import AnnData
Expand Down Expand Up @@ -191,9 +182,6 @@ def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"xy": {"p": 5}, "x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_l2", ElasticL2, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: CostKwargs_t):
Expand Down Expand Up @@ -233,9 +221,6 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any,
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"xy": {"p": 5}, "x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_l2", ElasticL2, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs_with_callback(
Expand Down
17 changes: 1 addition & 16 deletions tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
ElasticL2,
ElasticSTVS,
Euclidean,
PNormP,
SqEuclidean,
SqPNorm,
)
from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm
from ott.solvers.linear import acceleration

from anndata import AnnData
Expand Down Expand Up @@ -165,9 +156,6 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
("elastic_l2", ElasticL2, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
("elastic_stvs", ElasticSTVS, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: CostKwargs_t):
Expand Down Expand Up @@ -206,9 +194,6 @@ def test_prepare_marginals(self, adata_time: AnnData, marginal_keys):
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
("elastic_l2", ElasticL2, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
("elastic_stvs", ElasticSTVS, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
],
)
def test_prepare_costs_with_callback(
Expand Down
17 changes: 1 addition & 16 deletions tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
ElasticL2,
ElasticSTVS,
Euclidean,
PNormP,
SqEuclidean,
SqPNorm,
)
from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm
from ott.solvers.linear import acceleration

from anndata import AnnData
Expand Down Expand Up @@ -79,9 +70,6 @@ def test_solve_balanced(self, adata_time: AnnData, marginal_keys):
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"p": 3}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_l2", ElasticL2, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: Mapping[str, int]):
Expand All @@ -104,9 +92,6 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any,
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"p": 3}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_l2", ElasticL2, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs_with_callback(
Expand Down
4 changes: 4 additions & 0 deletions tests/problems/time/test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Tuple

import pytest
Expand Down Expand Up @@ -235,6 +236,7 @@ def test_compute_interpolated_distance_pipeline(self, gt_temporal_adata: AnnData
assert isinstance(interpolation_result, float)
assert interpolation_result > 0

@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnData):
config = gt_temporal_adata.uns
key = config["key"]
Expand Down Expand Up @@ -262,6 +264,7 @@ def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnDa
interpolation_result, gt_temporal_adata.uns["interpolated_distance_10_105_11"], rtol=1e-6, atol=1e-6
)

@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnData):
config = gt_temporal_adata.uns
key = config["key"]
Expand Down Expand Up @@ -313,6 +316,7 @@ def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData):
assert isinstance(result, float)
np.testing.assert_allclose(result, gt_temporal_adata.uns["batch_distances_10"], rtol=1e-5)

@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData):
config = gt_temporal_adata.uns
key = config["key"]
Expand Down
4 changes: 0 additions & 4 deletions tests/utils/test_tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ class TestTaggedArray:
("cosine", {}),
("pnorm_p", {"p": 3}),
("sq_pnorm", {"p": 2}),
("elastic_l1", {"scaling_reg": 1.3}),
("elastic_l2", {}),
("elastic_stvs", {}),
("elastic_sqk_overlap", {"k": 1}),
],
)
def test_from_adata_ott_cost_from_pointcloud(self, adata_time, cost: str, cost_kwargs: Mapping[str, Any]):
Expand Down

0 comments on commit 085eae9

Please sign in to comment.