Skip to content

Commit

Permalink
Merge pull request #27 from matching
Browse files Browse the repository at this point in the history
[WIP] Hooks on matching pairs
  • Loading branch information
ctongfei authored Oct 27, 2024
2 parents 431937d + f9d3db1 commit b67e1c5
Show file tree
Hide file tree
Showing 19 changed files with 1,268 additions and 1,061 deletions.
16 changes: 6 additions & 10 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,23 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup Poetry
run: pipx install poetry==1.6.1
- name: Set up Python 3.9
run: pipx install poetry==1.8.4
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: "poetry"
- name: Install dependencies
run: |
poetry env use python3.9
poetry env use python3.10
poetry install --with dev --extras torchmetrics
- name: Ensure Poetry envs
run: |
echo "$(poetry env info --path)/bin" >> $GITHUB_PATH
- name: Run ruff
uses: chartboost/ruff-action@v1
with:
version: 0.0.290
uses: astral-sh/ruff-action@v1
- name: Run pyright
uses: jakebailey/pyright-action@v1
with:
version: "1.1.327"
uses: jakebailey/pyright-action@v2
- name: Test with pytest
run: |
poetry run pytest
Expand Down
1,369 changes: 605 additions & 764 deletions poetry.lock

Large diffs are not rendered by default.

27 changes: 15 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
line-length = 120
src = ["src", "tests"]

