Skip to content

Commit

Permalink
Require GET in Upgrader.Upgrade.
Browse files Browse the repository at this point in the history
Return error if the request method is not GET.

Remove all request method tests from the examples.
  • Loading branch information
garyburd committed Nov 2, 2015
1 parent a4e0143 commit 567453a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
25 changes: 20 additions & 5 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ func newTLSServer(t *testing.T) *cstServer {
}

func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Logf("method %s not allowed", r.Method)
http.Error(w, "method not allowed", 405)
return
}
subprotos := Subprotocols(r)
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
Expand Down Expand Up @@ -287,6 +282,26 @@ func TestDialBadHeader(t *testing.T) {
}
}

func TestBadMethod(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil)
if err == nil {
t.Errorf("handshake succeeded, expect fail")
ws.Close()
}
}))
defer s.Close()

resp, err := http.PostForm(s.URL, url.Values{})
if err != nil {
t.Fatalf("PostForm returned error %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
}
}

func TestHandshake(t *testing.T) {
s := newServer(t)
defer s.Close()
Expand Down
4 changes: 0 additions & 4 deletions examples/chat/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ func (c *connection) writePump() {

// serveWs handles websocket requests from the peer.
func serveWs(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
}
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
Expand Down
5 changes: 0 additions & 5 deletions examples/command/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ func internalError(ws *websocket.Conn, msg string, err error) {
var upgrader = websocket.Upgrader{}

func serveWs(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("upgrade:", err)
Expand Down
4 changes: 0 additions & 4 deletions examples/echo/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ func echo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", 404)
return
}
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
}
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
}
Expand Down

0 comments on commit 567453a

Please sign in to comment.