Skip to content

Commit

Permalink
Fix buffer.UDP destination override (XTLS#2356)
Browse files Browse the repository at this point in the history
  • Loading branch information
dyhkwong authored Aug 29, 2023
1 parent e013dce commit b8bd243
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 94 deletions.
107 changes: 28 additions & 79 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package dispatcher

import (
"context"
"fmt"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
// Close implements common.Closable.
func (*DefaultDispatcher) Close() error { return nil }

func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
downOpt := pipe.OptionsFromContext(ctx)
upOpt := downOpt

if network == net.Network_UDP {
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
// When target replies, server will restore the domain and send back to client.
// Note: this map is not global but per connection context
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
domain := fkr0.GetDomainFromFakeDNS(addr)
if len(domain) > 0 {
buffer.UDP.Address = net.DomainAddress(domain)
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
}
}
} else {
if ip2domain == nil {
ip2domain = new(sync.Map)
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
}
domain := addr.Domain()
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
if err == nil {
for _, ip := range ips {
ip2domain.Store(ip.String(), domain)
}
newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if ip2domain == nil {
continue
}
if domain, found := ip2domain.Load(addr.IP().String()); found {
buffer.UDP.Address = net.DomainAddress(domain.(string))
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
} else {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
buffer.UDP.Address = fakeIp[0]
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
}
uplinkReader, uplinkWriter := pipe.New(upOpt...)
downlinkReader, downlinkWriter := pipe.New(downOpt...)
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
opt := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opt...)
downlinkReader, downlinkWriter := pipe.New(opt...)

inboundLink := &transport.Link{
Reader: downlinkReader,
Expand Down Expand Up @@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
protocolString = resComp.ProtocolForDomainResult()
}
for _, p := range request.OverrideDestinationForProtocol {
if strings.HasPrefix(protocolString, p) {
if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
return true
}
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
Expand All @@ -287,17 +219,17 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
panic("Dispatcher: Invalid destination.")
}
ob := &session.Outbound{
Target: destination,
OriginalTarget: destination,
Target: destination,
}
ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
}

sniffingRequest := content.SniffingRequest
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
inbound, outbound := d.getLink(ctx)
if !sniffingRequest.Enabled {
go d.routedDispatch(ctx, outbound, destination)
} else {
Expand All @@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
}
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
}
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination
} else {
ob.Target = destination
Expand All @@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
return newError("Dispatcher: Invalid destination.")
}
ob := &session.Outbound{
Target: destination,
OriginalTarget: destination,
Target: destination,
}
ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx)
Expand All @@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
}
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
}
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination
} else {
ob.Target = destination
Expand Down
7 changes: 6 additions & 1 deletion app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/mux"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/net/cnc"
Expand Down Expand Up @@ -166,6 +167,11 @@ func (h *Handler) Tag() string {

// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
}
if h.mux != nil {
test := func(err error) {
if err != nil {
Expand All @@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer)
}
}
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
switch h.udp443 {
case "reject":
Expand Down
38 changes: 38 additions & 0 deletions common/buf/override.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package buf

import (
"github.com/xtls/xray-core/common/net"
)

type EndpointOverrideReader struct {
Reader
Dest net.Address
OriginalDest net.Address
}

func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
mb, err := r.Reader.ReadMultiBuffer()
if err == nil {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == r.OriginalDest {
b.UDP.Address = r.Dest
}
}
}
return mb, err
}

type EndpointOverrideWriter struct {
Writer
Dest net.Address
OriginalDest net.Address
}

func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == w.Dest {
b.UDP.Address = w.OriginalDest
}
}
return w.Writer.WriteMultiBuffer(mb)
}
5 changes: 3 additions & 2 deletions common/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ type Inbound struct {
// Outbound is the metadata of an outbound connection.
type Outbound struct {
// Target address of the outbound connection.
Target net.Destination
RouteTarget net.Destination
OriginalTarget net.Destination
Target net.Destination
RouteTarget net.Destination
// Gateway address
Gateway net.Address
}
Expand Down
5 changes: 0 additions & 5 deletions transport/pipe/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ const (
type pipeOption struct {
limit int32 // maximum buffer size in bytes
discardOverflow bool
onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer
}

func (o *pipeOption) isFull(curSize int32) bool {
Expand Down Expand Up @@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
return nil
}

if p.option.onTransmission != nil {
mb = p.option.onTransmission(mb)
}

for {
err := p.writeMultiBufferInternal(mb)
if err == nil {
Expand Down
7 changes: 0 additions & 7 deletions transport/pipe/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pipe
import (
"context"

"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/features/policy"
Expand All @@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
}
}

func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
return func(option *pipeOption) {
option.onTransmission = hook
}
}

// DiscardOverflow returns an Option for Pipe to discard writes if full.
func DiscardOverflow() Option {
return func(opt *pipeOption) {
Expand Down

0 comments on commit b8bd243

Please sign in to comment.