Skip to content

Commit

Permalink
Add tls support for mysql client (#186)
Browse files Browse the repository at this point in the history
* Add tls support for mysql client

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Adjust code based on reviews

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update error message

* Update cmd/go-tpc/main.go

Co-authored-by: Daniël van Eeden <[email protected]>

* Update cmd/go-tpc/main.go

---------

Co-authored-by: Daniël van Eeden <[email protected]>
  • Loading branch information
db-will and dveeden authored Jan 9, 2025
1 parent 01c0653 commit 6cd9f74
Showing 1 changed file with 69 additions and 5 deletions.
74 changes: 69 additions & 5 deletions cmd/go-tpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package main
import (
"context"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"database/sql"
sqldrv "database/sql/driver"
"encoding/hex"
Expand All @@ -18,6 +20,7 @@ import (
"github.com/pingcap/go-tpc/pkg/util"
"github.com/spf13/cobra"
_ "go.uber.org/automaxprocs"

// mysql package
"github.com/go-sql-driver/mysql"
// pg
Expand Down Expand Up @@ -47,15 +50,19 @@ var (
connParams string
outputStyle string
targets []string
sslCA string
sslCert string
sslKey string

globalDB *sql.DB
globalCtx context.Context
)

const (
createDBDDL = "CREATE DATABASE "
mysqlDriver = "mysql"
pgDriver = "postgres"
createDBDDL = "CREATE DATABASE "
mysqlDriver = "mysql"
pgDriver = "postgres"
customTlsName = "custom"
)

type MuxDriver struct {
Expand Down Expand Up @@ -93,18 +100,31 @@ func newDB(targets []string, driver string, user string, password string, dbName
hash.Write([]byte(password))
hash.Write([]byte(dbName))
hash.Write([]byte(connParams))

if driver == mysqlDriver && (len(sslCA) > 0 || len(sslCert) > 0 || len(sslKey) > 0) {
registerMysqlTLSConfig()
}

for i, addr := range targets {
hash.Write([]byte(addr))
switch driver {
case mysqlDriver:
var tlsName string = "preferred"
if len(sslCA) > 0 {
tlsName = customTlsName
}
// allow multiple statements in one query to allow q15 on the TPC-H
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=preferred", user, password, addr, dbName)
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=%s", user, password, addr, dbName, tlsName)
if len(connParams) > 0 {
dsn = dsn + "&" + connParams
}
names[i] = dsn
drv = &mysql.MySQLDriver{}
case pgDriver:
if len(sslCA) > 0 || len(sslKey) > 0 || len(sslCert) > 0 {
panic("postgresql driver doesn't support TLS yet")
}

dsn := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, addr, dbName)
if len(connParams) > 0 {
dsn = dsn + "?" + connParams
Expand Down Expand Up @@ -150,9 +170,10 @@ func openDB() {
tmpDB, _ = newDB(targets, driver, user, password, "", connParams)
defer tmpDB.Close()
if _, err := tmpDB.Exec(createDBDDL + dbName); err != nil {
panic(fmt.Errorf("failed to create database, err %v\n", err))
panic(fmt.Errorf("failed to create database, err %v", err))
}
} else {
fmt.Printf("failed to ping db, err %v\n", err)
globalDB = nil
}
} else {
Expand Down Expand Up @@ -209,6 +230,9 @@ func main() {
rootCmd.PersistentFlags().StringVar(&outputStyle, "output", util.OutputStylePlain, "output style, valid values can be { plain | table | json }")
rootCmd.PersistentFlags().StringSliceVar(&targets, "targets", nil, "Target database addresses")
rootCmd.PersistentFlags().MarkHidden("targets")
rootCmd.PersistentFlags().StringVar(&sslCA, "ssl-ca", "", "Path of file that contains list of trusted SSL CAs for connection")
rootCmd.PersistentFlags().StringVar(&sslCert, "ssl-cert", "", "Path of file that contains X509 certificate in PEM format for connection")
rootCmd.PersistentFlags().StringVar(&sslKey, "ssl-key", "", "Path of file that contains X509 key in PEM format for connection")

cobra.EnablePrefixMatching = true

Expand Down Expand Up @@ -251,3 +275,43 @@ func main() {

cancel()
}

// registerMysqlTLSConfig constructs a `*tls.Config` from the CA, certification and key
// paths, and register to mysql client.
func registerMysqlTLSConfig() {
// Load the client certificates from disk
var certificates []tls.Certificate
if len(sslCert) != 0 && len(sslKey) != 0 {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Errorf("could not load client key pair, err %v", err))
}
certificates = []tls.Certificate{cert}
} else if len(sslCert) > 0 || len(sslKey) > 0 {
panic("incomplete key pair configuration")
}

// Create a certificate pool from CA
certPool := x509.NewCertPool()
ca, err := os.ReadFile(sslCA)
if err != nil {
panic(fmt.Errorf("could not read CA certificate, err %v", err))
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
panic("failed to append CA certs")
}

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: certificates,
RootCAs: certPool,
ClientCAs: certPool,
}

err = mysql.RegisterTLSConfig(customTlsName, tlsConfig)
if err != nil {
panic(fmt.Errorf("failed to register TLS config, err %v", err))
}
}

0 comments on commit 6cd9f74

Please sign in to comment.