From 677f1cb541a328075397fb929c339b37886d3918 Mon Sep 17 00:00:00 2001 From: Roman Sharkov Date: Fri, 14 Jun 2024 20:05:36 +0200 Subject: [PATCH] fix: Watch dirs recursively and dynamically Automatically start watching created directories. Automatically stop watching removed/renamed directories. --- config.go | 25 +- go.mod | 2 +- go.sum | 6 +- debounce.go => internal/debounce/debounce.go | 10 +- internal/watcher/watcher.go | 206 +++++++++ internal/watcher/watcher_test.go | 252 +++++++++++ main.go | 441 +++++-------------- server.go | 210 +++++++++ 8 files changed, 801 insertions(+), 351 deletions(-) rename debounce.go => internal/debounce/debounce.go (68%) create mode 100644 internal/watcher/watcher.go create mode 100644 internal/watcher/watcher_test.go create mode 100644 server.go diff --git a/config.go b/config.go index 1a26b37..2f19aca 100644 --- a/config.go +++ b/config.go @@ -4,17 +4,26 @@ import ( "encoding" "flag" "fmt" + "os" + "path" + "path/filepath" "strings" "time" "github.com/romshark/yamagiconf" ) +var config Config + type Config struct { + serverOutPath string // Initialized from os.Getwd and os.TempDir + App struct { // DirSrcRoot is the source root directory for the application server. DirSrcRoot string `yaml:"dir-src-root" validate:"dirpath,required"` + dirSrcRootAbsolute string // Initialized from DirSrcRoot + // DirCmd is the server cmd directory containing the `main` function. DirCmd string `yaml:"dir-cmd" validate:"dirpath,required"` @@ -62,11 +71,6 @@ type Config struct { } `yaml:"tls"` } -var ( - serverOutPath string - config Config -) - func mustParseConfig() { var fConfigPath string flag.StringVar(&fConfigPath, "config", "./templier.yml", "config file path") @@ -89,6 +93,17 @@ func mustParseConfig() { panic(fmt.Errorf("reading config file: %w", err)) } } + + workingDir, err := os.Getwd() + if err != nil { + panic(fmt.Errorf("getting working dir: %w", err)) + } + config.serverOutPath = path.Join(os.TempDir(), workingDir) + + config.App.dirSrcRootAbsolute, err = filepath.Abs(config.App.DirSrcRoot) + if err != nil { + panic(fmt.Errorf("getting absolute path for app.dir-src-root: %w", err)) + } } type SpaceSeparatedList []string diff --git a/go.mod b/go.mod index ffc5fdc..e3dc4b1 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/a-h/templ v0.2.707 github.com/fatih/color v1.17.0 github.com/fsnotify/fsnotify v1.7.0 - github.com/gorilla/websocket v1.5.2 + github.com/gorilla/websocket v1.5.3 github.com/romshark/yamagiconf v0.10.4 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index c875076..1d28533 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/go-playground/validator/v10 v10.21.0/go.mod h1:dbuPbCMFw/DrkbEynArYaC github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/gorilla/websocket v1.5.2 h1:qoW6V1GT3aZxybsbC6oLnailWnB+qTMVwMreOso9XUw= -github.com/gorilla/websocket v1.5.2/go.mod h1:0n9H61RBAcf5/38py2MCYbxzPIY9rOkpvvMT24Rqs30= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -70,8 +70,6 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= diff --git a/debounce.go b/internal/debounce/debounce.go similarity index 68% rename from debounce.go rename to internal/debounce/debounce.go index 8f84a12..c868c4b 100644 --- a/debounce.go +++ b/internal/debounce/debounce.go @@ -1,4 +1,4 @@ -package main +package debounce import ( "context" @@ -6,9 +6,15 @@ import ( "time" ) -func NewDebouncedSync(duration time.Duration) ( +// NewSync creates a new concurrency-safe debouncer. +func NewSync(duration time.Duration) ( runDebouncer func(ctx context.Context), trigger func(fn func()), ) { + if duration == 0 { + // Debounce disabled, execute fn immediately. + return func(context.Context) { /*Noop*/ }, func(fn func()) { fn() } + } + var lock sync.Mutex var fn func() ticker := time.NewTicker(duration) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go new file mode 100644 index 0000000..55b057f --- /dev/null +++ b/internal/watcher/watcher.go @@ -0,0 +1,206 @@ +package watcher + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/fsnotify/fsnotify" +) + +// Watcher is a recursive file watcher. +type Watcher struct { + lock sync.Mutex + closed bool + watchedDirs map[string]struct{} // dir path -> closer channel + onChange func(ctx context.Context, e fsnotify.Event) + watcher *fsnotify.Watcher +} + +// New creates a new file watcher that executes onChange for any +// remove/create/change/chmod filesystem event. +// onChange will receive the ctx that was passed to Run. +func New(onChange func(ctx context.Context, e fsnotify.Event)) (*Watcher, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + return &Watcher{ + watchedDirs: make(map[string]struct{}), + onChange: onChange, + watcher: watcher, + }, nil +} + +var ErrClosed = errors.New("closed") + +// RangeWatchedDirs calls fn for every currently watched directory. +// Noop if the watcher is closed. +func (w *Watcher) RangeWatchedDirs(fn func(path string) (continueIter bool)) { + w.lock.Lock() + defer w.lock.Unlock() + if w.closed { + return + } + for p := range w.watchedDirs { + if !fn(p) { + return + } + } +} + +// Close stops watching everything and closes the watcher. +// Noop if the watcher is closed. +func (w *Watcher) Close() error { + w.lock.Lock() + defer w.lock.Unlock() + if w.closed { + return nil + } + w.closed = true + return w.watcher.Close() +} + +// Run runs the watcher. +// Returns ErrClosed if already closed. +func (w *Watcher) Run(ctx context.Context) error { + w.lock.Lock() + if w.closed { + w.lock.Unlock() + return ErrClosed + } + w.lock.Unlock() + + defer w.Close() + for { + select { + case <-ctx.Done(): + return ctx.Err() // Watching canceled + case e := <-w.watcher.Events: + switch e.Op { + case fsnotify.Create, fsnotify.Remove, fsnotify.Rename: + if w.isDirEvent(e) { + switch e.Op { + case fsnotify.Create: + // New sub-directory was created, start watching it. + if err := w.Add(e.Name); err != nil { + return fmt.Errorf("adding created directory: %w", err) + } + case fsnotify.Remove, fsnotify.Rename: + // Sub-directory was removed or renamed, stop watching it. + // A new create notification will readd it. + if err := w.Remove(e.Name); err != nil { + return fmt.Errorf("removing directory: %w", err) + } + } + } + case 0: + continue + } + w.onChange(ctx, e) + case err := <-w.watcher.Errors: + if err != nil { + return fmt.Errorf("watching: %w", err) + } + } + } +} + +// Add starts watching the directory and all of its subdirectories recursively. +// Returns ErrClosed if the watcher is already closed. +func (w *Watcher) Add(dir string) error { + w.lock.Lock() + defer w.lock.Unlock() + if w.closed { + return ErrClosed + } + err := forEachDir(dir, func(dir string) error { + if _, ok := w.watchedDirs[dir]; ok { + return errAlreadyWatched // Directory already watched + } + w.watchedDirs[dir] = struct{}{} + return w.watcher.Add(dir) + }) + if err == errAlreadyWatched { + return nil + } + return err +} + +var errAlreadyWatched = errors.New("directory already watched") + +// Remove stops watching the directory and all of its subdirectories recursively. +// Returns ErrClosed if the watcher is already closed. +func (w *Watcher) Remove(dir string) error { + w.lock.Lock() + defer w.lock.Unlock() + if w.closed { + return ErrClosed + } + + if _, ok := w.watchedDirs[dir]; !ok { + return nil + } + delete(w.watchedDirs, dir) + if err := w.removeWatcher(dir); err != nil { + return err + } + + // Stop all sub-directory watchers + for p := range w.watchedDirs { + if strings.HasPrefix(p, dir) { + delete(w.watchedDirs, p) + if err := w.removeWatcher(dir); err != nil { + return err + } + } + } + + return nil +} + +// removeWatcher ignores ErrNonExistentWatch when removing a watcher. +func (w *Watcher) removeWatcher(dir string) error { + if err := w.watcher.Remove(dir); err != nil { + if !errors.Is(err, fsnotify.ErrNonExistentWatch) { + return err + } + } + return nil +} + +func (w *Watcher) isDirEvent(e fsnotify.Event) bool { + switch e.Op { + case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: + fileInfo, err := os.Stat(e.Name) + if err != nil { + return false + } + return fileInfo.IsDir() + } + _, ok := w.watchedDirs[e.Name] + return ok +} + +// forEachDir executes fn for every subdirectory of pathDir, +// including pathDir itself, recursively. +func forEachDir(pathDir string, fn func(dir string) error) error { + // Use filepath.Walk to traverse directories + err := filepath.Walk(pathDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err // Stop walking the directory tree. + } + if !info.IsDir() { + return nil // Continue walking. + } + if err = fn(path); err != nil { + return err + } + return nil + }) + return err +} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 0000000..21b732f --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,252 @@ +package watcher_test + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/romshark/templier/internal/watcher" + + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/require" +) + +func TestWatcher(t *testing.T) { + notifications := make(chan fsnotify.Event, 10) // Expect 10 events + w, err := watcher.New(func(ctx context.Context, e fsnotify.Event) { + notifications <- e + }) + require.NoError(t, err) + defer func() { require.NoError(t, w.Close()) }() + + go func() { require.NoError(t, w.Run(context.Background())) }() + + base := t.TempDir() + + // Create a sub-directory that exists even before Run + MustMkdir(t, base, "existing-subdir") + + require.NoError(t, w.Add(base)) + + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "existing-subdir"), + }) + + events := make([]fsnotify.Event, cap(notifications)) + + // After every operation, wait for fsnotify to trigger, + // otherwise events might get lost. + MustCreateFile(t, base, "newfile") + events[0] = <-notifications + + MustMkdir(t, base, "newdir") + events[1] = <-notifications + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "existing-subdir"), + filepath.Join(base, "newdir"), + }) + + MustMkdir(t, base, "newdir", "subdir") + events[2] = <-notifications + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "existing-subdir"), + filepath.Join(base, "newdir"), + filepath.Join(base, "newdir", "subdir"), + }) + + MustCreateFile(t, base, "newdir", "subdir", "subfile") + events[3] = <-notifications + + MustCreateFile(t, base, "newdir", "subdir", "subfile2") + events[4] = <-notifications + + MustCreateFile(t, base, "existing-subdir", "subfile3") + events[5] = <-notifications + + MustRemove(t, base, "existing-subdir", "subfile3") + events[6] = <-notifications + + MustRemove(t, base, "existing-subdir") + events[7] = <-notifications + + // Renaming will generate two events, first the renaming event and later + // the event of creation of a new directory. + MustRename(t, filepath.Join(base, "newdir"), filepath.Join(base, "newname")) + events[8] = <-notifications + events[9] = <-notifications + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "newname"), + filepath.Join(base, "newname/subdir"), + }) + + require.Len(t, notifications, 0, + "notifications channel buffer must now be empty") + + // Event 0 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newfile"), + }) + // Event 1 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newdir"), + }) + // Event 2 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newdir", "subdir"), + }) + // Event 3 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newdir", "subdir", "subfile"), + }) + // Event 4 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newdir", "subdir", "subfile2"), + }) + // Event 5 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "existing-subdir", "subfile3"), + }) + // Event 6 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Remove, + Name: filepath.Join(base, "existing-subdir", "subfile3"), + }) + // Event 7 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Remove, + Name: filepath.Join(base, "existing-subdir"), + }) + // Event 8 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Rename, + Name: filepath.Join(base, "newdir"), + }) + // Event 9 + require.Contains(t, events, fsnotify.Event{ + Op: fsnotify.Create, + Name: filepath.Join(base, "newname"), + }) +} + +func TestWatcherClosed(t *testing.T) { + w, err := watcher.New(func(ctx context.Context, e fsnotify.Event) {}) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { require.ErrorIs(t, w.Run(ctx), context.Canceled) }() + require.NoError(t, w.Add(t.TempDir())) + cancel() // Close + + tempDir := t.TempDir() + + require.ErrorIs(t, w.Add(filepath.Join(tempDir, "new")), watcher.ErrClosed) + require.ErrorIs(t, w.Remove(filepath.Join(tempDir, "new")), watcher.ErrClosed) + require.ErrorIs(t, w.Run(context.Background()), watcher.ErrClosed) + + ExpectWatched(t, w, []string{}) +} + +func TestWatcherAdd_AlreadyWatched(t *testing.T) { + w, err := watcher.New(func(ctx context.Context, e fsnotify.Event) {}) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { require.ErrorIs(t, w.Run(ctx), context.Canceled) }() + + tempDir := t.TempDir() + ExpectWatched(t, w, []string{}) + require.NoError(t, w.Add(tempDir)) + ExpectWatched(t, w, []string{tempDir}) + require.NoError(t, w.Add(tempDir)) // Add again + ExpectWatched(t, w, []string{tempDir}) +} + +func TestWatcherRemove(t *testing.T) { + w, err := watcher.New(func(ctx context.Context, e fsnotify.Event) {}) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { require.ErrorIs(t, w.Run(ctx), context.Canceled) }() + + base := t.TempDir() + MustMkdir(t, base, "sub") + MustMkdir(t, base, "sub", "subsub") + MustMkdir(t, base, "sub", "subsub2") + MustMkdir(t, base, "sub", "subsub2", "subsubsub") + MustMkdir(t, base, "sub2") + + ExpectWatched(t, w, []string{}) + + require.NoError(t, w.Add(base)) + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "sub"), + filepath.Join(base, "sub", "subsub"), + filepath.Join(base, "sub", "subsub2"), + filepath.Join(base, "sub", "subsub2", "subsubsub"), + filepath.Join(base, "sub2"), + }) + + require.NoError(t, w.Remove(filepath.Join(base, "sub", "subsub2", "subsubsub"))) + ExpectWatched(t, w, []string{ + base, + filepath.Join(base, "sub"), + filepath.Join(base, "sub", "subsub"), + filepath.Join(base, "sub", "subsub2"), + filepath.Join(base, "sub2"), + }) + + require.NoError(t, w.Remove(base)) + ExpectWatched(t, w, []string{}) +} + +func ExpectWatched(t *testing.T, w *watcher.Watcher, expect []string) { + t.Helper() + actual := []string{} + w.RangeWatchedDirs(func(path string) (continueIter bool) { + actual = append(actual, path) + return true + }) + require.Len(t, actual, len(expect), "actual: %v", actual) + for _, exp := range expect { + require.Contains(t, actual, exp) + } +} + +func MustMkdir(t *testing.T, pathParts ...string) { + t.Helper() + err := os.Mkdir(filepath.Join(pathParts...), 0o777) + require.NoError(t, err) +} + +func MustCreateFile(t *testing.T, pathParts ...string) *os.File { + t.Helper() + f, err := os.Create(filepath.Join(pathParts...)) + require.NoError(t, err) + return f +} + +func MustRemove(t *testing.T, pathParts ...string) { + t.Helper() + err := os.Remove(filepath.Join(pathParts...)) + require.NoError(t, err) +} + +func MustRename(t *testing.T, from, to string) { + t.Helper() + err := os.Rename(from, to) + require.NoError(t, err) +} diff --git a/main.go b/main.go index b1098b4..22d6f84 100644 --- a/main.go +++ b/main.go @@ -1,31 +1,28 @@ package main import ( - "bytes" "context" "crypto/tls" _ "embed" "errors" "fmt" - "io" - "net" "net/http" "net/url" "os" "os/exec" "os/signal" "path" - "path/filepath" "strconv" "strings" "sync" "sync/atomic" - "syscall" "time" + "github.com/romshark/templier/internal/debounce" + "github.com/romshark/templier/internal/watcher" + "github.com/fatih/color" "github.com/fsnotify/fsnotify" - "github.com/gorilla/websocket" ) const ( @@ -85,20 +82,6 @@ var ( func main() { mustParseConfig() - workingDir, err := os.Getwd() - if err != nil { - panic(fmt.Errorf("getting working dir: %w", err)) - } - serverOutPath = path.Join(os.TempDir(), workingDir) - - templierBaseURL := url.URL{ - Scheme: "http", - Host: config.TemplierHost, - } - if config.TLS != nil { - templierBaseURL.Scheme = "https" - } - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() @@ -128,22 +111,52 @@ func main() { go runTemplierServer() go runAppLauncher() - // Watch all files except "*.templ" and the .git subdir - mustGoWatchDir( - ctx, config.App.DirSrcRoot, - config.Debounce.Go, onFileChangedRebuildServer, - ".git", - ) - // Watch all .templ and regenerate templates - mustGoWatchDir( - ctx, config.App.DirSrcRoot, - config.Debounce.Templ, onTemplFileChangedGenTemplates, - ) + debouncerTempl, debouncedTempl := debounce.NewSync(config.Debounce.Templ) + go debouncerTempl(ctx) + + debouncerGo, debouncedGo := debounce.NewSync(config.Debounce.Go) + go debouncerGo(ctx) + + watcher, err := watcher.New(func(ctx context.Context, e fsnotify.Event) { + debounce := debouncedGo + if isTemplFile(e.Name) { + // Use different debouncer for .templ files + debounce = debouncedTempl + } + debounce(func() { onFileChanged(ctx, e) }) + }) + if err != nil { + fmt.Printf("🤖 ERR: initializing file watcher: %v", err) + os.Exit(1) + } - fmt.Print("🤖 templier ") - fGreen.Print("started") - fmt.Print(" on ") - fBlueUnderline.Println(templierBaseURL.String()) + go func() { + if err := watcher.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + panic(fmt.Errorf("running file watcher:%w", err)) + } + }() + + fmt.Println("ADD", config.App.dirSrcRootAbsolute) + if err := watcher.Add(config.App.dirSrcRootAbsolute); err != nil { + fmt.Printf("🤖 ERR: setting up file watcher for app.dir-src-root(%q): %v", + config.App.dirSrcRootAbsolute, err) + os.Exit(1) + } + + { + templierBaseURL := url.URL{ + Scheme: "http", + Host: config.TemplierHost, + } + if config.TLS != nil { + templierBaseURL.Scheme = "https" + } + + fmt.Print("🤖 templier ") + fGreen.Print("started") + fmt.Print(" on ") + fBlueUnderline.Println(templierBaseURL.String()) + } <-ctx.Done() chStopServer <- struct{}{} @@ -306,135 +319,67 @@ AWAIT_COMMAND: // fileChangedLock prevents more than one rebuilder goroutine at a time. var fileChangedLock sync.Mutex -func onFileChangedRebuildServer(e fsnotify.Event) { - switch currentState.Load().(State).Type { - case StateTypeErrTempl, StateTypeErrGolangCILint: - return - } +func onFileChanged(ctx context.Context, e fsnotify.Event) { + switch { + case isTemplFile(e.Name): + fileChangedLock.Lock() + defer fileChangedLock.Unlock() + + var operation string + switch e.Op { + case fsnotify.Create: + operation = "created" + case fsnotify.Write: + operation = "changed" + case fsnotify.Remove: + operation = "removed" + default: + return + } - fileChangedLock.Lock() - defer fileChangedLock.Unlock() + fmt.Print("🤖 template file ") + fmt.Print(operation) + fmt.Print(": ") + fCyanUnderline.Println(e.Name) - chMsgClients <- bytesMsgReloadInitiated + runTemplGenerate(ctx, e.Name) - var operation string - switch e.Op { - case fsnotify.Create: - operation = "created" - case fsnotify.Write: - operation = "changed" - case fsnotify.Remove: - operation = "removed" default: - return - } + switch currentState.Load().(State).Type { + case StateTypeErrTempl, StateTypeErrGolangCILint: + return + } - // Ignore .templ files, another watcher will take care of them. - if strings.HasSuffix(e.Name, ".templ") { - return - } + fileChangedLock.Lock() + defer fileChangedLock.Unlock() - fmt.Print("🤖 file ") - fmt.Print(operation) - fmt.Print(": ") - fBlueUnderline.Println(e.Name) + chMsgClients <- bytesMsgReloadInitiated - ctx := context.Background() - func() { - if config.Lint && !runGolangCILint(ctx) { - return - } - if !buildAndRerunServer(ctx) { + var operation string + switch e.Op { + case fsnotify.Create: + operation = "created" + case fsnotify.Write: + operation = "changed" + case fsnotify.Remove: + operation = "removed" + default: return } - }() -} - -func onTemplFileChangedGenTemplates(e fsnotify.Event) { - fileChangedLock.Lock() - defer fileChangedLock.Unlock() - - var operation string - switch e.Op { - case fsnotify.Create: - operation = "created" - case fsnotify.Write: - operation = "changed" - case fsnotify.Remove: - operation = "removed" - default: - return - } - if !strings.HasSuffix(e.Name, ".templ") { - return - } - - fmt.Print("🤖 template file ") - fmt.Print(operation) - fmt.Print(": ") - fCyanUnderline.Println(e.Name) - runTemplGenerate(context.Background(), e.Name) -} + fmt.Print("🤖 file ") + fmt.Print(operation) + fmt.Print(": ") + fBlueUnderline.Println(e.Name) -// mustGoWatchDir watches all directories in pathDir ignoring ignoredDirs -// and debounces calls to fn. -func mustGoWatchDir( - ctx context.Context, - pathDir string, - debounceDur time.Duration, - fn func(e fsnotify.Event), - ignoredDirs ...string, -) { - debounce, do := NewDebouncedSync(debounceDur) - go debounce(ctx) - - forEachDir(pathDir, func(dir string) { - go func() { - watcher, err := fsnotify.NewWatcher() - if err != nil { - panic(fmt.Errorf("initializing file watcher: %w", err)) - } - defer watcher.Close() - if err := watcher.Add(dir); err != nil { - panic(fmt.Errorf("setting up file watcher: %w", err)) + func() { + if config.Lint && !runGolangCILint(ctx) { + return } - for { - select { - case <-ctx.Done(): - return - case e := <-watcher.Events: - do(func() { fn(e) }) - case err := <-watcher.Errors: - panic(fmt.Errorf("watching file: %w", err)) - } + if !buildAndRerunServer(ctx) { + return } }() - }, ignoredDirs...) -} - -// forEachDir executes fn for every subdirectory of pathDir, -// including pathDir itself, recursively. -func forEachDir(pathDir string, fn func(dir string), ignore ...string) { - // Use filepath.Walk to traverse directories - err := filepath.Walk(pathDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err // Stop walking the directory tree. - } - if !info.IsDir() { - return nil // Continue walking. - } - for _, ignored := range ignore { - if strings.HasPrefix(path, ignored) { - return nil - } - } - fn(path) - return nil - }) - // Handle potential errors from walking the directory tree. - if err != nil { - panic(err) // For simplicity, panic on error. Adjust error handling as needed. } } @@ -469,12 +414,12 @@ func runGolangCILint(ctx context.Context) (ok bool) { } func buildAndRerunServer(ctx context.Context) (ok bool) { - if err := os.MkdirAll(serverOutPath, os.ModePerm); err != nil { + if err := os.MkdirAll(config.serverOutPath, os.ModePerm); err != nil { panic(fmt.Errorf("creating go binary output file path in %q: %w", - serverOutPath, err)) + config.serverOutPath, err)) } - binaryPath := makeUniqueServerOutPath(serverOutPath) + binaryPath := makeUniqueServerOutPath(config.serverOutPath) // Register the binary path to make sure it's defer-deleted filesToBeDeletedBeforeExit.Store(binaryPath) @@ -494,193 +439,11 @@ func buildAndRerunServer(ctx context.Context) (ok bool) { return true } +func isTemplFile(filePath string) bool { + return strings.HasSuffix(filePath, ".templ") +} + func makeUniqueServerOutPath(basePath string) string { tm := time.Now() return path.Join(basePath, "server_"+strconv.FormatInt(tm.UnixNano(), 16)) } - -type Server struct { - httpClient *http.Client - appHostAddr string - broadcasterRegister chan chan []byte - jsInjection []byte - webSocketUpgrader websocket.Upgrader -} - -func NewServer( - httpClient *http.Client, - appHostAddr string, - printDebugLogs bool, - broadcasterRegister chan chan []byte, - connectionRefusedTimeout time.Duration, -) *Server { - var jsInjectionBuf bytes.Buffer - err := jsInjection(printDebugLogs, PathProxyEvents).Render( - context.Background(), &jsInjectionBuf, - ) - if err != nil { - panic(fmt.Errorf("rendering the live reload injection template: %w", err)) - } - return &Server{ - httpClient: httpClient, - appHostAddr: appHostAddr, - broadcasterRegister: broadcasterRegister, - jsInjection: jsInjectionBuf.Bytes(), - webSocketUpgrader: websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, // Ignore CORS - }, - } -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == PathProxyEvents { - // This request comes from the injected JavaScript, - // Don't forward, handle it instead. - s.handleProxyEvents(w, r) - return - } - - state := currentState.Load().(State) - if state.Type.IsErr() { - s.handleErrPage(w, r) - return - } - - u, err := url.JoinPath(s.appHostAddr, r.URL.Path) - if err != nil { - internalErr(w, "joining path", err) - return - } - - proxyReq, err := http.NewRequestWithContext( - r.Context(), r.Method, u, r.Body, - ) - if err != nil { - internalErr(w, "initializing request", err) - return - } - - // Copy original request headers - proxyReq.Header = r.Header.Clone() - - var resp *http.Response - for start := time.Now(); ; { - resp, err = s.httpClient.Do(proxyReq) - if err != nil { - if isConnRefused(err) { - if time.Since(start) < config.ProxyTimeout { - continue - } - } - http.Error(w, - fmt.Sprintf("proxy: sending request: %v", err), - http.StatusInternalServerError) - return - } - break - } - defer resp.Body.Close() - - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } - - // Check if the response content type is HTML to inject the script - if resp.StatusCode == http.StatusOK && - strings.Contains(resp.Header.Get("Content-Type"), "text/html") { - b, err := io.ReadAll(resp.Body) - if err != nil { - internalErr(w, "reading response body", err) - return - } - // Inject JavaScript - modified := bytes.Replace(b, bytesBodyClosingTag, s.jsInjection, 1) - w.Header().Set("Content-Length", strconv.Itoa(len(modified))) - _, _ = w.Write(modified) - - } else { - // For non-HTML responses, just proxy the response - _, _ = io.Copy(w, resp.Body) - w.WriteHeader(resp.StatusCode) - } -} - -func isConnRefused(err error) bool { - var opErr *net.OpError - if errors.As(err, &opErr) && opErr.Op == "dial" { - const c = syscall.ECONNREFUSED - if sysErr, ok := opErr.Err.(*os.SyscallError); ok && sysErr.Err == c { - return true - } - } - return false -} - -func (s *Server) handleProxyEvents(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, - "expecting method GET on templier proxy route 'events'", - http.StatusMethodNotAllowed) - return - } - c, err := s.webSocketUpgrader.Upgrade(w, r, nil) - if err != nil { - fmt.Println("🤖 ERR: upgrading to websocket:", err) - internalErr(w, "upgrading to websocket", err) - return - } - defer c.Close() - - messages := make(chan []byte) - s.broadcasterRegister <- messages - - for msg := range messages { - err = c.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err != nil { - fmt.Println("🤖 ERR: setting websocket write deadline:", err) - } - err = c.WriteMessage(websocket.TextMessage, msg) - if err != nil { - return // Disconnected - } - } -} - -func (s *Server) handleErrPage(w http.ResponseWriter, r *http.Request) { - state := currentState.Load().(State) - - var header string - switch state.Type { - case StateTypeErrTempl: - header = "Error: Templ" - case StateTypeErrCompile: - header = "Error: Compiling" - case StateTypeErrGolangCILint: - header = "Error: golangci-lint" - default: - header = "Error" - } - title := header - - comp := errpage( - title, header, string(state.Msg), - config.PrintJSDebugLogs, PathProxyEvents, - ) - err := comp.Render(r.Context(), w) - if err != nil { - panic(fmt.Errorf("rendering errpage: %w", err)) - } -} - -var bytesBodyClosingTag = []byte("") - -func internalErr(w http.ResponseWriter, msg string, err error) { - http.Error(w, - fmt.Sprintf("proxy: %s: %v", msg, err), - http.StatusInternalServerError) -} diff --git a/server.go b/server.go new file mode 100644 index 0000000..2f3382d --- /dev/null +++ b/server.go @@ -0,0 +1,210 @@ +package main + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +type Server struct { + httpClient *http.Client + appHostAddr string + broadcasterRegister chan chan []byte + jsInjection []byte + webSocketUpgrader websocket.Upgrader +} + +func NewServer( + httpClient *http.Client, + appHostAddr string, + printDebugLogs bool, + broadcasterRegister chan chan []byte, + connectionRefusedTimeout time.Duration, +) *Server { + var jsInjectionBuf bytes.Buffer + err := jsInjection(printDebugLogs, PathProxyEvents).Render( + context.Background(), &jsInjectionBuf, + ) + if err != nil { + panic(fmt.Errorf("rendering the live reload injection template: %w", err)) + } + return &Server{ + httpClient: httpClient, + appHostAddr: appHostAddr, + broadcasterRegister: broadcasterRegister, + jsInjection: jsInjectionBuf.Bytes(), + webSocketUpgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, // Ignore CORS + }, + } +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == PathProxyEvents { + // This request comes from the injected JavaScript, + // Don't forward, handle it instead. + s.handleProxyEvents(w, r) + return + } + + state := currentState.Load().(State) + if state.Type.IsErr() { + s.handleErrPage(w, r) + return + } + + u, err := url.JoinPath(s.appHostAddr, r.URL.Path) + if err != nil { + internalErr(w, "joining path", err) + return + } + + proxyReq, err := http.NewRequestWithContext( + r.Context(), r.Method, u, r.Body, + ) + if err != nil { + internalErr(w, "initializing request", err) + return + } + + // Copy original request headers + proxyReq.Header = r.Header.Clone() + + var resp *http.Response + for start := time.Now(); ; { + resp, err = s.httpClient.Do(proxyReq) + if err != nil { + if isConnRefused(err) { + if time.Since(start) < config.ProxyTimeout { + continue + } + } + http.Error(w, + fmt.Sprintf("proxy: sending request: %v", err), + http.StatusInternalServerError) + return + } + break + } + defer resp.Body.Close() + + // Copy response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Check if the response content type is HTML to inject the script + if resp.StatusCode == http.StatusOK && + strings.Contains(resp.Header.Get("Content-Type"), "text/html") { + b, err := io.ReadAll(resp.Body) + if err != nil { + internalErr(w, "reading response body", err) + return + } + // Inject JavaScript + modified := bytes.Replace(b, bytesBodyClosingTag, s.jsInjection, 1) + w.Header().Set("Content-Length", strconv.Itoa(len(modified))) + _, _ = w.Write(modified) + + } else { + // For non-HTML responses, just proxy the response + if _, err = io.Copy(w, resp.Body); err != nil { + internalErr(w, "copying response body", err) + return + } + if resp.StatusCode != http.StatusOK { + w.WriteHeader(resp.StatusCode) + } + } +} + +func isConnRefused(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "dial" { + const c = syscall.ECONNREFUSED + if sysErr, ok := opErr.Err.(*os.SyscallError); ok && sysErr.Err == c { + return true + } + } + return false +} + +func (s *Server) handleProxyEvents(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, + "expecting method GET on templier proxy route 'events'", + http.StatusMethodNotAllowed) + return + } + c, err := s.webSocketUpgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println("🤖 ERR: upgrading to websocket:", err) + internalErr(w, "upgrading to websocket", err) + return + } + defer c.Close() + + messages := make(chan []byte) + s.broadcasterRegister <- messages + + for msg := range messages { + err = c.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + fmt.Println("🤖 ERR: setting websocket write deadline:", err) + } + err = c.WriteMessage(websocket.TextMessage, msg) + if err != nil { + return // Disconnected + } + } +} + +func (s *Server) handleErrPage(w http.ResponseWriter, r *http.Request) { + state := currentState.Load().(State) + + var header string + switch state.Type { + case StateTypeErrTempl: + header = "Error: Templ" + case StateTypeErrCompile: + header = "Error: Compiling" + case StateTypeErrGolangCILint: + header = "Error: golangci-lint" + default: + header = "Error" + } + title := header + + comp := errpage( + title, header, string(state.Msg), + config.PrintJSDebugLogs, PathProxyEvents, + ) + err := comp.Render(r.Context(), w) + if err != nil { + panic(fmt.Errorf("rendering errpage: %w", err)) + } +} + +var bytesBodyClosingTag = []byte("") + +func internalErr(w http.ResponseWriter, msg string, err error) { + http.Error(w, + fmt.Sprintf("proxy: %s: %v", msg, err), + http.StatusInternalServerError) +}