diff --git a/Cargo.lock b/Cargo.lock index 5b3b521f..d7a6f30d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -832,9 +832,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.20" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -2305,9 +2305,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba" +checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" dependencies = [ "bitflags", "bytes", @@ -2422,6 +2422,8 @@ dependencies = [ "flate2", "fs_extra", "futures-util", + "http-body", + "hyper", "nipper", "notify", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 5ec70fd4..4048a644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ once_cell = "1" open = "3" remove_dir_all = "0.7" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream", "trust-dns"] } +http-body = "0.4" +hyper = "0.14" seahash = "4" serde = { version = "1", features = ["derive"] } tar = "0.4" diff --git a/src/autoreload.js b/src/autoreload.js index d141fbad..e5a62d08 100644 --- a/src/autoreload.js +++ b/src/autoreload.js @@ -1,6 +1,6 @@ (function () { var protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - var url = protocol + '//' + window.location.host + '/_trunk/ws'; + var url = protocol + '//' + '{{__TRUNK_ADDRESS__}}' + '/_trunk/ws'; var poll_interval = 5000; var reload_upon_connect = () => { window.setTimeout( diff --git a/src/serve.rs b/src/serve.rs index 402174a0..5409f57d 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -2,12 +2,18 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{Context, Result}; -use axum::body::{self, Body}; +use axum::body::{self, Body, Bytes}; use axum::extract::ws::{WebSocket, WebSocketUpgrade}; -use axum::http::StatusCode; +use axum::http::response::Parts; +use axum::http::{ + header::{CONTENT_TYPE, HOST}, + Request, StatusCode, +}; +use axum::middleware::Next; use axum::response::Response; use axum::routing::{get, get_service, Router}; use axum::Server; +use hyper::header::CONTENT_LENGTH; use tokio::sync::broadcast; use tokio::task::JoinHandle; use tower_http::services::{ServeDir, ServeFile}; @@ -194,7 +200,8 @@ fn router(state: Arc, cfg: Arc) -> Router { tracing::error!(?error, "failed serving static file"); StatusCode::INTERNAL_SERVER_ERROR }) - .layer(TraceLayer::new_for_http()), + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(html_address_middleware)), ), ) .route( @@ -270,6 +277,45 @@ fn router(state: Arc, cfg: Arc) -> Router { router } +async fn html_address_middleware( + request: Request, + next: Next, +) -> (Parts, Bytes) { + let uri = request.headers().get(HOST).cloned(); + let response = next.run(request).await; + let (parts, body) = response.into_parts(); + + match hyper::body::to_bytes(body).await { + Err(_) => (parts, Bytes::default()), + Ok(bytes) => { + let (mut parts, mut bytes) = (parts, bytes); + + if let Some(uri) = uri { + if parts + .headers + .get(CONTENT_TYPE) + .map(|t| t == "text/html") + .unwrap_or(false) + { + if let Ok(data_str) = std::str::from_utf8(&bytes) { + let data_str = data_str.replace( + "'{{__TRUNK_ADDRESS__}}'", + &uri.to_str() + .map(|s| format!("'{}'", s)) + .unwrap_or_else(|_| "window.location.href".into()), + ); + let bytes_vec = data_str.as_bytes().to_vec(); + parts.headers.insert(CONTENT_LENGTH, bytes_vec.len().into()); + bytes = Bytes::from(bytes_vec); + } + } + } + + (parts, bytes) + } + } +} + async fn handle_ws(mut ws: WebSocket, state: Arc) { let mut rx = state.build_done_chan.subscribe(); tracing::debug!("autoreload websocket opened");