diff --git a/.golangci.yml b/.golangci.yml index 3da1aa7a..bc2b437a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -204,6 +204,7 @@ issues: - funlen - forbidigo - gochecknoinits + - dupl - path: ".generated.go" linters: - typecheck diff --git a/cmd/handler/broker.go b/cmd/handler/broker.go index dc05b765..b135dab4 100644 --- a/cmd/handler/broker.go +++ b/cmd/handler/broker.go @@ -12,6 +12,12 @@ var broker = &cobra.Command{ Use: "broker", Short: "Run the Unchained client in broker mode", Long: `Run the Unchained client in broker mode`, + + PreRun: func(cmd *cobra.Command, _ []string) { + config.App.Network.CertFile = cmd.Flags().Lookup("cert-file").Value.String() + config.App.Network.KeyFile = cmd.Flags().Lookup("key-file").Value.String() + }, + Run: func(_ *cobra.Command, _ []string) { err := config.Load(config.App.System.ConfigPath, config.App.System.SecretsPath) if err != nil { @@ -30,9 +36,16 @@ func WithBrokerCmd(cmd *cobra.Command) { func init() { broker.Flags().StringP( - "broker", - "b", - "wss://shinobi.brokers.kenshi.io", - "Unchained broker to connect to", + "cert-file", + "C", + "", + "TLS certificate file", + ) + + broker.Flags().StringP( + "key-file", + "k", + "", + "TLS key file", ) } diff --git a/cmd/handler/plugin.go b/cmd/handler/plugin.go new file mode 100644 index 00000000..1f1dafb9 --- /dev/null +++ b/cmd/handler/plugin.go @@ -0,0 +1,50 @@ +package handler + +import ( + "os" + + "github.com/TimeleapLabs/unchained/cmd/handler/plugins" + "github.com/gorilla/websocket" + "github.com/spf13/cobra" +) + +var conn *websocket.Conn + +func Read() <-chan []byte { + out := make(chan []byte) + + go func() { + for { + _, payload, err := conn.ReadMessage() + if err != nil { + panic(err) + } + + out <- payload + } + }() + + return out +} + +// plugin represents the plugin command. +var plugin = &cobra.Command{ + Use: "plugin", + Short: "Run an Unchained plugin locally", + Long: `Run an Unchained plugin locally`, + + Run: func(cmd *cobra.Command, _ []string) { + os.Exit(1) + }, +} + +// WithPluginCmd appends the plugin command to the root command. +func WithPluginCmd(cmd *cobra.Command) { + cmd.AddCommand(plugin) +} + +func init() { + plugins.WithAIPluginCmd(plugin) + plugins.WithTextToImagePluginCmd(plugin) + plugins.WithTranslatePluginCmd(plugin) +} diff --git a/cmd/handler/plugins/ai.go b/cmd/handler/plugins/ai.go new file mode 100644 index 00000000..e2e4edff --- /dev/null +++ b/cmd/handler/plugins/ai.go @@ -0,0 +1,41 @@ +package plugins + +import ( + "log" + "os" + "os/signal" + "syscall" + + "github.com/TimeleapLabs/unchained/internal/service/ai" + "github.com/spf13/cobra" +) + +// worker represents the worker command. +var aiPlugin = &cobra.Command{ + Use: "ai", + Short: "Start the Unchained ai server for local invocation", + Long: `Start the Unchained ai server for local invocation`, + + Run: func(cmd *cobra.Command, _ []string) { + wg, cancel := ai.StartServer(cmd.Context()) + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + sig := <-sigChan + log.Printf("Received signal: %v. Shutting down gracefully...", sig) + cancel() // Cancel the context to stop all managed processes + }() + + // Wait for all processes to finish + wg.Wait() + log.Println("All processes have been stopped.") + }, +} + +// WithRunCmd appends the run command to the root command. +func WithAIPluginCmd(cmd *cobra.Command) { + cmd.AddCommand(aiPlugin) +} diff --git a/cmd/handler/plugins/common.go b/cmd/handler/plugins/common.go new file mode 100644 index 00000000..8916ab15 --- /dev/null +++ b/cmd/handler/plugins/common.go @@ -0,0 +1,41 @@ +package plugins + +import "github.com/gorilla/websocket" + +var conn *websocket.Conn +var closed = false + +func Read() <-chan []byte { + out := make(chan []byte) + + go func() { + for { + _, payload, err := conn.ReadMessage() + if err != nil { + if !closed { + panic(err) + } + } + + out <- payload + } + }() + + return out +} + +func CloseSocket() { + if conn != nil { + closed = true + err := conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + return + } + err = conn.Close() + if err != nil { + return + } + } +} diff --git a/cmd/handler/plugins/text_to_image.go b/cmd/handler/plugins/text_to_image.go new file mode 100644 index 00000000..13e134a9 --- /dev/null +++ b/cmd/handler/plugins/text_to_image.go @@ -0,0 +1,93 @@ +package plugins + +import ( + "os" + + "github.com/TimeleapLabs/unchained/internal/service/ai" + "github.com/spf13/cobra" +) + +// worker represents the worker command. +var textToImagePlugin = &cobra.Command{ + Use: "text-to-image", + Short: "Run the text-to-image plugin locally", + Long: `Run the text-to-image plugin locally`, + + Run: func(cmd *cobra.Command, _ []string) { + + prompt := cmd.Flags().Lookup("prompt").Value.String() + negativePrompt := cmd.Flags().Lookup("negative-prompt").Value.String() + output := cmd.Flags().Lookup("output").Value.String() + model := cmd.Flags().Lookup("model").Value.String() + loraWeights := cmd.Flags().Lookup("lora-weights").Value.String() + steps, err := cmd.Flags().GetUint8("inference") + + if err != nil { + panic(err) + } + + outputBytes := ai.TextToImage(prompt, negativePrompt, model, loraWeights, steps) + + // write outputBytes as png to output file path of output flag + err = os.WriteFile(output, outputBytes, 0644) //nolint: gosec // Other users may need to read these files. + if err != nil { + panic(err) + } + + CloseSocket() + os.Exit(0) + }, +} + +// WithRunCmd appends the run command to the root command. +func WithTextToImagePluginCmd(cmd *cobra.Command) { + cmd.AddCommand(textToImagePlugin) +} + +func init() { + textToImagePlugin.Flags().StringP( + "prompt", + "p", + "", + "Prompt data to process", + ) + textToImagePlugin.Flags().StringP( + "negative-prompt", + "n", + "", + "Negative prompt data to process", + ) + textToImagePlugin.Flags().Uint8P( + "inference", + "i", + 16, + "Number of inference steps", + ) + textToImagePlugin.Flags().StringP( + "model", + "m", + "OEvortex/PixelGen", + "Model to use for inference", + ) + textToImagePlugin.Flags().StringP( + "lora-weights", + "w", + "", + "Lora weights model name (if applicable)", + ) + textToImagePlugin.Flags().StringP( + "output", + "o", + "", + "Output file path", + ) + + err := textToImagePlugin.MarkFlagRequired("prompt") + if err != nil { + panic(err) + } + err = textToImagePlugin.MarkFlagRequired("output") + if err != nil { + panic(err) + } +} diff --git a/cmd/handler/plugins/translate.go b/cmd/handler/plugins/translate.go new file mode 100644 index 00000000..46453f16 --- /dev/null +++ b/cmd/handler/plugins/translate.go @@ -0,0 +1,123 @@ +package plugins + +import ( + "fmt" + "os" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + sia "github.com/pouya-eghbali/go-sia/v2/pkg" + "github.com/spf13/cobra" +) + +// worker represents the worker command. +var translatePlugin = &cobra.Command{ + Use: "translate", + Short: "Run the translate plugin locally", + Long: `Run the translate plugin locally`, + + Run: func(cmd *cobra.Command, _ []string) { + + input := cmd.Flags().Lookup("input").Value.String() + from := cmd.Flags().Lookup("from").Value.String() + to := cmd.Flags().Lookup("to").Value.String() + + var err error + conn, _, err = websocket.DefaultDialer.Dial( + "ws://localhost:8765", nil, + ) + + if err != nil { + panic(err) + } + + incoming := Read() + + requestUUID, err := uuid.NewV7() + if err != nil { + panic(err) + } + + uuidBytes, err := requestUUID.MarshalBinary() + if err != nil { + panic(err) + } + + payload := sia.New(). + AddUInt16(1). + AddByteArrayN(uuidBytes). + AddStringN(from). + AddStringN(to). + AddString16(input). + Bytes() + + err = conn.WriteMessage(websocket.BinaryMessage, payload) + if err != nil { + panic(err) + } + + // Start a goroutine to print dots every 3 seconds + stopDots := make(chan struct{}) + go func() { + for { + select { + case <-stopDots: + return + case <-time.After(1 * time.Second): + fmt.Print(".") + } + } + }() + + data := <-incoming + + // Stop the dot-printing goroutine + close(stopDots) + fmt.Println() + + // process data + s := sia.NewFromBytes(data) + uuidBytesFromResponse := s.ReadByteArrayN(16) + responseUUID, err := uuid.FromBytes(uuidBytesFromResponse) + if err != nil { + panic(err) + } + + if requestUUID != responseUUID { + panic("UUID mismatch") + } + + translated := s.ReadString16() + fmt.Println(translated) + + CloseSocket() + os.Exit(0) + }, +} + +// WithRunCmd appends the run command to the root command. +func WithTranslatePluginCmd(cmd *cobra.Command) { + cmd.AddCommand(translatePlugin) +} + +func init() { + translatePlugin.Flags().StringP( + "input", + "i", + "", + "Input text to translate", + ) + translatePlugin.Flags().StringP( + "from", + "f", + "en", + "From language", + ) + translatePlugin.Flags().StringP( + "to", + "t", + "fr", + "To language", + ) +} diff --git a/cmd/handler/worker.go b/cmd/handler/worker.go index c4f12148..4717326f 100644 --- a/cmd/handler/worker.go +++ b/cmd/handler/worker.go @@ -17,14 +17,14 @@ var worker = &cobra.Command{ config.App.Network.BrokerURI = cmd.Flags().Lookup("broker").Value.String() }, - Run: func(_ *cobra.Command, _ []string) { + Run: func(cmd *cobra.Command, _ []string) { err := config.Load(config.App.System.ConfigPath, config.App.System.SecretsPath) if err != nil { panic(err) } utils.SetupLogger(config.App.System.Log) - app.Worker() + app.Worker(cmd.Context()) }, } diff --git a/cmd/main.go b/cmd/main.go index 989a46b0..e12a4ec8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -28,6 +28,7 @@ func main() { handler.WithBrokerCmd(root) handler.WithConsumerCmd(root) handler.WithWorkerCmd(root) + handler.WithPluginCmd(root) err := root.Execute() if err != nil { @@ -41,4 +42,5 @@ func init() { root.PersistentFlags().StringVarP(&config.App.System.SecretsPath, "secrets", "s", "./secrets.yaml", "Secrets file") root.PersistentFlags().BoolVarP(&config.App.System.AllowGenerateSecrets, "allow-generate-secrets", "a", false, "Allow to generate secrets file if not exists") root.PersistentFlags().StringVarP(&config.App.System.ContextPath, "context", "x", "./context", "Context DB") + root.PersistentFlags().StringVarP(&config.App.System.Home, "home", "H", "./unchained", "Unchained Home") } diff --git a/go.mod b/go.mod index a5b70934..e5d8d43e 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/dgraph-io/badger/v4 v4.2.0 github.com/ethereum/go-ethereum v1.13.14 github.com/go-co-op/gocron/v2 v2.2.6 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru/v2 v2.0.7 @@ -20,12 +21,13 @@ require ( github.com/lmittmann/tint v1.0.4 github.com/mattn/go-colorable v0.1.13 github.com/peterldowns/pgtestdb v0.0.14 - github.com/pouya-eghbali/go-sia/v2 v2.1.0 + github.com/pouya-eghbali/go-sia/v2 v2.3.0 github.com/puzpuzpuz/xsync/v3 v3.1.0 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.9.0 github.com/vektah/gqlparser/v2 v2.5.11 golang.org/x/crypto v0.21.0 + golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 golang.org/x/sync v0.6.0 golang.org/x/text v0.14.0 gopkg.in/yaml.v3 v3.0.1 @@ -76,7 +78,6 @@ require ( github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/google/flatbuffers v24.3.7+incompatible // indirect github.com/google/go-cmp v0.6.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-bexpr v0.1.10 // indirect github.com/hashicorp/hcl/v2 v2.20.0 // indirect @@ -102,16 +103,18 @@ require ( github.com/mitchellh/pointerstructure v1.2.0 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/onsi/ginkgo v1.14.2 // indirect + github.com/onsi/gomega v1.10.4 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/client_golang v1.12.0 // indirect + github.com/prometheus/client_golang v1.12.1 // indirect github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect - github.com/rs/cors v1.7.0 // indirect + github.com/rs/cors v1.8.3 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shirou/gopsutil v3.21.11+incompatible // indirect github.com/sosodev/duration v1.2.0 // indirect @@ -130,12 +133,12 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zclconf/go-cty v1.14.4 // indirect go.opencensus.io v0.24.0 // indirect - golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect golang.org/x/mod v0.16.0 // indirect golang.org/x/net v0.22.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.19.0 // indirect + golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect diff --git a/go.sum b/go.sum index 2ea0c99e..c871105e 100644 --- a/go.sum +++ b/go.sum @@ -463,12 +463,14 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= -github.com/onsi/ginkgo v1.14.0 h1:2mOpI4JVVPBN+WQRa0WKH2eXR+Ey+uK4n7Zj0aYpIQA= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= +github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.4 h1:NiTx7EEvBzu9sFOD1zORteLSt3o8gnlvZZwSE9TnY9U= +github.com/onsi/gomega v1.10.4/go.mod h1:g/HbgYopi++010VEqkFgJHKC09uJiW9UkXvMUuKHUCQ= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterldowns/pgtestdb v0.0.14 h1:myVNL8ethaPZG7CQIjZxZCXwOG428THYRbSm0mIelpU= github.com/peterldowns/pgtestdb v0.0.14/go.mod h1:aG99+zgvWKOdGH+vtEFTDNVmaPOJD8ldIleuwJOgacA= @@ -484,14 +486,16 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pouya-eghbali/go-sia/v2 v2.1.0 h1:LdATSKXsEIhdZAoXRqfHJba+iIPjb8dhmRz3a11QFCA= -github.com/pouya-eghbali/go-sia/v2 v2.1.0/go.mod h1:E+hUvytS6uLa+HSBY+oi19zPvVGZdVzWSVW9zwzZnr8= +github.com/pouya-eghbali/go-sia/v2 v2.2.2 h1:W8K5CR00gLLMO3JLRfPBfeD93BSA5VpwuOOwfp5RKHw= +github.com/pouya-eghbali/go-sia/v2 v2.2.2/go.mod h1:E+hUvytS6uLa+HSBY+oi19zPvVGZdVzWSVW9zwzZnr8= +github.com/pouya-eghbali/go-sia/v2 v2.3.0 h1:x2bGKFnQOUOP9H+k8rr5hm3Qf2uWMnhLuDndwDApxGg= +github.com/pouya-eghbali/go-sia/v2 v2.3.0/go.mod h1:E+hUvytS6uLa+HSBY+oi19zPvVGZdVzWSVW9zwzZnr8= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= -github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= +github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -518,8 +522,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= -github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= +github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -699,6 +703,7 @@ golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= @@ -841,8 +846,9 @@ golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= diff --git a/internal/app/worker.go b/internal/app/worker.go index d89cc579..99bf6f5f 100644 --- a/internal/app/worker.go +++ b/internal/app/worker.go @@ -1,6 +1,10 @@ package app import ( + "context" + + "github.com/TimeleapLabs/unchained/internal/service/rpc" + "github.com/TimeleapLabs/unchained/internal/config" "github.com/TimeleapLabs/unchained/internal/consts" "github.com/TimeleapLabs/unchained/internal/crypto" @@ -17,7 +21,7 @@ import ( ) // Worker starts the Unchained worker and contains its DI. -func Worker() { +func Worker(_ context.Context) { utils.Logger. With("Mode", "Worker"). With("Version", consts.Version). @@ -40,6 +44,15 @@ func Worker() { evmLogService := evmlogService.New(ethRPC, pos, eventLogRepo, signerRepo, badger) uniswapService := uniswapService.New(ethRPC, pos, signerRepo, assetPrice) + rpcFunctions := []rpc.Option{} + for _, fun := range config.App.Functions { + switch fun.Type { //nolint: gocritic // This is a switch case for different types of rpc functions + case "unix": + rpcFunctions = append(rpcFunctions, rpc.WithUnixSocket(fun.Name, fun.Endpoint)) + } + } + rpcService := rpc.NewWorker(rpcFunctions...) + scheduler := scheduler.New( scheduler.WithEthLogs(evmLogService), scheduler.WithUniswapEvents(uniswapService), @@ -47,7 +60,7 @@ func Worker() { conn.Start() - handler := handler.NewWorkerHandler() + handler := handler.NewWorkerHandler(rpcService) client.NewRPC(handler) scheduler.Start() diff --git a/internal/config/model.go b/internal/config/model.go index 2c2efba9..ec20ba8a 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -13,6 +13,7 @@ type System struct { SecretsPath string AllowGenerateSecrets bool ContextPath string + Home string PrintVersion bool } @@ -69,6 +70,8 @@ type ProofOfStake struct { type Network struct { Bind string `env:"BIND" env-default:"0.0.0.0:9123" yaml:"bind"` + CertFile string `env:"CERT_FILE" env-default:"" yaml:"certFile"` + KeyFile string `env:"KEY_FILE" env-default:"" yaml:"keyFile"` BrokerURI string `env:"BROKER_URI" env-default:"wss://shinobi.brokers.kenshi.io" yaml:"brokerUri"` SubscribedChannel string `env:"SUBSCRIBED_CHANNEL" env-default:"unchained:" yaml:"subscribedChannel"` BrokerTimeout time.Duration `env:"BROKER_TIMEOUT" env-default:"3s" yaml:"brokerTimeout"` @@ -87,11 +90,19 @@ type Secret struct { EvmPrivateKey string `env:"EVM_PRIVATE_KEY" yaml:"evmPrivateKey"` } +// Function struct hold the function configuration of the application. +type Function struct { + Type string `json:"type"` + Name string `json:"name"` + Endpoint string `json:"endpoint"` +} + // Config struct is the main configuration struct of application. type Config struct { System System `yaml:"system"` Network Network `yaml:"network"` RPC []RPC `yaml:"rpc"` + Functions []Function `yaml:"functions"` Postgres Postgres `yaml:"postgres"` ProofOfStake ProofOfStake `yaml:"pos"` Plugins Plugins `yaml:"plugins"` diff --git a/internal/consts/errors.go b/internal/consts/errors.go index 55b0ced9..cb5585e6 100644 --- a/internal/consts/errors.go +++ b/internal/consts/errors.go @@ -28,4 +28,6 @@ var ( ErrDuplicateSignature = errors.New("duplicate signature") ErrCrossPriceIsNotZero = errors.New("cross price is not zero") ErrAlreadySynced = errors.New("already synced") + ErrCantSendRPCRequest = errors.New("can't send rpc request") + ErrCantReceiveRPCResponse = errors.New("can't receive rpc response") ) diff --git a/internal/consts/meta.go b/internal/consts/meta.go index 528747dc..5be62fb8 100644 --- a/internal/consts/meta.go +++ b/internal/consts/meta.go @@ -1,4 +1,4 @@ package consts -var Version = "0.12.0" -var ProtocolVersion = "0.12.0" +var Version = "0.13.0-ai-preview" +var ProtocolVersion = "0.13.0-ai-preview" diff --git a/internal/consts/opcodes.go b/internal/consts/opcodes.go index f0a1c4c0..ef89c430 100644 --- a/internal/consts/opcodes.go +++ b/internal/consts/opcodes.go @@ -22,4 +22,9 @@ const ( OpCodeCorrectnessReport OpCode = 10 OpCodeCorrectnessReportBroadcast OpCode = 11 + + OpCodeRegisterRPCFunction OpCode = 12 + + OpCodeRPCRequest OpCode = 13 + OpCodeRPCResponse OpCode = 14 ) diff --git a/internal/precompiles/ai/embed.go b/internal/precompiles/ai/embed.go new file mode 100644 index 00000000..b12b4120 --- /dev/null +++ b/internal/precompiles/ai/embed.go @@ -0,0 +1,41 @@ +// embed.go +package ai + +import ( + "embed" + "io/fs" + "os" + "path/filepath" +) + +//go:embed python_files/* +var pythonFiles embed.FS + +// ExtractPythonFiles extracts the embedded Python files to the target directory. +func ExtractPythonFiles(targetDir string) error { + return fs.WalkDir(pythonFiles, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + data, err := pythonFiles.ReadFile(path) + if err != nil { + return err + } + destPath := filepath.Join(targetDir, path) + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return err + } + return os.WriteFile(destPath, data, 0644) //nolint: gosec // Other users may need to read these files. + }) +} + +func GetRequirementsFile() []byte { + content, err := pythonFiles.ReadFile("python_files/requirements.txt") + if err != nil { + panic(err) + } + return content +} diff --git a/internal/precompiles/ai/python_files/chat.py b/internal/precompiles/ai/python_files/chat.py new file mode 100644 index 00000000..e4993117 --- /dev/null +++ b/internal/precompiles/ai/python_files/chat.py @@ -0,0 +1,56 @@ +import torch +from transformers import pipeline +import argparse + + +def initialize_pipeline(): + return pipeline( + "text-generation", + model="HuggingFaceH4/zephyr-7b-beta", + torch_dtype=torch.bfloat16, + device_map="mps" + ) + + +def generate_response(pipe, messages): + prompt = pipe.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + outputs = pipe( + prompt, + max_new_tokens=512, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=0.95, + ) + + return outputs[0]["generated_text"].strip() + + +def main(): + parser = argparse.ArgumentParser(description="CLI Chatbot") + parser.add_argument("--system_message", type=str, + default="You are god, and you are evil.") + args = parser.parse_args() + + pipe = initialize_pipeline() + + messages = [{"role": "system", "content": args.system_message}] + + print("Type your messages below (type 'exit' to quit):") + while True: + user_message = input("You: ") + if user_message.lower() == "exit": + break + + messages.append({"role": "user", "content": user_message}) + response = generate_response(pipe, messages) + messages.append({"role": "assistant", "content": response}) + + print(f"Bot: {response}") + + +if __name__ == "__main__": + main() diff --git a/internal/precompiles/ai/python_files/classify.py b/internal/precompiles/ai/python_files/classify.py new file mode 100644 index 00000000..788a1e02 --- /dev/null +++ b/internal/precompiles/ai/python_files/classify.py @@ -0,0 +1,37 @@ +import sys +from transformers import AutoImageProcessor, AutoModelForImageClassification +from PIL import Image +import torch + + +def classify_image(image_path): + # Load the image processor and model + processor = AutoImageProcessor.from_pretrained( + "google/vit-base-patch16-224") + model = AutoModelForImageClassification.from_pretrained( + "google/vit-base-patch16-224") + + # Open the image file + image = Image.open(image_path).convert("RGB") + + # Preprocess the image + inputs = processor(images=image, return_tensors="pt") + + # Forward pass through the model + with torch.no_grad(): + outputs = model(**inputs) + + # Get the predicted label + logits = outputs.logits + predicted_class_idx = logits.argmax(-1).item() + return model.config.id2label[predicted_class_idx] + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python image_classifier.py ") + sys.exit(1) + + image_path = sys.argv[1] + predicted_label = classify_image(image_path) + print(f"Predicted class: {predicted_label}") diff --git a/internal/precompiles/ai/python_files/detect.py b/internal/precompiles/ai/python_files/detect.py new file mode 100644 index 00000000..1aaf29ba --- /dev/null +++ b/internal/precompiles/ai/python_files/detect.py @@ -0,0 +1,54 @@ +import sys +from transformers import AutoImageProcessor, AutoModelForObjectDetection +from PIL import Image, ImageDraw, ImageFont +import torch + + +def detect_objects(image_path, font_path='noto.ttf', font_size=32): + # Load the image processor and model + processor = AutoImageProcessor.from_pretrained("hustvl/yolos-small") + model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small") + + # Open the image file + image = Image.open(image_path).convert("RGB") + + # Preprocess the image + inputs = processor(images=image, return_tensors="pt") + + # Forward pass through the model + with torch.no_grad(): + outputs = model(**inputs) + + # Process outputs + target_sizes = torch.tensor([image.size[::-1]]) + results = processor.post_process_object_detection( + outputs, target_sizes=target_sizes, threshold=0.9)[0] + + # Load the font + font = ImageFont.truetype(font_path, font_size) + + # Draw bounding boxes and labels on the image + draw = ImageDraw.Draw(image) + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + box = [round(i, 2) for i in box.tolist()] + draw.rectangle(box, outline="red", width=3) + text = f"{model.config.id2label[label.item()]}: { + round(score.item(), 3)}" + text_bbox = draw.textbbox((box[0], box[1]), text, font=font) + text_location = (box[0], box[1] - (text_bbox[3] - text_bbox[1])) + draw.rectangle( + [text_location, (text_bbox[2], text_bbox[3])], fill="red") + draw.text((box[0], box[1] - (text_bbox[3] - text_bbox[1])), + text, fill="white", font=font) + + return image + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python object_detector.py ") + sys.exit(1) + + image_path = sys.argv[1] + detected_image = detect_objects(image_path) + detected_image.show() # Display the image with detections diff --git a/internal/precompiles/ai/python_files/gen.py b/internal/precompiles/ai/python_files/gen.py new file mode 100644 index 00000000..5184db66 --- /dev/null +++ b/internal/precompiles/ai/python_files/gen.py @@ -0,0 +1,134 @@ +import sys +import torch +from diffusers import DiffusionPipeline +import io +import os +from sia import Sia +from torch_device import get_device + +pipelines = {} + +OPEN_SOURCE_MODELS = [ + "segmind/SSD-1B", + "Corcelio/mobius", + "segmind/Segmind-Vega", + "Corcelio/openvision", + "SimianLuo/LCM_Dreamshaper_v7", + "OEvortex/PixelGen" +] + +NON_FREE_MODELS = [ + "fluently/Fluently-XL-Final", + "alvdansen/littletinies", + "cagliostrolab/animagine-xl-3.1", + "SG161222/Realistic_Vision_V6.0_B1_noVAE", + "Lykon/dreamshaper-xl-v2-turbo", + "UnfilteredAI/NSFW-gen-v2.1", +] + + +def get_pipeline(model_name, lora_weights=None): + key = (model_name + ":::" + lora_weights) if lora_weights else model_name + + if key not in pipelines: + pipelines[key] = DiffusionPipeline.from_pretrained( + model_name, torch_dtype=torch.float16) + + if lora_weights: + pipelines[key].load_lora_weights(lora_weights) + + pipelines[key].to(get_device()) + + pipelines[key].safety_checker = lambda images, **kwargs: ( + images, [False] * len(images)) + + return pipelines[key] + + +def image_to_bytes(image): + byte_io = io.BytesIO() + image.save(byte_io, format="PNG") + return byte_io.getvalue() + + +def parse_packet(packet: Sia): + # get uuid v7 from packet + uuid = packet.read_byte_array_n(16) + model = packet.read_string8() + lora_weights = packet.read_string8() + steps = packet.read_uint8() + # get prompt length from packet (little endian uint16 at offset 17) + prompt = packet.read_string16() + negative_prompt = packet.read_string16() + return { + "uuid": uuid, + "prompt": prompt, + "model": model, + "lora_weights": lora_weights, + "negative_prompt": negative_prompt, + "steps": steps + } + + +def pack_response_packet(uuid, response: bytes): + return Sia().add_byte_array_n(uuid).add_byte_array32(response).content + + +def request_handler(packet): + parsed = parse_packet(packet) + pipe = get_pipeline(parsed["model"], parsed["lora_weights"]) + + images = pipe( + prompt=parsed["prompt"], + num_inference_steps=parsed["steps"], + guidance_scale=5.0, + lcm_origin_steps=50, + height=1024, + width=1024, + negative_prompt=parsed["negative_prompt"], + num_images_per_prompt=1, + output_type="pil").images + + response = image_to_bytes(images[0]) + + return pack_response_packet(parsed["uuid"], response) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python script.py ''") + sys.exit(1) + + caption = sys.argv[1] + + for model_name in [*OPEN_SOURCE_MODELS, *NON_FREE_MODELS]: + pipe = DiffusionPipeline.from_pretrained( + model_name, torch_dtype=torch.float32) + + pipe.to(get_device()) + + lora_weights = os.getenv("IMAGE_TO_TEXT_LORA_WEIGHTS") + + if lora_weights: + pipe.load_lora_weights(lora_weights) + + pipe.safety_checker = lambda images, **kwargs: ( + images, [False] * len(images)) + + num_inference_steps_str = os.getenv( + "IMAGE_TO_TEXT_STEPS") or "32" + num_inference_steps = int(num_inference_steps_str) + + images = pipe( + prompt=caption, + num_inference_steps=num_inference_steps, + guidance_scale=4.0, + lcm_origin_steps=50, + output_type="pil").images + + # Save the generated images + for idx, img in enumerate(images): + model_name_stripped = model_name.replace("/", "_") + img.save(f"generated_image_{model_name_stripped}_{idx}.png") + + print("Images saved successfully.") diff --git a/internal/precompiles/ai/python_files/main.py b/internal/precompiles/ai/python_files/main.py new file mode 100644 index 00000000..fd44c839 --- /dev/null +++ b/internal/precompiles/ai/python_files/main.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python + +import asyncio +from websockets.server import serve +import gen +import translate +from sia import Sia +import warnings + +# Ignore all warnings +warnings.filterwarnings("ignore") + + +async def echo(websocket): + async for message in websocket: + sia = Sia().set_content(bytearray(message)) + opcode = sia.read_uint16() + if opcode == 0: + response = gen.request_handler(sia) + await websocket.send(bytes(response)) + elif opcode == 1: + response = translate.request_handler(sia) + await websocket.send(bytes(response)) + else: + await websocket.send(b"\0") + + +async def main(): + async with serve(echo, "127.0.0.1", 8765): + await asyncio.Future() # run forever + +if __name__ == "__main__": + print("Server started on ws://localhost:8765") + asyncio.run(main()) diff --git a/internal/precompiles/ai/python_files/requirements.txt b/internal/precompiles/ai/python_files/requirements.txt new file mode 100644 index 00000000..6471c958 --- /dev/null +++ b/internal/precompiles/ai/python_files/requirements.txt @@ -0,0 +1,38 @@ +accelerate==0.31.0 +certifi==2024.6.2 +charset-normalizer==3.3.2 +click==8.1.7 +diffusers==0.29.1 +filelock==3.15.4 +flash-attention==1.0.0 +fsspec==2024.6.0 +huggingface-hub==0.23.4 +idna==3.7 +importlib-metadata==7.2.0 +jinja2==3.1.4 +joblib==1.4.2 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.1 +numpy==1.24.4 +packaging==24.1 +peft==0.11.1 +pillow==10.3.0 +psutil==6.0.0 +PyYAML==6.0.1 +regex==2024.5.15 +requests==2.32.3 +sacremoses==0.1.1 +safetensors==0.4.3 +sentencepiece==0.2.0 +sympy==1.12.1 +timm==1.0.7 +tokenizers==0.19.1 +torch==2.3.1 +torchvision==0.18.1 +tqdm==4.66.4 +transformers==4.41.2 +typing-extensions==4.12.2 +urllib3==2.2.2 +websockets==12.0 +zipp==3.19.2 diff --git a/internal/precompiles/ai/python_files/sia.py b/internal/precompiles/ai/python_files/sia.py new file mode 100644 index 00000000..628d3a54 --- /dev/null +++ b/internal/precompiles/ai/python_files/sia.py @@ -0,0 +1,256 @@ +class Sia: + def __init__(self): + self.index = 0 + self.content = bytearray() + + def seek(self, index): + self.index = index + return self + + def set_content(self, content): + self.content = bytearray(content) + return self + + def embed_sia(self, sia): + self.content.extend(sia.content) + return self + + def embed_bytes(self, bytes): + self.content.extend(bytes) + return self + + def add_uint8(self, n): + self.content.extend(n.to_bytes(1, 'little')) + return self + + def read_uint8(self): + if self.index >= len(self.content): + raise ValueError("Not enough data to read uint8") + value = int.from_bytes( + self.content[self.index:self.index + 1], 'little') + self.index += 1 + return value + + def add_int8(self, n): + self.content.extend(n.to_bytes(1, 'little', signed=True)) + return self + + def read_int8(self): + if self.index >= len(self.content): + raise ValueError("Not enough data to read int8") + value = int.from_bytes( + self.content[self.index:self.index + 1], 'little', signed=True) + self.index += 1 + return value + + def add_uint16(self, n): + self.content.extend(n.to_bytes(2, 'little')) + return self + + def read_uint16(self): + if self.index + 2 > len(self.content): + raise ValueError("Not enough data to read uint16") + value = int.from_bytes( + self.content[self.index:self.index + 2], 'little') + self.index += 2 + return value + + def add_int16(self, n): + self.content.extend(n.to_bytes(2, 'little', signed=True)) + return self + + def read_int16(self): + if self.index + 2 > len(self.content): + raise ValueError("Not enough data to read int16") + value = int.from_bytes( + self.content[self.index:self.index + 2], 'little', signed=True) + self.index += 2 + return value + + def add_uint32(self, n): + self.content.extend(n.to_bytes(4, 'little')) + return self + + def read_uint32(self): + if self.index + 4 > len(self.content): + raise ValueError("Not enough data to read uint32") + value = int.from_bytes( + self.content[self.index:self.index + 4], 'little') + self.index += 4 + return value + + def add_int32(self, n): + self.content.extend(n.to_bytes(4, 'little', signed=True)) + return self + + def read_int32(self): + if self.index + 4 > len(self.content): + raise ValueError("Not enough data to read int32") + value = int.from_bytes( + self.content[self.index:self.index + 4], 'little', signed=True) + self.index += 4 + return value + + def add_uint64(self, n): + self.content.extend(n.to_bytes(8, 'little')) + return self + + def read_uint64(self): + if self.index + 8 > len(self.content): + raise ValueError("Not enough data to read uint64") + value = int.from_bytes( + self.content[self.index:self.index + 8], 'little') + self.index += 8 + return value + + def add_int64(self, n): + self.content.extend(n.to_bytes(8, 'little', signed=True)) + return self + + def read_int64(self): + if self.index + 8 > len(self.content): + raise ValueError("Not enough data to read int64") + value = int.from_bytes( + self.content[self.index:self.index + 8], 'little', signed=True) + self.index += 8 + return value + + def add_string8(self, s): + encoded_string = s.encode('utf-8') + return self.add_byte_array8(encoded_string) + + def read_string_n(self, length): + if self.index + length > len(self.content): + raise ValueError("Not enough data to read string") + bytes = self.content[self.index:self.index + length] + self.index += length + return bytes.decode('utf-8') + + def write_string_n(self, s): + encoded_string = s.encode('utf-8') + self.content.extend(encoded_string) + return self + + def read_string8(self): + return self.read_byte_array8().decode('utf-8') + + def add_string16(self, s): + encoded_string = s.encode('utf-8') + return self.add_byte_array16(encoded_string) + + def read_string16(self): + return self.read_byte_array16().decode('utf-8') + + def add_string32(self, s): + encoded_string = s.encode('utf-8') + return self.add_byte_array32(encoded_string) + + def read_string32(self): + return self.read_byte_array32().decode('utf-8') + + def add_string64(self, s): + encoded_string = s.encode('utf-8') + return self.add_byte_array64(encoded_string) + + def read_string64(self): + return self.read_byte_array64().decode('utf-8') + + def add_byte_array_n(self, bytes): + self.content.extend(bytes) + return self + + def add_byte_array8(self, bytes): + return self.add_uint8(len(bytes)).add_byte_array_n(bytes) + + def add_byte_array16(self, bytes): + return self.add_uint16(len(bytes)).add_byte_array_n(bytes) + + def add_byte_array32(self, bytes): + return self.add_uint32(len(bytes)).add_byte_array_n(bytes) + + def add_byte_array64(self, bytes): + return self.add_uint64(len(bytes)).add_byte_array_n(bytes) + + def read_byte_array_n(self, length): + if self.index + length > len(self.content): + raise ValueError("Not enough data to read byte array") + bytes = self.content[self.index:self.index + length] + self.index += length + return bytes + + def read_byte_array8(self): + length = self.read_uint8() + return self.read_byte_array_n(length) + + def read_byte_array16(self): + length = self.read_uint16() + return self.read_byte_array_n(length) + + def read_byte_array32(self): + length = self.read_uint32() + return self.read_byte_array_n(length) + + def read_byte_array64(self): + length = self.read_uint64() + return self.read_byte_array_n(length) + + def add_bool(self, b): + bool_byte = 1 if b else 0 + self.content.extend(bool_byte.to_bytes(1, 'little')) + return self + + def read_bool(self): + if self.index >= len(self.content): + raise ValueError("Not enough data to read bool") + value = self.content[self.index] == 1 + self.index += 1 + return value + + def add_big_int(self, n): + hex_str = n.to_bytes((n.bit_length() + 7) // 8, 'little').hex() + bytes = bytearray.fromhex(hex_str) + return self.add_byte_array8(bytes) + + def read_big_int(self): + bytes = self.read_byte_array8() + return int.from_bytes(bytes, 'little') + + def add_array8(self, arr, fn): + self.add_uint8(len(arr)) + for item in arr: + fn(self, item) + return self + + def read_array8(self, fn): + length = self.read_uint8() + return [fn(self) for _ in range(length)] + + def add_array16(self, arr, fn): + self.add_uint16(len(arr)) + for item in arr: + fn(self, item) + return self + + def read_array16(self, fn): + length = self.read_uint16() + return [fn(self) for _ in range(length)] + + def add_array32(self, arr, fn): + self.add_uint32(len(arr)) + for item in arr: + fn(self, item) + return self + + def read_array32(self, fn): + length = self.read_uint32() + return [fn(self) for _ in range(length)] + + def add_array64(self, arr, fn): + self.add_uint64(len(arr)) + for item in arr: + fn(self, item) + return self + + def read_array64(self, fn): + length = self.read_uint64() + return [fn(self) for _ in range(length)] diff --git a/internal/precompiles/ai/python_files/torch_device.py b/internal/precompiles/ai/python_files/torch_device.py new file mode 100644 index 00000000..e6e69fcf --- /dev/null +++ b/internal/precompiles/ai/python_files/torch_device.py @@ -0,0 +1,13 @@ +import torch + + +def get_device(): + if torch.cuda.is_available(): + return "cuda" + # detect if m1/m2/m3 + if torch.backends.mps.is_available(): + return "mps" + # detect if vulkan (for android, raspberry pi, etc.) + if torch.is_vulkan_available(): + return "vulkan" + return "cpu" diff --git a/internal/precompiles/ai/python_files/translate.py b/internal/precompiles/ai/python_files/translate.py new file mode 100644 index 00000000..66f8e018 --- /dev/null +++ b/internal/precompiles/ai/python_files/translate.py @@ -0,0 +1,49 @@ +# Use a pipeline as a high-level helper +from transformers import pipeline +from sia import Sia +import sys + +loaded_models = {} + +def parse_packet(packet: Sia): + uuid = packet.read_byte_array_n(16) + fromLang = packet.read_string_n(2) + toLang = packet.read_string_n(2) + prompt = packet.read_string16() + return {"uuid": uuid, "from": fromLang, "to": toLang, "prompt": prompt} + + +def pack_response_packet(uuid, response: str): + return Sia().add_byte_array_n(uuid).add_string16(response).content + + +def request_handler(packet: Sia): + data = parse_packet(packet) + fromLang = data["from"] + toLang = data["to"] + prompt = data["prompt"] + + key = f"{fromLang}-{toLang}" + if key not in loaded_models: + loaded_models[key] = pipeline( + "translation", model=f"Helsinki-NLP/opus-mt-{fromLang}-{toLang}") + + output = loaded_models[key](prompt) + response = output[0]["translation_text"] + + return pack_response_packet(data["uuid"], response) + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python translate.py ") + sys.exit(1) + + fromLang = sys.argv[1] + toLang = sys.argv[2] + text = sys.argv[3] + + pipe = pipeline( + "translation", model=f"Helsinki-NLP/opus-mt-{fromLang}-{toLang}") + + output = pipe(text) + print(output[0]["translation_text"]) diff --git a/internal/service/ai/cache.go b/internal/service/ai/cache.go new file mode 100644 index 00000000..2e6bbf67 --- /dev/null +++ b/internal/service/ai/cache.go @@ -0,0 +1,35 @@ +package ai + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" +) + +// TxCache is a simple in-memory cache for transaction hashes. +type TxCache struct { + mu sync.Mutex + cache map[common.Hash]struct{} +} + +// NewTxCache creates a new TxCache. +func NewTxCache() *TxCache { + return &TxCache{ + cache: make(map[common.Hash]struct{}), + } +} + +// MarkExpired marks a transaction hash as expired. +func (tc *TxCache) MarkExpired(txHash common.Hash) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.cache[txHash] = struct{}{} +} + +// IsExpired checks if a transaction hash is marked as expired. +func (tc *TxCache) IsExpired(txHash common.Hash) bool { + tc.mu.Lock() + defer tc.mu.Unlock() + _, expired := tc.cache[txHash] + return expired +} diff --git a/internal/service/ai/common.go b/internal/service/ai/common.go new file mode 100644 index 00000000..92138de5 --- /dev/null +++ b/internal/service/ai/common.go @@ -0,0 +1,41 @@ +package ai + +import ( + "github.com/gorilla/websocket" +) + +func Read(conn *websocket.Conn, closed *bool) <-chan []byte { + out := make(chan []byte) + + go func() { + defer close(out) + for { + _, payload, err := conn.ReadMessage() + if err != nil { + if !*closed { + panic(err) + } + return + } + out <- payload + } + }() + + return out +} + +func CloseSocket(conn *websocket.Conn, closed *bool) { + if conn != nil { + *closed = true + err := conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + return + } + err = conn.Close() + if err != nil { + return + } + } +} diff --git a/internal/service/ai/fees.go b/internal/service/ai/fees.go new file mode 100644 index 00000000..f96a042e --- /dev/null +++ b/internal/service/ai/fees.go @@ -0,0 +1,75 @@ +package ai + +import ( + "context" + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient" +) + +// TxChecker contains the Ethereum client and transaction cache. +type TxChecker struct { + client *ethclient.Client + txCache *TxCache +} + +// NewTxChecker creates a new TxChecker. +func NewTxChecker(clientURL string) (*TxChecker, error) { + client, err := ethclient.Dial(clientURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to the Ethereum client: %w", err) + } + return &TxChecker{ + client: client, + txCache: NewTxCache(), + }, nil +} + +// CheckTransaction checks if a transaction meets the criteria. +func (tc *TxChecker) CheckTransaction(txHash common.Hash, toAddress common.Address, amount *big.Int) (bool, error) { + if tc.txCache.IsExpired(txHash) { + return false, fmt.Errorf("transaction is expired") + } + + tx, isPending, err := tc.client.TransactionByHash(context.Background(), txHash) + if err != nil { + return false, fmt.Errorf("could not retrieve transaction: %w", err) + } + + if isPending { + return false, fmt.Errorf("transaction is still pending") + } + + receipt, err := tc.client.TransactionReceipt(context.Background(), txHash) + if err != nil { + return false, fmt.Errorf("could not retrieve transaction receipt: %w", err) + } + + if receipt.Status != 1 { + return false, fmt.Errorf("transaction failed") + } + + header, err := tc.client.HeaderByNumber(context.Background(), receipt.BlockNumber) + if err != nil { + return false, fmt.Errorf("could not retrieve block header: %w", err) + } + + blockTime := time.Unix(int64(header.Time), 0) + if time.Since(blockTime) > 5*time.Minute { + tc.txCache.MarkExpired(txHash) + return false, fmt.Errorf("transaction is older than 5 minutes") + } + + if tx.To() == nil || *tx.To() != toAddress { + return false, nil + } + + if tx.Value().Cmp(amount) != 0 { + return false, nil + } + + return true, nil +} diff --git a/internal/service/ai/server.go b/internal/service/ai/server.go new file mode 100644 index 00000000..b34b60f8 --- /dev/null +++ b/internal/service/ai/server.go @@ -0,0 +1,153 @@ +package ai + +import ( + "context" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/TimeleapLabs/unchained/internal/utils" + + "github.com/TimeleapLabs/unchained/internal/config" + "github.com/TimeleapLabs/unchained/internal/precompiles/ai" +) + +func startProcess(ctx context.Context, wg *sync.WaitGroup, env []string, cmdPath string, cmdArgs []string, cmdCwd string) { + defer wg.Done() + + for { + cmd := exec.CommandContext(ctx, cmdPath, cmdArgs...) + cmd.Dir = cmdCwd + cmd.Env = append(os.Environ(), env...) + + // Capture stdout and stderr + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Start the process + if err := cmd.Start(); err != nil { + log.Printf("Failed to start process: %v", err) + return + } + + log.Printf("Started process %d", cmd.Process.Pid) + + // Wait for the process to finish or be killed + err := cmd.Wait() + if ctx.Err() != nil { + log.Printf("Process %d was killed", cmd.Process.Pid) + return + } + + if err != nil { + log.Printf("Process %d exited with error: %v. Restarting...", cmd.Process.Pid, err) + } else { + log.Printf("Process %d exited successfully. Restarting...", cmd.Process.Pid) + } + + // Sleep before restarting to avoid rapid restart loops + time.Sleep(2 * time.Second) + } +} + +func StartServer(ctx context.Context) (*sync.WaitGroup, context.CancelFunc) { + targetDir := filepath.Join(config.App.System.Home, "ai") + + // Extract Python files + utils.Logger.Info("Extracting the plugin Python files...") + if err := ai.ExtractPythonFiles(targetDir); err != nil { + log.Fatalf("Failed to extract Python files: %v", err) + } + + // Check if pyenv is installed + _, err := exec.LookPath("pyenv") + if err != nil { + log.Fatal("pyenv not found in PATH") + } + + // Install Python 3.8.10 using pyenv if not already installed + checkPythonCmd := exec.Command("pyenv", "versions", "--bare") + output, err := checkPythonCmd.Output() + if err != nil || !strings.Contains(string(output), "3.8.10") { + utils.Logger.Info("Installing Python 3.8.10...") + installPythonCmd := exec.Command("pyenv", "install", "-s", "3.8.10") + // installPythonCmd.Stdout = os.Stdout + // installPythonCmd.Stderr = os.Stderr + if err := installPythonCmd.Run(); err != nil { + log.Fatalf("Failed to install Python 3.8.10: %v", err) + } + } + + // Select Python 3.8.10 as the local version + utils.Logger.Info("Selecting Python 3.8.10...") + selectPythonCmd := exec.Command("pyenv", "local", "3.8.10") + selectPythonCmd.Dir = targetDir + // selectPythonCmd.Stdout = os.Stdout + // selectPythonCmd.Stderr = os.Stderr + if err := selectPythonCmd.Run(); err != nil { + log.Fatalf("Failed to select Python 3.8.10: %v", err) + } + + // Get the path to the Python 3.8.10 interpreter + pythonPathCmd := exec.Command("pyenv", "which", "python3.8") + pythonPathCmd.Dir = targetDir + pythonPath, err := pythonPathCmd.Output() + if err != nil { + log.Fatalf("Failed to get Python 3.8.10 path: %v", err) + } + pythonPathStr := strings.TrimSpace(string(pythonPath)) + + // Create a virtual environment with Python 3.8.10 if not already created + venvPath := filepath.Join(targetDir, "venv") + if _, err := os.Stat(filepath.Join(venvPath, "bin", "python")); os.IsNotExist(err) { + utils.Logger.Info("Creating virtual environment...") + createVenvCmd := exec.Command(pythonPathStr, "-m", "venv", venvPath) + // createVenvCmd.Stdout = os.Stdout + // createVenvCmd.Stderr = os.Stderr + if err := createVenvCmd.Run(); err != nil { + log.Fatalf("Failed to create virtual environment: %v", err) + } + } + + // Activate the virtual environment + activateScript := filepath.Join(venvPath, "bin", "activate") + + // Install dependencies if not already installed + pipPath := filepath.Join(venvPath, "bin", "pip") + freezeCmd := exec.Command("bash", "-c", fmt.Sprintf("source %s && exec %s freeze", activateScript, pipPath)) //nolint: gosec // This is a trusted command + output, err = freezeCmd.Output() + if err != nil || string(output) != string(ai.GetRequirementsFile()) { + utils.Logger.Info("Installing dependencies...") + requirementPath := filepath.Join(targetDir, "python_files", "requirements.txt") + installDepsCmd := exec.Command("bash", "-c", fmt.Sprintf("source %s && %s install -r %s", activateScript, pipPath, requirementPath)) //nolint: gosec // This is a trusted command + // installDepsCmd.Stdout = os.Stdout + // installDepsCmd.Stderr = os.Stderr + if err := installDepsCmd.Run(); err != nil { + log.Fatalf("Failed to install dependencies: %v", err) + } + } + + // Set up the context and wait group for process management + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + + // Define environment variables + env := []string{ + "HF_HOME=" + filepath.Join(targetDir, "hf_home"), + "PYTHONWARNINGS=ignore", + } + + // Start the process in a separate goroutine + wg.Add(1) + activateScript = filepath.Join("venv", "bin", "activate") + mainPyPath := filepath.Join("python_files", "main.py") + pythonCommand := fmt.Sprintf("source %s && exec python %s", activateScript, mainPyPath) + go startProcess(ctx, &wg, env, "bash", []string{"-c", pythonCommand}, targetDir) + + return &wg, cancel +} diff --git a/internal/service/ai/text_to_image.go b/internal/service/ai/text_to_image.go new file mode 100644 index 00000000..cd811065 --- /dev/null +++ b/internal/service/ai/text_to_image.go @@ -0,0 +1,81 @@ +package ai + +import ( + "fmt" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + sia "github.com/pouya-eghbali/go-sia/v2/pkg" +) + +func TextToImage(prompt string, negativePrompt string, model string, loraWeights string, steps uint8) []byte { + conn, _, err := websocket.DefaultDialer.Dial( + "ws://localhost:8765", nil, + ) + if err != nil { + panic(err) + } + + closed := false + defer CloseSocket(conn, &closed) + + incoming := Read(conn, &closed) + + requestUUID, err := uuid.NewV7() + if err != nil { + panic(err) + } + + uuidBytes, err := requestUUID.MarshalBinary() + if err != nil { + panic(err) + } + + payload := sia.New(). + AddUInt16(0). + AddByteArrayN(uuidBytes). + AddString8(model). + AddString8(loraWeights). + AddUInt8(steps). + AddString16(prompt). + AddString16(negativePrompt). + Bytes() + + err = conn.WriteMessage(websocket.BinaryMessage, payload) + if err != nil { + panic(err) + } + + // Start a goroutine to print dots every second + stopDots := make(chan struct{}) + go func() { + for { + select { + case <-stopDots: + return + case <-time.After(1 * time.Second): + fmt.Print(".") //nolint:forbidigo // This is a CLI tool + } + } + }() + + data := <-incoming + + // Stop the dot-printing goroutine + close(stopDots) + + // process data + s := sia.NewFromBytes(data) + uuidBytesFromResponse := s.ReadByteArrayN(16) + responseUUID, err := uuid.FromBytes(uuidBytesFromResponse) + if err != nil { + panic(err) + } + + if requestUUID != responseUUID { + panic("UUID mismatch") + } + + return s.ReadByteArray32() +} diff --git a/internal/service/rpc/coordinator.go b/internal/service/rpc/coordinator.go new file mode 100644 index 00000000..e73a8873 --- /dev/null +++ b/internal/service/rpc/coordinator.go @@ -0,0 +1,79 @@ +package rpc + +import ( + "github.com/TimeleapLabs/unchained/internal/transport/server/websocket/store" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "golang.org/x/exp/rand" +) + +// Coordinator is a struct that holds the tasks and workers. +type Coordinator struct { + Tasks map[uuid.UUID]*websocket.Conn + Workers map[string][]*websocket.Conn +} + +// RegisterTask will register a task which a connection provide. +func (r *Coordinator) RegisterTask(taskID uuid.UUID, conn *websocket.Conn) { + r.Tasks[taskID] = conn +} + +// UnregisterTask will unregister a task which a connection provide. +func (r *Coordinator) UnregisterTask(taskID uuid.UUID) { + delete(r.Tasks, taskID) +} + +// GetTask will return a task which a connection provide. +func (r *Coordinator) GetTask(taskID uuid.UUID) *websocket.Conn { + return r.Tasks[taskID] +} + +// RegisterWorker will register a worker which a connection provide. +func (r *Coordinator) RegisterWorker(function string, conn *websocket.Conn) { + r.Workers[function] = append(r.Workers[function], conn) +} + +// UnregisterWorker will unregister a worker which a connection provide. +func (r *Coordinator) UnregisterWorker(function string, conn *websocket.Conn) { + workers := r.Workers[function] + for i, c := range workers { + if c == conn { + r.Workers[function] = append(workers[:i], workers[i+1:]...) + break + } + } +} + +// GetWorkers will return a list of workers which provide a function. +func (r *Coordinator) GetWorkers(function string) []*websocket.Conn { + return r.Workers[function] +} + +// GetRandomWorker will return a random worker which provide a function. +func (r *Coordinator) GetRandomWorker(function string) *websocket.Conn { + workers := r.Workers[function] + available := make([]*websocket.Conn, 0, len(workers)) + + for _, worker := range workers { + if _, ok := store.Signers.Load(worker); ok { + available = append(available, worker) + } + } + + if len(available) == 0 { + return nil + } + + r.Workers[function] = available + random := rand.Intn(len(available)) + + return available[random] +} + +// NewCoordinator creates a new Coordinator. +func NewCoordinator() *Coordinator { + return &Coordinator{ + Tasks: make(map[uuid.UUID]*websocket.Conn), + Workers: make(map[string][]*websocket.Conn), + } +} diff --git a/internal/service/rpc/coordinator_test.go b/internal/service/rpc/coordinator_test.go new file mode 100644 index 00000000..1c54d5c7 --- /dev/null +++ b/internal/service/rpc/coordinator_test.go @@ -0,0 +1,52 @@ +package rpc + +import ( + "testing" + + "github.com/TimeleapLabs/unchained/internal/utils" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/suite" +) + +type CoordinatorTestSuite struct { + suite.Suite + service *Coordinator +} + +func (s *CoordinatorTestSuite) SetupTest() { + utils.SetupLogger("info") + + s.service = NewCoordinator() +} + +func (s *CoordinatorTestSuite) TestCoordinator_RegisterWorker() { + conn := &websocket.Conn{} + s.service.RegisterWorker("test-worker", conn) + gotConns := s.service.GetWorkers("test-worker") + s.Len(gotConns, 1) + s.Equal(conn, gotConns[0]) + + s.service.UnregisterWorker("test-worker", conn) + gotConns = s.service.GetWorkers("test-worker") + s.Len(gotConns, 0) +} + +func (s *CoordinatorTestSuite) TestCoordinator_RegisterTask() { + conn := &websocket.Conn{} + + taskID, err := uuid.NewUUID() + s.NoError(err) + + s.service.RegisterTask(taskID, conn) + gotConn := s.service.GetTask(taskID) + s.Equal(conn, gotConn) + + s.service.UnregisterTask(taskID) + gotConn = s.service.GetTask(taskID) + s.Nil(gotConn) +} + +func TestCoordinatorSuite(t *testing.T) { + suite.Run(t, new(CoordinatorTestSuite)) +} diff --git a/internal/service/rpc/dto.go b/internal/service/rpc/dto.go new file mode 100644 index 00000000..30bcdfed --- /dev/null +++ b/internal/service/rpc/dto.go @@ -0,0 +1,55 @@ +package rpc + +import ( + sia "github.com/pouya-eghbali/go-sia/v2/pkg" +) + +type TextToImageRPCRequestParams struct { + // The text to be converted to an image + Prompt string `json:"prompt"` + NegativePrompt string `json:"negativePrompt"` + // The model to be used + Model string `json:"model"` + // The weights of the model + LoraWeights string `json:"loraWeights"` + // The number of steps to run + Steps uint8 `json:"steps"` +} + +type TextToImageRPCResponseParams struct { + // The image in bytes + Image []byte `json:"image"` +} + +func (t *TextToImageRPCRequestParams) Sia() sia.Sia { + return sia.New(). + AddString16(t.Prompt). + AddString16(t.NegativePrompt). + AddString8(t.Model). + AddString8(t.LoraWeights). + AddUInt8(t.Steps) +} + +func (t *TextToImageRPCRequestParams) FromSiaBytes(bytes []byte) *TextToImageRPCRequestParams { + s := sia.NewFromBytes(bytes) + + t.Prompt = s.ReadString16() + t.NegativePrompt = s.ReadString16() + t.Model = s.ReadString8() + t.LoraWeights = s.ReadString8() + t.Steps = s.ReadUInt8() + + return t +} + +func (t *TextToImageRPCResponseParams) Sia() sia.Sia { + return sia.New(). + AddByteArray32(t.Image) +} + +func (t *TextToImageRPCResponseParams) FromSiaBytes(bytes []byte) *TextToImageRPCResponseParams { + s := sia.NewFromBytes(bytes) + t.Image = s.ReadByteArray32() + + return t +} diff --git a/internal/service/rpc/dto/register.go b/internal/service/rpc/dto/register.go new file mode 100644 index 00000000..52b16019 --- /dev/null +++ b/internal/service/rpc/dto/register.go @@ -0,0 +1,22 @@ +package dto + +import sia "github.com/pouya-eghbali/go-sia/v2/pkg" + +// RegisterFunction is a DTO for registering a function. +type RegisterFunction struct { + Function string `json:"function"` + Runtime string `json:"runtime"` +} + +func (t *RegisterFunction) Sia() sia.Sia { + return sia.New(). + AddString8(t.Function) +} + +func (t *RegisterFunction) FromSiaBytes(bytes []byte) *RegisterFunction { + s := sia.NewFromBytes(bytes) + + t.Function = s.ReadString8() + + return t +} diff --git a/internal/service/rpc/dto/request.go b/internal/service/rpc/dto/request.go new file mode 100644 index 00000000..28d16ab4 --- /dev/null +++ b/internal/service/rpc/dto/request.go @@ -0,0 +1,71 @@ +package dto + +import ( + "github.com/google/uuid" + sia "github.com/pouya-eghbali/go-sia/v2/pkg" +) + +// RPCRequest is the request of a RPC request. +type RPCRequest struct { + // The ID of the request + ID uuid.UUID `json:"id"` + // The signature of the request + Signature [48]byte `json:"signature"` + // Payment information + TxHash string `json:"tx_hash"` + // The method to be called + Method string `json:"method"` + // params to pass to the function + Params []byte `json:"params"` +} + +// NewRequest creates a new request with unique ID. +func NewRequest(method string, params []byte, signature [48]byte, txHash string) RPCRequest { + taskID, err := uuid.NewV7() + if err != nil { + panic(err) + } + + return RPCRequest{ + ID: taskID, + Method: method, + Params: params, + Signature: signature, + TxHash: txHash, + } +} + +func (t *RPCRequest) Sia() sia.Sia { + uuidBytes, err := t.ID.MarshalBinary() + + if err != nil { + panic(err) + } + + return sia.New(). + AddByteArray8(uuidBytes). + AddByteArray8(t.Signature[:]). + AddString8(t.TxHash). + AddString8(t.Method). + EmbedBytes(t.Params) +} + +func (t *RPCRequest) FromSiaBytes(bytes []byte) *RPCRequest { + s := sia.NewFromBytes(bytes) + + uuidBytes := s.ReadByteArray8() + err := t.ID.UnmarshalBinary(uuidBytes) + if err != nil { + panic(err) + // return nil + } + + t.Signature = [48]byte{} + copy(t.Signature[:], s.ReadByteArray8()) + + t.TxHash = s.ReadString8() + t.Method = s.ReadString8() + t.Params = s.Bytes()[s.Offset():] + + return t +} diff --git a/internal/service/rpc/dto/request_test.go b/internal/service/rpc/dto/request_test.go new file mode 100644 index 00000000..eae4779a --- /dev/null +++ b/internal/service/rpc/dto/request_test.go @@ -0,0 +1,17 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRequestSia(t *testing.T) { + req := NewRequest("test", []byte("hello world"), [48]byte{}, "txHash") + reqByte := req.Sia().Bytes() + gotReq := new(RPCRequest).FromSiaBytes(reqByte) + + t.Log(req) + t.Log(*gotReq) + assert.Equal(t, req, *gotReq) +} diff --git a/internal/service/rpc/dto/response.go b/internal/service/rpc/dto/response.go new file mode 100644 index 00000000..6837b1d3 --- /dev/null +++ b/internal/service/rpc/dto/response.go @@ -0,0 +1,43 @@ +package dto + +import ( + "github.com/google/uuid" + sia "github.com/pouya-eghbali/go-sia/v2/pkg" +) + +// RPCResponse is the response of a RPC request. +type RPCResponse struct { + // The ID of the request + ID uuid.UUID `json:"id"` + // The error of the function + Error uint64 `json:"error"` + // The response of the function + Response []byte `json:"response"` +} + +func (t *RPCResponse) Sia() sia.Sia { + uuidBytes, err := t.ID.MarshalBinary() + if err != nil { + panic(err) + } + + return sia.New(). + AddByteArray8(uuidBytes). + AddUInt64(t.Error). + EmbedBytes(t.Response) +} + +func (t *RPCResponse) FromSiaBytes(bytes []byte) *RPCResponse { + s := sia.NewFromBytes(bytes) + + uuidBytes := s.ReadByteArray8() + err := t.ID.UnmarshalBinary(uuidBytes) + if err != nil { + return nil + } + + t.Error = s.ReadUInt64() + t.Response = s.Bytes()[s.Offset():] + + return t +} diff --git a/internal/service/rpc/runtime.go b/internal/service/rpc/runtime.go new file mode 100644 index 00000000..e038f9fb --- /dev/null +++ b/internal/service/rpc/runtime.go @@ -0,0 +1,38 @@ +package rpc + +import ( + "net" +) + +// Runtime is a type that holds the runtime of a function. +type Runtime string + +const ( + Mock Runtime = "Mock" + Unix Runtime = "Unix" +) + +func WithMockTask(name string) func(s *Worker) { + return func(s *Worker) { + s.functions[name] = meta{ + runtime: Mock, + } + } +} + +func WithUnixSocket(name string, path string) func(s *Worker) { + return func(s *Worker) { + meta := meta{ + runtime: Unix, + path: path, + } + + var err error + meta.conn, err = net.Dial("unix", path) + if err != nil { + panic(err) + } + + s.functions[name] = meta + } +} diff --git a/internal/service/rpc/runtime/mock.go b/internal/service/rpc/runtime/mock.go new file mode 100644 index 00000000..9049210a --- /dev/null +++ b/internal/service/rpc/runtime/mock.go @@ -0,0 +1,5 @@ +package runtime + +func RunMock(params []byte) ([]byte, error) { + return params, nil +} diff --git a/internal/service/rpc/runtime/unix.go b/internal/service/rpc/runtime/unix.go new file mode 100644 index 00000000..0cdbf6ee --- /dev/null +++ b/internal/service/rpc/runtime/unix.go @@ -0,0 +1,69 @@ +package runtime + +import ( + "context" + "errors" + "io" + "net" + + "github.com/TimeleapLabs/unchained/internal/consts" + "github.com/TimeleapLabs/unchained/internal/service/rpc/dto" + "github.com/TimeleapLabs/unchained/internal/utils" + sia "github.com/pouya-eghbali/go-sia/v2/pkg" +) + +type UnixPayload struct { + Size uint32 + Params []byte +} + +func NewUnixPayload(params *dto.RPCRequest) *UnixPayload { + payload := params.Sia().Bytes() + return &UnixPayload{ + Size: uint32(len(payload)), + Params: payload, + } +} + +func (p *UnixPayload) Sia() sia.Sia { + return sia.New().AddUInt32(p.Size).EmbedBytes(p.Params) +} + +// RunUnixCall runs a function with the given name and parameters. +func RunUnixCall(_ context.Context, conn net.Conn, params *dto.RPCRequest) (*dto.RPCResponse, error) { + payload := NewUnixPayload(params) + _, err := conn.Write(payload.Sia().Bytes()) + if err != nil { + utils.Logger.With("err", err).Error("Error sending message") + return nil, consts.ErrCantSendRPCRequest + } + + // Wait for response + var response []byte + buf := make([]byte, 1024) + payloadSize := int(4) + + for payloadSize > len(response) { + n, err := conn.Read(buf) + if errors.Is(err, io.EOF) { + utils.Logger.Error("Connection closed") + break // End of file or connection closed + } else if err != nil { + utils.Logger.With("err", err).Error("Error receiving response") + return nil, consts.ErrCantReceiveRPCResponse + } + + response = append(response, buf[:n]...) + if payloadSize == 4 && len(response) >= 4 { + payloadSize += int(sia.NewFromBytes(response).ReadUInt32()) + } + + if len(response) >= payloadSize { + break + } + } + + utils.Logger.With("length", len(response)).Info("Received response") + + return new(dto.RPCResponse).FromSiaBytes(response[4:]), nil +} diff --git a/internal/service/rpc/worker.go b/internal/service/rpc/worker.go new file mode 100644 index 00000000..bc308ef0 --- /dev/null +++ b/internal/service/rpc/worker.go @@ -0,0 +1,71 @@ +package rpc + +import ( + "context" + "net" + + "github.com/TimeleapLabs/unchained/internal/consts" + "github.com/TimeleapLabs/unchained/internal/service/rpc/dto" + "github.com/TimeleapLabs/unchained/internal/service/rpc/runtime" + "github.com/TimeleapLabs/unchained/internal/transport/client/conn" + "github.com/TimeleapLabs/unchained/internal/utils" +) + +type Option func(s *Worker) + +// meta is a struct that holds the information of a function. +type meta struct { + runtime Runtime + path string + conn net.Conn +} + +// Worker is a struct that holds the functions that the worker can run. +type Worker struct { + functions map[string]meta +} + +// RunFunction runs a function with the given name and parameters. +func (w *Worker) RunFunction(ctx context.Context, name string, params *dto.RPCRequest) ([]byte, error) { + switch w.functions[name].runtime { + case Unix: + result, err := runtime.RunUnixCall(ctx, w.functions[name].conn, params) + if err != nil { + utils.Logger.With("err", err).Error("Failed to run wasm") + return nil, err + } + + return result.Sia().Bytes(), nil + case Mock: + return runtime.RunMock(params.Sia().Bytes()) + } + + return nil, consts.ErrInternalError +} + +// registerFunction registers a function with the broker. +func (w *Worker) registerFunction(name string, runtime string) { + payload := dto.RegisterFunction{Function: name, Runtime: runtime} + conn.Send(consts.OpCodeRegisterRPCFunction, payload.Sia().Bytes()) +} + +// RegisterFunctions registers the functions with the broker. +func (w *Worker) RegisterFunctions() { + // Register the functions + for name, fun := range w.functions { + w.registerFunction(name, string(fun.runtime)) + } +} + +// NewWorker creates a new worker. +func NewWorker(options ...Option) *Worker { + worker := &Worker{ + functions: make(map[string]meta), + } + + for _, o := range options { + o(worker) + } + + return worker +} diff --git a/internal/service/rpc/worker_test.go b/internal/service/rpc/worker_test.go new file mode 100644 index 00000000..1b145357 --- /dev/null +++ b/internal/service/rpc/worker_test.go @@ -0,0 +1,98 @@ +package rpc + +import ( + "context" + "net" + "os" + "testing" + + "github.com/TimeleapLabs/unchained/internal/consts" + "github.com/TimeleapLabs/unchained/internal/service/rpc/dto" + "github.com/TimeleapLabs/unchained/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +const ( + UnixSocketPath = "/tmp/test.sock" +) + +var ( + SamplePacket = dto.NewRequest("test", nil, [48]byte{}, "txHash") +) + +type WorkerTestSuite struct { + suite.Suite + service *Worker + server *net.UnixListener +} + +func handleConnection(t *testing.T, conn net.Conn) { + defer conn.Close() + + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + assert.NoError(t, err) + + t.Log("Received: ", buf[0:n]) + _, err = conn.Write(buf[0:n]) + assert.NoError(t, err) + } +} + +func (s *WorkerTestSuite) SetupTest() { + utils.SetupLogger("info") + + _ = os.Remove(UnixSocketPath) + + var err error + s.server, err = net.ListenUnix("unix", &net.UnixAddr{Name: UnixSocketPath, Net: "unix"}) + s.Require().NoError(err) + + go func() { + for { + conn, err := s.server.Accept() + s.Require().NoError(err) + + go handleConnection(s.T(), conn) + } + }() + + s.service = NewWorker( + WithMockTask("test"), + WithUnixSocket("unix-test", UnixSocketPath), + ) +} + +func (s *WorkerTestSuite) TestRunFunction() { + s.Run("Should run successfully", func() { + result, err := s.service.RunFunction(context.TODO(), "test", &SamplePacket) + s.NoError(err) + s.Equal(SamplePacket.Sia().Bytes(), result) + }) + + s.Run("Run non-existing func, Should return err", func() { + _, err := s.service.RunFunction(context.TODO(), "non-existing-test", &SamplePacket) + s.Error(err, consts.ErrInternalError) + }) + + s.Run("Should run successfully", func() { + result, err := s.service.RunFunction(context.TODO(), "unix-test", &SamplePacket) + s.NoError(err) + + gotPacket := new(dto.RPCRequest).FromSiaBytes(result) + s.Equal(SamplePacket, *gotPacket) + }) +} + +func (s *WorkerTestSuite) TearDownTest() { + err := s.server.Close() + s.Require().NoError(err) + + _ = os.Remove(UnixSocketPath) +} + +func TestWorkerTestSuite(t *testing.T) { + suite.Run(t, new(WorkerTestSuite)) +} diff --git a/internal/transport/client/client.go b/internal/transport/client/client.go index 5470b5fe..68940435 100644 --- a/internal/transport/client/client.go +++ b/internal/transport/client/client.go @@ -37,13 +37,12 @@ func NewRPC(handler handler.Handler) { conn.Send(consts.OpCodeKoskResult, challenge) case consts.OpCodePriceReportBroadcast: handler.PriceReport(ctx, payload[1:]) - case consts.OpCodeEventLogBroadcast: handler.EventLog(ctx, payload[1:]) - case consts.OpCodeCorrectnessReportBroadcast: handler.CorrectnessReport(ctx, payload[1:]) - + case consts.OpCodeRPCRequest: + handler.RPCRequest(ctx, payload[1:]) default: utils.Logger. With("Code", payload[0]). diff --git a/internal/transport/client/handler/handler.go b/internal/transport/client/handler/handler.go index 71fad11e..61501bee 100644 --- a/internal/transport/client/handler/handler.go +++ b/internal/transport/client/handler/handler.go @@ -9,4 +9,6 @@ type Handler interface { CorrectnessReport(ctx context.Context, message []byte) EventLog(ctx context.Context, message []byte) PriceReport(ctx context.Context, message []byte) + RPCRequest(ctx context.Context, message []byte) + RPCResponse(ctx context.Context, message []byte) } diff --git a/internal/transport/client/handler/rpc.go b/internal/transport/client/handler/rpc.go new file mode 100644 index 00000000..95ba377a --- /dev/null +++ b/internal/transport/client/handler/rpc.go @@ -0,0 +1,49 @@ +package handler + +import ( + "context" + "math/big" + + "github.com/TimeleapLabs/unchained/internal/consts" + "github.com/TimeleapLabs/unchained/internal/service/rpc/dto" + "github.com/TimeleapLabs/unchained/internal/transport/client/conn" + "github.com/TimeleapLabs/unchained/internal/utils" + + "github.com/TimeleapLabs/unchained/internal/service/ai" + "github.com/ethereum/go-ethereum/common" +) + +var TimeleapRPC = "https://devnet.timeleap.swiss/rpc" +var CollectorAddress = common.HexToAddress("0xA2dEc4f8089f89F426e6beB76B555f3Cf9E7f499") + +func (h *consumer) RPCRequest(_ context.Context, _ []byte) {} + +func (w worker) RPCRequest(ctx context.Context, message []byte) { + utils.Logger.Info("RPC Request") + packet := new(dto.RPCRequest).FromSiaBytes(message) + + // check fees + checker, err := ai.NewTxChecker(TimeleapRPC) + if err != nil { + return + } + + // 0.1 TLP + fee, _ := new(big.Int).SetString("100000000000000000", 10) + + ok, err := checker.CheckTransaction(common.HexToHash(packet.TxHash), CollectorAddress, fee) + if err != nil || !ok { + return + } + + response, err := w.rpc.RunFunction(ctx, packet.Method, packet) + if err != nil { + return + } + + conn.Send(consts.OpCodeRPCResponse, response) +} + +func (w worker) RPCResponse(_ context.Context, _ []byte) {} + +func (h *consumer) RPCResponse(_ context.Context, _ []byte) {} diff --git a/internal/transport/client/handler/worker.go b/internal/transport/client/handler/worker.go index b341e66a..1a72f417 100644 --- a/internal/transport/client/handler/worker.go +++ b/internal/transport/client/handler/worker.go @@ -1,8 +1,16 @@ package handler +import "github.com/TimeleapLabs/unchained/internal/service/rpc" + type worker struct { + rpc *rpc.Worker } -func NewWorkerHandler() Handler { - return &worker{} +func NewWorkerHandler(rpc *rpc.Worker) Handler { + // Register the worker functions with the broker + rpc.RegisterFunctions() + + return &worker{ + rpc: rpc, + } } diff --git a/internal/transport/server/pubsub/pubsub.go b/internal/transport/server/pubsub/pubsub.go index 71f1596b..0c542937 100644 --- a/internal/transport/server/pubsub/pubsub.go +++ b/internal/transport/server/pubsub/pubsub.go @@ -7,9 +7,11 @@ import ( "github.com/TimeleapLabs/unchained/internal/consts" ) +// topics is a map of topics to a slice of channels that are subscribed to that topic. var topics = make(map[consts.Channels][]chan []byte) var mu sync.Mutex +// getTopicsByPrefix returns a map of topics that have the given prefix. func getTopicsByPrefix(topic consts.Channels) map[consts.Channels][]chan []byte { keys := make(map[consts.Channels][]chan []byte) for key := range topics { @@ -22,6 +24,7 @@ func getTopicsByPrefix(topic consts.Channels) map[consts.Channels][]chan []byte return keys } +// Publish sends a message to all subscribers of the given topic. func Publish(destinationTopic consts.Channels, operation consts.OpCode, message []byte) { mu.Lock() defer mu.Unlock() @@ -37,6 +40,7 @@ func Publish(destinationTopic consts.Channels, operation consts.OpCode, message } } +// Subscribe creates a new channel and appends it to the list of subscribers for the given topic. func Subscribe(topic string) chan []byte { mu.Lock() defer mu.Unlock() diff --git a/internal/transport/server/server.go b/internal/transport/server/server.go index 47e4eeef..c7bf4b45 100644 --- a/internal/transport/server/server.go +++ b/internal/transport/server/server.go @@ -4,25 +4,40 @@ import ( "fmt" "net/http" - "github.com/TimeleapLabs/unchained/internal/utils" - "github.com/TimeleapLabs/unchained/internal/config" + "github.com/TimeleapLabs/unchained/internal/utils" ) +// New creates a new HTTP server. func New(options ...func()) { for _, option := range options { option() } - utils.Logger. - With("Bind", fmt.Sprintf("http://%s", config.App.Network.Bind)). - Info("Starting a HTTP server") - server := &http.Server{ Addr: config.App.Network.Bind, ReadHeaderTimeout: config.App.Network.BrokerTimeout, } + if config.App.Network.CertFile != "" && config.App.Network.KeyFile != "" { + utils.Logger. + With("Bind", fmt.Sprintf("https://%s", config.App.Network.Bind)). + With("CertFile", config.App.Network.CertFile). + With("KeyFile", config.App.Network.KeyFile). + Info("Starting a HTTPS server") + + err := server.ListenAndServeTLS(config.App.Network.CertFile, config.App.Network.KeyFile) + if err != nil { + panic(err) + } + + return + } + + utils.Logger. + With("Bind", fmt.Sprintf("http://%s", config.App.Network.Bind)). + Info("Starting a HTTP server") + err := server.ListenAndServe() if err != nil { panic(err) diff --git a/internal/transport/server/websocket/handler/correctness.go b/internal/transport/server/websocket/handler/correctness.go index c99a9354..11d3bc6f 100644 --- a/internal/transport/server/websocket/handler/correctness.go +++ b/internal/transport/server/websocket/handler/correctness.go @@ -3,25 +3,30 @@ package handler import ( "github.com/TimeleapLabs/unchained/internal/consts" "github.com/TimeleapLabs/unchained/internal/model" + "github.com/TimeleapLabs/unchained/internal/transport/server/pubsub" "github.com/TimeleapLabs/unchained/internal/transport/server/websocket/middleware" "github.com/gorilla/websocket" ) -func CorrectnessRecord(conn *websocket.Conn, payload []byte) ([]byte, error) { +// CorrectnessRecord is a handler for correctness report. +func CorrectnessRecord(conn *websocket.Conn, payload []byte) { err := middleware.IsConnectionAuthenticated(conn) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } correctness := new(model.CorrectnessReportPacket).FromBytes(payload) correctnessHash, err := correctness.Correctness.Bls() if err != nil { - return []byte{}, consts.ErrInternalError + SendError(conn, consts.OpCodeError, consts.ErrInternalError) + return } signer, err := middleware.IsMessageValid(conn, correctnessHash, correctness.Signature) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } broadcastPacket := model.BroadcastCorrectnessPacket{ @@ -30,5 +35,6 @@ func CorrectnessRecord(conn *websocket.Conn, payload []byte) ([]byte, error) { Signer: signer, } - return broadcastPacket.Sia().Bytes(), nil + pubsub.Publish(consts.ChannelCorrectnessReport, consts.OpCodeCorrectnessReportBroadcast, broadcastPacket.Sia().Bytes()) + SendMessage(conn, consts.OpCodeFeedback, "signature.accepted") } diff --git a/internal/transport/server/websocket/handler/event.go b/internal/transport/server/websocket/handler/event.go index 4aa3d206..f28406a5 100644 --- a/internal/transport/server/websocket/handler/event.go +++ b/internal/transport/server/websocket/handler/event.go @@ -3,25 +3,30 @@ package handler import ( "github.com/TimeleapLabs/unchained/internal/consts" "github.com/TimeleapLabs/unchained/internal/model" + "github.com/TimeleapLabs/unchained/internal/transport/server/pubsub" "github.com/TimeleapLabs/unchained/internal/transport/server/websocket/middleware" "github.com/gorilla/websocket" ) -func EventLog(conn *websocket.Conn, payload []byte) ([]byte, error) { +// EventLog handles the event log packet from the client. +func EventLog(conn *websocket.Conn, payload []byte) { err := middleware.IsConnectionAuthenticated(conn) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } priceReport := new(model.EventLogReportPacket).FromBytes(payload) priceInfoHash, err := priceReport.EventLog.Bls() if err != nil { - return []byte{}, consts.ErrInternalError + SendError(conn, consts.OpCodeError, consts.ErrInternalError) + return } signer, err := middleware.IsMessageValid(conn, priceInfoHash, priceReport.Signature) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } broadcastPacket := model.BroadcastEventPacket{ @@ -30,5 +35,6 @@ func EventLog(conn *websocket.Conn, payload []byte) ([]byte, error) { Signer: signer, } - return broadcastPacket.Sia().Bytes(), nil + pubsub.Publish(consts.ChannelEventLog, consts.OpCodeEventLogBroadcast, broadcastPacket.Sia().Bytes()) + SendMessage(conn, consts.OpCodeFeedback, "signature.accepted") } diff --git a/internal/transport/server/websocket/handler/hello.go b/internal/transport/server/websocket/handler/hello.go index b76d5d4e..5b967bdf 100644 --- a/internal/transport/server/websocket/handler/hello.go +++ b/internal/transport/server/websocket/handler/hello.go @@ -8,12 +8,15 @@ import ( "github.com/gorilla/websocket" ) -func Hello(conn *websocket.Conn, payload []byte) ([]byte, error) { +// Hello handler store the new client in the Signers map and send it a challenge packet. +func Hello(conn *websocket.Conn, payload []byte) { + utils.Logger.With("IP", conn.RemoteAddr().String()).Info("New Client Registered") signer := new(model.Signer).FromBytes(payload) if signer.Name == "" { utils.Logger.Error("Signer name is empty Or public key is invalid") - return []byte{}, consts.ErrInvalidConfig + SendError(conn, consts.OpCodeError, consts.ErrInvalidConfig) + return } store.Signers.Range(func(conn *websocket.Conn, signerInMap model.Signer) bool { @@ -21,6 +24,7 @@ func Hello(conn *websocket.Conn, payload []byte) ([]byte, error) { if publicKeyInUse { Close(conn) } + return !publicKeyInUse }) @@ -30,5 +34,6 @@ func Hello(conn *websocket.Conn, payload []byte) ([]byte, error) { challenge := model.ChallengePacket{Random: utils.NewChallenge()} store.Challenges.Store(conn, challenge) - return challenge.Sia().Bytes(), nil + SendMessage(conn, consts.OpCodeFeedback, "conf.ok") + Send(conn, consts.OpCodeKoskChallenge, challenge.Sia().Bytes()) } diff --git a/internal/transport/server/websocket/handler/helper.go b/internal/transport/server/websocket/handler/helper.go index e6db66a5..7d458dc4 100644 --- a/internal/transport/server/websocket/handler/helper.go +++ b/internal/transport/server/websocket/handler/helper.go @@ -8,9 +8,10 @@ import ( "github.com/gorilla/websocket" ) -func Send(conn *websocket.Conn, messageType int, opCode consts.OpCode, payload []byte) { +// Send sends a packet to the client. +func Send(conn *websocket.Conn, opCode consts.OpCode, payload []byte) { err := conn.WriteMessage( - messageType, + websocket.BinaryMessage, append( []byte{byte(opCode)}, payload...), @@ -20,10 +21,12 @@ func Send(conn *websocket.Conn, messageType int, opCode consts.OpCode, payload [ } } -func SendMessage(conn *websocket.Conn, messageType int, opCode consts.OpCode, message string) { - Send(conn, messageType, opCode, []byte(message)) +// SendMessage sends a string packet to the client. +func SendMessage(conn *websocket.Conn, opCode consts.OpCode, message string) { + Send(conn, opCode, []byte(message)) } +// BroadcastListener listens for messages on the channel and sends them to the client. func BroadcastListener(ctx context.Context, conn *websocket.Conn, ch chan []byte) { for { select { @@ -40,10 +43,12 @@ func BroadcastListener(ctx context.Context, conn *websocket.Conn, ch chan []byte } } -func SendError(conn *websocket.Conn, messageType int, opCode consts.OpCode, err error) { - SendMessage(conn, messageType, opCode, err.Error()) +// SendError sends an error message to the client. +func SendError(conn *websocket.Conn, opCode consts.OpCode, err error) { + SendMessage(conn, opCode, err.Error()) } +// Close gracefully closes the connection. func Close(conn *websocket.Conn) { err := conn.WriteMessage( websocket.CloseMessage, diff --git a/internal/transport/server/websocket/handler/kosk.go b/internal/transport/server/websocket/handler/kosk.go index 9555ca70..a195e43d 100644 --- a/internal/transport/server/websocket/handler/kosk.go +++ b/internal/transport/server/websocket/handler/kosk.go @@ -9,21 +9,23 @@ import ( "github.com/gorilla/websocket" ) -func Kosk(conn *websocket.Conn, payload []byte) error { +// Kosk handler check the result of signer challenge and store it. +func Kosk(conn *websocket.Conn, payload []byte) { challenge := new(model.ChallengePacket).FromBytes(payload) hash, err := bls.Hash(challenge.Random[:]) if err != nil { - return err + SendError(conn, consts.OpCodeError, err) + return } _, err = middleware.IsMessageValid(conn, hash, challenge.Signature) if err != nil { - return consts.ErrInvalidKosk + SendError(conn, consts.OpCodeError, consts.ErrInvalidKosk) + return } challenge.Passed = true store.Challenges.Store(conn, *challenge) - - return nil + SendMessage(conn, consts.OpCodeFeedback, "kosk.ok") } diff --git a/internal/transport/server/websocket/handler/price.go b/internal/transport/server/websocket/handler/price.go index f56187e9..34a693ce 100644 --- a/internal/transport/server/websocket/handler/price.go +++ b/internal/transport/server/websocket/handler/price.go @@ -3,26 +3,30 @@ package handler import ( "github.com/TimeleapLabs/unchained/internal/consts" "github.com/TimeleapLabs/unchained/internal/model" + "github.com/TimeleapLabs/unchained/internal/transport/server/pubsub" "github.com/TimeleapLabs/unchained/internal/transport/server/websocket/middleware" "github.com/gorilla/websocket" ) // PriceReport check signature of message and return price info. -func PriceReport(conn *websocket.Conn, payload []byte) ([]byte, error) { +func PriceReport(conn *websocket.Conn, payload []byte) { err := middleware.IsConnectionAuthenticated(conn) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } priceReport := new(model.PriceReportPacket).FromBytes(payload) priceInfoHash, err := priceReport.PriceInfo.Bls() if err != nil { - return []byte{}, consts.ErrInternalError + SendError(conn, consts.OpCodeError, err) + return } signer, err := middleware.IsMessageValid(conn, priceInfoHash, priceReport.Signature) if err != nil { - return []byte{}, err + SendError(conn, consts.OpCodeError, err) + return } priceInfo := model.BroadcastPricePacket{ @@ -31,5 +35,6 @@ func PriceReport(conn *websocket.Conn, payload []byte) ([]byte, error) { Signer: signer, } - return priceInfo.Sia().Bytes(), nil + pubsub.Publish(consts.ChannelPriceReport, consts.OpCodePriceReportBroadcast, priceInfo.Sia().Bytes()) + SendMessage(conn, consts.OpCodeFeedback, "signature.accepted") } diff --git a/internal/transport/server/websocket/handler/rpc.go b/internal/transport/server/websocket/handler/rpc.go new file mode 100644 index 00000000..db5a0ad7 --- /dev/null +++ b/internal/transport/server/websocket/handler/rpc.go @@ -0,0 +1,67 @@ +package handler + +import ( + "context" + + "github.com/TimeleapLabs/unchained/internal/consts" + "github.com/TimeleapLabs/unchained/internal/service/rpc" + "github.com/TimeleapLabs/unchained/internal/service/rpc/dto" + "github.com/TimeleapLabs/unchained/internal/utils" + "github.com/gorilla/websocket" +) + +// unchainedRPC is a global variable that holds the rpc coordinator. +var unchainedRPC = rpc.NewCoordinator() + +// RegisterRPCFunction is a handler of network that registers a new worker. +func RegisterRPCFunction(_ context.Context, conn *websocket.Conn, payload []byte) { + request := new(dto.RegisterFunction). + FromSiaBytes(payload) + + utils.Logger. + With("IP", conn.RemoteAddr().String()). + With("Function", request.Function). + Info("New Worker registered") + + unchainedRPC.RegisterWorker(request.Function, conn) +} + +// CallFunction is a handler of network that calls a registered function. +func CallFunction(_ context.Context, conn *websocket.Conn, payload []byte) { + request := new(dto.RPCRequest). + FromSiaBytes(payload) + + utils.Logger. + With("IP", conn.RemoteAddr().String()). + With("ID", request.ID). + With("Function", request.Method). + Info("RPC Request") + + unchainedRPC.RegisterTask(request.ID, conn) + worker := unchainedRPC.GetRandomWorker(request.Method) + + if worker != nil { + utils.Logger. + With("IP", conn.RemoteAddr().String()). + With("Function", request.Method). + Info("RPC Request Sent to Worker") + + Send(worker, consts.OpCodeRPCRequest, payload) + } +} + +// ResponseFunction is a handler of network that sends a response to requester. +func ResponseFunction(_ context.Context, conn *websocket.Conn, payload []byte) { + response := new(dto.RPCResponse). + FromSiaBytes(payload) + + task := unchainedRPC.GetTask(response.ID) + if task != nil { + utils.Logger. + With("IP", conn.RemoteAddr().String()). + With("ID", response.ID). + Info("RPC Response") + + Send(task, consts.OpCodeRPCResponse, payload) + } +} diff --git a/internal/transport/server/websocket/middleware/authentication.go b/internal/transport/server/websocket/middleware/authentication.go index 246cae16..1866f547 100644 --- a/internal/transport/server/websocket/middleware/authentication.go +++ b/internal/transport/server/websocket/middleware/authentication.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/websocket" ) +// IsConnectionAuthenticated checks if the connection has passed the challenge or not. func IsConnectionAuthenticated(conn *websocket.Conn) error { challenge, ok := store.Challenges.Load(conn) if !ok || !challenge.Passed { diff --git a/internal/transport/server/websocket/middleware/signature.go b/internal/transport/server/websocket/middleware/signature.go index 4cf3dbde..3ed1297d 100644 --- a/internal/transport/server/websocket/middleware/signature.go +++ b/internal/transport/server/websocket/middleware/signature.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/websocket" ) +// IsMessageValid checks if the message's signature belong to signer or not. func IsMessageValid(conn *websocket.Conn, message bls12381.G1Affine, signature [48]byte) (model.Signer, error) { signer, ok := store.Signers.Load(conn) if !ok { diff --git a/internal/transport/server/websocket/websocket.go b/internal/transport/server/websocket/websocket.go index f94cd296..a23ad228 100644 --- a/internal/transport/server/websocket/websocket.go +++ b/internal/transport/server/websocket/websocket.go @@ -16,6 +16,7 @@ import ( var upgrader = websocket.Upgrader{} +// WithWebsocket is a function that starts a websocket server. func WithWebsocket() func() { return func() { utils.Logger.Info("Starting a websocket server") @@ -25,7 +26,10 @@ func WithWebsocket() func() { } } +// multiplexer is a function that routes incoming messages to the appropriate handler. func multiplexer(w http.ResponseWriter, r *http.Request) { + upgrader.CheckOrigin = func(r *http.Request) bool { return true } // remove this line in production + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { utils.Logger.Error("Can't upgrade the HTTP connection: %v", err) @@ -39,7 +43,7 @@ func multiplexer(w http.ResponseWriter, r *http.Request) { defer cancel() for { - messageType, payload, err := conn.ReadMessage() + _, payload, err := conn.ReadMessage() if err != nil { utils.Logger.Error("Can't read message: %v", err) @@ -57,50 +61,15 @@ func multiplexer(w http.ResponseWriter, r *http.Request) { switch consts.OpCode(payload[0]) { case consts.OpCodeHello: - utils.Logger.With("IP", conn.RemoteAddr().String()).Info("New Client Registered") - result, err := handler.Hello(conn, payload[1:]) - if err != nil { - handler.SendError(conn, messageType, consts.OpCodeError, err) - continue - } - - handler.SendMessage(conn, messageType, consts.OpCodeFeedback, "conf.ok") - handler.Send(conn, messageType, consts.OpCodeKoskChallenge, result) + handler.Hello(conn, payload[1:]) case consts.OpCodePriceReport: - result, err := handler.PriceReport(conn, payload[1:]) - if err != nil { - handler.SendError(conn, messageType, consts.OpCodeError, err) - continue - } - - pubsub.Publish(consts.ChannelPriceReport, consts.OpCodePriceReportBroadcast, result) - handler.SendMessage(conn, messageType, consts.OpCodeFeedback, "signature.accepted") + handler.PriceReport(conn, payload[1:]) case consts.OpCodeEventLog: - result, err := handler.EventLog(conn, payload[1:]) - if err != nil { - handler.SendError(conn, messageType, consts.OpCodeError, err) - continue - } - - pubsub.Publish(consts.ChannelEventLog, consts.OpCodeEventLogBroadcast, result) - handler.SendMessage(conn, messageType, consts.OpCodeFeedback, "signature.accepted") + handler.EventLog(conn, payload[1:]) case consts.OpCodeCorrectnessReport: - result, err := handler.CorrectnessRecord(conn, payload[1:]) - if err != nil { - handler.SendError(conn, messageType, consts.OpCodeError, err) - continue - } - - pubsub.Publish(consts.ChannelCorrectnessReport, consts.OpCodeCorrectnessReportBroadcast, result) - handler.SendMessage(conn, messageType, consts.OpCodeFeedback, "signature.accepted") + handler.CorrectnessRecord(conn, payload[1:]) case consts.OpCodeKoskResult: - err := handler.Kosk(conn, payload[1:]) - if err != nil { - handler.SendError(conn, messageType, consts.OpCodeError, err) - continue - } - - handler.SendMessage(conn, messageType, consts.OpCodeFeedback, "kosk.ok") + handler.Kosk(conn, payload[1:]) case consts.OpCodeRegisterConsumer: utils.Logger. With("IP", conn.RemoteAddr().String()). @@ -108,8 +77,14 @@ func multiplexer(w http.ResponseWriter, r *http.Request) { Info("New Consumer registered") go handler.BroadcastListener(ctx, conn, pubsub.Subscribe(string(payload[1:]))) + case consts.OpCodeRegisterRPCFunction: + handler.RegisterRPCFunction(ctx, conn, payload[1:]) + case consts.OpCodeRPCRequest: + handler.CallFunction(ctx, conn, payload[1:]) + case consts.OpCodeRPCResponse: + handler.ResponseFunction(ctx, conn, payload[1:]) default: - handler.SendError(conn, messageType, consts.OpCodeError, consts.ErrNotSupportedInstruction) + handler.SendError(conn, consts.OpCodeError, consts.ErrNotSupportedInstruction) } } }