Skip to content

Commit

Permalink
feat: 添加 Dialect.Backup 用于备份数据库
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed May 3, 2024
1 parent e0a676a commit 89270e3
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 5 deletions.
8 changes: 7 additions & 1 deletion core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type Dialect interface {
// DriverName 与当前实例关联的驱动名称
//
// 原则上驱动名和 Dialect 应该是一一对应的,但是也会有例外,比如:
// github.com/lib/pq 和 github.com/jackc/pgx/v4/stdlib 功能上是相同的,
// github.com/lib/pq 和 github.com/jackc/pgx/v5/stdlib 功能上是相同的,
// 仅注册的名称的不同。
DriverName() string

Expand Down Expand Up @@ -153,6 +153,12 @@ type Dialect interface {
//
// NOTE: query 中不能同时存在 ? 和命名参数。因为如果是命名参数,则 Exec 等的参数顺序可以是随意的。
Prepare(sql string) (query string, orders map[string]int, err error)

// Backup 备份数据库
//
// dsn 初始化数据库的参数,主要从其中获取数据库名称等参数;
// dest 备份的文件名,格式由实现者决定;
Backup(dsn, dest string) error
}

// ErrConstraintExists 返回约束名已经存在的错误
Expand Down
8 changes: 8 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type DB struct {
tablePrefix string
sqlBuilder *sqlbuilder.SQLBuilder
models *model.Models
dsn string
}

// NewDB 声明一个新的 [DB] 实例
Expand All @@ -44,9 +45,15 @@ func NewDB(tablePrefix, dsn string, dialect Dialect) (*DB, error) {
tablePrefix: tablePrefix,
sqlBuilder: sqlbuilder.New(e),
models: ms,
dsn: dsn,
}, nil
}

// Backup 备份数据库至 dest
//
// 具体格式由各个数据库自行决定。
func (db *DB) Backup(dest string) error { return db.Dialect().Backup(db.dsn, dest) }

// New 重新指定表名前缀为 tablePrefix
//
// 如果要复用表模型,可以采此方法创建一个不同表名前缀的 [DB] 对表模型进行操作。
Expand All @@ -61,6 +68,7 @@ func (db *DB) New(tablePrefix string) *DB {
tablePrefix: tablePrefix,
sqlBuilder: sqlbuilder.New(e),
models: db.models,
dsn: db.dsn,
}
}

Expand Down
21 changes: 21 additions & 0 deletions dialect/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

package dialect

import (
"os/exec"

"github.com/issue9/sliceutil"
)

type base struct {
driverName string
name string
Expand All @@ -24,3 +30,18 @@ func (b *base) Name() string { return b.name }
func (b *base) DriverName() string { return b.driverName }

func (b *base) Quotes() (byte, byte) { return b.quoteL, b.quoteR }

func buildCmdArgs(k, v string) string {
if v == "" {
return ""
}
return k + "=" + v
}

func newCommand(name string, env, kv []string) *exec.Cmd {
env = sliceutil.Filter(env, func(i string, _ int) bool { return i != "" })
kv = sliceutil.Filter(kv, func(i string, _ int) bool { return i != "" })
cmd := exec.Command(name, kv...)
cmd.Env = append(cmd.Env, env...)
return cmd
}
17 changes: 17 additions & 0 deletions dialect/base_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: 2024 caixw
//
// SPDX-License-Identifier: MIT

package dialect

import (
"testing"

"github.com/issue9/assert/v4"
)

func TestBuildCmdArg(t *testing.T) {
a := assert.New(t, false)
a.Equal(buildCmdArgs("-p", ""), "").
Equal(buildCmdArgs("-p", "123"), "-p=123")
}
39 changes: 39 additions & 0 deletions dialect/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"

xm "github.com/go-sql-driver/mysql"

"github.com/issue9/orm/v6/core"
"github.com/issue9/orm/v6/sqlbuilder"
)
Expand Down Expand Up @@ -307,3 +311,38 @@ func formatTime(col *core.Column, t time.Time) (string, error) {
}
return "'" + t.Format(datetimeLayouts[index]) + "'", nil
}

func (m *mysql) Backup(dsn, dest string) error {
conf, err := xm.ParseDSN(dsn)
if err != nil {
return err
}

if conf.DBName == "" {
panic("未指定数据库名")
}

h, p, err := net.SplitHostPort(conf.Addr)
if err != nil {
return err
}

file, err := os.Create(dest)
if err != nil {
return err
}
defer file.Close()

cmd := newCommand("mysqldump", []string{}, []string{
buildCmdArgs("--host", h),
buildCmdArgs("--port", p),
buildCmdArgs("--protocol", conf.Net),
buildCmdArgs("--user", conf.User),
buildCmdArgs("--password", conf.Passwd),
conf.DBName,
})
cmd.Stderr = os.Stderr
cmd.Stdout = file

return cmd.Run()
}
12 changes: 12 additions & 0 deletions dialect/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package dialect_test

import (
"os"
"testing"

"github.com/issue9/assert/v4"
Expand Down Expand Up @@ -284,3 +285,14 @@ func TestMysql_TypesDefault(t *testing.T) {
testTypesDefault(t)
})
}

