1
1
#!/usr/bin/env python
2
2
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
6
6
import logging
@@ -43,7 +43,11 @@ def _add_unit(num, unit):
43
43
def _fit_model (data , params , additional_regressors ):
44
44
from prophet import Prophet
45
45
46
+ monthly_seasonality = params .pop ("monthly_seasonality" , False )
46
47
model = Prophet (** params )
48
+ if monthly_seasonality :
49
+ model .add_seasonality (name = "monthly" , period = 30.5 , fourier_order = 5 )
50
+ params ["monthly_seasonality" ] = monthly_seasonality
47
51
for add_reg in additional_regressors :
48
52
model .add_regressor (add_reg )
49
53
model .fit (data )
@@ -256,7 +260,7 @@ def _generate_report(self):
256
260
self .outputs [s_id ], include_legend = True
257
261
),
258
262
series_ids = series_ids ,
259
- target_category_column = self .target_cat_col
263
+ target_category_column = self .target_cat_col ,
260
264
)
261
265
section_1 = rc .Block (
262
266
rc .Heading ("Forecast Overview" , level = 2 ),
@@ -269,7 +273,7 @@ def _generate_report(self):
269
273
sec2 = _select_plot_list (
270
274
lambda s_id : self .models [s_id ].plot_components (self .outputs [s_id ]),
271
275
series_ids = series_ids ,
272
- target_category_column = self .target_cat_col
276
+ target_category_column = self .target_cat_col ,
273
277
)
274
278
section_2 = rc .Block (
275
279
rc .Heading ("Forecast Broken Down by Trend Component" , level = 2 ), sec2
@@ -285,7 +289,7 @@ def _generate_report(self):
285
289
sec3 = _select_plot_list (
286
290
lambda s_id : sec3_figs [s_id ],
287
291
series_ids = series_ids ,
288
- target_category_column = self .target_cat_col
292
+ target_category_column = self .target_cat_col ,
289
293
)
290
294
section_3 = rc .Block (rc .Heading ("Forecast Changepoints" , level = 2 ), sec3 )
291
295
@@ -299,7 +303,9 @@ def _generate_report(self):
299
303
pd .Series (
300
304
m .seasonalities ,
301
305
index = pd .Index (m .seasonalities .keys (), dtype = "object" ),
302
- name = s_id if self .target_cat_col else self .original_target_column ,
306
+ name = s_id
307
+ if self .target_cat_col
308
+ else self .original_target_column ,
303
309
dtype = "object" ,
304
310
)
305
311
)
@@ -330,11 +336,15 @@ def _generate_report(self):
330
336
self .formatted_local_explanation = aggregate_local_explanations
331
337
332
338
if not self .target_cat_col :
333
- self .formatted_global_explanation = self .formatted_global_explanation .rename (
334
- {"Series 1" : self .original_target_column },
335
- axis = 1 ,
339
+ self .formatted_global_explanation = (
340
+ self .formatted_global_explanation .rename (
341
+ {"Series 1" : self .original_target_column },
342
+ axis = 1 ,
343
+ )
344
+ )
345
+ self .formatted_local_explanation .drop (
346
+ "Series" , axis = 1 , inplace = True
336
347
)
337
- self .formatted_local_explanation .drop ("Series" , axis = 1 , inplace = True )
338
348
339
349
# Create a markdown section for the global explainability
340
350
global_explanation_section = rc .Block (
0 commit comments