Skip to content

Commit

Permalink
Implement decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
magnified103 committed Nov 11, 2024
1 parent 61f246b commit e5a0ad5
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 134 deletions.
136 changes: 46 additions & 90 deletions packages/node/src/riblt/decoder.ts
Original file line number Diff line number Diff line change
@@ -1,106 +1,62 @@
import type { CodedSymbol } from "./symbol.js";
import type { SourceSymbol, CodedSymbol, SourceSymbolFactory } from "./symbol.js";
import { RandomMapping } from "./mapping.js";


export class Decoder<T extends Symbol<T>> {
private cs: CodedSymbol<T>[] = [];
private local: CodingWindow<T>;
private window: CodingWindow<T>;
private remote: CodingWindow<T>;
private decodable: number[] = [];
private decoded: number = 0;
export class Decoder<T extends SourceSymbol> {
decodedSymbols: T[];

constructor() {
this.local = new CodingWindow<T>();
this.window = new CodingWindow<T>();
this.remote = new CodingWindow<T>();
constructor(private readonly sourceSymbolFactory: SourceSymbolFactory<T>) {
this.decodedSymbols = [];
}

public decoded(): boolean {
return this.decoded === this.cs.length;
}

public localSymbols(): HashedSymbol<T>[] {
return this.local.symbols;
}

public remoteSymbols(): HashedSymbol<T>[] {
return this.remote.symbols;
}

public addSymbol(s: T): void {
const th = new HashedSymbol<T>(s, s.hash());
this.addHashedSymbol(th);
}

public addHashedSymbol(s: HashedSymbol<T>): void {
this.window.addHashedSymbol(s);
}
tryDecode(local: CodedSymbol<T>[], remote: CodedSymbol<T>[]): boolean {
if (local.length !== remote.length) {
throw Error("The length of coded symbol sequences must be equal");
}

public addCodedSymbol(c: CodedSymbol<T>): void {
c = this.window.applyWindow(c, "remove");
c = this.remote.applyWindow(c, "remove");
c = this.local.applyWindow(c, "add");
this.cs.push(c);
if ((c.count === 1 || c.count === -1) && c.hash === c.symbol.hash()) {
this.decodable.push(this.cs.length - 1);
} else if (c.count === 0 && c.hash === 0) {
this.decodable.push(this.cs.length - 1);
for (let i = 0; i < local.length; i++) {
local[i].xor(remote[i]);
}
}

private applyNewSymbol(t: HashedSymbol<T>, direction: number): RandomMapping {
const m = new RandomMapping(t.hash, 0);
while (m.lastIdx < this.cs.length) {
const cidx = m.lastIdx;
this.cs[cidx] = this.cs[cidx].apply(t, direction);
if (
(this.cs[cidx].count === -1 || this.cs[cidx].count === 1) &&
this.cs[cidx].hash === this.cs[cidx].symbol.hash()
) {
this.decodable.push(cidx);
this.decodedSymbols = [];
const pureSymbols: CodedSymbol<T>[] = [];
let remaining = 0;
const isDecoded = new Array(local.length).fill(false);

for (let i = 0; i < local.length; i++) {
if (local[i].isZero()) {
isDecoded[i] = true;
} else {
remaining++;
if (local[i].isPure()) {
pureSymbols.push(local[i]);
}
}
m.nextIndex();
}
return m;
}

public tryDecode(): void {
for (const didx of this.decodable) {
const cidx = this.decodable[didx];
const c = this.cs[cidx];
switch (c.count) {
case 1:
const ns1 = new HashedSymbol<T>();
ns1.symbol = ns1.symbol.xor(c.symbol);
ns1.hash = c.hash;
const m1 = this.applyNewSymbol(ns1, -1);
this.remote.addHashedSymbolWithMapping(ns1, m1);
this.decoded += 1;
break;
case -1:
const ns2 = new HashedSymbol<T>();
ns2.symbol = ns2.symbol.xor(c.symbol);
ns2.hash = c.hash;
const m2 = this.applyNewSymbol(ns2, 1);
this.local.addHashedSymbolWithMapping(ns2, m2);
this.decoded += 1;
break;
case 0:
this.decoded += 1;
break;
default:
throw new Error("Invalid degree for decodable coded symbol");
while (pureSymbols.length > 0) {
const symbol = pureSymbols.pop() as CodedSymbol<T>;
if (symbol.isZero()) {
continue;
}
this.decodedSymbols.push(this.sourceSymbolFactory.clone(symbol.sum));

const mapping = new RandomMapping(symbol.checksum, 0);
while (mapping.lastIdx < local.length) {
const idx = mapping.lastIdx;
if (isDecoded[idx]) {
continue;
}
local[idx].xor(symbol);
if (local[idx].isZero()) {
isDecoded[idx] = true;
remaining--;
} else if (local[idx].isPure()) {
pureSymbols.push(local[idx]);
}
}
}
this.decodable = [];
}

public reset(): void {
this.cs = [];
this.decodable = [];
this.local.reset();
this.remote.reset();
this.window.reset();
this.decoded = 0;
return remaining === 0;
}
}
7 changes: 4 additions & 3 deletions packages/node/src/riblt/encoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MappingHeap {
}
}

class CodingPrefix<T extends SourceSymbol> {
export class CodingPrefix<T extends SourceSymbol> {
private sourceSymbols: HashedSymbol<T>[];
public codedSymbols: CodedSymbol<T>[];
private mapGenerators: RandomMapping[];
Expand Down Expand Up @@ -105,12 +105,13 @@ class CodingPrefix<T extends SourceSymbol> {
): void {
this.sourceSymbols.push(hashedSymbol);
this.mapGenerators.push(mapping);
this.queue.push(new SymbolMapping(this.sourceSymbols.length - 1, mapping.lastIdx));
}

extendPrefix(size: number): void {
while (this.queue.size > 0 && this.queue.top.codedIdx < size) {
const mapping = this.queue.pop();
while (mapping !== undefined && mapping.codedIdx < size) {
while (mapping.codedIdx < size) {
const sourceIdx = mapping.sourceIdx;
const codedIdx = mapping.codedIdx;
this.codedSymbols[codedIdx].apply(this.sourceSymbols[sourceIdx], 1);
Expand All @@ -121,7 +122,7 @@ class CodingPrefix<T extends SourceSymbol> {
}
}

class Encoder<T extends SourceSymbol> extends CodingPrefix<T> {
export class Encoder<T extends SourceSymbol> extends CodingPrefix<T> {
addSymbol(s: T): void {
super.addSymbol(s);
}
Expand Down
99 changes: 58 additions & 41 deletions packages/node/src/riblt/symbol.ts
Original file line number Diff line number Diff line change
@@ -1,56 +1,73 @@
export interface SourceSymbol {
xor(s: SourceSymbol): void;
hash(): Uint8Array;
xor(s: SourceSymbol): void;
hash(): Uint8Array;
}


export interface SourceSymbolFactory<T> {
empty(): T;
emptyHash(): Uint8Array;
clone(s: T): T;
empty(): T;
emptyHash(): Uint8Array;
clone(s: T): T;
}


export class HashedSymbol<T extends SourceSymbol> {
sum: T;
checksum: Uint8Array;

constructor(sum: T, checksum: Uint8Array) {
this.sum = sum;
this.checksum = checksum;
}

xor(s: HashedSymbol<T>) {
this.sum.xor(s.sum);
for (let i = 0; i < this.checksum.length; i++) {
this.checksum[i] ^= s.checksum[i];
}
}

isPure(): boolean {
const checksum = this.sum.hash();
if (checksum.length !== this.checksum.length) {
return false;
}
for (let i = 0; i < checksum.length; i++) {
if (checksum[i] !== this.checksum[i]) {
return false;
}
}
return true;
}
sum: T;
checksum: Uint8Array;

constructor(sum: T, checksum: Uint8Array) {
this.sum = sum;
this.checksum = checksum;
}

xor(s: HashedSymbol<T>): void {
this.sum.xor(s.sum);
for (let i = 0; i < this.checksum.length; i++) {
this.checksum[i] ^= s.checksum[i];
}
}

isPure(): boolean {
const checksum = this.sum.hash();
if (checksum.length !== this.checksum.length) {
return false;
}
for (let i = 0; i < checksum.length; i++) {
if (checksum[i] !== this.checksum[i]) {
return false;
}
}
return true;
}
}

export class CodedSymbol<T extends SourceSymbol> extends HashedSymbol<T> {
count: number;
count: number;

constructor(sum: T, checksum: Uint8Array, count = 1) {
super(sum, checksum);
this.count = count;
}

apply(s: HashedSymbol<T>, direction: number) {
super.xor(s);
this.count += direction;
}

constructor(sum: T, checksum: Uint8Array, count = 1) {
super(sum, checksum);
this.count = count;
}
xor(s: CodedSymbol<T>) {
super.xor(s);
this.count -= s.count;
}

apply(s: HashedSymbol<T>, direction: number) {
this.xor(s);
this.count += direction;
}
isZero(): boolean {
if (this.count !== 0) {
return false;
}
for (let i = 0; i < this.checksum.length; i++) {
if (this.checksum[i] !== 0) {
return false;
}
}
return true;
}
}

0 comments on commit e5a0ad5

Please sign in to comment.