diff --git a/mux_test.go b/mux_test.go index b078f5f8..bac758bc 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1769,7 +1769,70 @@ func TestPanicOnCapturingGroups(t *testing.T) { } func TestRouterInContext(t *testing.T) { - // TODO Write tests for router in context + router := NewRouter() + router.HandleFunc("/r1", func(w http.ResponseWriter, r *http.Request) { + contextRouter := CurrentRouter(r) + if contextRouter == nil { + t.Fatal("Router not found in context") + return + } + + route := contextRouter.Get("r2") + if route == nil { + t.Fatal("Route with name not found") + return + } + + url, err := route.URL() + if err != nil { + t.Fatal("Error while getting url for r2: ", err) + return + } + + _, err = w.Write([]byte(url.String())) + if err != nil { + t.Fatalf("Failed writing HTTP response: %v", err) + } + }).Name("r1") + + noRouterMsg := []byte("no-router") + haveRouterMsg := []byte("have-router") + router.HandleFunc("/r2", func(w http.ResponseWriter, r *http.Request) { + var msg []byte + + contextRouter := CurrentRouter(r) + if contextRouter == nil { + msg = noRouterMsg + } else { + msg = haveRouterMsg + } + + _, err := w.Write(msg) + if err != nil { + t.Fatalf("Failed writing HTTP response: %v", err) + } + }).Name("r2") + + t.Run("router in request context get route by name", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/r1") + + router.ServeHTTP(rw, req) + if !bytes.Equal(rw.Body.Bytes(), []byte("/r2")) { + t.Fatalf("Expected output to be '/r1' but got '%s'", rw.Body.String()) + } + }) + + t.Run("omit router from request context", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/r2") + + router.OmitRouterFromContext(true) + router.ServeHTTP(rw, req) + if !bytes.Equal(rw.Body.Bytes(), noRouterMsg) { + t.Fatal("Router not omitted from context") + } + }) } // ----------------------------------------------------------------------------