Skip to content

Commit

Permalink
add more test cases for socks5
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Dec 31, 2018
1 parent 72d7850 commit 445889c
Show file tree
Hide file tree
Showing 23 changed files with 1,168 additions and 417 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ go:

install: true
script:
- env GO111MODULE=on go test -race -v -coverprofile=coverage.txt -covermode=atomic
- cd cmd/gost && env GO111MODULE=on go build
- go test -race -v -coverprofile=coverage.txt -covermode=atomic
- cd cmd/gost && go build

after_success:
- bash <(curl -s https://codecov.io/bash)
15 changes: 11 additions & 4 deletions bypass_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gost

import (
"bytes"
"fmt"
"io"
"testing"
"time"
Expand Down Expand Up @@ -158,10 +159,13 @@ var bypassContainTests = []struct {

func TestBypassContains(t *testing.T) {
for i, tc := range bypassContainTests {
bp := NewBypassPatterns(tc.reversed, tc.patterns...)
if bp.Contains(tc.addr) != tc.bypassed {
t.Errorf("#%d test failed: %v, %s", i, tc.patterns, tc.addr)
}
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
bp := NewBypassPatterns(tc.reversed, tc.patterns...)
if bp.Contains(tc.addr) != tc.bypassed {
t.Errorf("#%d test failed: %v, %s", i, tc.patterns, tc.addr)
}
})
}
}

Expand Down Expand Up @@ -244,6 +248,9 @@ func TestByapssReload(t *testing.T) {
}
if tc.stopped {
bp.Stop()
if bp.Period() >= 0 {
t.Errorf("period of the stopped reloader should be minus value")
}
}
if bp.Stopped() != tc.stopped {
t.Errorf("#%d test failed: stopped value should be %v, got %v",
Expand Down
8 changes: 5 additions & 3 deletions cmd/gost/.config/dns.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# resolver timeout, default 30s.
timeout 10s

# resolver cache TTL, default 60s, minus value means that cache is disabled.
ttl 300s
# resolver cache TTL,
# minus value means that cache is disabled,
# default to the TTL in DNS server response.
# ttl 300s

# period for live reloading
reload 10s

# ip[:port] [protocol] [hostname]

https://1.0.0.1/dns-query
1.1.1.1:853 tls cloudflare-dns.com
https://1.0.0.1/dns-query https
8.8.8.8
8.8.8.8 tcp
1.1.1.1 udp
Expand Down
7 changes: 2 additions & 5 deletions cmd/gost/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/url"
"os"
"strings"
"time"

"github.com/ginuerzh/gost"
)
Expand Down Expand Up @@ -196,8 +195,6 @@ func parseResolver(cfg string) gost.Resolver {
if cfg == "" {
return nil
}
timeout := 30 * time.Second
ttl := 60 * time.Second
var nss []gost.NameServer

f, err := os.Open(cfg)
Expand Down Expand Up @@ -237,11 +234,11 @@ func parseResolver(cfg string) gost.Resolver {
}
}
}
return gost.NewResolver(timeout, ttl, nss...)
return gost.NewResolver(0, nss...)
}
defer f.Close()

resolver := gost.NewResolver(timeout, ttl)
resolver := gost.NewResolver(0)
resolver.Reload(f)

go gost.PeriodReload(resolver, cfg)
Expand Down
134 changes: 126 additions & 8 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package gost

import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
Expand Down Expand Up @@ -37,6 +41,110 @@ var (
})
)

// proxyConn obtains a connection to the proxy server.
func proxyConn(client *Client, server *Server) (net.Conn, error) {
conn, err := client.Dial(server.Addr().String())
if err != nil {
return nil, err
}

cc, err := client.Handshake(conn, AddrHandshakeOption(server.Addr().String()))
if err != nil {
conn.Close()
return nil, err
}

return cc, nil
}

// httpRoundtrip does a HTTP request-response roundtrip, and checks the data received.
func httpRoundtrip(conn net.Conn, targetURL string, data []byte) (err error) {
req, err := http.NewRequest(
http.MethodGet,
targetURL,
bytes.NewReader(data),
)
if err != nil {
return
}
if err = req.Write(conn); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}

recv, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}

if !bytes.Equal(data, recv) {
return fmt.Errorf("data not equal")
}
return
}

