diff --git a/lib/attack.go b/lib/attack.go index 80df5e1a..8467b943 100644 --- a/lib/attack.go +++ b/lib/attack.go @@ -375,6 +375,33 @@ func DNSCaching(ttl time.Duration) func(*Attacker) { } } +// firstOfEachIPFamily returns the first IP of each IP family in the input slice. +func firstOfEachIPFamily(ips []string) []string { + if len(ips) == 0 { + return ips + } + + var ( + lastV4 bool + each = ips[:0] + ) + + for i := 0; i < len(ips) && len(each) < 2; i++ { + ip := net.ParseIP(ips[i]) + if ip == nil { + continue + } + + isV4 := ip.To4() != nil + if len(each) == 0 || isV4 != lastV4 { + each = append(each, ips[i]) + lastV4 = isV4 + } + } + + return each +} + type attack struct { name string began time.Time diff --git a/lib/attack_test.go b/lib/attack_test.go index c6f2c368..4ae0145f 100644 --- a/lib/attack_test.go +++ b/lib/attack_test.go @@ -17,6 +17,8 @@ import ( "strings" "testing" "time" + + "github.com/google/go-cmp/cmp" ) func TestAttackRate(t *testing.T) { @@ -420,3 +422,79 @@ func TestDNSCaching_Issue649(t *testing.T) { atk := NewAttacker(DNSCaching(0)) _ = atk.hit(tr, &attack{name: "TEST", began: time.Now()}) } + +func TestFirstOfEachIPFamily(t *testing.T) { + tests := []struct { + name string + input []string + want []string + }{ + { + name: "empty list", + input: []string{}, + want: []string{}, + }, + { + name: "single IPv4", + input: []string{"192.168.1.1"}, + want: []string{"192.168.1.1"}, + }, + { + name: "single IPv6", + input: []string{"fe80::1"}, + want: []string{"fe80::1"}, + }, + { + name: "multiple IPv6", + input: []string{"fe80::1", "fe80::2"}, + want: []string{"fe80::1"}, + }, + { + name: "one IPv4 and one IPv6", + input: []string{"192.168.1.1", "fe80::1"}, + want: []string{"192.168.1.1", "fe80::1"}, + }, + { + name: "one IPv6 and one IPv4", + input: []string{"fe80::1", "192.168.1.1"}, + want: []string{"fe80::1", "192.168.1.1"}, + }, + { + name: "multiple IPs with alternating versions", + input: []string{"192.168.1.1", "fe80::1", "192.168.1.2", "fe80::2"}, + want: []string{"192.168.1.1", "fe80::1"}, + }, + { + name: "multiple IPs with same versions", + input: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}, + want: []string{"192.168.1.1"}, + }, + { + name: "multiple IPs with non-alternating versions", + input: []string{"192.168.1.1", "fe80::1", "192.168.1.2", "192.168.1.3", "fe80::2"}, + want: []string{"192.168.1.1", "fe80::1"}, + }, + { + name: "invalid IP addresses", + input: []string{"invalid", "192.168.1.1", "fe80::1"}, + want: []string{"192.168.1.1", "fe80::1"}, + }, + { + name: "IPv4 with embedded IPv6", + input: []string{"192.168.1.1", "::ffff:c000:280", "fe80::1"}, + want: []string{"192.168.1.1", "fe80::1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := firstOfEachIPFamily(tt.input) + if len(result) != len(tt.want) { + t.Fatalf("want %v, got %v", tt.want, result) + } + if diff := cmp.Diff(tt.want, result); diff != "" { + t.Errorf("unexpected result (-want +got):\n%s", diff) + } + }) + } +}