diff --git a/host_test.go b/host_test.go index d29a875a..320802e0 100644 --- a/host_test.go +++ b/host_test.go @@ -9,12 +9,22 @@ import ( "io/ioutil" "strings" "testing" + "time" "github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/sshd" "golang.org/x/crypto/ssh" ) +func nextScanToken(scanner *bufio.Scanner, i int) *bufio.Scanner { + count := 0 + for count < i { + scanner.Scan() + count++ + } + return scanner +} + func stripPrompt(s string) string { pos := strings.LastIndex(s, "\033[K") if pos < 0 { @@ -107,9 +117,7 @@ func TestHostNameCollision(t *testing.T) { scanner := bufio.NewScanner(r) // Consume the initial buffer - scanner.Scan() - scanner.Scan() - scanner.Scan() + nextScanToken(scanner, 3) actual := scanner.Text() if !strings.HasPrefix(actual, "[Guest1] ") { @@ -124,6 +132,87 @@ func TestHostNameCollision(t *testing.T) { <-done } +func TestMotdCommand(t *testing.T) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + + auth := NewAuth() + config := sshd.MakeAuth(auth) + config.AddHostKey(key) + + s, err := sshd.ListenSSH("localhost:0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + host := NewHost(s, auth) + go host.Serve() + + err = sshd.ConnectShell(s.Addr().String(), "baz", func(r io.Reader, w io.WriteCloser) error { + if err != nil { + t.Error(err) + } + member, _ := host.Room.MemberById("baz") + if member == nil { + return errors.New("failed to load MemberById") + } + + scanner := bufio.NewScanner(r) + testMotd := "foobar" + host.motd = testMotd + + // Test as regular user with no parameters - expected behaviour: should print the MOTD + w.Write([]byte("/motd\r\n")) + + // Consuming buffer + nextScanToken(scanner, 3) + + actual := scanner.Text() + actual = stripPrompt(actual)[3:] + expected := "foobar" + if strings.Compare(actual, expected) != 0 { + t.Error("failed to print MOTD using /motd with no parameters", "actual:", actual, "expected:", expected) + } + + // Test as regular user - expected behaviour: should return an error + w.Write([]byte("/motd foobarbaz\r\n")) + if strings.Compare(host.motd, "foobar") != 0 { + t.Error("failed to hinder non-OPs to modify the MOTD") + } + + // Test as OP - expected behaviour: should modify the MOTD + host.Room.Ops.Add(member) + testMotd = "barfoo" + w.Write([]byte("/motd barfoo\r\n")) + + // Fix this during the code-review process + time.Sleep(time.Millisecond * 500) + + if strings.Compare(host.motd, testMotd) != 0 { + t.Error("failed to allow OPs to modify the MOTD") + } + + // Get around rate limitation + time.Sleep(time.Second * 3) + + // Test as OP - expected behaviour: should print the MOTD even if OP + w.Write([]byte("/motd\r\n")) + + nextScanToken(scanner, 8) + + actual = scanner.Text() + actual = stripPrompt(actual)[3:] + expected = "barfoo" + if strings.Compare(actual, expected) != 0 { + t.Error("failed to print MOTD using /motd with no parameters - as OP") + } + + return nil + }) +} + func TestHostWhitelist(t *testing.T) { key, err := sshd.NewRandomSigner(512) if err != nil {