Skip to content

Commit

Permalink
Merge pull request #353 from Kuadrant/optimize-metrics-layer
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-cattermole authored Jun 17, 2024
2 parents a5a13fa + 1a579fb commit 71ea11e
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 141 deletions.
150 changes: 91 additions & 59 deletions limitador-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use std::{env, process};
use tracing_subscriber::Layer;

#[cfg(feature = "distributed_storage")]
use clap::parser::ValuesRef;
Expand All @@ -52,9 +51,10 @@ use sysinfo::{MemoryRefreshKind, RefreshKind, System};
use thiserror::Error;
use tokio::runtime::Handle;
use tracing::level_filters::LevelFilter;
use tracing::Subscriber;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{layer::SubscriberExt, Layer};

mod envoy_rls;
mod http_api;
Expand Down Expand Up @@ -217,63 +217,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let (config, version) = create_config();
println!("{LIMITADOR_HEADER} {version}");

let level = config.log_level.unwrap_or_else(|| {
tracing_subscriber::filter::EnvFilter::from_default_env()
.max_level_hint()
.unwrap_or(LevelFilter::ERROR)
});

let fmt_layer = tracing_subscriber::fmt::layer()
.with_span_events(if level >= LevelFilter::DEBUG {
FmtSpan::CLOSE
} else {
FmtSpan::NONE
})
.with_filter(level);

let metrics_layer = MetricsLayer::new()
.gather(
"should_rate_limit",
PrometheusMetrics::record_datastore_latency,
vec!["datastore"],
)
.gather(
"flush_batcher_and_update_counters",
PrometheusMetrics::record_datastore_latency,
vec!["datastore"],
);

if !config.tracing_endpoint.is_empty() {
global::set_text_map_propagator(TraceContextPropagator::new());

let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(config.tracing_endpoint.clone()),
)
.with_trace_config(trace::config().with_resource(Resource::new(vec![
KeyValue::new("service.name", "limitador"),
])))
.install_batch(opentelemetry_sdk::runtime::Tokio)?;

let telemetry_layer = tracing_opentelemetry::layer().with_tracer(tracer);

// Init tracing subscriber with telemetry
tracing_subscriber::registry()
.with(metrics_layer)
.with(fmt_layer)
.with(level.max(LevelFilter::INFO))
.with(telemetry_layer)
.init();
} else {
// Init tracing subscriber without telemetry
tracing_subscriber::registry()
.with(metrics_layer)
.with(fmt_layer)
.init();
};
configure_tracing_subscriber(&config);

info!("Version: {}", version);
info!("Using config: {:?}", config);
Expand Down Expand Up @@ -808,3 +752,91 @@ fn guess_cache_size() -> Option<u64> {
fn leak<D: Display>(s: D) -> &'static str {
return Box::leak(format!("{}", s).into_boxed_str());
}

fn configure_tracing_subscriber(config: &Configuration) {
let level = config.log_level.unwrap_or_else(|| {
tracing_subscriber::filter::EnvFilter::from_default_env()
.max_level_hint()
.unwrap_or(LevelFilter::ERROR)
});

let metrics_layer = MetricsLayer::default()
.gather(
"should_rate_limit",
PrometheusMetrics::record_datastore_latency,
vec!["datastore"],
)
.gather(
"flush_batcher_and_update_counters",
PrometheusMetrics::record_datastore_latency,
vec!["datastore"],
);

if !config.tracing_endpoint.is_empty() {
// Init tracing subscriber with telemetry
// If running in memory initialize without metrics
match config.storage {
StorageConfiguration::InMemory(_) => tracing_subscriber::registry()
.with(fmt_layer(level))
.with(telemetry_layer(&config.tracing_endpoint, level))
.init(),
_ => tracing_subscriber::registry()
.with(metrics_layer)
.with(fmt_layer(level))
.with(telemetry_layer(&config.tracing_endpoint, level))
.init(),
}
} else {
// If running in memory initialize without metrics
match config.storage {
StorageConfiguration::InMemory(_) => {
tracing_subscriber::registry().with(fmt_layer(level)).init()
}
_ => tracing_subscriber::registry()
.with(metrics_layer)
.with(fmt_layer(level))
.init(),
}
}
}

