Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-kruse committed Feb 2, 2025
2 parents 8d1be15 + d7292f3 commit d1e8c9c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 22 deletions.
14 changes: 12 additions & 2 deletions examples/parametric_derivatives.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/eikonax/derivator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PartialDerivatorData:


# ==================================================================================================
class PartialDerivator(eqx.Module):
class PartialDerivator(eqx.Module, strict=True):
r"""Component for computing partial derivatives of the Godunov Update operator.
Given a tensor field $M$ and a solution vector $u$, the partial derivator computes the partial
Expand Down Expand Up @@ -586,7 +586,8 @@ class DerivativeSolver:
\mathbf{u}(\mathbf{m}) = \mathbf{G}(\mathbf{u}(\mathbf{m}), \mathbf{M}(\mathbf{m}))
$$
To obtain the Jacobian $\mathbf{J} = \frac{d\mathbf{u}}{d\mathbf{m}}\in\mathbb{R}^{N_V\times M}$,
To obtain the Jacobian
$\mathbf{J} = \frac{d\mathbf{u}}{d\mathbf{m}}\in\mathbb{R}^{N_V\times M}$,
we simply differentiate the fixed point relation,
$$
Expand Down
5 changes: 3 additions & 2 deletions src/eikonax/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ class Solution:


# ==================================================================================================
class Solver(eqx.Module):
class Solver(eqx.Module, strict=True):
r"""Eikonax solver class.
The solver class is the main component for computing the solution $u$ of the Eikonal equation
for given geometry $\Omega$ of dimension $d$, tensor field $\mathbf{M}$, and initial sites $\Gamma$,
for given geometry $\Omega$ of dimension $d$, tensor field $\mathbf{M}$, and initial sites
$\Gamma$,
$$
\begin{gather*}
Expand Down
65 changes: 49 additions & 16 deletions src/eikonax/tensorfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
TensorField: Tensor field component
"""

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import final

import equinox as eqx
import jax
Expand All @@ -42,7 +43,7 @@


# ==================================================================================================
class BaseVectorToSimplicesMap(ABC, eqx.Module):
class AbstractVectorToSimplicesMap(eqx.Module, strict=True):
"""ABC interface contract for vector-to-simplices maps.
Every component derived from this class needs to implement the `map` method, which maps returns
Expand Down Expand Up @@ -82,7 +83,8 @@ def map(


# --------------------------------------------------------------------------------------------------
class LinearScalarMap(BaseVectorToSimplicesMap):
@final
class LinearScalarMap(AbstractVectorToSimplicesMap):
r"""Simple one-to-one map from global to simplex parameters.
Every simplex takes exactly one parameter $m_s$, which is sorted in the global parameter
Expand All @@ -109,7 +111,7 @@ def map(


# ==================================================================================================
class BaseSimplexTensor(ABC, eqx.Module):
class AbstractSimplexTensor(eqx.Module, strict=True):
"""ABC interface contract for assembly of the tensor field.
`SimplexTensor` components assemble the tensor field for a given simplex and a set of parameters
Expand All @@ -127,13 +129,15 @@ class BaseSimplexTensor(ABC, eqx.Module):
"""

# Equinox modules are data classes, so we have to define attributes at the class level
_dimension: int
_dimension: eqx.AbstractVar[int]

# ----------------------------------------------------------------------------------------------
def __init__(self, dimension: int) -> None:
"""Constructor, simply fixes the dimension of the tensor field."""
self._dimension = dimension
def __check_init__(self) -> None:
"""Check that dimension is initialized correctly in subclasses."""
if not isinstance(self._dimension, int):
raise TypeError("Dimension must be an integer")

# ----------------------------------------------------------------------------------------------
@abstractmethod
def assemble(
self,
Expand All @@ -158,6 +162,7 @@ def assemble(
"""
raise NotImplementedError

# ----------------------------------------------------------------------------------------------
@abstractmethod
def derivative(
self,
Expand All @@ -184,8 +189,9 @@ def derivative(
raise NotImplementedError


# --------------------------------------------------------------------------------------------------
class LinearScalarSimplexTensor(BaseSimplexTensor):
# ==================================================================================================
@final
class LinearScalarSimplexTensor(AbstractSimplexTensor):
r"""SimplexTensor implementation relying on one parameter per simplex.
Given a scalar parameter $m_s$, the tensor field is assembled as $m_s \cdot \mathbf{I}$, where
Expand All @@ -196,6 +202,18 @@ class LinearScalarSimplexTensor(BaseSimplexTensor):
derivative: Parametric derivative of the `assemble` method
"""

_dimension: int

# ----------------------------------------------------------------------------------------------
def __init__(self, dimension: int) -> None:
"""Constructor.
Args:
dimension (int): Dimension of the tensor field
"""
self._dimension = dimension

# ----------------------------------------------------------------------------------------------
def assemble(
self, _simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, ""]
) -> jtFloat[jax.Array, "dim dim"]:
Expand All @@ -213,6 +231,7 @@ def assemble(
tensor = parameters * jnp.identity(self._dimension, dtype=jnp.float32)
return tensor

# ----------------------------------------------------------------------------------------------
def derivative(
self, _simplex_ind: jtInt[jax.Array, ""], _parameters: jtFloat[jax.Array, ""]
) -> jtFloat[jax.Array, "dim dim num_parameters_local"]:
Expand All @@ -229,8 +248,9 @@ def derivative(
return derivative


# --------------------------------------------------------------------------------------------------
class InvLinearScalarSimplexTensor(BaseSimplexTensor):
# ==================================================================================================
@final
class InvLinearScalarSimplexTensor(AbstractSimplexTensor):
r"""SimplexTensor implementation relying on one parameter per simplex.
Given a scalar parameter $m_s$, the tensor field is assembled as
Expand All @@ -241,6 +261,18 @@ class InvLinearScalarSimplexTensor(BaseSimplexTensor):
derivative: Parametric derivative of the `assemble` method
"""

_dimension: int

# ----------------------------------------------------------------------------------------------
def __init__(self, dimension: int) -> None:
"""Constructor.
Args:
dimension (int): Dimension of the tensor field
"""
self._dimension = dimension

# ----------------------------------------------------------------------------------------------
def assemble(
self, _simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, ""]
) -> jtFloat[jax.Array, "dim dim"]:
Expand All @@ -258,6 +290,7 @@ def assemble(
tensor = 1 / parameters * jnp.identity(self._dimension, dtype=jnp.float32)
return tensor

# ----------------------------------------------------------------------------------------------
def derivative(
self, _simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, ""]
) -> jtFloat[jax.Array, "dim dim num_parameters_local"]:
Expand Down Expand Up @@ -302,15 +335,15 @@ class TensorField(eqx.Module):
# Equinox modules are data classes, so we have to define attributes at the class level
_num_simplices: int
_simplex_inds: jtFloat[jax.Array, "num_simplices"]
_vector_to_simplices_map: BaseVectorToSimplicesMap
_simplex_tensor: BaseSimplexTensor
_vector_to_simplices_map: AbstractVectorToSimplicesMap
_simplex_tensor: AbstractSimplexTensor

# ----------------------------------------------------------------------------------------------
def __init__(
self,
num_simplices: int,
vector_to_simplices_map: BaseVectorToSimplicesMap,
simplex_tensor: BaseSimplexTensor,
vector_to_simplices_map: AbstractVectorToSimplicesMap,
simplex_tensor: AbstractSimplexTensor,
) -> None:
"""Constructor.
Expand Down

0 comments on commit d1e8c9c

Please sign in to comment.