Skip to content

Commit c51b7a1

Browse files
committed
Added some image support
1 parent b41ff7d commit c51b7a1

22 files changed

+614
-83
lines changed

attachment.go

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ import (
88
"net/http"
99
"os"
1010
"path/filepath"
11+
"strings"
1112
)
1213

1314
///////////////////////////////////////////////////////////////////////////////
1415
// TYPES
1516

17+
// General attachment metadata
1618
type AttachmentMeta struct {
1719
Id string `json:"id,omitempty"`
1820
Filename string `json:"filename,omitempty"`
@@ -21,9 +23,17 @@ type AttachmentMeta struct {
2123
Data []byte `json:"data"`
2224
}
2325

26+
// OpenAI image metadata
27+
type ImageMeta struct {
28+
Url string `json:"url,omitempty"`
29+
Data []byte `json:"b64_json,omitempty"`
30+
Prompt string `json:"revised_prompt,omitempty"`
31+
}
32+
2433
// Attachment for messages
2534
type Attachment struct {
26-
meta AttachmentMeta
35+
meta *AttachmentMeta
36+
image *ImageMeta
2737
}
2838

2939
const (
@@ -38,6 +48,11 @@ func NewAttachment() *Attachment {
3848
return new(Attachment)
3949
}
4050

51+
// NewAttachment with OpenAI image
52+
func NewAttachmentWithImage(image *ImageMeta) *Attachment {
53+
return &Attachment{image: image}
54+
}
55+
4156
// ReadAttachment returns an attachment from a reader object.
4257
// It is the responsibility of the caller to close the reader.
4358
func ReadAttachment(r io.Reader) (*Attachment, error) {
@@ -50,7 +65,7 @@ func ReadAttachment(r io.Reader) (*Attachment, error) {
5065
filename = f.Name()
5166
}
5267
return &Attachment{
53-
meta: AttachmentMeta{
68+
meta: &AttachmentMeta{
5469
Filename: filename,
5570
Data: data,
5671
},
@@ -73,19 +88,25 @@ func (a *Attachment) MarshalJSON() ([]byte, error) {
7388
Filename string `json:"filename,omitempty"`
7489
Type string `json:"type"`
7590
Bytes uint64 `json:"bytes"`
76-
Caption string `json:"transcript,omitempty"`
91+
Caption string `json:"caption,omitempty"`
7792
}
78-
j.Id = a.meta.Id
79-
j.Filename = a.meta.Filename
93+
8094
j.Type = a.Type()
81-
j.Bytes = uint64(len(a.meta.Data))
82-
j.Caption = a.meta.Caption
95+
j.Caption = a.Caption()
96+
if a.meta != nil {
97+
j.Id = a.meta.Id
98+
j.Filename = a.meta.Filename
99+
j.Bytes = uint64(len(a.meta.Data))
100+
} else if a.image != nil {
101+
j.Bytes = uint64(len(a.image.Data))
102+
}
103+
83104
return json.Marshal(j)
84105
}
85106

86107
// Stringify an attachment
87108
func (a *Attachment) String() string {
88-
data, err := json.MarshalIndent(a.meta, "", " ")
109+
data, err := json.MarshalIndent(a, "", " ")
89110
if err != nil {
90111
return err.Error()
91112
}
@@ -95,68 +116,98 @@ func (a *Attachment) String() string {
95116
////////////////////////////////////////////////////////////////////////////////
96117
// PUBLIC METHODS
97118

119+
// Write out attachment
120+
func (a *Attachment) Write(w io.Writer) (int, error) {
121+
if a.meta != nil {
122+
return w.Write(a.meta.Data)
123+
}
124+
if a.image != nil {
125+
return w.Write(a.image.Data)
126+
}
127+
return 0, io.EOF
128+
}
129+
98130
// Return the filename of an attachment
99131
func (a *Attachment) Filename() string {
100-
return a.meta.Filename
132+
if a.meta != nil {
133+
return a.meta.Filename
134+
} else {
135+
return ""
136+
}
101137
}
102138

103139
// Return the raw attachment data
104140
func (a *Attachment) Data() []byte {
105-
return a.meta.Data
141+
if a.meta != nil {
142+
return a.meta.Data
143+
}
144+
if a.image != nil {
145+
return a.image.Data
146+
}
147+
return nil
106148
}
107149

108150
// Return the caption for the attachment
109151
func (a *Attachment) Caption() string {
110-
return a.meta.Caption
152+
if a.meta != nil {
153+
return strings.TrimSpace(a.meta.Caption)
154+
}
155+
if a.image != nil {
156+
return strings.TrimSpace(a.image.Prompt)
157+
}
158+
return ""
111159
}
112160

113161
// Return the mime media type for the attachment, based
114162
// on the data and/or filename extension. Returns an empty string if
115163
// there is no data or filename
116164
func (a *Attachment) Type() string {
117165
// If there's no data or filename, return empty
118-
if len(a.meta.Data) == 0 && a.meta.Filename == "" {
166+
if len(a.Data()) == 0 && a.Filename() == "" {
119167
return ""
120168
}
121169

122170
// Mimetype based on content
123171
mimetype := defaultMimetype
124-
if len(a.meta.Data) > 0 {
125-
mimetype = http.DetectContentType(a.meta.Data)
172+
if len(a.Data()) > 0 {
173+
mimetype = http.DetectContentType(a.Data())
126174
if mimetype != defaultMimetype {
127175
return mimetype
128176
}
129177
}
130178

131179
// Mimetype based on filename
132-
if a.meta.Filename != "" {
180+
if a.Filename() != "" {
133181
// Detect mimetype from extension
134-
mimetype = mime.TypeByExtension(filepath.Ext(a.meta.Filename))
182+
mimetype = mime.TypeByExtension(filepath.Ext(a.Filename()))
135183
}
136184

137185
// Return the default mimetype
138186
return mimetype
139187
}
140188

141189
func (a *Attachment) Url() string {
142-
return "data:" + a.Type() + ";base64," + base64.StdEncoding.EncodeToString(a.meta.Data)
190+
return "data:" + a.Type() + ";base64," + base64.StdEncoding.EncodeToString(a.Data())
143191
}
144192

145193
// Streaming includes the ability to append data
146194
func (a *Attachment) Append(other *Attachment) {
147-
if other.meta.Id != "" {
148-
a.meta.Id = other.meta.Id
149-
}
150-
if other.meta.Filename != "" {
151-
a.meta.Filename = other.meta.Filename
152-
}
153-
if other.meta.ExpiresAt != 0 {
154-
a.meta.ExpiresAt = other.meta.ExpiresAt
155-
}
156-
if other.meta.Caption != "" {
157-
a.meta.Caption += other.meta.Caption
158-
}
159-
if len(other.meta.Data) > 0 {
160-
a.meta.Data = append(a.meta.Data, other.meta.Data...)
195+
if a.meta != nil {
196+
if other.meta.Id != "" {
197+
a.meta.Id = other.meta.Id
198+
}
199+
if other.meta.Filename != "" {
200+
a.meta.Filename = other.meta.Filename
201+
}
202+
if other.meta.ExpiresAt != 0 {
203+
a.meta.ExpiresAt = other.meta.ExpiresAt
204+
}
205+
if other.meta.Caption != "" {
206+
a.meta.Caption += other.meta.Caption
207+
}
208+
if len(other.meta.Data) > 0 {
209+
a.meta.Data = append(a.meta.Data, other.meta.Data...)
210+
}
161211
}
212+
// TODO: Append for image
162213
}

cmd/llm/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type ChatCmd struct {
2626
// PUBLIC METHODS
2727

2828
func (cmd *ChatCmd) Run(globals *Globals) error {
29-
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
29+
return run(globals, AudioType, cmd.Model, func(ctx context.Context, model llm.Model) error {
3030
// Current buffer
3131
var buf string
3232

cmd/llm/chat2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func NewTelegramServer(token string, model llm.Model, toolkit llm.ToolKit, opts
5757
// PUBLIC METHODS
5858

5959
func (cmd *Chat2Cmd) Run(globals *Globals) error {
60-
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
60+
return run(globals, ChatType, cmd.Model, func(ctx context.Context, model llm.Model) error {
6161
server, err := NewTelegramServer(cmd.Token, model, globals.toolkit, telegram.WithDebug(globals.Debug))
6262
if err != nil {
6363
return err

cmd/llm/complete.go

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,41 @@ import (
99

1010
// Packages
1111
llm "github.com/mutablelogic/go-llm"
12+
"github.com/mutablelogic/go-llm/pkg/openai"
1213
)
1314

1415
////////////////////////////////////////////////////////////////////////////////
1516
// TYPES
1617

1718
type CompleteCmd struct {
18-
Model string `arg:"" help:"Model name"`
1919
Prompt string `arg:"" optional:"" help:"Prompt"`
20+
Model string `flag:"model" help:"Model name"`
2021
File []string `type:"file" short:"f" help:"Files to attach"`
2122
System string `flag:"system" help:"Set the system prompt"`
2223
NoStream bool `flag:"no-stream" help:"Do not stream output"`
23-
Format string `flag:"format" enum:"text,markdown,json" default:"text" help:"Output format"`
24+
Format string `flag:"format" enum:"text,markdown,json,image,audio" default:"text" help:"Output format"`
25+
Size string `flag:"size" enum:"256x256,512x512,1024x1024,1792x1024,1024x1792" default:"1024x1024" help:"Image size"`
26+
Style string `flag:"style" enum:"vivid,natural" default:"vivid" help:"Image style"`
27+
Quality string `flag:"quality" enum:"standard,hd" default:"standard" help:"Image quality"`
2428
Temperature *float64 `flag:"temperature" short:"t" help:"Temperature for sampling"`
2529
}
2630

2731
////////////////////////////////////////////////////////////////////////////////
2832
// PUBLIC METHODS
2933

34+
func typeFromFormat(format string) Type {
35+
switch format {
36+
case "image":
37+
return ImageType
38+
case "audio":
39+
return AudioType
40+
default:
41+
return TextType
42+
}
43+
}
44+
3045
func (cmd *CompleteCmd) Run(globals *Globals) error {
31-
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
46+
return run(globals, typeFromFormat(cmd.Format), cmd.Model, func(ctx context.Context, model llm.Model) error {
3247
var prompt []byte
3348

3449
// If we are pipeline content in via stdin
@@ -76,13 +91,20 @@ func (cmd *CompleteCmd) Run(globals *Globals) error {
7691
return err
7792
}
7893

79-
// Print the completion
94+
// Print the completion - text
8095
if cmd.NoStream {
8196
fmt.Println(completion.Text(0))
8297
} else {
8398
fmt.Println("")
8499
}
85100

101+
// Print the completion - attachments
102+
for i := 0; i < completion.Num(); i++ {
103+
if attachment := completion.Attachment(i); attachment != nil {
104+
fmt.Println(attachment)
105+
}
106+
}
107+
86108
// Return success
87109
return nil
88110
})
@@ -106,9 +128,12 @@ func (cmd *CompleteCmd) opts() []llm.Opt {
106128
}
107129

108130
// Set format
109-
if cmd.Format == "json" {
110-
opts = append(opts, llm.WithFormat("json"))
111-
}
131+
opts = append(opts, llm.WithFormat(cmd.Format))
132+
133+
// Set image parameters
134+
opts = append(opts, openai.WithSize(cmd.Size))
135+
opts = append(opts, openai.WithStyle(cmd.Style))
136+
opts = append(opts, openai.WithQuality(cmd.Quality))
112137

113138
// Set temperature
114139
if cmd.Temperature != nil {

0 commit comments

Comments
 (0)