-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #152 from TimeleapLabs/dirty-ai-pouya
Add AI Plugins
- Loading branch information
Showing
59 changed files
with
2,377 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,6 +204,7 @@ issues: | |
- funlen | ||
- forbidigo | ||
- gochecknoinits | ||
- dupl | ||
- path: ".generated.go" | ||
linters: | ||
- typecheck | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package handler | ||
|
||
import ( | ||
"os" | ||
|
||
"github.com/TimeleapLabs/unchained/cmd/handler/plugins" | ||
"github.com/gorilla/websocket" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
var conn *websocket.Conn | ||
|
||
func Read() <-chan []byte { | ||
out := make(chan []byte) | ||
|
||
go func() { | ||
for { | ||
_, payload, err := conn.ReadMessage() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
out <- payload | ||
} | ||
}() | ||
|
||
return out | ||
} | ||
|
||
// plugin represents the plugin command. | ||
var plugin = &cobra.Command{ | ||
Use: "plugin", | ||
Short: "Run an Unchained plugin locally", | ||
Long: `Run an Unchained plugin locally`, | ||
|
||
Run: func(cmd *cobra.Command, _ []string) { | ||
os.Exit(1) | ||
}, | ||
} | ||
|
||
// WithPluginCmd appends the plugin command to the root command. | ||
func WithPluginCmd(cmd *cobra.Command) { | ||
cmd.AddCommand(plugin) | ||
} | ||
|
||
func init() { | ||
plugins.WithAIPluginCmd(plugin) | ||
plugins.WithTextToImagePluginCmd(plugin) | ||
plugins.WithTranslatePluginCmd(plugin) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package plugins | ||
|
||
import ( | ||
"log" | ||
"os" | ||
"os/signal" | ||
"syscall" | ||
|
||
"github.com/TimeleapLabs/unchained/internal/service/ai" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
// worker represents the worker command. | ||
var aiPlugin = &cobra.Command{ | ||
Use: "ai", | ||
Short: "Start the Unchained ai server for local invocation", | ||
Long: `Start the Unchained ai server for local invocation`, | ||
|
||
Run: func(cmd *cobra.Command, _ []string) { | ||
wg, cancel := ai.StartServer(cmd.Context()) | ||
|
||
// Set up signal handling | ||
sigChan := make(chan os.Signal, 1) | ||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) | ||
|
||
go func() { | ||
sig := <-sigChan | ||
log.Printf("Received signal: %v. Shutting down gracefully...", sig) | ||
cancel() // Cancel the context to stop all managed processes | ||
}() | ||
|
||
// Wait for all processes to finish | ||
wg.Wait() | ||
log.Println("All processes have been stopped.") | ||
}, | ||
} | ||
|
||
// WithRunCmd appends the run command to the root command. | ||
func WithAIPluginCmd(cmd *cobra.Command) { | ||
cmd.AddCommand(aiPlugin) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package plugins | ||
|
||
import "github.com/gorilla/websocket" | ||
|
||
var conn *websocket.Conn | ||
var closed = false | ||
|
||
func Read() <-chan []byte { | ||
out := make(chan []byte) | ||
|
||
go func() { | ||
for { | ||
_, payload, err := conn.ReadMessage() | ||
if err != nil { | ||
if !closed { | ||
panic(err) | ||
} | ||
} | ||
|
||
out <- payload | ||
} | ||
}() | ||
|
||
return out | ||
} | ||
|
||
func CloseSocket() { | ||
if conn != nil { | ||
closed = true | ||
err := conn.WriteMessage( | ||
websocket.CloseMessage, | ||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) | ||
if err != nil { | ||
return | ||
} | ||
err = conn.Close() | ||
if err != nil { | ||
return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package plugins | ||
|
||
import ( | ||
"os" | ||
|
||
"github.com/TimeleapLabs/unchained/internal/service/ai" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
// worker represents the worker command. | ||
var textToImagePlugin = &cobra.Command{ | ||
Use: "text-to-image", | ||
Short: "Run the text-to-image plugin locally", | ||
Long: `Run the text-to-image plugin locally`, | ||
|
||
Run: func(cmd *cobra.Command, _ []string) { | ||
|
||
prompt := cmd.Flags().Lookup("prompt").Value.String() | ||
negativePrompt := cmd.Flags().Lookup("negative-prompt").Value.String() | ||
output := cmd.Flags().Lookup("output").Value.String() | ||
model := cmd.Flags().Lookup("model").Value.String() | ||
loraWeights := cmd.Flags().Lookup("lora-weights").Value.String() | ||
steps, err := cmd.Flags().GetUint8("inference") | ||
|
||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
outputBytes := ai.TextToImage(prompt, negativePrompt, model, loraWeights, steps) | ||
|
||
// write outputBytes as png to output file path of output flag | ||
err = os.WriteFile(output, outputBytes, 0644) //nolint: gosec // Other users may need to read these files. | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
CloseSocket() | ||
os.Exit(0) | ||
}, | ||
} | ||
|
||
// WithRunCmd appends the run command to the root command. | ||
func WithTextToImagePluginCmd(cmd *cobra.Command) { | ||
cmd.AddCommand(textToImagePlugin) | ||
} | ||
|
||
func init() { | ||
textToImagePlugin.Flags().StringP( | ||
"prompt", | ||
"p", | ||
"", | ||
"Prompt data to process", | ||
) | ||
textToImagePlugin.Flags().StringP( | ||
"negative-prompt", | ||
"n", | ||
"", | ||
"Negative prompt data to process", | ||
) | ||
textToImagePlugin.Flags().Uint8P( | ||
"inference", | ||
"i", | ||
16, | ||
"Number of inference steps", | ||
) | ||
textToImagePlugin.Flags().StringP( | ||
"model", | ||
"m", | ||
"OEvortex/PixelGen", | ||
"Model to use for inference", | ||
) | ||
textToImagePlugin.Flags().StringP( | ||
"lora-weights", | ||
"w", | ||
"", | ||
"Lora weights model name (if applicable)", | ||
) | ||
textToImagePlugin.Flags().StringP( | ||
"output", | ||
"o", | ||
"", | ||
"Output file path", | ||
) | ||
|
||
err := textToImagePlugin.MarkFlagRequired("prompt") | ||
if err != nil { | ||
panic(err) | ||
} | ||
err = textToImagePlugin.MarkFlagRequired("output") | ||
if err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package plugins | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"time" | ||
|
||
"github.com/google/uuid" | ||
"github.com/gorilla/websocket" | ||
sia "github.com/pouya-eghbali/go-sia/v2/pkg" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
// worker represents the worker command. | ||
var translatePlugin = &cobra.Command{ | ||
Use: "translate", | ||
Short: "Run the translate plugin locally", | ||
Long: `Run the translate plugin locally`, | ||
|
||
Run: func(cmd *cobra.Command, _ []string) { | ||
|
||
input := cmd.Flags().Lookup("input").Value.String() | ||
from := cmd.Flags().Lookup("from").Value.String() | ||
to := cmd.Flags().Lookup("to").Value.String() | ||
|
||
var err error | ||
conn, _, err = websocket.DefaultDialer.Dial( | ||
"ws://localhost:8765", nil, | ||
) | ||
|
||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
incoming := Read() | ||
|
||
requestUUID, err := uuid.NewV7() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
uuidBytes, err := requestUUID.MarshalBinary() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
payload := sia.New(). | ||
AddUInt16(1). | ||
AddByteArrayN(uuidBytes). | ||
AddStringN(from). | ||
AddStringN(to). | ||
AddString16(input). | ||
Bytes() | ||
|
||
err = conn.WriteMessage(websocket.BinaryMessage, payload) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
// Start a goroutine to print dots every 3 seconds | ||
stopDots := make(chan struct{}) | ||
go func() { | ||
for { | ||
select { | ||
case <-stopDots: | ||
return | ||
case <-time.After(1 * time.Second): | ||
fmt.Print(".") | ||
} | ||
} | ||
}() | ||
|
||
data := <-incoming | ||
|
||
// Stop the dot-printing goroutine | ||
close(stopDots) | ||
fmt.Println() | ||
|
||
// process data | ||
s := sia.NewFromBytes(data) | ||
uuidBytesFromResponse := s.ReadByteArrayN(16) | ||
responseUUID, err := uuid.FromBytes(uuidBytesFromResponse) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if requestUUID != responseUUID { | ||
panic("UUID mismatch") | ||
} | ||
|
||
translated := s.ReadString16() | ||
fmt.Println(translated) | ||
|
||
CloseSocket() | ||
os.Exit(0) | ||
}, | ||
} | ||
|
||
// WithRunCmd appends the run command to the root command. | ||
func WithTranslatePluginCmd(cmd *cobra.Command) { | ||
cmd.AddCommand(translatePlugin) | ||
} | ||
|
||
func init() { | ||
translatePlugin.Flags().StringP( | ||
"input", | ||
"i", | ||
"", | ||
"Input text to translate", | ||
) | ||
translatePlugin.Flags().StringP( | ||
"from", | ||
"f", | ||
"en", | ||
"From language", | ||
) | ||
translatePlugin.Flags().StringP( | ||
"to", | ||
"t", | ||
"fr", | ||
"To language", | ||
) | ||
} |
Oops, something went wrong.