diff --git a/crates/bevy_math/src/curve/cores.rs b/crates/bevy_math/src/curve/cores.rs new file mode 100644 index 0000000000000..101ce877c8929 --- /dev/null +++ b/crates/bevy_math/src/curve/cores.rs @@ -0,0 +1,550 @@ +//! Core data structures to be used internally in Curve implementations, encapsulating storage +//! and access patterns for reuse. + +use super::interval::Interval; +use core::fmt::Debug; +use thiserror::Error; + +#[cfg(feature = "bevy_reflect")] +use bevy_reflect::Reflect; + +/// This type expresses the relationship of a value to a fixed collection of values. It is a kind +/// of summary used intermediately by sampling operations. +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub enum InterpolationDatum { + /// This value lies exactly on a value in the family. + Exact(T), + + /// This value is off the left tail of the family; the inner value is the family's leftmost. + LeftTail(T), + + /// This value is off the right tail of the family; the inner value is the family's rightmost. + RightTail(T), + + /// This value lies on the interior, in between two points, with a third parameter expressing + /// the interpolation factor between the two. + Between(T, T, f32), +} + +impl InterpolationDatum { + /// Map all values using a given function `f`, leaving the interpolation parameters in any + /// [`Between`] variants unchanged. + /// + /// [`Between`]: `InterpolationDatum::Between` + #[must_use] + pub fn map(self, f: impl Fn(T) -> S) -> InterpolationDatum { + match self { + InterpolationDatum::Exact(v) => InterpolationDatum::Exact(f(v)), + InterpolationDatum::LeftTail(v) => InterpolationDatum::LeftTail(f(v)), + InterpolationDatum::RightTail(v) => InterpolationDatum::RightTail(f(v)), + InterpolationDatum::Between(u, v, s) => InterpolationDatum::Between(f(u), f(v), s), + } + } +} + +/// The data core of a curve derived from evenly-spaced samples. The intention is to use this +/// in addition to explicit or inferred interpolation information in user-space in order to +/// implement curves using [`domain`] and [`sample_with`] +/// +/// The internals are made transparent to give curve authors freedom, but [the provided constructor] +/// enforces the required invariants. +/// +/// [the provided constructor]: EvenCore::new +/// [`domain`]: EvenCore::domain +/// [`sample_with`]: EvenCore::sample_with +/// +/// # Example +/// ```rust +/// # use bevy_math::curve::*; +/// # use bevy_math::curve::cores::*; +/// enum InterpolationMode { +/// Linear, +/// Step, +/// } +/// +/// trait LinearInterpolate { +/// fn lerp(&self, other: &Self, t: f32) -> Self; +/// } +/// +/// fn step(first: &T, second: &T, t: f32) -> T { +/// if t >= 1.0 { +/// second.clone() +/// } else { +/// first.clone() +/// } +/// } +/// +/// struct MyCurve { +/// core: EvenCore, +/// interpolation_mode: InterpolationMode, +/// } +/// +/// impl Curve for MyCurve +/// where +/// T: LinearInterpolate + Clone, +/// { +/// fn domain(&self) -> Interval { +/// self.core.domain() +/// } +/// +/// fn sample(&self, t: f32) -> T { +/// match self.interpolation_mode { +/// InterpolationMode::Linear => self.core.sample_with(t, ::lerp), +/// InterpolationMode::Step => self.core.sample_with(t, step), +/// } +/// } +/// } +/// ``` +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct EvenCore { + /// The domain over which the samples are taken, which corresponds to the domain of the curve + /// formed by interpolating them. + /// + /// # Invariants + /// This must always be a bounded interval; i.e. its endpoints must be finite. + pub domain: Interval, + + /// The samples that are interpolated to extract values. + /// + /// # Invariants + /// This must always have a length of at least 2. + pub samples: Vec, +} + +/// An error indicating that a [`EvenCore`] could not be constructed. +#[derive(Debug, Error, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub enum EvenCoreError { + /// Not enough samples were provided. + #[error("Need at least two samples to create a EvenCore, but {samples} were provided")] + NotEnoughSamples { + /// The number of samples that were provided. + samples: usize, + }, + + /// Unbounded domains are not compatible with `EvenCore`. + #[error("Cannot create a EvenCore over a domain with an infinite endpoint")] + InfiniteDomain, +} + +impl EvenCore { + /// Create a new [`EvenCore`] from the specified `domain` and `samples`. An error is returned + /// if there are not at least 2 samples or if the given domain is unbounded. + #[inline] + pub fn new(domain: Interval, samples: impl Into>) -> Result { + let samples: Vec = samples.into(); + if samples.len() < 2 { + return Err(EvenCoreError::NotEnoughSamples { + samples: samples.len(), + }); + } + if !domain.is_finite() { + return Err(EvenCoreError::InfiniteDomain); + } + + Ok(EvenCore { domain, samples }) + } + + /// The domain of the curve derived from this core. + #[inline] + pub fn domain(&self) -> Interval { + self.domain + } + + /// Obtain a value from the held samples using the given `interpolation` to interpolate + /// between adjacent samples. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + #[inline] + pub fn sample_with(&self, t: f32, interpolation: I) -> T + where + T: Clone, + I: Fn(&T, &T, f32) -> T, + { + match even_interp(self.domain, self.samples.len(), t) { + InterpolationDatum::Exact(idx) + | InterpolationDatum::LeftTail(idx) + | InterpolationDatum::RightTail(idx) => self.samples[idx].clone(), + InterpolationDatum::Between(lower_idx, upper_idx, s) => { + interpolation(&self.samples[lower_idx], &self.samples[upper_idx], s) + } + } + } + + /// Given a time `t`, obtain a [`InterpolationDatum`] which governs how interpolation might recover + /// a sample at time `t`. For example, when a [`Between`] value is returned, its contents can + /// be used to interpolate between the two contained values with the given parameter. The other + /// variants give additional context about where the value is relative to the family of samples. + /// + /// [`Between`]: `InterpolationDatum::Between` + pub fn sample_interp(&self, t: f32) -> InterpolationDatum<&T> { + even_interp(self.domain, self.samples.len(), t).map(|idx| &self.samples[idx]) + } + + /// Like [`sample_interp`], but the returned values include the sample times. This can be + /// useful when sampling is not scale-invariant. + /// + /// [`sample_interp`]: EvenCore::sample_interp + pub fn sample_interp_timed(&self, t: f32) -> InterpolationDatum<(f32, &T)> { + let segment_len = self.domain.length() / (self.samples.len() - 1) as f32; + even_interp(self.domain, self.samples.len(), t).map(|idx| { + ( + self.domain.start() + segment_len * idx as f32, + &self.samples[idx], + ) + }) + } +} + +/// Given a domain and a number of samples taken over that interval, return a [`InterpolationDatum`] +/// that governs how samples are extracted relative to the stored data. +/// +/// `domain` must be a bounded interval (i.e. `domain.is_finite() == true`). +/// +/// `samples` must be at least 2. +/// +/// This function will never panic, but it may return invalid indices if its assumptions are violated. +pub fn even_interp(domain: Interval, samples: usize, t: f32) -> InterpolationDatum { + let subdivs = samples - 1; + let step = domain.length() / subdivs as f32; + let t_shifted = t - domain.start(); + let steps_taken = t_shifted / step; + + if steps_taken <= 0.0 { + // To the left side of all the samples. + InterpolationDatum::LeftTail(0) + } else if steps_taken >= subdivs as f32 { + // To the right side of all the samples + InterpolationDatum::RightTail(samples - 1) + } else { + let lower_index = steps_taken.floor() as usize; + // This upper index is always valid because `steps_taken` is a finite value + // strictly less than `samples - 1`, so its floor is at most `samples - 2` + let upper_index = lower_index + 1; + let s = steps_taken.fract(); + InterpolationDatum::Between(lower_index, upper_index, s) + } +} + +/// The data core of a curve defined by unevenly-spaced samples or keyframes. The intention is to +/// use this in concert with implicitly or explicitly-defined interpolation in user-space in +/// order to implement the curve interface using [`domain`] and [`sample_with`]. +/// +/// [`domain`]: UnevenCore::domain +/// [`sample_with`]: UnevenCore::sample_with +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct UnevenCore { + /// The times for the samples of this curve. + /// + /// # Invariants + /// This must always have a length of at least 2, be sorted, and have no + /// duplicated or non-finite times. + pub times: Vec, + + /// The samples corresponding to the times for this curve. + /// + /// # Invariants + /// This must always have the same length as `times`. + pub samples: Vec, +} + +/// An error indicating that an [`UnevenCore`] could not be constructed. +#[derive(Debug, Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub enum UnevenCoreError { + /// Not enough samples were provided. + #[error("Need at least two samples to create an UnevenCore, but {samples} were provided")] + NotEnoughSamples { + /// The number of samples that were provided. + samples: usize, + }, +} + +impl UnevenCore { + /// Create a new [`UnevenCore`]. The given samples are filtered to finite times and + /// sorted internally; if there are not at least 2 valid timed samples, an error will be + /// returned. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + pub fn new(timed_samples: impl Into>) -> Result { + let timed_samples: Vec<(f32, T)> = timed_samples.into(); + + // Filter out non-finite sample times first so they don't interfere with sorting/deduplication. + let mut timed_samples: Vec<(f32, T)> = timed_samples + .into_iter() + .filter(|(t, _)| t.is_finite()) + .collect(); + timed_samples + .sort_by(|(t0, _), (t1, _)| t0.partial_cmp(t1).unwrap_or(std::cmp::Ordering::Equal)); + timed_samples.dedup_by_key(|(t, _)| *t); + + let (times, samples): (Vec, Vec) = timed_samples.into_iter().unzip(); + + if times.len() < 2 { + return Err(UnevenCoreError::NotEnoughSamples { + samples: times.len(), + }); + } + Ok(UnevenCore { times, samples }) + } + + /// The domain of the curve derived from this core. + /// + /// # Panics + /// This method may panic if the type's invariants aren't satisfied. + #[inline] + pub fn domain(&self) -> Interval { + let start = self.times.first().unwrap(); + let end = self.times.last().unwrap(); + Interval::new(*start, *end).unwrap() + } + + /// Obtain a value from the held samples using the given `interpolation` to interpolate + /// between adjacent samples. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + #[inline] + pub fn sample_with(&self, t: f32, interpolation: I) -> T + where + T: Clone, + I: Fn(&T, &T, f32) -> T, + { + match uneven_interp(&self.times, t) { + InterpolationDatum::Exact(idx) + | InterpolationDatum::LeftTail(idx) + | InterpolationDatum::RightTail(idx) => self.samples[idx].clone(), + InterpolationDatum::Between(lower_idx, upper_idx, s) => { + interpolation(&self.samples[lower_idx], &self.samples[upper_idx], s) + } + } + } + + /// Given a time `t`, obtain a [`InterpolationDatum`] which governs how interpolation might recover + /// a sample at time `t`. For example, when a [`Between`] value is returned, its contents can + /// be used to interpolate between the two contained values with the given parameter. The other + /// variants give additional context about where the value is relative to the family of samples. + /// + /// [`Between`]: `InterpolationDatum::Between` + pub fn sample_interp(&self, t: f32) -> InterpolationDatum<&T> { + uneven_interp(&self.times, t).map(|idx| &self.samples[idx]) + } + + /// Like [`sample_interp`], but the returned values include the sample times. This can be + /// useful when sampling is not scale-invariant. + /// + /// [`sample_interp`]: UnevenCore::sample_interp + pub fn sample_interp_timed(&self, t: f32) -> InterpolationDatum<(f32, &T)> { + uneven_interp(&self.times, t).map(|idx| (self.times[idx], &self.samples[idx])) + } + + /// This core, but with the sample times moved by the map `f`. + /// In principle, when `f` is monotone, this is equivalent to [`Curve::reparametrize`], + /// but the function inputs to each are inverses of one another. + /// + /// The samples are re-sorted by time after mapping and deduplicated by output time, so + /// the function `f` should generally be injective over the sample times of the curve. + /// + /// [`Curve::reparametrize`]: crate::curve::Curve::reparametrize + pub fn map_sample_times(mut self, f: impl Fn(f32) -> f32) -> UnevenCore { + let mut timed_samples: Vec<(f32, T)> = + self.times.into_iter().map(f).zip(self.samples).collect(); + timed_samples.dedup_by(|(t1, _), (t2, _)| (*t1).eq(t2)); + timed_samples.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap()); + self.times = timed_samples.iter().map(|(t, _)| t).copied().collect(); + self.samples = timed_samples.into_iter().map(|(_, x)| x).collect(); + self + } +} + +/// The data core of a curve using uneven samples (i.e. keyframes), where each sample time +/// yields some fixed number of values — the [sampling width]. This may serve as storage for +/// curves that yield vectors or iterators, and in some cases, it may be useful for cache locality +/// if the sample type can effectively be encoded as a fixed-length slice of values. +/// +/// [sampling width]: ChunkedUnevenCore::width +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ChunkedUnevenCore { + /// The times, one for each sample. + /// + /// # Invariants + /// This must always have a length of at least 2, be sorted, and have no duplicated or + /// non-finite times. + pub times: Vec, + + /// The values that are used in sampling. Each width-worth of these correspond to a single sample. + /// + /// # Invariants + /// The length of this vector must always be some fixed integer multiple of that of `times`. + pub values: Vec, +} + +/// An error that indicates that a [`ChunkedUnevenCore`] could not be formed. +#[derive(Debug, Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub enum ChunkedUnevenSampleCoreError { + /// The width of a `ChunkedUnevenCore` cannot be zero. + #[error("Chunk width must be at least 1")] + ZeroWidth, + + /// At least two sample times are necessary to interpolate in `ChunkedUnevenCore`. + #[error("Need at least two samples to create an UnevenCore, but {samples} were provided")] + NotEnoughSamples { + /// The number of samples that were provided. + samples: usize, + }, + + /// The length of the value buffer is supposed to be the `width` times the number of samples. + #[error("Expected {expected} total values based on width, but {actual} were provided")] + MismatchedLengths { + /// The expected length of the value buffer. + expected: usize, + /// The actual length of the value buffer. + actual: usize, + }, +} + +impl ChunkedUnevenCore { + /// Create a new [`ChunkedUnevenCore`]. The given `times` are sorted, filtered to finite times, + /// and deduplicated. See the [type-level documentation] for more information about this type. + /// + /// Produces an error in any of the following circumstances: + /// - `width` is zero. + /// - `times` has less than `2` valid entries. + /// - `values` has the incorrect length relative to `times`. + /// + /// [type-level documentation]: ChunkedUnevenCore + pub fn new( + times: impl Into>, + values: impl Into>, + width: usize, + ) -> Result { + let times: Vec = times.into(); + let values: Vec = values.into(); + + if width == 0 { + return Err(ChunkedUnevenSampleCoreError::ZeroWidth); + } + + let times = filter_sort_dedup_times(times); + + if times.len() < 2 { + return Err(ChunkedUnevenSampleCoreError::NotEnoughSamples { + samples: times.len(), + }); + } + + if values.len() != times.len() * width { + return Err(ChunkedUnevenSampleCoreError::MismatchedLengths { + expected: times.len() * width, + actual: values.len(), + }); + } + + Ok(Self { times, values }) + } + + /// The domain of the curve derived from this core. + /// + /// # Panics + /// This may panic if this type's invariants aren't met. + #[inline] + pub fn domain(&self) -> Interval { + let start = self.times.first().unwrap(); + let end = self.times.last().unwrap(); + Interval::new(*start, *end).unwrap() + } + + /// The sample width: the number of values that are contained in each sample. + #[inline] + pub fn width(&self) -> usize { + self.values.len() / self.times.len() + } + + /// Given a time `t`, obtain a [`InterpolationDatum`] which governs how interpolation might recover + /// a sample at time `t`. For example, when a [`Between`] value is returned, its contents can + /// be used to interpolate between the two contained values with the given parameter. The other + /// variants give additional context about where the value is relative to the family of samples. + /// + /// [`Between`]: `InterpolationDatum::Between` + #[inline] + pub fn sample_interp(&self, t: f32) -> InterpolationDatum<&[T]> { + uneven_interp(&self.times, t).map(|idx| self.time_index_to_slice(idx)) + } + + /// Like [`sample_interp`], but the returned values include the sample times. This can be + /// useful when sampling is not scale-invariant. + /// + /// [`sample_interp`]: ChunkedUnevenCore::sample_interp + pub fn sample_interp_timed(&self, t: f32) -> InterpolationDatum<(f32, &[T])> { + uneven_interp(&self.times, t).map(|idx| (self.times[idx], self.time_index_to_slice(idx))) + } + + /// Given an index in [times], returns the slice of [values] that correspond to the sample at + /// that time. + /// + /// [times]: ChunkedUnevenCore::times + /// [values]: ChunkedUnevenCore::values + #[inline] + fn time_index_to_slice(&self, idx: usize) -> &[T] { + let width = self.width(); + let lower_idx = width * idx; + let upper_idx = lower_idx + width; + &self.values[lower_idx..upper_idx] + } +} + +/// Sort the given times, deduplicate them, and filter them to only finite times. +fn filter_sort_dedup_times(times: Vec) -> Vec { + // Filter before sorting/deduplication so that NAN doesn't interfere with them. + let mut times: Vec = times.into_iter().filter(|t| t.is_finite()).collect(); + times.sort_by(|t0, t1| t0.partial_cmp(t1).unwrap()); + times.dedup(); + times +} + +/// Given a list of `times` and a target value, get the interpolation relationship for the +/// target value in terms of the indices of the starting list. In a sense, this encapsulates the +/// heart of uneven/keyframe sampling. +/// +/// `times` is assumed to be sorted, deduplicated, and consisting only of finite values. It is also +/// assumed to contain at least two values. +/// +/// # Panics +/// This function will panic if `times` contains NAN. +pub fn uneven_interp(times: &[f32], t: f32) -> InterpolationDatum { + match times.binary_search_by(|pt| pt.partial_cmp(&t).unwrap()) { + Ok(index) => InterpolationDatum::Exact(index), + Err(index) => { + if index == 0 { + // This is before the first keyframe. + InterpolationDatum::LeftTail(0) + } else if index >= times.len() { + // This is after the last keyframe. + InterpolationDatum::RightTail(times.len() - 1) + } else { + // This is actually in the middle somewhere. + let t_lower = times[index - 1]; + let t_upper = times[index]; + let s = (t - t_lower) / (t_upper - t_lower); + InterpolationDatum::Between(index - 1, index, s) + } + } + } +} diff --git a/crates/bevy_math/src/curve/interval.rs b/crates/bevy_math/src/curve/interval.rs new file mode 100644 index 0000000000000..0edbd42aad94a --- /dev/null +++ b/crates/bevy_math/src/curve/interval.rs @@ -0,0 +1,321 @@ +//! The [`Interval`] type for nonempty intervals used by the [`Curve`](super::Curve) trait. + +use std::{ + cmp::{max_by, min_by}, + ops::RangeInclusive, +}; +use thiserror::Error; + +#[cfg(feature = "bevy_reflect")] +use bevy_reflect::Reflect; +#[cfg(all(feature = "serialize", feature = "bevy_reflect"))] +use bevy_reflect::{ReflectDeserialize, ReflectSerialize}; + +/// A nonempty closed interval, possibly infinite in either direction. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect), reflect(Debug, PartialEq))] +#[cfg_attr( + all(feature = "serialize", feature = "bevy_reflect"), + reflect(Serialize, Deserialize) +)] +pub struct Interval { + start: f32, + end: f32, +} + +/// An error that indicates that an operation would have returned an invalid [`Interval`]. +#[derive(Debug, Error)] +#[error("The resulting interval would be invalid (empty or with a NaN endpoint)")] +pub struct InvalidIntervalError; + +/// An error indicating that an infinite interval was used where it was inappropriate. +#[derive(Debug, Error)] +#[error("This operation does not make sense in the context of an infinite interval")] +pub struct InfiniteIntervalError; + +/// An error indicating that spaced points on an interval could not be formed. +#[derive(Debug, Error)] +#[error("Could not sample evenly-spaced points with these inputs")] +pub enum SpacedPointsError { + /// This operation failed because fewer than two points were requested. + #[error("Parameter `points` must be at least 2")] + NotEnoughPoints, + + /// This operation failed because the underlying interval is unbounded. + #[error("Cannot sample evenly-spaced points on an infinite interval")] + InfiniteInterval(InfiniteIntervalError), +} + +impl Interval { + /// Create a new [`Interval`] with the specified `start` and `end`. The interval can be infinite + /// but cannot be empty and neither endpoint can be NaN; invalid parameters will result in an error. + pub fn new(start: f32, end: f32) -> Result { + if start >= end || start.is_nan() || end.is_nan() { + Err(InvalidIntervalError) + } else { + Ok(Self { start, end }) + } + } + + /// Get the start of this interval. + #[inline] + pub fn start(self) -> f32 { + self.start + } + + /// Get the end of this interval. + #[inline] + pub fn end(self) -> f32 { + self.end + } + + /// Create an [`Interval`] by intersecting this interval with another. Returns an error if the + /// intersection would be empty (hence an invalid interval). + pub fn intersect(self, other: Interval) -> Result { + let lower = max_by(self.start, other.start, |x, y| x.partial_cmp(y).unwrap()); + let upper = min_by(self.end, other.end, |x, y| x.partial_cmp(y).unwrap()); + Self::new(lower, upper) + } + + /// Get the length of this interval. Note that the result may be infinite (`f32::INFINITY`). + #[inline] + pub fn length(self) -> f32 { + self.end - self.start + } + + /// Returns `true` if both endpoints of this interval are finite. + #[inline] + pub fn is_finite(self) -> bool { + self.length().is_finite() + } + + /// Returns `true` if this interval has a finite left endpoint. + #[inline] + pub fn is_left_finite(self) -> bool { + self.start.is_finite() + } + + /// Returns `true` if this interval has a finite right endpoint. + #[inline] + pub fn is_right_finite(self) -> bool { + self.end.is_finite() + } + + /// Returns `true` if `item` is contained in this interval. + #[inline] + pub fn contains(self, item: f32) -> bool { + (self.start..=self.end).contains(&item) + } + + /// Clamp the given `value` to lie within this interval. + #[inline] + pub fn clamp(self, value: f32) -> f32 { + value.clamp(self.start, self.end) + } + + /// Get the linear map which maps this curve onto the `other` one. Returns an error if either + /// interval is infinite. + pub fn linear_map_to(self, other: Self) -> Result f32, InfiniteIntervalError> { + if !self.is_finite() || !other.is_finite() { + return Err(InfiniteIntervalError); + } + let scale = other.length() / self.length(); + Ok(move |x| (x - self.start) * scale + other.start) + } + + /// Get an iterator over equally-spaced points from this interval in increasing order. + /// Returns an error if `points` is less than 2 or if the interval is unbounded. + pub fn spaced_points( + self, + points: usize, + ) -> Result, SpacedPointsError> { + if points < 2 { + return Err(SpacedPointsError::NotEnoughPoints); + } + if !self.is_finite() { + return Err(SpacedPointsError::InfiniteInterval(InfiniteIntervalError)); + } + let step = self.length() / (points - 1) as f32; + Ok((0..points).map(move |x| self.start + x as f32 * step)) + } +} + +impl TryFrom> for Interval { + type Error = InvalidIntervalError; + fn try_from(range: RangeInclusive) -> Result { + Interval::new(*range.start(), *range.end()) + } +} + +/// Create an [`Interval`] with a given `start` and `end`. Alias of [`Interval::new`]. +pub fn interval(start: f32, end: f32) -> Result { + Interval::new(start, end) +} + +/// The [`Interval`] from negative infinity to infinity. +pub fn everywhere() -> Interval { + Interval::new(f32::NEG_INFINITY, f32::INFINITY).unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::{assert_abs_diff_eq, AbsDiffEq}; + + #[test] + fn make_intervals() { + let ivl = Interval::new(2.0, -1.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(-0.0, 0.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NEG_INFINITY, 15.5); + assert!(ivl.is_ok()); + + let ivl = Interval::new(-2.0, f32::INFINITY); + assert!(ivl.is_ok()); + + let ivl = Interval::new(f32::NEG_INFINITY, f32::INFINITY); + assert!(ivl.is_ok()); + + let ivl = Interval::new(f32::INFINITY, f32::NEG_INFINITY); + assert!(ivl.is_err()); + + let ivl = Interval::new(-1.0, f32::NAN); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NAN, -42.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NAN, f32::NAN); + assert!(ivl.is_err()); + + let ivl = Interval::new(0.0, 1.0); + assert!(ivl.is_ok()); + } + + #[test] + fn lengths() { + let ivl = interval(-5.0, 10.0).unwrap(); + assert!((ivl.length() - 15.0).abs() <= f32::EPSILON); + + let ivl = interval(5.0, 100.0).unwrap(); + assert!((ivl.length() - 95.0).abs() <= f32::EPSILON); + + let ivl = interval(0.0, f32::INFINITY).unwrap(); + assert_eq!(ivl.length(), f32::INFINITY); + + let ivl = interval(f32::NEG_INFINITY, 0.0).unwrap(); + assert_eq!(ivl.length(), f32::INFINITY); + + let ivl = everywhere(); + assert_eq!(ivl.length(), f32::INFINITY); + } + + #[test] + fn intersections() { + let ivl1 = interval(-1.0, 1.0).unwrap(); + let ivl2 = interval(0.0, 2.0).unwrap(); + let ivl3 = interval(-3.0, 0.0).unwrap(); + let ivl4 = interval(0.0, f32::INFINITY).unwrap(); + let ivl5 = interval(f32::NEG_INFINITY, 0.0).unwrap(); + let ivl6 = everywhere(); + + assert!(ivl1 + .intersect(ivl2) + .is_ok_and(|ivl| ivl == interval(0.0, 1.0).unwrap())); + assert!(ivl1 + .intersect(ivl3) + .is_ok_and(|ivl| ivl == interval(-1.0, 0.0).unwrap())); + assert!(ivl2.intersect(ivl3).is_err()); + assert!(ivl1 + .intersect(ivl4) + .is_ok_and(|ivl| ivl == interval(0.0, 1.0).unwrap())); + assert!(ivl1 + .intersect(ivl5) + .is_ok_and(|ivl| ivl == interval(-1.0, 0.0).unwrap())); + assert!(ivl4.intersect(ivl5).is_err()); + assert_eq!(ivl1.intersect(ivl6).unwrap(), ivl1); + assert_eq!(ivl4.intersect(ivl6).unwrap(), ivl4); + assert_eq!(ivl5.intersect(ivl6).unwrap(), ivl5); + } + + #[test] + fn containment() { + let ivl = interval(0.0, 1.0).unwrap(); + assert!(ivl.contains(0.0)); + assert!(ivl.contains(1.0)); + assert!(ivl.contains(0.5)); + assert!(!ivl.contains(-0.1)); + assert!(!ivl.contains(1.1)); + assert!(!ivl.contains(f32::NAN)); + + let ivl = interval(3.0, f32::INFINITY).unwrap(); + assert!(ivl.contains(3.0)); + assert!(ivl.contains(2.0e5)); + assert!(ivl.contains(3.5e6)); + assert!(!ivl.contains(2.5)); + assert!(!ivl.contains(-1e5)); + assert!(!ivl.contains(f32::NAN)); + } + + #[test] + fn finiteness() { + assert!(!everywhere().is_finite()); + assert!(interval(0.0, 3.5e5).unwrap().is_finite()); + assert!(!interval(-2.0, f32::INFINITY).unwrap().is_finite()); + assert!(!interval(f32::NEG_INFINITY, 5.0).unwrap().is_finite()); + } + + #[test] + fn linear_maps() { + let ivl1 = interval(-3.0, 5.0).unwrap(); + let ivl2 = interval(0.0, 1.0).unwrap(); + let map = ivl1.linear_map_to(ivl2); + assert!(map.is_ok_and(|f| f(-3.0).abs_diff_eq(&0.0, f32::EPSILON) + && f(5.0).abs_diff_eq(&1.0, f32::EPSILON) + && f(1.0).abs_diff_eq(&0.5, f32::EPSILON))); + + let ivl1 = interval(0.0, 1.0).unwrap(); + let ivl2 = everywhere(); + assert!(ivl1.linear_map_to(ivl2).is_err()); + + let ivl1 = interval(f32::NEG_INFINITY, -4.0).unwrap(); + let ivl2 = interval(0.0, 1.0).unwrap(); + assert!(ivl1.linear_map_to(ivl2).is_err()); + } + + #[test] + fn spaced_points() { + let ivl = interval(0.0, 50.0).unwrap(); + let points_iter = ivl.spaced_points(1); + assert!(points_iter.is_err()); + let points_iter: Vec = ivl.spaced_points(2).unwrap().collect(); + assert_abs_diff_eq!(points_iter[0], 0.0); + assert_abs_diff_eq!(points_iter[1], 50.0); + let points_iter = ivl.spaced_points(21).unwrap(); + let step = ivl.length() / 20.0; + for (index, point) in points_iter.enumerate() { + let expected = ivl.start() + step * index as f32; + assert_abs_diff_eq!(point, expected); + } + + let ivl = interval(-21.0, 79.0).unwrap(); + let points_iter = ivl.spaced_points(10000).unwrap(); + let step = ivl.length() / 9999.0; + for (index, point) in points_iter.enumerate() { + let expected = ivl.start() + step * index as f32; + assert_abs_diff_eq!(point, expected); + } + + let ivl = interval(-1.0, f32::INFINITY).unwrap(); + let points_iter = ivl.spaced_points(25); + assert!(points_iter.is_err()); + + let ivl = interval(f32::NEG_INFINITY, -25.0).unwrap(); + let points_iter = ivl.spaced_points(9); + assert!(points_iter.is_err()); + } +} diff --git a/crates/bevy_math/src/curve/mod.rs b/crates/bevy_math/src/curve/mod.rs new file mode 100644 index 0000000000000..b4e7c917e0946 --- /dev/null +++ b/crates/bevy_math/src/curve/mod.rs @@ -0,0 +1,1050 @@ +//! The [`Curve`] trait, used to describe curves in a number of different domains. This module also +//! contains the [`Interval`] type, along with a selection of core data structures used to back +//! curves that are interpolated from samples. + +pub mod cores; +pub mod interval; + +pub use interval::{everywhere, interval, Interval}; + +use crate::StableInterpolate; +use cores::{EvenCore, EvenCoreError, UnevenCore, UnevenCoreError}; +use interval::{InfiniteIntervalError, InvalidIntervalError}; +use std::{marker::PhantomData, ops::Deref}; +use thiserror::Error; + +#[cfg(feature = "bevy_reflect")] +use bevy_reflect::Reflect; + +/// A trait for a type that can represent values of type `T` parametrized over a fixed interval. +/// Typical examples of this are actual geometric curves where `T: VectorSpace`, but other kinds +/// of interpolable data can be represented instead (or in addition). +pub trait Curve { + /// The interval over which this curve is parametrized. + fn domain(&self) -> Interval; + + /// Sample a point on this curve at the parameter value `t`, extracting the associated value. + fn sample(&self, t: f32) -> T; + + /// Sample a point on this curve at the parameter value `t`, returning `None` if the point is + /// outside of the curve's domain. + fn sample_checked(&self, t: f32) -> Option { + match self.domain().contains(t) { + true => Some(self.sample(t)), + false => None, + } + } + + /// Sample a point on this curve at the parameter value `t`, clamping `t` to lie inside the + /// domain of the curve. + fn sample_clamped(&self, t: f32) -> T { + let t = self.domain().clamp(t); + self.sample(t) + } + + /// Resample this [`Curve`] to produce a new one that is defined by interpolation over equally + /// spaced values, using the provided `interpolation` to interpolate between adjacent samples. + /// A total of `samples` samples are used, although at least two samples are required to produce + /// well-formed output. If fewer than two samples are provided, or if this curve has an unbounded + /// domain, then a [`ResamplingError`] is returned. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + /// + /// # Example + /// ``` + /// # use bevy_math::*; + /// # use bevy_math::curve::*; + /// let quarter_rotation = function_curve(interval(0.0, 90.0).unwrap(), |t| Rot2::degrees(t)); + /// // A curve which only stores three data points and uses `nlerp` to interpolate them: + /// let resampled_rotation = quarter_rotation.resample(3, |x, y, t| x.nlerp(*y, t)); + /// ``` + fn resample( + &self, + samples: usize, + interpolation: I, + ) -> Result, ResamplingError> + where + Self: Sized, + I: Fn(&T, &T, f32) -> T, + { + if samples < 2 { + return Err(ResamplingError::NotEnoughSamples(samples)); + } + if !self.domain().is_finite() { + return Err(ResamplingError::InfiniteInterval(InfiniteIntervalError)); + } + + let samples: Vec = self + .domain() + .spaced_points(samples) + .unwrap() + .map(|t| self.sample(t)) + .collect(); + Ok(SampleCurve { + core: EvenCore { + domain: self.domain(), + samples, + }, + interpolation, + }) + } + + /// Resample this [`Curve`] to produce a new one that is defined by interpolation over equally + /// spaced values. A total of `samples` samples are used, although at least two samples are + /// required in order to produce well-formed output. If fewer than two samples are provided, + /// or if this curve has an unbounded domain, then a [`ResamplingError`] is returned. + fn resample_auto(&self, samples: usize) -> Result, ResamplingError> + where + T: StableInterpolate, + { + if samples < 2 { + return Err(ResamplingError::NotEnoughSamples(samples)); + } + if !self.domain().is_finite() { + return Err(ResamplingError::InfiniteInterval(InfiniteIntervalError)); + } + + let samples: Vec = self + .domain() + .spaced_points(samples) + .unwrap() + .map(|t| self.sample(t)) + .collect(); + Ok(SampleAutoCurve { + core: EvenCore { + domain: self.domain(), + samples, + }, + }) + } + + /// Extract an iterator over evenly-spaced samples from this curve. If `samples` is less than 2 + /// or if this curve has unbounded domain, then an error is returned instead. + fn samples(&self, samples: usize) -> Result, ResamplingError> + where + Self: Sized, + { + if samples < 2 { + return Err(ResamplingError::NotEnoughSamples(samples)); + } + if !self.domain().is_finite() { + return Err(ResamplingError::InfiniteInterval(InfiniteIntervalError)); + } + + // Unwrap on `spaced_points` always succeeds because its error conditions are handled + // above. + Ok(self + .domain() + .spaced_points(samples) + .unwrap() + .map(|t| self.sample(t))) + } + + /// Resample this [`Curve`] to produce a new one that is defined by interpolation over samples + /// taken at a given set of times. The given `interpolation` is used to interpolate adjacent + /// samples, and the `sample_times` are expected to contain at least two valid times within the + /// curve's domain interval. + /// + /// Redundant sample times, non-finite sample times, and sample times outside of the domain + /// are simply filtered out. With an insufficient quantity of data, a [`ResamplingError`] is + /// returned. + /// + /// The domain of the produced curve stretches between the first and last sample times of the + /// iterator. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + fn resample_uneven( + &self, + sample_times: impl IntoIterator, + interpolation: I, + ) -> Result, ResamplingError> + where + Self: Sized, + I: Fn(&T, &T, f32) -> T, + { + let mut times: Vec = sample_times + .into_iter() + .filter(|t| t.is_finite() && self.domain().contains(*t)) + .collect(); + times.dedup_by(|t1, t2| (*t1).eq(t2)); + if times.len() < 2 { + return Err(ResamplingError::NotEnoughSamples(times.len())); + } + times.sort_by(|t1, t2| t1.partial_cmp(t2).unwrap()); + let samples = times.iter().copied().map(|t| self.sample(t)).collect(); + Ok(UnevenSampleCurve { + core: UnevenCore { times, samples }, + interpolation, + }) + } + + /// Resample this [`Curve`] to produce a new one that is defined by interpolation over samples + /// taken at the given set of times. The given `sample_times` are expected to contain at least + /// two valid times within the curve's domain interval. + /// + /// Redundant sample times, non-finite sample times, and sample times outside of the domain + /// are simply filtered out. With an insufficient quantity of data, a [`ResamplingError`] is + /// returned. + /// + /// The domain of the produced [`UnevenSampleAutoCurve`] stretches between the first and last + /// sample times of the iterator. + fn resample_uneven_auto( + &self, + sample_times: impl IntoIterator, + ) -> Result, ResamplingError> + where + Self: Sized, + T: StableInterpolate, + { + let mut times: Vec = sample_times + .into_iter() + .filter(|t| t.is_finite() && self.domain().contains(*t)) + .collect(); + times.dedup_by(|t1, t2| (*t1).eq(t2)); + if times.len() < 2 { + return Err(ResamplingError::NotEnoughSamples(times.len())); + } + times.sort_by(|t1, t2| t1.partial_cmp(t2).unwrap()); + let samples = times.iter().copied().map(|t| self.sample(t)).collect(); + Ok(UnevenSampleAutoCurve { + core: UnevenCore { times, samples }, + }) + } + + /// Create a new curve by mapping the values of this curve via a function `f`; i.e., if the + /// sample at time `t` for this curve is `x`, the value at time `t` on the new curve will be + /// `f(x)`. + fn map(self, f: F) -> MapCurve + where + Self: Sized, + F: Fn(T) -> S, + { + MapCurve { + preimage: self, + f, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] whose parameter space is related to the parameter space of this curve + /// by `f`. For each time `t`, the sample from the new curve at time `t` is the sample from + /// this curve at time `f(t)`. The given `domain` will be the domain of the new curve. The + /// function `f` is expected to take `domain` into `self.domain()`. + /// + /// Note that this is the opposite of what one might expect intuitively; for example, if this + /// curve has a parameter interval of `[0, 1]`, then linearly mapping the parameter domain to + /// `[0, 2]` would be performed as follows, dividing by what might be perceived as the scaling + /// factor rather than multiplying: + /// ``` + /// # use bevy_math::curve::*; + /// let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let domain = my_curve.domain(); + /// let scaled_curve = my_curve.reparametrize(interval(0.0, 2.0).unwrap(), |t| t / 2.0); + /// ``` + /// This kind of linear remapping is provided by the convenience method + /// [`Curve::reparametrize_linear`], which requires only the desired domain for the new curve. + /// + /// # Examples + /// ``` + /// // Reverse a curve: + /// # use bevy_math::curve::*; + /// # use bevy_math::vec2; + /// let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let domain = my_curve.domain(); + /// let reversed_curve = my_curve.reparametrize(domain, |t| domain.end() - t); + /// + /// // Take a segment of a curve: + /// # let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let curve_segment = my_curve.reparametrize(interval(0.0, 0.5).unwrap(), |t| 0.5 + t); + /// + /// // Reparametrize by an easing curve: + /// # let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// # let easing_curve = constant_curve(interval(0.0, 1.0).unwrap(), vec2(1.0, 1.0)); + /// let domain = my_curve.domain(); + /// let eased_curve = my_curve.reparametrize(domain, |t| easing_curve.sample(t).y); + /// ``` + fn reparametrize(self, domain: Interval, f: F) -> ReparamCurve + where + Self: Sized, + F: Fn(f32) -> f32, + { + ReparamCurve { + domain, + base: self, + f, + _phantom: PhantomData, + } + } + + /// Linearly reparametrize this [`Curve`], producing a new curve whose domain is the given + /// `domain` instead of the current one. This operation is only valid for curves with finite + /// domains; if either this curve's domain or the given `domain` is infinite, an + /// [`InfiniteIntervalError`] is returned. + fn reparametrize_linear( + self, + domain: Interval, + ) -> Result, InfiniteIntervalError> + where + Self: Sized, + { + if !domain.is_finite() { + return Err(InfiniteIntervalError); + } + + Ok(LinearReparamCurve { + base: self, + new_domain: domain, + _phantom: PhantomData, + }) + } + + /// Reparametrize this [`Curve`] by sampling from another curve. + /// + /// TODO: Figure out what the right signature for this is; currently, this is less flexible than + /// just using `C`, because `&C` is a curve anyway, but this version probably footguns less. + fn reparametrize_by_curve(self, other: &C) -> CurveReparamCurve + where + Self: Sized, + C: Curve, + { + CurveReparamCurve { + base: self, + reparam_curve: other, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] which is the graph of this one; that is, its output includes the + /// parameter itself in the samples. For example, if this curve outputs `x` at time `t`, then + /// the produced curve will produce `(t, x)` at time `t`. + fn graph(self) -> GraphCurve + where + Self: Sized, + { + GraphCurve { + base: self, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] by zipping this curve together with another. The sample at time `t` + /// in the new curve is `(x, y)`, where `x` is the sample of `self` at time `t` and `y` is the + /// sample of `other` at time `t`. The domain of the new curve is the intersection of the + /// domains of its constituents. If the domain intersection would be empty, an + /// [`InvalidIntervalError`] is returned. + fn zip(self, other: C) -> Result, InvalidIntervalError> + where + Self: Sized, + C: Curve + Sized, + { + let domain = self.domain().intersect(other.domain())?; + Ok(ProductCurve { + domain, + first: self, + second: other, + _phantom: PhantomData, + }) + } + + /// Create a new [`Curve`] by composing this curve end-to-end with another, producing another curve + /// with outputs of the same type. The domain of the other curve is translated so that its start + /// coincides with where this curve ends. A [`CompositionError`] is returned if this curve's domain + /// doesn't have a finite right endpoint or if `other`'s domain doesn't have a finite left endpoint. + fn compose(self, other: C) -> Result, CompositionError> + where + Self: Sized, + C: Curve, + { + if !self.domain().is_right_finite() { + return Err(CompositionError::RightInfiniteFirst); + } + if !other.domain().is_left_finite() { + return Err(CompositionError::LeftInfiniteSecond); + } + Ok(ComposeCurve { + first: self, + second: other, + _phantom: PhantomData, + }) + } + + /// Borrow this curve rather than taking ownership of it. This is essentially an alias for a + /// prefix `&`; the point is that intermediate operations can be performed while retaining + /// access to the original curve. + /// + /// # Example + /// ``` + /// # use bevy_math::curve::*; + /// let my_curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t * t + 1.0); + /// // Borrow `my_curve` long enough to resample a mapped version. Note that `map` takes + /// // ownership of its input. + /// let samples = my_curve.by_ref().map(|x| x * 2.0).resample_auto(100).unwrap(); + /// // Do something else with `my_curve` since we retained ownership: + /// let new_curve = my_curve.reparametrize_linear(interval(-1.0, 1.0).unwrap()).unwrap(); + /// ``` + fn by_ref(&self) -> &Self + where + Self: Sized, + { + self + } +} + +impl Curve for D +where + C: Curve + ?Sized, + D: Deref, +{ + fn domain(&self) -> Interval { + >::domain(self) + } + + fn sample(&self, t: f32) -> T { + >::sample(self, t) + } +} + +/// An error indicating that a resampling operation could not be performed because of +/// malformed inputs. +#[derive(Debug, Error)] +#[error("Could not resample from this curve because of bad inputs")] +pub enum ResamplingError { + /// This resampling operation was not provided with enough samples to have well-formed output. + #[error("Not enough samples to construct resampled curve")] + NotEnoughSamples(usize), + + /// This resampling operation failed because of an unbounded interval. + #[error("Could not resample because this curve has unbounded domain")] + InfiniteInterval(InfiniteIntervalError), +} + +/// An error indicating that an end-to-end composition couldn't be performed because of +/// malformed inputs. +#[derive(Debug, Error)] +#[error("Could not compose these curves together")] +pub enum CompositionError { + /// The right endpoint of the first curve was infinite. + #[error("The first curve has an infinite right endpoint")] + RightInfiniteFirst, + + /// The left endpoint of the second curve was infinite. + #[error("The second curve has an infinite left endpoint")] + LeftInfiniteSecond, +} + +/// A curve which takes a constant value over its domain. +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ConstantCurve { + domain: Interval, + value: T, +} + +impl Curve for ConstantCurve +where + T: Clone, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample(&self, _t: f32) -> T { + self.value.clone() + } +} + +/// A curve defined by a function. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct FunctionCurve { + domain: Interval, + f: F, + _phantom: PhantomData, +} + +impl Curve for FunctionCurve +where + F: Fn(f32) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample(&self, t: f32) -> T { + (self.f)(t) + } +} + +/// A curve that is defined by explicit neighbor interpolation over a set of samples. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct SampleCurve { + core: EvenCore, + interpolation: I, +} + +impl Curve for SampleCurve +where + T: Clone, + I: Fn(&T, &T, f32) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.core.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + self.core.sample_with(t, &self.interpolation) + } +} + +impl SampleCurve { + /// Create a new [`SampleCurve`] using the specified `interpolation` to interpolate between + /// the given `samples`. An error is returned if there are not at least 2 samples or if the + /// given `domain` is unbounded. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + pub fn new( + domain: Interval, + samples: impl Into>, + interpolation: I, + ) -> Result + where + I: Fn(&T, &T, f32) -> T, + { + Ok(Self { + core: EvenCore::new(domain, samples)?, + interpolation, + }) + } +} + +/// A curve that is defined by neighbor interpolation over a set of samples. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct SampleAutoCurve { + core: EvenCore, +} + +impl Curve for SampleAutoCurve +where + T: StableInterpolate, +{ + #[inline] + fn domain(&self) -> Interval { + self.core.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + self.core + .sample_with(t, ::interpolate_stable) + } +} + +impl SampleAutoCurve { + /// Create a new [`SampleCurve`] using type-inferred interpolation to interpolate between + /// the given `samples`. An error is returned if there are not at least 2 samples or if the + /// given `domain` is unbounded. + pub fn new(domain: Interval, samples: impl Into>) -> Result { + Ok(Self { + core: EvenCore::new(domain, samples)?, + }) + } +} + +/// A curve that is defined by interpolation over unevenly spaced samples with explicit +/// interpolation. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct UnevenSampleCurve { + core: UnevenCore, + interpolation: I, +} + +impl Curve for UnevenSampleCurve +where + T: Clone, + I: Fn(&T, &T, f32) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.core.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + self.core.sample_with(t, &self.interpolation) + } +} + +impl UnevenSampleCurve { + /// Create a new [`UnevenSampleCurve`] using the provided `interpolation` to interpolate + /// between adjacent `timed_samples`. The given samples are filtered to finite times and + /// sorted internally; if there are not at least 2 valid timed samples, an error will be + /// returned. + /// + /// The interpolation takes two values by reference together with a scalar parameter and + /// produces an owned value. The expectation is that `interpolation(&x, &y, 0.0)` and + /// `interpolation(&x, &y, 1.0)` are equivalent to `x` and `y` respectively. + pub fn new( + timed_samples: impl Into>, + interpolation: I, + ) -> Result { + Ok(Self { + core: UnevenCore::new(timed_samples)?, + interpolation, + }) + } + + /// This [`UnevenSampleAutoCurve`], but with the sample times moved by the map `f`. + /// In principle, when `f` is monotone, this is equivalent to [`Curve::reparametrize`], + /// but the function inputs to each are inverses of one another. + /// + /// The samples are re-sorted by time after mapping and deduplicated by output time, so + /// the function `f` should generally be injective over the sample times of the curve. + pub fn map_sample_times(self, f: impl Fn(f32) -> f32) -> UnevenSampleCurve { + Self { + core: self.core.map_sample_times(f), + interpolation: self.interpolation, + } + } +} + +/// A curve that is defined by interpolation over unevenly spaced samples. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct UnevenSampleAutoCurve { + core: UnevenCore, +} + +impl Curve for UnevenSampleAutoCurve +where + T: StableInterpolate, +{ + #[inline] + fn domain(&self) -> Interval { + self.core.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + self.core + .sample_with(t, ::interpolate_stable) + } +} + +impl UnevenSampleAutoCurve { + /// Create a new [`UnevenSampleAutoCurve`] from a given set of timed samples, interpolated + /// using the The samples are filtered to finite times and + /// sorted internally; if there are not at least 2 valid timed samples, an error will be + /// returned. + pub fn new(timed_samples: impl Into>) -> Result { + Ok(Self { + core: UnevenCore::new(timed_samples)?, + }) + } + + /// This [`UnevenSampleAutoCurve`], but with the sample times moved by the map `f`. + /// In principle, when `f` is monotone, this is equivalent to [`Curve::reparametrize`], + /// but the function inputs to each are inverses of one another. + /// + /// The samples are re-sorted by time after mapping and deduplicated by output time, so + /// the function `f` should generally be injective over the sample times of the curve. + pub fn map_sample_times(self, f: impl Fn(f32) -> f32) -> UnevenSampleAutoCurve { + Self { + core: self.core.map_sample_times(f), + } + } +} + +/// A curve whose samples are defined by mapping samples from another curve through a +/// given function. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct MapCurve { + preimage: C, + f: F, + _phantom: PhantomData<(S, T)>, +} + +impl Curve for MapCurve +where + C: Curve, + F: Fn(S) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.preimage.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + (self.f)(self.preimage.sample(t)) + } +} + +/// A curve whose sample space is mapped onto that of some base curve's before sampling. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ReparamCurve { + domain: Interval, + base: C, + f: F, + _phantom: PhantomData, +} + +impl Curve for ReparamCurve +where + C: Curve, + F: Fn(f32) -> f32, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample(&self, t: f32) -> T { + self.base.sample((self.f)(t)) + } +} + +/// A curve that has had its domain altered by a linear remapping. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct LinearReparamCurve { + base: C, + /// Invariants: This interval must always be bounded. + new_domain: Interval, + _phantom: PhantomData, +} + +impl Curve for LinearReparamCurve +where + C: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.new_domain + } + + #[inline] + fn sample(&self, t: f32) -> T { + let f = self.new_domain.linear_map_to(self.base.domain()).unwrap(); + self.base.sample(f(t)) + } +} + +/// A curve that has been reparametrized by another curve, using that curve to transform the +/// sample times before sampling. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct CurveReparamCurve { + base: C, + reparam_curve: D, + _phantom: PhantomData, +} + +impl Curve for CurveReparamCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.reparam_curve.domain() + } + + #[inline] + fn sample(&self, t: f32) -> T { + let sample_time = self.reparam_curve.sample(t); + self.base.sample(sample_time) + } +} + +/// A curve that is the graph of another curve over its parameter space. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct GraphCurve { + base: C, + _phantom: PhantomData, +} + +impl Curve<(f32, T)> for GraphCurve +where + C: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.base.domain() + } + + #[inline] + fn sample(&self, t: f32) -> (f32, T) { + (t, self.base.sample(t)) + } +} + +/// A curve that combines the data from two constituent curves into a tuple output type. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ProductCurve { + domain: Interval, + first: C, + second: D, + _phantom: PhantomData<(S, T)>, +} + +impl Curve<(S, T)> for ProductCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample(&self, t: f32) -> (S, T) { + (self.first.sample(t), self.second.sample(t)) + } +} + +/// The curve that results from composing one curve with another. The second curve is +/// effectively reparametrized so that its start is at the end of the first. +/// +/// For this to be well-formed, the first curve's domain must be right-finite and the second's +/// must be left-finite. +pub struct ComposeCurve { + first: C, + second: D, + _phantom: PhantomData, +} + +impl Curve for ComposeCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + // This unwrap always succeeds because `first` has a valid Interval as its domain and the + // length of `second` cannot be NAN. It's still fine if it's infinity. + Interval::new( + self.first.domain().start(), + self.first.domain().end() + self.second.domain().length(), + ) + .unwrap() + } + + #[inline] + fn sample(&self, t: f32) -> T { + if t > self.first.domain().end() { + self.second.sample( + // `t - first.domain.end` computes the offset into the domain of the second. + t - self.first.domain().end() + self.second.domain().start(), + ) + } else { + self.first.sample(t) + } + } +} + +/// Create a [`Curve`] that constantly takes the given `value` over the given `domain`. +pub fn constant_curve(domain: Interval, value: T) -> ConstantCurve { + ConstantCurve { domain, value } +} + +/// Convert the given function `f` into a [`Curve`] with the given `domain`, sampled by +/// evaluating the function. +pub fn function_curve(domain: Interval, f: F) -> FunctionCurve +where + F: Fn(f32) -> T, +{ + FunctionCurve { + domain, + f, + _phantom: PhantomData, + } +} + +/// Flip a curve that outputs tuples so that the tuples are arranged the other way. +pub fn flip(curve: impl Curve<(S, T)>) -> impl Curve<(T, S)> { + curve.map(|(s, t)| (t, s)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Quat; + use approx::{assert_abs_diff_eq, AbsDiffEq}; + use std::f32::consts::TAU; + + #[test] + fn constant_curves() { + let curve = constant_curve(everywhere(), 5.0); + assert!(curve.sample(-35.0) == 5.0); + + let curve = constant_curve(interval(0.0, 1.0).unwrap(), true); + assert!(curve.sample(2.0)); + assert!(curve.sample_checked(2.0).is_none()); + } + + #[test] + fn function_curves() { + let curve = function_curve(everywhere(), |t| t * t); + assert!(curve.sample(2.0).abs_diff_eq(&4.0, f32::EPSILON)); + assert!(curve.sample(-3.0).abs_diff_eq(&9.0, f32::EPSILON)); + + let curve = function_curve(interval(0.0, f32::INFINITY).unwrap(), |t| t.log2()); + assert_eq!(curve.sample(3.5), f32::log2(3.5)); + assert!(curve.sample(-1.0).is_nan()); + assert!(curve.sample_checked(-1.0).is_none()); + } + + #[test] + fn mapping() { + let curve = function_curve(everywhere(), |t| t * 3.0 + 1.0); + let mapped_curve = curve.map(|x| x / 7.0); + assert_eq!(mapped_curve.sample(3.5), (3.5 * 3.0 + 1.0) / 7.0); + assert_eq!(mapped_curve.sample(-1.0), (-1.0 * 3.0 + 1.0) / 7.0); + assert_eq!(mapped_curve.domain(), everywhere()); + + let curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t * TAU); + let mapped_curve = curve.map(Quat::from_rotation_z); + assert_eq!(mapped_curve.sample(0.0), Quat::IDENTITY); + assert!(mapped_curve.sample(1.0).is_near_identity()); + assert_eq!(mapped_curve.domain(), interval(0.0, 1.0).unwrap()); + } + + #[test] + fn reparametrization() { + let curve = function_curve(interval(1.0, f32::INFINITY).unwrap(), |t| t.log2()); + let reparametrized_curve = curve + .by_ref() + .reparametrize(interval(0.0, f32::INFINITY).unwrap(), |t| t.exp2()); + assert_abs_diff_eq!(reparametrized_curve.sample(3.5), 3.5); + assert_abs_diff_eq!(reparametrized_curve.sample(100.0), 100.0); + assert_eq!( + reparametrized_curve.domain(), + interval(0.0, f32::INFINITY).unwrap() + ); + + let reparametrized_curve = curve + .by_ref() + .reparametrize(interval(0.0, 1.0).unwrap(), |t| t + 1.0); + assert_abs_diff_eq!(reparametrized_curve.sample(0.0), 0.0); + assert_abs_diff_eq!(reparametrized_curve.sample(1.0), 1.0); + assert_eq!(reparametrized_curve.domain(), interval(0.0, 1.0).unwrap()); + } + + #[test] + fn multiple_maps() { + // Make sure these actually happen in the right order. + let curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t.exp2()); + let first_mapped = curve.map(|x| x.log2()); + let second_mapped = first_mapped.map(|x| x * -2.0); + assert_abs_diff_eq!(second_mapped.sample(0.0), 0.0); + assert_abs_diff_eq!(second_mapped.sample(0.5), -1.0); + assert_abs_diff_eq!(second_mapped.sample(1.0), -2.0); + } + + #[test] + fn multiple_reparams() { + // Make sure these happen in the right order too. + let curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t.exp2()); + let first_reparam = curve.reparametrize(interval(1.0, 2.0).unwrap(), |t| t.log2()); + let second_reparam = first_reparam.reparametrize(interval(0.0, 1.0).unwrap(), |t| t + 1.0); + assert_abs_diff_eq!(second_reparam.sample(0.0), 1.0); + assert_abs_diff_eq!(second_reparam.sample(0.5), 1.5); + assert_abs_diff_eq!(second_reparam.sample(1.0), 2.0); + } + + #[test] + fn resampling() { + let curve = function_curve(interval(1.0, 4.0).unwrap(), |t| t.log2()); + + // Need at least two points to sample. + let nice_try = curve.by_ref().resample_auto(1); + assert!(nice_try.is_err()); + + // The values of a resampled curve should be very close at the sample points. + // Because of denominators, it's not literally equal. + // (This is a tradeoff against O(1) sampling.) + let resampled_curve = curve.by_ref().resample_auto(101).unwrap(); + let step = curve.domain().length() / 100.0; + for index in 0..101 { + let test_pt = curve.domain().start() + index as f32 * step; + let expected = curve.sample(test_pt); + assert_abs_diff_eq!(resampled_curve.sample(test_pt), expected, epsilon = 1e-6); + } + + // Another example. + let curve = function_curve(interval(0.0, TAU).unwrap(), |t| t.cos()); + let resampled_curve = curve.by_ref().resample_auto(1001).unwrap(); + let step = curve.domain().length() / 1000.0; + for index in 0..1001 { + let test_pt = curve.domain().start() + index as f32 * step; + let expected = curve.sample(test_pt); + assert_abs_diff_eq!(resampled_curve.sample(test_pt), expected, epsilon = 1e-6); + } + } + + #[test] + fn uneven_resampling() { + let curve = function_curve(interval(0.0, f32::INFINITY).unwrap(), |t| t.exp()); + + // Need at least two points to resample. + let nice_try = curve.by_ref().resample_uneven_auto([1.0; 1]); + assert!(nice_try.is_err()); + + // Uneven sampling should produce literal equality at the sample points. + // (This is part of what you get in exchange for O(log(n)) sampling.) + let sample_points = (0..100).map(|idx| idx as f32 * 0.1); + let resampled_curve = curve.by_ref().resample_uneven_auto(sample_points).unwrap(); + for idx in 0..100 { + let test_pt = idx as f32 * 0.1; + let expected = curve.sample(test_pt); + assert_eq!(resampled_curve.sample(test_pt), expected); + } + assert_abs_diff_eq!(resampled_curve.domain().start(), 0.0); + assert_abs_diff_eq!(resampled_curve.domain().end(), 9.9, epsilon = 1e-6); + + // Another example. + let curve = function_curve(interval(1.0, f32::INFINITY).unwrap(), |t| t.log2()); + let sample_points = (0..10).map(|idx| (idx as f32).exp2()); + let resampled_curve = curve.by_ref().resample_uneven_auto(sample_points).unwrap(); + for idx in 0..10 { + let test_pt = (idx as f32).exp2(); + let expected = curve.sample(test_pt); + assert_eq!(resampled_curve.sample(test_pt), expected); + } + assert_abs_diff_eq!(resampled_curve.domain().start(), 1.0); + assert_abs_diff_eq!(resampled_curve.domain().end(), 512.0); + } +} diff --git a/crates/bevy_math/src/lib.rs b/crates/bevy_math/src/lib.rs index 868dae094510d..a81ee19c9b18a 100644 --- a/crates/bevy_math/src/lib.rs +++ b/crates/bevy_math/src/lib.rs @@ -17,6 +17,7 @@ pub mod bounding; mod common_traits; mod compass; pub mod cubic_splines; +pub mod curve; mod direction; mod float_ord; pub mod primitives;