Skip to content

Commit

Permalink
NEOS-1504, NEOS-1373, NEOS-1249: update sql dialers to include tls ce…
Browse files Browse the repository at this point in the history
…rt support (#3028)
  • Loading branch information
nickzelei authored Dec 10, 2024
1 parent 7cfc8aa commit 09b1d57
Show file tree
Hide file tree
Showing 62 changed files with 2,519 additions and 962 deletions.
755 changes: 398 additions & 357 deletions backend/gen/go/protos/mgmt/v1alpha1/connection.pb.go

Large diffs are not rendered by default.

84 changes: 1 addition & 83 deletions backend/pkg/clienttls/clienttls.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,7 @@ type ClientTlsFileConfig struct {
ClientKey *string
}

type ClientTlsFileHandler = func(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error)

func UpsertCLientTlsFiles(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error) {
if config == nil {
return nil, errors.New("config was nil")
}

errgrp := errgroup.Group{}

filenames := GetClientTlsFileNames(config)

errgrp.Go(func() error {
if filenames.RootCert == nil {
return nil
}
_, err := os.Stat(*filenames.RootCert)
if err != nil && !os.IsNotExist(err) {
return err
} else if err != nil && os.IsNotExist(err) {
if err := os.WriteFile(*filenames.RootCert, []byte(config.GetRootCert()), 0600); err != nil {
return err
}
}
return nil
})
errgrp.Go(func() error {
if filenames.ClientCert != nil && filenames.ClientKey != nil {
_, err := os.Stat(*filenames.ClientKey)
if err != nil && !os.IsNotExist(err) {
return err
} else if err != nil && os.IsNotExist(err) {
if err := os.WriteFile(*filenames.ClientKey, []byte(config.GetClientKey()), 0600); err != nil {
return err
}
}
}
return nil
})
errgrp.Go(func() error {
if filenames.ClientCert != nil && filenames.ClientKey != nil {
_, err := os.Stat(*filenames.ClientCert)
if err != nil && !os.IsNotExist(err) {
return err
} else if err != nil && os.IsNotExist(err) {
if err := os.WriteFile(*filenames.ClientCert, []byte(config.GetClientCert()), 0600); err != nil {
return err
}
}
}
return nil
})

err := errgrp.Wait()
if err != nil {
return nil, err
}

return &filenames, nil
}
type ClientTlsFileHandler func(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error)

// Joins the client cert and key into a single file
func UpsertClientTlsFileSingleClient(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error) {
Expand Down Expand Up @@ -125,30 +67,6 @@ func UpsertClientTlsFileSingleClient(config *mgmtv1alpha1.ClientTlsConfig) (*Cli
return &filenames, nil
}

func GetClientTlsFileNames(config *mgmtv1alpha1.ClientTlsConfig) ClientTlsFileConfig {
if config == nil {
return ClientTlsFileConfig{}
}

basedir := os.TempDir()

output := ClientTlsFileConfig{}
if config.GetRootCert() != "" {
content := hashContent(config.GetRootCert())
fullpath := filepath.Join(basedir, content)
output.RootCert = &fullpath
}
if config.GetClientCert() != "" && config.GetClientKey() != "" {
certContent := hashContent(config.GetClientCert())
certpath := filepath.Join(basedir, certContent)
keyContent := hashContent(config.GetClientKey())
keypath := filepath.Join(basedir, keyContent)
output.ClientCert = &certpath
output.ClientKey = &keypath
}
return output
}

// Joins the client cert and key into a single file
func GetClientTlsFileNamesSingleClient(config *mgmtv1alpha1.ClientTlsConfig) ClientTlsFileConfig {
if config == nil {
Expand Down
23 changes: 0 additions & 23 deletions backend/pkg/dbconnect-config/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/url"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/clienttls"
)

const postgresScheme = "postgres"
Expand Down Expand Up @@ -55,9 +54,6 @@ func NewFromPostgresConnection(
if cc.Connection.GetSslMode() != "" {
query.Set("sslmode", cc.Connection.GetSslMode())
}
if config.PgConfig.GetClientTls() != nil {
query = setPgClientTlsQueryParams(query, config.PgConfig.GetClientTls())
}
if connectionTimeout != nil {
query.Set("connect_timeout", fmt.Sprintf("%d", *connectionTimeout))
}
Expand All @@ -79,32 +75,13 @@ func NewFromPostgresConnection(
if !query.Has("connect_timeout") && connectionTimeout != nil {
query.Set("connect_timeout", fmt.Sprintf("%d", *connectionTimeout))
}
// todo: move this out of here into the driver
if config.PgConfig.GetClientTls() != nil {
query = setPgClientTlsQueryParams(query, config.PgConfig.GetClientTls())
}
uriconfig.RawQuery = query.Encode()
return &pgConnectConfig{url: uriconfig.String(), user: getUserFromInfo(uriconfig.User)}, nil
default:
return nil, fmt.Errorf("unsupported pg connection config: %T", cc)
}
}

func setPgClientTlsQueryParams(
query url.Values,
cfg *mgmtv1alpha1.ClientTlsConfig,
) url.Values {
filenames := clienttls.GetClientTlsFileNames(cfg)
if filenames.RootCert != nil {
query.Set("sslrootcert", *filenames.RootCert)
}
if filenames.ClientCert != nil && filenames.ClientKey != nil {
query.Set("sslcert", *filenames.ClientCert)
query.Set("sslkey", *filenames.ClientKey)
}
return query
}

func getUserFromInfo(u *url.Userinfo) string {
if u == nil {
return ""
Expand Down
Loading

0 comments on commit 09b1d57

Please sign in to comment.