-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmysql.go
104 lines (82 loc) · 2.05 KB
/
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
)
type DatabaseConfig struct {
username string
password string
hostname string
port string
dbName string
}
var (
databaseConn *sql.DB
)
func connectDatabase() error {
log.Println("Trying to connect to DB")
db, err := sql.Open("mysql", createDSN(true))
if err != nil {
return fmt.Errorf("failed to open mysql connection: %w", err)
}
databaseConn = db
if err := createDatabase(os.Getenv("DATABASE_NAME")); err != nil {
return err
}
if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping db: %w", err)
}
db, err = sql.Open("mysql", createDSN(false))
if err != nil {
return fmt.Errorf("failed to open mysql connection using databasename: %w", err)
}
log.Println("connected to database")
databaseConn = db
return nil
}
func createDatabase(dbname string) error {
log.Println("Creating database")
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFunc()
tx, err := databaseConn.BeginTx(ctx, nil)
if err != nil {
return err
}
res, err := tx.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname))
if err != nil {
return err
}
no, err := res.RowsAffected()
if err != nil {
return err
}
if no == 0 {
return errors.New("failed to create database, no rows affected")
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commmit tx: %w", err)
}
return nil
}
func createDSN(skipDB bool) string {
dbCfg := getDatabaseConfig()
if skipDB {
return fmt.Sprintf("%s:%s@tcp(%s)/%s", dbCfg.username, dbCfg.password, dbCfg.hostname, "")
}
return fmt.Sprintf("%s:%s@tcp(%s)/%s", dbCfg.username, dbCfg.password, dbCfg.hostname, dbCfg.dbName)
}
func getDatabaseConfig() DatabaseConfig {
return DatabaseConfig{
username: os.Getenv("DATABASE_USERNAME"),
password: os.Getenv("DATABASE_PASSWORD"),
dbName: os.Getenv("DATABASE_NAME"),
hostname: os.Getenv("MYSQL_SERVICE_HOST"),
port: os.Getenv("MYSQL_SERVICE_port"),
}
}