diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d63b8e0..3efb1797 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ Subheadings to categorize changes are `added, changed, deprecated, removed, fixe - Updated CLI argument parser to clap v0.4. - Reduce error to warning when processing a project without Cargo.toml and no `` (fixes #487) - Add wasm-bindgen URLs for all supported architectures +- Changed hot-reload to use request host instead of `window.location.host` ### fixed - Nested WS proxies - if `backend=ws://localhost:8000/ws` is set, queries for `ws://localhost:8080/ws/entityX` will be linked with `ws://localhost:8000/ws/entityX` diff --git a/Cargo.lock b/Cargo.lock index 743a4e00..516cd124 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2659,6 +2659,8 @@ dependencies = [ "flate2", "fs_extra", "futures-util", + "http-body", + "hyper", "local-ip-address", "nipper", "notify", diff --git a/Cargo.toml b/Cargo.toml index b8880ae5..d657bb75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ reqwest = { version = "0.11", default-features = false, features = [ "stream", "trust-dns", ] } +http-body = "0.4" +hyper = "0.14" seahash = "4" serde = { version = "1", features = ["derive"] } tar = "0.4" diff --git a/src/autoreload.js b/src/autoreload.js index d141fbad..e5a62d08 100644 --- a/src/autoreload.js +++ b/src/autoreload.js @@ -1,6 +1,6 @@ (function () { var protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - var url = protocol + '//' + window.location.host + '/_trunk/ws'; + var url = protocol + '//' + '{{__TRUNK_ADDRESS__}}' + '/_trunk/ws'; var poll_interval = 5000; var reload_upon_connect = () => { window.setTimeout( diff --git a/src/serve.rs b/src/serve.rs index fa41e1c2..712dc3e1 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -3,9 +3,14 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{Context, Result}; -use axum::body::{self, Body}; +use axum::body::{self, Body, Bytes}; use axum::extract::ws::{WebSocket, WebSocketUpgrade}; -use axum::http::StatusCode; +use axum::http::response::Parts; +use axum::http::{ + header::{CONTENT_LENGTH, CONTENT_TYPE, HOST}, + Request, StatusCode, +}; +use axum::middleware::Next; use axum::response::Response; use axum::routing::{get, get_service, Router}; use axum::Server; @@ -227,7 +232,8 @@ fn router(state: Arc, cfg: Arc) -> Router { tracing::error!(?error, "failed serving static file"); StatusCode::INTERNAL_SERVER_ERROR }) - .layer(TraceLayer::new_for_http()), + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(html_address_middleware)), ), ) .route( @@ -303,6 +309,45 @@ fn router(state: Arc, cfg: Arc) -> Router { router } +async fn html_address_middleware( + request: Request, + next: Next, +) -> (Parts, Bytes) { + let uri = request.headers().get(HOST).cloned(); + let response = next.run(request).await; + let (parts, body) = response.into_parts(); + + match hyper::body::to_bytes(body).await { + Err(_) => (parts, Bytes::default()), + Ok(bytes) => { + let (mut parts, mut bytes) = (parts, bytes); + + if let Some(uri) = uri { + if parts + .headers + .get(CONTENT_TYPE) + .map(|t| t == "text/html") + .unwrap_or(false) + { + if let Ok(data_str) = std::str::from_utf8(&bytes) { + let data_str = data_str.replace( + "'{{__TRUNK_ADDRESS__}}'", + &uri.to_str() + .map(|s| format!("'{}'", s)) + .unwrap_or_else(|_| "window.location.href".into()), + ); + let bytes_vec = data_str.as_bytes().to_vec(); + parts.headers.insert(CONTENT_LENGTH, bytes_vec.len().into()); + bytes = Bytes::from(bytes_vec); + } + } + } + + (parts, bytes) + } + } +} + async fn handle_ws(mut ws: WebSocket, state: Arc) { let mut rx = state.build_done_chan.subscribe(); tracing::debug!("autoreload websocket opened");