Skip to content

Commit

Permalink
lib: fix IP selection in DNS caching
Browse files Browse the repository at this point in the history
Fixes #677
  • Loading branch information
tsenart committed Jul 27, 2024
1 parent b0b14b9 commit 1799abc
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lib/attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,33 @@ func DNSCaching(ttl time.Duration) func(*Attacker) {
}
}

// firstOfEachIPFamily returns the first IP of each IP family in the input slice.
func firstOfEachIPFamily(ips []string) []string {
if len(ips) == 0 {
return ips
}

var (
lastV4 bool
each = ips[:0]
)

for i := 0; i < len(ips) && len(each) < 2; i++ {
ip := net.ParseIP(ips[i])
if ip == nil {
continue
}

isV4 := ip.To4() != nil
if len(each) == 0 || isV4 != lastV4 {
each = append(each, ips[i])
lastV4 = isV4
}
}

return each
}

type attack struct {
name string
began time.Time
Expand Down
78 changes: 78 additions & 0 deletions lib/attack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
)

func TestAttackRate(t *testing.T) {
Expand Down Expand Up @@ -420,3 +422,79 @@ func TestDNSCaching_Issue649(t *testing.T) {
atk := NewAttacker(DNSCaching(0))
_ = atk.hit(tr, &attack{name: "TEST", began: time.Now()})
}

func TestFirstOfEachIPFamily(t *testing.T) {
tests := []struct {
name string
input []string
want []string
}{
{
name: "empty list",
input: []string{},
want: []string{},
},
{
name: "single IPv4",
input: []string{"192.168.1.1"},
want: []string{"192.168.1.1"},
},
{
name: "single IPv6",
input: []string{"fe80::1"},
want: []string{"fe80::1"},
},
{
name: "multiple IPv6",
input: []string{"fe80::1", "fe80::2"},
want: []string{"fe80::1"},
},
{
name: "one IPv4 and one IPv6",
input: []string{"192.168.1.1", "fe80::1"},
want: []string{"192.168.1.1", "fe80::1"},
},
{
name: "one IPv6 and one IPv4",
input: []string{"fe80::1", "192.168.1.1"},
want: []string{"fe80::1", "192.168.1.1"},
},
{
name: "multiple IPs with alternating versions",
input: []string{"192.168.1.1", "fe80::1", "192.168.1.2", "fe80::2"},
want: []string{"192.168.1.1", "fe80::1"},
},
{
name: "multiple IPs with same versions",
input: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"},
want: []string{"192.168.1.1"},
},
{
name: "multiple IPs with non-alternating versions",
input: []string{"192.168.1.1", "fe80::1", "192.168.1.2", "192.168.1.3", "fe80::2"},
want: []string{"192.168.1.1", "fe80::1"},
},
{
name: "invalid IP addresses",
input: []string{"invalid", "192.168.1.1", "fe80::1"},
want: []string{"192.168.1.1", "fe80::1"},
},
{
name: "IPv4 with embedded IPv6",
input: []string{"192.168.1.1", "::ffff:c000:280", "fe80::1"},
want: []string{"192.168.1.1", "fe80::1"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := firstOfEachIPFamily(tt.input)
if len(result) != len(tt.want) {
t.Fatalf("want %v, got %v", tt.want, result)
}
if diff := cmp.Diff(tt.want, result); diff != "" {
t.Errorf("unexpected result (-want +got):\n%s", diff)
}
})
}
}

0 comments on commit 1799abc

Please sign in to comment.