Skip to content

Commit

Permalink
Fix ensemble_time_series.py for case with only one time step
Browse files Browse the repository at this point in the history
  • Loading branch information
senesis committed Jun 25, 2024
1 parent 6ada762 commit 338114c
Showing 1 changed file with 61 additions and 33 deletions.
94 changes: 61 additions & 33 deletions scripts/ensemble_time_series_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@
# -- For this, we use the python library argparse
# --------------------------------------------------------------------------------------------------
from __future__ import print_function, division, unicode_literals, absolute_import
from netCDF4 import Dataset, num2date
import numpy as np
import cftime
from datetime import timedelta
import matplotlib.lines as mlines
import matplotlib.pyplot as plt

import argparse

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from datetime import timedelta
try:
from datetime import datetime as cdatetime
except:
import datetime as cdatetime
import cftime
import numpy as np

# This use of netCDF4 should be changed to using xarray, for
# homogeneity with CliMAF main code
from netCDF4 import Dataset, num2date


# SS : Try to bypass use of netcdftime, which is not included in conda's NetCDF4, and
Expand All @@ -51,7 +51,8 @@

# -- Initialize the parser
# --------------------------------------------------------------------------------------------------
parser = argparse.ArgumentParser(description='Plot script for CliMAF that handles CliMAF ensemble')
parser = argparse.ArgumentParser(
description='Plot script for CliMAF that handles CliMAF ensemble')

# -- Describe the arguments you need
# --------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -287,12 +288,14 @@
if '-' in elt:
dumsplit = elt.split('-')
print((0, tuple(map(int, dumsplit[1:len(dumsplit)]))))
linestyles_list[linestyles_list.index(elt)] = (0, tuple(map(int, dumsplit[1:len(dumsplit)])))
linestyles_list[linestyles_list.index(elt)] = (
0, tuple(map(int, dumsplit[1:len(dumsplit)])))

if args.highlight_period_lw:
highlight_period_lw_list = args.highlight_period_lw.split(',')
if len(highlight_period_lw_list) == 1:
highlight_period_lw_list = highlight_period_lw_list * len(filenames_list)
highlight_period_lw_list = highlight_period_lw_list * \
len(filenames_list)
else:
highlight_period_lw_list = [2.5] * len(filenames_list)

Expand All @@ -313,13 +316,16 @@
else:
horizontal_lines_lw_list = [2] * len(horizontal_lines_values_list)
if len(horizontal_lines_lw_list) == 1 and len(horizontal_lines_values_list) > 1:
horizontal_lines_lw_list = horizontal_lines_lw_list * len(horizontal_lines_values_list)
horizontal_lines_lw_list = horizontal_lines_lw_list * \
len(horizontal_lines_values_list)
if args.horizontal_lines_colors:
horizontal_lines_colors_list = args.horizontal_lines_colors.split(',')
else:
horizontal_lines_colors_list = ['black'] * len(horizontal_lines_values_list)
horizontal_lines_colors_list = [
'black'] * len(horizontal_lines_values_list)
if len(horizontal_lines_colors_list) == 1 and len(horizontal_lines_values_list) > 1:
horizontal_lines_colors_list = horizontal_lines_colors_list * len(horizontal_lines_values_list)
horizontal_lines_colors_list = horizontal_lines_colors_list * \
len(horizontal_lines_values_list)

for ind in range(0, len(horizontal_lines_values_list)):
hline_val = horizontal_lines_values_list[ind]
Expand All @@ -341,7 +347,8 @@
tname = dim
break
nctime = dat.variables[tname][:]
t_unit = dat.variables[tname].units # get unit "days since 1950-01-01T00:00:00Z"
# get unit "days since 1950-01-01T00:00:00Z"
t_unit = dat.variables[tname].units
if 'months' in t_unit:
if len(nctime) == 12:
x = np.array(range(1, 13))
Expand Down Expand Up @@ -371,18 +378,22 @@
year = int(strdate.split('-')[0])
month = int(strdate.split('-')[1])
day = int(strdate.split('-')[2])
if year == 0:
year = 3000
datevar.append(cdatetime(year, month, day))
else:
datevar.append(elt)
print('datevar = ', datevar)
#
x = np.array(datevar)
# x = np.array(datevar)[0,:]

