diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 8c83239..3fb2754 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.20.x + go-version: 1.21.x - name: Check out code uses: actions/checkout@v3 diff --git a/.github/workflows/changelog-update.yml b/.github/workflows/changelog-update.yml index 7300c4a..ad71564 100644 --- a/.github/workflows/changelog-update.yml +++ b/.github/workflows/changelog-update.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.20.x + go-version: 1.21.x - name: Checkout code uses: actions/checkout@v3 diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index dd86e01..ac6b5e2 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.20.x + go-version: 1.21.x - name: Run golangci-lint uses: golangci/golangci-lint-action@v3.7.0 with: diff --git a/context/NContext.go b/context/NContext.go new file mode 100644 index 0000000..876e268 --- /dev/null +++ b/context/NContext.go @@ -0,0 +1,91 @@ +package contextutil + +import "context" + +// A problematic situation when implementing context in a function +// is when that function has more than one return values +// if function has only one return value we can safely wrap it something like this +/* + func DoSomething() error {} + ch := make(chan error) + go func() { + ch <- DoSomething() + }() + select { + case err := <-ch: + // handle error + case <-ctx.Done(): + // handle context cancelation + } +*/ +// but what if we have more than one value to return? +// we can use generics and a struct and that is what we are doing here +// here we use struct and generics to store return values of a function +// instead of storing it in a []interface{} + +type twoValueCtx[T1 any, T2 any] struct { + var1 T1 + var2 T2 +} + +type threeValueCtx[T1 any, T2 any, T3 any] struct { + var1 T1 + var2 T2 + var3 T3 +} + +// ExecFunc implements context for a function which has no return values +// and executes that function. if context is cancelled before function returns +// it will return context error otherwise it will return nil +func ExecFunc(ctx context.Context, fn func()) error { + ch := make(chan struct{}) + go func() { + fn() + ch <- struct{}{} + }() + select { + case <-ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// ExecFuncWithTwoReturns wraps a function which has two return values given that last one is error +// and executes that function in a goroutine there by implementing context +// if context is cancelled before function returns it will return context error +// otherwise it will return function's return values +func ExecFuncWithTwoReturns[T1 any](ctx context.Context, fn func() (T1, error)) (T1, error) { + ch := make(chan twoValueCtx[T1, error]) + go func() { + x, y := fn() + ch <- twoValueCtx[T1, error]{var1: x, var2: y} + }() + select { + case <-ctx.Done(): + var tmp T1 + return tmp, ctx.Err() + case v := <-ch: + return v.var1, v.var2 + } +} + +// ExecFuncWithThreeReturns wraps a function which has three return values given that last one is error +// and executes that function in a goroutine there by implementing context +// if context is cancelled before function returns it will return context error +// otherwise it will return function's return values +func ExecFuncWithThreeReturns[T1 any, T2 any](ctx context.Context, fn func() (T1, T2, error)) (T1, T2, error) { + ch := make(chan threeValueCtx[T1, T2, error]) + go func() { + x, y, z := fn() + ch <- threeValueCtx[T1, T2, error]{var1: x, var2: y, var3: z} + }() + select { + case <-ctx.Done(): + var tmp1 T1 + var tmp2 T2 + return tmp1, tmp2, ctx.Err() + case v := <-ch: + return v.var1, v.var2, v.var3 + } +} diff --git a/context/Ncontext_test.go b/context/Ncontext_test.go new file mode 100644 index 0000000..2fbec61 --- /dev/null +++ b/context/Ncontext_test.go @@ -0,0 +1,110 @@ +package contextutil_test + +import ( + "context" + "errors" + "testing" + "time" + + contextutil "github.com/projectdiscovery/utils/context" +) + +func TestExecFuncWithTwoReturns(t *testing.T) { + t.Run("function completes before context cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + fn := func() (int, error) { + time.Sleep(1 * time.Second) + return 42, nil + } + + val, err := contextutil.ExecFuncWithTwoReturns(ctx, fn) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if val != 42 { + t.Errorf("Unexpected return value: got %v, want 42", val) + } + }) + + t.Run("context cancelled before function completes", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + fn := func() (int, error) { + time.Sleep(2 * time.Second) + return 42, nil + } + + _, err := contextutil.ExecFuncWithTwoReturns(ctx, fn) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded error, got: %v", err) + } + }) +} + +func TestExecFuncWithThreeReturns(t *testing.T) { + t.Run("function completes before context cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + fn := func() (int, string, error) { + time.Sleep(1 * time.Second) + return 42, "hello", nil + } + + val1, val2, err := contextutil.ExecFuncWithThreeReturns(ctx, fn) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if val1 != 42 || val2 != "hello" { + t.Errorf("Unexpected return values: got %v and %v, want 42 and 'hello'", val1, val2) + } + }) + + t.Run("context cancelled before function completes", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + fn := func() (int, string, error) { + time.Sleep(2 * time.Second) + return 42, "hello", nil + } + + _, _, err := contextutil.ExecFuncWithThreeReturns(ctx, fn) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded error, got: %v", err) + } + }) +} + +func TestExecFunc(t *testing.T) { + t.Run("function completes before context cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + fn := func() { + time.Sleep(1 * time.Second) + } + + err := contextutil.ExecFunc(ctx, fn) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + + t.Run("context cancelled before function completes", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + fn := func() { + time.Sleep(2 * time.Second) + } + + err := contextutil.ExecFunc(ctx, fn) + if err != context.DeadlineExceeded { + t.Errorf("Expected context deadline exceeded error, got: %v", err) + } + }) +} diff --git a/errors/errors.go b/errors/errors.go index e1f790f..8d07b22 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -1,7 +1,11 @@ package errorutil import ( + "context" + "errors" "fmt" + "net" + "os" "strings" ) @@ -52,3 +56,9 @@ func WrapwithNil(err error, errx ...error) Error { ee := NewWithErr(err) return ee.Wrap(errx...) } + +// IsTimeout checks if error is timeout error +func IsTimeout(err error) bool { + var net net.Error + return (errors.As(err, &net) && net.Timeout()) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, os.ErrDeadlineExceeded) +} diff --git a/go.mod b/go.mod index 745f4f9..2a02b9f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/projectdiscovery/utils -go 1.20 +go 1.21 require ( github.com/Masterminds/semver/v3 v3.2.1 diff --git a/go.sum b/go.sum index ec8de8d..5c9e5bf 100644 --- a/go.sum +++ b/go.sum @@ -17,7 +17,9 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/bits-and-blooms/bitset v1.8.0 h1:FD+XqgOZDUxxZ8hzoBFuV9+cGWY9CslN6d5MS5JVb4c= +github.com/bits-and-blooms/bitset v1.8.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bits-and-blooms/bloom/v3 v3.5.0 h1:AKDvi1V3xJCmSR6QhcBfHbCN4Vf8FfxeWkMNQfmAGhY= +github.com/bits-and-blooms/bloom/v3 v3.5.0/go.mod h1:Y8vrn7nk1tPIlmLtW2ZPV+W7StdVMor6bC1xgpjMZFs= github.com/charmbracelet/glamour v0.6.0 h1:wi8fse3Y7nfcabbbDuwolqTqMQPMnVPeZhDM273bISc= github.com/charmbracelet/glamour v0.6.0/go.mod h1:taqWV4swIMMbWALc0m7AfE9JkPSU8om2538k9ITBxOc= github.com/cheggaaa/pb/v3 v3.1.4 h1:DN8j4TVVdKu3WxVwcRKu0sG00IIU6FewoABZzXbRQeo= @@ -38,6 +40,7 @@ github.com/ebitengine/purego v0.4.0/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2 github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY= +github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= @@ -68,6 +71,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/hashicorp/golang-lru/v2 v2.0.6 h1:3xi/Cafd1NaoEnS/yDssIiuVeDVywU0QdFGl3aQaQHM= +github.com/hashicorp/golang-lru/v2 v2.0.6/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hdm/jarm-go v0.0.7 h1:Eq0geenHrBSYuKrdVhrBdMMzOmA+CAMLzN2WrF3eL6A= github.com/hdm/jarm-go v0.0.7/go.mod h1:kinGoS0+Sdn1Rr54OtanET5E5n7AlD6T6CrJAKDjJSQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -85,6 +89,7 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8= github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= @@ -130,8 +135,10 @@ github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6 github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= +github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= @@ -180,6 +187,7 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= +github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8= github.com/tidwall/btree v1.4.3 h1:Lf5U/66bk0ftNppOBjVoy/AIPBrLMkheBp4NnSNiYOo= github.com/tidwall/btree v1.4.3/go.mod h1:LGm8L/DZjPLmeWGjv5kFrY8dL4uVhMmzmmLYmsObdKE= github.com/tidwall/buntdb v1.3.0 h1:gdhWO+/YwoB2qZMeAU9JcWWsHSYU3OvcieYgFRS0zwA= @@ -190,6 +198,7 @@ github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vl github.com/tidwall/grect v0.1.4 h1:dA3oIgNgWdSspFzn1kS4S/RDpZFLrIxAZOdJKjYapOg= github.com/tidwall/grect v0.1.4/go.mod h1:9FBsaYRaR0Tcy4UwefBX/UDcDcDy9V5jUcxHzv2jd5Q= github.com/tidwall/lotsa v1.0.2 h1:dNVBH5MErdaQ/xd9s769R31/n2dXavsQ0Yf4TMEHHw8= +github.com/tidwall/lotsa v1.0.2/go.mod h1:X6NiU+4yHA3fE3Puvpnn1XMDrFZrE9JO2/w+UMuqgR8= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -263,6 +272,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/reader/conn_read.go b/reader/conn_read.go new file mode 100644 index 0000000..15a4187 --- /dev/null +++ b/reader/conn_read.go @@ -0,0 +1,100 @@ +package reader + +import ( + "context" + "errors" + "io" + "syscall" + "time" + + contextutil "github.com/projectdiscovery/utils/context" + errorutil "github.com/projectdiscovery/utils/errors" +) + +const ( + // although this is more than enough for most cases + MaxReadSize = 1 << 23 // 8MB +) + +var ( + ErrTooLarge = errors.New("reader: too large only 8MB allowed as per MaxReadSize") +) + +// ConnReadN reads at most N bytes from reader and it optimized +// for connection based readers like net.Conn it should not be used +// for file/buffer based reading, ConnReadN should be preferred +// instead of 'conn.Read() without loop' . It ignores EOF, UnexpectedEOF and timeout errors +// Note: you are responsible for adding a timeout to context +func ConnReadN(ctx context.Context, reader io.Reader, N int64) ([]byte, error) { + if N == -1 { + N = MaxReadSize + } else if N < -1 { + return nil, errors.New("reader: N cannot be less than -1") + } else if N == 0 { + return []byte{}, nil + } else if N > MaxReadSize { + return nil, ErrTooLarge + } + var readErr error + pr, pw := io.Pipe() + + // When using the Nuclei network protocol to read all available data from a connection, + // there may be a timeout error after data has been sent by server. In this scenario, + // we should return the data and ignore the error (if it is a timeout error). + // To avoid race conditions, we use io.Pipe() along with a goroutine. + // For an example of this scenario, refer to TestConnReadN#6. + + go func() { + defer pw.Close() + fn := func() (int64, error) { + return io.CopyN(pw, io.LimitReader(reader, N), N) + } + // ExecFuncWithTwoReturns will execute the function but errors if context is done + _, readErr = contextutil.ExecFuncWithTwoReturns(ctx, fn) + }() + + // read from pipe and return + bin, err2 := io.ReadAll(pr) + if err2 != nil { + return nil, errorutil.NewWithErr(err2).Msgf("something went wrong while reading from pipe") + } + + if readErr != nil { + if errorutil.IsTimeout(readErr) && len(bin) > 0 { + // if error is a timeout error and we have some data already + // then return data and ignore error + return bin, nil + } else if IsAcceptedError(readErr) { + // if error is accepted error ex: EOF, UnexpectedEOF, connection refused + // then return data and ignore error + return bin, nil + } else { + return nil, errorutil.WrapfWithNil(readErr, "reader: error while reading from connection") + } + } else { + return bin, nil + } +} + +// ConnReadNWithTimeout is same as ConnReadN but it takes timeout +// instead of context and it returns error if read does not finish in given time +func ConnReadNWithTimeout(reader io.Reader, N int64, after time.Duration) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), after) + defer cancel() + return ConnReadN(ctx, reader, N) +} + +// IsAcceptedError checks if the error is accepted error +// for example: connection refused, io.EOF, io.ErrUnexpectedEOF +// while reading from connection +func IsAcceptedError(err error) bool { + if err == io.EOF || err == io.ErrUnexpectedEOF { + // ideally we should error out if we get a timeout error but + // that's different for our use case + return true + } + if errors.Is(err, syscall.ECONNREFUSED) { + return true + } + return false +} diff --git a/reader/conn_read_test.go b/reader/conn_read_test.go new file mode 100644 index 0000000..64dd291 --- /dev/null +++ b/reader/conn_read_test.go @@ -0,0 +1,79 @@ +package reader + +import ( + "bytes" + "crypto/tls" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestConnReadN(t *testing.T) { + timeout := time.Duration(5) * time.Second + + t.Run("Test with N as -1", func(t *testing.T) { + reader := strings.NewReader("Hello, World!") + data, err := ConnReadNWithTimeout(reader, -1, timeout) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if string(data) != "Hello, World!" { + t.Errorf("Expected 'Hello, World!', got '%s'", string(data)) + } + }) + + t.Run("Test with N as 0", func(t *testing.T) { + reader := strings.NewReader("Hello, World!") + data, err := ConnReadNWithTimeout(reader, 0, timeout) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if len(data) != 0 { + t.Errorf("Expected empty, got '%s'", string(data)) + } + }) + + t.Run("Test with N greater than MaxReadSize", func(t *testing.T) { + reader := bytes.NewReader(make([]byte, MaxReadSize+1)) + _, err := ConnReadNWithTimeout(reader, MaxReadSize+1, timeout) + if err != ErrTooLarge { + t.Errorf("Expected 'ErrTooLarge', got '%v'", err) + } + }) + + t.Run("Test with N less than MaxReadSize", func(t *testing.T) { + reader := strings.NewReader("Hello, World!") + data, err := ConnReadNWithTimeout(reader, 5, timeout) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if string(data) != "Hello" { + t.Errorf("Expected 'Hello', got '%s'", string(data)) + } + }) + t.Run("Read From Connection", func(t *testing.T) { + conn, err := tls.Dial("tcp", "projectdiscovery.io:443", &tls.Config{InsecureSkipVerify: true}) + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + require.Nil(t, err, "could not connect to projectdiscovery.io over tls") + defer conn.Close() + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: projectdiscovery.io\r\nConnection: close\r\n\r\n")) + require.Nil(t, err, "could not write to connection") + data, err := ConnReadNWithTimeout(conn, -1, timeout) + require.Nilf(t, err, "could not read from connection: %s", err) + require.NotEmpty(t, data, "could not read from connection") + }) + + t.Run("Read From Connection which times out", func(t *testing.T) { + conn, err := tls.Dial("tcp", "projectdiscovery.io:443", &tls.Config{InsecureSkipVerify: true}) + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + require.Nil(t, err, "could not connect to projectdiscovery.io over tls") + defer conn.Close() + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: projectdiscovery.io\r\n\r\n")) + require.Nil(t, err, "could not write to connection") + data, err := ConnReadNWithTimeout(conn, -1, timeout) + require.Nilf(t, err, "could not read from connection: %s", err) + require.NotEmpty(t, data, "could not read from connection") + }) +} diff --git a/url/README.md b/url/README.md index b20b2da..ed71ed0 100644 --- a/url/README.md +++ b/url/README.md @@ -41,6 +41,13 @@ scanme.sh/%invalid/path - `.UpdateRelPath(newrelpath string, unsafe bool)` - `.Clone()` and more +- Dealing with Double URL Encoding of chars like `%0A` when `.Path` is directly updated + + when `url.Parse` is used to parse url like `https://127.0.0.1/%0A` it internally calls `u.setPath` which decodes `%0A` to `\n` and saves it in `u.Path` and when final url is created at time of writing to connection in http.Request Path is then escaped again thus `\n` becomes `%0A` and final url becomes `https://127.0.0.1/%0A` which is expected/required behavior. + + If `u.Path` is changed/updated directly after `url.Parse` ex: `u.Path = "%0A"` then at time of writing to connection in http.Request, Path is escaped again thus `%0A` becomes `%250A` and final url becomes `https://127.0.0.1/%250A` which is not expected/required behavior to avoid this we manually unescape/decode `u.Path` and we set `u.Path = unescape(u.Path)` which takes care of this edgecase. + + This is how `utils/url/URL` handles this edgecase when `u.Path` is directly updated. ### Note diff --git a/url/url.go b/url/url.go index b03867d..2c8ee66 100644 --- a/url/url.go +++ b/url/url.go @@ -100,7 +100,7 @@ func (u *URL) Clone() *URL { // String func (u *URL) String() string { var buff bytes.Buffer - if u.Scheme != "" { + if u.Scheme != "" && u.Host != "" { buff.WriteString(u.Scheme + "://") } if u.User != nil { @@ -308,10 +308,6 @@ func ParseURL(inputURL string, unsafe bool) (*URL, error) { } if u.IsRelative { return ParseRelativePath(inputURL, unsafe) - } else if unsafe { - // we are not relative, but we still need to call this in order to call - // the internal parser for paths url.Parse will not handle. - u.parseUnsafeRelativePath() } return u, nil } diff --git a/url/url_test.go b/url/url_test.go index 8f992fe..e3234a9 100644 --- a/url/url_test.go +++ b/url/url_test.go @@ -146,7 +146,7 @@ func TestParseInvalidUnsafe(t *testing.T) { for _, input := range testcases { u, err := ParseURL(input, true) require.Nilf(t, err, "got error for url %v", input) - require.Equal(t, input, u.String()) + require.Equal(t, input, u.URL.String()) } }