-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclient.go
130 lines (107 loc) · 2.99 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package openrouter
import (
"context"
"encoding/json"
"fmt"
utils "github.com/casibase/go-openrouter/internal"
"io"
"net/http"
)
type Client struct {
config ClientConfig
requestBuilder utils.RequestBuilder
}
func NewClient(auth, xTitle, httpReferer string) (*Client, error) {
config, err := DefaultConfig(auth, xTitle, httpReferer)
if err != nil {
return nil, err
}
return NewClientWithConfig(config), nil
}
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: utils.NewRequestBuilder(),
}
}
func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
contentType := req.Header.Get("Content-Type")
if contentType == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
c.setCommonHeaders(req)
res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if isFailureStatusCode(res) {
return c.handleErrorResp(res)
}
return decodeResponse(res.Body, v)
}
func (c *Client) setCommonHeaders(req *http.Request) {
req.Header.Set("HTTP-Referer", c.config.HttpReferer)
req.Header.Set("X-Title", c.config.XTitle)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
func isFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
func decodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return decodeString(body, result)
}
return json.NewDecoder(body).Decode(v)
}
func decodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}
// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}
func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any) (*http.Request, error) {
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
c.setCommonHeaders(req)
return req, nil
}
func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := &RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
if errRes.Error != nil {
reqErr.Err = errRes.Error
}
return reqErr
}
errRes.Error.HTTPStatusCode = resp.StatusCode
return errRes.Error
}