func TestMysql_Backup(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)
suite.Run(func(d *test.Driver) {
path := "./testdata/mysql.sql"
d.Assertion.NotError(d.DB.Backup(path))
d.Assertion.FileExists(path)
d.Assertion.NotError(os.Remove(path))
})
}
48 changes: 48 additions & 0 deletions dialect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import (
"database/sql/driver"
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"

"github.com/lib/pq"

"github.com/issue9/orm/v6/core"
)

Expand Down Expand Up @@ -261,3 +264,48 @@ func (p *postgres) formatSQL(col *core.Column) (f string, err error) {

return fmt.Sprint(v), nil
}

func (p *postgres) Backup(dsn, dest string) error {
// http://www.postgres.cn/docs/14/app-pgdump.html

opt, err := parsePostgresDSN(dsn)
if err != nil {
return err
}

cmd := newCommand("pg_dump", []string{
buildCmdArgs("PGPASSWORD", opt["password"]),
}, []string{
buildCmdArgs("--format", "c"),
buildCmdArgs("--file", dest),
buildCmdArgs("--host", opt["host"]),
buildCmdArgs("--port", opt["port"]),
buildCmdArgs("--username", opt["user"]),
buildCmdArgs("--dbname", opt["dbname"]),
"--no-password",
})
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout

return cmd.Run()
}

func parsePostgresDSN(dsn string) (opt map[string]string, err error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
if dsn, err = pq.ParseURL(dsn); err != nil {
return nil, err
}
}

items := strings.Fields(dsn)
opt = make(map[string]string, len(items))
for _, item := range items {
kv := strings.Split(item, "=")
if len(kv) != 2 {
panic("参数格式错误:" + item)
}
opt[kv[0]] = kv[1]
}

return opt, nil
}
12 changes: 12 additions & 0 deletions dialect/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package dialect_test

import (
"database/sql"
"os"
"testing"

"github.com/issue9/assert/v4"
Expand Down Expand Up @@ -317,3 +318,14 @@ func TestPostgres_TypesDefault(t *testing.T) {
testTypesDefault(t)
})
}

func TestPostgres_Backup(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, "", test.Postgres)
suite.Run(func(d *test.Driver) {
path := "./testdata/postgres.sql"
d.Assertion.NotError(d.DB.Backup(path))
d.Assertion.FileExists(path)
d.Assertion.NotError(os.Remove(path))
})
}
15 changes: 15 additions & 0 deletions dialect/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"database/sql/driver"
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"

"github.com/issue9/orm/v6/core"
Expand Down Expand Up @@ -343,3 +345,16 @@ func (s *sqlite3) formatSQL(v any) (f string, err error) {

return fmt.Sprint(v), nil
}

func (s *sqlite3) Backup(dsn, dest string) error {
if index := strings.IndexByte(dsn, '?'); index >= 0 {
dsn = dsn[:index]
}

data,err :=os.ReadFile(dsn)
if err!=nil{
return err
}

return os.WriteFile(dest, data , os.ModePerm)
}
12 changes: 12 additions & 0 deletions dialect/sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package dialect_test

import (
"os"
"testing"

"github.com/issue9/assert/v4"
Expand Down Expand Up @@ -328,3 +329,14 @@ func TestSqlite3_TypesDefault(t *testing.T) {
testTypesDefault(t)
})
}

func TestSqlite3_Backup(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, "", test.Sqlite3)
suite.Run(func(d *test.Driver) {
path := "./testdata/sqlite3.db"
d.Assertion.NotError(d.DB.Backup(path))
d.Assertion.FileExists(path)
d.Assertion.NotError(os.Remove(path))
})
}
2 changes: 2 additions & 0 deletions dialect/testdata/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.db
*.sql
Empty file added dialect/testdata/.gitkeep
Empty file.
8 changes: 4 additions & 4 deletions internal/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package test

import (
"os"
"slices"

"github.com/issue9/assert/v4"
"github.com/issue9/sliceutil"

"github.com/issue9/orm/v6"
"github.com/issue9/orm/v6/core"
Expand Down Expand Up @@ -92,16 +92,16 @@ func NewSuite(a *assert.Assertion, tablePrefix string, dialect ...core.Dialect)
s := &Suite{a: a}
a.TB().Cleanup(func() { s.close() })

fs := flags
for _, c := range cases {
name := c.dialect.Name()
driver := c.dialect.DriverName()

if len(dialect) > 0 && sliceutil.Count(dialect, func(i core.Dialect, _ int) bool { return i.Name() == name && i.DriverName() == driver }) <= 0 {
if len(dialect) > 0 && slices.IndexFunc(dialect, func(i core.Dialect) bool { return i.Name() == name && i.DriverName() == driver }) < 0 {
continue
}

fs := flags
if len(fs) > 0 && sliceutil.Count(fs, func(i *flagVar, _ int) bool { return i.Name == name && i.DriverName == driver }) <= 0 {
if len(fs) > 0 && slices.IndexFunc(fs, func(i *flagVar) bool { return i.Name == name && i.DriverName == driver }) < 0 {
continue
}

Expand Down

0 comments on commit 89270e3

Please sign in to comment.