Skip to content

Commit 6617275

Browse files
authored
Merge pull request #23 from acd/verify-remote
Validate IP address and port of received packets (take 2)
2 parents 54292c7 + b9cf328 commit 6617275

7 files changed

+63
-36
lines changed

client.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type Client struct {
4949

5050
// Send starts outgoing file transmission. It returns io.ReaderFrom or error.
5151
func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
52-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
52+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
5353
if err != nil {
5454
return nil, err
5555
}
@@ -79,7 +79,7 @@ func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
7979

8080
// Receive starts incoming file transmission. It returns io.WriterTo or error.
8181
func (c Client) Receive(filename string, mode string) (io.WriterTo, error) {
82-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
82+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
8383
if err != nil {
8484
return nil, err
8585
}

conn.go

-11
This file was deleted.

conn_darwin.go

-17
This file was deleted.

receiver.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type receiver struct {
4242
send []byte
4343
receive []byte
4444
addr *net.UDPAddr
45+
tid int
4546
conn *net.UDPConn
4647
block uint16
4748
retry *backoff
@@ -159,11 +160,14 @@ func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) {
159160
if err != nil {
160161
return 0, nil, err
161162
}
162-
// TODO: compare addr here?
163+
if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) {
164+
continue
165+
}
163166
p, err := parsePacket(r.receive[:c])
164167
if err != nil {
165168
return 0, addr, err
166169
}
170+
r.tid = addr.Port
167171
switch p := p.(type) {
168172
case pDATA:
169173
if p.block() == r.block {

sender.go

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type OutgoingTransfer interface {
3333
type sender struct {
3434
conn *net.UDPConn
3535
addr *net.UDPAddr
36+
tid int
3637
send []byte
3738
receive []byte
3839
retry *backoff
@@ -189,10 +190,14 @@ func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
189190
if err != nil {
190191
return nil, err
191192
}
193+
if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) {
194+
continue
195+
}
192196
p, err := parsePacket(s.receive[:n])
193197
if err != nil {
194198
continue
195199
}
200+
s.tid = addr.Port
196201
switch p := p.(type) {
197202
case pACK:
198203
if p.block() == s.block {

server.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (s *Server) processRequest(conn *net.UDPConn) error {
118118
return fmt.Errorf("unpack WRQ: %v", err)
119119
}
120120
//fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
121-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
121+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
122122
if err != nil {
123123
return err
124124
}
@@ -157,13 +157,14 @@ func (s *Server) processRequest(conn *net.UDPConn) error {
157157
return fmt.Errorf("unpack RRQ: %v", err)
158158
}
159159
//fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
160-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
160+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
161161
if err != nil {
162162
return err
163163
}
164164
rf := &sender{
165165
send: make([]byte, datagramLength),
166166
receive: make([]byte, datagramLength),
167+
tid: remoteAddr.Port,
167168
conn: conn,
168169
retry: &backoff{},
169170
timeout: s.timeout,

tftp_test.go

+48-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,57 @@ import (
88
"math/rand"
99
"net"
1010
"os"
11+
"strconv"
1112
"sync"
1213
"testing"
1314
"testing/iotest"
1415
"time"
1516
)
1617

18+
var localhost string = determineLocalhost()
19+
20+
func determineLocalhost() string {
21+
l, err := net.ListenTCP("tcp", nil)
22+
if err != nil {
23+
panic(fmt.Sprintf("ListenTCP error: %s", err))
24+
}
25+
_, lport, _ := net.SplitHostPort(l.Addr().String())
26+
defer l.Close()
27+
28+
lo := make(chan string)
29+
30+
go func() {
31+
for {
32+
conn, err := l.Accept()
33+
if err != nil {
34+
break
35+
}
36+
conn.Close()
37+
}
38+
}()
39+
40+
go func() {
41+
port, _ := strconv.Atoi(lport)
42+
for _, af := range []string{"tcp6", "tcp4"} {
43+
conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port})
44+
if err == nil {
45+
conn.Close()
46+
host, _, _ := net.SplitHostPort(conn.LocalAddr().String())
47+
lo <- host
48+
return
49+
}
50+
}
51+
panic("could not determine address family")
52+
}()
53+
54+
return <-lo
55+
}
56+
57+
func localSystem(c *net.UDPConn) string {
58+
_, port, _ := net.SplitHostPort(c.LocalAddr().String())
59+
return net.JoinHostPort(localhost, port)
60+
}
61+
1762
func TestPackUnpack(t *testing.T) {
1863
v := []string{"test-filename/with-subdir"}
1964
testOptsList := []options{
@@ -272,7 +317,7 @@ func TestSendTsizeFromSeek(t *testing.T) {
272317
return nil
273318
}, nil)
274319

275-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
320+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
276321
if err != nil {
277322
t.Fatalf("listening: %v", err)
278323
}
@@ -310,7 +355,7 @@ func makeTestServer() (*Server, *Client) {
310355
// Create server
311356
s := NewServer(b.handleRead, b.handleWrite)
312357

313-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
358+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
314359
if err != nil {
315360
panic(err)
316361
}
@@ -329,7 +374,7 @@ func makeTestServer() (*Server, *Client) {
329374
func TestNoHandlers(t *testing.T) {
330375
s := NewServer(nil, nil)
331376

332-
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
377+
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
333378
if err != nil {
334379
panic(err)
335380
}

0 commit comments

Comments
 (0)