Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI Plugins #152

Merged
merged 14 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ issues:
- funlen
- forbidigo
- gochecknoinits
- dupl
- path: ".generated.go"
linters:
- typecheck
Expand Down
21 changes: 17 additions & 4 deletions cmd/handler/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ var broker = &cobra.Command{
Use: "broker",
Short: "Run the Unchained client in broker mode",
Long: `Run the Unchained client in broker mode`,

PreRun: func(cmd *cobra.Command, _ []string) {
config.App.Network.CertFile = cmd.Flags().Lookup("cert-file").Value.String()
config.App.Network.KeyFile = cmd.Flags().Lookup("key-file").Value.String()
},

Run: func(_ *cobra.Command, _ []string) {
err := config.Load(config.App.System.ConfigPath, config.App.System.SecretsPath)
if err != nil {
Expand All @@ -30,9 +36,16 @@ func WithBrokerCmd(cmd *cobra.Command) {

func init() {
broker.Flags().StringP(
"broker",
"b",
"wss://shinobi.brokers.kenshi.io",
"Unchained broker to connect to",
"cert-file",
"C",
"",
"TLS certificate file",
)

broker.Flags().StringP(
"key-file",
"k",
"",
"TLS key file",
)
}
50 changes: 50 additions & 0 deletions cmd/handler/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package handler

import (
"os"

"github.com/TimeleapLabs/unchained/cmd/handler/plugins"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
)

var conn *websocket.Conn

func Read() <-chan []byte {
out := make(chan []byte)

go func() {
for {
_, payload, err := conn.ReadMessage()
if err != nil {
panic(err)
}

out <- payload
}
}()

return out
}

// plugin represents the plugin command.
var plugin = &cobra.Command{
Use: "plugin",
Short: "Run an Unchained plugin locally",
Long: `Run an Unchained plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {
os.Exit(1)
},
}

// WithPluginCmd appends the plugin command to the root command.
func WithPluginCmd(cmd *cobra.Command) {
cmd.AddCommand(plugin)
}

func init() {
plugins.WithAIPluginCmd(plugin)
plugins.WithTextToImagePluginCmd(plugin)
plugins.WithTranslatePluginCmd(plugin)
}
41 changes: 41 additions & 0 deletions cmd/handler/plugins/ai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package plugins

import (
"log"
"os"
"os/signal"
"syscall"

"github.com/TimeleapLabs/unchained/internal/service/ai"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var aiPlugin = &cobra.Command{
Use: "ai",
Short: "Start the Unchained ai server for local invocation",
Long: `Start the Unchained ai server for local invocation`,

Run: func(cmd *cobra.Command, _ []string) {
wg, cancel := ai.StartServer(cmd.Context())

// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

go func() {
sig := <-sigChan
log.Printf("Received signal: %v. Shutting down gracefully...", sig)
cancel() // Cancel the context to stop all managed processes
}()

// Wait for all processes to finish
wg.Wait()
log.Println("All processes have been stopped.")
},
}

// WithRunCmd appends the run command to the root command.
func WithAIPluginCmd(cmd *cobra.Command) {
cmd.AddCommand(aiPlugin)
}
41 changes: 41 additions & 0 deletions cmd/handler/plugins/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package plugins

import "github.com/gorilla/websocket"

var conn *websocket.Conn
var closed = false

func Read() <-chan []byte {
out := make(chan []byte)

go func() {
for {
_, payload, err := conn.ReadMessage()
if err != nil {
if !closed {
panic(err)
}
}

out <- payload
}
}()

return out
}

func CloseSocket() {
if conn != nil {
closed = true
err := conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return
}
err = conn.Close()
if err != nil {
return
}
}
}
93 changes: 93 additions & 0 deletions cmd/handler/plugins/text_to_image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package plugins

import (
"os"

"github.com/TimeleapLabs/unchained/internal/service/ai"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var textToImagePlugin = &cobra.Command{
Use: "text-to-image",
Short: "Run the text-to-image plugin locally",
Long: `Run the text-to-image plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {

prompt := cmd.Flags().Lookup("prompt").Value.String()
negativePrompt := cmd.Flags().Lookup("negative-prompt").Value.String()
output := cmd.Flags().Lookup("output").Value.String()
model := cmd.Flags().Lookup("model").Value.String()
loraWeights := cmd.Flags().Lookup("lora-weights").Value.String()
steps, err := cmd.Flags().GetUint8("inference")

if err != nil {
panic(err)
}

outputBytes := ai.TextToImage(prompt, negativePrompt, model, loraWeights, steps)

// write outputBytes as png to output file path of output flag
err = os.WriteFile(output, outputBytes, 0644) //nolint: gosec // Other users may need to read these files.
if err != nil {
panic(err)
}

CloseSocket()
os.Exit(0)
},
}

// WithRunCmd appends the run command to the root command.
func WithTextToImagePluginCmd(cmd *cobra.Command) {
cmd.AddCommand(textToImagePlugin)
}

func init() {
textToImagePlugin.Flags().StringP(
"prompt",
"p",
"",
"Prompt data to process",
)
textToImagePlugin.Flags().StringP(
"negative-prompt",
"n",
"",
"Negative prompt data to process",
)
textToImagePlugin.Flags().Uint8P(
"inference",
"i",
16,
"Number of inference steps",
)
textToImagePlugin.Flags().StringP(
"model",
"m",
"OEvortex/PixelGen",
"Model to use for inference",
)
textToImagePlugin.Flags().StringP(
"lora-weights",
"w",
"",
"Lora weights model name (if applicable)",
)
textToImagePlugin.Flags().StringP(
"output",
"o",
"",
"Output file path",
)

err := textToImagePlugin.MarkFlagRequired("prompt")
if err != nil {
panic(err)
}
err = textToImagePlugin.MarkFlagRequired("output")
if err != nil {
panic(err)
}
}
123 changes: 123 additions & 0 deletions cmd/handler/plugins/translate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package plugins

import (
"fmt"
"os"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
sia "github.com/pouya-eghbali/go-sia/v2/pkg"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var translatePlugin = &cobra.Command{
Use: "translate",
Short: "Run the translate plugin locally",
Long: `Run the translate plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {

input := cmd.Flags().Lookup("input").Value.String()
from := cmd.Flags().Lookup("from").Value.String()
to := cmd.Flags().Lookup("to").Value.String()

var err error
conn, _, err = websocket.DefaultDialer.Dial(
"ws://localhost:8765", nil,
)

if err != nil {
panic(err)
}

incoming := Read()

requestUUID, err := uuid.NewV7()
if err != nil {
panic(err)
}

uuidBytes, err := requestUUID.MarshalBinary()
if err != nil {
panic(err)
}

payload := sia.New().
AddUInt16(1).
AddByteArrayN(uuidBytes).
AddStringN(from).
AddStringN(to).
AddString16(input).
Bytes()

err = conn.WriteMessage(websocket.BinaryMessage, payload)
if err != nil {
panic(err)
}

// Start a goroutine to print dots every 3 seconds
stopDots := make(chan struct{})
go func() {
for {
select {
case <-stopDots:
return
case <-time.After(1 * time.Second):
fmt.Print(".")
}
}
}()

data := <-incoming

// Stop the dot-printing goroutine
close(stopDots)
fmt.Println()

// process data
s := sia.NewFromBytes(data)
uuidBytesFromResponse := s.ReadByteArrayN(16)
responseUUID, err := uuid.FromBytes(uuidBytesFromResponse)
if err != nil {
panic(err)
}

if requestUUID != responseUUID {
panic("UUID mismatch")
}

translated := s.ReadString16()
fmt.Println(translated)

CloseSocket()
os.Exit(0)
},
}

// WithRunCmd appends the run command to the root command.
func WithTranslatePluginCmd(cmd *cobra.Command) {
cmd.AddCommand(translatePlugin)
}

func init() {
translatePlugin.Flags().StringP(
"input",
"i",
"",
"Input text to translate",
)
translatePlugin.Flags().StringP(
"from",
"f",
"en",
"From language",
)
translatePlugin.Flags().StringP(
"to",
"t",
"fr",
"To language",
)
}
Loading
Loading