diff --git a/go.mod b/go.mod index a0d758f..e68d251 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require github.com/pires/go-proxyproto v0.7.0 require github.com/pelletier/go-toml/v2 v2.2.2 +require github.com/fsnotify/fsnotify v1.7.0 + require github.com/julienschmidt/httprouter v1.3.0 require golang.org/x/sys v0.23.0 // indirect diff --git a/go.sum b/go.sum index d7e3269..aff84cd 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= diff --git a/main.go b/main.go index 1118cf0..6fc4076 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,49 @@ package main import ( "encoding/json" "flag" + "fmt" "log" "os" "strings" + "github.com/fsnotify/fsnotify" "github.com/pelletier/go-toml/v2" ) +func watchConfigFile(path string) (chan fsnotify.Event, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + err = watcher.Add(path) + if err != nil { + return nil, err + } + return watcher.Events, nil +} + +func replaceServer(sshmux *Server, event fsnotify.Event) (*Server, error) { + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + log.Printf("info: %s has been changed, reloading...\n", event.Name) + // Start new server instance + newServer, err := sshmuxServer(event.Name) + if err != nil { + return sshmux, fmt.Errorf("failed to parse %s: %w", event.Name, err) + } + err = newServer.Start() + if err != nil { + return sshmux, fmt.Errorf("failed to start sshmux server: %w", err) + } + // Replace old server + go sshmux.Shutdown() + return newServer, nil + } + if event.Has(fsnotify.Remove) || event.Has(fsnotify.Rename) { + log.Printf("warn: %s has been deleted\n", event.Name) + } + return sshmux, nil +} + func sshmuxServer(configFile string) (*Server, error) { var config Config configFileBytes, err := os.ReadFile(configFile) @@ -35,7 +71,9 @@ func sshmuxServer(configFile string) (*Server, error) { func main() { var configFile string + var reload bool flag.StringVar(&configFile, "c", "/etc/sshmux/config.toml", "config file") + flag.BoolVar(&reload, "r", false, "auto reload") flag.Parse() sshmux, err := sshmuxServer(configFile) if err != nil { @@ -45,5 +83,18 @@ func main() { if err != nil { log.Fatal(err) } - sshmux.Wait() + if reload { + events, err := watchConfigFile(configFile) + if err != nil { + log.Fatal(err) + } + for event := range events { + sshmux, err = replaceServer(sshmux, event) + if err != nil { + log.Print(err) + } + } + } else { + sshmux.Wait() + } }