diff --git a/cmd/dashboard/controller/fm.go b/cmd/dashboard/controller/fm.go index 8124f47548..87699114e3 100644 --- a/cmd/dashboard/controller/fm.go +++ b/cmd/dashboard/controller/fm.go @@ -5,11 +5,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/rpc" @@ -48,7 +48,7 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) { rpc.NezhaHandlerSingleton.CreateStream(streamId) - fmData, _ := utils.Json.Marshal(&model.TaskFM{ + fmData, _ := json.Marshal(&model.TaskFM{ StreamID: streamId, }) if err := server.TaskStream.Send(&proto.Task{ diff --git a/cmd/dashboard/controller/server.go b/cmd/dashboard/controller/server.go index 5b039de72b..35f2499c17 100644 --- a/cmd/dashboard/controller/server.go +++ b/cmd/dashboard/controller/server.go @@ -7,11 +7,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/jinzhu/copier" "gorm.io/gorm" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/singleton" ) @@ -81,13 +81,13 @@ func updateServer(c *gin.Context) (any, error) { s.DDNSProfiles = sf.DDNSProfiles s.OverrideDDNSDomains = sf.OverrideDDNSDomains - ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles) + ddnsProfilesRaw, err := json.Marshal(s.DDNSProfiles) if err != nil { return nil, err } s.DDNSProfilesRaw = string(ddnsProfilesRaw) - overrideDomainsRaw, err := utils.Json.Marshal(sf.OverrideDDNSDomains) + overrideDomainsRaw, err := json.Marshal(sf.OverrideDDNSDomains) if err != nil { return nil, err } diff --git a/cmd/dashboard/controller/setting.go b/cmd/dashboard/controller/setting.go index 333b321328..cd5bf471a8 100644 --- a/cmd/dashboard/controller/setting.go +++ b/cmd/dashboard/controller/setting.go @@ -45,7 +45,7 @@ func listConfig(c *gin.Context) (model.SettingResponse[any], error) { Oauth2Providers: config.Oauth2Providers, } if authorized { - configForGuests.TLS = singleton.Conf.TLS + configForGuests.AgentTLS = singleton.Conf.AgentTLS configForGuests.InstallHost = singleton.Conf.InstallHost } conf = model.SettingResponse[any]{ @@ -98,7 +98,7 @@ func updateConfig(c *gin.Context) (any, error) { singleton.Conf.CustomCode = sf.CustomCode singleton.Conf.CustomCodeDashboard = sf.CustomCodeDashboard singleton.Conf.RealIPHeader = sf.RealIPHeader - singleton.Conf.TLS = sf.TLS + singleton.Conf.AgentTLS = sf.AgentTLS singleton.Conf.UserTemplate = sf.UserTemplate if err := singleton.Conf.Save(); err != nil { diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index ec6a7d1000..7899f1b783 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -4,11 +4,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/rpc" @@ -46,7 +46,7 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { rpc.NezhaHandlerSingleton.CreateStream(streamId) - terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ + terminalData, _ := json.Marshal(&model.TerminalTask{ StreamID: streamId, }) if err := server.TaskStream.Send(&proto.Task{ diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go index d409ec5c97..4d0d211803 100644 --- a/cmd/dashboard/controller/ws.go +++ b/cmd/dashboard/controller/ws.go @@ -9,6 +9,7 @@ import ( "unicode/utf8" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "golang.org/x/sync/singleflight" @@ -183,7 +184,7 @@ func getServerStat(withPublicNote, authorized bool) ([]byte, error) { }) } - return utils.Json.Marshal(model.StreamServerData{ + return json.Marshal(model.StreamServerData{ Now: time.Now().Unix() * 1000, Online: singleton.GetOnlineUserCount(), Servers: servers, diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index c7c15b3917..42c5b950a4 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -16,8 +16,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ory/graceful" "golang.org/x/crypto/bcrypt" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "github.com/nezhahq/nezha/cmd/dashboard/controller" "github.com/nezhahq/nezha/cmd/dashboard/controller/waf" @@ -133,11 +131,22 @@ func main() { controller.InitUpgrader() muxHandler := newHTTPandGRPCMux(httpHandler, grpcHandler) - http2Server := &http2.Server{} - muxServer := &http.Server{Handler: h2c.NewHandler(muxHandler, http2Server), ReadHeaderTimeout: time.Second * 5} + muxServer := &http.Server{ + Handler: muxHandler, + ReadHeaderTimeout: time.Second * 5, + } + muxServer.Protocols.SetHTTP1(true) + if singleton.Conf.EnableTLS { + muxServer.Protocols.SetHTTP2(true) + } else { + muxServer.Protocols.SetUnencryptedHTTP2(true) + } if err := graceful.Graceful(func() error { log.Printf("NEZHA>> Dashboard::START ON %s:%d", singleton.Conf.ListenHost, singleton.Conf.ListenPort) + if singleton.Conf.EnableTLS { + return muxServer.ServeTLS(l, singleton.Conf.TLSCertPath, singleton.Conf.TLSKeyPath) + } return muxServer.Serve(l) }, func(c context.Context) error { log.Println("NEZHA>> Graceful::START") diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 79157f89be..f4e559c750 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "encoding/json" "fmt" "log" "net/http" @@ -169,7 +170,7 @@ func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) { rpcService.NezhaHandlerSingleton.CreateStream(streamId) defer rpcService.NezhaHandlerSingleton.CloseStream(streamId) - taskData, err := utils.Json.Marshal(model.TaskNAT{ + taskData, err := json.Marshal(model.TaskNAT{ StreamID: streamId, Host: natConfig.Host, }) diff --git a/go.mod b/go.mod index d1a3663f9f..5704e20381 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/nezhahq/nezha -go 1.23.6 +go 1.24.0 require ( github.com/appleboy/gin-jwt/v2 v2.10.1 @@ -8,10 +8,10 @@ require ( github.com/dustinkirkland/golang-petname v0.0.0-20240428194347-eebcea082ee0 github.com/gin-contrib/pprof v1.5.2 github.com/gin-gonic/gin v1.10.0 + github.com/goccy/go-json v0.10.5 github.com/gorilla/websocket v1.5.3 github.com/hashicorp/go-uuid v1.0.3 github.com/jinzhu/copier v0.4.0 - github.com/json-iterator/go v1.1.12 github.com/knadh/koanf/parsers/yaml v0.1.0 github.com/knadh/koanf/providers/env v1.0.0 github.com/knadh/koanf/providers/file v1.1.2 @@ -57,11 +57,11 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.25.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/goccy/go-json v0.10.5 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/knadh/koanf/maps v0.1.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/model/alertrule.go b/model/alertrule.go index 9b4fca083c..459c04da4a 100644 --- a/model/alertrule.go +++ b/model/alertrule.go @@ -1,7 +1,7 @@ package model import ( - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "gorm.io/gorm" ) @@ -25,17 +25,17 @@ type AlertRule struct { } func (r *AlertRule) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(r.Rules); err != nil { + if data, err := json.Marshal(r.Rules); err != nil { return err } else { r.RulesRaw = string(data) } - if data, err := utils.Json.Marshal(r.FailTriggerTasks); err != nil { + if data, err := json.Marshal(r.FailTriggerTasks); err != nil { return err } else { r.FailTriggerTasksRaw = string(data) } - if data, err := utils.Json.Marshal(r.RecoverTriggerTasks); err != nil { + if data, err := json.Marshal(r.RecoverTriggerTasks); err != nil { return err } else { r.RecoverTriggerTasksRaw = string(data) @@ -45,13 +45,13 @@ func (r *AlertRule) BeforeSave(tx *gorm.DB) error { func (r *AlertRule) AfterFind(tx *gorm.DB) error { var err error - if err = utils.Json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil { + if err = json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil { return err } - if err = utils.Json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil { + if err = json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil { return err } - if err = utils.Json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil { + if err = json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil { return err } return nil diff --git a/model/config.go b/model/config.go index f7ec9251a4..aab9ee64ee 100644 --- a/model/config.go +++ b/model/config.go @@ -31,7 +31,7 @@ type ConfigForGuests struct { Oauth2Providers []string `json:"oauth2_providers,omitempty"` InstallHost string `json:"install_host,omitempty"` - TLS bool `json:"tls,omitempty"` + AgentTLS bool `json:"tls,omitempty"` } type Config struct { @@ -47,7 +47,7 @@ type Config struct { ListenPort uint `mapstructure:"listen_port" json:"listen_port,omitempty"` ListenHost string `mapstructure:"listen_host" json:"listen_host,omitempty"` InstallHost string `mapstructure:"install_host" json:"install_host,omitempty"` - TLS bool `mapstructure:"tls" json:"tls,omitempty"` + AgentTLS bool `mapstructure:"tls" json:"tls,omitempty"` // 用于前端判断生成的安装命令是否启用 TLS Location string `mapstructure:"location" json:"location,omitempty"` // 时区,默认为 Asia/Shanghai ForceAuth bool `mapstructure:"force_auth" json:"force_auth,omitempty"` // 强制要求认证 @@ -71,6 +71,11 @@ type Config struct { // oauth2 供应商列表,无需配置,自动生成 Oauth2Providers []string `yaml:"-" json:"oauth2_providers,omitempty"` + // TLS 证书配置 + EnableTLS bool `mapstructure:"enable_tls" json:"enable_tls,omitempty"` + TLSCertPath string `mapstructure:"tls_cert_path" json:"tls_cert_path,omitempty"` + TLSKeyPath string `mapstructure:"tls_key_path" json:"tls_key_path,omitempty"` + k *koanf.Koanf `json:"-"` filePath string `json:"-"` } diff --git a/model/cron.go b/model/cron.go index 76a9c30858..418c4215af 100644 --- a/model/cron.go +++ b/model/cron.go @@ -3,7 +3,7 @@ package model import ( "time" - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "github.com/robfig/cron/v3" "gorm.io/gorm" ) @@ -34,7 +34,7 @@ type Cron struct { } func (c *Cron) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(c.Servers); err != nil { + if data, err := json.Marshal(c.Servers); err != nil { return err } else { c.ServersRaw = string(data) @@ -43,5 +43,5 @@ func (c *Cron) BeforeSave(tx *gorm.DB) error { } func (c *Cron) AfterFind(tx *gorm.DB) error { - return utils.Json.Unmarshal([]byte(c.ServersRaw), &c.Servers) + return json.Unmarshal([]byte(c.ServersRaw), &c.Servers) } diff --git a/model/ddns.go b/model/ddns.go index fe942b9cec..601476e997 100644 --- a/model/ddns.go +++ b/model/ddns.go @@ -1,7 +1,7 @@ package model import ( - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "gorm.io/gorm" ) @@ -39,7 +39,7 @@ func (d DDNSProfile) TableName() string { } func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(d.Domains); err != nil { + if data, err := json.Marshal(d.Domains); err != nil { return err } else { d.DomainsRaw = string(data) @@ -48,5 +48,5 @@ func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error { } func (d *DDNSProfile) AfterFind(tx *gorm.DB) error { - return utils.Json.Unmarshal([]byte(d.DomainsRaw), &d.Domains) + return json.Unmarshal([]byte(d.DomainsRaw), &d.Domains) } diff --git a/model/notification.go b/model/notification.go index dcde788456..e6f570db6a 100644 --- a/model/notification.go +++ b/model/notification.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/goccy/go-json" "github.com/nezhahq/nezha/pkg/utils" ) @@ -66,7 +67,7 @@ func (ns *NotificationServerBundle) reqBody(message string) (string, error) { switch n.RequestType { case NotificationRequestTypeJSON: return ns.replaceParamsInString(n.RequestBody, message, func(msg string) string { - msgBytes, _ := utils.Json.Marshal(msg) + msgBytes, _ := json.Marshal(msg) return string(msgBytes)[1 : len(msgBytes)-1] }), nil case NotificationRequestTypeForm: diff --git a/model/server.go b/model/server.go index 67f655e007..635b9000bd 100644 --- a/model/server.go +++ b/model/server.go @@ -5,9 +5,9 @@ import ( "slices" "time" + "github.com/goccy/go-json" "gorm.io/gorm" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -58,13 +58,13 @@ func (s *Server) CopyFromRunningServer(old *Server) { func (s *Server) AfterFind(tx *gorm.DB) error { if s.DDNSProfilesRaw != "" { - if err := utils.Json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil { + if err := json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil { log.Println("NEZHA>> Server.AfterFind:", err) return nil } } if s.OverrideDDNSDomainsRaw != "" { - if err := utils.Json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil { + if err := json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil { log.Println("NEZHA>> Server.AfterFind:", err) return nil } diff --git a/model/service.go b/model/service.go index 76e423c913..0be9f383c9 100644 --- a/model/service.go +++ b/model/service.go @@ -4,10 +4,10 @@ import ( "fmt" "log" + "github.com/goccy/go-json" "github.com/robfig/cron/v3" "gorm.io/gorm" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -91,17 +91,17 @@ func (m *Service) CronSpec() string { } func (m *Service) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(m.SkipServers); err != nil { + if data, err := json.Marshal(m.SkipServers); err != nil { return err } else { m.SkipServersRaw = string(data) } - if data, err := utils.Json.Marshal(m.FailTriggerTasks); err != nil { + if data, err := json.Marshal(m.FailTriggerTasks); err != nil { return err } else { m.FailTriggerTasksRaw = string(data) } - if data, err := utils.Json.Marshal(m.RecoverTriggerTasks); err != nil { + if data, err := json.Marshal(m.RecoverTriggerTasks); err != nil { return err } else { m.RecoverTriggerTasksRaw = string(data) @@ -111,16 +111,16 @@ func (m *Service) BeforeSave(tx *gorm.DB) error { func (m *Service) AfterFind(tx *gorm.DB) error { m.SkipServers = make(map[uint64]bool) - if err := utils.Json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil { + if err := json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil { log.Println("NEZHA>> Service.AfterFind:", err) return nil } // 加载触发任务列表 - if err := utils.Json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil { + if err := json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil { return err } - if err := utils.Json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil { + if err := json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil { return err } diff --git a/model/setting_api.go b/model/setting_api.go index bb445d74a9..ea7c977497 100644 --- a/model/setting_api.go +++ b/model/setting_api.go @@ -13,7 +13,7 @@ type SettingForm struct { RealIPHeader string `json:"real_ip_header,omitempty" validate:"optional"` // 真实IP UserTemplate string `json:"user_template,omitempty" validate:"optional"` - TLS bool `json:"tls,omitempty" validate:"optional"` + AgentTLS bool `json:"tls,omitempty" validate:"optional"` EnableIPChangeNotification bool `json:"enable_ip_change_notification,omitempty" validate:"optional"` EnablePlainIPInNotification bool `json:"enable_plain_ip_in_notification,omitempty" validate:"optional"` } diff --git a/pkg/utils/gjson.go b/pkg/utils/gjson.go index 77594a7025..15df2eb267 100644 --- a/pkg/utils/gjson.go +++ b/pkg/utils/gjson.go @@ -33,9 +33,7 @@ func GjsonIter(json string) (iter.Seq2[string, string], error) { return nil, ErrGjsonWrongType } - return func(yield func(string, string) bool) { - result.ForEach(func(k, v gjson.Result) bool { - return yield(k.String(), v.String()) - }) - }, nil + return ConvertSeq2(result.ForEach, func(k, v gjson.Result) (string, string) { + return k.String(), v.String() + }), nil } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index ad3bfc95ca..2d250e0a04 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -13,24 +13,19 @@ import ( "strings" "golang.org/x/exp/constraints" - - jsoniter "github.com/json-iterator/go" ) var ( - Json = jsoniter.ConfigCompatibleWithStandardLibrary - DNSServers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} -) -var ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) + ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) + ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) +) func ipv4Desensitize(ipv4Addr string) string { return ipv4Re.ReplaceAllString(ipv4Addr, "$1****$2") } -var ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) - func ipv6Desensitize(ipv6Addr string) string { return ipv6Re.ReplaceAllString(ipv6Addr, "$1****$2") } @@ -51,9 +46,11 @@ func IPStringToBinary(ip string) ([]byte, error) { } func BinaryToIPString(b []byte) string { - var addr16 [16]byte - copy(addr16[:], b) - addr := netip.AddrFrom16(addr16) + if len(b) < 16 { + return "::" + } + + addr := netip.AddrFrom16([16]byte(b)) return addr.Unmap().String() } @@ -129,10 +126,20 @@ func Unique[T comparable](s []T) []T { return ret } -func ConvertSeq[T, U any](seq iter.Seq[T], f func(e T) U) iter.Seq[U] { - return func(yield func(U) bool) { - for e := range seq { - if !yield(f(e)) { +func ConvertSeq[In, Out any](seq iter.Seq[In], f func(In) Out) iter.Seq[Out] { + return func(yield func(Out) bool) { + for in := range seq { + if !yield(f(in)) { + return + } + } + } +} + +func ConvertSeq2[KIn, VIn, KOut, VOut any](seq iter.Seq2[KIn, VIn], f func(KIn, VIn) (KOut, VOut)) iter.Seq2[KOut, VOut] { + return func(yield func(KOut, VOut) bool) { + for k, v := range seq { + if !yield(f(k, v)) { return } }