This repository has been archived by the owner on Oct 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 213
/
client.go
140 lines (124 loc) · 3.8 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package mysql
import (
"database/sql"
"errors"
"io/ioutil"
"net/url"
"os"
"strings"
"github.com/compose/transporter/client"
"github.com/compose/transporter/log"
//_ "github.com/go-sql-driver/mysql" // import mysql driver
"github.com/go-mysql-org/go-mysql/driver" // full import of alternative mysql driver
)
const (
// DefaultURI is the default endpoint of MySQL on the local machine.
// Primarily used when initializing a new Client without a specific URI.
DefaultURI = "mysql://root@localhost:3306?"
)
var (
_ client.Client = &Client{}
)
// ClientOptionFunc is a function that configures a Client.
// It is used in NewClient.
type ClientOptionFunc func(*Client) error
// Client represents a client to the underlying File source.
type Client struct {
uri string
db string
mysqlSession *sql.DB
}
// NewClient creates a default file client
func NewClient(options ...ClientOptionFunc) (*Client, error) {
// Set up the client
c := &Client{
uri: DefaultURI,
db: "test", // TODO: Temporary change from `mysql`? The default local
// instance I have has `test`, but that was before I
// switched to connecting as root
}
// Run the options on it
for _, option := range options {
if err := option(c); err != nil {
return nil, err
}
}
return c, nil
}
// WithURI defines the full connection string for the MySQL connection
// Make this handle the different DSNs for these two?
// - https://github.com/go-sql-driver/mysql#dsn-data-source-name
// - https://github.com/go-mysql-org/go-mysql#driver
func WithURI(uri string) ClientOptionFunc {
return func(c *Client) error {
_, err := url.Parse(uri)
c.uri = uri
return err
}
}
// WithCustomTLS configures the RootCAs for the underlying TLS connection
func WithCustomTLS(uri string, cert string, serverName string) ClientOptionFunc {
return func(c *Client) error {
if cert == "" {
// Then there are no TLS options to configure
return nil
}
if _, err := os.Stat(cert); err != nil {
return errors.New("Cert file not found")
}
caPem, err := ioutil.ReadFile(cert)
if err != nil {
return err
}
log.Debugf("Cert: %s", caPem)
// Pass through to the driver
// If serverName then don't do insecureSkipVerify
insecureSkipVerify := true
if serverName != "" {
insecureSkipVerify = false
}
driverErr := driver.SetCustomTLSConfig(uri, caPem, make([]byte, 0), make([]byte, 0), insecureSkipVerify, serverName)
if driverErr != nil {
return driverErr
}
return nil
}
}
// Close implements necessary calls to cleanup the underlying *sql.DB
func (c *Client) Close() {
if c.mysqlSession != nil {
c.mysqlSession.Close()
}
}
// Connect initializes the MySQL connection
func (c *Client) Connect() (client.Session, error) {
var err error
var dsn string
if c.mysqlSession == nil {
// Previously it said here "there's really no way for this to error...", but that sounds
// like terrible advice when developing, especially, as it took me ages to figure out I
// was getting:
//
// > panic: invalid DSN: missing the slash separating the database name
//
// So let's do _something_
// Also, let's strip prefix if it is there since we need a DSN
dsn = strings.Replace(c.uri, "mysql://", "", 1)
log.Debugln("DSN: " + dsn)
c.mysqlSession, err = sql.Open("mysql", dsn)
if err != nil {
panic(err.Error()) // TODO: Maybe not panic?
}
log.Debugln(c.uri)
// TODO: Error handling below?
uri, _ := url.Parse(c.uri)
if uri.Path != "" {
c.db = uri.Path[1:]
}
}
// We need to disable Foreign Key Checks for imports so also use that to check connection
// Ideally we don't want to send this _every_ time just once per session
// Previously we used `err = c.mysqlSession.Ping()` to check the connection
_, err = c.mysqlSession.Exec("SET FOREIGN_KEY_CHECKS=0;")
return &Session{c.mysqlSession, c.db}, err
}