diff --git a/httpconn.c b/httpconn.c index dcf7900..64a10e5 100644 --- a/httpconn.c +++ b/httpconn.c @@ -44,6 +44,7 @@ struct http_conn { int tunnel_read_paused; int msg_complete_on_eof; int persistent; + int expect_continue; const struct http_cbs *cbs; void *cbarg; ev_int64_t body_length; @@ -208,6 +209,8 @@ http_conn_error_to_string(enum http_conn_error err) return "Connection failed"; case ERROR_IDLE_CONN_TIMEDOUT: return "Idle connection timed out"; + case ERROR_CLIENT_EXPECTATION_FAILED: + return "Can't statisfy client's expectation"; case ERROR_CLIENT_POST_WITHOUT_LENGTH: return "Client post with unknown length"; case ERROR_INCOMPLETE_HEADERS: @@ -238,6 +241,7 @@ begin_message(struct http_conn *conn) conn->state = HTTP_STATE_IDLE; if (!conn->read_paused) bufferevent_enable(conn->bev, EV_READ); + // XXX we should have a separate function to tell that server is idle. if (conn->type == HTTP_SERVER) bufferevent_set_timeouts(conn->bev, &idle_server_timeout, NULL); else @@ -249,8 +253,13 @@ end_message(struct http_conn *conn, enum http_conn_error err) { if (conn->firstline) mem_free(conn->firstline); - if (conn->headers) + if (conn->headers) { headers_clear(conn->headers); + mem_free(conn->headers); + } + + conn->firstline = NULL; + conn->headers = NULL; if (err != ERROR_NONE || !conn->persistent) { conn->state = HTTP_STATE_MANGLED; @@ -455,7 +464,7 @@ read_body(struct http_conn *conn) } } -static void +static enum http_conn_error check_headers(struct http_conn *conn, struct http_request *req, struct http_response *resp) { @@ -469,6 +478,7 @@ check_headers(struct http_conn *conn, struct http_request *req, conn->msg_complete_on_eof = 0; conn->data_remaining = -1; conn->body_length = -1; + conn->expect_continue = 0; tunnel = 0; if (conn->type == HTTP_CLIENT) { @@ -479,6 +489,25 @@ check_headers(struct http_conn *conn, struct http_request *req, conn->has_body = 1; else if (req->meth == METH_CONNECT) tunnel = 1; + + val = headers_find(conn->headers, "Expect"); + if (val) { + int cont; + + cont = !evutil_ascii_strcasecmp(val, "100-continue"); + mem_free(val); + if (cont == 0 || !conn->has_body) + return ERROR_CLIENT_EXPECTATION_FAILED; + + if (cont && req->vers != HTTP_11) { + cont = 0; + log_info("http: ignoring expect continue from " + "old client"); + headers_remove(conn->headers, "Expect"); + } + + conn->expect_continue = cont; + } } else { /* server */ vers = resp->vers; if ((resp->code >= 100 && resp->code < 200) || @@ -502,7 +531,7 @@ check_headers(struct http_conn *conn, struct http_request *req, ev_int64_t iv; iv = get_int(val, 10); if (iv < 0) { - log_warn("http_conn: mangled " + log_warn("http: mangled " "Content-Length"); headers_remove(conn->headers, "content-length"); @@ -517,11 +546,8 @@ check_headers(struct http_conn *conn, struct http_request *req, } if (conn->type == HTTP_CLIENT && conn->body_length < 0 && - conn->te != TE_CHUNKED) { - EVENT1(conn, on_error, - ERROR_CLIENT_POST_WITHOUT_LENGTH); - return; - } + conn->te != TE_CHUNKED) + return ERROR_CLIENT_POST_WITHOUT_LENGTH; } conn->data_remaining = conn->body_length; @@ -546,6 +572,8 @@ check_headers(struct http_conn *conn, struct http_request *req, } } conn->persistent = persistent; + + return ERROR_NONE; } static void @@ -555,6 +583,7 @@ read_headers(struct http_conn *conn) struct evbuffer *inbuf = bufferevent_get_input(conn->bev); struct http_request *req = NULL; struct http_response *resp = NULL; + enum http_conn_error err; assert(conn->state == HTTP_STATE_READ_HEADERS); @@ -588,20 +617,36 @@ read_headers(struct http_conn *conn) return; } - check_headers(conn, req, resp); + err = check_headers(conn, req, resp); conn->headers = NULL; - /* ownership of req or resp is now passed on */ - if (req) - EVENT1(conn, on_client_request, req); - if (resp) - EVENT1(conn, on_server_response, resp); + if (err == ERROR_NONE) { + int server_continuation = 0; + + /* ownership of req or resp is now passed on */ + if (req) + EVENT1(conn, on_client_request, req); + if (resp) { + if (resp->code == 100) { + http_response_free(resp); + EVENT0(conn, on_server_continuation); + begin_message(conn); + server_continuation = 1; + } else + EVENT1(conn, on_server_response, resp); + } - if (conn->state != HTTP_STATE_TUNNEL_CONNECTING) { - if (!conn->has_body) - end_message(conn, ERROR_NONE); - else - conn->state = HTTP_STATE_READ_BODY; + if (!server_continuation && + conn->state != HTTP_STATE_TUNNEL_CONNECTING) { + if (!conn->has_body) + end_message(conn, ERROR_NONE); + else + conn->state = HTTP_STATE_READ_BODY; + } + } else { + http_request_free(req); + http_response_free(resp); + end_message(conn, err); } } @@ -911,6 +956,25 @@ http_conn_write_request(struct http_conn *conn, struct http_request *req) headers_dump(req->headers, outbuf); } +int +http_conn_expect_continue(struct http_conn *conn) +{ + return conn->expect_continue; +} + +void +http_conn_write_continue(struct http_conn *conn) +{ + struct evbuffer *outbuf; + + if (conn->expect_continue) { + outbuf = bufferevent_get_output(conn->bev); + conn->expect_continue = 0; + assert(conn->vers == HTTP_11); + evbuffer_add_printf(outbuf, "HTTP/1.1 100 Continue\r\n\r\n"); + } +} + void http_conn_write_response(struct http_conn *conn, struct http_response *resp) { @@ -1109,6 +1173,9 @@ http_conn_start_tunnel(struct http_conn *conn, struct evdns_base *dns, void http_request_free(struct http_request *req) { + if (!req) + return; + url_free(req->url); headers_clear(req->headers); mem_free(req); @@ -1117,6 +1184,9 @@ http_request_free(struct http_request *req) void http_response_free(struct http_response *resp) { + if (!resp) + return; + headers_clear(resp->headers); mem_free(resp->headers); mem_free(resp->reason); diff --git a/httpconn.h b/httpconn.h index 6b574bc..40504c7 100644 --- a/httpconn.h +++ b/httpconn.h @@ -41,6 +41,7 @@ enum http_conn_error { ERROR_NONE, ERROR_CONNECT_FAILED, ERROR_IDLE_CONN_TIMEDOUT, + ERROR_CLIENT_EXPECTATION_FAILED, ERROR_CLIENT_POST_WITHOUT_LENGTH, ERROR_INCOMPLETE_HEADERS, ERROR_INCOMPLETE_BODY, @@ -78,6 +79,7 @@ struct http_cbs { void (*on_connect)(struct http_conn *, void *); void (*on_error)(struct http_conn *, enum http_conn_error, void *); void (*on_client_request)(struct http_conn *, struct http_request *, void *); + void (*on_server_continuation)(struct http_conn *, void *); void (*on_server_response)(struct http_conn *, struct http_response *, void *); void (*on_read_body)(struct http_conn *, struct evbuffer *, void *); void (*on_msg_complete)(struct http_conn *, void *); @@ -98,6 +100,8 @@ int http_conn_connect(struct http_conn *conn, struct evdns_base *dns, void http_conn_free(struct http_conn *conn); void http_conn_write_request(struct http_conn *conn, struct http_request *req); +int http_conn_expect_continue(struct http_conn *conn); +void http_conn_write_continue(struct http_conn *conn); void http_conn_write_response(struct http_conn *conn, struct http_response *resp); /* return: 0 on choaked, 1 on queued. */ diff --git a/proxy.c b/proxy.c index 9e212aa..5cb2523 100644 --- a/proxy.c +++ b/proxy.c @@ -51,6 +51,7 @@ static void on_client_flush(struct http_conn *, void *); static void on_server_connected(struct http_conn *, void *); static void on_server_error(struct http_conn *, enum http_conn_error, void *); +static void on_server_continuation(struct http_conn *, void *); static void on_server_response(struct http_conn *, struct http_response *, void *); static void on_server_read_body(struct http_conn *, struct evbuffer *, void *); static void on_server_msg_complete(struct http_conn *, void *); @@ -62,6 +63,7 @@ static const struct http_cbs client_methods = { on_client_error, on_client_request, 0, + 0, on_client_read_body, on_client_msg_complete, on_client_write_more, @@ -72,6 +74,7 @@ static const struct http_cbs server_methods = { on_server_connected, on_server_error, 0, + on_server_continuation, on_server_response, on_server_read_body, on_server_msg_complete, @@ -183,29 +186,31 @@ client_scrub_request(struct client *client, struct http_request *req) if (req->meth == METH_CONNECT) { assert(req->url->host && req->url->port >= 1); // XXX we could filter host/port here - } else { - if (!req->url->host) { - http_conn_send_error(client->conn, 403, "Forbidden"); - goto fail; - } - if (evutil_ascii_strcasecmp(req->url->scheme, "http")) { - http_conn_send_error(client->conn, 400, "Invalid URL"); - goto fail; - } + return 0; + } - if (req->url->port < 0) - req->url->port = 80; - - if (!headers_has_key(req->headers, "Host")) { - char *host; - size_t len = strlen(req->url->host) + 6; - host = mem_calloc(1, len); - evutil_snprintf(host, len, "%s:%d", req->url->host, - req->url->port); - headers_add_key_val(req->headers, "Host", host); - mem_free(host); - } + if (!req->url->host) { + http_conn_send_error(client->conn, 403, "Forbidden"); + goto fail; + } + if (evutil_ascii_strcasecmp(req->url->scheme, "http")) { + http_conn_send_error(client->conn, 400, "Invalid URL"); + goto fail; + } + + if (req->url->port < 0) + req->url->port = 80; + + if (!headers_has_key(req->headers, "Host")) { + char *host; + size_t len = strlen(req->url->host) + 6; + host = mem_calloc(1, len); + evutil_snprintf(host, len, "%s:%d", req->url->host, + req->url->port); + headers_add_key_val(req->headers, "Host", host); + mem_free(host); } + // XXX remove proxy auth msgs? return 0; @@ -277,6 +282,24 @@ client_associate_server(struct client *client) return server_connect(client->server); } +static void +client_start_reading_request_body(struct client *client, int on_continue) +{ + assert(client->server != NULL); + + /* should we wait for the server to send 100 continue? */ + if (!on_continue && http_conn_expect_continue(client->conn)) + return; + + if (http_conn_current_message_has_body(client->conn) && + client->nrequests == 1) { + http_conn_write_continue(client->conn); + http_conn_set_output_encoding(client->server->conn, + http_conn_get_current_message_body_encoding(client->conn)); + http_conn_start_reading(client->conn); + } +} + /* returns 1 when there's a request we can dispatch with the associated server. */ static int @@ -309,6 +332,8 @@ client_dispatch_request(struct client *client) req->url->path, server->client, server, server->host, server->port); http_conn_write_request(server->conn, req); + // XXX we may want to wait for 100-continue + client_start_reading_request_body(client, 0); server->state = SERVER_STATE_REQUEST_SENT; return 1; } @@ -316,15 +341,6 @@ client_dispatch_request(struct client *client) return 0; } -static void -client_start_reading_request_body(struct client *client) -{ - //XXX make sure server knows what transefer encodign to use. - if (http_conn_current_message_has_body(client->conn) && - client->nrequests == 1) - http_conn_start_reading(client->conn); -} - static void client_write_response(struct client *client, struct http_response *resp) { @@ -464,8 +480,10 @@ on_client_msg_complete(struct http_conn *conn, void *arg) { struct client *client = arg; - if (http_conn_current_message_has_body(conn)) + if (http_conn_current_message_has_body(conn)) { + log_debug("proxy: finished reading client's message"); http_conn_write_finished(client->server->conn); + } } static void @@ -535,6 +553,15 @@ on_server_error(struct http_conn *conn, enum http_conn_error err, void *arg) server_free(server); } +static void +on_server_continuation(struct http_conn *conn, void *arg) +{ + struct server *server = arg; + + log_debug("proxy: got 100 continue from server"); + client_start_reading_request_body(server->client, 1); +} + static void on_server_response(struct http_conn *conn, struct http_response *resp, void *arg) @@ -548,10 +575,7 @@ on_server_response(struct http_conn *conn, struct http_response *resp, server, server->client); - // XXX maybe not read body on error? - // XXX handle expect 100-continue, etc - - client_start_reading_request_body(server->client); + // XXX maybe stop reading body on error? http_response_free(resp); }