Skip to content

Commit

Permalink
feat(sim/graphs): add dtw alignment (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin-Hoppe authored Jan 8, 2025
1 parent cfc1e86 commit 95e46eb
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 42 deletions.
202 changes: 160 additions & 42 deletions src/cbrkit/sim/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,40 @@ def __call__(self, x: Sequence[V], y: Sequence[V]) -> float:
return 0.0


@dataclass(slots=True, frozen=True)
class SequenceSim[V, S: Float](StructuredValue[float]):
"""
A class representing sequence similarity with optional mapping and similarity scores.
Attributes:
value: The overall similarity score as a float.
similarities: Optional local similarity scores as a sequence of floats.
mapping: Optional alignment information as a sequence of tuples.
"""

value: float
similarities: Sequence[S] | None = field(default=None)
mapping: Sequence[tuple[V, V]] | None = field(default=None)


@dataclass(slots=True)
class dtw[V](SimFunc[Collection[V] | np.ndarray, float]):
"""Dynamic Time Warping similarity function.
class dtw[V](SimFunc[Collection[V] | np.ndarray, SequenceSim[V, float]]):
"""
Dynamic Time Warping similarity function with optional backtracking for alignment.
Examples:
>>> sim = dtw()
>>> sim([1, 2, 3], [1, 2, 3, 4])
0.5
SequenceSim(value=0.5, similarities=None, mapping=None)
>>> sim = dtw(distance_func=lambda a, b: abs(a - b))
>>> sim([1, 2, 3], [3, 4, 5])
0.14285714285714285
SequenceSim(value=0.14285714285714285, similarities=None, mapping=None)
>>> sim = dtw(distance_func=lambda a, b: abs(len(str(a)) - len(str(b))))
>>> sim(["a", "bb", "ccc", "ddd", "ee", "fff"], ["cffffffffffcj", "dfffffffffffffffffffffffffffffffffded"])
0.011235955056179775
>>> sim(["a", "bb", "ccc"], ["aa", "bbb", "c"], return_alignment=True)
SequenceSim(value=0.25, similarities=[0.5, 1.0, 1.0, 0.3333333333333333], mapping=[('a', 'aa'), ('bb', 'aa'), ('ccc', 'bbb'), ('ccc', 'c')])
>>> sim = dtw(distance_func=lambda a, b: abs(a - b))
>>> sim([1, 2, 3], [1, 2, 3, 4], return_alignment=True)
SequenceSim(value=0.5, similarities=[1.0, 1.0, 1.0, 0.5], mapping=[(1, 1), (2, 2), (3, 3), (3, 4)])
"""

distance_func: Callable[[V, V], float] | None = None
Expand All @@ -108,21 +128,53 @@ def __call__(
self,
x: Collection[V] | np.ndarray,
y: Collection[V] | np.ndarray,
) -> float:
return_alignment: bool = False,
) -> SequenceSim[V, float]:
"""
Perform DTW and optionally return alignment information.
Args:
x: The first sequence as a collection or numpy array.
y: The second sequence as a collection or numpy array.
return_alignment: Whether to compute and return the alignment (default: False).
Returns:
A SequenceSim object containing the similarity value, local similarities, and optional alignment.
"""
if not isinstance(x, np.ndarray):
x = np.array(x, dtype=object) # Allow non-numeric types
if not isinstance(y, np.ndarray):
y = np.array(y, dtype=object)

# Compute the DTW distance manually using the custom distance
dtw_distance = self.compute_dtw(x, y)
# Compute the DTW distance
dtw_distance, alignment, local_similarities = self.compute_dtw(
x, y, return_alignment
)

# Convert DTW distance to similarity
similarity = dist2sim(dtw_distance)

return float(similarity)
# Return SequenceSim with updated attributes
return SequenceSim(
value=float(similarity),
similarities=local_similarities,
mapping=alignment if return_alignment else None,
)

def compute_dtw(
self, x: np.ndarray, y: np.ndarray, return_alignment: bool
) -> tuple[float, list[tuple[V, V]] | None, list[float] | None]:
"""
Compute DTW distance and optionally compute the best alignment and local similarities.
Args:
x: The first sequence as a numpy array.
y: The second sequence as a numpy array.
return_alignment: Whether to compute the alignment.
def compute_dtw(self, x: np.ndarray, y: np.ndarray) -> float:
Returns:
A tuple of (DTW distance, best alignment or None, local similarities or None).
"""
n, m = len(x), len(y)
dtw_matrix = np.full((n + 1, m + 1), np.inf)
dtw_matrix[0, 0] = 0
Expand All @@ -142,7 +194,71 @@ def compute_dtw(self, x: np.ndarray, y: np.ndarray) -> float:
)
dtw_matrix[i, j] = cost + last_min

return dtw_matrix[n, m]
# If alignment is not requested, skip backtracking
if not return_alignment:
return dtw_matrix[n, m], None, None

# Backtracking to find the best alignment and local similarities
mapping, local_similarities = self.backtrack(dtw_matrix, x, y, n, m)

return dtw_matrix[n, m], mapping, local_similarities

def backtrack(
self, dtw_matrix: np.ndarray, x: np.ndarray, y: np.ndarray, n: int, m: int
) -> tuple[list[tuple[V, V]], list[float]]:
"""
Backtrack through the DTW matrix to find the best alignment and local similarities.
Args:
dtw_matrix: The DTW matrix.
x: The first sequence as a numpy array.
y: The second sequence as a numpy array.
n: The length of the first sequence.
m: The length of the second sequence.
Returns:
A tuple of (alignment, local similarities).
"""
i, j = n, m
alignment = []
local_similarities = []

while i > 0 and j > 0:
alignment.append((x[i - 1], y[j - 1])) # Align elements
cost = (
self.distance_func(x[i - 1], y[j - 1])
if self.distance_func
else abs(x[i - 1] - y[j - 1])
)
local_similarities.append(dist2sim(cost)) # Convert cost to similarity
# Move in the direction of the minimum cost
if dtw_matrix[i - 1, j] == min(
dtw_matrix[i - 1, j], dtw_matrix[i, j - 1], dtw_matrix[i - 1, j - 1]
):
i -= 1 # Move up
elif dtw_matrix[i, j - 1] == min(
dtw_matrix[i - 1, j], dtw_matrix[i, j - 1], dtw_matrix[i - 1, j - 1]
):
j -= 1 # Move left
else:
i -= 1 # Move diagonally
j -= 1

# Handle remaining elements in i or j
while i > 0:
alignment.append((x[i - 1], None)) # Unmatched element from x
local_similarities.append(0.0) # No similarity for unmatched
i -= 1
while j > 0:
alignment.append((None, y[j - 1])) # Unmatched element from y
local_similarities.append(0.0) # No similarity for unmatched
j -= 1

return alignment[::-1], local_similarities[
::-1
] # Reverse to start from the beginning

__all__ += ["dtw"]


@dataclass(slots=True, frozen=True)
Expand Down Expand Up @@ -246,12 +362,6 @@ def __call__(self, query: Sequence[V], case: Sequence[V]) -> float:
return best_score


@dataclass(slots=True, frozen=True)
class SequenceSim[S: Float](StructuredValue[float]):
value: float
local_similarities: Sequence[S] = field(default_factory=tuple)


@dataclass
class Weight:
weight: float
Expand All @@ -263,20 +373,23 @@ class Weight:


@dataclass(slots=True, frozen=True)
class sequence_mapping[V, S: Float](SimFunc[Sequence[V], SequenceSim[S]], HasMetadata):
"""List Mapping similarity function.
class sequence_mapping[V, S: Float](
SimFunc[Sequence[V], SequenceSim[V, S]], HasMetadata
):
"""
List Mapping similarity function.
Parameters:
element_similarity: The similarity function to use for comparing elements.
exact: Whether to use exact or inexact comparison. Default is False (inexact).
weights: Optional list of weights for weighted similarity calculation.
element_similarity: The similarity function to use for comparing elements.
exact: Whether to use exact or inexact comparison. Default is False (inexact).
weights: Optional list of weights for weighted similarity calculation.
Examples:
>>> sim = sequence_mapping(lambda x, y: 1.0 if x == y else 0.0, True)
>>> result = sim(["a", "b", "c"], ["a", "b", "c"])
>>> result.value
1.0
>>> result.local_similarities
>>> result.similarities
[1.0, 1.0, 1.0]
"""

Expand All @@ -285,7 +398,6 @@ class sequence_mapping[V, S: Float](SimFunc[Sequence[V], SequenceSim[S]], HasMet
weights: list[Weight] | None = None

@property
@override
def metadata(self) -> JsonDict:
return {
"element_similarity": get_metadata(self.element_similarity),
Expand All @@ -297,9 +409,9 @@ def metadata(self) -> JsonDict:

def compute_contains_exact(
self, list1: Sequence[V], list2: Sequence[V]
) -> SequenceSim[S]:
) -> SequenceSim[V, S]:
if len(list1) != len(list2):
return SequenceSim(value=0.0)
return SequenceSim(value=0.0, similarities=None, mapping=None)

sim_sum = 0.0
local_similarities: list[S] = []
Expand All @@ -310,29 +422,30 @@ def compute_contains_exact(
local_similarities.append(sim)

return SequenceSim(
value=sim_sum / len(list1), local_similarities=local_similarities
value=sim_sum / len(list1),
similarities=local_similarities,
mapping=None,
)

def compute_contains_inexact(
self, larger_list: Sequence[V], smaller_list: Sequence[V]
) -> SequenceSim[S]:
) -> SequenceSim[V, S]:
max_similarity = -1.0
best_local_similarities = []
best_local_similarities: list[S] = []

for i in range(len(larger_list) - len(smaller_list) + 1):
sublist = larger_list[i : i + len(smaller_list)]
sim_result = self.compute_contains_exact(sublist, smaller_list)

if sim_result.value > max_similarity:
max_similarity = sim_result.value
best_local_similarities = sim_result.local_similarities
best_local_similarities = sim_result.similarities or []

return SequenceSim(
value=max_similarity, local_similarities=best_local_similarities
value=max_similarity, similarities=best_local_similarities, mapping=None
)

@override
def __call__(self, x: Sequence[V], y: Sequence[V]) -> SequenceSim[S]:
def __call__(self, x: Sequence[V], y: Sequence[V]) -> SequenceSim[V, S]:
if self.exact:
result = self.compute_contains_exact(x, y)
else:
Expand All @@ -341,7 +454,7 @@ def __call__(self, x: Sequence[V], y: Sequence[V]) -> SequenceSim[S]:
else:
result = self.compute_contains_inexact(y, x)

if self.weights:
if self.weights and result.similarities:
total_weighted_sim = 0.0
total_weight = 0.0

Expand All @@ -350,21 +463,24 @@ def __call__(self, x: Sequence[V], y: Sequence[V]) -> SequenceSim[S]:
weight_range = weight.upper_bound - weight.lower_bound
weight.normalized_weight = weight.weight / weight_range

for sim in result.local_similarities:
sim = unpack_float(sim)
for sim in result.similarities:
sim_val = unpack_float(sim)

for weight in self.weights:
lower_bound = weight.lower_bound
upper_bound = weight.upper_bound
inclusive_lower = weight.inclusive_lower
inclusive_upper = weight.inclusive_upper

# Check if sim_val falls within weight's bounds
if (
(inclusive_lower and lower_bound <= sim <= upper_bound)
or (not inclusive_lower and lower_bound < sim <= upper_bound)
) and (inclusive_upper or sim < upper_bound):
(inclusive_lower and lower_bound <= sim_val <= upper_bound)
or (
not inclusive_lower and lower_bound < sim_val <= upper_bound
)
) and (inclusive_upper or sim_val < upper_bound):
assert weight.normalized_weight is not None
weighted_sim = weight.normalized_weight * sim
weighted_sim = weight.normalized_weight * sim_val
total_weighted_sim += weighted_sim
total_weight += weight.normalized_weight

Expand All @@ -374,7 +490,9 @@ def __call__(self, x: Sequence[V], y: Sequence[V]) -> SequenceSim[S]:
final_similarity = result.value

return SequenceSim(
value=final_similarity, local_similarities=result.local_similarities
value=final_similarity,
similarities=result.similarities,
mapping=result.mapping,
)

return result
Expand Down
Loading

0 comments on commit 95e46eb

Please sign in to comment.