func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err error) {
conn, err := proxyConn(client, server)
if err != nil {
return
}
defer conn.Close()

conn, err = client.Connect(conn, host)
if err != nil {
return
}

conn.SetDeadline(time.Now().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{})

if _, err = conn.Write(data); err != nil {
return
}

recv := make([]byte, len(data))
if _, err = conn.Read(recv); err != nil {
return
}

if !bytes.Equal(data, recv) {
return fmt.Errorf("data not equal")
}

return
}

func proxyRoundtrip(client *Client, server *Server, targetURL string, data []byte) (err error) {
conn, err := proxyConn(client, server)
if err != nil {
return err
}
defer conn.Close()

u, err := url.Parse(targetURL)
if err != nil {
return
}

conn, err = client.Connect(conn, u.Host)
if err != nil {
return
}

conn.SetDeadline(time.Now().Add(500 * time.Millisecond))
defer conn.SetDeadline(time.Time{})

return httpRoundtrip(conn, targetURL, data)
}

type udpRequest struct {
Body io.Reader
RemoteAddr string
Expand All @@ -55,11 +163,12 @@ type udpHandlerFunc func(w io.Writer, r *udpRequest)

// udpTestServer is a UDP server for test.
type udpTestServer struct {
ln net.PacketConn
handler udpHandlerFunc
wg sync.WaitGroup
mu sync.Mutex // guards closed and conns
closed bool
ln net.PacketConn
handler udpHandlerFunc
wg sync.WaitGroup
mu sync.Mutex // guards closed and conns
closed bool
exitChan chan struct{}
}

func newUDPTestServer(handler udpHandlerFunc) *udpTestServer {
Expand All @@ -68,9 +177,13 @@ func newUDPTestServer(handler udpHandlerFunc) *udpTestServer {
if err != nil {
panic(fmt.Sprintf("udptest: failed to listen on a port: %v", err))
}
ln.SetReadBuffer(1024 * 1024)
ln.SetWriteBuffer(1024 * 1024)

return &udpTestServer{
ln: ln,
handler: handler,
ln: ln,
handler: handler,
exitChan: make(chan struct{}),
}
}

Expand All @@ -83,7 +196,7 @@ func (s *udpTestServer) serve() {
data := make([]byte, 1024)
n, raddr, err := s.ln.ReadFrom(data)
if err != nil {
return
break
}
if s.handler != nil {
s.wg.Add(1)
Expand All @@ -101,6 +214,9 @@ func (s *udpTestServer) serve() {
}()
}
}

// signal the listener has been exited.
close(s.exitChan)
}

func (s *udpTestServer) Addr() string {
Expand All @@ -119,6 +235,8 @@ func (s *udpTestServer) Close() error {
s.closed = true
s.mu.Unlock()

<-s.exitChan

s.wg.Wait()

return err
Expand Down
11 changes: 0 additions & 11 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,6 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) {

go ln.listenLoop()

// if err = <-ln.errChan; err != nil {
// ln.Close()
// }

return ln, err
}

Expand All @@ -680,17 +676,10 @@ func (l *tcpRemoteForwardListener) isChainValid() bool {

func (l *tcpRemoteForwardListener) listenLoop() {
var tempDelay time.Duration
// var once sync.Once

for {
conn, err := l.accept()

// once.Do(func() {
// l.errChan <- err
// log.Log("once.Do error:", err)
// close(l.errChan)
// })

select {
case <-l.closed:
if conn != nil {
Expand Down
34 changes: 0 additions & 34 deletions forward_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package gost

import (
"bytes"
"crypto/rand"
"fmt"
"net/http/httptest"
"net/url"
"testing"
"time"
)

func tcpDirectForwardRoundtrip(targetURL string, data []byte) error {
Expand Down Expand Up @@ -122,37 +119,6 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) {
})
}

func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err error) {
conn, err := proxyConn(client, server)
if err != nil {
return
}
defer conn.Close()

conn.SetDeadline(time.Now().Add(1 * time.Second))
defer conn.SetDeadline(time.Time{})

conn, err = client.Connect(conn, host)
if err != nil {
return
}

if _, err = conn.Write(data); err != nil {
return
}

recv := make([]byte, len(data))
if _, err = conn.Read(recv); err != nil {
return
}

if !bytes.Equal(data, recv) {
return fmt.Errorf("data not equal")
}

return
}

func udpDirectForwardRoundtrip(host string, data []byte) error {
ln, err := UDPDirectForwardListener("localhost:0", 0)
if err != nil {
Expand Down
Loading

0 comments on commit 445889c

Please sign in to comment.