Skip to content

Commit

Permalink
refactor: do not refresh documents when .edit rag-docs (#1032)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Dec 3, 2024
1 parent 46348df commit e86a9f3
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 109 deletions.
2 changes: 1 addition & 1 deletion src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ async fn load_documents(
let mut medias = vec![];
let mut data_urls = HashMap::new();
let loaders = config.read().document_loaders.clone();
let local_files = expand_glob_paths(&local_paths).await?;
let local_files = expand_glob_paths(&local_paths, true).await?;
for file_path in local_files {
if is_image(&file_path) {
let data_url = read_media_to_data_url(&file_path)
Expand Down
4 changes: 2 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,7 @@ impl Config {
if new_document_paths.is_empty() || new_document_paths == document_paths {
bail!("No changes")
}
rag.refresh_document_paths(&new_document_paths, config, abort_signal)
rag.refresh_document_paths(&new_document_paths, false, config, abort_signal)
.await?;
config.write().rag = Some(Arc::new(rag));
Ok(())
Expand All @@ -1368,7 +1368,7 @@ impl Config {
None => bail!("No RAG"),
};
let document_paths = rag.document_paths().to_vec();
rag.refresh_document_paths(&document_paths, config, abort_signal)
rag.refresh_document_paths(&document_paths, true, config, abort_signal)
.await?;
config.write().rag = Some(Arc::new(rag));
Ok(())
Expand Down
69 changes: 0 additions & 69 deletions src/rag/loader.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,11 @@
use super::*;

use anyhow::{Context, Result};
use path_absolutize::Absolutize;
use std::collections::HashMap;

pub const EXTENSION_METADATA: &str = "__extension__";
pub const PATH_METADATA: &str = "__path__";

pub async fn load_document(
loaders: &HashMap<String, String>,
path: &str,
has_error: &mut bool,
) -> (String, Vec<(String, RagMetadata)>) {
let mut path = path.to_string();
let mut maybe_error = None;
let mut files = vec![];
if is_url(&path) {
if let Some(path) = path.strip_suffix("**") {
match load_recursive_url(loaders, path).await {
Ok(v) => files.extend(v),
Err(err) => maybe_error = Some(err),
}
} else {
match load_url(loaders, &path).await {
Ok(v) => files.push(v),
Err(err) => maybe_error = Some(err),
}
}
} else {
match Path::new(&path).absolutize() {
Ok(v) => {
path = v.display().to_string();
match load_path(loaders, &path, has_error).await {
Ok(v) => files.extend(v),
Err(err) => maybe_error = Some(err),
}
}
Err(_) => {
maybe_error = Some(anyhow!("Invalid path"));
}
};
}
if let Some(err) = maybe_error {
*has_error = true;
println!("{}", warning_text(&format!("⚠️ {err:?}")));
}
(path, files)
}

pub async fn load_recursive_url(
loaders: &HashMap<String, String>,
path: &str,
Expand Down Expand Up @@ -76,33 +34,6 @@ pub async fn load_recursive_url(
Ok(output)
}

pub async fn load_path(
loaders: &HashMap<String, String>,
path: &str,
has_error: &mut bool,
) -> Result<Vec<(String, RagMetadata)>> {
let file_paths = expand_glob_paths(&[path]).await?;
let mut output = vec![];
let file_paths_len = file_paths.len();
match file_paths_len {
0 => {}
1 => output.push(load_file(loaders, &file_paths[0]).await?),
_ => {
for path in file_paths {
println!("Load {path}");
match load_file(loaders, &path).await {
Ok(v) => output.push(v),
Err(err) => {
*has_error = true;
println!("{}", warning_text(&format!("Error: {err:?}")));
}
}
}
}
}
Ok(output)
}

pub async fn load_file(
loaders: &HashMap<String, String>,
path: &str,
Expand Down
155 changes: 126 additions & 29 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use hnsw_rs::prelude::*;
use indexmap::{IndexMap, IndexSet};
use inquire::{required, validator::Validation, Confirm, Select, Text};
use parking_lot::RwLock;
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration};
Expand Down Expand Up @@ -90,7 +91,7 @@ impl Rag {
let loaders = config.read().document_loaders.clone();
let (spinner, spinner_rx) = Spinner::create("");
abortable_run_with_spinner_rx(
rag.sync_documents(loaders, &paths, Some(spinner)),
rag.sync_documents(&paths, true, loaders, Some(spinner)),
spinner_rx,
abort_signal,
)
Expand Down Expand Up @@ -129,19 +130,17 @@ impl Rag {
&self.data.document_paths
}

pub async fn refresh_document_paths<T>(
pub async fn refresh_document_paths(
&mut self,
document_paths: &[T],
document_paths: &[String],
refresh: bool,
config: &GlobalConfig,
abort_signal: AbortSignal,
) -> Result<()>
where
T: AsRef<str>,
{
) -> Result<()> {
let loaders = config.read().document_loaders.clone();
let (spinner, spinner_rx) = Spinner::create("");
abortable_run_with_spinner_rx(
self.sync_documents(loaders, document_paths, Some(spinner)),
self.sync_documents(document_paths, refresh, loaders, Some(spinner)),
spinner_rx,
abort_signal,
)
Expand Down Expand Up @@ -320,31 +319,90 @@ impl Rag {
Ok((embeddings, ids))
}

pub async fn sync_documents<T: AsRef<str>>(
pub async fn sync_documents(
&mut self,
paths: &[String],
refresh: bool,
loaders: HashMap<String, String>,
paths: &[T],
spinner: Option<Spinner>,
) -> Result<()> {
if let Some(spinner) = &spinner {
let _ = spinner.set_message(String::new());
}
let (document_paths, mut recursive_urls, mut urls, mut local_paths) =
resolve_paths(paths).await?;
let mut to_deleted: IndexMap<String, Vec<FileId>> = Default::default();
if refresh {
for (file_id, file) in &self.data.files {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else {
let recursive_urls_cloned = recursive_urls.clone();
let match_recursive_url = |v: &str| {
recursive_urls_cloned
.iter()
.any(|start_url| v.starts_with(start_url))
};
recursive_urls = recursive_urls
.into_iter()
.filter(|v| !self.data.document_paths.contains(&format!("{v}**")))
.collect();
for (file_id, file) in &self.data.files {
if is_url(&file.path) {
if !urls.swap_remove(&file.path) && !match_recursive_url(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else if !local_paths.swap_remove(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
}
}

let mut document_paths = vec![];
let mut files = vec![];
let paths_len = paths.len();
let mut has_error = false;
for (index, path) in paths.iter().enumerate() {
let path = path.as_ref();
println!("Load {path} [{}/{paths_len}]", index + 1);
let (path, document_files) = load_document(&loaders, path, &mut has_error).await;
files.extend(document_files);
document_paths.push(path);
let mut index = 0;
let total = recursive_urls.len() + urls.len() + local_paths.len();
let handle_error = |error: anyhow::Error, has_error: &mut bool| {
println!("{}", warning_text(&format!("⚠️ {error}")));
*has_error = true;
};
for start_url in recursive_urls {
index += 1;
println!("Load {start_url}** [{index}/{total}]");
match load_recursive_url(&loaders, &start_url).await {
Ok(v) => files.extend(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for url in urls {
index += 1;
println!("Load {url} [{index}/{total}]");
match load_url(&loaders, &url).await {
Ok(v) => files.push(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for local_path in local_paths {
index += 1;
println!("Load {local_path} [{index}/{total}]");
match load_file(&loaders, &local_path).await {
Ok(v) => files.push(v),
Err(err) => handle_error(err, &mut has_error),
}
}

if has_error {
let mut aborted = true;
if *IS_STDOUT_TERMINAL && !document_paths.is_empty() {
if *IS_STDOUT_TERMINAL && total > 0 {
let ans = Confirm::new("Some documents failed to load. Continue?")
.with_default(false)
.prompt()?;
Expand All @@ -355,21 +413,24 @@ impl Rag {
}
}

let mut to_deleted: IndexMap<String, FileId> = Default::default();
for (file_id, file) in &self.data.files {
to_deleted.insert(file.hash.clone(), *file_id);
}

let mut rag_files = vec![];
for (contents, mut metadata) in files {
let path = match metadata.swap_remove(PATH_METADATA) {
Some(v) => v,
None => continue,
};
let hash = sha256(&contents);
if let Some(file_id) = to_deleted.get(&hash) {
if self.data.files[file_id].path == path {
to_deleted.swap_remove(&hash);
if let Some(file_ids) = to_deleted.get_mut(&hash) {
if let Some((i, _)) = file_ids
.iter()
.enumerate()
.find(|(_, v)| self.data.files[*v].path == path)
{
if file_ids.len() == 1 {
to_deleted.swap_remove(&hash);
} else {
file_ids.remove(i);
}
continue;
}
}
Expand Down Expand Up @@ -415,9 +476,10 @@ impl Rag {
.await?;
}

self.data.del(to_deleted.values().cloned().collect());
let to_delete_file_ids: Vec<_> = to_deleted.values().flatten().copied().collect();
self.data.del(to_delete_file_ids);
self.data.add(next_file_id, files, document_ids, embeddings);
self.data.document_paths = document_paths;
self.data.document_paths = document_paths.into_iter().collect();

if self.data.files.is_empty() {
bail!("No RAG files");
Expand Down Expand Up @@ -845,6 +907,41 @@ fn add_documents() -> Result<Vec<String>> {
Ok(paths)
}

async fn resolve_paths<T: AsRef<str>>(
paths: &[T],
) -> Result<(
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
)> {
let mut document_paths = IndexSet::new();
let mut recursive_urls = IndexSet::new();
let mut urls = IndexSet::new();
let mut absolute_paths = vec![];
for path in paths {
let path = path.as_ref().trim();
if is_url(path) {
if let Some(start_url) = path.strip_suffix("**") {
recursive_urls.insert(start_url.to_string());
} else {
urls.insert(path.to_string());
}
document_paths.insert(path.to_string());
} else {
let absolute_path = Path::new(path)
.absolutize()
.with_context(|| format!("Invalid path '{path}'"))?
.display()
.to_string();
absolute_paths.push(absolute_path.clone());
document_paths.insert(absolute_path);
}
}
let local_paths = expand_glob_paths(&absolute_paths, false).await?;
Ok((document_paths, recursive_urls, urls, local_paths))
}

fn progress(spinner: &Option<Spinner>, message: String) {
if let Some(spinner) = spinner {
let _ = spinner.set_message(message);
Expand Down
Loading

0 comments on commit e86a9f3

Please sign in to comment.