diff --git a/cmd/step/main.go b/cmd/step/main.go index daeb4a5fe..473c6867f 100644 --- a/cmd/step/main.go +++ b/cmd/step/main.go @@ -19,6 +19,7 @@ import ( "github.com/smallstep/cli/usage" // Enabled commands + _ "github.com/smallstep/cli/command/base64" _ "github.com/smallstep/cli/command/ca" _ "github.com/smallstep/cli/command/certificate" _ "github.com/smallstep/cli/command/crypto" diff --git a/command/base64/base64.go b/command/base64/base64.go new file mode 100644 index 000000000..d5278bd79 --- /dev/null +++ b/command/base64/base64.go @@ -0,0 +1,138 @@ +package base64 + +import ( + "bytes" + "encoding/base64" + "fmt" + "os" + "strings" + + "github.com/pkg/errors" + "github.com/smallstep/cli/command" + "github.com/smallstep/cli/utils" + "github.com/urfave/cli" +) + +func init() { + cmd := cli.Command{ + Name: "base64", + Action: command.ActionFunc(base64Action), + Usage: "encodes and decodes using base64 representation", + UsageText: `**step base64** [**-d**|**--decode**] [**-r**|**--raw**] [**-u**|**--url**]`, + Description: `**step base64** implements base64 encoding as specified by RFC 4648. + +## Examples + +Encode to base64 using the standard encoding: +''' +$ echo -n This is the string to encode | step base64 +VGhpcyBpcyB0aGUgc3RyaW5nIHRvIGVuY29kZQ== +$ step base64 This is the string to encode +VGhpcyBpcyB0aGUgc3RyaW5nIHRvIGVuY29kZQ== +''' + +Decode a base64 encoded string: +''' +$ echo VGhpcyBpcyB0aGUgc3RyaW5nIHRvIGVuY29kZQ== | step base64 -d +This is the string to encode +''' + +Encode to base64 without padding: +''' +$ echo -n This is the string to encode | step base64 -r +VGhpcyBpcyB0aGUgc3RyaW5nIHRvIGVuY29kZQ +$ step base64 -r This is the string to encode +VGhpcyBpcyB0aGUgc3RyaW5nIHRvIGVuY29kZQ +''' + +Encode to base64 using the url encoding: +''' +$ echo 'abc123$%^&*()_+-=~' | step base64 -u +YWJjMTIzJCVeJiooKV8rLT1-Cg== +''' + +Decode an url encoded base64 string. The encoding type can be enforced +using the '-u' or '-r' flags, but it will be autodetected if they are not +passed: +''' +$ echo YWJjMTIzJCVeJiooKV8rLT1-Cg== | step base64 -d +abc123$%^&*()_+-=~ +$ echo YWJjMTIzJCVeJiooKV8rLT1-Cg== | step base64 -d -u +abc123$%^&*()_+-=~ +'''`, + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "d,decode", + Usage: "decode base64 input", + }, + cli.BoolFlag{ + Name: "r,raw", + Usage: "use the unpadded base64 encoding", + }, + cli.BoolFlag{ + Name: "u,url", + Usage: "use the encoding format typically used in URLs and file names", + }, + }, + } + + command.Register(cmd) +} + +func base64Action(ctx *cli.Context) error { + var err error + var data []byte + isDecode := ctx.Bool("decode") + + if ctx.NArg() > 0 { + data = []byte(strings.Join(ctx.Args(), " ")) + } else { + var prompt string + if isDecode { + prompt = "Please enter text to decode" + } else { + prompt = "Please enter text to encode" + } + + if data, err = utils.ReadInput(prompt); err != nil { + return err + } + } + + enc := getEncoder(ctx, data) + if isDecode { + b, err := enc.DecodeString(string(data)) + if err != nil { + return errors.Wrap(err, "error decoding input") + } + os.Stdout.Write(b) + } else { + fmt.Println(enc.EncodeToString(data)) + } + + return nil +} + +func getEncoder(ctx *cli.Context, data []byte) *base64.Encoding { + raw := ctx.Bool("raw") + url := ctx.Bool("url") + isDecode := ctx.Bool("decode") + + // Detect encoding + if isDecode && !ctx.IsSet("raw") && !ctx.IsSet("url") { + raw = !bytes.HasSuffix(bytes.TrimSpace(data), []byte("=")) + url = bytes.Contains(data, []byte("-")) || bytes.Contains(data, []byte("_")) + } + + if raw { + if url { + return base64.RawURLEncoding + } + return base64.RawStdEncoding + } + if url { + return base64.URLEncoding + } + + return base64.StdEncoding +} diff --git a/command/crypto/jwt/sign.go b/command/crypto/jwt/sign.go index 7726b44fa..74b75438f 100644 --- a/command/crypto/jwt/sign.go +++ b/command/crypto/jwt/sign.go @@ -372,7 +372,7 @@ func readPayload(filename string) (interface{}, error) { if err != nil { return nil, errors.Wrap(err, "error reading data") } - if st.Size() == 0 { + if st.Size() == 0 && st.Mode()&os.ModeNamedPipe == 0 { return make(map[string]interface{}), nil } r = os.Stdin diff --git a/utils/read.go b/utils/read.go index 4e56b6022..95414e52d 100644 --- a/utils/read.go +++ b/utils/read.go @@ -18,6 +18,9 @@ import ( // indicates STDIN as a file to be read. const stdinFilename = "-" +// stdin points to os.Stdin. +var stdin = os.Stdin + // FileExists is a wrapper on os.Stat that returns false if os.Stat returns an // error, it returns true otherwise. This method does not care if os.Stat // returns any other kind of errors. @@ -69,26 +72,24 @@ func ReadStringPasswordFromFile(filename string) (string, error) { // ReadInput from stdin if something is detected or ask the user for an input // using the given prompt. func ReadInput(prompt string) ([]byte, error) { - st, err := os.Stdin.Stat() + st, err := stdin.Stat() if err != nil { return nil, errors.Wrap(err, "error reading data") } - if st.Size() > 0 { - return ReadAll(os.Stdin) + if st.Size() == 0 && st.Mode()&os.ModeNamedPipe == 0 { + return ui.PromptPassword(prompt) } - return ui.PromptPassword(prompt) + return ReadAll(stdin) } -var _osStdin = os.Stdin - // ReadFile returns the contents of the file identified by name. It reads from // STDIN if name is a hyphen ("-"). func ReadFile(name string) (b []byte, err error) { if name == stdinFilename { name = "/dev/stdin" - b, err = ioutil.ReadAll(_osStdin) + b, err = ioutil.ReadAll(stdin) } else { b, err = ioutil.ReadFile(name) } diff --git a/utils/read_test.go b/utils/read_test.go index d04481427..b80cb8037 100644 --- a/utils/read_test.go +++ b/utils/read_test.go @@ -2,19 +2,42 @@ package utils import ( "bytes" + "fmt" "io" "io/ioutil" "os" + "reflect" "testing" "github.com/stretchr/testify/require" ) +type mockReader struct { + n int + err error +} + +func (r *mockReader) Read(p []byte) (int, error) { + return r.n, r.err +} + // Helper function for setting os.Stdin for mocking in tests. func setStdin(new *os.File) (cleanup func()) { - old := _osStdin - _osStdin = new - return func() { _osStdin = old } + old := stdin + stdin = new + return func() { stdin = old } +} + +// Returns a temp file and a cleanup function to delete it. +func newFile(t *testing.T, data []byte) (file *os.File, cleanup func()) { + f, err := ioutil.TempFile("" /* dir */, "utils-read-test") + require.NoError(t, err) + // write to temp file and reset read cursor to beginning of file + _, err = f.Write(data) + require.NoError(t, err) + _, err = f.Seek(0, io.SeekStart) + require.NoError(t, err) + return f, func() { os.Remove(f.Name()) } } func TestFileExists(t *testing.T) { @@ -43,6 +66,66 @@ func TestFileExists(t *testing.T) { } } +func TestReadAll(t *testing.T) { + content := []byte("read all this") + + type args struct { + r io.Reader + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {"ok", args{bytes.NewReader(content)}, content, false}, + {"fail", args{&mockReader{err: fmt.Errorf("this is an error")}}, []byte{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ReadAll(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("ReadAll() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ReadAll() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestReadString(t *testing.T) { + c1 := []byte("read all this") + c2 := []byte("read all this\n and all that") + + type args struct { + r io.Reader + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"ok", args{bytes.NewReader(c1)}, "read all this", false}, + {"ok with new line", args{bytes.NewReader(c2)}, "read all this", false}, + {"fail", args{&mockReader{err: fmt.Errorf("this is an error")}}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ReadString(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("ReadString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ReadString() = %v, want %v", got, tt.want) + } + }) + } +} + func TestReadFile(t *testing.T) { content := []byte("my file content") f, cleanup := newFile(t, content) @@ -84,14 +167,40 @@ func TestStringReadPasswordFromFile(t *testing.T) { require.Equal(t, "my-password-on-file", s, "expected %s to equal %s", s, content) } -// Returns a temp file and a cleanup function to delete it. -func newFile(t *testing.T, data []byte) (file *os.File, cleanup func()) { - f, err := ioutil.TempFile("" /* dir */, "utils-read-test") - require.NoError(t, err) - // write to temp file and reset read cursor to beginning of file - _, err = f.Write(data) - require.NoError(t, err) - _, err = f.Seek(0, io.SeekStart) - require.NoError(t, err) - return f, func() { os.Remove(f.Name()) } +func TestReadInput(t *testing.T) { + + type args struct { + prompt string + } + tests := []struct { + name string + args args + before func() func() + want []byte + wantErr bool + }{ + {"ok", args{"Write input"}, func() func() { + content := []byte("my file content") + mockStdin, cleanup := newFile(t, content) + reset := setStdin(mockStdin) + return func() { + defer cleanup() + reset() + } + }, []byte("my file content"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cleanup := tt.before() + defer cleanup() + got, err := ReadInput(tt.args.prompt) + if (err != nil) != tt.wantErr { + t.Errorf("ReadInput() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ReadInput() = %v, want %v", got, tt.want) + } + }) + } }