# y = test_dat[:,0,0]
y = np.squeeze(test_dat)
if len(y.shape) > 1:
print("input data is not 1D")
if len(y.shape) == 0:
y = np.array([y])
print('lw_list[dataset_number]', lw_list[dataset_number])
print('colors[dataset_number]', colors[dataset_number])
print('linestyles_list[dataset_number]', linestyles_list[dataset_number])
Expand Down Expand Up @@ -410,8 +421,8 @@
print('X :', x[0])
print('X :', np.shape(x))
print('X :', type(x[0]))
print('Y :', int(y[0]))
print('Y :', type(int(y[0])))
#print('Y :', int(y[0]))
#print('Y :', type(int(y[0])))

# handles_for_legend.append(
# #plt.plot(x,y,lw=lw_list[filenames_list.index(pathfilename)], color=colors[filenames_list.index(pathfilename)],
Expand All @@ -426,15 +437,18 @@
# -- Highlight the period used to compute the climatology
if args.highlight_period:
# highlight_period = highlight_period_list[filenames_list.index(pathfilename)]
highlight_period = highlight_period_list[dataset_number] # filenames_list.index(pathfilename)]
# filenames_list.index(pathfilename)]
highlight_period = highlight_period_list[dataset_number]
sep = ('_' if '_' in highlight_period else '-')
dum = highlight_period.split(sep)
startyear = int(dum[0])
endyear = int(dum[1])
#
ind = np.argwhere((x > cdatetime(startyear, 1, 1)) & (x < cdatetime(endyear, 12, 31))).flatten()
ind = np.argwhere((x > cdatetime(startyear, 1, 1)) & (
x < cdatetime(endyear, 12, 31))).flatten()
print('highlight_period = ', highlight_period)
print("highlight_period_lw_list[dataset_number] = ", highlight_period_lw_list[dataset_number])
print("highlight_period_lw_list[dataset_number] = ",
highlight_period_lw_list[dataset_number])
plt.plot(x[ind], y[ind],
lw=highlight_period_lw_list[dataset_number],
alpha=alphas_list[dataset_number],
Expand Down Expand Up @@ -463,13 +477,16 @@
if len(split_x) == 2:
x_date = cdatetime(int(split_x[0]), int(split_x[1]))
elif len(split_x) == 3:
x_date = cdatetime(int(split_x[0]), int(split_x[1]), int(split_x[2]))
x_date = cdatetime(int(split_x[0]), int(
split_x[1]), int(split_x[2]))
elif len(x_text) == 6:
x_date = cdatetime(int(x_text[0:4]), int(x_text[4:6]), 15)
elif len(x_text) == 8:
x_date = cdatetime(int(x_text[0:4]), int(x_text[4:6]), int(x_text[6:8]))
x_date = cdatetime(int(x_text[0:4]), int(
x_text[4:6]), int(x_text[6:8]))
else:
print('--> Date provided as x value could not be interpreted: ', xlim_date)
print(
'--> Date provided as x value could not be interpreted: ', xlim_date)
xlim_period.append(x_date)
plt.xlim(xlim_period)

Expand All @@ -494,7 +511,8 @@
if args.right_string:
plt.title(args.right_string, loc='right', fontsize=right_string_fontsize)
else:
plt.title('Variable = ' + variable, loc='right', fontsize=right_string_fontsize)
plt.title('Variable = ' + variable, loc='right',
fontsize=right_string_fontsize)

# -- X and Y axis labels
if args.xlabel:
Expand All @@ -504,7 +522,8 @@
plt.ylabel(args.ylabel,
fontsize=(float(args.ylabel_fontsize) if args.ylabel_fontsize else default_ylabel_fontsize))

plt.tick_params(labelsize=(float(args.tick_size) if args.tick_size else default_tick_size))
plt.tick_params(labelsize=(float(args.tick_size)
if args.tick_size else default_tick_size))
#
# -- Draw legend by hand
draw_legend = (False if args.draw_legend.lower() in ['false'] else True)
Expand All @@ -514,7 +533,8 @@
legend_fontsize = (args.legend_fontsize if args.legend_fontsize else '12')
legend_ncol = (args.legend_ncol if args.legend_ncol else 1)
legend_frame = (True if args.legend_frame.lower() in ['true'] else False)
legend_colors_list = (args.legend_colors.split(',') if args.legend_colors else colors)
legend_colors_list = (args.legend_colors.split(
',') if args.legend_colors else colors)
leg_dict = dict(bbox_to_anchor=(float(legend_xy_pos.split(',')[0]), float(legend_xy_pos.split(',')[1])),
loc=int(legend_loc), borderaxespad=0., prop={'size': float(legend_fontsize)}, ncol=int(legend_ncol),
frameon=legend_frame)
Expand All @@ -523,7 +543,8 @@
legend_labels_list = args.legend_labels.split(',')
# if add_custom_legend_to_default:
# -- Do we start a new legend or append to the existing one?
legend_handles = (handles_for_legend if args.append_custom_legend_to_default.lower() in ['true'] else [])
legend_handles = (
handles_for_legend if args.append_custom_legend_to_default.lower() in ['true'] else [])
print('colors = ', colors)
print('legend_colors_list = ', legend_colors_list)
for legend_label in legend_labels_list:
Expand All @@ -544,7 +565,8 @@
if len(legend_lw_list) == len(legend_labels_list):
legend_lw_list = [2] * len(filenames_list) + legend_lw_list
if len(legend_lw_list) == (len(legend_labels_list) + 1):
legend_lw_list = [legend_lw_list[0]] * len(filenames_list) + legend_lw_list[1:len(legend_lw_list)]
legend_lw_list = [
legend_lw_list[0]] * len(filenames_list) + legend_lw_list[1:len(legend_lw_list)]

# for legobj in leg.legendHandles:
print('legend_lw_list = ', legend_lw_list)
Expand All @@ -554,20 +576,24 @@
# -- Add some text
if args.text:
text_list = args.text.split('|')
text_fontsize_list = (args.text_fontsize.split(',') if args.text_fontsize else [12] * len(text_list))
text_fontsize_list = (args.text_fontsize.split(
',') if args.text_fontsize else [12] * len(text_list))
if len(text_fontsize_list) == 1 and len(text_list) > 1:
text_fontsize_list = text_fontsize_list * len(text_list)
text_colors_list = (args.text_colors.split(',') if args.text_colors else ['black'] * len(text_list))
text_colors_list = (args.text_colors.split(
',') if args.text_colors else ['black'] * len(text_list))
if len(text_colors_list) == 1 and len(text_list) > 1:
text_colors_list = text_colors_list * len(text_list)
text_verticalalignment_list = (args.text_verticalalignment.split(',') if args.text_verticalalignment
else ['bottom'] * len(text_list))
if len(text_verticalalignment_list) == 1 and len(text_list) > 1:
text_verticalalignment_list = text_verticalalignment_list * len(text_list)
text_verticalalignment_list = text_verticalalignment_list * \
len(text_list)
text_horizontalalignment_list = (args.text_horizontalalignment.split(',') if args.text_horizontalalignment
else ['left'] * len(text_list))
if len(text_horizontalalignment_list) == 1 and len(text_list) > 1:
text_horizontalalignment_list = text_horizontalalignment_list * len(text_list)
text_horizontalalignment_list = text_horizontalalignment_list * \
len(text_list)
for text_elt in text_list:
text_ind = text_list.index(text_elt)
# -- treatment of the x value = date
Expand All @@ -581,11 +607,13 @@
if len(split_x) == 2:
x_date = cdatetime(int(split_x[0]), int(split_x[1]))
elif len(split_x) == 3:
x_date = cdatetime(int(split_x[0]), int(split_x[1]), int(split_x[2]))
x_date = cdatetime(int(split_x[0]), int(
split_x[1]), int(split_x[2]))
elif len(x_text) == 6:
x_date = cdatetime(int(x_text[0:4]), int(x_text[4:6]))
elif len(x_text) == 8:
x_date = cdatetime(int(x_text[0:4]), int(x_text[4:6]), int(x_text[6:8]))
x_date = cdatetime(int(x_text[0:4]), int(
x_text[4:6]), int(x_text[6:8]))
else:
print('--> Date provided as x value could not be interpreted: ', x_text)
# -- y
Expand Down

0 comments on commit 338114c

Please sign in to comment.