Skip to content

Commit

Permalink
Refactor plugin/yaml/main.go and plugin/yaml/yaml.go to support multi…
Browse files Browse the repository at this point in the history
…ple YAML config files
  • Loading branch information
tg123 committed Oct 28, 2024
1 parent 23e18f4 commit cdfd580
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 41 deletions.
6 changes: 3 additions & 3 deletions plugin/yaml/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ func main() {
Name: "yaml",
Usage: "sshpiperd yaml plugin",
Flags: []cli.Flag{
&cli.StringFlag{
&cli.StringSliceFlag{
Name: "config",
Usage: "path to yaml config file",
Usage: "path to yaml config files, can be globs as well",
Required: true,
EnvVars: []string{"SSHPIPERD_YAML_CONFIG"},
Destination: &plugin.File,
Destination: &plugin.FileGlobs,
},
&cli.BoolFlag{
Name: "no-check-perm",
Expand Down
37 changes: 19 additions & 18 deletions plugin/yaml/skel.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ import (
)

type skelpipeWrapper struct {
plugin *plugin

pipe *yamlPipe
pipe *yamlPipe
config *piperConfig
}
type skelpipeFromWrapper struct {
plugin *plugin
config *piperConfig

from *yamlPipeFrom
to *yamlPipeTo
Expand All @@ -28,7 +27,7 @@ type skelpipePublicKeyWrapper struct {
}

type skelpipeToWrapper struct {
plugin *plugin
config *piperConfig

username string
to *yamlPipeTo
Expand All @@ -39,7 +38,7 @@ func (s *skelpipeWrapper) From() []libplugin.SkelPipeFrom {
for _, f := range s.pipe.From {

w := &skelpipeFromWrapper{
plugin: s.plugin,
config: s.config,
from: &f,
to: &s.pipe.To,
}
Expand Down Expand Up @@ -70,7 +69,7 @@ func (s *skelpipeToWrapper) IgnoreHostKey(conn libplugin.ConnMetadata) bool {
}

func (s *skelpipeToWrapper) KnownHosts(conn libplugin.ConnMetadata) ([]byte, error) {
return s.plugin.loadFileOrDecodeMany(s.to.KnownHosts, s.to.KnownHostsData, map[string]string{
return s.config.loadFileOrDecodeMany(s.to.KnownHosts, s.to.KnownHostsData, map[string]string{
"DOWNSTREAM_USER": conn.User(),
"UPSTREAM_USER": s.username,
})
Expand Down Expand Up @@ -101,7 +100,7 @@ func (s *skelpipeFromWrapper) MatchConn(conn libplugin.ConnMetadata) (libplugin.

if matched {
return &skelpipeToWrapper{
plugin: s.plugin,
config: s.config,
username: targetuser,
to: s.to,
}, nil
Expand All @@ -115,19 +114,19 @@ func (s *skelpipePasswordWrapper) TestPassword(conn libplugin.ConnMetadata, pass
}

func (s *skelpipePublicKeyWrapper) AuthorizedKeys(conn libplugin.ConnMetadata) ([]byte, error) {
return s.plugin.loadFileOrDecodeMany(s.from.AuthorizedKeys, s.from.AuthorizedKeysData, map[string]string{
return s.config.loadFileOrDecodeMany(s.from.AuthorizedKeys, s.from.AuthorizedKeysData, map[string]string{
"DOWNSTREAM_USER": conn.User(),
})
}

func (s *skelpipePublicKeyWrapper) TrustedUserCAKeys(conn libplugin.ConnMetadata) ([]byte, error) {
return s.plugin.loadFileOrDecodeMany(s.from.TrustedUserCAKeys, s.from.TrustedUserCAKeysData, map[string]string{
return s.config.loadFileOrDecodeMany(s.from.TrustedUserCAKeys, s.from.TrustedUserCAKeysData, map[string]string{
"DOWNSTREAM_USER": conn.User(),
})
}

func (s *skelpipeToWrapper) PrivateKey(conn libplugin.ConnMetadata) ([]byte, []byte, error) {
p, err := s.plugin.loadFileOrDecode(s.to.PrivateKey, s.to.PrivateKeyData, map[string]string{
p, err := s.config.loadFileOrDecode(s.to.PrivateKey, s.to.PrivateKeyData, map[string]string{
"DOWNSTREAM_USER": conn.User(),
"UPSTREAM_USER": s.username,
})
Expand All @@ -144,19 +143,21 @@ func (s *skelpipeToWrapper) OverridePassword(conn libplugin.ConnMetadata) ([]byt
}

func (p *plugin) listPipe(_ libplugin.ConnMetadata) ([]libplugin.SkelPipe, error) {
config, err := p.loadConfig()
configs, err := p.loadConfig()
if err != nil {
return nil, err
}

var pipes []libplugin.SkelPipe
for _, pipe := range config.Pipes {
wrapper := &skelpipeWrapper{
plugin: p,
pipe: &pipe,
}
pipes = append(pipes, wrapper)
for _, config := range configs {
for _, pipe := range config.Pipes {
wrapper := &skelpipeWrapper{
config: &config,
pipe: &pipe,
}
pipes = append(pipes, wrapper)

}
}

return pipes, nil
Expand Down
58 changes: 38 additions & 20 deletions plugin/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"

"github.com/urfave/cli/v2"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -76,19 +77,20 @@ type yamlPipe struct {
type piperConfig struct {
Version string `yaml:"version"`
Pipes []yamlPipe `yaml:"pipes,flow"`

filename string
}

type plugin struct {
File string
FileGlobs cli.StringSlice
NoCheckPerm bool
}

func newYamlPlugin() *plugin {
return &plugin{}
}

func (p *plugin) checkPerm() error {
filename := p.File
func (p *plugin) checkPerm(filename string) error {
f, err := os.Open(filename)
if err != nil {
return err
Expand All @@ -111,28 +113,44 @@ func (p *plugin) checkPerm() error {
return nil
}

func (p *plugin) loadConfig() (piperConfig, error) {
var config piperConfig
func (p *plugin) loadConfig() ([]piperConfig, error) {
var allconfig []piperConfig

err := p.checkPerm()
if err != nil {
return config, err
}
for _, fg := range p.FileGlobs.Value() {
files, err := filepath.Glob(fg)
if err != nil {
return nil, err
}

configbyte, err := os.ReadFile(p.File)
if err != nil {
return config, err
}
for _, file := range files {

err = yaml.Unmarshal(configbyte, &config)
if err != nil {
return config, err
if err := p.checkPerm(file); err != nil {
return nil, err
}

configbyte, err := os.ReadFile(file)
if err != nil {
return nil, err
}

var config piperConfig

err = yaml.Unmarshal(configbyte, &config)
if err != nil {
return nil, err
}

config.filename = file

allconfig = append(allconfig, config)

}
}

return config, nil
return allconfig, nil
}

func (p *plugin) loadFileOrDecode(file string, base64data string, vars map[string]string) ([]byte, error) {
func (p *piperConfig) loadFileOrDecode(file string, base64data string, vars map[string]string) ([]byte, error) {
if file != "" {

file = os.Expand(file, func(placeholderName string) string {
Expand All @@ -145,7 +163,7 @@ func (p *plugin) loadFileOrDecode(file string, base64data string, vars map[strin
})

if !filepath.IsAbs(file) {
file = filepath.Join(filepath.Dir(p.File), file)
file = filepath.Join(filepath.Dir(p.filename), file)
}

return os.ReadFile(file)
Expand All @@ -158,7 +176,7 @@ func (p *plugin) loadFileOrDecode(file string, base64data string, vars map[strin
return nil, nil
}

func (p *plugin) loadFileOrDecodeMany(files listOrString, base64data listOrString, vars map[string]string) ([]byte, error) {
func (p *piperConfig) loadFileOrDecodeMany(files listOrString, base64data listOrString, vars map[string]string) ([]byte, error) {
var byteSlices [][]byte

for _, file := range files.Combine() {
Expand Down

0 comments on commit cdfd580

Please sign in to comment.