Skip to content

Commit

Permalink
Merge pull request #137 from ucl-bug/plum-update
Browse files Browse the repository at this point in the history
Plum update
  • Loading branch information
astanziola authored Aug 7, 2023
2 parents f2a4509 + 8cb1881 commit ee5ba44
Show file tree
Hide file tree
Showing 39 changed files with 1,450 additions and 1,092 deletions.
11 changes: 5 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ repos:
hooks:
- id: pycln
args: [--config=pyproject.toml]
- repo: https://github.com/google/yapf
rev: v0.40.0
hooks:
- id: yapf
args: ['--style=pyproject.toml', '--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: '5.12.0'
hooks:
- id: isort
files: "\\.(py)$"
args: [--settings-path=pyproject.toml]
- repo: https://github.com/google/yapf
rev: v0.40.0
hooks:
- id: yapf
args: ['--style=pyproject.toml', '--parallel', '--in-place']
25 changes: 24 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Changed
- The Quickstart tutorial has been upgdated.
- The `@new_discretization` decorator has been renamed `@discretization`
- The property `Field.ndim` has now been moved into `Field.domain.ndim`, as it is fundamentally a property of the domain
- Before, `OnGrid` fields were able to automatically add an extra dimension if needed at initialization. This however can easily clash with some of the internal operations of jax during compliation. This is now not possible, use `.from_grid` instead, which implements the same functionality.

### Removed
- The `__about__` file has been removed, as it is redundant
- The function `params_map` is removed, use `jax.tree_util.tree_map` instead.

### Added
- The new `operator.abstract` decorator can be used to define an unimplemented operator, with the goal of specifying input arguments and docstrings.
- `Linear` fields are now defined as equal if they have the same set of parameters.
- `Ongrid` fields now have the property `.add_dim`, which adds an extra tailing dimension to its parameters. The method returns a new field.
- The function `jaxdf.util.get_implemented` is now exposed to the user.
- Added `laplacian` operator for `FiniteDifferences` fields.

### Deprecated
- The property `.is_field_complex` is now deprecated in favor of `.is_complex`. Same argument for `.is_real`
- `Field.get_field` is now deprecated in favor of the `__call__` metho.

### Fixed
- `OnGrid.from_grid` now automatically adds a dimension at the end of the array for scalar fields, if needed
- Added a custom operator for `equinox.internal._omega._Metaω` objects and Fields, which makes the library compatible with `diffrax`

## [0.2.6] - 2023-06-28
### Changed
Expand Down Expand Up @@ -33,4 +57,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
[Unreleased]: https://github.com/ucl-bug/jaxdf/compare/0.2.6...master
[0.2.6]: https://github.com/ucl-bug/jaxdf/compare/0.2.5...0.2.6
[0.2.5]: https://github.com/ucl-bug/jaxdf/tree/0.2.5

866 changes: 463 additions & 403 deletions docs/notebooks/quickstart.ipynb

Large diffs are not rendered by default.

31 changes: 16 additions & 15 deletions docs/operators/differential.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,52 @@

## `derivative`

Given a field $`u`$, it returns the field
Given a field $u$, it returns the field

```math
$$
\frac{\partial}{\partial \epsilon} u, \qquad \epsilon \in \{x, y, \dots \}
```
$$

{{ implementations('jaxdf.operators.differential', 'derivative') }}


## `diag_jacobian`

Given a vector field $`u = (u_x,u_y,\dots)`$ with the same dimensions as the dimensions of the domain, it returns the diagonal of the Jacobian matrix
Given a vector field $u = (u_x,u_y,\dots)$ with the same dimensions as the dimensions of the domain, it returns the diagonal of the Jacobian matrix

```math
$$
\left( \frac{\partial u_x}{\partial x}, \frac{\partial u_y}{\partial y}, \dots \right)
```
$$

{{ implementations('jaxdf.operators.differential', 'diag_jacobian') }}
## `gradient`

Given a field $`u`$, it returns the vector field
Given a field $u$, it returns the vector field

```math
$$
\nabla u = \left(\frac{\partial u}{\partial x}, \frac{\partial u}{\partial y}, \dots\right)
```
$$

{{ implementations('jaxdf.operators.differential', 'gradient') }}

## `heterog_laplacian`

Given a field $`u`$ and a cofficient field $`c`$, it returns the field
Given a field $u$ and a cofficient field $c$, it returns the field

```math
$$
\nabla_c^2 u = \nabla \cdot (c \nabla u)
```
$$

{{ implementations('jaxdf.operators.differential', 'heterog_laplacian') }}


## `laplacian`

Given a scalar field $`u`$, it returns the scalar field
Given a scalar field $u$, it returns the scalar field

```math
$$
\nabla^2 u = \nabla \cdot \nabla u = \sum_{\epsilon \in \{x,y,\dots\}} \frac{\partial^2 u}{\partial \epsilon^2}
```
$$

{{ implementations('jaxdf.operators.differential', 'laplacian') }}

Expand All @@ -61,6 +61,7 @@ Given a scalar field $`u`$, it returns the scalar field
members:
- get_fd_coefficients
- fd_derivative_init
- fd_diag_jacobian_init
show_root_heading: false
show_root_toc_entry: false
show_source: false
Expand Down
4 changes: 0 additions & 4 deletions docs/operators/dummy.md

This file was deleted.

24 changes: 12 additions & 12 deletions docs/operators/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

## `compose`

Implements a decorator that allows to compose `jax` functions with fields. Given a function $`f`$ and a `Field` $`x`$, the result is a new field representing
Implements a decorator that allows to compose `jax` functions with fields. Given a function $f$ and a `Field` $x$, the result is a new field representing

```math
$$
y = f(x)
```
$$

The usage of the decorator is as follows:
```python
Expand All @@ -32,32 +32,32 @@ It is useful to improve the readibility of the code.

## `get_component`

This operator $A(u, \text{dim})$ which has the signature `(u: Field, dim: int) -> Field`. It returns the component of the field $`u`$ at the dimension $`dim`$.
This operator $A(u, \text{dim})$ which has the signature `(u: Field, dim: int) -> Field`. It returns the component of the field $u$ at the dimension $dim$.

```math
$$
u(x) = (u_0(x), u_1(x), \ldots, u_N(x)) \to u_{\text{dim}}(x)
```
$$

{{ implementations('jaxdf.operators.functions', 'get_component') }}

## `shift_operator`

Implements the shift operator $`S(\Delta x)`$ which is used to shift (spatially) a field $`u`$ by a constant $`\Delta x`$:
Implements the shift operator $S(\Delta x)$ which is used to shift (spatially) a field $u$ by a constant $\Delta x$:

```math
$$
v = S(\Delta x) u = u(x - \Delta x)
```
$$

{{ implementations('jaxdf.operators.functions', 'shift_operator') }}


## `sum_over_dims`

Reduces a vector field $`u = (u_x, u_y, \dots)`$ to a scalar field by summing over the dimensions:
Reduces a vector field $u = (u_x, u_y, \dots)$ to a scalar field by summing over the dimensions:

```math
$$
v = \sum_{i \in \{x,y,\dots\}} u_i
```
$$

{{ implementations('jaxdf.operators.functions', 'sum_over_dims') }}

Expand Down
6 changes: 3 additions & 3 deletions docs/operators/linear_algebra.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# `jaxdf.operators.differential`
# `jaxdf.operators.linear_algebra`

## `dot_product`

Returns the dot product $`u \cdot v`$ between two vector fields $`u`$ and $`v`$.
Returns the dot product $u \cdot v$ between two vector fields $u$ and $v$.

{{ implementations('jaxdf.operators.differential', 'dot_product') }}
{{ implementations('jaxdf.operators.linear_algebra', 'dot_product') }}
6 changes: 3 additions & 3 deletions docs/operators/magic.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Implements the `bool()` operator for fields.

## `__divmod__`

Implements the `divmod()` operator for fields, which for two fields $`u`$ and $`v`$ returns a pair of fields `(u // v, u % v)`.
Implements the `divmod()` operator for fields, which for two fields $u$ and $v$ returns a pair of fields `(u // v, u % v)`.

{{ implementations('jaxdf.operators.magic', '__divmod__') }}

Expand All @@ -29,13 +29,13 @@ Implements the `*` operator for fields.

## `__neg__`

Given a field $`u`$, returns the field $`-u`$, using the syntax `-u`.
Given a field $u$, returns the field $-u$, using the syntax `-u`.

{{ implementations('jaxdf.operators.magic', '__neg__') }}

## `__pow__`

Given a field $`u`$ and a generic $`c`$, returns the field $`u^c`$, using the syntax `u**c`.
Given a field $u$ and a generic $c$, returns the field $u^c$, using the syntax `u**c`.

{{ implementations('jaxdf.operators.magic', '__pow__') }}

Expand Down
1 change: 0 additions & 1 deletion jaxdf/__about__.py

This file was deleted.

10 changes: 5 additions & 5 deletions jaxdf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# nopycln: file
from jaxdf.core import operator, debug_config, constants # isort:skip
from jaxdf import util, geometry # isort:skip
from jaxdf.discretization import * # isort:skip
from jaxdf.core import operator, debug_config, constants # isort:skip
from jaxdf import util, geometry # isort:skip
from jaxdf.discretization import * # isort:skip

# Must be imported after discretization
from jaxdf.operators.magic import * # isort:skip
from jaxdf import operators # isort:skip
from jaxdf.operators.magic import * # isort:skip
from jaxdf import operators # isort:skip
38 changes: 17 additions & 21 deletions jaxdf/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from jax import scipy as jsp


def reflection_conv(
kernel: jnp.ndarray, array: jnp.ndarray, reverse: bool = True
) -> jnp.ndarray:
def reflection_conv(kernel: jnp.ndarray,
array: jnp.ndarray,
reverse: bool = True) -> jnp.ndarray:
r"""Convolves an array with a kernel, using reflection padding.
The kernel is supposed to have the same number of dimensions as the array.
Expand All @@ -33,8 +33,7 @@ def reflection_conv(


def bubble_sort_abs_value(
points_list: List[Union[float, int]]
) -> List[Union[float, int]]:
points_list: List[Union[float, int]]) -> List[Union[float, int]]:
r"""Sorts a sequence of grid points by their absolute value.
Sorting is done __in place__. This function is written with numpy, so it can't
Expand All @@ -59,7 +58,8 @@ def bubble_sort_abs_value(
for j in range(0, len(points_list) - i - 1):
magnitude_condition = abs(points_list[j]) > abs(points_list[j + 1])
same_mag_condition = abs(points_list[j]) == abs(points_list[j + 1])
sign_condition = np.sign(points_list[j]) < np.sign(points_list[j + 1])
sign_condition = np.sign(points_list[j]) < np.sign(points_list[j +
1])
if magnitude_condition or (same_mag_condition and sign_condition):
temp = points_list[j]
points_list[j] = points_list[j + 1]
Expand All @@ -71,8 +71,8 @@ def bubble_sort_abs_value(
# TODO (astanziola): This fails on mypy for some reason, but can't work out how to fix.
@no_type_check
def fd_coefficients_fornberg(
order: int, grid_points: List[Union[float, int]], x0: Union[float, int]
) -> Tuple[List[None], List[Union[float, int]]]:
order: int, grid_points: List[Union[float, int]],
x0: Union[float, int]) -> Tuple[List[None], List[Union[float, int]]]:
r"""Generate finite difference stencils for a given order and grid points, using
the Fornberg algorithm described in [[Fornberg, 2018]](https://web.njit.edu/~jiang/math712/fornberg.pdf).
Expand Down Expand Up @@ -107,7 +107,7 @@ def fd_coefficients_fornberg(

# Sort the grid points
alpha = bubble_sort_abs_value(grid_points)
delta = dict() # key: (m,n,v)
delta = dict() # key: (m,n,v)
delta[(0, 0, 0)] = 1.0
c1 = 1.0

Expand All @@ -119,21 +119,17 @@ def fd_coefficients_fornberg(
if n < M:
delta[(n, n - 1, v)] = 0.0
for m in range(min([n, M]) + 1):
d1 = delta[(m, n - 1, v)] if (m, n - 1, v) in delta.keys() else 0.0
d2 = (
delta[(m - 1, n - 1, v)]
if (m - 1, n - 1, v) in delta.keys()
else 0.0
)
d1 = delta[(m, n - 1, v)] if (m, n - 1,
v) in delta.keys() else 0.0
d2 = (delta[(m - 1, n - 1, v)] if
(m - 1, n - 1, v) in delta.keys() else 0.0)
delta[(m, n, v)] = ((alpha[n] - x0) * d1 - m * d2) / c3

for m in range(min([n, M]) + 1):
d1 = (
delta[(m - 1, n - 1, n - 1)]
if (m - 1, n - 1, n - 1) in delta.keys()
else 0.0
)
d2 = delta[(m, n - 1, n - 1)] if (m, n - 1, n - 1) in delta.keys() else 0.0
d1 = (delta[(m - 1, n - 1, n - 1)] if
(m - 1, n - 1, n - 1) in delta.keys() else 0.0)
d2 = delta[(m, n - 1, n - 1)] if (m, n - 1,
n - 1) in delta.keys() else 0.0
delta[(m, n, n)] = (c1 / c2) * (m * d1 - (alpha[n - 1] - x0) * d2)
c1 = c2

Expand Down
Loading

0 comments on commit ee5ba44

Please sign in to comment.