Skip to content

Commit

Permalink
feat: add Expr.dt.add_business_days and Series.dt.add_business_days
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 12, 2024
1 parent 0b84b14 commit d76e501
Show file tree
Hide file tree
Showing 14 changed files with 875 additions and 14 deletions.
2 changes: 1 addition & 1 deletion crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ is_unique = []
unique_counts = []
is_between = []
approx_unique = []
business = ["dtype-date"]
business = ["dtype-date", "chrono"]
fused = []
cutqcut = ["dtype-categorical", "dtype-struct"]
rle = ["dtype-struct"]
Expand Down
289 changes: 280 additions & 9 deletions crates/polars-ops/src/series/ops/business.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
use polars_core::prelude::arity::binary_elementwise_values;
#[cfg(feature = "dtype-date")]
use chrono::DateTime;
use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise};
use polars_core::prelude::*;
#[cfg(feature = "dtype-date")]
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[cfg(feature = "timezones")]
use crate::prelude::replace_time_zone;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Roll {
Forward,
Backward,
Raise,
}

/// Count the number of business days between `start` and `end`, excluding `end`.
///
Expand Down Expand Up @@ -93,19 +110,19 @@ fn business_day_count_impl(
Err(x) => x as i32 + holidays_begin,
};

let mut start_weekday = weekday(start_date);
let mut start_day_of_week = get_day_of_week(start_date);
let diff = end_date - start_date;
let whole_weeks = diff / 7;
let mut count = -(holidays_end - holidays_begin);
count += whole_weeks * n_business_days_in_week_mask;
start_date += whole_weeks * 7;
while start_date < end_date {
// SAFETY: week_mask is length 7, start_weekday is between 0 and 6
if unsafe { *week_mask.get_unchecked(start_weekday) } {
// SAFETY: week_mask is length 7, start_day_of_week is between 0 and 6
if unsafe { *week_mask.get_unchecked(start_day_of_week) } {
count += 1;
}
start_date += 1;
start_weekday = increment_weekday(start_weekday);
start_day_of_week = increment_day_of_week(start_day_of_week);
}
if swapped {
-count
Expand All @@ -114,14 +131,260 @@ fn business_day_count_impl(
}
}

/// Add a given number of business days.
///
/// # Arguments
/// - `start`: Series holding start dates.
/// - `n`: Number of business days to add.
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of
/// days since the UNIX epoch.
/// - `roll`: what to do when the start date doesn't land on a business day:
/// - `Roll::Forward`: roll forward to the next business day.
/// - `Roll::Backward`: roll backward to the previous business day.
/// - `Roll::Raise`: raise an error.
pub fn add_business_days(
start: &Series,
n: &Series,
week_mask: [bool; 7],
holidays: &[i32],
roll: Roll,
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}

match start.dtype() {
DataType::Date => {},
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(time_unit, None) => {
let result_date =
add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?;
let start_time = start
.cast(&DataType::Time)?
.cast(&DataType::Duration(*time_unit))?;
return Ok(result_date.cast(&DataType::Datetime(*time_unit, None))? + start_time);
},
#[cfg(feature = "timezones")]
DataType::Datetime(time_unit, Some(time_zone)) => {
let start_naive = replace_time_zone(
start.datetime().unwrap(),
None,
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?;
let result_date = add_business_days(
&start_naive.cast(&DataType::Date)?,
n,
week_mask,
holidays,
roll,
)?;
let start_time = start_naive
.cast(&DataType::Time)?
.cast(&DataType::Duration(*time_unit))?;
let result_naive =
result_date.cast(&DataType::Datetime(*time_unit, None))? + start_time;
let result_tz_aware = replace_time_zone(
result_naive.datetime().unwrap(),
Some(time_zone),
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?;
return Ok(result_tz_aware.into_series());
},
_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()),
}

let holidays = normalise_holidays(holidays, &week_mask);
let start_dates = start.date()?;
let n = match &n.dtype() {
DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?,
DataType::Int32 => n.clone(),
_ => {
polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype())
},
};
let n = n.i32()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;

let out: Int32Chunked = match (start_dates.len(), n.len()) {
(_, 1) => {
if let Some(n) = n.get(0) {
start_dates.try_apply_nonnull_values_generic(|start_date| {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
Ok::<i32, PolarsError>(add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
))
})?
} else {
Int32Chunked::full_null(start_dates.name(), start_dates.len())
}
},
(1, _) => {
if let Some(start_date) = start_dates.get(0) {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
n.try_apply_nonnull_values_generic(|n| {
Ok::<i32, PolarsError>(add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
))
})?
} else {
Int32Chunked::full_null(start_dates.name(), n.len())
}
},
_ => try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| {
match (opt_start_date, opt_n) {
(Some(start_date), Some(n)) => {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
Ok::<Option<i32>, PolarsError>(Some(add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)))
},
_ => Ok(None),
}
})?,
};
Ok(out.into_date().into_series())
}

