Skip to content

Commit

Permalink
Feature/cicid (#11)
Browse files Browse the repository at this point in the history
* chore: cleanup test.yml

* (chore): pacyfing pipeline

* (fix): include python-package.yml

* (fix): fixing typo in noxfile

* (chore): cleaning pipeline

* (chore): cleaning code

* (chore): cleaning code

* (chore): moving experiment- and visualization scripts to corresponding folders

* (chore): cleaning code in stats module

* (fix): fixing tests

* (fix): pacifying pylint
  • Loading branch information
greinerth authored Aug 12, 2024
1 parent c3e2c49 commit ffc3ae9
Show file tree
Hide file tree
Showing 16 changed files with 87 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"runArgs": ["--gpus", "all"],
// Use 'postCreateCommand' to run commands after the container is created.

"postCreateCommand": "pip3 install --user -e .[dev] && sudo apt update && sudo apt upgrade -y && sudo apt install texlive-xetex cm-super dvipng -y",
"postCreateCommand": "pip3 install --user -e .[dev,test] && sudo apt update && sudo apt upgrade -y && sudo apt install texlive-xetex cm-super dvipng -y",

"customizations": {
"vscode": {
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand All @@ -30,4 +30,4 @@ jobs:
- name: Lint and test with nox
run: |
# stop the build if there are Python syntax errors or undefined names
nox
nox -s pylint tests
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ resources.
The easiest way to run the scripts is to use
[VSCode's devcontainer capability](https://code.visualstudio.com/docs/devcontainers/containers).
The project was tested on Ubuntu 22.04 (which also served as a host system for
the devcontainers) with the Python3.11 interpreter.
the devcontainers) with the Python3.11 and Python3.12 interpreter.

### Ubuntu 22.04

Expand All @@ -31,7 +31,7 @@ First, download this repository and execute the following commands

```
cd /path/to/repository
python3.11 -m pip --user install -e .
pip --user install -e .
```

For proper visualization of the results make sure LaTex is installed.\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import Any

import numpy as np

from varprodmdstatspy.util.experiment_utils import (
comp_checker,
dmd_stats,
Expand All @@ -21,11 +20,10 @@
)

logging.basicConfig(level=logging.INFO, filename=__name__)
# logging.root.setLevel(logging.INFO)

OPT_ARGS = {"method": "trf", "tr_solver": "exact", "loss": "linear"}


# OPT_ARGS = {"method": 'lm', "loss": 'linear'}
def test_high_dim_signal(
method: str, n_runs: int, std: float, eps: float
) -> dict[str, Any]:
Expand All @@ -34,13 +32,11 @@ def test_high_dim_signal(
__x, __time = np.meshgrid(x_loc, time)
z = signal(__x, __time).T

# time_stats, error_stats = dmd_stats(dmd, z, time, std, n_iter=n_runs)
mean_err, mean_dt, c_xx, c_xy, c_yy = dmd_stats(
method, z, time, std, OPT_ARGS, eps, n_iter=n_runs
)
return {
"case": "High dimensional signal",
# "omega_size": omega_size,
"method": method,
"compression": eps,
"n_runs": n_runs,
Expand Down Expand Up @@ -122,8 +118,7 @@ def run_mrse():
help="Scale the search directions with inverse jacobian, [Default: False]",
)
__args = parser.parse_args()
# manager = mp.Manager()
# results = manager.list()

if __args.scale_jac:
OPT_ARGS["x_scale"] = "jac"

Expand Down Expand Up @@ -156,9 +151,7 @@ def run_mrse():
c_xx_list = []
c_xy_list = []
c_yy_list = []
# exec_time_std_list = []
std_noise_list = []
# omega_list = []
mrse_mean_list = []

for res in starmap(test_high_dim_signal, args_in):
Expand All @@ -180,8 +173,10 @@ def run_mrse():

msg = f"{method} - Mean RSE: {mean_mrse}"
logging.info(msg)
stats = f"{
method} - Mean exec time: {mean_t} [s], Std exec time: {std_t} [s]"

stats = " ".join(
[f"{method} - Mean exec time: {mean_t} [s],", f"Std exec time: {std_t} [s]"]
)
logging.info(stats)

if std > 0:
Expand All @@ -197,16 +192,13 @@ def run_mrse():

data_dict = {
"Method": method_list,
# "N_eigs": omega_list,
"c": comp_list,
# "Experiment": case_list,
"E[t]": exec_time_mean_list,
"E[MRSE]": mrse_mean_list,
"STD_NOISE": std_noise_list,
"c_xx": c_xx_list,
"c_xy": c_xy_list,
"c_yy": c_yy_list,
# "N_RUNS": N_RUNS,
}
loss = OPT_ARGS["loss"]
opt = OPT_ARGS["method"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import netCDF4 as nc
import numpy as np
import wget

from varprodmdstatspy.util.experiment_utils import (
comp_checker,
dmd_stats,
Expand All @@ -29,13 +28,14 @@
OPT_ARGS = {"method": "trf", "tr_solver": "exact", "loss": "linear"}


# OPT_ARGS = {"method": 'lm', "loss": 'linear'}
def download(url: str, outdir: str):
"""Download dataset.
Found on: https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
Args:
url (str): url
fname (str): Output
:param url: url
:type url: str
:param outdir: Output
:type outdir: str
"""
wget.download(url, outdir)

Expand All @@ -60,7 +60,6 @@ def test_complex2d_signal(
)
return {
"case": "Complex 2D signal",
# "omega_size": omega_size,
"method": method,
"compression": eps,
"n_runs": n_runs,
Expand Down Expand Up @@ -133,7 +132,6 @@ def test_global_temp(
)
return {
"case": "Global temperature",
# "omega_size": omega_size,
"method": method,
"compression": eps,
"n_runs": n_runs,
Expand All @@ -160,8 +158,6 @@ def run_ssim():

currentdir = Path(inspect.getfile(inspect.currentframe())).resolve().parent

# PATH = os.path.join(currentdir, "data")
# FILE = os.path.join(PATH, DATASET)
OUTDIR = currentdir / "output"
parser = argparse.ArgumentParser("VarProDMD vs BOPDMD stats")

Expand Down Expand Up @@ -236,8 +232,7 @@ def run_ssim():
if __args.fct not in fcts:
msg = "f{__args.fct} not implemented!"
raise KeyError(msg)
# manager = mp.Manager()
# results = manager.list()

if __args.scale_jac:
OPT_ARGS["x_scale"] = "jac"

Expand Down Expand Up @@ -287,13 +282,11 @@ def run_ssim():
c_xx_list = []
c_xy_list = []
c_yy_list = []
# exec_time_std_list = []

std_noise_list = []
# omega_list = []
ssim_mean_list = []

for res in starmap(fcts[__args.fct], __args_in):
# logging.info(Fore.CYAN + res["case"])
std = res["std"]
method = res["method"]
mean_ssim = res["mean_err"]
Expand All @@ -307,15 +300,15 @@ def run_ssim():
c_xy_list.append(res["c_xy"])
c_yy_list.append(res["c_yy"])
std_noise_list.append(std)
# case_list.append(res["case"])
# omega_list.append(omega_size)

std_ssim = np.sqrt(res["c_xx"])
msg = f"{method} - Mean SSIM: {mean_ssim}, Std SSIM: {std_ssim}"
logging.info(msg)

stats = f"{
method} - Mean exec time: {mean_t} [s], Std exec time: {std_t} [s]"
stats = " ".join(
[f"{method} - Mean exec time: {mean_t} [s],", f"Std exec time: {std_t} [s]"]
)

logging.info(stats)

if std > 0:
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def pylint(session: nox.Session) -> None:
# This needs to be installed into the package environment, and is slower
# than a pre-commit check
session.install(".", "pylint>=3.2")
session.run("pylint", "varpdosmdstatspy", *session.posargs)
session.run("pylint", "varprodmdstatspy", *session.posargs)


@nox.session
Expand Down
23 changes: 14 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["hatchling"]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[project]
Expand Down Expand Up @@ -35,14 +35,20 @@ dependencies = [
"seaborn",
"netCDF4",
"colorama",
"wget"
"wget",
"SciencePlots"
]

[tool.hatch.build]
include = [
"varprodmdstatspy"
[tool.hatch]
version.source = "vcs"
build.hooks.vcs.version-file = "varprodmdstatspy/_version.py"
build.include = [
"varprodmdstatspy",
"experiments",
"visualization"
]


[project.optional-dependencies]
test = [
"pytest >=6",
Expand All @@ -58,7 +64,6 @@ dev = ["anybadge",
"hatchling",
"nox",
"pre-commit",
"SciencePlots",
"memray"]

docs = [
Expand Down Expand Up @@ -149,6 +154,6 @@ messages_control.disable = [
]

[project.scripts]
run_ssim = "varprodmdstatspy.varprodmd_ssim_performance:run_ssim"
run_mrse = "varprodmdstatspy.varprodmd_mrse_performance:run_mrse"
visualize_stats = "varprodmdstatspy.visualize_results:visualize_stats"
run_ssim = "experiments.varprodmd_ssim_performance:run_ssim"
run_mrse = "experiments.varprodmd_mrse_performance:run_mrse"
visualize_stats = "visualization.visualize_results:visualize_stats"
32 changes: 31 additions & 1 deletion tests/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,37 @@

import numpy as np
from varprodmdstatspy.util.experiment_utils import ssim_multi_images
from varprodmdstatspy.visualize_complex2d import generate_complex2d

generator = np.random.Generator(np.random.PCG64())


def generate_complex2d(
std: float = -1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate damped oscillating signal
:param std: Standard deviation for data corruption, defaults to -1
:type std: float, optional
:return: snapshots, timestamps, data
:rtype: tuple[np.ndarray, np.ndarray, np.ndarray]
"""
timestamps = np.linspace(0, 6, 16)
x_1 = np.linspace(-3, 3, 128)
x_2 = np.linspace(-3, 3, 128)
x1grid, x2grid = np.meshgrid(x_1, x_2)

data = [
np.expand_dims(2 / np.cosh(x1grid) / np.cosh(x2grid) * (1.2j**-t), axis=0)
for t in timestamps
]
snapshots_flat = np.zeros((np.prod(data[0].shape), len(data)), dtype=complex)
for j, img in enumerate(data):
__img = img.copy()
if std > 0:
__img += generator.normal(0, std, img.shape)
data[j] = __img
snapshots_flat[:, j] = np.ravel(__img)
return snapshots_flat, timestamps, np.concatenate(data, axis=0)


def test_ssim() -> None:
Expand Down
5 changes: 5 additions & 0 deletions varprodmdstatspy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from ._version import version as __version__

__all__ = ["__version__"]
4 changes: 3 additions & 1 deletion varprodmdstatspy/util/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
from pydmd.bopdmd import BOPDMD
from pydmd.varprodmd import VarProDMD
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import ( # pylint: disable=no-name-in-module
structural_similarity as ssim,
)

from varprodmdstatspy.util import stats

Expand Down
7 changes: 2 additions & 5 deletions varprodmdstatspy/util/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
delta_t = timeit.default_timer() - t_1
self.push(delta_t)

if delta_t < self._min:
self._min = delta_t

if delta_t > self._max:
self._max = delta_t
self._min = min(delta_t, self._min)
self._max = max(delta_t, self._max)

return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ def generate_complex2d(
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate damped oscillating signal
Args:
std (float, optional): Standard deviatopm for noise.
If <= 0 no noise is added. Defaults to -1.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: snapshots, timestamps, data
:param std: Standard deviation for data corruption, defaults to -1
:type std: float, optional
:return: snapshots, timestamps, data
:rtype: tuple[np.ndarray, np.ndarray, np.ndarray]
"""
timestamps = np.linspace(0, 6, 16)
x_1 = np.linspace(-3, 3, 128)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
z_signal = signal(_x, _time).T
_x = _x.T
_time = _time.T
# OPT_ARGS["loss"] = "huber"

dmd = VarProDMD(compression=COMP, optargs=OPT_ARGS, exact=True)
dmd.fit(z_signal, time)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ def generate_moving_points(
) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]:
"""Generate moving points example
Args:
std (float, optional):
Standard deviation, ignored when negative. Defaults to -1.
Returns:
Tuple[np.ndarray, np.ndarray, List[np.ndarray]]: snapshots, timestamps, data
:param std: Standard deviation, ignored when negative. Defaults to -1.
:type std: float, optional
:return: snapshots, timestamps, data
:rtype: tuple[np.ndarray, np.ndarray, list[np.ndarray]]
"""
fps = 30.0
total_time = 5.0
Expand Down
File renamed without changes.

0 comments on commit ffc3ae9

Please sign in to comment.