forked from go-skynet/go-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
llama.go
56 lines (45 loc) · 1.43 KB
/
llama.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
package llama
// #cgo CXXFLAGS: -I./llama.cpp/examples -I./llama.cpp
// #cgo LDFLAGS: -L./ -lbinding -lm -lstdc++
// #include "binding.h"
import "C"
import (
"fmt"
"strings"
"unsafe"
)
type LLama struct {
state unsafe.Pointer
}
func New(model string, opts ...ModelOption) (*LLama, error) {
mo := NewModelOptions(opts...)
modelPath := C.CString(model)
result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Parts), C.int(mo.Seed), C.bool(mo.F16Memory), C.bool(mo.MLock))
if result == nil {
return nil, fmt.Errorf("failed loading model")
}
return &LLama{state: result}, nil
}
func (l *LLama) Free() {
C.llama_free_model(l.state)
}
func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)
input := C.CString(text)
if po.Tokens == 0 {
po.Tokens = 99999999
}
out := make([]byte, po.Tokens)
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV))
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
if ret != 0 {
return "", fmt.Errorf("inference failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, text)
res = strings.TrimPrefix(res, "\n")
C.llama_free_params(params)
return res, nil
}