Skip to content

Commit

Permalink
Fix lints for clippy 1.75. (huggingface#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 28, 2023
1 parent cd889c0 commit 1e442d4
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 40 deletions.
34 changes: 17 additions & 17 deletions candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,23 +478,6 @@ extract_dims!(
(usize, usize, usize, usize, usize)
);

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}

pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
Expand Down Expand Up @@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
Ok((d1, d2, d3, d4, d).into())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
2 changes: 1 addition & 1 deletion candle-examples/examples/musicgen/musicgen_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl MusicgenDecoder {
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
for decoder_layer in self.layers.iter_mut() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
}
let xs = self.layer_norm.forward(&xs)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/segment_anything/sam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl Sam {
let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
Some((points, labels))
};
let points = points.as_ref().map(|(x, y)| (x, y));
let points = points.as_ref().map(|xy| (&xy.0, &xy.1));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
self.mask_decoder.forward(
Expand Down
12 changes: 6 additions & 6 deletions candle-wasm-examples/llama2-c/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ pub enum Msg {
Run,
UpdateStatus(String),
SetModel(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
WorkerIn(WorkerInput),
WorkerOut(Result<WorkerOutput, String>),
}

pub struct CurrentDecode {
Expand Down Expand Up @@ -75,7 +75,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
Expand Down Expand Up @@ -128,11 +128,11 @@ impl Component for App {
let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
.send_message(Msg::WorkerIn(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
Msg::WorkerOutMsg(output) => {
Msg::WorkerOut(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::GenerationDone(Err(err))) => {
Expand Down Expand Up @@ -165,7 +165,7 @@ impl Component for App {
}
true
}
Msg::WorkerInMsg(inp) => {
Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}
Expand Down
16 changes: 7 additions & 9 deletions candle-wasm-examples/whisper/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ pub enum Msg {
Run(usize),
UpdateStatus(String),
SetDecoder(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
WorkerIn(WorkerInput),
WorkerOut(Result<WorkerOutput, String>),
}

pub struct CurrentDecode {
Expand Down Expand Up @@ -116,7 +116,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
Expand Down Expand Up @@ -165,18 +165,16 @@ impl Component for App {
Err(err) => {
let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(output)
}
Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
Msg::WorkerOut(output)
}
Ok(wav_bytes) => Msg::WorkerIn(WorkerInput::DecodeTask { wav_bytes }),
}
})
}
//
true
}
Msg::WorkerOutMsg(output) => {
Msg::WorkerOut(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
Expand All @@ -198,7 +196,7 @@ impl Component for App {
}
true
}
Msg::WorkerInMsg(inp) => {
Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}
Expand Down
12 changes: 6 additions & 6 deletions candle-wasm-examples/yolo/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub enum Msg {
Run,
UpdateStatus(String),
SetModel(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
WorkerIn(WorkerInput),
WorkerOut(Result<WorkerOutput, String>),
}

pub struct CurrentDecode {
Expand Down Expand Up @@ -117,7 +117,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
Expand Down Expand Up @@ -166,7 +166,7 @@ impl Component for App {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::RunData(RunData {
Ok(image_data) => Msg::WorkerIn(WorkerInput::RunData(RunData {
image_data,
conf_threshold: 0.5,
iou_threshold: 0.5,
Expand All @@ -176,7 +176,7 @@ impl Component for App {
}
true
}
Msg::WorkerOutMsg(output) => {
Msg::WorkerOut(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::ProcessingDone(Err(err))) => {
Expand Down Expand Up @@ -218,7 +218,7 @@ impl Component for App {
}
true
}
Msg::WorkerInMsg(inp) => {
Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}
Expand Down

0 comments on commit 1e442d4

Please sign in to comment.