diff --git a/entity/src/conversation.rs b/entity/src/conversation.rs new file mode 100644 index 000000000..b9311a77b --- /dev/null +++ b/entity/src/conversation.rs @@ -0,0 +1,19 @@ +use sea_orm::entity::prelude::*; +use time::OffsetDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "conversation")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: Uuid, + pub user_id: String, + pub state: serde_json::Value, + pub seq: i32, + pub summary: String, + pub updated_at: OffsetDateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/entity/src/lib.rs b/entity/src/lib.rs index d0b671f7e..2223e5a9a 100644 --- a/entity/src/lib.rs +++ b/entity/src/lib.rs @@ -1,6 +1,7 @@ pub mod advisory; pub mod advisory_vulnerability; pub mod base_purl; +pub mod conversation; pub mod cpe; pub mod cpe_license_assertion; pub mod cvss3; diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 47f55e4c1..24a223ed2 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -98,6 +98,7 @@ mod m0000780_alter_source_document_time; mod m0000790_alter_sbom_alter_document_id; mod m0000800_alter_product_version_range_scheme; mod m0000810_fix_get_purl; +mod m0000820_create_conversation; pub struct Migrator; @@ -203,6 +204,7 @@ impl MigratorTrait for Migrator { Box::new(m0000790_alter_sbom_alter_document_id::Migration), Box::new(m0000800_alter_product_version_range_scheme::Migration), Box::new(m0000810_fix_get_purl::Migration), + Box::new(m0000820_create_conversation::Migration), ] } } diff --git a/migration/src/m0000820_create_conversation.rs b/migration/src/m0000820_create_conversation.rs new file mode 100644 index 000000000..2d0506e34 --- /dev/null +++ b/migration/src/m0000820_create_conversation.rs @@ -0,0 +1,76 @@ +use crate::UuidV4; +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Conversation::Table) + .col( + ColumnDef::new(Conversation::Id) + .uuid() + .not_null() + .default(Func::cust(UuidV4)) + .primary_key(), + ) + .col(ColumnDef::new(Conversation::UserId).string().not_null()) + .col(ColumnDef::new(Conversation::State).json_binary().not_null()) + .col(ColumnDef::new(Conversation::Seq).integer().not_null()) + .col(ColumnDef::new(Conversation::Summary).string().not_null()) + .col( + ColumnDef::new(Conversation::UpdatedAt) + .timestamp_with_time_zone() + .not_null(), + ) + .to_owned(), + ) + .await?; + + // this index should speed up lookup up the most recent conversations for a user + manager + .create_index( + Index::create() + .table(Conversation::Table) + .name(Conversation::ConverstationUserIdUpdatedAtIdx.to_string()) + .col(Conversation::UserId) + .col(Conversation::UpdatedAt) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_index( + Index::drop() + .if_exists() + .table(Conversation::Table) + .name(Conversation::ConverstationUserIdUpdatedAtIdx.to_string()) + .to_owned(), + ) + .await?; + + manager + .drop_table(Table::drop().table(Conversation::Table).to_owned()) + .await + } +} + +#[derive(DeriveIden)] +enum Conversation { + Table, + Id, + UserId, + ConverstationUserIdUpdatedAtIdx, + State, + Seq, + Summary, + UpdatedAt, +} diff --git a/modules/fundamental/src/ai/service/mod.rs b/modules/fundamental/src/ai/service/mod.rs index 235a943b6..728f52e8c 100644 --- a/modules/fundamental/src/ai/service/mod.rs +++ b/modules/fundamental/src/ai/service/mod.rs @@ -19,11 +19,21 @@ use langchain_rust::{ prompt_args, tools::Tool, }; -use sea_orm::ConnectionTrait; +use sea_orm::{prelude::Uuid, ColumnTrait, EntityTrait, QueryFilter, QueryOrder}; +use sea_orm::{ActiveModelTrait, ConnectionTrait, Set}; + +use serde_json::Value; use std::env; use std::sync::Arc; +use time::OffsetDateTime; use tokio::sync::OnceCell; + +use trustify_common::db::{limiter::LimiterTrait, query::Filtering}; + +use trustify_common::db::query::q; use trustify_common::db::Database; +use trustify_common::model::{Paginated, PaginatedResults}; +use trustify_entity::conversation; pub const PREFIX: &str = include_str!("prefix.txt"); @@ -284,6 +294,88 @@ impl AiService { Ok(response) } + + pub async fn create_conversation( + &self, + user_id: String, + state: Value, + summary: String, + connection: &C, + ) -> Result { + let model = conversation::ActiveModel { + id: Default::default(), + user_id: Set(user_id), + state: Set(state), + seq: Set(0), + summary: Set(summary), + updated_at: Set(OffsetDateTime::now_utc()), + }; + Ok(model.insert(connection).await?) + } + + pub async fn update_conversation( + &self, + conversation_id: Uuid, + state: Value, + summary: String, + seq: i32, + connection: &C, + ) -> Result { + let model = conversation::ActiveModel { + id: Set(conversation_id), + state: Set(state), + summary: Set(summary), + seq: Set(seq), + updated_at: Set(OffsetDateTime::now_utc()), + ..Default::default() + }; + + let result = conversation::Entity::update(model) + .filter(conversation::Column::Seq.lte(seq)) + .exec(connection) + .await?; + + Ok(result) + } + + pub async fn fetch_conversation( + &self, + id: Uuid, + connection: &C, + ) -> Result, Error> { + let select = conversation::Entity::find().filter(conversation::Column::Id.eq(id)); + + Ok(select.one(connection).await?) + } + + pub async fn fetch_conversations( + &self, + user_id: String, + paginated: Paginated, + connection: &C, + ) -> Result, Error> { + let limiter = conversation::Entity::find() + .order_by_desc(conversation::Column::UpdatedAt) + .filtering(q(format!("user_id={}", user_id).as_str()))? + .limiting(connection, paginated.offset, paginated.limit); + + let total = limiter.total().await?; + + Ok(PaginatedResults { + total, + items: limiter.fetch().await?, + }) + } + + pub async fn delete_conversation( + &self, + id: Uuid, + connection: &C, + ) -> Result { + let query = conversation::Entity::delete_by_id(id); + let result = query.exec(connection).await?; + Ok(result.rows_affected) + } } #[cfg(test)] diff --git a/modules/fundamental/src/ai/service/test.rs b/modules/fundamental/src/ai/service/test.rs index 0f904426e..fb3a34655 100644 --- a/modules/fundamental/src/ai/service/test.rs +++ b/modules/fundamental/src/ai/service/test.rs @@ -1,9 +1,12 @@ use crate::ai::model::ChatState; use crate::ai::service::AiService; +use serde_json::json; use test_context::test_context; use test_log::test; + use trustify_common::hashing::Digests; +use trustify_common::model::Paginated; use trustify_module_ingestor::graph::product::ProductInformation; use trustify_test_context::TrustifyContext; @@ -178,3 +181,83 @@ async fn test_completions_advisory_info(ctx: &TrustifyContext) -> Result<(), any Ok(()) } + +#[test_context(TrustifyContext)] +#[test(actix_web::test)] +async fn conversation_crud(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + let service = AiService::new(ctx.db.clone()); + + // create a conversation + let value1 = json!({"test":"value1"}); + let conversation = service + .create_conversation("user_a".into(), value1.clone(), "summary".into(), &ctx.db) + .await?; + + assert_eq!("user_a", conversation.user_id); + assert_eq!(value1, conversation.state); + assert_eq!("summary", conversation.summary); + assert_eq!(0i32, conversation.seq); + let conversation_id = conversation.id; + + // get the created conversation + let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?; + + assert_eq!(Some(conversation.clone()), fetched); + + // list the conversations of the user + let converstations = service + .fetch_conversations( + "user_a".into(), + Paginated { + offset: 0, + limit: 10, + }, + &ctx.db, + ) + .await?; + + assert_eq!(1, converstations.total); + assert_eq!(1, converstations.items.len()); + assert_eq!(conversation, converstations.items[0]); + + let value2 = json!({"test":"value2"}); + service + .update_conversation( + conversation_id, + value2.clone(), + "summary2".into(), + 1, + &ctx.db, + ) + .await?; + + // get the updated conversation + let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?; + + assert_eq!(value2, fetched.unwrap().state); + + // verify that the update fails due to old seq + service + .update_conversation( + conversation_id, + json!({"test":"bad"}), + "summary2".into(), + 0, + &ctx.db, + ) + .await + .expect_err("should fail due to old seq"); + + // delete the conversation + let delete_count = service + .delete_conversation(conversation_id, &ctx.db) + .await?; + assert_eq!(delete_count, 1u64); + + // get the deleted conversation + let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?; + + assert_eq!(None, fetched); + + Ok(()) +}