Skip to content

Commit

Permalink
feat[ai]: Add crud operations for a conversations table
Browse files Browse the repository at this point in the history
Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Jan 3, 2025
1 parent ff1e5a8 commit b987e9e
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 1 deletion.
19 changes: 19 additions & 0 deletions entity/src/conversation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use sea_orm::entity::prelude::*;
use time::OffsetDateTime;

#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "conversation")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: Uuid,
pub user_id: String,
pub state: serde_json::Value,
pub seq: i32,
pub summary: String,
pub updated_at: OffsetDateTime,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}
1 change: 1 addition & 0 deletions entity/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod advisory;
pub mod advisory_vulnerability;
pub mod base_purl;
pub mod conversation;
pub mod cpe;
pub mod cpe_license_assertion;
pub mod cvss3;
Expand Down
2 changes: 2 additions & 0 deletions migration/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ mod m0000780_alter_source_document_time;
mod m0000790_alter_sbom_alter_document_id;
mod m0000800_alter_product_version_range_scheme;
mod m0000810_fix_get_purl;
mod m0000820_create_conversation;

pub struct Migrator;

Expand Down Expand Up @@ -203,6 +204,7 @@ impl MigratorTrait for Migrator {
Box::new(m0000790_alter_sbom_alter_document_id::Migration),
Box::new(m0000800_alter_product_version_range_scheme::Migration),
Box::new(m0000810_fix_get_purl::Migration),
Box::new(m0000820_create_conversation::Migration),
]
}
}
Expand Down
76 changes: 76 additions & 0 deletions migration/src/m0000820_create_conversation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use crate::UuidV4;
use sea_orm_migration::prelude::*;

#[derive(DeriveMigrationName)]
pub struct Migration;

#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(Conversation::Table)
.col(
ColumnDef::new(Conversation::Id)
.uuid()
.not_null()
.default(Func::cust(UuidV4))
.primary_key(),
)
.col(ColumnDef::new(Conversation::UserId).string().not_null())
.col(ColumnDef::new(Conversation::State).json_binary().not_null())
.col(ColumnDef::new(Conversation::Seq).integer().not_null())
.col(ColumnDef::new(Conversation::Summary).string().not_null())
.col(
ColumnDef::new(Conversation::UpdatedAt)
.timestamp_with_time_zone()
.not_null(),
)
.to_owned(),
)
.await?;

// this index should speed up lookup up the most recent conversations for a user
manager
.create_index(
Index::create()
.table(Conversation::Table)
.name(Conversation::ConverstationUserIdUpdatedAtIdx.to_string())
.col(Conversation::UserId)
.col(Conversation::UpdatedAt)
.to_owned(),
)
.await?;

Ok(())
}

async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_index(
Index::drop()
.if_exists()
.table(Conversation::Table)
.name(Conversation::ConverstationUserIdUpdatedAtIdx.to_string())
.to_owned(),
)
.await?;

manager
.drop_table(Table::drop().table(Conversation::Table).to_owned())
.await
}
}

#[derive(DeriveIden)]
enum Conversation {
Table,
Id,
UserId,
ConverstationUserIdUpdatedAtIdx,
State,
Seq,
Summary,
UpdatedAt,
}
94 changes: 93 additions & 1 deletion modules/fundamental/src/ai/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@ use langchain_rust::{
prompt_args,
tools::Tool,
};
use sea_orm::ConnectionTrait;
use sea_orm::{prelude::Uuid, ColumnTrait, EntityTrait, QueryFilter, QueryOrder};
use sea_orm::{ActiveModelTrait, ConnectionTrait, Set};

use serde_json::Value;
use std::env;
use std::sync::Arc;
use time::OffsetDateTime;
use tokio::sync::OnceCell;

use trustify_common::db::{limiter::LimiterTrait, query::Filtering};

use trustify_common::db::query::q;
use trustify_common::db::Database;
use trustify_common::model::{Paginated, PaginatedResults};
use trustify_entity::conversation;

pub const PREFIX: &str = include_str!("prefix.txt");

Expand Down Expand Up @@ -284,6 +294,88 @@ impl AiService {

Ok(response)
}

pub async fn create_conversation<C: ConnectionTrait>(
&self,
user_id: String,
state: Value,
summary: String,
connection: &C,
) -> Result<conversation::Model, Error> {
let model = conversation::ActiveModel {
id: Default::default(),
user_id: Set(user_id),
state: Set(state),
seq: Set(0),
summary: Set(summary),
updated_at: Set(OffsetDateTime::now_utc()),
};
Ok(model.insert(connection).await?)
}