fn fmt_layer<S>(level: LevelFilter) -> impl Layer<S>
where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
{
tracing_subscriber::fmt::layer()
.with_span_events(if level >= LevelFilter::DEBUG {
FmtSpan::CLOSE
} else {
FmtSpan::NONE
})
.with_filter(level)
}

fn telemetry_layer<S>(tracing_endpoint: &String, level: LevelFilter) -> impl Layer<S>
where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
{
global::set_text_map_propagator(TraceContextPropagator::new());

let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(tracing_endpoint),
)
.with_trace_config(
trace::config().with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"limitador",
)])),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)
.expect("error installing tokio tracing exporter");

// Set the level to minimum info if tracing enabled
tracing_opentelemetry::layer()
.with_tracer(tracer)
.with_filter(level.max(LevelFilter::INFO))
}
105 changes: 46 additions & 59 deletions limitador-server/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,16 @@ struct SpanState {

impl SpanState {
fn new(group: String) -> Self {
Self {
group_times: HashMap::from([(group, Timings::new())]),
}
let mut group_times = HashMap::new();
group_times.insert(group, Timings::new());
Self { group_times }
}

fn increment(&mut self, group: &String, timings: Timings) -> &mut Self {
fn increment(&mut self, group: String, timings: Timings) {
self.group_times
.entry(group.to_string())
.entry(group)
.and_modify(|x| *x += timings)
.or_insert(timings);
self
}
}

Expand All @@ -82,23 +81,18 @@ impl MetricsGroup {
}
}

#[derive(Default)]
pub struct MetricsLayer {
groups: HashMap<String, MetricsGroup>,
}

impl MetricsLayer {
pub fn new() -> Self {
Self {
groups: HashMap::new(),
}
}

pub fn gather(mut self, aggregate: &str, consumer: fn(Timings), records: Vec<&str>) -> Self {
// TODO(adam-cattermole): does not handle case where aggregate already exists
let rec = records.iter().map(|r| r.to_string()).collect();
let rec = records.iter().map(|&r| r.to_string()).collect();
self.groups
.entry(aggregate.to_string())
.or_insert(MetricsGroup::new(Box::new(consumer), rec));
.or_insert_with(|| MetricsGroup::new(Box::new(consumer), rec));
self
}
}
Expand All @@ -111,7 +105,7 @@ where
fn on_new_span(&self, _attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
let span = ctx.span(id).expect("Span not found, this is a bug");
let mut extensions = span.extensions_mut();
let name = span.name();
let name = span.name().to_string();

// if there's a parent
if let Some(parent) = span.parent() {
Expand All @@ -122,88 +116,81 @@ where
}

// if we are an aggregator
if self.groups.contains_key(name) {
if self.groups.contains_key(&name) {
if let Some(span_state) = extensions.get_mut::<SpanState>() {
// if the SpanState has come from parent and we must append
// (we are a second level aggregator)
span_state
.group_times
.entry(name.to_string())
.or_insert(Timings::new());
.entry(name.clone())
.or_insert_with(Timings::new);
} else {
// otherwise create a new SpanState with ourselves
extensions.insert(SpanState::new(name.to_string()))
extensions.insert(SpanState::new(name.to_owned()))
}
}

if let Some(span_state) = extensions.get_mut::<SpanState>() {
// either we are an aggregator or nested within one
for group in span_state.group_times.keys() {
for record in &self
if self
.groups
.get(group)
.expect("Span state contains group times for an unconfigured group")
.records
.contains(&name)
{
if name == record {
extensions.insert(Timings::new());
return;
}
extensions.insert(Timings::new());
return;
}
}
// if here we are an intermediate span that should not be recorded
}
}

fn on_enter(&self, id: &Id, ctx: Context<'_, S>) {
let span = ctx.span(id).expect("Span not found, this is a bug");
let mut extensions = span.extensions_mut();
if let Some(span) = ctx.span(id) {
if let Some(timings) = span.extensions_mut().get_mut::<Timings>() {
let now = Instant::now();
timings.idle += (now - timings.last).as_nanos() as u64;
timings.last = now;

if let Some(timings) = extensions.get_mut::<Timings>() {
let now = Instant::now();
timings.idle += (now - timings.last).as_nanos() as u64;
timings.last = now;
timings.updated = true;
timings.updated = true;
}
}
}

fn on_exit(&self, id: &Id, ctx: Context<'_, S>) {
let span = ctx.span(id).expect("Span not found, this is a bug");
let mut extensions = span.extensions_mut();

if let Some(timings) = extensions.get_mut::<Timings>() {
let now = Instant::now();
timings.busy += (now - timings.last).as_nanos() as u64;
timings.last = now;
timings.updated = true;
if let Some(span) = ctx.span(id) {
if let Some(timings) = span.extensions_mut().get_mut::<Timings>() {
let now = Instant::now();
timings.busy += (now - timings.last).as_nanos() as u64;
timings.last = now;
timings.updated = true;
}
}
}

fn on_close(&self, id: Id, ctx: Context<'_, S>) {
let span = ctx.span(&id).expect("Span not found, this is a bug");
let mut extensions = span.extensions_mut();
let name = span.name();
let mut t: Option<Timings> = None;
let name = span.name().to_string();

if let Some(timing) = extensions.get_mut::<Timings>() {
let mut time = *timing;
time.idle += (Instant::now() - time.last).as_nanos() as u64;
t = Some(time);
}
let timing = extensions.get_mut::<Timings>().map(|t| {
let now = Instant::now();
t.idle += (now - t.last).as_nanos() as u64;
*t
});

if let Some(span_state) = extensions.get_mut::<SpanState>() {
if let Some(timing) = t {
let group_times = span_state.group_times.clone();
if let Some(timing) = timing {
// iterate over the groups this span belongs to
'aggregate: for group in group_times.keys() {
for group in span_state.group_times.keys().cloned().collect::<Vec<_>>() {
// find the set of records related to these groups in the layer
for record in &self.groups.get(group).unwrap().records {
if self.groups.get(&group).unwrap().records.contains(&name) {
// if we are a record for this group then increment the relevant
// span-local timing and continue to the next group
if name == record {
span_state.increment(group, timing);
continue 'aggregate;
}
span_state.increment(group, timing);
}
}
}
Expand All @@ -214,8 +201,8 @@ where
parent.extensions_mut().replace(span_state.clone());
}
// IF we are aggregator call consume function
if let Some(metrics_group) = self.groups.get(name) {
if let Some(t) = span_state.group_times.get(name).filter(|&t| t.updated) {
if let Some(metrics_group) = self.groups.get(&name) {
if let Some(t) = span_state.group_times.get(&name).filter(|&t| t.updated) {
(metrics_group.consumer)(*t);
}
}
Expand Down Expand Up @@ -285,22 +272,22 @@ mod tests {
#[test]
fn span_state_increment() {
let group = String::from("group");
let mut span_state = SpanState::new(group.clone());
let mut span_state = SpanState::new(group.to_owned());
let t1 = Timings {
idle: 5,
busy: 5,
last: Instant::now(),
updated: true,
};
span_state.increment(&group, t1);
span_state.increment(group.to_owned(), t1);
assert_eq!(span_state.group_times.get(&group).unwrap().idle, t1.idle);
assert_eq!(span_state.group_times.get(&group).unwrap().busy, t1.busy);
}

#[test]
fn metrics_layer() {
let consumer = |_| println!("group/record");
let ml = MetricsLayer::new().gather("group", consumer, vec!["record"]);
let ml = MetricsLayer::default().gather("group", consumer, vec!["record"]);
assert_eq!(ml.groups.get("group").unwrap().records, vec!["record"]);
}
}
Loading

0 comments on commit 71ea11e

Please sign in to comment.