diff --git a/connector.go b/connector.go index 77dcda7..f8aa759 100644 --- a/connector.go +++ b/connector.go @@ -19,13 +19,20 @@ var _ driver.Connector = &Connector{} // Connector is an implementation of driver.Connector type Connector struct { - Session *session.Session - Config *mysql.Config + // Session is AWS Session. + Session *session.Session + + // Config is a configure for connecting to MySQL servers. + Config *mysql.Config + + // MaxConnsPerSecond is a limit for creating new connections. + // Zero means no limit. MaxConnsPerSecond int mu sync.Mutex limiter *rate.Limiter - config *mysql.Config + // config is same as Config, but TLS configured + config *mysql.Config } // Connect returns a connection to the database. @@ -37,47 +44,62 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { } } + connector, err := c.newConnector() + if err != nil { + return nil, err + } + return connector.Connect(ctx) +} + +func (c *Connector) newConnector() (driver.Connector, error) { + config, err := c.newConfig() + if err != nil { + return nil, err + } + + // refresh token cred := c.Session.Config.Credentials region := c.Session.Config.Region if region == nil { return nil, errors.New("rdsmysql: region is missing") } + token, err := rdsutils.BuildAuthToken(config.Addr, *region, config.User, cred) + if err != nil { + return nil, fmt.Errorf("rdsmysql: fail to build auth token: %w", err) + } + config.Passwd = token + // create new connector + connector, err := mysql.NewConnector(config) + if err != nil { + return nil, fmt.Errorf("rdsmysql: fail to created new connector: %w", err) + } + return connector, nil +} + +func (c *Connector) newConfig() (*mysql.Config, error) { c.mu.Lock() + defer c.mu.Unlock() + if c.config == nil { copy := *c.Config // shallow copy, but ok. we rewrite only shallow fields. - // format and parse dns. - // because TLS config is loaded by ParseDNS. + // override configure for Amazon RDS + // see https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.Connecting.AWSCLI.html copy.TLSConfig = "rdsmysql" + copy.AllowCleartextPasswords = true + + // format and reparse dns. + // because we can't write TLS config directly, ParseDNS does. config, err := mysql.ParseDSN(copy.FormatDSN()) if err != nil { - c.mu.Unlock() - return nil, fmt.Errorf("fail to parse dsn: %w", err) + return nil, fmt.Errorf("rdsmysql: fail to parse dsn: %w", err) } c.config = config } - config := c.config - - token, err := rdsutils.BuildAuthToken(config.Addr, *region, config.User, cred) - if err != nil { - c.mu.Unlock() - return nil, fmt.Errorf("fail to build auth token: %w", err) - } - // override configure - config.AllowCleartextPasswords = true - config.Passwd = token - config.TLSConfig = "rdsmysql" - - connector, err := mysql.NewConnector(config) - if err != nil { - c.mu.Unlock() - return nil, fmt.Errorf("fail to created new connector: %w", err) - } - c.mu.Unlock() - - return connector.Connect(ctx) + copy := *c.config // shallow copy, but ok. we rewrite only shallow fields. + return ©, nil } func (c *Connector) getlimiter() *rate.Limiter { @@ -85,12 +107,13 @@ func (c *Connector) getlimiter() *rate.Limiter { return nil } c.mu.Lock() + defer c.mu.Unlock() + limiter := c.limiter if limiter == nil { limiter = rate.NewLimiter(rate.Limit(c.MaxConnsPerSecond), 1) c.limiter = limiter } - c.mu.Unlock() return limiter } diff --git a/connector_test.go b/connector_test.go new file mode 100644 index 0000000..f1ef62e --- /dev/null +++ b/connector_test.go @@ -0,0 +1,40 @@ +package rdsmysql + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/go-sql-driver/mysql" +) + +func newTestConnector() *Connector { + cred := credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "") + awsConfig := aws.NewConfig().WithRegion("ap-northeast-1").WithCredentials(cred) + awsSession := session.Must(session.NewSession(awsConfig)) + mysqlConfig, err := mysql.ParseDSN("user:@tcp(db-foobar.ap-northeast-1.rds.amazonaws.com:3306)/") + if err != nil { + panic(err) + } + return &Connector{ + Session: awsSession, + Config: mysqlConfig, + } +} + +func TestNewConnector(t *testing.T) { + _, err := newTestConnector().newConnector() + if err != nil { + t.Fatal(err) + } +} + +func BenchmarkNewConnector(b *testing.B) { + c := newTestConnector() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.newConnector() + } + }) +} diff --git a/go.mod b/go.mod index 8c9237f..26003f4 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,9 @@ module github.com/shogo82148/rdsmysql go 1.13 require ( - github.com/aws/aws-sdk-go v1.25.16 + github.com/aws/aws-sdk-go v1.29.11 github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.4.1-0.20190405001506-d3a0b0fcd73c - github.com/stretchr/testify v1.3.0 // indirect - golang.org/x/time v0.0.0-20190921001708-c4c64cad1fd0 - google.golang.org/appengine v1.6.5 // indirect + github.com/go-sql-driver/mysql v1.5.0 + golang.org/x/text v0.3.2 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index 8c8d633..2c857fe 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,31 @@ -github.com/aws/aws-sdk-go v1.25.16 h1:k7Fy6T/uNuLX6zuayU/TJoP7yMgGcJSkZpF7QVjwYpA= -github.com/aws/aws-sdk-go v1.25.16/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.29.11 h1:f1QJRPu30p0i1lzKhkSSaZFudFGCra2HKgdE442nN6c= +github.com/aws/aws-sdk-go v1.29.11/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-sql-driver/mysql v1.4.1-0.20190405001506-d3a0b0fcd73c h1:4ys3t0WZZ9YX5V2Ky7fdev6gzoU8Em6Q4S/1TeZaSqo= -github.com/go-sql-driver/mysql v1.4.1-0.20190405001506-d3a0b0fcd73c/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/time v0.0.0-20190921001708-c4c64cad1fd0 h1:xQwXv67TxFo9nC1GJFyab5eq/5B590r6RlnL/G8Sz7w= -golang.org/x/time v0.0.0-20190921001708-c4c64cad1fd0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/rdamysql_test.go b/rdamysql_test.go new file mode 100644 index 0000000..5d4799d --- /dev/null +++ b/rdamysql_test.go @@ -0,0 +1,27 @@ +package rdsmysql_test + +import ( + "database/sql" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/shogo82148/rdsmysql" +) + +func ExampleOpen() { + // register authentication infomation + c := aws.NewConfig().WithRegion("ap-northeast-1") + s := session.Must(session.NewSession(c)) + d := &rdsmysql.Driver{ + Session: s, + } + sql.Register("rdsmysql", d) + + db, err := sql.Open("rdsmysql", "user:@tcp(db-foobar.ap-northeast-1.rds.amazonaws.com:3306)/") + if err != nil { + panic(err) + } + defer db.Close() + + // do something with db +}