pub async fn update_conversation<C: ConnectionTrait>(
&self,
conversation_id: Uuid,
state: Value,
summary: String,
seq: i32,
connection: &C,
) -> Result<conversation::Model, Error> {
let model = conversation::ActiveModel {
id: Set(conversation_id),
state: Set(state),
summary: Set(summary),
seq: Set(seq),
updated_at: Set(OffsetDateTime::now_utc()),
..Default::default()
};

let result = conversation::Entity::update(model)
.filter(conversation::Column::Seq.lte(seq))
.exec(connection)
.await?;

Ok(result)
}

pub async fn fetch_conversation<C: ConnectionTrait>(
&self,
id: Uuid,
connection: &C,
) -> Result<Option<conversation::Model>, Error> {
let select = conversation::Entity::find().filter(conversation::Column::Id.eq(id));

Ok(select.one(connection).await?)
}

pub async fn fetch_conversations<C: ConnectionTrait + Sync + Send>(
&self,
user_id: String,
paginated: Paginated,
connection: &C,
) -> Result<PaginatedResults<conversation::Model>, Error> {
let limiter = conversation::Entity::find()
.order_by_desc(conversation::Column::UpdatedAt)
.filtering(q(format!("user_id={}", user_id).as_str()))?
.limiting(connection, paginated.offset, paginated.limit);

let total = limiter.total().await?;

Ok(PaginatedResults {
total,
items: limiter.fetch().await?,
})
}

pub async fn delete_conversation<C: ConnectionTrait>(
&self,
id: Uuid,
connection: &C,
) -> Result<u64, Error> {
let query = conversation::Entity::delete_by_id(id);
let result = query.exec(connection).await?;
Ok(result.rows_affected)
}
}

#[cfg(test)]
Expand Down
83 changes: 83 additions & 0 deletions modules/fundamental/src/ai/service/test.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::ai::model::ChatState;
use crate::ai::service::AiService;
use serde_json::json;

use test_context::test_context;
use test_log::test;

use trustify_common::hashing::Digests;
use trustify_common::model::Paginated;
use trustify_module_ingestor::graph::product::ProductInformation;
use trustify_test_context::TrustifyContext;

Expand Down Expand Up @@ -178,3 +181,83 @@ async fn test_completions_advisory_info(ctx: &TrustifyContext) -> Result<(), any

Ok(())
}

#[test_context(TrustifyContext)]
#[test(actix_web::test)]
async fn conversation_crud(ctx: &TrustifyContext) -> Result<(), anyhow::Error> {
let service = AiService::new(ctx.db.clone());

// create a conversation
let value1 = json!({"test":"value1"});
let conversation = service
.create_conversation("user_a".into(), value1.clone(), "summary".into(), &ctx.db)
.await?;

assert_eq!("user_a", conversation.user_id);
assert_eq!(value1, conversation.state);
assert_eq!("summary", conversation.summary);
assert_eq!(0i32, conversation.seq);
let conversation_id = conversation.id;

// get the created conversation
let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?;

assert_eq!(Some(conversation.clone()), fetched);

// list the conversations of the user
let converstations = service
.fetch_conversations(
"user_a".into(),
Paginated {
offset: 0,
limit: 10,
},
&ctx.db,
)
.await?;

assert_eq!(1, converstations.total);
assert_eq!(1, converstations.items.len());
assert_eq!(conversation, converstations.items[0]);

let value2 = json!({"test":"value2"});
service
.update_conversation(
conversation_id,
value2.clone(),
"summary2".into(),
1,
&ctx.db,
)
.await?;

// get the updated conversation
let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?;

assert_eq!(value2, fetched.unwrap().state);

// verify that the update fails due to old seq
service
.update_conversation(
conversation_id,
json!({"test":"bad"}),
"summary2".into(),
0,
&ctx.db,
)
.await
.expect_err("should fail due to old seq");

// delete the conversation
let delete_count = service
.delete_conversation(conversation_id, &ctx.db)
.await?;
assert_eq!(delete_count, 1u64);

// get the deleted conversation
let fetched = service.fetch_conversation(conversation_id, &ctx.db).await?;

assert_eq!(None, fetched);

Ok(())
}

0 comments on commit b987e9e

Please sign in to comment.