Skip to content

Commit

Permalink
Implement the module trait directly for QMatMul. (huggingface#1372)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Nov 25, 2023
1 parent 762e996 commit bfa7c8f
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 18 deletions.
11 changes: 5 additions & 6 deletions candle-core/examples/basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use anyhow::Result;
use candle_core::{Device, Tensor};

fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Ok(())
}
6 changes: 0 additions & 6 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}

impl Module for quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}

impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
}
}

impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl crate::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
Expand Down
2 changes: 1 addition & 1 deletion candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle_core::{
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Result, Tensor,
Device, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion candle-nn/examples/cpu_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use candle::quantized::GgmlType;
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};

const CHECK_CONV2D: bool = false;
Expand Down
2 changes: 1 addition & 1 deletion candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};

mod utils;
use utils::wrap_err;
Expand Down
2 changes: 1 addition & 1 deletion candle-wasm-tests/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
Device, Result, Tensor,
Device, Module, Result, Tensor,
};

use wasm_bindgen_test::*;
Expand Down

0 comments on commit bfa7c8f

Please sign in to comment.