Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
jkroepke committed Nov 19, 2024
1 parent 8c1fb30 commit a0748d0
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 74 deletions.
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ linters:
- funlen
- execinquery
- mnd
- exportloopref

issues:
exclude-rules:
Expand Down
211 changes: 139 additions & 72 deletions internal/smtp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"mime"
"mime/multipart"
"net"
Expand All @@ -27,21 +28,22 @@ type CallBackFn func(mail *MailMessage) error

// Server represents a basic SMTP server.
type Server struct {
Address string
listener net.Listener
CallBackFn CallBackFn
address string
callBackFn CallBackFn

wg sync.WaitGroup
done chan struct{}
listener net.Listener
logger *slog.Logger
wg sync.WaitGroup
done chan struct{}
}

// NewServer creates a new SMTP server instance.
func NewServer(address string, callback CallBackFn) *Server {
func NewServer(address string, logger *slog.Logger, callback CallBackFn) *Server {
return &Server{
Address: address,
CallBackFn: callback,

done: make(chan struct{}, 1),
address: address,
callBackFn: callback,
done: make(chan struct{}, 1),
logger: logger,
}
}

Expand All @@ -52,7 +54,7 @@ type Session struct {

// Start starts the SMTP server and handles incoming connections.
func (s *Server) Start() error {
listener, err := net.Listen("tcp", s.Address)
listener, err := net.Listen("tcp", s.address)
if err != nil {
return fmt.Errorf("error starting SMTP server: %w", err)
}
Expand All @@ -76,7 +78,9 @@ func (s *Server) Start() error {
go func() {
defer s.wg.Done()

s.handleConnection(conn)
if err := s.handleConnection(conn); err != nil {
s.logger.Error("error handling connection", slog.Any("err", err))
}
}()
}
}
Expand All @@ -103,24 +107,36 @@ func (s *Server) Shutdown() error {
}

// handleConnection processes an SMTP client connection.
func (s *Server) handleConnection(conn net.Conn) {
func (s *Server) handleConnection(conn net.Conn) error {
var (
err error
from string
to string
)

defer conn.Close()
defer func(conn net.Conn) {
_ = conn.Close()
}(conn)

reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)

conn.SetDeadline(time.Now().Add(30 * time.Second))
writer.WriteString("220 Welcome to the SMTP server\r\n")
writer.Flush()
if err = conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
return fmt.Errorf("error setting connection deadline: %w", err)
}

if _, err = writer.WriteString("220 Welcome to the SMTP server\r\n"); err != nil {
return fmt.Errorf("error writing welcome message: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

for {
line, err := reader.ReadString('\n')
if err != nil {
return
return fmt.Errorf("error reading command: %w", err)
}
line = strings.TrimSpace(line)

Expand All @@ -132,44 +148,73 @@ func (s *Server) handleConnection(conn net.Conn) {
case strings.HasPrefix(line, "MAIL FROM"):
// Handle MAIL FROM command
if from, err = s.parseAddress(line); err != nil {
_, _ = writer.WriteString(fmt.Sprintf("550 Error: %v\r\n", err))
_ = writer.Flush()
if _, err = writer.WriteString(fmt.Sprintf("550 Error: %v\r\n", err)); err != nil {
return fmt.Errorf("error writing error message: %w", err)
}

if err := writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return
return nil
}
_, _ = writer.WriteString("250 OK\r\n")
if err := writer.Flush(); err != nil {
return

if _, err = writer.WriteString("250 OK\r\n"); err != nil {
return fmt.Errorf("error writing OK response: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}
case strings.HasPrefix(line, "RCPT TO"):
if to, err = s.parseAddress(line); err != nil {
_, _ = writer.WriteString(fmt.Sprintf("550 Error: %v\r\n", err))
_ = writer.Flush()
if _, err = writer.WriteString(fmt.Sprintf("550 Error: %v\r\n", err)); err != nil {
return fmt.Errorf("error writing error message: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return
return nil
}
_, _ = writer.WriteString("250 OK\r\n")
if err := writer.Flush(); err != nil {
return
if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}
case strings.HasPrefix(line, "DATA"):
_, _ = writer.WriteString("354 Start mail input; end with <CRLF>.<CRLF>\r\n")
if err := writer.Flush(); err != nil {
return
if _, err = writer.WriteString("354 Start mail input; end with <CRLF>.<CRLF>\r\n"); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

mailData := collectMailData(reader)
if mailData == "" {
_, _ = writer.WriteString("550 Error reading mail data\r\n")
_ = writer.Flush()
return
if _, err = writer.WriteString("550 Error reading mail data\r\n"); err != nil {
return fmt.Errorf("error writing error message: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return nil
}

msg, err := parseMailData(mailData)
if err != nil {
_, _ = writer.WriteString(fmt.Sprintf("550 Error processing mail: %v\r\n", err))
_ = writer.Flush()
return
if _, err = writer.WriteString(fmt.Sprintf("550 Error processing mail: %v\r\n", err)); err != nil {
return fmt.Errorf("error writing error message: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return nil
}

if msg.From != from {
Expand All @@ -181,26 +226,33 @@ func (s *Server) handleConnection(conn net.Conn) {
}

// Invoke the callback function
if err := s.CallBackFn(msg); err != nil {
if err := s.callBackFn(msg); err != nil {
_, _ = writer.WriteString(fmt.Sprintf("550 Error processing mail: %v\r\n", err))
_ = writer.Flush()
return
if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return nil
}

_, _ = writer.WriteString("250 OK\r\n")
if err := writer.Flush(); err != nil {
return
if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

case strings.HasPrefix(line, "QUIT"):
_, _ = writer.WriteString("221 Bye\r\n")
_ = writer.Flush()
return // Close connection after QUIT command
if _, err = writer.WriteString("221 Bye\r\n"); err != nil {
return fmt.Errorf("error writing QUIT response: %w", err)
}

if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

return nil // Close connection after QUIT command
case strings.HasPrefix(line, "NOOP"):
_, _ = writer.WriteString("250 OK\r\n")
if err := writer.Flush(); err != nil {
return
if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}

case line == ".":
Expand All @@ -209,8 +261,8 @@ func (s *Server) handleConnection(conn net.Conn) {

default:
_, _ = writer.WriteString("250 OK\r\n")
if err := writer.Flush(); err != nil {
return
if err = writer.Flush(); err != nil {
return fmt.Errorf("error flushing writer: %w", err)
}
}
}
Expand Down Expand Up @@ -306,26 +358,40 @@ func processMultipartMessage(bodyReader io.Reader, boundary string, mailMessage
multipartReader := multipart.NewReader(bodyReader, boundary)

for {
part, err := multipartReader.NextPart()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("error reading multipart message: %w", err)
}
defer part.Close()
if err := func() error {
part, err := multipartReader.NextPart()
if errors.Is(err, io.EOF) {
return io.EOF
}

// Process each part
partContentType := part.Header.Get("Content-Type")
partData, err := io.ReadAll(part)
if err != nil {
return fmt.Errorf("error reading part: %w", err)
}
if err != nil {
return fmt.Errorf("error reading multipart message: %w", err)
}

defer func(part *multipart.Part) {
_ = part.Close()
}(part)

// Process each part
partContentType := part.Header.Get("Content-Type")
partData, err := io.ReadAll(part)
if err != nil {
return fmt.Errorf("error reading part: %w", err)
}

if strings.HasPrefix(partContentType, "text/plain") {
mailMessage.PlainText = strings.TrimSpace(string(partData))
} else if strings.HasPrefix(partContentType, "text/html") {
mailMessage.HTMLText = strings.TrimSpace(string(partData))
}

if strings.HasPrefix(partContentType, "text/plain") {
mailMessage.PlainText = strings.TrimSpace(string(partData))
} else if strings.HasPrefix(partContentType, "text/html") {
mailMessage.HTMLText = strings.TrimSpace(string(partData))
return nil
}(); err != nil {
if errors.Is(err, io.EOF) {
break
}

return err
}
}

Expand All @@ -340,11 +406,12 @@ func (s *Server) handleEHLO(writer *bufio.Writer) {
}

// Send the EHLO response
writer.WriteString("250-Hello\r\n") // "250" is the response code for a successful command
_, _ = writer.WriteString("250-Hello\r\n") // "250" is the response code for a successful command
for _, ext := range extensions {
writer.WriteString(ext + "\r\n")
_, _ = writer.WriteString(ext + "\r\n")
}
writer.WriteString("250 OK\r\n") // End of EHLO response

_, _ = writer.WriteString("250 OK\r\n") // End of EHLO response
_ = writer.Flush()
}

Expand Down
6 changes: 5 additions & 1 deletion internal/smtp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package smtp_test

import (
"fmt"
"log/slog"
"net/smtp"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -79,8 +81,10 @@ Your Name`,
} {
t.Run(test.name, func(t *testing.T) {

logger := slog.New(slog.NewTextHandler(os.Stderr, nil))

// Create a new server
s := smtpserver.NewServer("localhost:1515", func(mail *smtpserver.MailMessage) error {
s := smtpserver.NewServer("localhost:1515", logger, func(mail *smtpserver.MailMessage) error {
assert.Equal(t, test.expectedPlainText, strings.ReplaceAll(mail.PlainText, "\r\n", "\n"))
assert.Equal(t, test.expectedHTMLText, strings.ReplaceAll(mail.HTMLText, "\r\n", "\n"))
assert.Equal(t, test.subject, mail.Subject)
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func run(logger *slog.Logger) error {

go func() {
if err := server.Start(); err != nil {
fmt.Printf("Error: %v\n", err)
errCh <- fmt.Errorf("failed to start SMTP server: %w", err)
}

close(errCh)
Expand Down

0 comments on commit a0748d0

Please sign in to comment.