diff --git a/.env.example b/.env.example index 65541ae..442eed6 100644 --- a/.env.example +++ b/.env.example @@ -4,4 +4,4 @@ DB_USER=postgres DB_PASSWORD=postgres DB_NAME=nuklaivm DB_SSLMODE=require # Or "disable" if you don't want to use SSL -GRPC_WHITELISTED_BLOCKCHAIN_NODES="192.168.1.100" # "127.0.0.1,localhost,::1" is already included by default. You can even include something like myblockchain.aws.com +GRPC_WHITELISTED_BLOCKCHAIN_NODES="192.168.1.100,172.28.0.0/12," # "127.0.0.1,localhost,::1" is already included by default. You can even include something like myblockchain.aws.com diff --git a/config/config.go b/config/config.go index f95877a..a4441b5 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "log" "net" "net/url" "os" @@ -30,25 +31,33 @@ func GetDatabaseURL() string { // GetWhitelistIPs retrieves the list of whitelisted IPs from the environment variable // and resolves domain names to IPs. -func GetWhitelistIPs() []string { +// GetWhitelistIPs retrieves the list of whitelisted IPs and CIDR ranges +func GetWhitelistIPs() ([]string, []string) { ipList := getEnv("GRPC_WHITELISTED_BLOCKCHAIN_NODES", "127.0.0.1,localhost,::1") - entries := strings.Split(ipList, ",") // Split by comma + entries := strings.Split(ipList, ",") - whitelist := []string{} - defaultIPs := []string{"127.0.0.1", "localhost", "::1"} // Always include these + whitelistIPs := []string{} + whitelistCIDRs := []string{} + defaultEntries := []string{"127.0.0.1", "localhost", "::1"} - // Resolve domain names and add IPs to the whitelist - for _, entry := range append(defaultIPs, entries...) { + // Combine default entries and user-provided entries + for _, entry := range append(defaultEntries, entries...) { entry = strings.TrimSpace(entry) - ips, err := resolveToIPs(entry) - if err != nil { - // Log the error and skip unresolved entries - continue + if strings.Contains(entry, "/") { + // CIDR range + whitelistCIDRs = append(whitelistCIDRs, entry) + } else { + // IP or domain + ips, err := resolveToIPs(entry) + if err == nil { + whitelistIPs = append(whitelistIPs, ips...) + } else { + log.Printf("Failed to resolve: %s, skipping", entry) + } } - whitelist = append(whitelist, ips...) } - return uniqueStrings(whitelist) // Ensure no duplicates + return uniqueStrings(whitelistIPs), uniqueStrings(whitelistCIDRs) } // getEnv retrieves the value of the environment variable named by the key. diff --git a/grpc/whitelist.go b/grpc/whitelist.go index 0ba0f94..1eff6dc 100644 --- a/grpc/whitelist.go +++ b/grpc/whitelist.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net" "strings" "github.com/nuklai/nuklaivm-external-subscriber/config" @@ -11,22 +12,50 @@ import ( "google.golang.org/grpc/peer" ) -var WhitelistedIPs = make(map[string]bool) +var ( + WhitelistedIPs = make(map[string]bool) + WhitelistedCIDRs []*net.IPNet +) // LoadWhitelist loads the whitelist using the config package func LoadWhitelist() { - ips := config.GetWhitelistIPs() - if len(ips) == 0 { - log.Println("No whitelisted IPs provided. The gRPC server will reject all connections.") - return - } + ips, cidrs := config.GetWhitelistIPs() - // Populate the whitelist map + // Load individual IPs for _, ip := range ips { WhitelistedIPs[ip] = true } - log.Printf("Loaded whitelisted IPs: %v\n", WhitelistedIPs) + // Load CIDR ranges + for _, cidr := range cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + log.Printf("Invalid CIDR range: %s", cidr) + continue + } + WhitelistedCIDRs = append(WhitelistedCIDRs, ipNet) + } + + log.Printf("Loaded whitelisted IPs: %v", WhitelistedIPs) + log.Printf("Loaded whitelisted CIDRs: %v", cidrs) +} + +// isAllowedIP checks if an IP is whitelisted +func isAllowedIP(clientIP string) bool { + // Check against individual IPs + if WhitelistedIPs[clientIP] { + return true + } + + // Check against CIDR ranges + ip := net.ParseIP(clientIP) + for _, ipNet := range WhitelistedCIDRs { + if ipNet.Contains(ip) { + return true + } + } + + return false } // UnaryInterceptor checks the IP of the client and allows/denies the connection @@ -36,8 +65,8 @@ func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ return nil, fmt.Errorf("could not retrieve peer info") } - clientIP := strings.Split(peerInfo.Addr.String(), ":")[0] // Extract IP address - if !WhitelistedIPs[clientIP] { + clientIP := strings.Split(peerInfo.Addr.String(), ":")[0] + if !isAllowedIP(clientIP) { log.Printf("Unauthorized connection attempt from IP: %s", clientIP) return nil, fmt.Errorf("unauthorized IP: %s", clientIP) }