Skip to content

Commit 2410af3

Browse files
committed
migrate web::sitemap tests to axum test framework
1 parent 668dfb0 commit 2410af3

File tree

7 files changed

+173
-49
lines changed

7 files changed

+173
-49
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ procfs = "0.15.1"
113113
criterion = "0.5.1"
114114
kuchikiki = "0.8"
115115
http02 = { version = "0.2.11", package = "http"}
116+
http-body-util = "0.1.0"
116117
rand = "0.8"
117118
mockito = "1.0.2"
118119
test-case = "3.0.0"
119120
reqwest = { version = "0.12", features = ["blocking", "json"] }
121+
tower = { version = "0.5.1", features = ["util"] }
120122
aws-smithy-types = "1.0.1"
121123
aws-smithy-runtime = {version = "1.0.1", features = ["client", "test-util"]}
122124
aws-smithy-http = "0.60.0"

src/bin/cratesfyi.rs

+4
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,10 @@ impl Context for BinContext {
914914
};
915915
}
916916

917+
async fn async_pool(&self) -> Result<Pool> {
918+
self.pool()
919+
}
920+
917921
fn pool(&self) -> Result<Pool> {
918922
Ok(self
919923
.pool

src/context.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub trait Context {
1919
async fn async_storage(&self) -> Result<Arc<AsyncStorage>>;
2020
async fn cdn(&self) -> Result<Arc<CdnBackend>>;
2121
fn pool(&self) -> Result<Pool>;
22+
async fn async_pool(&self) -> Result<Pool>;
2223
fn service_metrics(&self) -> Result<Arc<ServiceMetrics>>;
2324
fn instance_metrics(&self) -> Result<Arc<InstanceMetrics>>;
2425
fn index(&self) -> Result<Arc<Index>>;

src/test/mod.rs

+108-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use crate::{
1212
ServiceMetrics,
1313
};
1414
use anyhow::Context as _;
15-
use axum::async_trait;
15+
use axum::{async_trait, body::Body, http::Request, response::Response as AxumResponse, Router};
1616
use fn_error_context::context;
1717
use futures_util::{stream::TryStreamExt, FutureExt};
18+
use http_body_util::BodyExt; // for `collect`
1819
use once_cell::sync::OnceCell;
1920
use reqwest::{
2021
blocking::{Client, ClientBuilder, RequestBuilder, Response},
@@ -27,6 +28,7 @@ use std::{
2728
};
2829
use tokio::runtime::{Builder, Runtime};
2930
use tokio::sync::oneshot::Sender;
31+
use tower::ServiceExt;
3032
use tracing::{debug, error, instrument, trace};
3133

3234
#[track_caller]
@@ -126,7 +128,6 @@ pub(crate) fn assert_success(path: &str, web: &TestFrontend) -> Result<()> {
126128
assert!(status.is_success(), "failed to GET {path}: {status}");
127129
Ok(())
128130
}
129-
130131
/// Make sure that a URL returns a status code between 200-299,
131132
/// also check the cache-control headers.
132133
pub(crate) fn assert_success_cached(
@@ -259,6 +260,96 @@ pub(crate) fn assert_redirect_cached(
259260
Ok(redirect_response)
260261
}
261262

263+
pub(crate) trait AxumResponseTestExt {
264+
async fn text(self) -> String;
265+
}
266+
267+
impl AxumResponseTestExt for axum::response::Response {
268+
async fn text(self) -> String {
269+
String::from_utf8_lossy(&self.into_body().collect().await.unwrap().to_bytes()).to_string()
270+
}
271+
}
272+
273+
pub(crate) trait AxumRouterTestExt {
274+
async fn assert_success(&self, path: &str) -> Result<()>;
275+
async fn get(&self, path: &str) -> Result<AxumResponse>;
276+
async fn assert_redirect_common(
277+
&self,
278+
path: &str,
279+
expected_target: &str,
280+
) -> Result<AxumResponse>;
281+
async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result<AxumResponse>;
282+
}
283+
284+
impl AxumRouterTestExt for axum::Router {
285+
/// Make sure that a URL returns a status code between 200-299
286+
async fn assert_success(&self, path: &str) -> Result<()> {
287+
let response = self
288+
.clone()
289+
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
290+
.await?;
291+
292+
let status = response.status();
293+
assert!(status.is_success(), "failed to GET {path}: {status}");
294+
Ok(())
295+
}
296+
/// simple `get` method
297+
async fn get(&self, path: &str) -> Result<AxumResponse> {
298+
Ok(self
299+
.clone()
300+
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
301+
.await?)
302+
}
303+
304+
async fn assert_redirect_common(
305+
&self,
306+
path: &str,
307+
expected_target: &str,
308+
) -> Result<AxumResponse> {
309+
let response = self.get(path).await?;
310+
let status = response.status();
311+
if !status.is_redirection() {
312+
anyhow::bail!("non-redirect from GET {path}: {status}");
313+
}
314+
315+
let redirect_target = response
316+
.headers()
317+
.get("Location")
318+
.context("missing 'Location' header")?
319+
.to_str()
320+
.context("non-ASCII redirect")?;
321+
322+
// FIXME: not sure we need this
323+
// if !expected_target.starts_with("http") {
324+
// // TODO: Should be able to use Url::make_relative,
325+
// // but https://github.com/servo/rust-url/issues/766
326+
// let base = format!("http://{}", web.server_addr());
327+
// redirect_target = redirect_target
328+
// .strip_prefix(&base)
329+
// .unwrap_or(redirect_target);
330+
// }
331+
332+
if redirect_target != expected_target {
333+
anyhow::bail!("got redirect to {redirect_target}");
334+
}
335+
336+
Ok(response)
337+
}
338+
339+
#[context("expected redirect from {path} to {expected_target}")]
340+
async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result<AxumResponse> {
341+
let redirect_response = self.assert_redirect_common(path, expected_target).await?;
342+
343+
let response = self.get(expected_target).await?;
344+
let status = response.status();
345+
if !status.is_success() {
346+
anyhow::bail!("failed to GET {expected_target}: {status}");
347+
}
348+
349+
Ok(redirect_response)
350+
}
351+
}
352+
262353
pub(crate) struct TestEnvironment {
263354
build_queue: OnceCell<Arc<BuildQueue>>,
264355
async_build_queue: tokio::sync::OnceCell<Arc<AsyncBuildQueue>>,
@@ -534,6 +625,13 @@ impl TestEnvironment {
534625
self.runtime().block_on(self.async_fake_release())
535626
}
536627

628+
pub(crate) async fn web_app(&self) -> Router {
629+
let template_data = Arc::new(TemplateData::new(1).unwrap());
630+
build_axum_app(self, template_data)
631+
.await
632+
.expect("could not build axum app")
633+
}
634+
537635
pub(crate) async fn async_fake_release(&self) -> fakes::FakeRelease {
538636
fakes::FakeRelease::new(
539637
self.async_db().await,
@@ -569,6 +667,10 @@ impl Context for TestEnvironment {
569667
Ok(TestEnvironment::cdn(self).await)
570668
}
571669

670+
async fn async_pool(&self) -> Result<Pool> {
671+
Ok(self.async_db().await.pool())
672+
}
673+
572674
fn pool(&self) -> Result<Pool> {
573675
Ok(self.db().pool())
574676
}
@@ -734,10 +836,12 @@ impl TestFrontend {
734836
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
735837

736838
debug!("building axum app");
737-
let axum_app = build_axum_app(context, template_data).expect("could not build axum app");
839+
let runtime = context.runtime().unwrap();
840+
let axum_app = runtime
841+
.block_on(build_axum_app(context, template_data))
842+
.expect("could not build axum app");
738843

739844
let handle = thread::spawn({
740-
let runtime = context.runtime().unwrap();
741845
move || {
742846
runtime.block_on(async {
743847
axum::serve(axum_listener, axum_app.into_make_service())

src/web/mod.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,16 @@ async fn set_sentry_transaction_name_from_axum_route(
393393
next.run(request).await
394394
}
395395

396-
fn apply_middleware(
396+
async fn apply_middleware(
397397
router: AxumRouter,
398398
context: &dyn Context,
399399
template_data: Option<Arc<TemplateData>>,
400400
) -> Result<AxumRouter> {
401401
let config = context.config()?;
402402
let has_templates = template_data.is_some();
403-
let runtime = context.runtime()?;
404-
let async_storage = runtime.block_on(context.async_storage())?;
405-
let build_queue = runtime.block_on(context.async_build_queue())?;
403+
404+
let async_storage = context.async_storage().await?;
405+
let build_queue = context.async_build_queue().await?;
406406

407407
Ok(router.layer(
408408
ServiceBuilder::new()
@@ -419,12 +419,11 @@ fn apply_middleware(
419419
.then_some(middleware::from_fn(log_timeouts_to_sentry)),
420420
))
421421
.layer(option_layer(config.request_timeout.map(TimeoutLayer::new)))
422-
.layer(Extension(context.pool()?))
422+
.layer(Extension(context.async_pool().await?))
423423
.layer(Extension(build_queue))
424424
.layer(Extension(context.service_metrics()?))
425425
.layer(Extension(context.instance_metrics()?))
426426
.layer(Extension(context.config()?))
427-
.layer(Extension(context.storage()?))
428427
.layer(Extension(async_storage))
429428
.layer(option_layer(template_data.map(Extension)))
430429
.layer(middleware::from_fn(csp::csp_middleware))
@@ -435,15 +434,15 @@ fn apply_middleware(
435434
))
436435
}
437436

438-
pub(crate) fn build_axum_app(
437+
pub(crate) async fn build_axum_app(
439438
context: &dyn Context,
440439
template_data: Arc<TemplateData>,
441440
) -> Result<AxumRouter, Error> {
442-
apply_middleware(routes::build_axum_routes(), context, Some(template_data))
441+
apply_middleware(routes::build_axum_routes(), context, Some(template_data)).await
443442
}
444443

445-
pub(crate) fn build_metrics_axum_app(context: &dyn Context) -> Result<AxumRouter, Error> {
446-
apply_middleware(routes::build_metric_routes(), context, None)
444+
pub(crate) async fn build_metrics_axum_app(context: &dyn Context) -> Result<AxumRouter, Error> {
445+
apply_middleware(routes::build_metric_routes(), context, None).await
447446
}
448447

449448
pub fn start_background_metrics_webserver(
@@ -458,8 +457,10 @@ pub fn start_background_metrics_webserver(
458457
axum_addr.port()
459458
);
460459

461-
let metrics_axum_app = build_metrics_axum_app(context)?.into_make_service();
462460
let runtime = context.runtime()?;
461+
let metrics_axum_app = runtime
462+
.block_on(build_metrics_axum_app(context))?
463+
.into_make_service();
463464

464465
runtime.spawn(async move {
465466
match tokio::net::TcpListener::bind(axum_addr)
@@ -501,8 +502,10 @@ pub fn start_web_server(addr: Option<SocketAddr>, context: &dyn Context) -> Resu
501502
context.storage()?;
502503
context.repository_stats_updater()?;
503504

504-
let app = build_axum_app(context, template_data)?.into_make_service();
505505
context.runtime()?.block_on(async {
506+
let app = build_axum_app(context, template_data)
507+
.await?
508+
.into_make_service();
506509
let listener = tokio::net::TcpListener::bind(axum_addr)
507510
.await
508511
.context("error binding socket for metrics web server")?;

0 commit comments

Comments
 (0)