select = [
lint.select = [
"E", # pycodestyle
"F", # pyflakes
"UP", # pyupgrade
"D", # pydocstyle
]
ignore = [
lint.ignore = [
"D102", # Missing docstring in public method
"D107", # Missing docstring in `__init__`
]

fixable = ["ALL"]
lint.fixable = ["ALL"]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2

[tool.black]
Expand All @@ -33,7 +33,7 @@ testpaths = ["tests"]

[tool.poetry]
name = "metametric"
version = "0.1.2"
version = "0.2.0"
description = "A Unified View of Evaluation Metrics for Structured Prediction"
authors = ["Tongfei Chen <[email protected]>", "Yunmo Chen <[email protected]>", "Will Gantt <[email protected]>"]
readme = "README.md"
Expand All @@ -42,26 +42,29 @@ repository = "https://github.com/wanmok/metametric"
packages = [{ include = "metametric", from = "src" }]

[tool.poetry.dependencies]
python = ">=3.9"
scipy = "^1.11.2"
python = ">=3.10"
scipy = "^1.14.0"
numpy = "^1.25.2"
torchmetrics = { version = "^1.1.2", optional = true }

[tool.poetry.group.dev]
optional = true

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
pyright = "^1.1.327"
ruff = "^0.0.290"
pyright = "^1.1.385"
networkx = "^3.1"
mkdocs = "^1.5.3"
mkdocs-material = "^9.4.2"
mkdocstrings = {extras = ["python"], version = "^0.23.0"}
mkdocstrings = {extras = ["python"], version = "^0.26.2"}
ruff = "^0.7.0"
griffe = "^1.5.1"

[tool.poetry.extras]
torchmetrics = ["torchmetrics"]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D100", "D103", "F841"]
2 changes: 1 addition & 1 deletion src/metametric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""The core functionality of metametric package."""

__version__ = "0.1.0"
__version__ = "0.2.0"

from metametric.core.metric import Metric, Variable # noqa: F401
from metametric.core.reduction import Reduction # noqa: F401
Expand Down
23 changes: 16 additions & 7 deletions src/metametric/core/_ilp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, is_dataclass
from typing import (Any, Callable, Collection, Generic, Iterator, List,
Optional, Sequence, Type, TypeVar)

import numpy as np
import scipy as sp

from metametric.core.constraint import MatchingConstraint
from metametric.core._problem import MatchingProblem
from metametric.core.metric import Variable

T = TypeVar('T')
Expand Down Expand Up @@ -122,6 +123,9 @@ class LatentVariableConstraintBuilder(ConstraintBuilder, Generic[T]):
cls: Type[T]
gram_matrix: np.ndarray # R[n_x, n_y]

def __post_init__(self):
assert is_dataclass(self.cls)

def build(self) -> Optional[sp.optimize.LinearConstraint]:
x_vars = list(_all_variables(self.x))
y_vars = list(_all_variables(self.y))
Expand All @@ -131,7 +135,7 @@ def build(self) -> Optional[sp.optimize.LinearConstraint]:
for i, a in enumerate(self.x):
for j, b in enumerate(self.y):
if self.gram_matrix[i, j] > 0:
for fld in fields(self.cls):
for fld in fields(self.cls): # pyright: ignore
a_fld = getattr(a, fld.name, None)
b_fld = getattr(b, fld.name, None)
if isinstance(a_fld, Variable) and isinstance(b_fld, Variable):
Expand All @@ -152,7 +156,7 @@ def build(self) -> Optional[sp.optimize.LinearConstraint]:
)


class MatchingProblem(Generic[T]):
class ILPMatchingProblem(MatchingProblem[T]):
"""Creates a matching problem that is solved by ILP.
The constrained ILP problem has variables
Expand All @@ -166,11 +170,9 @@ def __init__(
gram_matrix: np.ndarray,
has_vars: bool = False,
):
self.x = x
self.y = y
super().__init__(x, y, gram_matrix)
self.n_x = len(x)
self.n_y = len(y)
self.gram_matrix = gram_matrix
if has_vars:
self.x_vars = list(_all_variables(x))
self.y_vars = list(_all_variables(y))
Expand Down Expand Up @@ -257,7 +259,14 @@ def solve(self):
bounds=sp.optimize.Bounds(lb=0, ub=1),
integrality=np.ones_like(coef),
)
return -result.fun
solution = result.x[:self.n_x * self.n_y].reshape([self.n_x, self.n_y])
matching = [
(i, j, self.gram_matrix[i, j].item())
for i in range(self.n_x)
for j in range(self.n_y)
if solution[i, j] > 0
]
return -result.fun, matching


def _get_one_to_many_constraint_matrix(n_x: int, n_y: int) -> np.ndarray: # [Y, X * Y]
Expand Down
55 changes: 55 additions & 0 deletions src/metametric/core/_problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import abstractmethod
from typing import Sequence, TypeVar, Generic, Tuple, Collection

import numpy as np
import scipy.optimize as spo

from metametric.core.constraint import MatchingConstraint

T = TypeVar("T")


class MatchingProblem(Generic[T]):
"""A matching between two collections of objects."""

def __init__(self, x: Sequence[T], y: Sequence[T], gram_matrix: np.ndarray):
self.x = x
self.y = y
self.gram_matrix = gram_matrix

@abstractmethod
def solve(self) -> Tuple[float, Collection[Tuple[int, int, float]]]:
"""Solves the matching problem."""
raise NotImplementedError


class AssignmentProblem(MatchingProblem[T]):
def __init__(self, x: Sequence[T], y: Sequence[T], gram_matrix: np.ndarray, constraint: MatchingConstraint):
super().__init__(x, y, gram_matrix)
self.constraint = constraint

def solve(self) -> Tuple[float, Collection[Tuple[int, int, float]]]:
m = self.gram_matrix
if self.constraint == MatchingConstraint.ONE_TO_ONE:
row_idx, col_idx = spo.linear_sum_assignment(
cost_matrix=m,
maximize=True,
)
total = m[row_idx, col_idx].sum()
matching = [(i.item(), j.item(), m[i, j].item()) for i, j in zip(row_idx, col_idx)]
return total, matching
if self.constraint == MatchingConstraint.ONE_TO_MANY:
total = m.max(axis=0).sum().item()
selected_x = m.argmax(axis=0)
matching = [(selected_x[j].item(), j, m[selected_x[j], j].item()) for j in range(m.shape[1])]
return total, matching
if self.constraint == MatchingConstraint.MANY_TO_ONE:
total = m.max(axis=1).sum().item()
selected_y = m.argmax(axis=1)
matching = [(i, selected_y[i].item(), m[i, selected_y[i]].item()) for i in range(m.shape[0])]
return total, matching
if self.constraint == MatchingConstraint.MANY_TO_MANY:
total = m.sum().item()
matching = [(i, j, m[i, j].item()) for i in range(m.shape[0]) for j in range(m.shape[1])]
return total, matching
raise ValueError(f"Invalid constraint: {self.constraint}")
15 changes: 9 additions & 6 deletions src/metametric/core/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import (Annotated, Any, Callable, Collection, Literal, Type,
TypeVar, Union, get_args, get_origin, Optional)

from metametric.core.matching import (MatchingConstraint,
LatentSetMatchingMetric,
SetMatchingMetric)
from metametric.core.matching_metrics import (MatchingConstraint,
LatentSetMatchingMetric,
SetMatchingMetric)
from metametric.core.metric import (DiscreteMetric, HasLatentMetric, HasMetric,
Metric, ProductMetric, UnionMetric,
Variable)
Expand Down Expand Up @@ -37,8 +37,7 @@ def dataclass_has_variable(cls: Type) -> bool:
return False



def derive_metric(cls: Type, constraint: MatchingConstraint) -> Metric:
def derive_metric(cls: Type, constraint: MatchingConstraint) -> Metric: # dependent type, can't enforce
"""Derive a unified metric from any type.
Args:
Expand All @@ -65,7 +64,11 @@ def derive_metric(cls: Type, constraint: MatchingConstraint) -> Metric:
# derive product metric from dataclass
elif is_dataclass(cls):
return ProductMetric(
cls=cls, field_metrics={fld.name: derive_metric(fld.type, constraint=constraint) for fld in fields(cls)}
cls=cls,
field_metrics={
fld.name: derive_metric(fld.type, constraint=constraint) # pyright: ignore
for fld in fields(cls)
}
)

# derive union metric from unions
Expand Down
4 changes: 2 additions & 2 deletions src/metametric/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def predecessors(self, x: T) -> Iterator[T]:
raise NotImplementedError()


def _adjacency_matrix(graph: Graph[T]) -> np.ndarray:
def _adjacency_matrix(graph: Graph) -> np.ndarray:
"""Get the adjacency matrix of a graph."""
nodes = list(graph.nodes())
node_to_id = {x: i for i, x in enumerate(nodes)}
Expand All @@ -33,7 +33,7 @@ def _adjacency_matrix(graph: Graph[T]) -> np.ndarray:
return adj


def _reachability_matrix(graph: Graph[T]) -> np.ndarray:
def _reachability_matrix(graph: Graph) -> np.ndarray:
"""Get the reachability matrix of a graph."""
a = _adjacency_matrix(graph)
b = np.eye(a.shape[0], dtype=bool) + a
Expand Down
Loading

0 comments on commit b67e1c5

Please sign in to comment.