diff --git a/packages/node/src/riblt/decoder.ts b/packages/node/src/riblt/decoder.ts index 1cb2bcb67..9393b3df7 100644 --- a/packages/node/src/riblt/decoder.ts +++ b/packages/node/src/riblt/decoder.ts @@ -1,106 +1,62 @@ -import type { CodedSymbol } from "./symbol.js"; +import type { SourceSymbol, CodedSymbol, SourceSymbolFactory } from "./symbol.js"; +import { RandomMapping } from "./mapping.js"; -export class Decoder> { - private cs: CodedSymbol[] = []; - private local: CodingWindow; - private window: CodingWindow; - private remote: CodingWindow; - private decodable: number[] = []; - private decoded: number = 0; +export class Decoder { + decodedSymbols: T[]; - constructor() { - this.local = new CodingWindow(); - this.window = new CodingWindow(); - this.remote = new CodingWindow(); + constructor(private readonly sourceSymbolFactory: SourceSymbolFactory) { + this.decodedSymbols = []; } - public decoded(): boolean { - return this.decoded === this.cs.length; - } - - public localSymbols(): HashedSymbol[] { - return this.local.symbols; - } - - public remoteSymbols(): HashedSymbol[] { - return this.remote.symbols; - } - - public addSymbol(s: T): void { - const th = new HashedSymbol(s, s.hash()); - this.addHashedSymbol(th); - } - - public addHashedSymbol(s: HashedSymbol): void { - this.window.addHashedSymbol(s); - } + tryDecode(local: CodedSymbol[], remote: CodedSymbol[]): boolean { + if (local.length !== remote.length) { + throw Error("The length of coded symbol sequences must be equal"); + } - public addCodedSymbol(c: CodedSymbol): void { - c = this.window.applyWindow(c, "remove"); - c = this.remote.applyWindow(c, "remove"); - c = this.local.applyWindow(c, "add"); - this.cs.push(c); - if ((c.count === 1 || c.count === -1) && c.hash === c.symbol.hash()) { - this.decodable.push(this.cs.length - 1); - } else if (c.count === 0 && c.hash === 0) { - this.decodable.push(this.cs.length - 1); + for (let i = 0; i < local.length; i++) { + local[i].xor(remote[i]); } - } - private applyNewSymbol(t: HashedSymbol, direction: number): RandomMapping { - const m = new RandomMapping(t.hash, 0); - while (m.lastIdx < this.cs.length) { - const cidx = m.lastIdx; - this.cs[cidx] = this.cs[cidx].apply(t, direction); - if ( - (this.cs[cidx].count === -1 || this.cs[cidx].count === 1) && - this.cs[cidx].hash === this.cs[cidx].symbol.hash() - ) { - this.decodable.push(cidx); + this.decodedSymbols = []; + const pureSymbols: CodedSymbol[] = []; + let remaining = 0; + const isDecoded = new Array(local.length).fill(false); + + for (let i = 0; i < local.length; i++) { + if (local[i].isZero()) { + isDecoded[i] = true; + } else { + remaining++; + if (local[i].isPure()) { + pureSymbols.push(local[i]); + } } - m.nextIndex(); } - return m; - } - public tryDecode(): void { - for (const didx of this.decodable) { - const cidx = this.decodable[didx]; - const c = this.cs[cidx]; - switch (c.count) { - case 1: - const ns1 = new HashedSymbol(); - ns1.symbol = ns1.symbol.xor(c.symbol); - ns1.hash = c.hash; - const m1 = this.applyNewSymbol(ns1, -1); - this.remote.addHashedSymbolWithMapping(ns1, m1); - this.decoded += 1; - break; - case -1: - const ns2 = new HashedSymbol(); - ns2.symbol = ns2.symbol.xor(c.symbol); - ns2.hash = c.hash; - const m2 = this.applyNewSymbol(ns2, 1); - this.local.addHashedSymbolWithMapping(ns2, m2); - this.decoded += 1; - break; - case 0: - this.decoded += 1; - break; - default: - throw new Error("Invalid degree for decodable coded symbol"); + while (pureSymbols.length > 0) { + const symbol = pureSymbols.pop() as CodedSymbol; + if (symbol.isZero()) { + continue; + } + this.decodedSymbols.push(this.sourceSymbolFactory.clone(symbol.sum)); + + const mapping = new RandomMapping(symbol.checksum, 0); + while (mapping.lastIdx < local.length) { + const idx = mapping.lastIdx; + if (isDecoded[idx]) { + continue; + } + local[idx].xor(symbol); + if (local[idx].isZero()) { + isDecoded[idx] = true; + remaining--; + } else if (local[idx].isPure()) { + pureSymbols.push(local[idx]); + } } } - this.decodable = []; - } - public reset(): void { - this.cs = []; - this.decodable = []; - this.local.reset(); - this.remote.reset(); - this.window.reset(); - this.decoded = 0; + return remaining === 0; } } diff --git a/packages/node/src/riblt/encoder.ts b/packages/node/src/riblt/encoder.ts index b7591baec..884751726 100644 --- a/packages/node/src/riblt/encoder.ts +++ b/packages/node/src/riblt/encoder.ts @@ -76,7 +76,7 @@ class MappingHeap { } } -class CodingPrefix { +export class CodingPrefix { private sourceSymbols: HashedSymbol[]; public codedSymbols: CodedSymbol[]; private mapGenerators: RandomMapping[]; @@ -105,12 +105,13 @@ class CodingPrefix { ): void { this.sourceSymbols.push(hashedSymbol); this.mapGenerators.push(mapping); + this.queue.push(new SymbolMapping(this.sourceSymbols.length - 1, mapping.lastIdx)); } extendPrefix(size: number): void { while (this.queue.size > 0 && this.queue.top.codedIdx < size) { const mapping = this.queue.pop(); - while (mapping !== undefined && mapping.codedIdx < size) { + while (mapping.codedIdx < size) { const sourceIdx = mapping.sourceIdx; const codedIdx = mapping.codedIdx; this.codedSymbols[codedIdx].apply(this.sourceSymbols[sourceIdx], 1); @@ -121,7 +122,7 @@ class CodingPrefix { } } -class Encoder extends CodingPrefix { +export class Encoder extends CodingPrefix { addSymbol(s: T): void { super.addSymbol(s); } diff --git a/packages/node/src/riblt/symbol.ts b/packages/node/src/riblt/symbol.ts index f85dc9bab..7dbfea978 100644 --- a/packages/node/src/riblt/symbol.ts +++ b/packages/node/src/riblt/symbol.ts @@ -1,56 +1,73 @@ export interface SourceSymbol { - xor(s: SourceSymbol): void; - hash(): Uint8Array; + xor(s: SourceSymbol): void; + hash(): Uint8Array; } export interface SourceSymbolFactory { - empty(): T; - emptyHash(): Uint8Array; - clone(s: T): T; + empty(): T; + emptyHash(): Uint8Array; + clone(s: T): T; } export class HashedSymbol { - sum: T; - checksum: Uint8Array; - - constructor(sum: T, checksum: Uint8Array) { - this.sum = sum; - this.checksum = checksum; - } - - xor(s: HashedSymbol) { - this.sum.xor(s.sum); - for (let i = 0; i < this.checksum.length; i++) { - this.checksum[i] ^= s.checksum[i]; - } - } - - isPure(): boolean { - const checksum = this.sum.hash(); - if (checksum.length !== this.checksum.length) { - return false; - } - for (let i = 0; i < checksum.length; i++) { - if (checksum[i] !== this.checksum[i]) { - return false; - } - } - return true; - } + sum: T; + checksum: Uint8Array; + + constructor(sum: T, checksum: Uint8Array) { + this.sum = sum; + this.checksum = checksum; + } + + xor(s: HashedSymbol): void { + this.sum.xor(s.sum); + for (let i = 0; i < this.checksum.length; i++) { + this.checksum[i] ^= s.checksum[i]; + } + } + + isPure(): boolean { + const checksum = this.sum.hash(); + if (checksum.length !== this.checksum.length) { + return false; + } + for (let i = 0; i < checksum.length; i++) { + if (checksum[i] !== this.checksum[i]) { + return false; + } + } + return true; + } } export class CodedSymbol extends HashedSymbol { - count: number; + count: number; + + constructor(sum: T, checksum: Uint8Array, count = 1) { + super(sum, checksum); + this.count = count; + } + + apply(s: HashedSymbol, direction: number) { + super.xor(s); + this.count += direction; + } - constructor(sum: T, checksum: Uint8Array, count = 1) { - super(sum, checksum); - this.count = count; - } + xor(s: CodedSymbol) { + super.xor(s); + this.count -= s.count; + } - apply(s: HashedSymbol, direction: number) { - this.xor(s); - this.count += direction; - } + isZero(): boolean { + if (this.count !== 0) { + return false; + } + for (let i = 0; i < this.checksum.length; i++) { + if (this.checksum[i] !== 0) { + return false; + } + } + return true; + } }