Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize plotting APIs #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions atmospheric_explorer/apis_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"from atmospheric_explorer.cams_interfaces import InversionOptimisedGreenhouseGas, EAC4Instance\n",
"from atmospheric_explorer.shapefile import ShapefilesDownloader\n",
"from atmospheric_explorer.utils import get_local_folder\n",
"from atmospheric_explorer.units_conversion import convert_units_array\n",
"import plotly.express as px\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"from glob import glob\n",
"import xarray as xr\n",
"from datetime import datetime\n",
"import geopandas as gpd\n",
"from shapely.geometry import mapping\n",
"import plotly.graph_objects as go\n",
"import shutil\n",
"import numpy as np\n",
"import statsmodels.stats.api as sms\n",
"import pandas as pd\n",
"from plotly.subplots import make_subplots\n",
"from math import ceil\n",
"from atmospheric_explorer.plotting_apis import line_with_ci_subplots, clip_and_concat_countries, ghg_surface_satellite_yearly_plot, eac4_anomalies_plot, eac4_hovmoeller_latitude_plot, eac4_hovmoeller_levels_plot\n",
"\n",
"from atmospheric_explorer.utils import hex_to_rgb\n",
"from atmospheric_explorer.data_transformations import confidence_interval"
"import plotly.express as px\n",
"from atmospheric_explorer.plotting_apis import ghg_surface_satellite_yearly_plot, eac4_anomalies_plot, eac4_hovmoeller_latitude_plot, eac4_hovmoeller_levels_plot"
]
},
{
Expand Down Expand Up @@ -175,7 +156,20 @@
"metadata": {},
"outputs": [],
"source": [
"import plotly.express as px"
"fig = px.imshow(\n",
" img = [[1,2], [3,4]],\n",
" x = ['A', 'B'],\n",
" y = [1, 2],\n",
" text_auto=True\n",
")\n",
"fig.update_yaxes(type='log')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# New modularized APIs test"
]
},
{
Expand All @@ -184,21 +178,42 @@
"metadata": {},
"outputs": [],
"source": [
"fig = px.imshow(\n",
" img = [[1,2], [3,4]],\n",
" x = ['A', 'B'],\n",
" y = [1, 2],\n",
" text_auto=True\n",
"from atmospheric_explorer.plotting_apis_new import TimeSeriesPlotInstance, DatasetName, EAC4Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test = TimeSeriesPlotInstance(\n",
" DatasetName.eac4,\n",
" \"Test title\",\n",
" [\"Italy\"],\n",
" \"1MS\",\n",
" EAC4Parameters(\n",
" \"total_column_nitrogen_dioxide\",\n",
" \"tcno2\",\n",
" \"netcdf\",\n",
" \"2020-01-01/2020-07-01\",\n",
" \"00:00\",\n",
" ),\n",
" None,\n",
")\n",
"fig.update_yaxes(type='log')"
"\n",
"test.download_data()\n",
"test.transform_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"test.plot()"
]
}
],
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions atmospheric_explorer/cams_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class InversionOptimisedGreenhouseGas(CAMSDataInterface):

