-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwhisper.go
239 lines (204 loc) · 5.6 KB
/
whisper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// Package whisper provides a client for interacting with the OpenAI Whisper ASR API.
package whisper
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
)
const (
DefaultBase = "https://api.openai.com/v1"
DefaultModel = "whisper-1"
)
// Client is the main structure for interacting with the Whisper ASR API.
type Client struct {
apiKey string
baseURL string
httpClient *http.Client
}
// ClientOption is a function type that allows to set options for the Client.
type ClientOption func(*Client)
// WithKey sets the API key for the Client.
func WithKey(key string) ClientOption {
return func(c *Client) {
c.apiKey = key
}
}
// WithBaseURL sets the base URL for the Client.
func WithBaseURL(url string) ClientOption {
return func(c *Client) {
c.baseURL = url
}
}
// WithHTTPClient sets the HTTP client for the Client.
func WithHTTPClient(httpClient *http.Client) ClientOption {
return func(c *Client) {
c.httpClient = httpClient
}
}
// NewClient creates a new Whisper ASR API client with the given options.
func NewClient(opts ...ClientOption) *Client {
c := &Client{}
for _, opt := range opts {
opt(c)
}
if c.apiKey == "" {
c.apiKey = os.Getenv("OPENAI_API_KEY")
}
if c.baseURL == "" {
c.baseURL = os.Getenv("OPENAI_BASE_URL")
}
if c.httpClient == nil {
c.httpClient = http.DefaultClient
}
return c
}
// URL constructs the full URL for the given relative path.
func (c *Client) URL(relPath string) string {
if strings.Contains(relPath, "://") {
return relPath
}
baseURL := c.baseURL
if baseURL == "" {
baseURL = DefaultBase
}
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(relPath, "/")
}
// Segment represents a segment of transcribed text in the TranscribeResponse.
type Segment struct {
ID int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
Transient bool `json:"transient"`
}
// TranscribeResponse represents the response from the Whisper ASR API.
type TranscribeResponse struct {
Task string `json:"task"`
Language string `json:"language"`
Duration float64 `json:"duration"`
Segments []Segment `json:"segments"`
Text string `json:"text"`
}
// TranscribeFile transcribes the given file using the Whisper ASR API.
func (c *Client) TranscribeFile(file string, opts ...TranscribeOption) (*TranscribeResponse, error) {
h, err := os.Open(file)
if err != nil {
return nil, err
}
defer h.Close()
opts = append([]TranscribeOption{WithFile(filepath.Base(file))}, opts...)
return c.Transcribe(h, opts...)
}
// TranscribeConfig is a structure that holds the configuration for the Transcribe method.
type TranscribeConfig struct {
Model string
Language string
File string
}
// TranscribeOption is a function type that allows to set options for the Transcribe method.
type TranscribeOption func(*TranscribeConfig)
// WithModel sets the model for the Transcribe method.
func WithModel(model string) TranscribeOption {
return func(tc *TranscribeConfig) {
tc.Model = model
}
}
// WithLanguage sets the language for the Transcribe method.
func WithLanguage(lang string) TranscribeOption {
return func(tc *TranscribeConfig) {
tc.Language = lang
}
}
// WithFile sets the file for the Transcribe method.
func WithFile(file string) TranscribeOption {
return func(tc *TranscribeConfig) {
tc.File = file
}
}
// Transcribe transcribes the given audio stream using the Whisper ASR API.
func (c *Client) Transcribe(h io.Reader, opts ...TranscribeOption) (*TranscribeResponse, error) {
if c.apiKey == "" {
return nil, errors.New("missing API key (set OPENAI_API_KEY in env)")
}
tc := &TranscribeConfig{}
for _, opt := range opts {
opt(tc)
}
if tc.Model == "" {
tc.Model = DefaultModel
}
if tc.File == "" {
return nil, errors.New("filename is not set")
}
b := &bytes.Buffer{}
mp := multipart.NewWriter(b)
f, err := mp.CreateFormField("model")
if err != nil {
return nil, err
}
f.Write([]byte(tc.Model))
if f, err = mp.CreateFormField("response_format"); err != nil {
return nil, err
}
f.Write([]byte("verbose_json"))
fp, err := mp.CreateFormFile("file", tc.File)
if err != nil {
return nil, err
}
if _, err = io.Copy(fp, h); err != nil {
return nil, err
}
mp.Close()
url := c.URL("audio/transcriptions")
req, err := http.NewRequest(http.MethodPost, url, b)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", mp.FormDataContentType())
req.Header.Set("Accept-Encoding", "gzip, deflate")
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var r io.Reader
switch strings.ToLower(resp.Header.Get("Content-Encoding")) {
case "gzip":
r, err = gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer r.(*gzip.Reader).Close()
case "deflate":
r = flate.NewReader(resp.Body)
defer r.(io.ReadCloser).Close()
default:
r = resp.Body
}
if resp.StatusCode != http.StatusOK {
io.Copy(os.Stderr, r)
return nil, fmt.Errorf("unexpected response: %s", resp.Status)
}
var tr TranscribeResponse
if err = json.NewDecoder(r).Decode(&tr); err != nil {
return nil, err
}
return &tr, nil
}