Skip to content

Commit

Permalink
Merge pull request #3 from wanmok/project-skeleton
Browse files Browse the repository at this point in the history
Updated project skeleton
  • Loading branch information
ctongfei authored Sep 18, 2023
2 parents ed0b9b6 + 0bed979 commit c0951c8
Show file tree
Hide file tree
Showing 20 changed files with 684 additions and 373 deletions.
216 changes: 149 additions & 67 deletions poetry.lock

Large diffs are not rendered by default.

50 changes: 42 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,53 @@
[tool.ruff]
line-length = 120
src = ["src", "tests"]

select = [
"E", # pycodestyle
"F", # pyflakes
"UP", # pyupgrade
"D", # pydocstyle
]
ignore = [
"D105", # Missing docstring in magic method
]

fixable = ["ALL"]

[tool.ruff.pydocstyle]
convention = "numpy"

[tool.black]
line-length = 120

[tool.pyright]
include = ["src", "tests"]

[tool.pytest.ini_options]
minversion = "6.0"
testpaths = ["tests"]

[tool.poetry]
name = "unified-metric"
version = "0.1.0"
description = ""
authors = ["Tongfei Chen <[email protected]>"]
license = "MIT"
description = "A Unified View of Evaluation Metrics for Information Extraction"
authors = ["Tongfei Chen <[email protected]>", "Yunmo Chen <[email protected]>"]
readme = "README.md"
packages = [{include = "unimetric"}]
packages = [{ include = "unimetric", from = "src" }]

[tool.poetry.dependencies]
python = "^3.8"
numpy = "^1.19"
scipy = "^1.9"
python = ">=3.9,<3.13"
scipy = "^1.11.2"
numpy = "^1.25.2"
pytest = "^7.4.2"

[tool.poetry.group.dev]
optional = true

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest = "^7.4.2"
pyright = "^1.1.327"
ruff = "^0.0.290"

[build-system]
requires = ["poetry-core"]
Expand Down
1 change: 1 addition & 0 deletions src/unimetric/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""The core functionality of unimetric package."""
1 change: 1 addition & 0 deletions src/unimetric/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Core functionality for the unimetric package that provides the automatic metric derivation."""
52 changes: 41 additions & 11 deletions unimetric/alignment.py → src/unimetric/core/alignment.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,74 @@
from typing import Collection, Generic, Mapping, TypeVar, Iterator, Tuple
"""Metric derivation with alignment constraints."""
import enum
from enum import Enum
from typing import Collection, TypeVar

import numpy as np
import scipy.optimize as spo

from unimetric.metric import Metric, DiscreteMetric

from unimetric.core.metric import Metric

T = TypeVar('T')
T = TypeVar("T")


class AlignmentConstraint(Enum):
"""
Alignment constraints for the alignment metric.
"""
ONE_TO_ONE = 0
ONE_TO_MANY = 1
MANY_TO_ONE = 2
MANY_TO_MANY = 3
"""Alignment constraints for the alignment metric."""

ONE_TO_ONE = enum.auto()
ONE_TO_MANY = enum.auto()
MANY_TO_ONE = enum.auto()
MANY_TO_MANY = enum.auto()


class AlignmentMetric(Metric[Collection[T]]):
"""A metric derived using some alignment constraints."""

def __init__(self, inner: Metric[T], constraint: AlignmentConstraint = AlignmentConstraint.ONE_TO_ONE):
self.inner = inner
self.constraint = constraint

def score(self, x: Collection[T], y: Collection[T]) -> float:
"""Score two collections of objects.
Parameters
----------
x : Collection[T]
y : Collection[T]
Returns
-------
float
The score of the two collections.
"""
# TODO: alternative implementation when the inner metric is discrete
return solve_alignment(
self.inner.gram_matrix(x, y),
self.constraint,
)

def score_self(self, x: Collection[T]) -> float:
"""Score a collection of objects with itself."""
if self.constraint == AlignmentConstraint.MANY_TO_MANY:
return self.inner.gram_matrix(x, x).sum()
else:
return sum(self.inner.score_self(u) for u in x)


