Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GODEL support with T5 model type #376

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@
//! <details>
//! <summary> <b>4. Dialogue Model </b> </summary>
//!
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or
//! [GODEL](https://github.com/microsoft/GODEL).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
Expand Down
48 changes: 41 additions & 7 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// limitations under the License.

//! # Multi-turn dialogue
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or
//! [GODEL](https://github.com/microsoft/GODEL).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
Expand Down Expand Up @@ -59,6 +60,7 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use std::collections::HashMap;
use tch::{Device, Kind, Tensor};
use uuid::Uuid;
Expand Down Expand Up @@ -695,14 +697,16 @@ impl Default for ConversationManager {
pub enum ConversationOption {
/// Conversation based on GPT2 model
GPT2(GPT2Generator),
T5(T5Generator),
}

impl ConversationOption {
pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
"GPT-2 and T5 are currently the only supported model for conversation generation"
.to_string(),
)),
}
Expand All @@ -717,8 +721,12 @@ impl ConversationOption {
config.into(),
tokenizer,
)?)),
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
"GPT-2 and T5 are currently the only supported model for conversation generation"
.to_string(),
)),
}
Expand All @@ -729,27 +737,31 @@ impl ConversationOption {
Self::GPT2(model_ref) => {
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
}
Self::T5(model_ref) => Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()),
}
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
}
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
}
}

/// Returns the `ModelType` for this ConversationOption
pub fn model_type(&self) -> ModelType {
match *self {
Self::GPT2(_) => ModelType::GPT2,
Self::T5(_) => ModelType::T5,
}
}

Expand All @@ -765,6 +777,19 @@ impl ConversationOption {
.into_iter()
.map(|output| output.indices)
.collect(),
Self::T5(ref model) => model
.generate_from_ids_and_past(input_ids, attention_mask, None)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}

/// Interface method to get the model family (encoder-decoder or decoder)
fn is_encoder_decoder(&self) -> bool {
match *self {
Self::GPT2(ref generator) => generator.is_encoder_decoder(),
Self::T5(ref generator) => generator.is_encoder_decoder(),
}
}
}
Expand Down Expand Up @@ -915,7 +940,11 @@ impl ConversationModel {
.zip(active_uuid.into_iter())
.zip(removed_padding_quantities.into_iter())
{
let generated_response = &generated_sequence[input_length - removed_padding.0..];
let generated_response = if self.model.is_encoder_decoder() {
generated_sequence.as_slice()
} else {
&generated_sequence[input_length - removed_padding.0..]
};
conversation
.generated_responses
.push(
Expand Down Expand Up @@ -1023,9 +1052,14 @@ impl ConversationModel {
.get(input_idx as i64)
.slice(0, 0, (max_len - input.len()) as i64, 1)
.fill_(0);
let mut padded_input = vec![pad_token; max_len - input.len()];
padded_input.extend(input);
padded_input
let padding = vec![pad_token; max_len - input.len()];
if self.model.is_encoder_decoder() {
// right padding assumed for encoder-decoders
[input, &padding].concat()
} else {
// left padding assumed for decoders
[&padding, input].concat()
}
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.device))
.collect::<Vec<Tensor>>();
Expand Down
40 changes: 40 additions & 0 deletions src/t5/t5_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ impl T5ModelResources {
"sentence-t5-base/model",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/model",
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = (
"godel-v1-1-large/model",
"https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/rust_model.ot",
);
}

impl T5ConfigResources {
Expand All @@ -79,6 +89,16 @@ impl T5ConfigResources {
"sentence-t5-base/config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/config",
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = (
"godel-v1-1-large/config",
"https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/config.json",
);
}

impl T5VocabResources {
Expand All @@ -92,11 +112,31 @@ impl T5VocabResources {
"t5-base/spiece",
"https://huggingface.co/t5-base/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const T5_LARGE: (&'static str, &'static str) = (
"t5-large/spiece",
"https://huggingface.co/t5-large/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/spiece",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_LARGE: (&'static str, &'static str) = (
"sentence-t5-large/spiece",
"https://huggingface.co/sentence-transformers/sentence-t5-large/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/spiece",
"https://huggingface.co/t5-base/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = (
"godel-v1-1-large/spiece",
"https://huggingface.co/t5-large/resolve/main/spiece.model",
);
}

const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
Expand Down
4 changes: 4 additions & 0 deletions tests/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
model_type: ModelType::GPT2,
do_sample: false,
device: Device::Cpu,
..Default::default()
Expand Down Expand Up @@ -760,6 +761,7 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
model_type: ModelType::GPT2,
do_sample: false,
device: Device::Cpu,
..Default::default()
Expand Down Expand Up @@ -802,6 +804,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
model_type: ModelType::GPT2,
max_length: Some(36),
min_length_for_response: 24,
do_sample: false,
Expand Down Expand Up @@ -851,6 +854,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result
fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> anyhow::Result<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
model_type: ModelType::GPT2,
do_sample: false,
device: Device::Cpu,
..Default::default()
Expand Down
Loading