Skip to content

Commit

Permalink
Merge pull request #152 from TimeleapLabs/dirty-ai-pouya
Browse files Browse the repository at this point in the history
Add AI Plugins
  • Loading branch information
pouya-eghbali authored Aug 24, 2024
2 parents 91d0177 + 8616d80 commit 06fd614
Show file tree
Hide file tree
Showing 59 changed files with 2,377 additions and 106 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ issues:
- funlen
- forbidigo
- gochecknoinits
- dupl
- path: ".generated.go"
linters:
- typecheck
Expand Down
21 changes: 17 additions & 4 deletions cmd/handler/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ var broker = &cobra.Command{
Use: "broker",
Short: "Run the Unchained client in broker mode",
Long: `Run the Unchained client in broker mode`,

PreRun: func(cmd *cobra.Command, _ []string) {
config.App.Network.CertFile = cmd.Flags().Lookup("cert-file").Value.String()
config.App.Network.KeyFile = cmd.Flags().Lookup("key-file").Value.String()
},

Run: func(_ *cobra.Command, _ []string) {
err := config.Load(config.App.System.ConfigPath, config.App.System.SecretsPath)
if err != nil {
Expand All @@ -30,9 +36,16 @@ func WithBrokerCmd(cmd *cobra.Command) {

func init() {
broker.Flags().StringP(
"broker",
"b",
"wss://shinobi.brokers.kenshi.io",
"Unchained broker to connect to",
"cert-file",
"C",
"",
"TLS certificate file",
)

broker.Flags().StringP(
"key-file",
"k",
"",
"TLS key file",
)
}
50 changes: 50 additions & 0 deletions cmd/handler/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package handler

import (
"os"

"github.com/TimeleapLabs/unchained/cmd/handler/plugins"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
)

var conn *websocket.Conn

func Read() <-chan []byte {
out := make(chan []byte)

go func() {
for {
_, payload, err := conn.ReadMessage()
if err != nil {
panic(err)
}

out <- payload
}
}()

return out
}

// plugin represents the plugin command.
var plugin = &cobra.Command{
Use: "plugin",
Short: "Run an Unchained plugin locally",
Long: `Run an Unchained plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {
os.Exit(1)
},
}

// WithPluginCmd appends the plugin command to the root command.
func WithPluginCmd(cmd *cobra.Command) {
cmd.AddCommand(plugin)
}

func init() {
plugins.WithAIPluginCmd(plugin)
plugins.WithTextToImagePluginCmd(plugin)
plugins.WithTranslatePluginCmd(plugin)
}
41 changes: 41 additions & 0 deletions cmd/handler/plugins/ai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package plugins

import (
"log"
"os"
"os/signal"
"syscall"

"github.com/TimeleapLabs/unchained/internal/service/ai"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var aiPlugin = &cobra.Command{
Use: "ai",
Short: "Start the Unchained ai server for local invocation",
Long: `Start the Unchained ai server for local invocation`,

Run: func(cmd *cobra.Command, _ []string) {
wg, cancel := ai.StartServer(cmd.Context())

// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

go func() {
sig := <-sigChan
log.Printf("Received signal: %v. Shutting down gracefully...", sig)
cancel() // Cancel the context to stop all managed processes
}()

// Wait for all processes to finish
wg.Wait()
log.Println("All processes have been stopped.")
},
}

// WithRunCmd appends the run command to the root command.
func WithAIPluginCmd(cmd *cobra.Command) {
cmd.AddCommand(aiPlugin)
}
41 changes: 41 additions & 0 deletions cmd/handler/plugins/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package plugins

import "github.com/gorilla/websocket"

var conn *websocket.Conn
var closed = false

func Read() <-chan []byte {
out := make(chan []byte)

go func() {
for {
_, payload, err := conn.ReadMessage()
if err != nil {
if !closed {
panic(err)
}
}

out <- payload
}
}()

return out
}

func CloseSocket() {
if conn != nil {
closed = true
err := conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return
}
err = conn.Close()
if err != nil {
return
}
}
}
93 changes: 93 additions & 0 deletions cmd/handler/plugins/text_to_image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package plugins

import (
"os"

"github.com/TimeleapLabs/unchained/internal/service/ai"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var textToImagePlugin = &cobra.Command{
Use: "text-to-image",
Short: "Run the text-to-image plugin locally",
Long: `Run the text-to-image plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {

prompt := cmd.Flags().Lookup("prompt").Value.String()
negativePrompt := cmd.Flags().Lookup("negative-prompt").Value.String()
output := cmd.Flags().Lookup("output").Value.String()
model := cmd.Flags().Lookup("model").Value.String()
loraWeights := cmd.Flags().Lookup("lora-weights").Value.String()
steps, err := cmd.Flags().GetUint8("inference")

if err != nil {
panic(err)
}

outputBytes := ai.TextToImage(prompt, negativePrompt, model, loraWeights, steps)

// write outputBytes as png to output file path of output flag
err = os.WriteFile(output, outputBytes, 0644) //nolint: gosec // Other users may need to read these files.
if err != nil {
panic(err)
}

CloseSocket()
os.Exit(0)
},
}

// WithRunCmd appends the run command to the root command.
func WithTextToImagePluginCmd(cmd *cobra.Command) {
cmd.AddCommand(textToImagePlugin)
}

func init() {
textToImagePlugin.Flags().StringP(
"prompt",
"p",
"",
"Prompt data to process",
)
textToImagePlugin.Flags().StringP(
"negative-prompt",
"n",
"",
"Negative prompt data to process",
)
textToImagePlugin.Flags().Uint8P(
"inference",
"i",
16,
"Number of inference steps",
)
textToImagePlugin.Flags().StringP(
"model",
"m",
"OEvortex/PixelGen",
"Model to use for inference",
)
textToImagePlugin.Flags().StringP(
"lora-weights",
"w",
"",
"Lora weights model name (if applicable)",
)
textToImagePlugin.Flags().StringP(
"output",
"o",
"",
"Output file path",
)

err := textToImagePlugin.MarkFlagRequired("prompt")
if err != nil {
panic(err)
}
err = textToImagePlugin.MarkFlagRequired("output")
if err != nil {
panic(err)
}
}
123 changes: 123 additions & 0 deletions cmd/handler/plugins/translate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package plugins

import (
"fmt"
"os"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
sia "github.com/pouya-eghbali/go-sia/v2/pkg"
"github.com/spf13/cobra"
)

// worker represents the worker command.
var translatePlugin = &cobra.Command{
Use: "translate",
Short: "Run the translate plugin locally",
Long: `Run the translate plugin locally`,

Run: func(cmd *cobra.Command, _ []string) {

input := cmd.Flags().Lookup("input").Value.String()
from := cmd.Flags().Lookup("from").Value.String()
to := cmd.Flags().Lookup("to").Value.String()

var err error
conn, _, err = websocket.DefaultDialer.Dial(
"ws://localhost:8765", nil,
)

if err != nil {
panic(err)
}

incoming := Read()

requestUUID, err := uuid.NewV7()
if err != nil {
panic(err)
}

uuidBytes, err := requestUUID.MarshalBinary()
if err != nil {
panic(err)
}

payload := sia.New().
AddUInt16(1).
AddByteArrayN(uuidBytes).
AddStringN(from).
AddStringN(to).
AddString16(input).
Bytes()

err = conn.WriteMessage(websocket.BinaryMessage, payload)
if err != nil {
panic(err)
}

// Start a goroutine to print dots every 3 seconds
stopDots := make(chan struct{})
go func() {
for {
select {
case <-stopDots:
return
case <-time.After(1 * time.Second):
fmt.Print(".")
}
}
}()

data := <-incoming

// Stop the dot-printing goroutine
close(stopDots)
fmt.Println()

// process data
s := sia.NewFromBytes(data)
uuidBytesFromResponse := s.ReadByteArrayN(16)
responseUUID, err := uuid.FromBytes(uuidBytesFromResponse)
if err != nil {
panic(err)
}

if requestUUID != responseUUID {
panic("UUID mismatch")
}

translated := s.ReadString16()
fmt.Println(translated)

CloseSocket()
os.Exit(0)
},
}

// WithRunCmd appends the run command to the root command.
func WithTranslatePluginCmd(cmd *cobra.Command) {
cmd.AddCommand(translatePlugin)
}

func init() {
translatePlugin.Flags().StringP(
"input",
"i",
"",
"Input text to translate",
)
translatePlugin.Flags().StringP(
"from",
"f",
"en",
"From language",
)
translatePlugin.Flags().StringP(
"to",
"t",
"fr",
"To language",
)
}
Loading

0 comments on commit 06fd614

Please sign in to comment.