Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Eomys/SciDataTool
Browse files Browse the repository at this point in the history
  • Loading branch information
helene-t committed Nov 3, 2021
2 parents f370d37 + 64a7398 commit 2ed7e17
Show file tree
Hide file tree
Showing 14 changed files with 365 additions and 128 deletions.
27 changes: 23 additions & 4 deletions SciDataTool/Functions/Plot/plot_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def plot_2D(
font_size_legend=8,
is_show_legend=True,
is_outside_legend=False,
is_frame_legend=None,
scale_units="x",
scale=None,
width=0.005,
Expand Down Expand Up @@ -108,6 +109,8 @@ def plot_2D(
True to show figure after plot
is_outside_legend : bool
True to display legend outside the graph
is_frame_legend : bool
True to display legend in a frame
win_title : str
Title of the plot window
scale_units : str
Expand Down Expand Up @@ -396,16 +399,32 @@ def get_cumulated_array(data, **kwargs):
if is_grid:
ax.grid()

# Determine if frame is displayed
if is_frame_legend is None:
if is_outside_legend:
is_frame_legend = False
else:
is_frame_legend = True

# if ndatas > 1 and not no_legend:
if not no_legend:
ax.legend(prop={"family": font_name, "size": font_size_legend})
if is_outside_legend:
ax.legend(
prop={"family": font_name, "size": font_size_legend},
loc="upper left",
bbox_to_anchor=(1, 1),
frameon=is_frame_legend,
)
else:
ax.legend(
prop={"family": font_name, "size": font_size_legend},
loc="center left",
frameon=is_frame_legend,
)

if not is_show_legend:
ax.get_legend().remove()

if is_outside_legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.tight_layout()
for item in (
[ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()
Expand Down
9 changes: 8 additions & 1 deletion SciDataTool/Functions/Plot/plot_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def plot_3D(
is_disp_title=True,
type_plot="stem",
is_contour=False,
is_shading_flat=False,
save_path=None,
is_show_fig=None,
is_switch_axes=False,
Expand Down Expand Up @@ -96,6 +97,8 @@ def plot_3D(
type of 3D graph : "stem", "surf", "pcolor" or "scatter"
is_contour : bool
True to show contour line if type_plot = "pcolor"
is_shading_flat : bool
True to use flat shading instead of Gouraud
save_path : str
full path including folder, name and extension of the file to save if save_path is not None
is_show_fig : bool
Expand Down Expand Up @@ -254,12 +257,16 @@ def plot_3D(
ax.set_xlim([x_min, x_max])
ax.set_ylim([y_min, y_max])
elif type_plot == "pcolormesh":
if is_shading_flat:
shading = "flat"
else:
shading = "gouraud"
c = ax.pcolormesh(
Xdata,
Ydata,
Zdata,
cmap=colormap,
shading="gouraud",
shading=shading,
antialiased=True,
picker=True,
vmin=z_min,
Expand Down
28 changes: 26 additions & 2 deletions SciDataTool/Functions/Plot/plot_4D.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-

from numpy import log10, abs as np_abs, max as np_max, NaN, zeros_like
import matplotlib.pyplot as plt

from SciDataTool.Functions.Plot.init_fig import init_fig
import numpy as np


def plot_4D(
Expand All @@ -26,6 +29,7 @@ def plot_4D(
xticklabels=None,
yticklabels=None,
annotations=None,
annotation_threshold=0.01,
fig=None,
ax=None,
is_logscale_x=False,
Expand All @@ -42,6 +46,8 @@ def plot_4D(
font_size_label=10,
font_size_legend=8,
is_grid=False,
grid_xlw=None,
grid_ylw=None,
):
"""Plots a 4D graph
Expand Down Expand Up @@ -89,6 +95,8 @@ def plot_4D(
list of tick labels to use for the x-axis
annotations : list
list of annotations to apply to data
annotation_threshold : float
threshold to plot annotation (percentage of the maximum value)
fig : Matplotlib.figure.Figure
existing figure to use if None create a new one
ax : Matplotlib.axes.Axes object
Expand All @@ -109,6 +117,12 @@ def plot_4D(
True to show figure after plot
is_switch_axes : bool
to switch x and y axes
is_grid : bool
to plot grid
grid_xlw : float
grid linewidth along x
grid_ylw : float
grid linewidth along y
"""

# Set figure/subplot
Expand Down Expand Up @@ -194,7 +208,7 @@ def plot_4D(
ax.set_yticklabels(yticklabels)
if annotations is not None:
for i, txt in enumerate(annotations):
if Zdata[i] > z_max * 0.01:
if Zdata[i] > z_max * annotation_threshold and txt is not None:
ax.annotate(
str(txt) + " [Hz]",
(Xdata[i], Ydata[i]),
Expand Down Expand Up @@ -236,7 +250,17 @@ def plot_4D(
ax.title.set_fontname(font_name)

if is_grid:
ax.grid()
if grid_xlw is not None:
ax.xaxis.grid(lw=grid_xlw)
else:
ax.xaxis.grid()
if grid_ylw is not None:
ax.yaxis.grid(lw=grid_ylw)
else:
ax.yaxis.grid()
# ax.xaxis.grid(False) #To remove grid along x
# Plot grid below data
ax.set_axisbelow(True)

if save_path is not None:
save_path = save_path.replace("\\", "/")
Expand Down
4 changes: 2 additions & 2 deletions SciDataTool/Functions/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ def dB_to_dBA(values, freqs, noct=None):
)
Aweight = 2.0 + 20.0 * log10(RA)
Aweight[isnan(Aweight)] = -100 # replacing NaN by -100 dB
values += Aweight
return values
values_dBA = values + Aweight
return values_dBA


def to_noct(values, freqs, noct=3):
Expand Down
46 changes: 32 additions & 14 deletions SciDataTool/Functions/sum_mean.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np

from SciDataTool.Functions.derivation_integration import integrate
from SciDataTool.Functions.conversions import convert


def my_sum(values, index, Nper, is_aper):
def my_sum(values, index, Nper, is_aper, unit):
"""Returns the arithmetic sum of values along given axis
Parameters
Expand All @@ -29,11 +30,17 @@ def my_sum(values, index, Nper, is_aper):
shape0 = [s for ii, s in enumerate(shape) if ii != index]
values = np.zeros(shape0, dtype=values.dtype)
else:
# Take sum value multiplied by periodicity
if Nper is None:
# Set Nper to 1 in case of non-periodic axis
Nper = 1
values = Nper * np.sum(values, axis=index, keepdims=True)
# To sum dB or dBA
if "dB" in unit:
values = 10 * np.log10(
np.sum(10 ** (values / 10), axis=index, keepdims=True)
)
else:
# Take sum value multiplied by periodicity
if Nper is None:
# Set Nper to 1 in case of non-periodic axis
Nper = 1
values = Nper * np.sum(values, axis=index, keepdims=True)

return values

Expand Down Expand Up @@ -62,6 +69,9 @@ def my_mean(values, ax_val, index, Nper, is_aper, is_phys):
mean of values
"""

if ax_val.size == 1: # Do not use integrate for single point axes
is_phys = False

if is_phys:
# Integrate values and take mean value by dividing by integration interval in integrate()
values = integrate(values, ax_val, index, Nper, is_aper, is_phys, is_mean=True)
Expand Down Expand Up @@ -112,7 +122,7 @@ def root_mean_square(values, ax_val, index, Nper, is_aper, is_phys):
return np.sqrt(my_mean(values ** 2, ax_val, index, Nper, is_aper, is_phys))


def root_sum_square(values, ax_val, index, Nper, is_aper, is_phys):
def root_sum_square(values, ax_val, index, Nper, is_aper, is_phys, unit):
"""Returns the root sum square (arithmetic or integral) of values along given axis
Parameters
Expand All @@ -136,13 +146,21 @@ def root_sum_square(values, ax_val, index, Nper, is_aper, is_phys):
root sum square of values
"""

if is_aper and Nper is not None:
# Remove anti-periodicity since values is squared
is_aper = False
# To sum dB or dBA
if "dB" in unit:
return my_sum(values, index, Nper, is_aper, unit)

if is_phys:
values = integrate(values ** 2, ax_val, index, Nper, is_aper, is_phys)
else:
values = my_sum(values ** 2, index, Nper, is_aper)
if is_aper and Nper is not None:
# Remove anti-periodicity since values is squared
is_aper = False

if ax_val.size == 1: # Do not use integrate for single point axes
is_phys = False

if is_phys:
values = integrate(values ** 2, ax_val, index, Nper, is_aper, is_phys)
else:
values = my_sum(values ** 2, index, Nper, is_aper, unit)

return np.sqrt(values)
return np.sqrt(values)
17 changes: 15 additions & 2 deletions SciDataTool/Methods/DataND/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def convert(self, values, unit, is_norm, is_squeeze, axes_list):
ref_value = 1.0
if "ref" in self.normalizations:
ref_value *= self.normalizations["ref"].ref
is_match = False
for axis in axes_list:
is_match = False
if axis.name == "freqs" or axis.corr_name == "freqs":
if axis.corr_values is not None and axis.unit not in [
"SI",
Expand All @@ -66,7 +66,20 @@ def convert(self, values, unit, is_norm, is_squeeze, axes_list):
)
is_match = True
if not is_match:
raise UnitError("dBA conversion only available for fft with frequencies")
axis_names = [axis.name for axis in self.axes]
if "speed" in axis_names and "order" in axis_names:
freqs = self.get_freqs()
freqs = freqs.ravel("C")
shape = values.shape
values = values.reshape(freqs.shape + shape[2:])
values = np.apply_along_axis(
to_dBA, 0, values, freqs, self.unit, ref_value
)
values = values.reshape(shape)
else:
raise UnitError(
"dBA conversion only available for fft with frequencies"
)
elif unit in self.normalizations:
values = self.normalizations.get(unit).normalize(values)
else:
Expand Down
2 changes: 1 addition & 1 deletion SciDataTool/Methods/DataND/get_along.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_along(
# Interpolate over axis values
values = self.interpolate(values, axes_list)
# Sums
values = self.summing(values, axes_list, is_magnitude)
values = self.summing(values, axes_list, is_magnitude, unit=self.unit)
# Conversions
values = self.convert(values, unit, is_norm, is_squeeze, axes_list)
# Return axes and values
Expand Down
27 changes: 16 additions & 11 deletions SciDataTool/Methods/DataND/get_data_along.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,21 @@ def get_data_along(self, *args, unit="SI", is_norm=False, axis_data=[]):
name = axis.name
is_components = axis.is_components
axis_values = results[axis_name]
unit = axis.unit
ax_unit = axis.unit
elif axis_name in axes_dict:
if axes_dict[axis_name][0] == axis.name:
index = i
name = axis_name
is_components = axis.is_components
axis_values = results[axis_name]
unit = axes_dict[axis_name][2]
ax_unit = axes_dict[axis_name][2]
elif axis_name in rev_axes_dict:
if rev_axes_dict[axis_name][0] == axis.name:
index = i
name = axis_name
is_components = axis.is_components
axis_values = results[axis_name]
unit = rev_axes_dict[axis_name][2]
ax_unit = rev_axes_dict[axis_name][2]
# Update symmetries
if "smallestperiod" in args[index] or args[index] in [
"freqs",
Expand All @@ -76,20 +76,25 @@ def get_data_along(self, *args, unit="SI", is_norm=False, axis_data=[]):
Axes.append(
Data1D(
name=name,
unit=unit,
unit=ax_unit,
values=axis_values,
is_components=is_components,
normalizations=self.axes[index].normalizations,
symmetries=symmetries,
).to_linspace()
)
# Update unit if derivation or integration
unit = self.unit
for axis in axes_list:
if axis.extension in ["antiderivate", "integrate"]:
unit = get_unit_integrate(unit, axis.corr_unit)
elif axis.extension == "derivate":
unit = get_unit_derivate(unit, axis.corr_unit)

# Update unit if dB/dBA conversion
if "dB" in unit:
unit = unit
else:
# Update unit if derivation or integration
unit = self.unit
for axis in axes_list:
if axis.extension in ["antiderivate", "integrate"]:
unit = get_unit_integrate(unit, axis.corr_unit)
elif axis.extension == "derivate":
unit = get_unit_derivate(unit, axis.corr_unit)

return DataClass(
name=self.name,
Expand Down
Loading

0 comments on commit 2ed7e17

Please sign in to comment.