def __init__(
self: InversionOptimisedGreenhouseGas,
data_variables: str | set[str] | list[str],
data_variable: str | set[str] | list[str],
file_format: str,
quantity: str,
input_observations: str,
Expand All @@ -344,7 +344,7 @@ def __init__(
files_dir: str | None = None,
version: str = "latest",
):
super().__init__(data_variables, file_format)
super().__init__(data_variable, file_format)
self.quantity = quantity
self.input_observations = input_observations
self.time_aggregation = time_aggregation
Expand Down
7 changes: 3 additions & 4 deletions atmospheric_explorer/plotting_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _ghg_surface_satellite_yearly_data(
# pylint: disable=invalid-name
# Download surface data file
surface_data = InversionOptimisedGreenhouseGas(
data_variables=data_variable,
data_variable=data_variable,
file_format="zip",
quantity="surface_flux",
input_observations="surface",
Expand All @@ -216,7 +216,7 @@ def _ghg_surface_satellite_yearly_data(
)
surface_data.download()
satellite_data = InversionOptimisedGreenhouseGas(
data_variables=data_variable,
data_variable=data_variable,
file_format="zip",
quantity="surface_flux",
input_observations="satellite",
Expand Down Expand Up @@ -415,7 +415,6 @@ def eac4_hovmoeller_latitude_plot(
return fig


# Generate a vertical Hovmoeller plot (levels vs time) for a quantity from the Global Reanalysis EAC4 dataset.
def eac4_hovmoeller_levels_plot(
data_variable: str,
var_name: str,
Expand All @@ -427,7 +426,7 @@ def eac4_hovmoeller_levels_plot(
resampling: str = "1MS",
base_colorscale: list[str] = px.colors.sequential.RdBu_r,
) -> go.Figure:
"""Hoevmoeller plot of EAC4 multilevel variables, time vs pressure level"""
"""Generate a vertical Hovmoeller plot (levels vs time) for a quantity from the Global Reanalysis EAC4 dataset."""
# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=dangerous-default-value
Expand Down
199 changes: 171 additions & 28 deletions atmospheric_explorer/plotting_apis_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,122 @@

import plotly.express as px
import plotly.graph_objects as go
from xarray import DataArray

from atmospheric_explorer.cams_interfaces import (
EAC4Instance,
InversionOptimisedGreenhouseGas,
)
from atmospheric_explorer.data_transformations import clip_and_concat_countries
from atmospheric_explorer.units_conversion import convert_units_array

PlotType = Enum("PlotType", ["time_series", "hovmoeller"])
DatasetName = Enum("DatasetName", ["eac4", "ghg"])


class PlottingInterface(ABC):
"""Generic interface for all plotting APIs"""
class EAC4Parameters:
"""Parameters for EAC4 dataset"""

_plot_type: PlotType
# pylint: disable=too-few-public-methods

def __init__(
self: PlottingInterface,
self: EAC4Parameters,
_data_variable: str,
_var_name: str,
_time_period: str,
):
_file_format: str,
_dates_range: str,
_time_values: str,
) -> None:
self.data_variable = _data_variable
self.var_name = _var_name
self.time_period = _time_period
self.file_format = _file_format
self.dates_range = _dates_range
self.time_values = _time_values

@property
def data_variable(self: PlottingInterface) -> str:
"""Data variable name"""
return self._data_variable

@data_variable.setter
def data_variable(
self: PlottingInterface,
data_variable_input: str,
class GHGParameters:
"""Parameters for GHG dataset"""

# pylint: disable=too-few-public-methods

def __init__(
self: GHGParameters,
_data_variable: str,
_file_format: str,
_quantity: str,
_input_observations: str,
_time_aggregation: str,
_year: list[str],
_month: list[str],
) -> None:
self._data_variable = data_variable_input
self.data_variable = _data_variable
self.file_format = _file_format
self.quantity = _quantity
self.input_observations = _input_observations
self.time_aggregation = _time_aggregation
self.year = _year
self.month = _month


class PlottingInterface(ABC):
"""Generic interface for all plotting APIs"""

_plot_type: PlotType

def __init__(
self: PlottingInterface,
_dataset_name: DatasetName,
_eac4_parameters: EAC4Parameters,
_ghg_parameters: GHGParameters,
_countries: list[str],
_title: str,
):
self.dataset_name = _dataset_name
self.title = _title
self.countries = _countries
# Qui sotto: puo` avere senso tenere comunque degli argomenti generici,
# che ci sono per entrambi i dataset e possono servire sempre nei plot/dowload dati?
match self.dataset_name:
case DatasetName.eac4:
self.eac4_parameters = _eac4_parameters
self.data_variable = _eac4_parameters.data_variable
self.var_name = _eac4_parameters.var_name
case DatasetName.ghg:
self.ghg_parameters = _ghg_parameters
self.data_variable = _ghg_parameters.data_variable
self.var_name = _ghg_parameters.var_name

@abstractmethod
def download_data(self: PlottingInterface):
"""Downloads data"""
match self.dataset_name:
case DatasetName.eac4:
assert self.eac4_parameters is not None
data = EAC4Instance(
self.eac4_parameters.data_variable,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense if we used named args here as well, in the EAC4Instance class. I'd also suggest we make these class names more similar (e.g. EAC4Instance and GHGInstance?)

"netcdf",
dates_range=self.eac4_parameters.dates_range,
time_values=self.eac4_parameters.time_values,
)
data.download()
case DatasetName.ghg:
assert self.ghg_parameters is not None
data = InversionOptimisedGreenhouseGas(
data_variable=self.ghg_parameters.data_variable,
file_format="zip",
quantity=self.ghg_parameters.quantity,
input_observations=self.ghg_parameters.input_observations,
time_aggregation=self.ghg_parameters.time_aggregation,
year=self.ghg_parameters.year,
month=self.ghg_parameters.month,
)
data.download()
return data.read_dataset()

@abstractmethod
def transform_data(self: PlottingInterface):
"""Transforms data as needed"""
raise NotImplementedError("Method not implemented")

@abstractmethod
def plot(self: PlottingInterface):
Expand All @@ -48,18 +134,75 @@ def plot(self: PlottingInterface):
class TimeSeriesPlotInstance(PlottingInterface):
"""Time series plot object"""

# pylint: disable=too-many-arguments

_plot_type: PlotType = PlotType.time_series
_data_array: DataArray
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is the right type. It becomes a dataframe in line 188


def __init__(
self: PlottingInterface,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to use named arguments here

dataset_name: DatasetName,
title: str,
countries: list[str],
resampling: str,
eac4_parameters: EAC4Parameters,
ghg_parameters: GHGParameters,
):
match dataset_name:
case DatasetName.eac4:
super().__init__(
dataset_name,
eac4_parameters,
None,
countries,
title,
)
case DatasetName.ghg:
super().__init__(
dataset_name,
None,
ghg_parameters,
countries,
title,
)
self.resampling = resampling

def download_data(self: TimeSeriesPlotInstance):
self._data_array = super().download_data()

def transform_data(self: TimeSeriesPlotInstance):
df_down = self._data_array.rio.write_crs("EPSG:4326")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still have to change variable names here

df_clipped = clip_and_concat_countries(df_down, self.countries).sel(
admin=self.countries[0]
)
df_agg = (
df_clipped.mean(dim=["latitude", "longitude"])
.resample(time=self.resampling, restore_coord_dims=True)
.mean(dim="time")
)
reference_value = df_agg.mean(dim="time")
df_converted = convert_units_array(df_agg[self.var_name], self.data_variable)
reference_value = df_converted.mean().values
df_anomalies = df_converted - reference_value
df_anomalies.attrs = df_converted.attrs
self._data_array = df_anomalies

# def __init__(
# self: PlottingInterface,
# data_variable: str,
# var_name: str,
# time_period: str,
# ):
# super().__init__(data_variable, var_name, time_period)

def plot(
self: TimeSeriesPlotInstance,
) -> go.Figure:
fig = px.line()
def plot(self: TimeSeriesPlotInstance) -> go.Figure:
fig = px.line(
y=self._data_array.values,
x=self._data_array.coords["time"],
markers="o",
)
fig.update_xaxes(title="Month")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a variable

fig.update_yaxes(title=self._data_array.attrs["units"])
fig.update_layout(
title={
"text": self.title,
"x": 0.45,
"y": 0.95,
"automargin": True,
"yref": "container",
"font": {"size": 19},
}
)
return fig