diff --git a/atmospheric_explorer/apis_test.ipynb b/atmospheric_explorer/apis_test.ipynb index d8aea0d..4e234ea 100644 --- a/atmospheric_explorer/apis_test.ipynb +++ b/atmospheric_explorer/apis_test.ipynb @@ -6,30 +6,11 @@ "metadata": {}, "outputs": [], "source": [ - "from atmospheric_explorer.cams_interfaces import InversionOptimisedGreenhouseGas, EAC4Instance\n", - "from atmospheric_explorer.shapefile import ShapefilesDownloader\n", "from atmospheric_explorer.utils import get_local_folder\n", - "from atmospheric_explorer.units_conversion import convert_units_array\n", - "import plotly.express as px\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", "import os\n", - "from glob import glob\n", - "import xarray as xr\n", - "from datetime import datetime\n", - "import geopandas as gpd\n", - "from shapely.geometry import mapping\n", - "import plotly.graph_objects as go\n", "import shutil\n", - "import numpy as np\n", - "import statsmodels.stats.api as sms\n", - "import pandas as pd\n", - "from plotly.subplots import make_subplots\n", - "from math import ceil\n", - "from atmospheric_explorer.plotting_apis import line_with_ci_subplots, clip_and_concat_countries, ghg_surface_satellite_yearly_plot, eac4_anomalies_plot, eac4_hovmoeller_latitude_plot, eac4_hovmoeller_levels_plot\n", - "\n", - "from atmospheric_explorer.utils import hex_to_rgb\n", - "from atmospheric_explorer.data_transformations import confidence_interval" + "import plotly.express as px\n", + "from atmospheric_explorer.plotting_apis import ghg_surface_satellite_yearly_plot, eac4_anomalies_plot, eac4_hovmoeller_latitude_plot, eac4_hovmoeller_levels_plot" ] }, { @@ -175,7 +156,20 @@ "metadata": {}, "outputs": [], "source": [ - "import plotly.express as px" + "fig = px.imshow(\n", + " img = [[1,2], [3,4]],\n", + " x = ['A', 'B'],\n", + " y = [1, 2],\n", + " text_auto=True\n", + ")\n", + "fig.update_yaxes(type='log')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# New modularized APIs test" ] }, { @@ -184,13 +178,32 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.imshow(\n", - " img = [[1,2], [3,4]],\n", - " x = ['A', 'B'],\n", - " y = [1, 2],\n", - " text_auto=True\n", + "from atmospheric_explorer.plotting_apis_new import TimeSeriesPlotInstance, DatasetName, EAC4Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test = TimeSeriesPlotInstance(\n", + " DatasetName.eac4,\n", + " \"Test title\",\n", + " [\"Italy\"],\n", + " \"1MS\",\n", + " EAC4Parameters(\n", + " \"total_column_nitrogen_dioxide\",\n", + " \"tcno2\",\n", + " \"netcdf\",\n", + " \"2020-01-01/2020-07-01\",\n", + " \"00:00\",\n", + " ),\n", + " None,\n", ")\n", - "fig.update_yaxes(type='log')" + "\n", + "test.download_data()\n", + "test.transform_data()" ] }, { @@ -198,7 +211,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "test.plot()" + ] } ], "metadata": { diff --git a/atmospheric_explorer/cams_interfaces.py b/atmospheric_explorer/cams_interfaces.py index 76524b9..76da354 100644 --- a/atmospheric_explorer/cams_interfaces.py +++ b/atmospheric_explorer/cams_interfaces.py @@ -334,7 +334,7 @@ class InversionOptimisedGreenhouseGas(CAMSDataInterface): def __init__( self: InversionOptimisedGreenhouseGas, - data_variables: str | set[str] | list[str], + data_variable: str | set[str] | list[str], file_format: str, quantity: str, input_observations: str, @@ -344,7 +344,7 @@ def __init__( files_dir: str | None = None, version: str = "latest", ): - super().__init__(data_variables, file_format) + super().__init__(data_variable, file_format) self.quantity = quantity self.input_observations = input_observations self.time_aggregation = time_aggregation diff --git a/atmospheric_explorer/plotting_apis.py b/atmospheric_explorer/plotting_apis.py index f53ff7f..3da7db7 100644 --- a/atmospheric_explorer/plotting_apis.py +++ b/atmospheric_explorer/plotting_apis.py @@ -206,7 +206,7 @@ def _ghg_surface_satellite_yearly_data( # pylint: disable=invalid-name # Download surface data file surface_data = InversionOptimisedGreenhouseGas( - data_variables=data_variable, + data_variable=data_variable, file_format="zip", quantity="surface_flux", input_observations="surface", @@ -216,7 +216,7 @@ def _ghg_surface_satellite_yearly_data( ) surface_data.download() satellite_data = InversionOptimisedGreenhouseGas( - data_variables=data_variable, + data_variable=data_variable, file_format="zip", quantity="surface_flux", input_observations="satellite", @@ -415,7 +415,6 @@ def eac4_hovmoeller_latitude_plot( return fig -# Generate a vertical Hovmoeller plot (levels vs time) for a quantity from the Global Reanalysis EAC4 dataset. def eac4_hovmoeller_levels_plot( data_variable: str, var_name: str, @@ -427,7 +426,7 @@ def eac4_hovmoeller_levels_plot( resampling: str = "1MS", base_colorscale: list[str] = px.colors.sequential.RdBu_r, ) -> go.Figure: - """Hoevmoeller plot of EAC4 multilevel variables, time vs pressure level""" + """Generate a vertical Hovmoeller plot (levels vs time) for a quantity from the Global Reanalysis EAC4 dataset.""" # pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=dangerous-default-value diff --git a/atmospheric_explorer/plotting_apis_new.py b/atmospheric_explorer/plotting_apis_new.py index b4d82dd..6d3c413 100644 --- a/atmospheric_explorer/plotting_apis_new.py +++ b/atmospheric_explorer/plotting_apis_new.py @@ -8,36 +8,122 @@ import plotly.express as px import plotly.graph_objects as go +from xarray import DataArray + +from atmospheric_explorer.cams_interfaces import ( + EAC4Instance, + InversionOptimisedGreenhouseGas, +) +from atmospheric_explorer.data_transformations import clip_and_concat_countries +from atmospheric_explorer.units_conversion import convert_units_array PlotType = Enum("PlotType", ["time_series", "hovmoeller"]) +DatasetName = Enum("DatasetName", ["eac4", "ghg"]) -class PlottingInterface(ABC): - """Generic interface for all plotting APIs""" +class EAC4Parameters: + """Parameters for EAC4 dataset""" - _plot_type: PlotType + # pylint: disable=too-few-public-methods def __init__( - self: PlottingInterface, + self: EAC4Parameters, _data_variable: str, _var_name: str, - _time_period: str, - ): + _file_format: str, + _dates_range: str, + _time_values: str, + ) -> None: self.data_variable = _data_variable self.var_name = _var_name - self.time_period = _time_period + self.file_format = _file_format + self.dates_range = _dates_range + self.time_values = _time_values - @property - def data_variable(self: PlottingInterface) -> str: - """Data variable name""" - return self._data_variable - @data_variable.setter - def data_variable( - self: PlottingInterface, - data_variable_input: str, +class GHGParameters: + """Parameters for GHG dataset""" + + # pylint: disable=too-few-public-methods + + def __init__( + self: GHGParameters, + _data_variable: str, + _file_format: str, + _quantity: str, + _input_observations: str, + _time_aggregation: str, + _year: list[str], + _month: list[str], ) -> None: - self._data_variable = data_variable_input + self.data_variable = _data_variable + self.file_format = _file_format + self.quantity = _quantity + self.input_observations = _input_observations + self.time_aggregation = _time_aggregation + self.year = _year + self.month = _month + + +class PlottingInterface(ABC): + """Generic interface for all plotting APIs""" + + _plot_type: PlotType + + def __init__( + self: PlottingInterface, + _dataset_name: DatasetName, + _eac4_parameters: EAC4Parameters, + _ghg_parameters: GHGParameters, + _countries: list[str], + _title: str, + ): + self.dataset_name = _dataset_name + self.title = _title + self.countries = _countries + # Qui sotto: puo` avere senso tenere comunque degli argomenti generici, + # che ci sono per entrambi i dataset e possono servire sempre nei plot/dowload dati? + match self.dataset_name: + case DatasetName.eac4: + self.eac4_parameters = _eac4_parameters + self.data_variable = _eac4_parameters.data_variable + self.var_name = _eac4_parameters.var_name + case DatasetName.ghg: + self.ghg_parameters = _ghg_parameters + self.data_variable = _ghg_parameters.data_variable + self.var_name = _ghg_parameters.var_name + + @abstractmethod + def download_data(self: PlottingInterface): + """Downloads data""" + match self.dataset_name: + case DatasetName.eac4: + assert self.eac4_parameters is not None + data = EAC4Instance( + self.eac4_parameters.data_variable, + "netcdf", + dates_range=self.eac4_parameters.dates_range, + time_values=self.eac4_parameters.time_values, + ) + data.download() + case DatasetName.ghg: + assert self.ghg_parameters is not None + data = InversionOptimisedGreenhouseGas( + data_variable=self.ghg_parameters.data_variable, + file_format="zip", + quantity=self.ghg_parameters.quantity, + input_observations=self.ghg_parameters.input_observations, + time_aggregation=self.ghg_parameters.time_aggregation, + year=self.ghg_parameters.year, + month=self.ghg_parameters.month, + ) + data.download() + return data.read_dataset() + + @abstractmethod + def transform_data(self: PlottingInterface): + """Transforms data as needed""" + raise NotImplementedError("Method not implemented") @abstractmethod def plot(self: PlottingInterface): @@ -48,18 +134,75 @@ def plot(self: PlottingInterface): class TimeSeriesPlotInstance(PlottingInterface): """Time series plot object""" + # pylint: disable=too-many-arguments + _plot_type: PlotType = PlotType.time_series + _data_array: DataArray + + def __init__( + self: PlottingInterface, + dataset_name: DatasetName, + title: str, + countries: list[str], + resampling: str, + eac4_parameters: EAC4Parameters, + ghg_parameters: GHGParameters, + ): + match dataset_name: + case DatasetName.eac4: + super().__init__( + dataset_name, + eac4_parameters, + None, + countries, + title, + ) + case DatasetName.ghg: + super().__init__( + dataset_name, + None, + ghg_parameters, + countries, + title, + ) + self.resampling = resampling + + def download_data(self: TimeSeriesPlotInstance): + self._data_array = super().download_data() + + def transform_data(self: TimeSeriesPlotInstance): + df_down = self._data_array.rio.write_crs("EPSG:4326") + df_clipped = clip_and_concat_countries(df_down, self.countries).sel( + admin=self.countries[0] + ) + df_agg = ( + df_clipped.mean(dim=["latitude", "longitude"]) + .resample(time=self.resampling, restore_coord_dims=True) + .mean(dim="time") + ) + reference_value = df_agg.mean(dim="time") + df_converted = convert_units_array(df_agg[self.var_name], self.data_variable) + reference_value = df_converted.mean().values + df_anomalies = df_converted - reference_value + df_anomalies.attrs = df_converted.attrs + self._data_array = df_anomalies - # def __init__( - # self: PlottingInterface, - # data_variable: str, - # var_name: str, - # time_period: str, - # ): - # super().__init__(data_variable, var_name, time_period) - - def plot( - self: TimeSeriesPlotInstance, - ) -> go.Figure: - fig = px.line() + def plot(self: TimeSeriesPlotInstance) -> go.Figure: + fig = px.line( + y=self._data_array.values, + x=self._data_array.coords["time"], + markers="o", + ) + fig.update_xaxes(title="Month") + fig.update_yaxes(title=self._data_array.attrs["units"]) + fig.update_layout( + title={ + "text": self.title, + "x": 0.45, + "y": 0.95, + "automargin": True, + "yref": "container", + "font": {"size": 19}, + } + ) return fig