Skip to content

Commit

Permalink
fix memory
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Feb 6, 2024
1 parent ffaca82 commit 1c56c2f
Showing 1 changed file with 114 additions and 15 deletions.
129 changes: 114 additions & 15 deletions notebooks/wp3/WIP-hit_rate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
"metadata": {},
"outputs": [],
"source": [
"import tempfile\n",
"\n",
"import cartopy.crs as ccrs\n",
"import matplotlib.pyplot as plt\n",
"import regionmask\n",
"import scipy.stats\n",
"import xarray as xr\n",
"from c3s_eqc_automatic_quality_control import diagnostics, download\n",
"\n",
Expand All @@ -48,7 +51,6 @@
"outputs": [],
"source": [
"# Time\n",
"year_forecast = 2023\n",
"year_start_hindcast = 1993\n",
"year_stop_hindcast = 2016\n",
"\n",
Expand Down Expand Up @@ -182,31 +184,63 @@
"metadata": {},
"outputs": [],
"source": [
"def compute_tercile_occupation(ds, region):\n",
" # Anomaly\n",
" ds = ds - diagnostics.time_weighted_mean(ds)\n",
"def mode(*args, axis=None, **kwargs):\n",
" return scipy.stats.mode(*args, axis=axis, **kwargs).mode\n",
"\n",
" # Reindex using year/month\n",
" time = ds[\"forecast_reference_time\"]\n",
" ds = ds.assign_coords(\n",
" year=(time.name, time.dt.year.data),\n",
" month=(time.name, time.dt.month.data),\n",
" )\n",
" ds = ds.set_index({time.name: (\"year\", \"month\")}).unstack(time.name)\n",
"\n",
"def compute_tercile_occupation(ds, region):\n",
" # Mask region\n",
" mask = regionmask.defined_regions.srex.mask(ds)\n",
" index = regionmask.defined_regions.srex.map_keys(region)\n",
" ds = ds.where((mask == index).compute(), drop=True)\n",
"\n",
" # Get valid and starting time\n",
" if \"leadtime_month\" in ds.dims:\n",
" ds = ds.rename(forecast_reference_time=\"starting_time\")\n",
" ds = ds.stack(\n",
" valid_time=(\"starting_time\", \"leadtime_month\"),\n",
" create_index=False,\n",
" )\n",
" ds[\"valid_time\"] = ds[\"starting_time\"].values\n",
" for shift in set(ds[\"leadtime_month\"].values):\n",
" shifted = ds.indexes[\"valid_time\"].shift(shift - 1, \"MS\")\n",
" ds[\"valid_time\"] = ds[\"valid_time\"].where(\n",
" ds[\"leadtime_month\"] != shift, shifted\n",
" )\n",
" else:\n",
" ds = ds.rename(forecast_reference_time=\"valid_time\")\n",
"\n",
" # Compute anomaly\n",
" climatology = diagnostics.time_weighted_mean(ds, time_name=\"valid_time\")\n",
" climatology = climatology.mean(set(climatology.dims) - {\"latitude\", \"longitude\"})\n",
" ds -= climatology\n",
"\n",
" # Reindex using year/month/starting month\n",
" time = ds[\"valid_time\"]\n",
" coords = {\n",
" \"year\": (time.name, time.dt.year.data),\n",
" \"month\": (time.name, time.dt.month.data),\n",
" }\n",
" if \"starting_time\" in ds.coords:\n",
" coords[\"starting_month\"] = (time.name, time[\"starting_time\"].dt.month.data)\n",
" ds = ds.assign_coords(coords)\n",
" ds = ds.set_index({time.name: tuple(coords)}).unstack(time.name)\n",
"\n",
" # Spatial mean\n",
" ds = diagnostics.spatial_weighted_mean(ds, weights=False)\n",
"\n",
" # Get quantiles\n",
" quantiles = ds.chunk(year=-1).quantile([1 / 3, 2 / 3], \"year\")\n",
" mask = xr.zeros_like(ds, None)\n",
" mask = xr.where(ds < quantiles.sel(quantile=1 / 3), -1, mask)\n",
" mask = xr.where(ds > quantiles.sel(quantile=2 / 3), 1, mask)\n",
" low = quantiles.sel(quantile=1 / 3)\n",
" high = quantiles.sel(quantile=2 / 3)\n",
" mask = xr.full_like(ds, None)\n",
" mask = xr.where(ds < low, -1, mask)\n",
" mask = xr.where((ds >= low) & (ds <= high), 0, mask)\n",
" mask = xr.where(ds > high, 1, mask)\n",
"\n",
" if \"realization\" in mask.dims:\n",
" # Get mode\n",
" mask = mask.reduce(mode, dim=\"realization\")\n",
"\n",
" return mask"
]
Expand Down Expand Up @@ -255,7 +289,72 @@
"metadata": {},
"outputs": [],
"source": [
"ds_reanalysis.reset_coords(drop=True)[\"2m_temperature\"].plot()"
"ds_reanalysis.reset_coords(drop=True)[\"2m_temperature\"].plot(cmap=\"viridis\")"
]
},
{
"cell_type": "markdown",
"id": "14",
"metadata": {},
"source": [
"## Download and transform seasonal forecast"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"# Get the seasonal forecast data\n",
"datasets = []\n",
"for centre, request_kwargs in centres.items():\n",
" for region in regions:\n",
" dataarrays = []\n",
" for variable in variables:\n",
" print(f\"{centre=} {region=} {variable=}\")\n",
" if variable in missing_variables.get(centre, []):\n",
" print(\"SKIP\")\n",
" continue\n",
"\n",
" with tempfile.TemporaryDirectory() as TMPDIR:\n",
" ds = download.download_and_transform(\n",
" collection_id_seasonal,\n",
" request_seasonal\n",
" | {\"originating_centre\": centre, \"variable\": variable}\n",
" | request_kwargs,\n",
" chunks=chunks,\n",
" transform_chunks=False,\n",
" transform_func=compute_tercile_occupation,\n",
" transform_func_kwargs={\"region\": region},\n",
" backend_kwargs={\n",
" \"time_dims\": (\n",
" \"forecastMonth\",\n",
" (\n",
" \"indexing_time\"\n",
" if centre in [\"ukmo\", \"jma\", \"ncep\"]\n",
" else \"time\"\n",
" ),\n",
" )\n",
" },\n",
" )\n",
" (da,) = ds.data_vars.values()\n",
" dataarrays.append(da.rename(variable))\n",
" ds = xr.merge(dataarrays)\n",
" datasets.append(ds.expand_dims(centre=[centre], region=[region]).compute())\n",
"ds_seasonal = xr.merge(datasets)\n",
"del datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"ds_seasonal[\"2m_temperature\"].plot(col=\"year\", col_wrap=5, cmap=\"viridis\")"
]
}
],
Expand Down

0 comments on commit 1c56c2f

Please sign in to comment.