Skip to content

Commit

Permalink
serializable gdac pre-processor
Browse files Browse the repository at this point in the history
  • Loading branch information
gmaze committed Sep 26, 2024
1 parent 6f74686 commit 4f688a8
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 74 deletions.
149 changes: 75 additions & 74 deletions argopy/data_fetchers/gdac_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This is not intended to be used directly, only by the facade at fetchers.py
"""

import numpy as np
import pandas as pd
import xarray as xr
Expand All @@ -13,10 +14,13 @@
import logging

from ..utils.format import format_oneline, argo_split_path
from ..utils.decorators import deprecated
from ..options import OPTIONS, check_gdac_path
from ..errors import DataNotFound
from ..stores import ArgoIndex
from .proto import ArgoDataFetcherProto
from .gdac_data_processors import pre_process_multiprof, filter_points


log = logging.getLogger("argopy.gdac.data")
access_points = ["wmo", "box"]
Expand All @@ -42,27 +46,27 @@ class GDACArgoDataFetcher(ArgoDataFetcherProto):
###
@abstractmethod
def init(self, *args, **kwargs):
""" Initialisation for a specific fetcher """
"""Initialisation for a specific fetcher"""
raise NotImplementedError("Not implemented")

###
# Methods that must not change
###
def __init__(
self,
gdac: str = "",
ds: str = "",
cache: bool = False,
cachedir: str = "",
dimension: str = "point",
errors: str = "raise",
parallel: bool = False,
parallel_method: str = "thread",
progress: bool = False,
api_timeout: int = 0,
**kwargs
self,
gdac: str = "",
ds: str = "",
cache: bool = False,
cachedir: str = "",
dimension: str = "point",
errors: str = "raise",
parallel: bool = False,
parallel_method: str = "thread",
progress: bool = False,
api_timeout: int = 0,
**kwargs
):
""" Init fetcher
"""Init fetcher
Parameters
----------
Expand Down Expand Up @@ -117,9 +121,7 @@ def __init__(
if "N_RECORDS" in kwargs:
nrows = kwargs["N_RECORDS"]
# Number of records in the index, this will force to load the index file:
self.N_RECORDS = self.indexfs.load(
nrows=nrows
).N_RECORDS
self.N_RECORDS = self.indexfs.load(nrows=nrows).N_RECORDS
self._post_filter_points = False

# Set method to download data:
Expand Down Expand Up @@ -160,13 +162,13 @@ def __repr__(self):
return "\n".join(summary)

def cname(self):
""" Return a unique string defining the constraints """
"""Return a unique string defining the constraints"""
return self._cname()

@property
@abstractmethod
def uri(self):
""" Return the list of files to load
"""Return the list of files to load
Returns
-------
Expand All @@ -175,7 +177,7 @@ def uri(self):
raise NotImplementedError("Not implemented")

def uri_mono2multi(self, URIs: list):
""" Convert mono-profile URI files to multi-profile files
"""Convert mono-profile URI files to multi-profile files
Multi-profile file name is based on the dataset requested ('phy', 'bgc'/'bgc-s')
Expand Down Expand Up @@ -225,7 +227,7 @@ def mono2multi(mono_path):

@property
def cachepath(self):
""" Return path to cache file(s) for this request
"""Return path to cache file(s) for this request
Returns
-------
Expand All @@ -234,13 +236,14 @@ def cachepath(self):
return [self.fs.cachepath(url) for url in self.uri]

def clear_cache(self):
""" Remove cached files and entries from resources opened with this fetcher """
"""Remove cached files and entries from resources opened with this fetcher"""
self.indexfs.clear_cache()
self.fs.clear_cache()
return self

@deprecated("Not serializable")
def _preprocess_multiprof(self, ds):
""" Pre-process one Argo multi-profile file as a collection of points
"""Pre-process one Argo multi-profile file as a collection of points
Parameters
----------
Expand All @@ -255,7 +258,7 @@ def _preprocess_multiprof(self, ds):
# Remove raw netcdf file attributes and replace them with argopy ones:
raw_attrs = ds.attrs
ds.attrs = {}
ds.attrs.update({'raw_attrs': raw_attrs})
ds.attrs.update({"raw_attrs": raw_attrs})

# Rename JULD and JULD_QC to TIME and TIME_QC
ds = ds.rename(
Expand Down Expand Up @@ -293,7 +296,7 @@ def _preprocess_multiprof(self, ds):
try:
ds.attrs["Fetched_by"] = getpass.getuser()
except: # noqa: E722
ds.attrs["Fetched_by"] = 'anonymous'
ds.attrs["Fetched_by"] = "anonymous"
ds.attrs["Fetched_date"] = pd.to_datetime("now", utc=True).strftime("%Y/%m/%d")
ds.attrs["Fetched_constraints"] = self.cname()
ds.attrs["Fetched_uri"] = ds.encoding["source"]
Expand All @@ -304,8 +307,11 @@ def _preprocess_multiprof(self, ds):

return ds

def pre_process(self, ds, *args, **kwargs):
return pre_process_multiprof(ds, *args, **kwargs)

def to_xarray(self, errors: str = "ignore"):
""" Load Argo data and return a :class:`xarray.Dataset`
"""Load Argo data and return a :class:`xarray.Dataset`
Parameters
----------
Expand All @@ -332,19 +338,36 @@ def to_xarray(self, errors: str = "ignore"):
elif len(self.uri) == 0:
raise DataNotFound("No data found for: %s" % self.indexfs.cname)

# Download data:
if hasattr(self, "BOX"):
access_point = "BOX"
access_point_opts = {'BOX': self.BOX}
elif hasattr(self, "CYC"):
access_point = "CYC"
access_point_opts = {'CYC': self.CYC}
elif hasattr(self, "WMO"):
access_point = "WMO"
access_point_opts = {'WMO': self.WMO}

# Download and pre-process data:
ds = self.fs.open_mfdataset(
self.uri,
method=self.method,
concat_dim="N_POINTS",
concat=True,
preprocess=self._preprocess_multiprof,
preprocess=pre_process_multiprof,
preprocess_opts={
"access_point": access_point,
"access_point_opts": access_point_opts,
"pre_filter_points": self._post_filter_points,
},
progress=self.progress,
errors=errors,
open_dataset_opts={'xr_opts': {'decode_cf': 1, 'use_cftime': 0, 'mask_and_scale': 1}},
open_dataset_opts={
"xr_opts": {"decode_cf": 1, "use_cftime": 0, "mask_and_scale": 1}
},
)

# Data post-processing:
# Meta-data processing:
ds["N_POINTS"] = np.arange(
0, len(ds["N_POINTS"])
) # Re-index to avoid duplicate values
Expand All @@ -355,18 +378,20 @@ def to_xarray(self, errors: str = "ignore"):
if "Fetched_from" not in ds.attrs:
raw_attrs = ds.attrs
ds.attrs = {}
ds.attrs.update({'raw_attrs': raw_attrs})
ds.attrs.update({"raw_attrs": raw_attrs})
if self.dataset_id == "phy":
ds.attrs["DATA_ID"] = "ARGO"
if self.dataset_id == "bgc":
if self.dataset_id in ["bgc", "bgc-s"]:
ds.attrs["DATA_ID"] = "ARGO-BGC"
ds.attrs["DOI"] = "http://doi.org/10.17882/42182"
ds.attrs["Fetched_from"] = self.server
try:
ds.attrs["Fetched_by"] = getpass.getuser()
except:
ds.attrs["Fetched_by"] = 'anonymous'
ds.attrs["Fetched_date"] = pd.to_datetime("now", utc=True).strftime("%Y/%m/%d")
ds.attrs["Fetched_by"] = "anonymous"
ds.attrs["Fetched_date"] = pd.to_datetime("now", utc=True).strftime(
"%Y/%m/%d"
)

ds.attrs["Fetched_constraints"] = self.cname()
if len(self.uri) == 1:
Expand All @@ -376,42 +401,18 @@ def to_xarray(self, errors: str = "ignore"):

return ds

@deprecated("Refactored to GDAC pre-processor submodule")
def filter_points(self, ds):
""" Enforce request criteria
This may be necessary if for download performance improvement we had to work with multi instead of mono profile
files: we loaded and merged multi-profile files, and then we need to make sure to retain only profiles requested.
"""
if hasattr(self, "BOX"):
# - box = [lon_min, lon_max, lat_min, lat_max, pres_min, pres_max]
# - box = [lon_min, lon_max, lat_min, lat_max, pres_min, pres_max, datim_min, datim_max]
ds = (
ds.where(ds["LONGITUDE"] >= self.BOX[0], drop=True)
.where(ds["LONGITUDE"] < self.BOX[1], drop=True)
.where(ds["LATITUDE"] >= self.BOX[2], drop=True)
.where(ds["LATITUDE"] < self.BOX[3], drop=True)
.where(ds["PRES"] >= self.BOX[4], drop=True) # todo what about PRES_ADJUSTED ?
.where(ds["PRES"] < self.BOX[5], drop=True)
)
if len(self.BOX) == 8:
ds = ds.where(
ds["TIME"] >= np.datetime64(self.BOX[6]), drop=True
).where(ds["TIME"] < np.datetime64(self.BOX[7]), drop=True)

if hasattr(self, "CYC"):
this_mask = xr.DataArray(
np.zeros_like(ds["N_POINTS"]),
dims=["N_POINTS"],
coords={"N_POINTS": ds["N_POINTS"]},
)
for cyc in self.CYC:
this_mask += ds["CYCLE_NUMBER"] == cyc
this_mask = this_mask >= 1 # any
ds = ds.where(this_mask, drop=True)

ds["N_POINTS"] = np.arange(0, len(ds["N_POINTS"]))

return ds
access_point = "BOX"
access_point_opts = {'BOX': self.BOX}
elif hasattr(self, "CYC"):
access_point = "CYC"
access_point_opts = {'CYC': self.CYC}
elif hasattr(self, "WMO"):
access_point = "WMO"
access_point_opts = {'WMO': self.WMO}
return filter_points(ds, access_point=access_point, **access_point_opts)

def transform_data_mode(self, ds: xr.Dataset, **kwargs):
"""Apply xarray argo accessor transform_data_mode method"""
Expand All @@ -436,9 +437,9 @@ def filter_qc(self, ds: xr.Dataset, **kwargs):
def filter_researchmode(self, ds: xr.Dataset, *args, **kwargs) -> xr.Dataset:
"""Filter dataset for research user mode
This filter will select only QC=1 delayed mode data with pressure errors smaller than 20db
This filter will select only QC=1 delayed mode data with pressure errors smaller than 20db
Use this filter instead of transform_data_mode and filter_qc
Use this filter instead of transform_data_mode and filter_qc
"""
ds = ds.argo.filter_researchmode(**kwargs)
if ds.argo._type == "point":
Expand All @@ -457,7 +458,7 @@ class Fetch_wmo(GDACArgoDataFetcher):
"""

def init(self, WMO: list = [], CYC=None, **kwargs):
""" Create Argo data loader for WMOs
"""Create Argo data loader for WMOs
Parameters
----------
Expand All @@ -483,7 +484,7 @@ def init(self, WMO: list = [], CYC=None, **kwargs):

@property
def uri(self):
""" List of files to load for a request
"""List of files to load for a request
Returns
-------
Expand Down Expand Up @@ -513,7 +514,7 @@ class Fetch_box(GDACArgoDataFetcher):
"""

def init(self, box: list, nrows=None, **kwargs):
""" Create Argo data loader
"""Create Argo data loader
Parameters
----------
Expand All @@ -540,7 +541,7 @@ def init(self, box: list, nrows=None, **kwargs):

@property
def uri(self):
""" List of files to load for a request
"""List of files to load for a request
Returns
-------
Expand Down
Loading

0 comments on commit 4f688a8

Please sign in to comment.