Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Eomys/SciDataTool
Browse files Browse the repository at this point in the history
  • Loading branch information
helene-t committed Feb 11, 2022
2 parents 709b746 + 129650a commit a1a51b1
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 15 deletions.
56 changes: 56 additions & 0 deletions SciDataTool/Functions/derivation_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,62 @@ def integrate_local(values, ax_val, index, Nper, is_aper, is_phys, is_freqs):
return values_integ


def integrate_local_pattern(values, ax_val, index):
"""Returns the local integral of values along given axis (does not change the axis)
Parameters
----------
values: ndarray
array to derivate
ax_val: ndarray
axis values
index: int
index of axis along which to derivate
Nper: int
number of periods to replicate
is_aper: bool
True if values is anti-periodic along axis
is_phys: bool
True if physical quantity (time/angle/z)
is_freqs: bool
True if frequency axis
Returns
-------
values_integ: ndarray
local integration of values
"""

if ax_val.size > 1:
# Swap axis to always have integration axis on 1st position
values = np.swapaxes(values, index, 0)

# Init output arrays
shape = list(values.shape[1:])
shape.insert(0, values.shape[0] - 1)
values_int = np.zeros(shape, dtype=values.dtype)
ax_int = np.zeros(shape[0])

# Trapezoidal integration on each segment and calculate position of each segment middle
for ii in range(shape[0]):
values_int[ii, ...] = scp_int.trapezoid(
values[ii : ii + 2, ...], x=ax_val[ii : ii + 2], axis=0
)
ax_int[ii] = np.mean(ax_val[ii : ii + 2])

# Remove zero length interval
Ia = np.nonzero(np.diff(ax_val))[0]
values_int = values_int[Ia, ...]
ax_int = ax_int[Ia]

values_int = np.swapaxes(values_int, index, 0)

else:
raise Exception("Cannot locally integrate along axis if axis size is 1")

return values_int, ax_int


def integrate(values, ax_val, index, Nper, is_aper, is_phys, is_mean=False):
"""Returns the integral of values along given axis
Expand Down
4 changes: 2 additions & 2 deletions SciDataTool/Functions/set_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def get_relative_tolerance(a, atol):
"""

a_max = np.max(np.abs(a))
if a_max > 0:
if a_max >= 0:
rtol = atol * np.max(np.abs(a))

if rtol > 1:
if rtol > 1 or rtol == 0:
# threshold tol to 1
rtol = 1

