Skip to content

Commit

Permalink
feat: use request uri for hot-reload instead of window.location.host
Browse files Browse the repository at this point in the history
  • Loading branch information
amrbashir committed Feb 22, 2023
1 parent b41a957 commit 8c92ebc
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
10 changes: 6 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ once_cell = "1"
open = "3"
remove_dir_all = "0.7"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream", "trust-dns"] }
http-body = "0.4"
hyper = "0.14"
seahash = "4"
serde = { version = "1", features = ["derive"] }
tar = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion src/autoreload.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(function () {
var protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
var url = protocol + '//' + window.location.host + '/_trunk/ws';
var url = protocol + '//' + '{{__TRUNK_ADDRESS__}}' + '/_trunk/ws';
var poll_interval = 5000;
var reload_upon_connect = () => {
window.setTimeout(
Expand Down
52 changes: 49 additions & 3 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@ use std::path::PathBuf;
use std::sync::Arc;

use anyhow::{Context, Result};
use axum::body::{self, Body};
use axum::body::{self, Body, Bytes};
use axum::extract::ws::{WebSocket, WebSocketUpgrade};
use axum::http::StatusCode;
use axum::http::response::Parts;
use axum::http::{
header::{CONTENT_TYPE, HOST},
Request, StatusCode,
};
use axum::middleware::Next;
use axum::response::Response;
use axum::routing::{get, get_service, Router};
use axum::Server;
use hyper::header::CONTENT_LENGTH;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tower_http::services::{ServeDir, ServeFile};
Expand Down Expand Up @@ -194,7 +200,8 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Router {
tracing::error!(?error, "failed serving static file");
StatusCode::INTERNAL_SERVER_ERROR
})
.layer(TraceLayer::new_for_http()),
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(html_address_middleware)),
),
)
.route(
Expand Down Expand Up @@ -270,6 +277,45 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Router {
router
}

async fn html_address_middleware<B: std::fmt::Debug>(
request: Request<B>,
next: Next<B>,
) -> (Parts, Bytes) {
let uri = request.headers().get(HOST).cloned();
let response = next.run(request).await;
let (parts, body) = response.into_parts();

match hyper::body::to_bytes(body).await {
Err(_) => (parts, Bytes::default()),
Ok(bytes) => {
let (mut parts, mut bytes) = (parts, bytes);

if let Some(uri) = uri {
if parts
.headers
.get(CONTENT_TYPE)
.map(|t| t == "text/html")
.unwrap_or(false)
{
if let Ok(data_str) = std::str::from_utf8(&bytes) {
let data_str = data_str.replace(
"'{{__TRUNK_ADDRESS__}}'",
&uri.to_str()
.map(|s| format!("'{}'", s))
.unwrap_or_else(|_| "window.location.href".into()),
);
let bytes_vec = data_str.as_bytes().to_vec();
parts.headers.insert(CONTENT_LENGTH, bytes_vec.len().into());
bytes = Bytes::from(bytes_vec);
}
}
}

(parts, bytes)
}
}
}

async fn handle_ws(mut ws: WebSocket, state: Arc<State>) {
let mut rx = state.build_done_chan.subscribe();
tracing::debug!("autoreload websocket opened");
Expand Down

0 comments on commit 8c92ebc

Please sign in to comment.