Skip to content

Commit

Permalink
Support Minus(u) for arbitrary values of u, e.g. Minus(3). (huggingfa…
Browse files Browse the repository at this point in the history
…ce#2428)

* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.
  • Loading branch information
LaurentMazare authored Aug 17, 2024
1 parent b75ef05 commit 7cff589
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,15 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
Minus(usize),
}

impl D {
fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
Expand All @@ -327,6 +329,7 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
Expand All @@ -336,6 +339,7 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
Expand Down

0 comments on commit 7cff589

Please sign in to comment.