forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LLaVA support (huggingface#2234)
* first commit * llava * clippy and fmt * some fixes * minor fixes * remove useless file * refactor: Remove llava/constants.rs and update llava/mod.rs * modify variable name * modify code after clippy * Minor tweaks. --------- Co-authored-by: laurent <[email protected]>
- Loading branch information
1 parent
03344d3
commit cd4d941
Showing
12 changed files
with
1,567 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>"; | ||
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>"; | ||
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>"; | ||
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
pub enum SeparatorStyle { | ||
Two, | ||
Mpt, | ||
} | ||
pub struct Conversation { | ||
pub system: String, | ||
pub roles: Vec<String>, | ||
pub messages: Vec<(String, Option<String>)>, | ||
pub offset: i32, | ||
pub sep_style: SeparatorStyle, | ||
pub sep: String, | ||
pub sep2: Option<String>, | ||
pub version: String, | ||
} | ||
|
||
impl Conversation { | ||
pub fn new( | ||
system: &str, | ||
roles: &[String], | ||
offset: i32, | ||
sep_style: SeparatorStyle, | ||
sep: &str, | ||
sep2: Option<&str>, | ||
version: &str, | ||
) -> Self { | ||
Conversation { | ||
system: system.to_string(), | ||
roles: roles.to_vec(), | ||
messages: Vec::new(), | ||
offset, | ||
sep_style, | ||
sep: sep.to_string(), | ||
sep2: sep2.map(|s| s.to_string()), | ||
version: version.to_string(), | ||
} | ||
} | ||
|
||
pub fn conv_chatml_direct() -> Self { | ||
Conversation::new( | ||
"<|im_start|>system\nAnswer the questions.", | ||
&[ | ||
"<|im_start|>user\n".to_string(), | ||
"<|im_start|>assistant\n".to_string(), | ||
], | ||
0, | ||
SeparatorStyle::Mpt, | ||
"<|im_end|>", | ||
None, | ||
"mpt", | ||
) | ||
} | ||
|
||
pub fn conv_llava_v1() -> Self { | ||
Conversation::new( | ||
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", | ||
&[ | ||
"USER".to_string(), | ||
"ASSISTANT".to_string(), | ||
], | ||
0, | ||
SeparatorStyle::Two, | ||
" ", | ||
Some("</s>"), | ||
"v1" | ||
) | ||
} | ||
|
||
pub fn append_message(&mut self, role: String, message: Option<&str>) { | ||
self.messages.push((role, message.map(|s| s.to_string()))) | ||
} | ||
|
||
pub fn append_user_message(&mut self, message: Option<&str>) { | ||
self.append_message(self.roles[0].clone(), message); | ||
} | ||
|
||
pub fn append_assistant_message(&mut self, message: Option<&str>) { | ||
self.append_message(self.roles[1].clone(), message); | ||
} | ||
|
||
pub fn get_prompt(&self) -> String { | ||
match self.sep_style { | ||
SeparatorStyle::Mpt => { | ||
let mut ret = String::new(); | ||
ret.push_str(&self.system); | ||
ret.push_str(&self.sep); | ||
for (role, message) in &self.messages { | ||
ret.push_str(role); | ||
if let Some(message) = message { | ||
ret.push_str(message); | ||
}; | ||
ret.push_str(&self.sep); | ||
} | ||
ret | ||
} | ||
SeparatorStyle::Two => { | ||
let seps = [self.sep.clone(), self.sep2.clone().unwrap()]; | ||
let mut ret = String::new(); | ||
ret.push_str(&self.system); | ||
ret.push_str(&seps[0]); | ||
for (i, (role, message)) in self.messages.iter().enumerate() { | ||
ret.push_str(role); | ||
if let Some(message) = message { | ||
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^ | ||
ret.push_str(message); | ||
ret.push_str(&seps[i % 2]); | ||
} else { | ||
ret.push(':') | ||
} | ||
} | ||
ret | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.