diff --git a/tests/waku_store_sync/sync_utils.nim b/tests/waku_store_sync/sync_utils.nim index aa56ff2e51..1dea8edfbc 100644 --- a/tests/waku_store_sync/sync_utils.nim +++ b/tests/waku_store_sync/sync_utils.nim @@ -25,10 +25,12 @@ proc newTestWakuRecon*( idsRx: AsyncQueue[SyncID], wantsTx: AsyncQueue[(PeerId, Fingerprint)], needsTx: AsyncQueue[(PeerId, Fingerprint)], + shards: seq[uint16] = @[0, 1, 2, 3, 4, 5, 6, 7], ): Future[SyncReconciliation] {.async.} = let peerManager = PeerManager.new(switch) let res = await SyncReconciliation.new( + shards = shards, peerManager = peerManager, wakuArchive = nil, relayJitter = 0.seconds, diff --git a/tests/waku_store_sync/test_protocol.nim b/tests/waku_store_sync/test_protocol.nim index d3ffa187f3..af921a08a9 100644 --- a/tests/waku_store_sync/test_protocol.nim +++ b/tests/waku_store_sync/test_protocol.nim @@ -157,6 +157,40 @@ suite "Waku Sync: reconciliation": localWants.contains((clientPeerInfo.peerId, hash3)) == true localWants.contains((serverPeerInfo.peerId, hash2)) == true + asyncTest "sync 2 nodes different shards": + let + msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic) + msg2 = fakeWakuMessage(ts = now() + 1, contentTopic = DefaultContentTopic) + msg3 = fakeWakuMessage(ts = now() + 2, contentTopic = DefaultContentTopic) + hash1 = computeMessageHash(DefaultPubsubTopic, msg1) + hash2 = computeMessageHash(DefaultPubsubTopic, msg2) + hash3 = computeMessageHash(DefaultPubsubTopic, msg3) + + server.messageIngress(hash1, msg1) + server.messageIngress(hash2, msg2) + client.messageIngress(hash1, msg1) + client.messageIngress(hash3, msg3) + + check: + remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false + remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false + localWants.contains((clientPeerInfo.peerId, hash3)) == false + localWants.contains((serverPeerInfo.peerId, hash2)) == false + + server = await newTestWakuRecon( + serverSwitch, idsChannel, localWants, remoteNeeds, @[0.uint16, 1, 2, 3] + ) + client = await newTestWakuRecon( + clientSwitch, idsChannel, localWants, remoteNeeds, @[4.uint16, 5, 6, 7] + ) + + var syncRes = await client.storeSynchronization(some(serverPeerInfo)) + assert syncRes.isOk(), $syncRes.error + + check: + remoteNeeds.len == 0 + localWants.len == 0 + asyncTest "sync 2 nodes same hashes": let msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic) diff --git a/waku/node/waku_node.nim b/waku/node/waku_node.nim index 6dc68ea2bf..a7dc3d200f 100644 --- a/waku/node/waku_node.nim +++ b/waku/node/waku_node.nim @@ -216,9 +216,17 @@ proc mountStoreSync*( let wantsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100) let needsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100) + var shards = seq[uint16] + let enrRes = node.enr.toTyped() + if enrRes.isOk(): + let shardingRes = enrRes.get().relaySharding() + if shardingRes.isSome(): + let relayShard = shardingRes.get() + shards = relayShard.shardIds + let recon = ?await SyncReconciliation.new( - node.peerManager, node.wakuArchive, storeSyncRange.seconds, + shards, node.peerManager, node.wakuArchive, storeSyncRange.seconds, storeSyncInterval.seconds, storeSyncRelayJitter.seconds, idsChannel, wantsChannel, needsChannel, ) diff --git a/waku/waku_store_sync/codec.nim b/waku/waku_store_sync/codec.nim index ee0b926a3a..254b915af7 100644 --- a/waku/waku_store_sync/codec.nim +++ b/waku/waku_store_sync/codec.nim @@ -52,6 +52,14 @@ proc deltaEncode*(value: RangesData): seq[byte] = i = 0 j = 0 + # encode shards + buf = uint64(value.shards.len).toBytes(Leb128) + output &= @buf + + for shard in value.shards: + buf = uint64(shard).toBytes(Leb128) + output &= @buf + # the first range is implicit but must be explicit when encoded let (bound, _) = value.ranges[0] @@ -209,6 +217,28 @@ proc getReconciled(idx: var int, buffer: seq[byte]): Result[bool, string] = return ok(recon) +proc getShards(idx: var int, buffer: seq[byte]): Result[seq[uint16], string] = + if idx + VarIntLen > buffer.len: + return err("Cannot decode shards") + + let slice = buffer[idx ..< idx + VarIntLen] + let (val, len) = uint64.fromBytes(slice, Leb128) + idx += len + let shardsLen = val + + var shards: seq[uint16] + for _ in 0 ..< shardsLen: + if idx + VarIntLen > buffer.len: + return err("Cannot decode shards") + + let slice = buffer[idx ..< idx + VarIntLen] + let (val, len) = uint64.fromBytes(slice, Leb128) + idx += len + + shards.add(uint16(val)) + + return ok(shards) + proc deltaDecode*( itemSet: var ItemSet, buffer: seq[byte], setLength: int ): Result[int, string] = @@ -242,7 +272,7 @@ proc getItemSet( return ok(itemSet) proc deltaDecode*(T: type RangesData, buffer: seq[byte]): Result[T, string] = - if buffer.len == 1: + if buffer.len <= 1: return ok(RangesData()) var @@ -250,6 +280,8 @@ proc deltaDecode*(T: type RangesData, buffer: seq[byte]): Result[T, string] = lastTime = Timestamp(0) idx = 0 + payload.shards = ?getShards(idx, buffer) + lastTime = ?getTimestamp(idx, buffer) # implicit first hash is always 0 diff --git a/waku/waku_store_sync/common.nim b/waku/waku_store_sync/common.nim index 2795450786..a45330ea75 100644 --- a/waku/waku_store_sync/common.nim +++ b/waku/waku_store_sync/common.nim @@ -26,6 +26,8 @@ type ItemSet = 2 RangesData* = object + shards*: seq[uint16] + ranges*: seq[(Slice[SyncID], RangeType)] fingerprints*: seq[Fingerprint] # Range type fingerprint stored here in order itemSets*: seq[ItemSet] # Range type itemset stored here in order diff --git a/waku/waku_store_sync/reconciliation.nim b/waku/waku_store_sync/reconciliation.nim index 5ad6260c9d..ef1a9aea10 100644 --- a/waku/waku_store_sync/reconciliation.nim +++ b/waku/waku_store_sync/reconciliation.nim @@ -1,7 +1,7 @@ {.push raises: [].} import - std/sequtils, + std/[sequtils, packedsets], stew/byteutils, results, chronicles, @@ -37,6 +37,8 @@ logScope: const DefaultStorageCap = 50_000 type SyncReconciliation* = ref object of LPProtocol + shards: PackedSet[uint16] + peerManager: PeerManager wakuArchive: WakuArchive @@ -114,16 +116,22 @@ proc processRequest( var hashToRecv: seq[WakuMessageHash] hashToSend: seq[WakuMessageHash] + sendPayload: RangesData + rawPayload: seq[byte] + + # Only process the ranges IF the shards matches + if recvPayload.shards.toPackedSet() == self.shards: + sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv) - let sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv) + sendPayload.shards = self.shards.toSeq() - for hash in hashToSend: - await self.remoteNeedsTx.addLast((conn.peerId, hash)) + for hash in hashToSend: + await self.remoteNeedsTx.addLast((conn.peerId, hash)) - for hash in hashToRecv: - await self.localWantstx.addLast((conn.peerId, hash)) + for hash in hashToRecv: + await self.localWantstx.addLast((conn.peerId, hash)) - let rawPayload = sendPayload.deltaEncode() + rawPayload = sendPayload.deltaEncode() total_bytes_exchanged.observe( rawPayload.len, labelValues = [Reconciliation, Sending] @@ -162,6 +170,7 @@ proc initiate( fingerprint = self.storage.computeFingerprint(bounds) initPayload = RangesData( + shards: self.shards.toSeq(), ranges: @[(bounds, RangeType.Fingerprint)], fingerprints: @[fingerprint], itemSets: @[], @@ -261,6 +270,7 @@ proc initFillStorage( proc new*( T: type SyncReconciliation, + shards: seq[uint16], peerManager: PeerManager, wakuArchive: WakuArchive, syncRange: timer.Duration = DefaultSyncRange, @@ -279,6 +289,7 @@ proc new*( SeqStorage.new(res.get()) var sync = SyncReconciliation( + shards: shards.toPackedSet(), peerManager: peerManager, storage: storage, syncRange: syncRange,