Expand Down
11 changes: 8 additions & 3 deletions SciDataTool/Methods/DataND/_apply_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
derivate,
integrate,
integrate_local,
integrate_local_pattern,
antiderivate,
)
from SciDataTool.Functions.sum_mean import (
Expand Down Expand Up @@ -92,9 +93,13 @@ def _apply_operations(self, values, axes_list, is_magnitude, unit, corr_unit):
values = integrate(values, ax_val, index, Nper, is_aper, is_phys)
# local integration over integration axes
elif extension == "integrate_local":
values = integrate_local(
values, ax_val, index, Nper, is_aper, is_phys, is_freqs
)
if axis_requested.name == "z":
values, ax_val = integrate_local_pattern(values, ax_val, index)
axis_requested.values = ax_val
else:
values = integrate_local(
values, ax_val, index, Nper, is_aper, is_phys, is_freqs
)
# antiderivation over antiderivation axes
elif extension == "antiderivate":
values = antiderivate(
Expand Down
16 changes: 10 additions & 6 deletions SciDataTool/Methods/DataND/plot_2D_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,17 @@ def plot_2D_Data(
if isinstance(result_0[axis.name], str):
title2 += name + "=" + result_0[axis.name]
else:
if result_0[axis.name][0] > 10:
fmt = "{:.5g}"
if isinstance(result_0[axis.name][0], str):
axis_str = result_0[axis.name][0]
else:
fmt = "{:.3g}"
axis_str = array2string(
result_0[axis.name], formatter={"float_kind": fmt.format}
).replace(" ", ", ")
if result_0[axis.name][0] > 10:
fmt = "{:.5g}"
else:
fmt = "{:.3g}"
axis_str = array2string(
result_0[axis.name], formatter={"float_kind": fmt.format}
).replace(" ", ", ")

if len(result_0[axis.name]) == 1:
axis_str = axis_str.strip("[]")

Expand Down
89 changes: 87 additions & 2 deletions Tests/Validation/test_fft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from pandas import NA
import pytest
from SciDataTool import DataTime, DataFreq, DataLinspace, Data1D, VectorField
from SciDataTool import (
DataTime,
DataFreq,
DataLinspace,
Data1D,
VectorField,
Norm_indices,
)
import numpy as np
from numpy.testing import assert_array_almost_equal
from numpy import zeros, exp, real, pi, take, insert, delete
Expand Down Expand Up @@ -851,5 +859,82 @@ def test_fft2_anti_period_random():
)


@pytest.mark.validation
def test_fft1d_non_uniform(per_a=2, is_apera=True, is_add_zero_freq=True):
"""check non uniform fft1d
TODO: solve bug for a single frequency vector"""
# %%
f = 50
Na = 4 * 10
slip = 0.01
Nt = 100
A0 = 10

sym_dict = dict()
per_a0 = per_a
if is_apera:
per_a *= 2
sym_dict["antiperiod"] = per_a
elif per_a > 1:
sym_dict["period"] = per_a

# Creating the data object
Phase = DataLinspace(
name="phase",
unit="rad",
initial=0,
final=2 * pi / per_a,
number=int(Na / per_a),
include_endpoint=False,
symmetries=sym_dict,
normalizations={"bar_id": Norm_indices()},
is_overlay=True,
)

Time = DataLinspace(
name="time",
unit="s",
initial=0,
final=1 / f,
number=Nt,
symmetries={"antiperiod": 4},
)

angle_bars = Phase.get_values(is_smallestperiod=True)

if is_add_zero_freq:
values = np.zeros((2, angle_bars.size), dtype=complex)
values[1, :] = A0 * np.exp(1j * per_a0 * angle_bars)
freqs_val = np.array([0, slip * f])
else:
values = A0 * np.exp(1j * per_a0 * angle_bars[None, :])
freqs_val = np.array([slip * f])

Freqs = Data1D(
name="freqs",
symbol="",
unit="Hz",
values=freqs_val,
normalizations=dict(),
)

Data = DataFreq(
name="field", unit="A", symbol="X", axes=[Freqs, Phase], values=values
)

val_time = Data.get_data_along(
"time=axis_data", "phase", axis_data={"time": Time.get_values()}
)

# Plot
# val_time.plot_2D_Data("phase", "time[0]", type_plot="bargraph")
# val_time.plot_2D_Data("time", "phase[0,1,2,3]")

val_check = A0 * np.cos(2 * np.pi * slip * f * Time.get_values())
assert_array_almost_equal(val_time.values[:, 0], val_check)


if __name__ == "__main__":
test_ifft2d_period()
# test_ifft2d_period()
test_fft1d_non_uniform(is_add_zero_freq=True)
test_fft1d_non_uniform(is_add_zero_freq=False)
88 changes: 86 additions & 2 deletions Tests/Validation/test_get_data_along.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_equal

from SciDataTool import DataLinspace, DataTime, Norm_ref, Data1D
from SciDataTool import DataLinspace, DataTime, Norm_ref, Data1D, DataPattern


@pytest.mark.validation
Expand Down Expand Up @@ -491,10 +491,94 @@ def test_get_data_along_integrate_local():
assert Field_int_loc.unit == "ms"


@pytest.mark.validation
def test_get_data_along_integrate_local_pattern():

# Test integration per step with DataPattern
f = 50
A = 5
Nt = 3
Time = DataLinspace(
name="time",
unit="s",
initial=0,
final=1 / f,
number=Nt,
include_endpoint=False,
)

z = DataPattern(
name="z",
unit="m",
values=np.array([-0.045, -0.09]),
rebuild_indices=[1, 1, 0, 0, 0, 0, 1, 1],
unique_indices=[2, 0],
values_whole=np.array([-0.09, -0.045, -0.045, 0.0, 0.0, 0.045, 0.045, 0.09]),
is_step=True,
)

time = Time.get_values()
field = np.zeros((Nt, 2))
field[:, 0] = A * np.cos(2 * np.pi * f * time)
field[:, 1] = 0.5 * A * np.cos(2 * np.pi * f * time)

Field = DataTime(
name="Example field",
symbol="X",
unit="T/m",
normalizations={"ref": Norm_ref(ref=2e-5)},
axes=[Time, z],
values=field,
)

# Field.plot_3D_Data("time", "z")
# Field.plot_2D_Data("z", "time[0]")
Field_int_loc = Field.get_data_along("time", "z=integrate_local")
assert_equal(Field_int_loc.values.shape, (Nt, 4))
assert_array_almost_equal(
2 * Field_int_loc.values[:, [0, 3]], Field_int_loc.values[:, [1, 2]]
)

z2 = DataPattern(
name="z",
unit="m",
values=np.array([-0.09, -0.045, 0.0, 0.045, 0.09]),
rebuild_indices=[0, 1, 2, 3, 4],
unique_indices=[0, 1, 2, 3, 4],
values_whole=np.array([-0.09, -0.045, 0.0, 0.045, 0.09]),
is_step=False,
)

field2 = np.zeros((Nt, 5))
field2[:, 0] = A * np.cos(2 * np.pi * f * time)
field2[:, 1] = 0.8 * A * np.cos(2 * np.pi * f * time)
field2[:, 2] = 0.6 * A * np.cos(2 * np.pi * f * time)
field2[:, 3] = 0.4 * A * np.cos(2 * np.pi * f * time)
field2[:, 4] = 0.2 * A * np.cos(2 * np.pi * f * time)

Field2 = DataTime(
name="Example field 2",
symbol="X",
unit="T/m",
normalizations={"ref": Norm_ref(ref=2e-5)},
axes=[Time, z2],
values=field2,
)

Field2.plot_3D_Data("time", "z")
Field2.plot_2D_Data("z", "time[0]")
Field_int_loc2 = Field2.get_data_along("time", "z=integrate_local")
assert_equal(Field_int_loc2.values.shape, (Nt, 4))
# assert_array_almost_equal(
# 0.8 * Field_int_loc2.values[:, 0], Field_int_loc2.values[:, 1]
# )


if __name__ == "__main__":
# test_get_data_along_single()
# test_get_data_along_integrate()
# test_get_data_along_derivate()
# test_get_data_along_antiderivate()
# test_get_data_along_to_linspace()
test_get_data_along_integrate_local()
# test_get_data_along_integrate_local()
test_get_data_along_integrate_local_pattern()

0 comments on commit a1a51b1

Please sign in to comment.