Skip to content

Commit

Permalink
Added the capability to handle cidrs in addition to ip addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
kpachhai committed Dec 2, 2024
1 parent 5c86ecf commit 61f8b43
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ DB_USER=postgres
DB_PASSWORD=postgres
DB_NAME=nuklaivm
DB_SSLMODE=require # Or "disable" if you don't want to use SSL
GRPC_WHITELISTED_BLOCKCHAIN_NODES="192.168.1.100" # "127.0.0.1,localhost,::1" is already included by default. You can even include something like myblockchain.aws.com
GRPC_WHITELISTED_BLOCKCHAIN_NODES="192.168.1.100,172.28.0.0/12," # "127.0.0.1,localhost,::1" is already included by default. You can even include something like myblockchain.aws.com
33 changes: 21 additions & 12 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"fmt"
"log"
"net"
"net/url"
"os"
Expand Down Expand Up @@ -30,25 +31,33 @@ func GetDatabaseURL() string {

// GetWhitelistIPs retrieves the list of whitelisted IPs from the environment variable
// and resolves domain names to IPs.
func GetWhitelistIPs() []string {
// GetWhitelistIPs retrieves the list of whitelisted IPs and CIDR ranges
func GetWhitelistIPs() ([]string, []string) {
ipList := getEnv("GRPC_WHITELISTED_BLOCKCHAIN_NODES", "127.0.0.1,localhost,::1")
entries := strings.Split(ipList, ",") // Split by comma
entries := strings.Split(ipList, ",")

whitelist := []string{}
defaultIPs := []string{"127.0.0.1", "localhost", "::1"} // Always include these
whitelistIPs := []string{}
whitelistCIDRs := []string{}
defaultEntries := []string{"127.0.0.1", "localhost", "::1"}

// Resolve domain names and add IPs to the whitelist
for _, entry := range append(defaultIPs, entries...) {
// Combine default entries and user-provided entries
for _, entry := range append(defaultEntries, entries...) {
entry = strings.TrimSpace(entry)
ips, err := resolveToIPs(entry)
if err != nil {
// Log the error and skip unresolved entries
continue
if strings.Contains(entry, "/") {
// CIDR range
whitelistCIDRs = append(whitelistCIDRs, entry)
} else {
// IP or domain
ips, err := resolveToIPs(entry)
if err == nil {
whitelistIPs = append(whitelistIPs, ips...)
} else {
log.Printf("Failed to resolve: %s, skipping", entry)
}
}
whitelist = append(whitelist, ips...)
}

return uniqueStrings(whitelist) // Ensure no duplicates
return uniqueStrings(whitelistIPs), uniqueStrings(whitelistCIDRs)
}

// getEnv retrieves the value of the environment variable named by the key.
Expand Down
49 changes: 39 additions & 10 deletions grpc/whitelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,58 @@ import (
"context"
"fmt"
"log"
"net"
"strings"

"github.com/nuklai/nuklaivm-external-subscriber/config"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)

var WhitelistedIPs = make(map[string]bool)
var (
WhitelistedIPs = make(map[string]bool)
WhitelistedCIDRs []*net.IPNet
)

// LoadWhitelist loads the whitelist using the config package
func LoadWhitelist() {
ips := config.GetWhitelistIPs()
if len(ips) == 0 {
log.Println("No whitelisted IPs provided. The gRPC server will reject all connections.")
return
}
ips, cidrs := config.GetWhitelistIPs()

// Populate the whitelist map
// Load individual IPs
for _, ip := range ips {
WhitelistedIPs[ip] = true
}

log.Printf("Loaded whitelisted IPs: %v\n", WhitelistedIPs)
// Load CIDR ranges
for _, cidr := range cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
log.Printf("Invalid CIDR range: %s", cidr)
continue
}
WhitelistedCIDRs = append(WhitelistedCIDRs, ipNet)
}

log.Printf("Loaded whitelisted IPs: %v", WhitelistedIPs)
log.Printf("Loaded whitelisted CIDRs: %v", cidrs)
}

// isAllowedIP checks if an IP is whitelisted
func isAllowedIP(clientIP string) bool {
// Check against individual IPs
if WhitelistedIPs[clientIP] {
return true
}

// Check against CIDR ranges
ip := net.ParseIP(clientIP)
for _, ipNet := range WhitelistedCIDRs {
if ipNet.Contains(ip) {
return true
}
}

return false
}

// UnaryInterceptor checks the IP of the client and allows/denies the connection
Expand All @@ -36,8 +65,8 @@ func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ
return nil, fmt.Errorf("could not retrieve peer info")
}

clientIP := strings.Split(peerInfo.Addr.String(), ":")[0] // Extract IP address
if !WhitelistedIPs[clientIP] {
clientIP := strings.Split(peerInfo.Addr.String(), ":")[0]
if !isAllowedIP(clientIP) {
log.Printf("Unauthorized connection attempt from IP: %s", clientIP)
return nil, fmt.Errorf("unauthorized IP: %s", clientIP)
}
Expand Down

0 comments on commit 61f8b43

Please sign in to comment.