Skip to content

Commit

Permalink
A few simple test case to get us started and run them in CI (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughns authored Jan 20, 2025
1 parent 673d744 commit 4d9574b
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 5 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Test

on:
pull_request:
push:
branches: [main]

jobs:
test:
name: Testing
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Test
run: go test -timeout 30s
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module ec-lms
module lk-jwt-service

go 1.23

Expand Down Expand Up @@ -31,6 +31,7 @@ require (
github.com/go-jose/go-jose/v3 v3.0.3 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/cel-go v0.21.0 // indirect
github.com/google/uuid v1.6.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
Expand Down
17 changes: 13 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
return
}

// TODO: we should be sanitising the input here before using it
// e.g. only allowing `https://` URL scheme
userInfo, err := exchangeOIDCToken(r.Context(), body.OpenIDToken, h.skipVerifyTLS)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -145,6 +147,7 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {

log.Printf("Got user info for %s", userInfo.Sub)

// TODO: is DeviceID required? If so then we should have validated at the start of the request processing
token, err := getJoinToken(h.key, h.secret, body.Room, userInfo.Sub+":"+body.DeviceID)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -170,6 +173,15 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
}
}

func (h *Handler) prepareMux() (*http.ServeMux) {

mux := http.NewServeMux()
mux.HandleFunc("/sfu/get", h.handle)
mux.HandleFunc("/healthz", h.healthcheck)

return mux
}

func main() {
skipVerifyTLS := os.Getenv("LIVEKIT_INSECURE_SKIP_VERIFY_TLS") == "YES_I_KNOW_WHAT_I_AM_DOING"
if skipVerifyTLS {
Expand Down Expand Up @@ -203,10 +215,7 @@ func main() {
skipVerifyTLS: skipVerifyTLS,
}

http.HandleFunc("/sfu/get", handler.handle)
http.HandleFunc("/healthz", handler.healthcheck)

log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", lk_jwt_port), nil))
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", lk_jwt_port), handler.prepareMux()))
}

func getJoinToken(apiKey, apiSecret, room, identity string) (string, error) {
Expand Down
222 changes: 222 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Copyright 2025 New Vector Ltd

// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.

// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package main

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/matrix-org/gomatrix"
)

func TestHealthcheck(t *testing.T) {
handler := &Handler{}
req, err := http.NewRequest("GET", "/healthz", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler.prepareMux().ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
}

func TestHandleOptions(t *testing.T) {
handler := &Handler{}
req, err := http.NewRequest("OPTIONS", "/sfu/get", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler.prepareMux().ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code for OPTIONS: got %v want %v", status, http.StatusOK)
}

if accessControlAllowOrigin := rr.Header().Get("Access-Control-Allow-Origin"); accessControlAllowOrigin != "*" {
t.Errorf("handler returned wrong Access-Control-Allow-Origin: got %v want %v", accessControlAllowOrigin, "*")
}

if accessControlAllowMethods := rr.Header().Get("Access-Control-Allow-Methods"); accessControlAllowMethods != "POST" {
t.Errorf("handler returned wrong Access-Control-Allow-Methods: got %v want %v", accessControlAllowMethods, "POST")
}
}

func TestHandlePostMissingParams(t *testing.T) {
handler := &Handler{}

testCases := []map[string]interface{} {
{},
{
"room": "",
},
}

for _, testCase := range testCases {
jsonBody, _ := json.Marshal(testCase)

req, err := http.NewRequest("POST", "/sfu/get", bytes.NewBuffer(jsonBody))
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler.prepareMux().ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusBadRequest {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest)
}

var resp gomatrix.RespError
err = json.NewDecoder(rr.Body).Decode(&resp)
if err != nil {
t.Errorf("failed to decode response body %v", err)
}

if resp.ErrCode != "M_BAD_JSON" {
t.Errorf("unexpected error code: got %v want %v", resp.ErrCode, "M_BAD_JSON")
}
}
}

func TestHandlePost(t *testing.T) {
handler := &Handler{
secret: "testSecret",
key: "testKey",
lk_url: "wss://lk.local:8080/foo",
skipVerifyTLS: true,
}

var matrixServerName = ""

testServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log("Received request")
// Inspect the request
if r.URL.Path != "/_matrix/federation/v1/openid/userinfo" {
t.Errorf("unexpected request path: got %v want %v", r.URL.Path, "/_matrix/federation/v1/openid/userinfo")
}

if accessToken := r.URL.Query().Get("access_token"); accessToken != "testAccessToken" {
t.Errorf("unexpected access token: got %v want %v", accessToken, "testAccessToken")
}

// Mock response
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
_, err := w.Write([]byte(fmt.Sprintf(`{"sub": "@user:%s"}`, matrixServerName)))
if err != nil {
t.Fatalf("failed to write response: %v", err)
}
}))
defer testServer.Close()

u, _ := url.Parse(testServer.URL)

matrixServerName = u.Host

testCase := map[string]interface{} {
"room": "testRoom",
"openid_token": map[string]interface{} {
"access_token": "testAccessToken",
"token_type": "testTokenType",
"matrix_server_name": u.Host,
},
"device_id": "testDevice",
}

jsonBody, _ := json.Marshal(testCase)

req, err := http.NewRequest("POST", "/sfu/get", bytes.NewBuffer(jsonBody))
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler.prepareMux().ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

if contentType := rr.Header().Get("Content-Type"); contentType != "application/json" {
t.Errorf("handler returned wrong Content-Type: got %v want %v", contentType, "application/json")
}

var resp SFUResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
if err != nil {
t.Errorf("failed to decode response body %v", err)
}

if resp.URL != "wss://lk.local:8080/foo" {
t.Errorf("unexpected URL: got %v want %v", resp.URL, "wss://lk.local:8080/foo")
}

if resp.JWT == "" {
t.Error("expected JWT to be non-empty")
}

// parse JWT checking the shared secret
token, err := jwt.Parse(resp.JWT, func(token *jwt.Token) (interface{}, error) {
return []byte(handler.secret), nil
})

if err != nil {
t.Fatalf("failed to parse JWT: %v", err)
}

claims, ok := token.Claims.(jwt.MapClaims)

if !ok || !token.Valid {
t.Fatalf("failed to parse claims from JWT: %v", err)
}

if claims["sub"] != "@user:"+matrixServerName+":testDevice" {
t.Errorf("unexpected sub: got %v want %v", claims["sub"], "@user:"+matrixServerName+":testDevice")
}

// should have permission for the room
if claims["video"].(map[string]interface{})["room"] != "testRoom" {
t.Errorf("unexpected room: got %v want %v", claims["room"], "testRoom")
}
}

func TestGetJoinToken(t *testing.T) {
apiKey := "testKey"
apiSecret := "testSecret"
room := "testRoom"
identity := "[email protected]"

token, err := getJoinToken(apiKey, apiSecret, room, identity)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if token == "" {
t.Error("expected token to be non-empty")
}
}

0 comments on commit 4d9574b

Please sign in to comment.