diff --git a/dht.go b/dht.go index 95f028ec..0a6f2ecb 100644 --- a/dht.go +++ b/dht.go @@ -20,7 +20,6 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/internal" dhtcfg "github.com/libp2p/go-libp2p-kad-dht/internal/config" - "github.com/libp2p/go-libp2p-kad-dht/internal/net" "github.com/libp2p/go-libp2p-kad-dht/metrics" "github.com/libp2p/go-libp2p-kad-dht/netsize" pb "github.com/libp2p/go-libp2p-kad-dht/pb" @@ -206,7 +205,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.disableFixLowPeers = cfg.DisableFixLowPeers dht.Validator = cfg.Validator - dht.msgSender = net.NewMessageSenderImpl(h, dht.protocols) + dht.msgSender = cfg.MsgSenderBuilder(h, dht.protocols) dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender) if err != nil { return nil, err diff --git a/dht_options.go b/dht_options.go index 250e4ca6..4be746ff 100644 --- a/dht_options.go +++ b/dht_options.go @@ -6,9 +6,11 @@ import ( "time" dhtcfg "github.com/libp2p/go-libp2p-kad-dht/internal/config" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" + "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -356,3 +358,12 @@ func AddressFilter(f func([]ma.Multiaddr) []ma.Multiaddr) Option { return nil } } + +// WithCustomMessageSender configures the pb.MessageSender of the IpfsDHT to use the +// custom implementation of the pb.MessageSender +func WithCustomMessageSender(messageSenderBuilder func(h host.Host, protos []protocol.ID) pb.MessageSenderWithDisconnect) Option { + return func(c *dhtcfg.Config) error { + c.MsgSenderBuilder = messageSenderBuilder + return nil + } +} diff --git a/internal/config/config.go b/internal/config/config.go index bacd2e4d..d9a5794c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,6 +7,8 @@ import ( "github.com/ipfs/boxo/ipns" ds "github.com/ipfs/go-datastore" dssync "github.com/ipfs/go-datastore/sync" + "github.com/libp2p/go-libp2p-kad-dht/internal/net" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" @@ -48,6 +50,7 @@ type Config struct { ProviderStore providers.ProviderStore QueryPeerFilter QueryFilterFunc LookupCheckConcurrency int + MsgSenderBuilder func(h host.Host, protos []protocol.ID) pb.MessageSenderWithDisconnect RoutingTable struct { RefreshQueryTimeout time.Duration @@ -114,6 +117,7 @@ var Defaults = func(o *Config) error { o.EnableProviders = true o.EnableValues = true o.QueryPeerFilter = EmptyQueryFilter + o.MsgSenderBuilder = net.NewMessageSenderImpl o.RoutingTable.LatencyTolerance = 10 * time.Second o.RoutingTable.RefreshQueryTimeout = 10 * time.Second