Skip to content

Commit

Permalink
Consolidate send analytics request methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Westwooo committed Nov 1, 2024
1 parent 8503280 commit 5d9a1b6
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 124 deletions.
133 changes: 50 additions & 83 deletions src/cli/analytics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use crate::cli::util::{
cluster_identifiers_from, convert_json_value_to_nu_value, convert_row_to_nu_value,
duration_to_golang_string, get_active_cluster,
};
use crate::client::{AnalyticsQueryRequest, HTTPClient, HttpResponse};
use crate::client::{http_handler::read_stream, AnalyticsQueryRequest};
use crate::state::State;
use crate::RemoteCluster;
use bytes::Bytes;
use futures::StreamExt;
use futures_core::Stream;
use log::debug;
use nu_engine::CallExt;
use nu_protocol::ast::Call;
Expand All @@ -21,7 +23,6 @@ use std::ops::Add;
use std::str::from_utf8;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::Instant;
use tokio_stream::StreamMap;
Expand Down Expand Up @@ -138,44 +139,43 @@ fn run(
let statement: String = call.req(engine_state, stack, 0)?;

let scope: Option<String> = call.get_flag(engine_state, stack, "scope")?;
let _with_meta = call.has_flag(engine_state, stack, "with-meta")?;

debug!("Running analytics query {}", &statement);

let mut streams = StreamMap::new();
let rt = Runtime::new().unwrap();
let rt = Arc::new(Runtime::new().unwrap());
for identifier in cluster_identifiers.clone() {
let active_cluster = get_active_cluster(identifier.clone(), &guard, span)?;
let bucket = call
.get_flag(engine_state, stack, "bucket")?
.or_else(|| active_cluster.active_bucket());
let maybe_scope = bucket.and_then(|b| scope.clone().map(|s| (b, s)));

// Needs doing outside the 'block_on' as trying dns resolution also blocks, causing a
// panic
let client = active_cluster.cluster().http_client();
let timeout = active_cluster.timeouts().analytics_timeout();

let stream = rt.block_on(async {
send_analytics_query_stream(
client,
timeout,
maybe_scope,
statement.clone(),
ctrl_c.clone(),
span,
)
.await
})?;
let stream = send_analytics_query(
active_cluster,
maybe_scope,
statement.clone(),
ctrl_c.clone(),
span,
rt.clone(),
)?;

streams.insert(identifier, stream);
let json_stream = JsonRowStream::new(stream);
let mut json_streamer = RawJsonRowStreamer::new(json_stream, "".to_string());

rt.block_on(async {
// Read prelude and signature
json_streamer.read_prelude().await.unwrap();
json_streamer.read_row().await.unwrap();

// Read row containing `results: [`
json_streamer.read_row().await.unwrap();
});

streams.insert(identifier, json_streamer);
}

let result_stream = AnalyticsStream {
streams,
span,
rt: Arc::new(rt),
};
let result_stream = AnalyticsStream { streams, span, rt };

Ok(PipelineData::from(ListStream::new(
result_stream,
Expand All @@ -184,80 +184,40 @@ fn run(
)))
}

pub async fn send_analytics_query_stream(
client: HTTPClient,
timeout: Duration,
scope: impl Into<Option<(String, String)>>,
statement: impl Into<String>,
ctrl_c: Arc<AtomicBool>,
span: Span,
) -> Result<RawJsonRowStreamer, ShellError> {
let (stream, status) = client
.analytics_query_stream_request(
AnalyticsQueryRequest::Execute {
statement: statement.into(),
scope: scope.into(),
timeout: duration_to_golang_string(timeout),
},
Instant::now().add(timeout),
ctrl_c.clone(),
)
.await
.map_err(|e| client_error_to_shell_error(e, span))?;

let json_stream = JsonRowStream::new(stream);
let mut json_streamer = RawJsonRowStreamer::new(json_stream, "".to_string());

// Read prelude and signature
json_streamer.read_prelude().await.unwrap();
json_streamer.read_row().await.unwrap();

if status != 200 {
let error_msg = if let Some(chunk) = json_streamer.read_row().await? {
from_utf8(&chunk).unwrap().to_string()
} else {
"could not parse errors from stream".to_string()
};

return Err(unexpected_status_code_error(status, error_msg, span));
}

// Read row containing `results: [`
json_streamer.read_row().await.unwrap();

Ok(json_streamer)
}

pub fn send_analytics_query(
active_cluster: &RemoteCluster,
scope: impl Into<Option<(String, String)>>,
statement: impl Into<String>,
ctrl_c: Arc<AtomicBool>,
span: Span,
) -> Result<HttpResponse, ShellError> {
let response = active_cluster
rt: Arc<Runtime>,
) -> Result<impl Stream<Item = Result<Bytes, reqwest::Error>> + Sized, ShellError> {
let (stream, status) = active_cluster
.cluster()
.http_client()
.analytics_query_request(
.analytics_query_stream_request(
AnalyticsQueryRequest::Execute {
statement: statement.into(),
scope: scope.into(),
timeout: duration_to_golang_string(active_cluster.timeouts().analytics_timeout()),
},
Instant::now().add(active_cluster.timeouts().analytics_timeout()),
ctrl_c,
ctrl_c.clone(),
rt.clone(),
)
.map_err(|e| client_error_to_shell_error(e, span))?;

if response.status() != 200 {
return Err(unexpected_status_code_error(
response.status(),
response.content(),
span,
));
if status != 200 {
let error_msg = rt.block_on(async {
read_stream(stream)
.await
.map_err(|e| client_error_to_shell_error(e, span))
})?;

return Err(unexpected_status_code_error(status, error_msg, span));
}

Ok(response)
Ok(stream)
}

pub fn do_analytics_query(
Expand All @@ -270,9 +230,16 @@ pub fn do_analytics_query(
with_meta: bool,
could_contain_mutations: bool,
) -> Result<Vec<Value>, ShellError> {
let response = send_analytics_query(active_cluster, scope, statement, ctrl_c, span)?;
let rt = Arc::new(Runtime::new().unwrap());
let response =
send_analytics_query(active_cluster, scope, statement, ctrl_c, span, rt.clone())?;
let response_content = rt.block_on(async {
read_stream(response)
.await
.map_err(|e| client_error_to_shell_error(e, span))
})?;

let content: serde_json::Value = serde_json::from_str(response.content())
let content: serde_json::Value = serde_json::from_str(&response_content)
.map_err(|e| deserialize_error(e.to_string(), span))?;

let mut results: Vec<Value> = vec![];
Expand Down
81 changes: 42 additions & 39 deletions src/client/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,58 +338,61 @@ impl HTTPClient {
})
}

pub async fn analytics_query_stream_request(
pub fn analytics_query_stream_request(
&self,
request: AnalyticsQueryRequest,
deadline: Instant,
ctrl_c: Arc<AtomicBool>,
rt: Arc<Runtime>,
) -> Result<
(
impl Stream<Item = Result<Bytes, reqwest::Error>> + Sized,
u16,
),
ClientError,
> {
let config: ClusterConfig = HTTPClient::get_config(
&self.seeds,
self.tls_enabled,
&self.http_client,
None,
deadline,
ctrl_c.clone(),
)
.await?;

let path = request.path();
if let Some(seed) = config.random_analytics_seed(self.tls_enabled) {
let uri = format!("{}:{}{}", seed.hostname(), seed.port(), &path);
let (stream, status) = match request.verb() {
// HttpVerb::Get => self.http_client.http_get(&uri, deadline, ctrl_c).await?,
HttpVerb::Post => {
self.http_client
.http_post_stream(
&uri,
request.payload(),
request.headers(),
deadline,
ctrl_c,
)
.await?
}
_ => {
return Err(ClientError::RequestFailed {
reason: Some("Method not allowed for analytics queries".to_string()),
key: None,
});
}
};
rt.block_on(async {
let config: ClusterConfig = HTTPClient::get_config(
&self.seeds,
self.tls_enabled,
&self.http_client,
None,
deadline,
ctrl_c.clone(),
)
.await?;

return Ok((stream, status));
}
let path = request.path();
if let Some(seed) = config.random_analytics_seed(self.tls_enabled) {
let uri = format!("{}:{}{}", seed.hostname(), seed.port(), &path);
let (stream, status) = match request.verb() {
// HttpVerb::Get => self.http_client.http_get(&uri, deadline, ctrl_c).await?,
HttpVerb::Post => {
self.http_client
.http_post_stream(
&uri,
request.payload(),
request.headers(),
deadline,
ctrl_c,
)
.await?
}
_ => {
return Err(ClientError::RequestFailed {
reason: Some("Method not allowed for analytics queries".to_string()),
key: None,
});
}
};

return Ok((stream, status));
}

Err(ClientError::RequestFailed {
reason: Some("No nodes found for service".to_string()),
key: None,
Err(ClientError::RequestFailed {
reason: Some("No nodes found for service".to_string()),
key: None,
})
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/client/http_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl HTTPHandler {
}
}

async fn read_stream(
pub async fn read_stream(
mut stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Sized + std::marker::Unpin,
) -> Result<String, ClientError> {
let mut content = "".to_string();
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod crc;
mod error;
mod gemini_client;
pub(crate) mod http_client;
mod http_handler;
pub(crate) mod http_handler;
mod kv;
mod kv_client;
mod llm_client;
Expand Down

0 comments on commit 5d9a1b6

Please sign in to comment.