def solve_alignment(gram_matrix: np.ndarray, constraint: AlignmentConstraint) -> float:
"""Solve the alignment problem.
Parameters
----------
gram_matrix : np.ndarray
The gram matrix of the inner metric.
constraint : AlignmentConstraint
The alignment constraint.
Returns
-------
float
The score of the alignment.
"""
if constraint == AlignmentConstraint.ONE_TO_ONE:
row_idx, col_idx = spo.linear_sum_assignment(
cost_matrix=gram_matrix,
Expand Down
154 changes: 154 additions & 0 deletions src/unimetric/core/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Decorator for deriving metrics from dataclasses."""
from dataclasses import fields, is_dataclass
from typing import (
Literal,
get_args,
get_origin,
Collection,
Annotated,
Union,
Protocol,
runtime_checkable,
Callable,
TypeVar,
Any,
)

from unimetric.core.alignment import AlignmentConstraint, AlignmentMetric
from unimetric.core.latent_alignment import dataclass_has_variable, LatentAlignmentMetric
from unimetric.core.metric import Metric, ProductMetric, DiscreteMetric, FScore, Jaccard, Precision, Recall, UnionMetric

T = TypeVar("T", covariant=True)

NormalizerLiteral = Literal["none", "jaccard", "dice", "f1"]
ConstraintLiteral = Literal["<->", "<-", "->", "~"]


@runtime_checkable
class HasMetric(Protocol):
"""Protocol for classes that have a metric."""

metric: Metric


@runtime_checkable
class HasLatentMetric(Protocol):
"""Protocol for classes that have a latent metric."""

latent_metric: Metric


def derive_metric(cls: Any, constraint: AlignmentConstraint) -> Metric:
"""Derive a unified metric from any type.
Parameters
----------
cls : Any
The dataclass-like class to derive the metric from.
constraint : AlignmentConstraint
The alignment constraint to use.
Returns
-------
Metric
The derived metric.
"""
# if the type is annotated with a metric instance, use the metric annotation
if get_origin(cls) is Annotated:
metric = get_args(cls)[1]
if isinstance(metric, Metric):
return metric

# if an explicit metric is defined, use it
# if getattr(cls, "metric", None) is not None:
if isinstance(cls, HasMetric):
return cls.metric

cls_origin = get_origin(cls)
# if getattr(cls, "latent_metric", None) is not None:
if isinstance(cls, HasLatentMetric):
return cls.latent_metric

# derive product metric from dataclass
elif is_dataclass(cls):
return ProductMetric(
cls=cls, field_metrics={fld.name: derive_metric(fld.type, constraint=constraint) for fld in fields(cls)}
)

# derive union metric from unions
elif cls_origin is Union:
return UnionMetric(
cls=cls, case_metrics={case: derive_metric(case, constraint=constraint) for case in get_args(cls)}
)

# derive alignment metric from collections
elif cls_origin is not None and isinstance(cls_origin, type) and issubclass(cls_origin, Collection):
elem_type = get_args(cls)[0]
inner_metric = derive_metric(elem_type, constraint=constraint)
if dataclass_has_variable(elem_type):
return LatentAlignmentMetric(
cls=elem_type,
inner=inner_metric,
constraint=constraint,
)
else:
return AlignmentMetric(
inner=inner_metric,
constraint=constraint,
)

# derive discrete metric from equality
elif getattr(cls, "__eq__", None) is not None:
return DiscreteMetric(cls=cls)

else:
raise ValueError(f"Could not derive metric from type {cls}.")


def unimetric(
normalizer: NormalizerLiteral = "none",
constraint: ConstraintLiteral = "<->",
) -> Callable[[T], T]:
"""Decorate a dataclass to have corresponding metric derived.
Parameters
----------
normalizer : NormalizerLiteral
The normalizer to use, by default "none"
constraint : ConstraintLiteral
The alignment constraint to use, by default "<->"
Returns
-------
Callable[[T], T]
The decorated new class.
"""

def class_decorator(cls: T) -> T:
alignment_constraint = {
"<->": AlignmentConstraint.ONE_TO_ONE,
"<-": AlignmentConstraint.ONE_TO_MANY,
"->": AlignmentConstraint.MANY_TO_ONE,
"~": AlignmentConstraint.MANY_TO_MANY,
"1:1": AlignmentConstraint.ONE_TO_ONE,
"1:*": AlignmentConstraint.ONE_TO_MANY,
"*:1": AlignmentConstraint.MANY_TO_ONE,
"*:*": AlignmentConstraint.MANY_TO_MANY,
}[constraint]
metric = derive_metric(cls, constraint=alignment_constraint)
normalized_metric = {
"none": lambda x: x,
"jaccard": Jaccard,
"dice": FScore,
"f1": FScore,
"precision": Precision,
"recall": Recall,
}[normalizer](metric)

if dataclass_has_variable(cls):
setattr(cls, "latent_metric", normalized_metric) # type: ignore
else:
setattr(cls, "metric", normalized_metric) # type: ignore
return cls

return class_decorator
Loading

0 comments on commit c0951c8

Please sign in to comment.