From 7e00a22f8827cbb24b963c13fdf60caa3ba1b4ff Mon Sep 17 00:00:00 2001 From: Dario Cancelliere Date: Thu, 11 May 2023 02:54:21 +0200 Subject: [PATCH 1/5] Added GODEL support --- src/pipelines/conversation.rs | 21 ++++++++++++++++++++- src/t5/t5_model.rs | 15 +++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index b8f46e5f..62f9e7ce 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -12,7 +12,8 @@ // limitations under the License. //! # Multi-turn dialogue -//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT). +//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or +//! [GODEL](https://github.com/microsoft/GODEL). //! This pipeline allows the generation of single or multi-turn conversations between a human and a model. //! The DialoGPT's page states that //! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality @@ -55,6 +56,7 @@ //! from the 3rd party utilization of the pretrained system. use crate::common::error::RustBertError; use crate::gpt2::GPT2Generator; +use crate::t5::T5Generator; use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; @@ -695,12 +697,14 @@ impl Default for ConversationManager { pub enum ConversationOption { /// Conversation based on GPT2 model GPT2(GPT2Generator), + T5(T5Generator), } impl ConversationOption { pub fn new(config: ConversationConfig) -> Result { match config.model_type { ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)), + ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)), _ => Err(RustBertError::InvalidConfigurationError( "GPT2 is currently the only supported model for conversation generation" .to_string(), @@ -717,6 +721,10 @@ impl ConversationOption { config.into(), tokenizer, )?)), + ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new_with_tokenizer( + config.into(), + tokenizer, + )?)), _ => Err(RustBertError::InvalidConfigurationError( "GPT2 is currently the only supported model for conversation generation" .to_string(), @@ -729,6 +737,9 @@ impl ConversationOption { Self::GPT2(model_ref) => { Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()) } + Self::T5(model_ref) => { + Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()) + } } } @@ -736,6 +747,7 @@ impl ConversationOption { pub fn get_tokenizer(&self) -> &TokenizerOption { match self { Self::GPT2(model_ref) => model_ref._get_tokenizer(), + Self::T5(model_ref) => model_ref._get_tokenizer(), } } @@ -743,6 +755,7 @@ impl ConversationOption { pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption { match self { Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(), + Self::T5(model_ref) => model_ref._get_tokenizer_mut(), } } @@ -750,6 +763,7 @@ impl ConversationOption { pub fn model_type(&self) -> ModelType { match *self { Self::GPT2(_) => ModelType::GPT2, + Self::T5(_) => ModelType::T5, } } @@ -765,6 +779,11 @@ impl ConversationOption { .into_iter() .map(|output| output.indices) .collect(), + Self::T5(ref model) => model + .generate_from_ids_and_past(input_ids, attention_mask, None) + .into_iter() + .map(|output| output.indices) + .collect(), } } } diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 204c4560..59d11807 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -61,6 +61,11 @@ impl T5ModelResources { "sentence-t5-base/model", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/model", + "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot", + ); } impl T5ConfigResources { @@ -79,6 +84,11 @@ impl T5ConfigResources { "sentence-t5-base/config", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/config", + "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json", + ); } impl T5VocabResources { @@ -97,6 +107,11 @@ impl T5VocabResources { "sentence-t5-base/spiece", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/spiece", + "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/spiece.model", + ); } const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German]; From a3f484e1afa73ce0070196706fa61f828a395445 Mon Sep 17 00:00:00 2001 From: Dario Cancelliere Date: Thu, 11 May 2023 03:47:10 +0200 Subject: [PATCH 2/5] Added other missing resources --- src/pipelines/conversation.rs | 2 ++ src/t5/t5_model.rs | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 62f9e7ce..ce450160 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -925,6 +925,8 @@ impl ConversationModel { let mut output = HashMap::with_capacity(active_uuid.len()); + println!("generated: {:#?}, prompt_ids: {:#?}", &generated, &prompt_ids); + for ( ((conversation, (generated_sequence, conversation_promp_ids)), uuid), removed_padding, diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 59d11807..f93e210f 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -66,6 +66,11 @@ impl T5ModelResources { "godel-v1-1-base/model", "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/model", + "https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/rust_model.ot", + ); } impl T5ConfigResources { @@ -89,6 +94,11 @@ impl T5ConfigResources { "godel-v1-1-base/config", "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/config", + "https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/config.json", + ); } impl T5VocabResources { @@ -107,10 +117,15 @@ impl T5VocabResources { "sentence-t5-base/spiece", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model", ); - /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( "godel-v1-1-base/spiece", - "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/spiece.model", + "https://huggingface.co/t5-base/resolve/main/spiece.model", + ); + /// Shared under Apache 2.0 license by the Google team at . + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/spiece", + "https://huggingface.co/t5-large/resolve/main/spiece.model", ); } From 564ae85df0e16158cc3d0fdac9fc1a40bf3a69ed Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 14 May 2023 09:05:34 +0100 Subject: [PATCH 3/5] - Remove debugging print statement - Skip truncation of prompt for encoder-decoder models - Add right padding logic for encoder-decoder models --- src/pipelines/conversation.rs | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 9c943c66..6e026cc2 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -56,11 +56,11 @@ //! from the 3rd party utilization of the pretrained system. use crate::common::error::RustBertError; use crate::gpt2::GPT2Generator; -use crate::t5::T5Generator; use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; use crate::resources::ResourceProvider; +use crate::t5::T5Generator; use std::collections::HashMap; use tch::{Device, Kind, Tensor}; use uuid::Uuid; @@ -737,9 +737,7 @@ impl ConversationOption { Self::GPT2(model_ref) => { Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()) } - Self::T5(model_ref) => { - Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()) - } + Self::T5(model_ref) => Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()), } } @@ -786,6 +784,14 @@ impl ConversationOption { .collect(), } } + + /// Interface method to get the model family (encoder-decoder or decoder) + fn is_encoder_decoder(&self) -> bool { + match *self { + Self::GPT2(ref generator) => generator.is_encoder_decoder(), + Self::T5(ref generator) => generator.is_encoder_decoder(), + } + } } /// # Conversation model @@ -925,8 +931,6 @@ impl ConversationModel { let mut output = HashMap::with_capacity(active_uuid.len()); - println!("generated: {:#?}, prompt_ids: {:#?}", &generated, &prompt_ids); - for ( ((conversation, (generated_sequence, conversation_promp_ids)), uuid), removed_padding, @@ -936,7 +940,11 @@ impl ConversationModel { .zip(active_uuid.into_iter()) .zip(removed_padding_quantities.into_iter()) { - let generated_response = &generated_sequence[input_length - removed_padding.0..]; + let generated_response = if self.model.is_encoder_decoder() { + generated_sequence.as_slice() + } else { + &generated_sequence[input_length - removed_padding.0..] + }; conversation .generated_responses .push( @@ -1044,9 +1052,14 @@ impl ConversationModel { .get(input_idx as i64) .slice(0, 0, (max_len - input.len()) as i64, 1) .fill_(0); - let mut padded_input = vec![pad_token; max_len - input.len()]; - padded_input.extend(input); - padded_input + let padding = vec![pad_token; max_len - input.len()]; + if self.model.is_encoder_decoder() { + // right padding assumed for encoder-decoders + [input, &padding].concat() + } else { + // left padding assumed for decoders + [&padding, input].concat() + } }) .map(|tokens| Tensor::of_slice(&tokens).to(self.device)) .collect::>(); From 971d4018fd46c24d2dcd3116f9fee74073541b71 Mon Sep 17 00:00:00 2001 From: Dario Cancelliere Date: Sun, 14 May 2023 14:13:05 +0200 Subject: [PATCH 4/5] Updated error message --- src/lib.rs | 3 ++- src/pipelines/conversation.rs | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3308455c..fdc8a96b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -256,7 +256,8 @@ //!
//! 4. Dialogue Model //! -//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT). +//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or +//! [GODEL](https://github.com/microsoft/GODEL). //! This pipeline allows the generation of single or multi-turn conversations between a human and a model. //! The DialoGPT's page states that //! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 6e026cc2..f26d7c70 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -706,7 +706,7 @@ impl ConversationOption { ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)), ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)), _ => Err(RustBertError::InvalidConfigurationError( - "GPT2 is currently the only supported model for conversation generation" + "GPT-2 and T5 are currently the only supported model for conversation generation" .to_string(), )), } @@ -726,7 +726,7 @@ impl ConversationOption { tokenizer, )?)), _ => Err(RustBertError::InvalidConfigurationError( - "GPT2 is currently the only supported model for conversation generation" + "GPT-2 and T5 are currently the only supported model for conversation generation" .to_string(), )), } From aca49b783f9160e5434ad85ecb920cb53dc01500 Mon Sep 17 00:00:00 2001 From: Dario Cancelliere Date: Sun, 14 May 2023 15:13:33 +0200 Subject: [PATCH 5/5] Added tests for GODEL T5 model --- src/t5/t5_model.rs | 10 ++ tests/gpt2.rs | 4 + tests/t5.rs | 230 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index f93e210f..fdf8bd33 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -112,11 +112,21 @@ impl T5VocabResources { "t5-base/spiece", "https://huggingface.co/t5-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license by the Google team at . + pub const T5_LARGE: (&'static str, &'static str) = ( + "t5-large/spiece", + "https://huggingface.co/t5-large/resolve/main/spiece.model", + ); /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( "sentence-t5-base/spiece", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_LARGE: (&'static str, &'static str) = ( + "sentence-t5-large/spiece", + "https://huggingface.co/sentence-transformers/sentence-t5-large/resolve/main/spiece.model", + ); /// Shared under Apache 2.0 license by the Google team at . pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( "godel-v1-1-base/spiece", diff --git a/tests/gpt2.rs b/tests/gpt2.rs index b820348f..25d9e371 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -723,6 +723,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> { fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() @@ -760,6 +761,7 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> { fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() @@ -802,6 +804,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> { fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, max_length: Some(36), min_length_for_response: 24, do_sample: false, @@ -851,6 +854,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() diff --git a/tests/t5.rs b/tests/t5.rs index 1cebe74d..cd33ae5f 100644 --- a/tests/t5.rs +++ b/tests/t5.rs @@ -1,4 +1,7 @@ use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::conversation::{ + ConversationConfig, ConversationManager, ConversationModel, +}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::RemoteResource; @@ -111,3 +114,230 @@ about exoplanets like K2-18b."]; Ok(()) } + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_single_multi_turn_conversation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!(output.get(&conversation_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_id).unwrap(), + &" I'm not sure, but I've heard it's a great comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("What's the last book you have read?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I read The Last of Us. It was a great book." + ); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_1_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_1_id).unwrap(), + &" I'm not sure, but I've heard it's a great comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + max_length: Some(36), + min_length_for_response: 24, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("Hello how are you today?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" i am a little tired from work" + ); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_1_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let _ = conversation_manager + .get(&conversation_2_id) + .unwrap() + .add_user_input("Fine."); + + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!( + output.get(&conversation_1_id).unwrap(), + &" No, it's a comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation_with_conversation_deletion() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("What's the last book you have read?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I read The Last of Us. It was a great book." + ); + + // Turn 2 + let _ = conversation_manager.remove(&conversation_1_id); + let _ = conversation_manager + .get(&conversation_2_id) + .unwrap() + .add_user_input("Why do you recommend it?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I've read it, but I'm not sure if I'd like it again. I'm not a huge fan of the genre." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +}