From d7d079e2b03367982c3bb50ddcdf2e648b1f45e2 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 20 Apr 2024 13:40:33 +0200 Subject: [PATCH] stream() also returns chat_tibble --- NAMESPACE | 1 + R/stream.R | 45 +++++++++++++++++++++++++++++++-------------- R/zzz.R | 2 +- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 4b2dd9f..7ca249b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -20,4 +20,5 @@ importFrom(purrr,map) importFrom(purrr,map2) importFrom(purrr,map_chr) importFrom(purrr,pluck) +importFrom(purrr,walk) importFrom(utils,tail) diff --git a/R/stream.R b/R/stream.R index bf011eb..69e79ad 100644 --- a/R/stream.R +++ b/R/stream.R @@ -5,25 +5,42 @@ stream <- function(..., model = "mistral-tiny", error_call = current_env()) { messages <- as_messages(..., error_call = error_call) req <- req_chat(messages, model, stream = TRUE, error_call = error_call) - resp <- req_perform_stream(req, callback = stream_callback, round = "line", buffer_kb = 0.01) - invisible(resp) + streamer <- mistral_stream() + resp <- req_perform_stream(req, callback = streamer$callback, round = "line", buffer_kb = 0.01) + + tbl_req <- list_rbind(map(messages, as_tibble)) + tbl_resp <- tibble( + role = "assistant", + content = paste0(streamer$tokens, collapse = "") + ) + tbl <- list_rbind(list(tbl_req, tbl_resp)) + + class(tbl) <- c("stream_tibble", "chat_tibble", class(tbl)) + attr(tbl, "resp") <- resp + invisible(tbl) } -stream_callback <- function(x) { - txt <- rawToChar(x) +mistral_stream <- function() { + tokens <- list() + + callback <- function(x) { + txt <- rawToChar(x) - lines <- str_split(txt, "\n")[[1]] - lines <- lines[lines != ""] - lines <- str_replace_all(lines, "^data: ", "") - lines <- lines[lines != "[DONE]"] + lines <- str_split(txt, "\n")[[1]] + lines <- lines[lines != ""] + lines <- str_replace_all(lines, "^data: ", "") + lines <- lines[lines != "[DONE]"] - tokens <- map_chr(lines, \(line) { - chunk <- fromJSON(line) - chunk$choices$delta$content - }) + tok <- map_chr(lines, \(line) { + json <- fromJSON(line) + json$choices$delta$content + }) + tokens <<- c(tokens, tok) + cat(tok) - cat(tokens) + TRUE + } - TRUE + environment() } diff --git a/R/zzz.R b/R/zzz.R index aa92f2d..73cac6c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -4,7 +4,7 @@ #' @import tibble #' @import stringr #' @import slap -#' @importFrom purrr list_rbind map map_chr pluck map2 list_flatten +#' @importFrom purrr list_rbind map map_chr pluck map2 list_flatten walk #' @importFrom jsonlite fromJSON #' @importFrom utils tail NULL