diff --git a/cmd/main.go b/cmd/main.go index 31c83b2..3fb3bda 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -9,14 +9,17 @@ import ( "log/slog" "os" "os/signal" + "runtime" "syscall" + "time" "github.com/absmach/mproxy" "github.com/absmach/mproxy/examples/simple" - "github.com/absmach/mproxy/pkg/http" - "github.com/absmach/mproxy/pkg/mqtt" - "github.com/absmach/mproxy/pkg/mqtt/websocket" - "github.com/absmach/mproxy/pkg/session" + "github.com/absmach/mproxy/forwarder" + "github.com/absmach/mproxy/http" + "github.com/absmach/mproxy/mqtt" + "github.com/absmach/mproxy/mqtt/websocket" + "github.com/absmach/mproxy/streamer" "github.com/caarlos0/env/v11" "github.com/joho/godotenv" "golang.org/x/sync/errgroup" @@ -37,6 +40,18 @@ const ( ) func main() { + go func() { + for { + time.Sleep(time.Second * 3) + fmt.Println("RTN", runtime.NumGoroutine()) + var m runtime.MemStats + runtime.ReadMemStats(&m) + fmt.Printf("Alloc = %v MiB", m.Alloc/1024/1024) + fmt.Printf("\tTotalAlloc = %v MiB", m.TotalAlloc/1024/1024) + fmt.Printf("\tSys = %v MiB", m.Sys/1024/1024) + fmt.Printf("\tNumGC = %v\n", m.NumGC) + } + }() ctx, cancel := context.WithCancel(context.Background()) g, ctx := errgroup.WithContext(ctx) @@ -47,7 +62,7 @@ func main() { handler := simple.New(logger) - var interceptor session.Interceptor + var interceptor mproxy.Interceptor // Loading .env file to environment err := godotenv.Load() @@ -62,9 +77,10 @@ func main() { } // mProxy server for MQTT without TLS - mqttProxy := mqtt.New(mqttConfig, handler, interceptor, logger) + mqttProxy := mqtt.New(handler, interceptor) + g.Go(func() error { - return mqttProxy.Listen(ctx) + return streamer.Listen(ctx, "MQTT", mqttConfig, mqttProxy, logger) }) // mProxy server Configuration for MQTT with TLS @@ -74,9 +90,9 @@ func main() { } // mProxy server for MQTT with TLS - mqttTLSProxy := mqtt.New(mqttTLSConfig, handler, interceptor, logger) + mqttTLSProxy := mqtt.New(handler, interceptor) g.Go(func() error { - return mqttTLSProxy.Listen(ctx) + return streamer.Listen(ctx, "MQTT", mqttTLSConfig, mqttTLSProxy, logger) }) // mProxy server Configuration for MQTT with mTLS @@ -86,9 +102,9 @@ func main() { } // mProxy server for MQTT with mTLS - mqttMTlsProxy := mqtt.New(mqttMTLSConfig, handler, interceptor, logger) + mqttMTlsProxy := mqtt.New(handler, interceptor) g.Go(func() error { - return mqttMTlsProxy.Listen(ctx) + return streamer.Listen(ctx, "MQTT", mqttMTLSConfig, mqttMTlsProxy, logger) }) // mProxy server Configuration for MQTT over Websocket without TLS @@ -98,9 +114,9 @@ func main() { } // mProxy server for MQTT over Websocket without TLS - wsProxy := websocket.New(wsConfig, handler, interceptor, logger) + wsProxy := websocket.New(wsConfig.Target, handler, interceptor, logger) g.Go(func() error { - return wsProxy.Listen(ctx) + return forwarder.Listen(ctx, "WS", wsConfig, wsProxy, logger) }) // mProxy server Configuration for MQTT over Websocket with TLS @@ -110,9 +126,9 @@ func main() { } // mProxy server for MQTT over Websocket with TLS - wsTLSProxy := websocket.New(wsTLSConfig, handler, interceptor, logger) + wsTLSProxy := websocket.New(wsTLSConfig.Target, handler, interceptor, logger) g.Go(func() error { - return wsTLSProxy.Listen(ctx) + return forwarder.Listen(ctx, "WS", wsTLSConfig, wsTLSProxy, logger) }) // mProxy server Configuration for MQTT over Websocket with mTLS @@ -122,9 +138,9 @@ func main() { } // mProxy server for MQTT over Websocket with mTLS - wsMTLSProxy := websocket.New(wsMTLSConfig, handler, interceptor, logger) + wsMTLSProxy := websocket.New(wsMTLSConfig.Target, handler, interceptor, logger) g.Go(func() error { - return wsMTLSProxy.Listen(ctx) + return forwarder.Listen(ctx, "WS", wsMTLSConfig, wsMTLSProxy, logger) }) // mProxy server Configuration for HTTP without TLS @@ -139,7 +155,7 @@ func main() { panic(err) } g.Go(func() error { - return httpProxy.Listen(ctx) + return forwarder.Listen(ctx, "HTTP", httpConfig, httpProxy, logger) }) // mProxy server Configuration for HTTP with TLS @@ -154,7 +170,7 @@ func main() { panic(err) } g.Go(func() error { - return httpTLSProxy.Listen(ctx) + return forwarder.Listen(ctx, "HTTP", httpTLSConfig, httpTLSProxy, logger) }) // mProxy server Configuration for HTTP with mTLS @@ -169,7 +185,7 @@ func main() { panic(err) } g.Go(func() error { - return httpMTLSProxy.Listen(ctx) + return forwarder.Listen(ctx, "HTTP", httpMTLSConfig, httpMTLSProxy, logger) }) g.Go(func() error { diff --git a/coap/streamer.go b/coap/streamer.go new file mode 100644 index 0000000..451e510 --- /dev/null +++ b/coap/streamer.go @@ -0,0 +1,95 @@ +package main + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "log" + + piondtls "github.com/pion/dtls/v2" + coap "github.com/plgd-dev/go-coap/v3" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/options" + + dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" + tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" +) + +func handleA(w mux.ResponseWriter, r *mux.Message) { + log.Printf("got message in handleA: %+v from %v\n", r, w.Conn().RemoteAddr()) + err := w.SetResponse(codes.GET, message.TextPlain, bytes.NewReader([]byte("A hello world"))) + if err != nil { + log.Printf("cannot set response: %v", err) + } +} + +func handleB(w mux.ResponseWriter, r *mux.Message) { + log.Printf("got message in handleB: %+v from %v\n", r, w.Conn().RemoteAddr()) + customResp := w.Conn().AcquireMessage(r.Context()) + defer w.Conn().ReleaseMessage(customResp) + customResp.SetCode(codes.Content) + customResp.SetToken(r.Token()) + customResp.SetContentFormat(message.TextPlain) + customResp.SetBody(bytes.NewReader([]byte("B hello world"))) + err := w.Conn().WriteMessage(customResp) + if err != nil { + log.Printf("cannot set response: %v", err) + } +} + +func handleOnNewConn(cc *udpClient.Conn) { + dtlsConn, ok := cc.NetConn().(*piondtls.Conn) + if !ok { + log.Fatalf("invalid type %T", cc.NetConn()) + } + clientId := dtlsConn.ConnectionState().IdentityHint + cc.SetContextValue("clientId", clientId) + cc.AddOnClose(func() { + clientId := dtlsConn.ConnectionState().IdentityHint + log.Printf("closed connection clientId: %s", clientId) + }) +} + +func main() { + m := mux.NewRouter() + m.Handle("/a", mux.HandlerFunc(handleA)) + m.Handle("/b", mux.HandlerFunc(handleB)) + + tcpOpts := []tcpServer.Option{} + tcpOpts = append(tcpOpts, + options.WithMux(m), + options.WithContext(context.Background())) + + dtlsOpts := []dtlsServer.Option{} + dtlsOpts = append(dtlsOpts, + options.WithMux(m), + options.WithContext(context.Background()), + options.WithOnNewConn(handleOnNewConn), + ) + + go func() { + // serve a tcp server on :5686 + log.Fatal(coap.ListenAndServeWithOptions("tcp", ":5686", tcpOpts)) + }() + + go func() { + // serve a tls tcp server on :5687 + log.Fatal(coap.ListenAndServeTCPTLSWithOptions("tcp", "5687", &tls.Config{}, tcpOpts...)) + }() + + go func() { + // serve a udp dtls server on :5688 + log.Fatal(coap.ListenAndServeDTLSWithOptions("udp", ":5688", &piondtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + fmt.Printf("Client's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte("Pion DTLS Client"), + CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8}, + }, dtlsOpts...)) + }() +} diff --git a/config.go b/config.go index a0fda0f..127146c 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,7 @@ package mproxy import ( "crypto/tls" - mptls "github.com/absmach/mproxy/pkg/tls" + mptls "github.com/absmach/mproxy/tls" "github.com/caarlos0/env/v11" ) diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 51b887a..bf5aa62 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -8,6 +8,7 @@ import ( "errors" "log/slog" + "github.com/absmach/mproxy" "github.com/absmach/mproxy/pkg/session" ) @@ -71,7 +72,7 @@ func (h *Handler) Disconnect(ctx context.Context) error { } func (h *Handler) logAction(ctx context.Context, action string, topics *[]string, payload *[]byte) error { - s, ok := session.FromContext(ctx) + s, ok := mproxy.FromContext(ctx) args := []interface{}{ slog.Group("session", slog.String("id", s.ID), slog.String("username", s.Username)), } diff --git a/forwarder/proxy.go b/forwarder/proxy.go new file mode 100644 index 0000000..aeb390c --- /dev/null +++ b/forwarder/proxy.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package forwarder + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + + "github.com/absmach/mproxy" + mptls "github.com/absmach/mproxy/tls" + "golang.org/x/sync/errgroup" +) + +func Listen(ctx context.Context, name string, config mproxy.Config, passer mproxy.Forwarder, logger *slog.Logger) error { + l, err := net.Listen("tcp", config.Address) + if err != nil { + return err + } + + if config.TLSConfig != nil { + l = tls.NewListener(l, config.TLSConfig) + } + status := mptls.SecurityStatus(config.TLSConfig) + + logger.Info(fmt.Sprintf("%s Proxy server started at %s%s with %s", name, config.Address, config.PathPrefix, status)) + + var server http.Server + g, ctx := errgroup.WithContext(ctx) + + mux := http.NewServeMux() + mux.HandleFunc(config.PathPrefix, passer.Forward) + server.Handler = mux + + g.Go(func() error { + return server.Serve(l) + }) + + g.Go(func() error { + <-ctx.Done() + return server.Close() + }) + if err := g.Wait(); err != nil { + logger.Info(fmt.Sprintf("%s Proxy server at %s%s with %s exiting with errors", name, config.Address, config.PathPrefix, status), slog.String("error", err.Error())) + } else { + logger.Info(fmt.Sprintf("%s Proxy server at %s%s with %s exiting...", name, config.Address, config.PathPrefix, status)) + } + return nil +} diff --git a/go.mod b/go.mod index a6b3b3e..a28fc60 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,26 @@ go 1.21 toolchain go1.21.4 require ( - github.com/caarlos0/env/v11 v11.0.0 - github.com/eclipse/paho.mqtt.golang v1.4.3 + github.com/caarlos0/env/v11 v11.2.0 + github.com/eclipse/paho.mqtt.golang v1.5.0 github.com/google/uuid v1.6.0 - github.com/gorilla/websocket v1.5.1 + github.com/gorilla/websocket v1.5.3 + github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c github.com/joho/godotenv v1.5.1 - golang.org/x/crypto v0.22.0 - golang.org/x/sync v0.7.0 + github.com/pion/dtls/v2 v2.2.8-0.20240501061905-2c36d63320a0 + github.com/plgd-dev/go-coap/v3 v3.3.4 + golang.org/x/crypto v0.25.0 + golang.org/x/sync v0.8.0 ) -require golang.org/x/net v0.24.0 // indirect +require ( + github.com/dsnet/golib/memfile v1.0.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/pion/logging v0.2.2 // indirect + github.com/pion/transport/v3 v3.0.2 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/sys v0.22.0 // indirect +) diff --git a/go.sum b/go.sum index aa55fb1..fa5bcab 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,101 @@ -github.com/caarlos0/env/v11 v11.0.0 h1:ZIlkOjuL3xoZS0kmUJlF74j2Qj8GMOq3CDLX/Viak8Q= -github.com/caarlos0/env/v11 v11.0.0/go.mod h1:2RC3HQu8BQqtEK3V4iHPxj0jOdWdbPpWJ6pOueeU1xM= -github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= -github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE= +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= +github.com/caarlos0/env/v11 v11.2.0 h1:kvB1ZmwdWgI3JsuuVUE7z4cY/6Ujr03D0w2WkOOH4Xs= +github.com/caarlos0/env/v11 v11.2.0/go.mod h1:LwgkYk1kDvfGpHthrWWLof3Ny7PezzFwS4QrsJdHTMo= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= +github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= +github.com/eclipse/paho.mqtt.golang v1.5.0 h1:EH+bUVJNgttidWFkLLVKaQPGmkTUfQQqjOsyvMGvD6o= +github.com/eclipse/paho.mqtt.golang v1.5.0/go.mod h1:du/2qNQVqJf/Sqs4MEL77kR8QTqANF7XU7Fk0aOTAgk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c h1:gYfYE403/nlrGNYj6BEOs9ucLCAGB9gstlSk92DttTg= +github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c/go.mod h1:Di7LXRyUcnvAcLicFhtM9/MlZl/TNgRSDHORM2c6CMI= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +github.com/pion/dtls/v2 v2.2.8-0.20240501061905-2c36d63320a0 h1:050ahk2K4HqwxPi2YM6Yc4lIttwNSY2+n9xPVsS3zoQ= +github.com/pion/dtls/v2 v2.2.8-0.20240501061905-2c36d63320a0/go.mod h1:tjBBbkwKGSQQZl36HQa2va5HqR9rWhujhlJMrgE2b/o= +github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/transport/v3 v3.0.2 h1:r+40RJR25S9w3jbA6/5uEPTzcdn7ncyU44RWCbHkLg4= +github.com/pion/transport/v3 v3.0.2/go.mod h1:nIToODoOlb5If2jF9y2Igfx3PFYWfuXi37m0IlWa/D0= +github.com/plgd-dev/go-coap/v3 v3.3.4 h1:clDLFOXXmXfhZqB0eSk6WJs2iYfjC2J22Ixwu5MHiO0= +github.com/plgd-dev/go-coap/v3 v3.3.4/go.mod h1:vxBvAgXxL+Au/58XYTM+8ftqO/ycFC9/Dh+uI72xYjA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http/forwarder.go b/http/forwarder.go new file mode 100644 index 0000000..e2d8fec --- /dev/null +++ b/http/forwarder.go @@ -0,0 +1,109 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "github.com/absmach/mproxy" +) + +const contentType = "application/json" + +// ErrMissingAuthentication returned when no basic or Authorization header is set. +var ErrMissingAuthentication = errors.New("missing authorization") + +// proxy represents HTTP proxy. +type proxy struct { + config mproxy.Config + target *httputil.ReverseProxy + session mproxy.Handler + logger *slog.Logger +} + +func NewProxy(config mproxy.Config, handler mproxy.Handler, logger *slog.Logger) (mproxy.Forwarder, error) { + target, err := url.Parse(config.Target) + if err != nil { + return proxy{}, err + } + + return proxy{ + config: config, + target: httputil.NewSingleHostReverseProxy(target), + session: handler, + logger: logger, + }, nil +} + +func (p proxy) Forward(w http.ResponseWriter, r *http.Request) { + // Metrics and health endpoints are served directly. + if r.URL.Path == "/metrics" || r.URL.Path == "/health" { + p.target.ServeHTTP(w, r) + return + } + + if !strings.HasPrefix(r.URL.Path, p.config.PathPrefix) { + http.NotFound(w, r) + return + } + + username, password, ok := r.BasicAuth() + switch { + case ok: + break + case r.Header.Get("Authorization") != "": + password = r.Header.Get("Authorization") + default: + encodeError(w, http.StatusBadGateway, ErrMissingAuthentication) + return + } + + s := &mproxy.Session{ + Password: []byte(password), + Username: username, + } + ctx := mproxy.NewContext(r.Context(), s) + payload, err := io.ReadAll(r.Body) + if err != nil { + encodeError(w, http.StatusBadRequest, err) + p.logger.Error("Failed to read body", slog.Any("error", err)) + return + } + if err := r.Body.Close(); err != nil { + encodeError(w, http.StatusInternalServerError, err) + p.logger.Error("Failed to close body", slog.Any("error", err)) + return + } + + // r.Body is reset to ensure it can be safely copied by httputil.ReverseProxy. + // no close method is required since NopClose Close() always returns nill. + r.Body = io.NopCloser(bytes.NewBuffer(payload)) + if err := p.session.AuthConnect(ctx); err != nil { + encodeError(w, http.StatusUnauthorized, err) + p.logger.Error("Failed to authorize connect", slog.Any("error", err)) + return + } + if err := p.session.Publish(ctx, &r.RequestURI, &payload); err != nil { + encodeError(w, http.StatusBadRequest, err) + p.logger.Error("Failed to publish", slog.Any("error", err)) + return + } + p.target.ServeHTTP(w, r) +} + +func encodeError(w http.ResponseWriter, statusCode int, err error) { + w.WriteHeader(statusCode) + w.Header().Set("Content-Type", contentType) + if err := json.NewEncoder(w).Encode(err); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} diff --git a/mproxy.go b/mproxy.go new file mode 100644 index 0000000..a73b3c2 --- /dev/null +++ b/mproxy.go @@ -0,0 +1,59 @@ +package mproxy + +import ( + "context" + "net" + "net/http" +) + +// Handler is an interface for mProxy hooks. +type Handler interface { + // Authorization on client `CONNECT` + // Each of the params are passed by reference, so that it can be changed + AuthConnect(ctx context.Context) error + + // Authorization on client `PUBLISH` + // Topic is passed by reference, so that it can be modified + AuthPublish(ctx context.Context, topic *string, payload *[]byte) error + + // Authorization on client `SUBSCRIBE` + // Topics are passed by reference, so that they can be modified + AuthSubscribe(ctx context.Context, topics *[]string) error + + // After client successfully connected + Connect(ctx context.Context) error + + // After client successfully published + Publish(ctx context.Context, topic *string, payload *[]byte) error + + // After client successfully subscribed + Subscribe(ctx context.Context, topics *[]string) error + + // After client unsubscribed + Unsubscribe(ctx context.Context, topics *[]string) error + + // Disconnect on connection with client lost + Disconnect(ctx context.Context) error +} + +// Interceptor is an interface for mProxy intercept hook. +type Interceptor interface { + // Intercept is called on every packet flowing through the mProxy. + // Packets can be modified before being sent to the broker or the client. + // The error indicates unsuccessful interception and mProxy is cancelling the packet. + Intercept(ctx context.Context, pkt interface{}) error +} + +// Streamer is used for streaming traffic. +type Streamer interface { + // Stream streams the traffic between conn1 and conn2 in any direction (or both) + // providing Handler and Interceptos. + Stream(ctx context.Context, conn1, conn2 net.Conn) error +} + +// Forwarder is used for request-response protocols. +type Forwarder interface { + // Forward forwards the HTTP request and response for HTTP and + // WS based protocols. + Forward(rw http.ResponseWriter, r *http.Request) +} diff --git a/mqtt/streamer.go b/mqtt/streamer.go new file mode 100644 index 0000000..27bd199 --- /dev/null +++ b/mqtt/streamer.go @@ -0,0 +1,176 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package mqtt + +import ( + "context" + "errors" + "fmt" + "io" + "net" + + "github.com/absmach/mproxy" + "github.com/absmach/mproxy/tls" + "github.com/eclipse/paho.mqtt.golang/packets" +) + +type Direction int + +const ( + Up Direction = iota + Down +) + +const unknownID = "unknown" + +var ( + errBroker = "failed to proxy from MQTT client with id %s to MQTT broker with error: %s" + errClient = "failed to proxy from MQTT broker to client with id %s with error: %s" +) + +type streamer struct { + h mproxy.Handler + ic mproxy.Interceptor +} + +func New(h mproxy.Handler, ic mproxy.Interceptor) mproxy.Streamer { + return &streamer{ + h: h, + ic: ic, + } +} + +// Stream starts proxy between client and broker. +func (s *streamer) Stream(ctx context.Context, in, out net.Conn) error { + cert, err := tls.ClientCert(in) + if err != nil { + return err + } + session := mproxy.Session{ + Cert: cert, + } + ctx = mproxy.NewContext(ctx, &session) + errs := make(chan error, 2) + go stream(ctx, Up, in, out, s.h, s.ic, errs) + go stream(ctx, Down, out, in, s.h, s.ic, errs) + + // Handle whichever error happens first. + // The other routine won't be blocked when writing + // to the errors channel because it is buffered. + err = <-errs + + disconnectErr := s.h.Disconnect(ctx) + + return errors.Join(err, disconnectErr) +} + +func stream(ctx context.Context, dir Direction, r, w net.Conn, h mproxy.Handler, ic mproxy.Interceptor, errs chan error) { + for { + // Read from one connection. + pkt, err := packets.ReadPacket(r) + if err != nil { + errs <- wrap(ctx, err, dir) + return + } + + switch dir { + case Up: + if err = authorize(ctx, pkt, h); err != nil { + errs <- wrap(ctx, err, dir) + return + } + default: + if p, ok := pkt.(*packets.PublishPacket); ok { + if err = h.AuthPublish(ctx, &p.TopicName, &p.Payload); err != nil { + pkt = packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) + err = pkt.Write(w) + errs <- wrap(ctx, err, dir) + return + } + } + } + + if ic != nil { + if err := ic.Intercept(ctx, &pkt); err != nil { + errs <- wrap(ctx, err, dir) + return + } + } + + // Send to another. + if err := pkt.Write(w); err != nil { + errs <- wrap(ctx, err, dir) + return + } + + // Notify only for packets sent from client to broker (incoming packets). + if dir == Up { + if err := notify(ctx, pkt, h); err != nil { + errs <- wrap(ctx, err, dir) + } + } + } +} + +func authorize(ctx context.Context, pkt packets.ControlPacket, h mproxy.Handler) error { + switch p := pkt.(type) { + case *packets.ConnectPacket: + s, ok := mproxy.FromContext(ctx) + if ok { + s.ID = p.ClientIdentifier + s.Username = p.Username + s.Password = p.Password + } + + ctx = mproxy.NewContext(ctx, s) + if err := h.AuthConnect(ctx); err != nil { + return err + } + // Copy back to the packet in case values are changed by Event handler. + // This is specific to CONN, as only that package type has credentials. + p.ClientIdentifier = s.ID + p.Username = s.Username + p.Password = s.Password + return nil + case *packets.PublishPacket: + return h.AuthPublish(ctx, &p.TopicName, &p.Payload) + case *packets.SubscribePacket: + return h.AuthSubscribe(ctx, &p.Topics) + default: + return nil + } +} + +func notify(ctx context.Context, pkt packets.ControlPacket, h mproxy.Handler) error { + switch p := pkt.(type) { + case *packets.ConnectPacket: + return h.Connect(ctx) + case *packets.PublishPacket: + return h.Publish(ctx, &p.TopicName, &p.Payload) + case *packets.SubscribePacket: + return h.Subscribe(ctx, &p.Topics) + case *packets.UnsubscribePacket: + return h.Unsubscribe(ctx, &p.Topics) + default: + return nil + } +} + +func wrap(ctx context.Context, err error, dir Direction) error { + if err == io.EOF { + return err + } + cid := unknownID + if s, ok := mproxy.FromContext(ctx); ok { + cid = s.ID + } + switch dir { + case Up: + return fmt.Errorf(errClient, cid, err.Error()) + case Down: + return fmt.Errorf(errBroker, cid, err.Error()) + default: + return err + } +} diff --git a/mqtt/websocket/conn.go b/mqtt/websocket/conn.go new file mode 100644 index 0000000..7852018 --- /dev/null +++ b/mqtt/websocket/conn.go @@ -0,0 +1,78 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package websocket + +import ( + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// wsWrapper is a websocket wrapper so it satisfies the net.Conn interface. +type wsWrapper struct { + *websocket.Conn + r io.Reader + rio sync.Mutex + wio sync.Mutex +} + +func newConn(ws *websocket.Conn) net.Conn { + return &wsWrapper{ + Conn: ws, + } +} + +// SetDeadline sets both the read and write deadlines. +func (c *wsWrapper) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// Write writes data to the websocket. +func (c *wsWrapper) Write(p []byte) (int, error) { + c.wio.Lock() + defer c.wio.Unlock() + + err := c.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +// Read reads the current websocket frame. +func (c *wsWrapper) Read(p []byte) (int, error) { + c.rio.Lock() + defer c.rio.Unlock() + for { + if c.r == nil { + // Advance to next message. + var err error + _, c.r, err = c.NextReader() + if err != nil { + return 0, err + } + } + n, err := c.r.Read(p) + if err == io.EOF { + // At end of message. + c.r = nil + if n > 0 { + return n, nil + } + // No data read, continue to next message. + continue + } + return n, err + } +} + +func (c *wsWrapper) Close() error { + return c.Conn.Close() +} diff --git a/mqtt/websocket/forwarder.go b/mqtt/websocket/forwarder.go new file mode 100644 index 0000000..7cf0521 --- /dev/null +++ b/mqtt/websocket/forwarder.go @@ -0,0 +1,82 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package websocket + +import ( + "context" + "log/slog" + "net/http" + "time" + + "github.com/absmach/mproxy" + "github.com/absmach/mproxy/mqtt" + "github.com/gorilla/websocket" +) + +type proxy struct { + handler mproxy.Handler + interceptor mproxy.Interceptor + logger *slog.Logger + target string +} + +// New - creates new WS proxy passer. +func New(target string, handler mproxy.Handler, interceptor mproxy.Interceptor, logger *slog.Logger) mproxy.Forwarder { + return &proxy{ + target: target, + handler: handler, + interceptor: interceptor, + logger: logger, + } +} + +var upgrader = websocket.Upgrader{ + // Timeout for WS upgrade request handshake. + HandshakeTimeout: 10 * time.Second, + // Paho JS client expecting header Sec-WebSocket-Protocol:mqtt in Upgrade response during handshake. + Subprotocols: []string{"mqttv3.1", "mqtt"}, + // Allow CORS + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +func (p proxy) Forward(w http.ResponseWriter, r *http.Request) { + cconn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + p.logger.Error("Error upgrading connection", slog.Any("error", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + go p.pass(cconn) +} + +func (p proxy) pass(in *websocket.Conn) { + defer in.Close() + // Using a new context so as to avoiding infinitely long traces. + // And also avoiding proxy cancellation due to parent context cancellation. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dialer := &websocket.Dialer{ + Subprotocols: []string{"mqtt"}, + } + out, _, err := dialer.Dial(p.target, nil) + if err != nil { + p.logger.Error("Unable to connect to broker", slog.Any("error", err)) + return + } + + errc := make(chan error, 1) + inboundConn := newConn(in) + outboundConn := newConn(out) + + defer inboundConn.Close() + defer outboundConn.Close() + + streamer := mqtt.New(p.handler, p.interceptor) + err = streamer.Stream(ctx, inboundConn, outboundConn) + errc <- err + p.logger.Warn("Broken connection for client", slog.Any("error", err)) +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..2d1d974 --- /dev/null +++ b/session.go @@ -0,0 +1,37 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package mproxy + +import ( + "context" + "crypto/x509" +) + +// The sessionKey type is unexported to prevent collisions with context keys defined in +// other packages. +type sessionKey struct{} + +// Session stores session data. +type Session struct { + ID string + Username string + Password []byte + Cert x509.Certificate +} + +// NewContext stores Session in context.Context values. +// It uses pointer to the session so it can be modified by handler. +func NewContext(ctx context.Context, s *Session) context.Context { + return context.WithValue(ctx, sessionKey{}, s) +} + +// FromContext retrieves Session from context.Context. +// Second value indicates if session is present in the context +// and if it's safe to use it (it's not nil). +func FromContext(ctx context.Context) (*Session, bool) { + if s, ok := ctx.Value(sessionKey{}).(*Session); ok && s != nil { + return s, true + } + return nil, false +} diff --git a/streamer/proxy.go b/streamer/proxy.go new file mode 100644 index 0000000..4d0353a --- /dev/null +++ b/streamer/proxy.go @@ -0,0 +1,83 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package streamer + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "log/slog" + "net" + + "github.com/absmach/mproxy" + mptls "github.com/absmach/mproxy/tls" + "golang.org/x/sync/errgroup" +) + +// Listen of the server, this will block. +func Listen(ctx context.Context, name string, config mproxy.Config, streamer mproxy.Streamer, logger *slog.Logger) error { + l, err := net.Listen("tcp", config.Address) + if err != nil { + return err + } + + if config.TLSConfig != nil { + l = tls.NewListener(l, config.TLSConfig) + } + status := mptls.SecurityStatus(config.TLSConfig) + logger.Info(fmt.Sprintf("Proxy server started at %s with %s", config.Address, status)) + g, ctx := errgroup.WithContext(ctx) + + // Acceptor loop + g.Go(func() error { + return accept(ctx, streamer, config.Target, l, *logger) + }) + + g.Go(func() error { + <-ctx.Done() + return l.Close() + }) + if err := g.Wait(); err != nil { + logger.Info(fmt.Sprintf("%s Proxy server at %s with %s exiting with errors", name, config.Address, status), slog.String("error", err.Error())) + } else { + logger.Info(fmt.Sprintf("%s Proxy server at %s with %s exiting...", name, config.Address, status)) + } + return nil +} + +func accept(ctx context.Context, streamer mproxy.Streamer, target string, l net.Listener, logger slog.Logger) error { + for { + select { + case <-ctx.Done(): + return nil + default: + in, err := l.Accept() + if err != nil { + logger.Warn("Accept error " + err.Error()) + continue + } + logger.Info("Accepted new client") + go func() { + defer close(in, logger) + out, err := net.Dial("tcp", target) + if err != nil { + logger.Error("Cannot connect to remote broker " + target + " due to: " + err.Error()) + return + } + defer close(out, logger) + + if err = streamer.Stream(ctx, in, out); err != io.EOF { + logger.Warn(err.Error()) + } + }() + } + } +} + +func close(conn net.Conn, logger slog.Logger) { + if err := conn.Close(); err != nil { + logger.Warn(fmt.Sprintf("Error closing connection %s", err.Error())) + } +} diff --git a/tcp/proxy.go b/tcp/proxy.go new file mode 100644 index 0000000..ce7431b --- /dev/null +++ b/tcp/proxy.go @@ -0,0 +1,17 @@ +package main + +import ( + "log" + + "github.com/inetaf/tcpproxy" +) + +func main() { + var p tcpproxy.Proxy + p.AddRoute(":1884", tcpproxy.To("localhost:1883")) // fallback + p.AddRoute(":8083", tcpproxy.To("localhost:8000")) // fallback + // p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431")) + // p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432")) + // p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback + log.Fatal(p.Run()) +} diff --git a/tls/config.go b/tls/config.go new file mode 100644 index 0000000..93d40c5 --- /dev/null +++ b/tls/config.go @@ -0,0 +1,32 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package tls + +import ( + "github.com/absmach/mproxy/pkg/tls/verifier" + "github.com/caarlos0/env/v11" +) + +type Config struct { + CertFile string `env:"CERT_FILE" envDefault:""` + KeyFile string `env:"KEY_FILE" envDefault:""` + ServerCAFile string `env:"SERVER_CA_FILE" envDefault:""` + ClientCAFile string `env:"CLIENT_CA_FILE" envDefault:""` + Validator verifier.Validator +} + +func NewConfig(opts env.Options) (Config, error) { + c := Config{} + var err error + if err = env.ParseWithOptions(&c, opts); err != nil { + return Config{}, err + } + verifiers, err := newVerifiers(opts) + if err != nil { + return Config{}, err + } + c.Validator = verifier.NewValidator(verifiers) + + return c, nil +} diff --git a/tls/tls.go b/tls/tls.go new file mode 100644 index 0000000..be66aad --- /dev/null +++ b/tls/tls.go @@ -0,0 +1,114 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "net" + "os" +) + +var ( + errTLSdetails = errors.New("failed to get TLS details of connection") + errLoadCerts = errors.New("failed to load certificates") + errLoadServerCA = errors.New("failed to load Server CA") + errLoadClientCA = errors.New("failed to load Client CA") + errAppendCA = errors.New("failed to append root ca tls.Config") +) + +// Load return a TLS configuration that can be used in TLS servers. +func Load(c *Config) (*tls.Config, error) { + if c.CertFile == "" || c.KeyFile == "" { + return nil, nil + } + + tlsConfig := &tls.Config{} + + certificate, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return nil, errors.Join(errLoadCerts, err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{certificate}, + } + + // Loading Server CA file + rootCA, err := loadCertFile(c.ServerCAFile) + if err != nil { + return nil, errors.Join(errLoadServerCA, err) + } + if len(rootCA) > 0 { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { + return nil, errAppendCA + } + } + + // Loading Client CA File + clientCA, err := loadCertFile(c.ClientCAFile) + if err != nil { + return nil, errors.Join(errLoadClientCA, err) + } + if len(clientCA) > 0 { + if tlsConfig.ClientCAs == nil { + tlsConfig.ClientCAs = x509.NewCertPool() + } + if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { + return nil, errAppendCA + } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + if c.Validator != nil { + tlsConfig.VerifyPeerCertificate = c.Validator + } + } + return tlsConfig, nil +} + +// ClientCert returns client certificate. +func ClientCert(conn net.Conn) (x509.Certificate, error) { + switch connVal := conn.(type) { + case *tls.Conn: + if err := connVal.Handshake(); err != nil { + return x509.Certificate{}, err + } + state := connVal.ConnectionState() + if state.Version == 0 { + return x509.Certificate{}, errTLSdetails + } + if len(state.PeerCertificates) == 0 { + return x509.Certificate{}, nil + } + cert := *state.PeerCertificates[0] + return cert, nil + default: + return x509.Certificate{}, nil + } +} + +// SecurityStatus returns log message from TLS config. +func SecurityStatus(c *tls.Config) string { + if c == nil { + return "no TLS" + } + ret := "TLS" + // It is possible to establish TLS with client certificates only. + if c.Certificates == nil || len(c.Certificates) == 0 { + ret = "no server certificates" + } + if c.ClientCAs != nil { + ret += " and " + c.ClientAuth.String() + } + return ret +} + +func loadCertFile(certFile string) ([]byte, error) { + if certFile != "" { + return os.ReadFile(certFile) + } + return []byte{}, nil +} diff --git a/tls/verifications.go b/tls/verifications.go new file mode 100644 index 0000000..25e85db --- /dev/null +++ b/tls/verifications.go @@ -0,0 +1,96 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package tls + +import ( + "errors" + "reflect" + "strings" + + "github.com/absmach/mproxy/pkg/tls/verifier" + "github.com/absmach/mproxy/pkg/tls/verifier/crl" + "github.com/absmach/mproxy/pkg/tls/verifier/ocsp" + "github.com/caarlos0/env/v11" +) + +// ErrInvalidCertVerification represents an error during the cert verification +// method loading. Supported are OCSP and CRL verification methods. +var ErrInvalidCertVerification = errors.New("invalid certificate verification method") + +type verification int + +const ( + OCSP verification = iota + 1 + CRL +) + +func newVerifiers(opts env.Options) ([]verifier.Verifier, error) { + if opts.FuncMap == nil { + opts.FuncMap = make(map[reflect.Type]env.ParserFunc) + } + opts.FuncMap[reflect.TypeOf(make([]verification, 0))] = envParseSliceValidate + opts.FuncMap[reflect.TypeOf(new(verification))] = envParseValidation + + var c struct { + Verifications []verification `env:"CERT_VERIFICATION_METHODS" envDefault:""` + } + if err := env.ParseWithOptions(&c, opts); err != nil { + return nil, err + } + if len(c.Verifications) == 0 { + return nil, nil + } + + var vms []verifier.Verifier + for _, v := range c.Verifications { + switch v { + case OCSP: + vm, err := ocsp.New(opts) + if err != nil { + return nil, err + } + vms = append(vms, vm) + case CRL: + vm, err := crl.New(opts) + if err != nil { + return nil, err + } + vms = append(vms, vm) + default: + return nil, ErrInvalidCertVerification + } + } + + return vms, nil +} + +func parseValidation(v string) (verification, error) { + v = strings.ToUpper(strings.TrimSpace(v)) + switch v { + case "OCSP": + return OCSP, nil + case "CRL": + return CRL, nil + default: + return 0, ErrInvalidCertVerification + } +} + +func envParseSliceValidate(v string) (interface{}, error) { + var vms []verification + v = strings.TrimSpace(v) + vmss := strings.Split(v, ",") + for _, vm := range vmss { + v, err := parseValidation(vm) + if err != nil { + return nil, err + } + vms = append(vms, v) + } + return vms, nil +} + +func envParseValidation(v string) (interface{}, error) { + return parseValidation(v) +} diff --git a/tls/verifier/crl/crl.go b/tls/verifier/crl/crl.go new file mode 100644 index 0000000..b5bdcda --- /dev/null +++ b/tls/verifier/crl/crl.go @@ -0,0 +1,285 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package crl + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + "github.com/absmach/mproxy/pkg/tls/verifier" + "github.com/caarlos0/env/v11" +) + +var ( + errRetrieveCRL = errors.New("failed to retrieve CRL") + errReadCRL = errors.New("failed to read CRL") + errParseCRL = errors.New("failed to parse CRL") + errExpiredCRL = errors.New("crl expired") + errCRLSign = errors.New("failed to verify CRL signature") + errOfflineCRLLoad = errors.New("failed to load offline CRL file") + errOfflineCRLIssuer = errors.New("failed to load offline CRL issuer cert file") + errOfflineCRLIssuerPEM = errors.New("failed to decode PEM block in offline CRL issuer cert file") + errCRLDistIssuer = errors.New("failed to load CRL distribution points issuer cert file") + errCRLDistIssuerPEM = errors.New("failed to decode PEM block in CRL distribution points issuer cert file") + errNoCRL = errors.New("neither offline crl file nor crl distribution points in certificate / environmental variable CRL_DISTRIBUTION_POINTS & CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE have values") + errCertRevoked = errors.New("certificate revoked") +) + +var ( + errParseCert = errors.New("failed to parse Certificate") + errClientCrt = errors.New("client certificate not received") +) + +type config struct { + CRLDepth uint `env:"CRL_DEPTH" envDefault:"1"` + OfflineCRLFile string `env:"OFFLINE_CRL_FILE" envDefault:""` + OfflineCRLIssuerCertFile string `env:"OFFLINE_CRL_ISSUER_CERT_FILE" envDefault:""` + CRLDistributionPoints url.URL `env:"CRL_DISTRIBUTION_POINTS" envDefault:""` + CRLDistributionPointsIssuerCertFile string `env:"CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE" envDefault:""` +} + +var _ verifier.Verifier = (*config)(nil) + +func New(opts env.Options) (verifier.Verifier, error) { + var c config + if err := env.ParseWithOptions(&c, opts); err != nil { + return nil, err + } + return &c, nil +} + +func (c *config) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + switch { + case len(verifiedChains) > 0: + return c.VerifyVerifiedPeerCertificates(verifiedChains) + case len(rawCerts) > 0: + var peerCertificates []*x509.Certificate + peerCertificates, err := parseCertificates(rawCerts) + if err != nil { + return err + } + return c.VerifyRawPeerCertificates(peerCertificates) + default: + return errClientCrt + } +} + +func (c *config) VerifyVerifiedPeerCertificates(verifiedPeerCertificateChains [][]*x509.Certificate) error { + offlineCRL, err := c.loadOfflineCRL() + if err != nil { + return err + } + for _, verifiedChain := range verifiedPeerCertificateChains { + for i := range verifiedChain { + cert := verifiedChain[i] + issuer := cert + if i+1 < len(verifiedChain) { + issuer = verifiedChain[i+1] + } + + crl, err := c.getCRLFromDistributionPoint(cert, issuer) + if err != nil { + return err + } + switch { + case crl == nil && offlineCRL != nil: + crl = offlineCRL + case crl == nil && offlineCRL == nil: + return errNoCRL + } + + if err := c.crlVerify(cert, crl); err != nil { + return err + } + } + } + return nil +} + +func (c *config) VerifyRawPeerCertificates(peerCertificates []*x509.Certificate) error { + offlineCRL, err := c.loadOfflineCRL() + if err != nil { + return err + } + for i, peerCertificate := range peerCertificates { + issuerCert := retrieveIssuerCert(peerCertificate.Issuer, peerCertificates) + crl, err := c.getCRLFromDistributionPoint(peerCertificate, issuerCert) + if err != nil { + return err + } + switch { + case crl == nil && offlineCRL != nil: + crl = offlineCRL + case crl == nil && offlineCRL == nil: + return errNoCRL + } + + if err := c.crlVerify(peerCertificate, crl); err != nil { + return err + } + if i+1 == int(c.CRLDepth) { + return nil + } + } + return nil +} + +func (c *config) crlVerify(peerCertificate *x509.Certificate, crl *x509.RevocationList) error { + for _, revokedCertificate := range crl.RevokedCertificateEntries { + if revokedCertificate.SerialNumber.Cmp(peerCertificate.SerialNumber) == 0 { + return errCertRevoked + } + } + return nil +} + +func (c *config) loadOfflineCRL() (*x509.RevocationList, error) { + offlineCRLBytes, err := loadCertFile(c.OfflineCRLFile) + if err != nil { + return nil, errors.Join(errOfflineCRLLoad, err) + } + if len(offlineCRLBytes) == 0 { + return nil, nil + } + fmt.Println(c.OfflineCRLIssuerCertFile) + issuer, err := c.loadOfflineCRLIssuerCert() + if err != nil { + return nil, err + } + _ = issuer + offlineCRL, err := parseVerifyCRL(offlineCRLBytes, nil, false) + if err != nil { + return nil, err + } + return offlineCRL, nil +} + +func (c *config) getCRLFromDistributionPoint(cert, issuer *x509.Certificate) (*x509.RevocationList, error) { + switch { + case len(cert.CRLDistributionPoints) > 0: + return retrieveCRL(cert.CRLDistributionPoints[0], issuer, true) + case c.CRLDistributionPoints.String() != "" && c.CRLDistributionPointsIssuerCertFile != "": + var crlIssuerCrt *x509.Certificate + var err error + if crlIssuerCrt, err = c.loadDistPointCRLIssuerCert(); err != nil { + return nil, err + } + return retrieveCRL(c.CRLDistributionPoints.String(), crlIssuerCrt, true) + default: + return nil, nil + } +} + +func (c *config) loadDistPointCRLIssuerCert() (*x509.Certificate, error) { + crlIssuerCertBytes, err := loadCertFile(c.CRLDistributionPointsIssuerCertFile) + if err != nil { + return nil, errors.Join(errCRLDistIssuer, err) + } + if len(crlIssuerCertBytes) == 0 { + return nil, nil + } + crlIssuerCertPEM, _ := pem.Decode(crlIssuerCertBytes) + if crlIssuerCertPEM == nil { + return nil, errCRLDistIssuerPEM + } + crlIssuerCert, err := x509.ParseCertificate(crlIssuerCertPEM.Bytes) + if err != nil { + return nil, errors.Join(errCRLDistIssuer, err) + } + return crlIssuerCert, nil +} + +func (c *config) loadOfflineCRLIssuerCert() (*x509.Certificate, error) { + offlineCrlIssuerCertBytes, err := loadCertFile(c.OfflineCRLIssuerCertFile) + if err != nil { + return nil, errors.Join(errOfflineCRLIssuer, err) + } + if len(offlineCrlIssuerCertBytes) == 0 { + return nil, nil + } + offlineCrlIssuerCertPEM, _ := pem.Decode(offlineCrlIssuerCertBytes) + if offlineCrlIssuerCertPEM == nil { + return nil, errOfflineCRLIssuerPEM + } + crlIssuerCert, err := x509.ParseCertificate(offlineCrlIssuerCertPEM.Bytes) + if err != nil { + return nil, errors.Join(errOfflineCRLIssuer, err) + } + return crlIssuerCert, nil +} + +func retrieveCRL(crlDistributionPoints string, issuerCert *x509.Certificate, checkSign bool) (*x509.RevocationList, error) { + resp, err := http.Get(crlDistributionPoints) + if err != nil { + return nil, errors.Join(errRetrieveCRL, err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Join(errReadCRL, err) + } + return parseVerifyCRL(body, issuerCert, checkSign) +} + +func parseVerifyCRL(clrB []byte, issuerCert *x509.Certificate, checkSign bool) (*x509.RevocationList, error) { + block, _ := pem.Decode(clrB) + if block == nil { + return nil, errParseCRL + } + + crl, err := x509.ParseRevocationList(block.Bytes) + if err != nil { + return nil, errors.Join(errParseCRL, err) + } + + if checkSign { + if err := crl.CheckSignatureFrom(issuerCert); err != nil { + return nil, errors.Join(errCRLSign, err) + } + } + + if crl.NextUpdate.Before(time.Now()) { + return nil, errExpiredCRL + } + return crl, nil +} + +func loadCertFile(certFile string) ([]byte, error) { + if certFile != "" { + return os.ReadFile(certFile) + } + return []byte{}, nil +} + +func retrieveIssuerCert(issuerSubject pkix.Name, certs []*x509.Certificate) *x509.Certificate { + for _, cert := range certs { + if cert.Subject.SerialNumber != "" && issuerSubject.SerialNumber != "" && cert.Subject.SerialNumber == issuerSubject.SerialNumber { + return cert + } + if (cert.Subject.SerialNumber == "" || issuerSubject.SerialNumber == "") && cert.Subject.String() == issuerSubject.String() { + return cert + } + } + return nil +} + +func parseCertificates(rawCerts [][]byte) ([]*x509.Certificate, error) { + var certs []*x509.Certificate + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return nil, errors.Join(errParseCert, err) + } + certs = append(certs, cert) + } + return certs, nil +} diff --git a/tls/verifier/ocsp/ocsp.go b/tls/verifier/ocsp/ocsp.go new file mode 100644 index 0000000..c51719e --- /dev/null +++ b/tls/verifier/ocsp/ocsp.go @@ -0,0 +1,239 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package ocsp + +import ( + "bytes" + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/absmach/mproxy/pkg/tls/verifier" + "github.com/caarlos0/env/v11" + "golang.org/x/crypto/ocsp" +) + +var ( + errParseIssuerCrt = errors.New("failed to parse issuer certificate") + errCreateOCSPReq = errors.New("failed to create OCSP Request") + errCreateOCSPHTTPReq = errors.New("failed to create OCSP HTTP Request") + errParseOCSPUrl = errors.New("failed to parse OCSP server URL") + errOCSPReq = errors.New("OCSP request failed") + errOCSPReadResp = errors.New("failed to read OCSP response") + errParseOCSPRespForCert = errors.New("failed to parse OCSP Response for Certificate") + errIssuerCert = errors.New("neither the issuer certificate is present in the chain nor is the issuer certificate URL present in AIA") + errNoOCSPURL = errors.New("neither OCSP Server/Responder URL is not present AIA of certificate nor environmental variable OCSP_RESPONDER_URL have value") + errOCSPServerFailed = errors.New("OCSP Server Failed") + errOCSPUnknown = errors.New("OCSP status unknown") + errCertRevoked = errors.New("certificate revoked") + errRetrieveIssuerCrt = errors.New("failed to retrieve issuer certificate") + errReadIssuerCrt = errors.New("failed to read issuer certificate") + errIssuerCrtPEM = errors.New("failed to decode issuer certificate PEM") + + errParseCert = errors.New("failed to parse Certificate") + errClientCrt = errors.New("client certificate not received") +) + +type config struct { + OCSPDepth uint `env:"OCSP_DEPTH" envDefault:"0"` + OCSPResponderURL url.URL `env:"OCSP_RESPONDER_URL" envDefault:""` +} + +var _ verifier.Verifier = (*config)(nil) + +func New(opts env.Options) (verifier.Verifier, error) { + var c config + if err := env.ParseWithOptions(&c, opts); err != nil { + return nil, err + } + return &c, nil +} + +func (c *config) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + switch { + case len(verifiedChains) > 0: + return c.VerifyVerifiedPeerCertificates(verifiedChains) + case len(rawCerts) > 0: + var peerCertificates []*x509.Certificate + peerCertificates, err := parseCertificates(rawCerts) + if err != nil { + return err + } + return c.VerifyRawPeerCertificates(peerCertificates) + default: + return errClientCrt + } +} + +func (c *config) VerifyRawPeerCertificates(peerCertificates []*x509.Certificate) error { + for i, peerCertificate := range peerCertificates { + issuer := retrieveIssuerCert(peerCertificate.Issuer, peerCertificates) + if err := c.ocspVerify(peerCertificate, issuer); err != nil { + return err + } + if i+1 == int(c.OCSPDepth) { + return nil + } + } + return nil +} + +func (c *config) VerifyVerifiedPeerCertificates(verifiedPeerCertificateChains [][]*x509.Certificate) error { + for _, verifiedChain := range verifiedPeerCertificateChains { + for i := range verifiedChain { + cert := verifiedChain[i] + issuer := cert + if i+1 < len(verifiedChain) { + issuer = verifiedChain[i+1] + } + if err := c.ocspVerify(cert, issuer); err != nil { + return err + } + } + } + return nil +} + +func (c *config) ocspVerify(peerCertificate, issuerCert *x509.Certificate) error { + opts := &ocsp.RequestOptions{Hash: crypto.SHA256} + var err error + + if !isRootCA(peerCertificate) { + if issuerCert == nil { + if len(peerCertificate.IssuingCertificateURL) < 1 { + return fmt.Errorf("%w common name %s and serial number %x", errIssuerCert, peerCertificate.Subject.CommonName, peerCertificate.SerialNumber) + } + issuerCert, err = retrieveIssuingCertificate(peerCertificate.IssuingCertificateURL[0]) + if err != nil { + return err + } + } + } else { + issuerCert = peerCertificate + } + + buffer, err := ocsp.CreateRequest(peerCertificate, issuerCert, opts) + if err != nil { + return errors.Join(errCreateOCSPReq, err) + } + + ocspURL := "" + ocspURLHost := "" + if c.OCSPResponderURL.String() == "" { + if len(peerCertificate.OCSPServer) < 1 { + return fmt.Errorf("%w common name %s and serial number %x", errNoOCSPURL, peerCertificate.Subject.CommonName, peerCertificate.SerialNumber) + } + ocspURL = peerCertificate.OCSPServer[0] + ocspParsedURL, err := url.Parse(peerCertificate.OCSPServer[0]) + if err != nil { + return errors.Join(errParseOCSPUrl, err) + } + ocspURLHost = ocspParsedURL.Host + } else { + ocspURLHost = c.OCSPResponderURL.Host + ocspURL = c.OCSPResponderURL.String() + } + + httpRequest, err := http.NewRequest(http.MethodPost, ocspURL, bytes.NewBuffer(buffer)) + if err != nil { + return errors.Join(errCreateOCSPHTTPReq, err) + } + httpRequest.Header.Add("Content-Type", "application/ocsp-request") + httpRequest.Header.Add("Accept", "application/ocsp-response") + httpRequest.Header.Add("host", ocspURLHost) + + httpClient := &http.Client{} + httpResponse, err := httpClient.Do(httpRequest) + if err != nil { + return errors.Join(errOCSPReq, err) + } + defer httpResponse.Body.Close() + output, err := io.ReadAll(httpResponse.Body) + if err != nil { + return errors.Join(errOCSPReadResp, err) + } + ocspResponse, err := ocsp.ParseResponseForCert(output, peerCertificate, issuerCert) + if err != nil { + return errors.Join(errParseOCSPRespForCert, err) + } + switch ocspResponse.Status { + case ocsp.Good: + return nil + case ocsp.Revoked: + return fmt.Errorf("%w command name %s and serial number %x revoked at %v", errCertRevoked, peerCertificate.Subject.CommonName, peerCertificate.SerialNumber, ocspResponse.RevokedAt) + case ocsp.ServerFailed: + return errOCSPServerFailed + case ocsp.Unknown: + fallthrough + default: + return errOCSPUnknown + } +} + +func retrieveIssuerCert(issuerSubject pkix.Name, certs []*x509.Certificate) *x509.Certificate { + for _, cert := range certs { + if cert.Subject.SerialNumber != "" && issuerSubject.SerialNumber != "" && cert.Subject.SerialNumber == issuerSubject.SerialNumber { + return cert + } + if (cert.Subject.SerialNumber == "" || issuerSubject.SerialNumber == "") && cert.Subject.String() == issuerSubject.String() { + return cert + } + } + return nil +} + +func retrieveIssuingCertificate(issuingCertificateURL string) (*x509.Certificate, error) { + resp, err := http.Get(issuingCertificateURL) + if err != nil { + return nil, errors.Join(errRetrieveIssuerCrt, err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Join(errReadIssuerCrt, err) + } + + block, _ := pem.Decode(body) + if block == nil { + return nil, errIssuerCrtPEM + } + + issCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Join(errParseIssuerCrt, err) + } + return issCert, nil +} + +func isRootCA(cert *x509.Certificate) bool { + if cert.IsCA { + // Check AuthorityKeyId and SubjectKeyId are same. + if len(cert.AuthorityKeyId) > 0 && len(cert.SubjectKeyId) > 0 && bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId) { + return true + } + // Alternatively, check Issuer and Subject are same. + if cert.Issuer.String() == cert.Subject.String() { + return true + } + } + return false +} + +func parseCertificates(rawCerts [][]byte) ([]*x509.Certificate, error) { + var certs []*x509.Certificate + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return nil, errors.Join(errParseCert, err) + } + certs = append(certs, cert) + } + return certs, nil +} diff --git a/tls/verifier/verifier.go b/tls/verifier/verifier.go new file mode 100644 index 0000000..ebd9bcc --- /dev/null +++ b/tls/verifier/verifier.go @@ -0,0 +1,24 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package verifier + +import "crypto/x509" + +type Verifier interface { + // VerifyPeerCertificate is used to verify certificates in TLS config. + VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error +} + +type Validator func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + +func NewValidator(verifiers []Verifier) Validator { + return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, vm := range verifiers { + if err := vm.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { + return err + } + } + return nil + } +} diff --git a/websockets/forwarder.go b/websockets/forwarder.go new file mode 100644 index 0000000..903569b --- /dev/null +++ b/websockets/forwarder.go @@ -0,0 +1,120 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package websockets + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + + "github.com/absmach/mproxy" + "github.com/absmach/mproxy/pkg/session" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" +) + +var ( + upgrader = websocket.Upgrader{} + ErrAuthorizationNotSet = errors.New("authorization not set") +) + +type Proxy struct { + target string + handler mproxy.Handler + logger *slog.Logger +} + +func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) { + var token string + headers := http.Header{} + switch { + case len(r.URL.Query()["authorization"]) != 0: + token = r.URL.Query()["authorization"][0] + case r.Header.Get("Authorization") != "": + token = r.Header.Get("Authorization") + headers.Add("Authorization", token) + default: + http.Error(w, ErrAuthorizationNotSet.Error(), http.StatusUnauthorized) + return + } + + target := fmt.Sprintf("%s%s", p.target, r.RequestURI) + + targetConn, _, err := websocket.DefaultDialer.Dial(target, headers) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer targetConn.Close() + + topic := r.URL.Path + s := mproxy.Session{Password: []byte(token)} + ctx := mproxy.NewContext(context.Background(), &s) + if err := p.handler.AuthConnect(ctx); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + if err := p.handler.AuthSubscribe(ctx, &[]string{topic}); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + if err := p.handler.Subscribe(ctx, &[]string{topic}); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + inConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + p.logger.Warn("WS Proxy failed to upgrade connection", slog.Any("error", err)) + return + } + defer inConn.Close() + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return p.stream(ctx, topic, inConn, targetConn, true) + }) + g.Go(func() error { + return p.stream(ctx, topic, targetConn, inConn, false) + }) + + if err := g.Wait(); err != nil { + if err := p.handler.Unsubscribe(ctx, &[]string{topic}); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + p.logger.Error("WS Proxy terminated", slog.Any("error", err)) + return + } +} + +func (p *Proxy) stream(ctx context.Context, topic string, src, dest *websocket.Conn, upstream bool) error { + for { + messageType, payload, err := src.ReadMessage() + if err != nil { + return err + } + if upstream { + if err := p.handler.AuthPublish(ctx, &topic, &payload); err != nil { + return err + } + if err := p.handler.Publish(ctx, &topic, &payload); err != nil { + return err + } + } + if err := dest.WriteMessage(messageType, payload); err != nil { + return err + } + } +} + +func NewProxy(target string, logger *slog.Logger, handler session.Handler) mproxy.Forwarder { + return &Proxy{ + target: target, + logger: logger, + handler: handler, + } +}