Skip to content

Commit

Permalink
merge EricLBuehler
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeadie committed Jul 1, 2024
1 parent b7a3e34 commit f07ddc5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 28 deletions.
27 changes: 0 additions & 27 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,33 +927,6 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
}
}

#[cfg(feature = "cuda")]
pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: usize) -> Result<Tensor> {
if !ltensor.device().is_cuda() {
return Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous();
}
use candle::cuda_backend::KVConcat;
let op = KVConcat { concat_dim };
//inputs for kvconcat must be contiguous tensors
if ltensor.is_contiguous() && rtensor.is_contiguous() {
ltensor.apply_op2(&rtensor, op)
} else if ltensor.is_contiguous() {
ltensor.apply_op2(&rtensor.contiguous()?, op)
} else if rtensor.is_contiguous() {
let ltensor = ltensor.contiguous()?;
ltensor.apply_op2(&rtensor, op)
} else {
let ltensor = ltensor.contiguous()?;
let rtensor = rtensor.contiguous()?;
ltensor.apply_op2(&rtensor, op)
}
}

#[cfg(not(feature = "cuda"))]
pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result<Tensor> {
Tensor::cat(&[ltensor, rtensor], concat_dim as usize)?.contiguous()
}

#[derive(Clone, Debug)]
pub struct Identity;

Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/vgg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl ModuleT for Vgg<'_> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
.map(|(in_c, out_c, name)| {
.map(|&(in_c, out_c, name)| {
candle_nn::conv2d(
*in_c,
*out_c,
Expand Down

0 comments on commit f07ddc5

Please sign in to comment.