-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from wanmok/project-skeleton
Updated project skeleton
- Loading branch information
Showing
20 changed files
with
684 additions
and
373 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,53 @@ | ||
[tool.ruff] | ||
line-length = 120 | ||
src = ["src", "tests"] | ||
|
||
select = [ | ||
"E", # pycodestyle | ||
"F", # pyflakes | ||
"UP", # pyupgrade | ||
"D", # pydocstyle | ||
] | ||
ignore = [ | ||
"D105", # Missing docstring in magic method | ||
] | ||
|
||
fixable = ["ALL"] | ||
|
||
[tool.ruff.pydocstyle] | ||
convention = "numpy" | ||
|
||
[tool.black] | ||
line-length = 120 | ||
|
||
[tool.pyright] | ||
include = ["src", "tests"] | ||
|
||
[tool.pytest.ini_options] | ||
minversion = "6.0" | ||
testpaths = ["tests"] | ||
|
||
[tool.poetry] | ||
name = "unified-metric" | ||
version = "0.1.0" | ||
description = "" | ||
authors = ["Tongfei Chen <[email protected]>"] | ||
license = "MIT" | ||
description = "A Unified View of Evaluation Metrics for Information Extraction" | ||
authors = ["Tongfei Chen <[email protected]>", "Yunmo Chen <[email protected]>"] | ||
readme = "README.md" | ||
packages = [{include = "unimetric"}] | ||
packages = [{ include = "unimetric", from = "src" }] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.8" | ||
numpy = "^1.19" | ||
scipy = "^1.9" | ||
python = ">=3.9,<3.13" | ||
scipy = "^1.11.2" | ||
numpy = "^1.25.2" | ||
pytest = "^7.4.2" | ||
|
||
[tool.poetry.group.dev] | ||
optional = true | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
pytest = "^7.4.0" | ||
pytest = "^7.4.2" | ||
pyright = "^1.1.327" | ||
ruff = "^0.0.290" | ||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""The core functionality of unimetric package.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Core functionality for the unimetric package that provides the automatic metric derivation.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"""Decorator for deriving metrics from dataclasses.""" | ||
from dataclasses import fields, is_dataclass | ||
from typing import ( | ||
Literal, | ||
get_args, | ||
get_origin, | ||
Collection, | ||
Annotated, | ||
Union, | ||
Protocol, | ||
runtime_checkable, | ||
Callable, | ||
TypeVar, | ||
Any, | ||
) | ||
|
||
from unimetric.core.alignment import AlignmentConstraint, AlignmentMetric | ||
from unimetric.core.latent_alignment import dataclass_has_variable, LatentAlignmentMetric | ||
from unimetric.core.metric import Metric, ProductMetric, DiscreteMetric, FScore, Jaccard, Precision, Recall, UnionMetric | ||
|
||
T = TypeVar("T", covariant=True) | ||
|
||
NormalizerLiteral = Literal["none", "jaccard", "dice", "f1"] | ||
ConstraintLiteral = Literal["<->", "<-", "->", "~"] | ||
|
||
|
||
@runtime_checkable | ||
class HasMetric(Protocol): | ||
"""Protocol for classes that have a metric.""" | ||
|
||
metric: Metric | ||
|
||
|
||
@runtime_checkable | ||
class HasLatentMetric(Protocol): | ||
"""Protocol for classes that have a latent metric.""" | ||
|
||
latent_metric: Metric | ||
|
||
|
||
def derive_metric(cls: Any, constraint: AlignmentConstraint) -> Metric: | ||
"""Derive a unified metric from any type. | ||
Parameters | ||
---------- | ||
cls : Any | ||
The dataclass-like class to derive the metric from. | ||
constraint : AlignmentConstraint | ||
The alignment constraint to use. | ||
Returns | ||
------- | ||
Metric | ||
The derived metric. | ||
""" | ||
# if the type is annotated with a metric instance, use the metric annotation | ||
if get_origin(cls) is Annotated: | ||
metric = get_args(cls)[1] | ||
if isinstance(metric, Metric): | ||
return metric | ||
|
||
# if an explicit metric is defined, use it | ||
# if getattr(cls, "metric", None) is not None: | ||
if isinstance(cls, HasMetric): | ||
return cls.metric | ||
|
||
cls_origin = get_origin(cls) | ||
# if getattr(cls, "latent_metric", None) is not None: | ||
if isinstance(cls, HasLatentMetric): | ||
return cls.latent_metric | ||
|
||
# derive product metric from dataclass | ||
elif is_dataclass(cls): | ||
return ProductMetric( | ||
cls=cls, field_metrics={fld.name: derive_metric(fld.type, constraint=constraint) for fld in fields(cls)} | ||
) | ||
|
||
# derive union metric from unions | ||
elif cls_origin is Union: | ||
return UnionMetric( | ||
cls=cls, case_metrics={case: derive_metric(case, constraint=constraint) for case in get_args(cls)} | ||
) | ||
|
||
# derive alignment metric from collections | ||
elif cls_origin is not None and isinstance(cls_origin, type) and issubclass(cls_origin, Collection): | ||
elem_type = get_args(cls)[0] | ||
inner_metric = derive_metric(elem_type, constraint=constraint) | ||
if dataclass_has_variable(elem_type): | ||
return LatentAlignmentMetric( | ||
cls=elem_type, | ||
inner=inner_metric, | ||
constraint=constraint, | ||
) | ||
else: | ||
return AlignmentMetric( | ||
inner=inner_metric, | ||
constraint=constraint, | ||
) | ||
|
||
# derive discrete metric from equality | ||
elif getattr(cls, "__eq__", None) is not None: | ||
return DiscreteMetric(cls=cls) | ||
|
||
else: | ||
raise ValueError(f"Could not derive metric from type {cls}.") | ||
|
||
|
||
def unimetric( | ||
normalizer: NormalizerLiteral = "none", | ||
constraint: ConstraintLiteral = "<->", | ||
) -> Callable[[T], T]: | ||
"""Decorate a dataclass to have corresponding metric derived. | ||
Parameters | ||
---------- | ||
normalizer : NormalizerLiteral | ||
The normalizer to use, by default "none" | ||
constraint : ConstraintLiteral | ||
The alignment constraint to use, by default "<->" | ||
Returns | ||
------- | ||
Callable[[T], T] | ||
The decorated new class. | ||
""" | ||
|
||
def class_decorator(cls: T) -> T: | ||
alignment_constraint = { | ||
"<->": AlignmentConstraint.ONE_TO_ONE, | ||
"<-": AlignmentConstraint.ONE_TO_MANY, | ||
"->": AlignmentConstraint.MANY_TO_ONE, | ||
"~": AlignmentConstraint.MANY_TO_MANY, | ||
"1:1": AlignmentConstraint.ONE_TO_ONE, | ||
"1:*": AlignmentConstraint.ONE_TO_MANY, | ||
"*:1": AlignmentConstraint.MANY_TO_ONE, | ||
"*:*": AlignmentConstraint.MANY_TO_MANY, | ||
}[constraint] | ||
metric = derive_metric(cls, constraint=alignment_constraint) | ||
normalized_metric = { | ||
"none": lambda x: x, | ||
"jaccard": Jaccard, | ||
"dice": FScore, | ||
"f1": FScore, | ||
"precision": Precision, | ||
"recall": Recall, | ||
}[normalizer](metric) | ||
|
||
if dataclass_has_variable(cls): | ||
setattr(cls, "latent_metric", normalized_metric) # type: ignore | ||
else: | ||
setattr(cls, "metric", normalized_metric) # type: ignore | ||
return cls | ||
|
||
return class_decorator |
Oops, something went wrong.