/// Ported from:
/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L265-L353
fn add_business_days_impl(
mut date: i32,
mut day_of_week: usize,
mut n: i32,
week_mask: &[bool; 7],
n_business_days_in_week_mask: i32,
holidays: &[i32],
) -> i32 {
if n > 0 {
let holidays_begin = match holidays.binary_search(&date) {
Ok(_) => {
unreachable!("`add_business_days` would have errored if `date` was a holiday.")
},
Err(x) => x,
};

date += (n / n_business_days_in_week_mask) * 7;
n %= n_business_days_in_week_mask;

let holidays_temp = match holidays[holidays_begin..].binary_search(&date) {
Ok(x) => x + 1,
Err(x) => x,
} + holidays_begin;

n += (holidays_temp - holidays_begin) as i32;
let holidays_begin = holidays_temp;

while n > 0 {
date += 1;
day_of_week = increment_day_of_week(day_of_week);
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
if unsafe {
(*week_mask.get_unchecked(day_of_week))
&& (!holidays[holidays_begin..].contains(&date))
} {
n -= 1;
}
}
date
} else {
let holidays_end = match holidays.binary_search(&date) {
Ok(x) => x + 1,
Err(x) => x,
};

date += (n / n_business_days_in_week_mask) * 7;
n %= n_business_days_in_week_mask;

let holidays_temp = match holidays[..holidays_end].binary_search(&date) {
Ok(x) => x,
Err(x) => x,
};

n -= (holidays_end - holidays_temp) as i32;
let holidays_end = holidays_temp;

while n < 0 {
date -= 1;
day_of_week = decrement_day_of_week(day_of_week);
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
if unsafe {
(*week_mask.get_unchecked(day_of_week))
&& (!holidays[..holidays_end].contains(&date))
} {
n += 1;
}
}
date
}
}

fn roll_start_date(
mut date: i32,
roll: Roll,
week_mask: &[bool; 7],
holidays: &[i32],
) -> PolarsResult<(i32, usize)> {
let mut day_of_week = get_day_of_week(date);
match roll {
Roll::Raise => {
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
if holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0)
.unwrap()
.format("%Y-%m-%d");
polars_bail!(ComputeError:
"date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date
)
};
},
Roll::Forward => {
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
date += 1;
day_of_week = increment_day_of_week(day_of_week);
}
},
Roll::Backward => {
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
date -= 1;
day_of_week = decrement_day_of_week(day_of_week);
}
},
}
Ok((date, day_of_week))
}

/// Sort and deduplicate holidays and remove holidays that are not business days.
fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
let mut holidays: Vec<i32> = holidays.to_vec();
holidays.sort_unstable();
let mut previous_holiday: Option<i32> = None;
holidays.retain(|&x| {
// SAFETY: week_mask is length 7, start_weekday is between 0 and 6
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(weekday(x)) } {
// SAFETY: week_mask is length 7, get_day_of_week result is between 0 and 6
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) }
{
return false;
}
previous_holiday = Some(x);
Expand All @@ -130,17 +393,25 @@ fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
holidays
}

fn weekday(x: i32) -> usize {
fn get_day_of_week(x: i32) -> usize {
// the first modulo might return a negative number, so we add 7 and take
// the modulo again so we're sure we have something between 0 (Monday)
// and 6 (Sunday)
(((x - 4) % 7 + 7) % 7) as usize
}

fn increment_weekday(x: usize) -> usize {
fn increment_day_of_week(x: usize) -> usize {
if x == 6 {
0
} else {
x + 1
}
}

fn decrement_day_of_week(x: usize) -> usize {
if x == 0 {
6
} else {
x - 1
}
}
21 changes: 21 additions & 0 deletions crates/polars-plan/src/dsl/dt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@ use super::*;
pub struct DateLikeNameSpace(pub(crate) Expr);

impl DateLikeNameSpace {
/// Add a given number of business days.
#[cfg(feature = "business")]
pub fn add_business_days(
self,
n: Expr,
week_mask: [bool; 7],
holidays: Vec<i32>,
roll: Roll,
) -> Expr {
self.0.map_many_private(
FunctionExpr::Business(BusinessFunction::AddBusinessDay {
week_mask,
holidays,
roll,
}),
&[n],
false,
false,
)
}

/// Convert from Date/Time/Datetime into String with the given format.
/// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html).
pub fn to_string(self, format: &str) -> Expr {
Expand Down
Loading

0 comments on commit d76e501

Please sign in to comment.