diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8d63b8e0..3efb1797 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -48,6 +48,7 @@ Subheadings to categorize changes are `added, changed, deprecated, removed, fixe
- Updated CLI argument parser to clap v0.4.
- Reduce error to warning when processing a project without Cargo.toml and no `` (fixes #487)
- Add wasm-bindgen URLs for all supported architectures
+- Changed hot-reload to use request host instead of `window.location.host`
### fixed
- Nested WS proxies - if `backend=ws://localhost:8000/ws` is set, queries for `ws://localhost:8080/ws/entityX` will be linked with `ws://localhost:8000/ws/entityX`
diff --git a/Cargo.lock b/Cargo.lock
index 743a4e00..516cd124 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2659,6 +2659,8 @@ dependencies = [
"flate2",
"fs_extra",
"futures-util",
+ "http-body",
+ "hyper",
"local-ip-address",
"nipper",
"notify",
diff --git a/Cargo.toml b/Cargo.toml
index b8880ae5..d657bb75 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -44,6 +44,8 @@ reqwest = { version = "0.11", default-features = false, features = [
"stream",
"trust-dns",
] }
+http-body = "0.4"
+hyper = "0.14"
seahash = "4"
serde = { version = "1", features = ["derive"] }
tar = "0.4"
diff --git a/src/autoreload.js b/src/autoreload.js
index d141fbad..e5a62d08 100644
--- a/src/autoreload.js
+++ b/src/autoreload.js
@@ -1,6 +1,6 @@
(function () {
var protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
- var url = protocol + '//' + window.location.host + '/_trunk/ws';
+ var url = protocol + '//' + '{{__TRUNK_ADDRESS__}}' + '/_trunk/ws';
var poll_interval = 5000;
var reload_upon_connect = () => {
window.setTimeout(
diff --git a/src/serve.rs b/src/serve.rs
index fa41e1c2..712dc3e1 100644
--- a/src/serve.rs
+++ b/src/serve.rs
@@ -3,9 +3,14 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Context, Result};
-use axum::body::{self, Body};
+use axum::body::{self, Body, Bytes};
use axum::extract::ws::{WebSocket, WebSocketUpgrade};
-use axum::http::StatusCode;
+use axum::http::response::Parts;
+use axum::http::{
+ header::{CONTENT_LENGTH, CONTENT_TYPE, HOST},
+ Request, StatusCode,
+};
+use axum::middleware::Next;
use axum::response::Response;
use axum::routing::{get, get_service, Router};
use axum::Server;
@@ -227,7 +232,8 @@ fn router(state: Arc, cfg: Arc) -> Router {
tracing::error!(?error, "failed serving static file");
StatusCode::INTERNAL_SERVER_ERROR
})
- .layer(TraceLayer::new_for_http()),
+ .layer(TraceLayer::new_for_http())
+ .layer(axum::middleware::from_fn(html_address_middleware)),
),
)
.route(
@@ -303,6 +309,45 @@ fn router(state: Arc, cfg: Arc) -> Router {
router
}
+async fn html_address_middleware(
+ request: Request,
+ next: Next,
+) -> (Parts, Bytes) {
+ let uri = request.headers().get(HOST).cloned();
+ let response = next.run(request).await;
+ let (parts, body) = response.into_parts();
+
+ match hyper::body::to_bytes(body).await {
+ Err(_) => (parts, Bytes::default()),
+ Ok(bytes) => {
+ let (mut parts, mut bytes) = (parts, bytes);
+
+ if let Some(uri) = uri {
+ if parts
+ .headers
+ .get(CONTENT_TYPE)
+ .map(|t| t == "text/html")
+ .unwrap_or(false)
+ {
+ if let Ok(data_str) = std::str::from_utf8(&bytes) {
+ let data_str = data_str.replace(
+ "'{{__TRUNK_ADDRESS__}}'",
+ &uri.to_str()
+ .map(|s| format!("'{}'", s))
+ .unwrap_or_else(|_| "window.location.href".into()),
+ );
+ let bytes_vec = data_str.as_bytes().to_vec();
+ parts.headers.insert(CONTENT_LENGTH, bytes_vec.len().into());
+ bytes = Bytes::from(bytes_vec);
+ }
+ }
+ }
+
+ (parts, bytes)
+ }
+ }
+}
+
async fn handle_ws(mut ws: WebSocket, state: Arc) {
let mut rx = state.build_done_chan.subscribe();
tracing::debug!("autoreload websocket opened");