From 6b99a287a67d9c713c67051c84813558e8a5eda7 Mon Sep 17 00:00:00 2001 From: lanthora Date: Fri, 19 Jul 2024 22:56:37 +0800 Subject: [PATCH] feat: add net api --- .gitignore | 1 + api/net.go | 149 ++++++ {user => api}/user.go | 81 ++- candy/device.go | 10 + candy/message.go | 74 +++ candy/net.go | 193 ++++++++ candy/util.go | 17 + candy/websocket.go | 465 ++++++++++++++++++ go.mod | 2 + go.sum | 4 + logger/logger.go | 24 +- main.go | 19 +- configs/configs.go => model/config.go | 19 +- .../configs_test.go => model/config_test.go | 2 +- model/device.go | 35 ++ model/net.go | 58 +++ model/user.go | 29 ++ status/status.go | 30 +- storage/storage.go | 4 +- 19 files changed, 1126 insertions(+), 90 deletions(-) create mode 100644 api/net.go rename {user => api}/user.go (71%) create mode 100644 candy/device.go create mode 100644 candy/message.go create mode 100644 candy/net.go create mode 100644 candy/util.go create mode 100644 candy/websocket.go rename configs/configs.go => model/config.go (91%) rename configs/configs_test.go => model/config_test.go (95%) create mode 100644 model/device.go create mode 100644 model/net.go create mode 100644 model/user.go diff --git a/.gitignore b/.gitignore index be66f20..3cc6ab7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ cacao sqlite.db +cookie diff --git a/api/net.go b/api/net.go new file mode 100644 index 0000000..e21f91c --- /dev/null +++ b/api/net.go @@ -0,0 +1,149 @@ +package api + +import ( + "github.com/gin-gonic/gin" + "github.com/lanthora/cacao/candy" + "github.com/lanthora/cacao/model" + "github.com/lanthora/cacao/status" + "github.com/lanthora/cacao/storage" + "gorm.io/gorm" +) + +func NetShow(c *gin.Context) { + user := c.MustGet("user").(*model.User) + nets := model.GetNetsByUserID(user.ID) + + type netinfo struct { + ID uint `json:"netid"` + Name string `json:"netname"` + Password string `json:"password"` + DHCP string `json:"dhcp"` + Broadcast bool `json:"broadcast"` + } + + response := make([]netinfo, 0) + for _, n := range nets { + response = append(response, netinfo{ + ID: n.ID, + Name: n.Name, + Password: n.Password, + DHCP: n.DHCP, + Broadcast: n.Broadcast, + }) + } + + status.UpdateSuccess(c, gin.H{ + "nets": response, + }) +} + +func NetInsert(c *gin.Context) { + var request struct { + Name string `json:"netname"` + Password string `json:"password"` + DHCP string `json:"dhcp"` + Broadcast bool `json:"broadcast"` + } + + if err := c.BindJSON(&request); err != nil { + status.UpdateCode(c, status.InvalidRequest) + return + } + + user := c.MustGet("user").(*model.User) + modelNet := &model.Net{ + UserID: user.ID, + Name: request.Name, + } + + db := storage.Get() + result := db.Where(modelNet).Take(modelNet) + if result.Error != gorm.ErrRecordNotFound { + status.UpdateCode(c, status.NetworkAlreadyExists) + return + } + + modelNet.Password = request.Password + modelNet.DHCP = request.DHCP + modelNet.Broadcast = request.Broadcast + modelNet.Create() + candy.InsertNet(modelNet) + + status.UpdateSuccess(c, gin.H{ + "netid": modelNet.ID, + "netname": modelNet.Name, + "password": modelNet.Password, + "dhcp": modelNet.DHCP, + "broadcast": modelNet.Broadcast, + }) +} + +func NetEdit(c *gin.Context) { + var request struct { + ID uint `json:"netid"` + Name string `json:"netname"` + Password string `json:"password"` + DHCP string `json:"dhcp"` + Broadcast bool `json:"broadcast"` + } + + if err := c.BindJSON(&request); err != nil { + status.UpdateCode(c, status.InvalidRequest) + return + } + + user := c.MustGet("user").(*model.User) + modelNet := &model.Net{} + modelNet.ID = request.ID + db := storage.Get() + result := db.Where(modelNet).Take(modelNet) + + if result.Error != nil || modelNet.UserID != user.ID { + status.UpdateCode(c, status.NetworkNotExists) + return + } + + modelNet.Name = request.Name + modelNet.Password = request.Password + modelNet.DHCP = request.DHCP + modelNet.Broadcast = request.Broadcast + modelNet.Update() + candy.UpdateNet(modelNet) + + status.UpdateSuccess(c, gin.H{ + "netid": modelNet.ID, + "netname": modelNet.Name, + "password": modelNet.Password, + "dhcp": modelNet.DHCP, + "broadcast": modelNet.Broadcast, + }) +} + +func NetDelete(c *gin.Context) { + var request struct { + ID uint `json:"netid"` + } + + if err := c.BindJSON(&request); err != nil { + status.UpdateCode(c, status.InvalidRequest) + return + } + + user := c.MustGet("user").(*model.User) + modelNet := &model.Net{} + modelNet.ID = request.ID + db := storage.Get() + result := db.Where(modelNet).Take(modelNet) + + if result.Error != nil || modelNet.UserID != user.ID { + status.UpdateCode(c, status.NetworkNotExists) + return + } + + modelNet.Delete() + candy.DeleteNet(modelNet.ID) + + status.UpdateSuccess(c, gin.H{ + "id": modelNet.ID, + }) +} diff --git a/user/user.go b/api/user.go similarity index 71% rename from user/user.go rename to api/user.go index 7c64141..3a226c5 100644 --- a/user/user.go +++ b/api/user.go @@ -1,48 +1,18 @@ -package user +package api import ( - "crypto/sha256" - "fmt" - "regexp" "strconv" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/candy" + "github.com/lanthora/cacao/model" "github.com/lanthora/cacao/status" "github.com/lanthora/cacao/storage" - "gorm.io/gorm" ) -type User struct { - gorm.Model - Name string `gorm:"uniqueIndex"` - Password string - Token string - Role string - IP string -} - -func init() { - db := storage.Get() - err := db.AutoMigrate(User{}) - if err != nil { - logger.Fatal("auto migrate users failed: %v", err) - } -} - -func isAlphanumeric(s string) bool { - match, _ := regexp.MatchString("^[a-zA-Z0-9]+$", s) - return match -} - -func sha256sum(data []byte) string { - hash := sha256.Sum256(data) - return fmt.Sprintf("%x", hash[:]) -} - func LoginMiddleware() gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.String() @@ -67,7 +37,7 @@ func LoginMiddleware() gin.HandlerFunc { c.Abort() return } - user := &User{} + user := &model.User{} user.ID = uint(id) db := storage.Get() @@ -82,7 +52,7 @@ func LoginMiddleware() gin.HandlerFunc { } } -func Register(c *gin.Context) { +func UserRegister(c *gin.Context) { var request struct { Username string `json:"username"` Password string `json:"password"` @@ -91,7 +61,7 @@ func Register(c *gin.Context) { status.UpdateCode(c, status.InvalidRequest) return } - if len(request.Username) < 3 || !isAlphanumeric(request.Username) { + if len(request.Username) < 3 || !candy.IsAlphanumeric(request.Username) { status.UpdateCode(c, status.InvalidUsername) return } @@ -103,7 +73,7 @@ func Register(c *gin.Context) { db := storage.Get() if func() bool { count := int64(0) - db.Model(&User{}).Where(&User{IP: c.ClientIP(), Role: "normal"}).Where("created_at > ?", time.Now().Add(-24*time.Hour)).Count(&count) + db.Model(&model.User{}).Where(&model.User{IP: c.ClientIP(), Role: "normal"}).Where("created_at > ?", time.Now().Add(-24*time.Hour)).Count(&count) return count > 0 }() { status.UpdateCode(c, status.RegisterTooFrequently) @@ -112,7 +82,7 @@ func Register(c *gin.Context) { if func() bool { count := int64(0) - db.Model(&User{}).Where(&User{Name: request.Username}).Count(&count) + db.Model(&model.User{}).Where(&model.User{Name: request.Username}).Count(&count) return count > 0 }() { status.UpdateCode(c, status.UsernameAlreadyTaken) @@ -121,16 +91,16 @@ func Register(c *gin.Context) { role := func() string { count := int64(0) - db.Model(&User{}).Count(&count) + db.Model(&model.User{}).Count(&count) if count == 0 { return "admin" } return "normal" }() - user := User{ + user := model.User{ Name: request.Username, - Password: sha256sum([]byte(request.Password)), + Password: candy.Sha256sum([]byte(request.Password)), Token: uuid.NewString(), Role: role, IP: c.ClientIP(), @@ -148,9 +118,21 @@ func Register(c *gin.Context) { "name": user.Name, "role": user.Role, }) + + if role == "normal" { + modelNet := &model.Net{ + UserID: user.ID, + Name: "@", + Password: request.Password, + DHCP: "192.168.202.0/24", + Broadcast: true, + } + modelNet.Create() + candy.InsertNet(modelNet) + } } -func Login(c *gin.Context) { +func UserLogin(c *gin.Context) { var request struct { Username string `json:"username"` Password string `json:"password"` @@ -160,19 +142,20 @@ func Login(c *gin.Context) { return } - user := User{ + user := model.User{ Name: request.Username, - Password: sha256sum([]byte(request.Password)), + Password: candy.Sha256sum([]byte(request.Password)), } db := storage.Get() + if result := db.Where(user).Take(&user); result.Error != nil { status.UpdateCode(c, status.UsernameOrPasswordIncorrect) return } user.Token = uuid.NewString() - db.Save(user) + user.Save() c.SetCookie("id", strconv.FormatUint(uint64(user.ID), 10), 86400, "/", "", false, true) c.SetCookie("token", user.Token, 86400, "/", "", false, true) @@ -183,12 +166,10 @@ func Login(c *gin.Context) { }) } -func Logout(c *gin.Context) { - user := c.MustGet("user").(*User) +func UserLogout(c *gin.Context) { + user := c.MustGet("user").(*model.User) user.Token = uuid.NewString() - - db := storage.Get() - db.Save(user) + user.Save() c.SetCookie("id", "", -1, "/", "", false, true) c.SetCookie("token", "", -1, "/", "", false, true) diff --git a/candy/device.go b/candy/device.go new file mode 100644 index 0000000..66c3679 --- /dev/null +++ b/candy/device.go @@ -0,0 +1,10 @@ +package candy + +import ( + "github.com/lanthora/cacao/model" +) + +type Device struct { + model *model.Device + ip uint32 +} diff --git a/candy/message.go b/candy/message.go new file mode 100644 index 0000000..9d7aeb8 --- /dev/null +++ b/candy/message.go @@ -0,0 +1,74 @@ +package candy + +const ( + AUTH uint8 = 0 + FORWARD uint8 = 1 + DHCP uint8 = 2 + PEER uint8 = 3 + VMAC uint8 = 4 + DISCOVERY uint8 = 5 + ROUTE uint8 = 6 + GENERAL uint8 = 255 +) + +type AuthMessage struct { + Type uint8 `struc:"uint8"` + IP uint32 `struc:"uint32"` + Timestamp int64 `struc:"int64"` + Hash [32]byte `struc:"[32]byte"` +} + +type ForwardMessage struct { + Type uint8 `struc:"uint8"` + Unused [12]byte `struc:"[12]byte"` + Src uint32 `struc:"uint32"` + Dst uint32 `struc:"uint32"` +} + +type DHCPMessage struct { + Type uint8 `struc:"uint8"` + Timestamp int64 `struc:"int64"` + Cidr []byte `struc:"[32]byte"` + Hash [32]byte `struc:"[32]byte"` +} + +type PeerConnMessage struct { + Type uint8 `struc:"uint8"` + Src uint32 `struc:"uint32"` + Dst uint32 `struc:"uint32"` + IP uint32 `struc:"uint32"` + Port uint16 `struc:"uint16"` +} + +type VMacMessage struct { + Type uint8 `struc:"uint8"` + VMac string `struc:"[16]byte"` + Timestamp int64 `struc:"int64"` + Hash [32]byte `struc:"[32]byte"` +} + +type DiscoveryMessage struct { + Type uint8 `struc:"uint8"` + Src uint32 `struc:"uint32"` + Dst uint32 `struc:"uint32"` +} + +type RouteMessage struct { + Type uint8 `struc:"uint8"` + Size uint8 `struc:"uint8"` + Reserved uint16 `struc:"uint16"` +} + +type RouteMessageEntry struct { + Dest uint32 `struc:"uint32"` + Mask uint32 `struc:"uint32"` + NextHop uint32 `struc:"uint32"` +} + +type GeneralMessage struct { + Type uint8 `struc:"uint8"` + Subtype uint8 `struc:"uint8"` + Extra uint16 `struc:"uint16"` + Src uint32 `struc:"uint32"` + Dst uint32 `struc:"uint32"` +} diff --git a/candy/net.go b/candy/net.go new file mode 100644 index 0000000..5114ca8 --- /dev/null +++ b/candy/net.go @@ -0,0 +1,193 @@ +package candy + +import ( + "crypto/sha256" + "encoding/binary" + "fmt" + "math/rand/v2" + "net" + "strconv" + "sync" + "time" + + "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/model" + "github.com/lanthora/cacao/storage" + "gorm.io/gorm" +) + +func init() { + idNetMap = map[uint]*Net{} + + for _, netModel := range model.GetNets() { + InsertNet(&netModel) + } +} + +type Net struct { + model *model.Net + ipWsMap map[uint32]*candysocket + ipWsMapMutex sync.RWMutex + + net uint32 + host uint32 + mask uint32 +} + +var idNetMap map[uint]*Net +var idNetMapMutex sync.RWMutex + +func (net *Net) ipConflict(ip, vmac string) bool { + db := storage.Get() + device := &model.Device{NetID: net.model.ID, IP: ip} + result := db.Where(device).Take(device) + if result.Error == gorm.ErrRecordNotFound { + return false + } + if result.Error == nil && device.VMac == vmac { + return false + } + + return true +} + +func (net *Net) checkAuthMessage(message *AuthMessage) error { + if absInt64(time.Now().Unix(), message.Timestamp) > 30 { + return fmt.Errorf("auth check failed: timestamp: %v", message.Timestamp) + } + + reported := message.Hash + + var data []byte + data = append(data, net.model.Password...) + data = binary.BigEndian.AppendUint32(data, message.IP) + data = binary.BigEndian.AppendUint64(data, uint64(message.Timestamp)) + + if sha256.Sum256([]byte(data)) != reported { + return fmt.Errorf("auth check failed: hash does not match") + } + return nil +} + +func (net *Net) checkDHCPMessage(message *DHCPMessage) error { + if absInt64(time.Now().Unix(), message.Timestamp) > 30 { + return fmt.Errorf("dhcp check failed: timestamp: %v", message.Timestamp) + } + + reported := message.Hash + + var data []byte + data = append(data, net.model.Password...) + data = binary.BigEndian.AppendUint64(data, uint64(message.Timestamp)) + + if sha256.Sum256([]byte(data)) != reported { + return fmt.Errorf("dhcp check failed: hash does not match") + } + return nil +} + +func (net *Net) checkVMacMessage(message *VMacMessage) error { + if absInt64(time.Now().Unix(), message.Timestamp) > 30 { + return fmt.Errorf("vmac check failed: timestamp: %v", message.Timestamp) + } + + if _, err := strconv.ParseUint(message.VMac, 16, 64); err != nil { + return fmt.Errorf("vmac check failed: invalid vmac") + } + + reported := message.Hash + + var data []byte + data = append(data, net.model.Password...) + data = append(data, message.VMac...) + data = binary.BigEndian.AppendUint64(data, uint64(message.Timestamp)) + + if sha256.Sum256([]byte(data)) != reported { + return fmt.Errorf("vmac check failed: hash does not match") + } + return nil +} + +func (net *Net) updateHost() string { + for ok := true; ok; ok = (net.host == 0 || net.host == ^net.mask) { + net.host = (net.host + 1) & (^net.mask) + } + return uint32ToStrIP(net.net | net.host) +} + +func (net *Net) Close() { + net.ipWsMapMutex.Lock() + defer net.ipWsMapMutex.Unlock() + for ip, ws := range net.ipWsMap { + ws.writeCloseMessage("net close") + ws.conn.Close() + delete(net.ipWsMap, ip) + } +} + +func InsertNet(netModel *model.Net) { + idNetMapMutex.Lock() + defer idNetMapMutex.Unlock() + + _, ipNet, err := net.ParseCIDR(netModel.DHCP) + if err != nil { + logger.Fatal("insert net failed: %v", err) + } + + netid := binary.BigEndian.Uint32(ipNet.IP) + mask := binary.BigEndian.Uint32(ipNet.Mask) + hostid := rand.Uint32() & ^mask + + if ^mask < 2 { + logger.Fatal("invalid net cidr: %v", netModel.DHCP) + } + + net := &Net{ + model: netModel, + ipWsMap: make(map[uint32]*candysocket), + net: netid, + host: hostid, + mask: mask, + } + net.updateHost() + idNetMap[netModel.ID] = net +} + +func UpdateNet(netModel *model.Net) { + DeleteNet(netModel.ID) + InsertNet(netModel) +} + +func DeleteNet(netid uint) { + idNetMapMutex.Lock() + defer idNetMapMutex.Unlock() + + if net, ok := idNetMap[netid]; ok { + net.Close() + } + + delete(idNetMap, netid) +} + +func getNetById(netid uint) *Net { + idNetMapMutex.RLock() + defer idNetMapMutex.RUnlock() + + if net, ok := idNetMap[netid]; ok { + return net + } + return nil +} + +func absInt64(a, b int64) int64 { + if a > b { + return a - b + } + return b - a +} + +func uint32ToStrIP(ip uint32) string { + var buffer []byte = make([]byte, 4) + binary.BigEndian.PutUint32(buffer, ip) + return net.IP(buffer).String() +} diff --git a/candy/util.go b/candy/util.go new file mode 100644 index 0000000..c286b1c --- /dev/null +++ b/candy/util.go @@ -0,0 +1,17 @@ +package candy + +import ( + "crypto/sha256" + "fmt" + "regexp" +) + +func IsAlphanumeric(s string) bool { + match, _ := regexp.MatchString("^[a-zA-Z0-9]+$", s) + return match +} + +func Sha256sum(data []byte) string { + hash := sha256.Sum256(data) + return fmt.Sprintf("%x", hash[:]) +} diff --git a/candy/websocket.go b/candy/websocket.go new file mode 100644 index 0000000..eeb9b6b --- /dev/null +++ b/candy/websocket.go @@ -0,0 +1,465 @@ +package candy + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/model" + "github.com/lanthora/cacao/storage" + "github.com/lunixbochs/struc" +) + +func WebsocketMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader("Upgrade") == "websocket" { + handleWebsocket(c) + c.Abort() + } else { + c.Next() + } + } +} + +func handleWebsocket(c *gin.Context) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Debug("websocket upgrade failed: %v", err) + return + } + defer conn.Close() + net := getNetFromPath(c.Request.URL.Path) + if net == nil { + logger.Debug("net not found: %v", c.Request.URL.Path) + return + } + ws := &candysocket{ctx: c, conn: conn, net: net} + conn.SetPingHandler(func(buffer string) error { return ws.handlePingMessage(buffer) }) + + for { + ws.updateReadDeadline() + messageType, buffer, err := conn.ReadMessage() + if err != nil { + logger.Debug("read websocket failed: %v", err) + break + } + if messageType != websocket.BinaryMessage { + continue + } + switch uint8(buffer[0]) { + case AUTH: + err = ws.handleAuthMessage(buffer) + case FORWARD: + err = ws.handleForwardMessage(buffer) + case DHCP: + err = ws.handleDHCPMessage(buffer) + case PEER: + err = ws.handlePeerConnMessage(buffer) + case VMAC: + err = ws.handleVMacMessage(buffer) + case DISCOVERY: + err = ws.handleDiscoveryMessage(buffer) + case GENERAL: + err = ws.handleGeneralMessage(buffer) + } + if err != nil { + logger.Debug("handle client message failed: %v", err) + break + } + } + + if ws.dev != nil && ws.dev.model.Online { + ws.dev.model.Online = false + ws.dev.model.Save() + + net.ipWsMapMutex.Lock() + defer net.ipWsMapMutex.Unlock() + delete(net.ipWsMap, ws.dev.ip) + } +} + +type candysocket struct { + ctx *gin.Context + conn *websocket.Conn + connMutex sync.Mutex + dev *Device + net *Net +} + +func (ws *candysocket) updateReadDeadline() error { + ws.connMutex.Lock() + defer ws.connMutex.Unlock() + return ws.conn.SetReadDeadline((time.Now().Add(60 * time.Second))) +} + +func (ws *candysocket) writeCloseMessage(text string) error { + ws.connMutex.Lock() + defer ws.connMutex.Unlock() + return ws.conn.WriteMessage(websocket.CloseMessage, []byte(text)) +} + +func (ws *candysocket) writeMessage(buffer []byte) error { + ws.connMutex.Lock() + defer ws.connMutex.Unlock() + return ws.conn.WriteMessage(websocket.BinaryMessage, buffer) +} + +func (ws *candysocket) writePong(buffer []byte) error { + ws.connMutex.Lock() + defer ws.connMutex.Unlock() + return ws.conn.WriteMessage(websocket.PongMessage, buffer) +} + +func (ws *candysocket) handlePingMessage(buffer string) error { + ws.updateReadDeadline() + + err := func() error { + ws.net.ipWsMapMutex.RLock() + defer ws.net.ipWsMapMutex.RUnlock() + + if ws.dev.model == nil { + return fmt.Errorf("ping failed: the client is not logged in: %v", buffer) + } + + if ws.dev.model.Online { + ws.dev.model.Save() + } + + info := strings.Split(buffer, "::") + if len(info) < 3 || info[0] != "candy" { + return fmt.Errorf("ping failed: invalid format: %v", buffer) + } + + ws.dev.model.OS = info[1] + ws.dev.model.Version = info[2] + + if len(info) > 3 { + ws.dev.model.Hostname = info[3] + } + + return nil + }() + + if err != nil { + logger.Debug("client exception: %v", err) + } + + ws.writePong([]byte(buffer)) + return nil +} + +func (ws *candysocket) handleAuthMessage(buffer []byte) error { + r := bytes.NewReader(buffer) + message := &AuthMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if err := ws.net.checkAuthMessage(message); err != nil { + return err + } + + if ws.dev == nil { + return fmt.Errorf("auth failed: vmac not received") + } + + if ws.net.net != ws.net.mask&message.IP { + return fmt.Errorf("auth failed: network does not match") + } + + if ws.net.ipConflict(uint32ToStrIP(message.IP), ws.dev.model.VMac) { + ws.writeCloseMessage("ip conflict") + return fmt.Errorf("auth failed: ip conflict: %v", uint32ToStrIP(message.IP)) + } + + ws.net.ipWsMapMutex.Lock() + defer ws.net.ipWsMapMutex.Unlock() + + if oldws, ok := ws.net.ipWsMap[message.IP]; ok { + oldws.dev.model.Online = false + oldws.dev.model.Save() + oldws.writeCloseMessage("vmac conflict") + oldws.conn.Close() + } + + ws.dev.ip = message.IP + ws.net.ipWsMap[message.IP] = ws + + db := storage.Get() + db.Where(ws.dev.model).Find(ws.dev.model) + ws.dev.model.IP = uint32ToStrIP(message.IP) + ws.dev.model.Online = true + ws.dev.model.Save() + return nil +} + +func (ws *candysocket) handleForwardMessage(buffer []byte) error { + if ws.dev == nil { + return fmt.Errorf("forward failed: conn is not logged in") + } + + if !ws.dev.model.Online { + return nil + } + + r := bytes.NewReader(buffer) + message := &ForwardMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if ws.dev.ip != message.Src { + return fmt.Errorf("forward failed: source address does not match login information") + } + + ws.dev.model.TX += uint64(len(buffer)) + + ws.net.ipWsMapMutex.RLock() + defer ws.net.ipWsMapMutex.RUnlock() + + if dstWs, ok := ws.net.ipWsMap[message.Dst]; ok { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + + broadcast := func() bool { + if !ws.net.model.Broadcast { + return false + } + if ws.net.net|^ws.net.mask == message.Dst { + return true + } + if message.Dst&0xF0000000 == 0xE0000000 { + return true + } + return false + }() + + if broadcast { + for _, dstWs := range ws.net.ipWsMap { + if dstWs != ws && dstWs.dev.model.Online { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + } + } + + return nil +} + +func (ws *candysocket) handleDHCPMessage(buffer []byte) error { + r := bytes.NewReader(buffer) + message := &DHCPMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if err := ws.net.checkDHCPMessage(message); err != nil { + return err + } + + if ws.net.model.DHCP == "" { + return fmt.Errorf("dhcp failed: DHCP is not enabled") + } + + cidr := func(input []byte) string { + return string(input[:bytes.IndexByte(input[:], 0)]) + }(message.Cidr) + + if ws.dev.model == nil { + return fmt.Errorf("dhcp failed: vmac not received") + } + db := storage.Get() + ip, ipNet, err := net.ParseCIDR(cidr) + needGenNewAddr := func() bool { + if err != nil { + return true + } + if binary.BigEndian.Uint32(ipNet.IP) != ws.net.net { + return true + } + if binary.BigEndian.Uint32(ipNet.Mask) != ws.net.mask { + return true + } + devices := []model.Device{} + db.Where(&model.Device{NetID: ws.net.model.ID, IP: ip.String()}).Find(&devices) + if len(devices) > 1 { + return true + } + if len(devices) == 0 { + return false + } + if len(devices) == 1 && devices[0].VMac == ws.dev.model.VMac { + return false + } + return true + }() + + var oldHost = ws.net.host + for needGenNewAddr { + + result := db.Where(&model.Device{NetID: ws.net.model.ID, IP: ws.net.updateHost()}) + if result.RowsAffected == 0 { + break + } + if oldHost == ws.net.host { + return fmt.Errorf("dhcp failed: not enough addresses") + } + } + + if needGenNewAddr { + ipNet := net.IPNet{ + IP: make(net.IP, 4), + Mask: make(net.IPMask, 4), + } + binary.BigEndian.PutUint32(ipNet.IP, ws.net.net|ws.net.host) + binary.BigEndian.PutUint32(ipNet.Mask, ws.net.mask) + message.Cidr = []byte(ipNet.String()) + } + + var output bytes.Buffer + struc.Pack(&output, message) + ws.writeMessage(output.Bytes()) + return nil +} + +func (ws *candysocket) handlePeerConnMessage(buffer []byte) error { + if ws.dev == nil { + return fmt.Errorf("peer conn failed: conn is not logged in") + } + + r := bytes.NewReader(buffer) + message := &PeerConnMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if ws.dev.ip != message.Src { + return fmt.Errorf("peer conn failed: source address does not match login information") + } + + ws.net.ipWsMapMutex.RLock() + defer ws.net.ipWsMapMutex.RUnlock() + + if dstWs, ok := ws.net.ipWsMap[message.Dst]; ok { + dstWs.writeMessage(buffer) + } + + return nil +} + +func (ws *candysocket) handleVMacMessage(buffer []byte) error { + r := bytes.NewReader(buffer) + message := &VMacMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if err := ws.net.checkVMacMessage(message); err != nil { + return err + } + ws.dev = &Device{model: &model.Device{NetID: ws.net.model.ID, VMac: message.VMac}} + return nil +} + +func (ws *candysocket) handleDiscoveryMessage(buffer []byte) error { + if ws.dev == nil || !ws.dev.model.Online { + return nil + } + + r := bytes.NewReader(buffer) + message := &DiscoveryMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if ws.dev.ip != message.Src { + return fmt.Errorf("discovery failed: source address does not match login information") + } + + ws.dev.model.TX += uint64(len(buffer)) + + ws.net.ipWsMapMutex.RLock() + defer ws.net.ipWsMapMutex.RUnlock() + + if dstWs, ok := ws.net.ipWsMap[message.Dst]; ok { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + + if uint32(0xFFFFFFFF) == message.Dst { + for _, dstWs := range ws.net.ipWsMap { + if dstWs != ws && dstWs.dev.model.Online { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + } + } + + return nil +} + +func (ws *candysocket) handleGeneralMessage(buffer []byte) error { + if ws.dev == nil || !ws.dev.model.Online { + return nil + } + + r := bytes.NewReader(buffer) + message := &GeneralMessage{} + if err := struc.Unpack(r, message); err != nil { + return err + } + + if ws.dev.ip != message.Src { + return fmt.Errorf("general failed: source address does not match login information") + } + + ws.dev.model.TX += uint64(len(buffer)) + + ws.net.ipWsMapMutex.RLock() + defer ws.net.ipWsMapMutex.RUnlock() + + if dstWs, ok := ws.net.ipWsMap[message.Dst]; ok { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + + if ws.net.model.Broadcast && uint32(0xFFFFFFFF) == message.Dst { + for _, dstWs := range ws.net.ipWsMap { + if dstWs != ws && dstWs.dev.model.Online { + dstWs.writeMessage(buffer) + dstWs.dev.model.RX += uint64(len(buffer)) + } + } + } + + return nil +} + +func getNetFromPath(path string) *Net { + result := strings.Split(strings.Trim(path, "/"), "/") + if len(result) < 1 { + return nil + } + username := result[0] + netname := "@" + if len(result) > 1 { + if !IsAlphanumeric(result[1]) { + return nil + } + netname = result[1] + } + netid := model.GetNetIdByUsernameAndNetname(username, netname) + return getNetById(netid) +} diff --git a/go.mod b/go.mod index 4c739d4..23f9316 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/glebarez/sqlite v1.11.0 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/sirupsen/logrus v1.9.3 gorm.io/gorm v1.25.11 ) diff --git a/go.sum b/go.sum index dd27fa3..07c5777 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -50,6 +52,8 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc= +github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/logger/logger.go b/logger/logger.go index 996b0c3..c43f3ed 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -8,18 +8,6 @@ import ( "github.com/sirupsen/logrus" ) -var logger *logrus.Logger - -type logFormatter struct{} - -func (f *logFormatter) Format(entry *logrus.Entry) ([]byte, error) { - b := &bytes.Buffer{} - timestamp := entry.Time.Format("2006-01-02 15:04:05") - msg := fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, entry.Message) - b.WriteString(msg) - return b.Bytes(), nil -} - func init() { logger = logrus.New() logger.SetReportCaller(true) @@ -35,6 +23,18 @@ func init() { Info("loglevel=[%v]", logger.GetLevel().String()) } +var logger *logrus.Logger + +type logFormatter struct{} + +func (f *logFormatter) Format(entry *logrus.Entry) ([]byte, error) { + b := &bytes.Buffer{} + timestamp := entry.Time.Format("2006-01-02 15:04:05") + msg := fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, entry.Message) + b.WriteString(msg) + return b.Bytes(), nil +} + func Fatal(format string, args ...interface{}) { logger.Fatalf(format, args...) } diff --git a/main.go b/main.go index 2ec4a63..6b0bfb9 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,10 @@ package main import ( "github.com/gin-gonic/gin" + "github.com/lanthora/cacao/api" "github.com/lanthora/cacao/argp" + "github.com/lanthora/cacao/candy" "github.com/lanthora/cacao/logger" - "github.com/lanthora/cacao/user" ) func init() { @@ -16,10 +17,18 @@ func main() { logger.Info("listen=[%v]", addr) r := gin.New() - r.Use(user.LoginMiddleware()) - r.POST("/api/user/register", user.Register) - r.POST("/api/user/login", user.Login) - r.POST("/api/user/logout", user.Logout) + r.Use(candy.WebsocketMiddleware(), api.LoginMiddleware()) + + user := r.Group("/api/user") + user.POST("/register", api.UserRegister) + user.POST("/login", api.UserLogin) + user.POST("/logout", api.UserLogout) + + net := r.Group("/api/net") + net.POST("/show", api.NetShow) + net.POST("/insert", api.NetInsert) + net.POST("/edit", api.NetEdit) + net.POST("/delete", api.NetDelete) if err := r.Run(addr); err != nil { logger.Fatal("service run failed: %v", err) diff --git a/configs/configs.go b/model/config.go similarity index 91% rename from configs/configs.go rename to model/config.go index 7e0c800..9f3df3e 100644 --- a/configs/configs.go +++ b/model/config.go @@ -1,4 +1,4 @@ -package configs +package model import ( "github.com/lanthora/cacao/logger" @@ -6,18 +6,23 @@ import ( "gorm.io/gorm" ) +func init() { + db := storage.Get() + + if err := db.AutoMigrate(Config{}); err != nil { + logger.Fatal("auto migrate configs failed: %v", err) + } +} + type Config struct { gorm.Model Key string `gorm:"uniqueIndex"` Value string } -func init() { +func (c *Config) Save() { db := storage.Get() - - if err := db.AutoMigrate(Config{}); err != nil { - logger.Fatal("auto migrate configs failed: %v", err) - } + db.Save(c) } func SetString(key string, value string) { @@ -25,7 +30,7 @@ func SetString(key string, value string) { config := &Config{Key: key} db.Take(config) config.Value = value - db.Save(config) + config.Save() } func GetString(key string) (value string, ok bool) { diff --git a/configs/configs_test.go b/model/config_test.go similarity index 95% rename from configs/configs_test.go rename to model/config_test.go index 3811318..da205aa 100644 --- a/configs/configs_test.go +++ b/model/config_test.go @@ -1,4 +1,4 @@ -package configs +package model import ( "testing" diff --git a/model/device.go b/model/device.go new file mode 100644 index 0000000..94c3a4b --- /dev/null +++ b/model/device.go @@ -0,0 +1,35 @@ +package model + +import ( + "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/storage" + "gorm.io/gorm" +) + +func init() { + db := storage.Get() + err := db.AutoMigrate(Device{}) + if err != nil { + logger.Fatal("auto migrate devices failed: %v", err) + } +} + +type Device struct { + gorm.Model + NetID uint + VMac string + IP string + Online bool + RX uint64 + TX uint64 + OS string + Version string + Hostname string +} + +func (d *Device) Save() { + db := storage.Get() + if result := db.Save(d); result.Error != nil { + logger.Debug("save device failed: %v", result.Error) + } +} diff --git a/model/net.go b/model/net.go new file mode 100644 index 0000000..2fba6f3 --- /dev/null +++ b/model/net.go @@ -0,0 +1,58 @@ +package model + +import ( + "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/storage" + "gorm.io/gorm" +) + +func init() { + db := storage.Get() + err := db.AutoMigrate(Net{}) + if err != nil { + logger.Fatal("auto migrate nets failed: %v", err) + } +} + +type Net struct { + gorm.Model + UserID uint `gorm:"index:idx_net"` + Name string `gorm:"index:idx_net"` + Password string + DHCP string + Broadcast bool +} + +func (n *Net) Create() { + db := storage.Get() + db.Create(n) +} + +func (n *Net) Update() { + db := storage.Get() + db.Updates(n) +} + +func (n *Net) Delete() { + db := storage.Get() + db.Unscoped().Delete(n) +} + +func GetNets() (nets []Net) { + db := storage.Get() + db.Find(&nets) + return +} + +func GetNetsByUserID(userid uint) (nets []Net) { + db := storage.Get() + db.Where(&Net{UserID: userid}).Find(&nets) + return +} + +func GetNetIdByUsernameAndNetname(username, netname string) uint { + netid := uint(0) + db := storage.Get() + db.Model(&Net{}).Select("nets.id").Joins("left join users on users.id = nets.user_id").Where("users.name=? and nets.name = ?", username, netname).Take(&netid) + return netid +} diff --git a/model/user.go b/model/user.go new file mode 100644 index 0000000..a7f5a67 --- /dev/null +++ b/model/user.go @@ -0,0 +1,29 @@ +package model + +import ( + "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/storage" + "gorm.io/gorm" +) + +func init() { + db := storage.Get() + err := db.AutoMigrate(User{}) + if err != nil { + logger.Fatal("auto migrate users failed: %v", err) + } +} + +type User struct { + gorm.Model + Name string `gorm:"uniqueIndex"` + Password string + Token string + Role string + IP string +} + +func (u *User) Save() { + db := storage.Get() + db.Save(u) +} diff --git a/status/status.go b/status/status.go index 1eafc30..e99856c 100644 --- a/status/status.go +++ b/status/status.go @@ -6,6 +6,21 @@ import ( "github.com/gin-gonic/gin" ) +func init() { + statusMessage = make(map[int]string) + statusMessage[Success] = "success" + statusMessage[Unexpected] = "unexpected" + statusMessage[NotLoggedIn] = "not logged in" + statusMessage[InvalidRequest] = "invalid request " + statusMessage[InvalidUsername] = "invalid username" + statusMessage[InvalidPassword] = "invalid password" + statusMessage[RegisterTooFrequently] = "register too frequently" + statusMessage[UsernameAlreadyTaken] = "username already taken" + statusMessage[UsernameOrPasswordIncorrect] = "username or password is incorret" + statusMessage[NetworkAlreadyExists] = "network already exists" + statusMessage[NetworkNotExists] = "network not exists" +} + const ( Success int = iota Unexpected @@ -16,23 +31,12 @@ const ( RegisterTooFrequently UsernameAlreadyTaken UsernameOrPasswordIncorrect + NetworkAlreadyExists + NetworkNotExists ) var statusMessage map[int]string -func init() { - statusMessage = make(map[int]string) - statusMessage[Success] = "success" - statusMessage[Unexpected] = "unexpected" - statusMessage[NotLoggedIn] = "not logged in" - statusMessage[InvalidRequest] = "invalid request " - statusMessage[InvalidUsername] = "invalid username" - statusMessage[InvalidPassword] = "invalid password" - statusMessage[RegisterTooFrequently] = "register too frequently" - statusMessage[UsernameAlreadyTaken] = "username already taken" - statusMessage[UsernameOrPasswordIncorrect] = "username or password is incorret" -} - func UpdateSuccess(c *gin.Context, data gin.H) { c.JSON(http.StatusOK, gin.H{ "status": Success, diff --git a/storage/storage.go b/storage/storage.go index 9be42e1..da86783 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -11,8 +11,6 @@ import ( gormlogger "gorm.io/gorm/logger" ) -var db *gorm.DB - func init() { storageDir := argp.Get("storage", ".") err := os.MkdirAll(storageDir, os.ModeDir|os.ModePerm) @@ -29,6 +27,8 @@ func init() { logger.Info("storage=[%v]", storageDir) } +var db *gorm.DB + func Get() *gorm.DB { return db }