Skip to content

Commit

Permalink
discojs/Image: add full type
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Aug 22, 2024
1 parent a85cb4d commit b596bb5
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 18 deletions.
12 changes: 5 additions & 7 deletions discojs-node/src/loaders/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ import * as fs from "node:fs/promises";

import { Dataset, Image } from "@epfml/discojs";

export async function load(path: string): Promise<Image> {
const { data, info } = await sharp(path).removeAlpha().raw().toBuffer({
export async function load(path: string): Promise<Image<1 | 3 | 4>> {
const { data, info } = await sharp(path).raw().toBuffer({
resolveWithObject: true,
});

return {
data,
width: info.width,
height: info.height,
};
if (info.channels === 2) throw new Error("unsupported channel count");

return new Image<1 | 3 | 4>(data, info.width, info.height, info.channels);
}

export async function loadAllInDir(dir: string): Promise<Dataset<Image>> {
Expand Down
6 changes: 3 additions & 3 deletions discojs-web/src/loaders/image.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { Image as DiscoImage } from "@epfml/discojs";
import { Image as DiscoImage } from "@epfml/discojs";

export async function load(file: Blob): Promise<DiscoImage> {
export async function load(file: Blob): Promise<DiscoImage<4>> {
const image = new Image();
const url = URL.createObjectURL(file);
image.src = url;
Expand All @@ -14,5 +14,5 @@ export async function load(file: Blob): Promise<DiscoImage> {
context.drawImage(image, 0, 0);
const data = new Uint8Array(context.getImageData(0, 0, width, height).data);

return { width, height, data };
return new DiscoImage(data, width, height, 4);
}
59 changes: 58 additions & 1 deletion discojs/src/convertors.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { PreTrainedTokenizer } from "@xenova/transformers";
import { List } from "immutable";
import { List, Repeat, Seq } from "immutable";
import { Image } from "./dataset/image.js";

export function convertToNumber(raw: string): number {
const num = Number.parseFloat(raw);
Expand Down Expand Up @@ -67,3 +68,59 @@ export function tokenizeAndLeftPad(

return padded;
}

export function removeAlpha<W extends number, H extends number>(
image: Image<4, W, H>,
): Image<3, W, H>;
export function removeAlpha<
D extends 1 | 3,
W extends number,
H extends number,
>(image: Image<D | 4, W, H>): Image<D, W, H>;
export function removeAlpha<W extends number, H extends number>(
image: Image<1 | 3 | 4, W, H>,
): Image<1 | 3, W, H> {
switch (image.depth) {
case 1:
case 3:
return new Image(image.data, image.width, image.height, image.depth);
case 4:
return new Image(
image.data.filter((_, i) => i % 4 !== 3),
image.width,
image.height,
3,
);
}
}

export function expandToMulticolor<W extends number, H extends number>(
image: Image<1, W, H>,
): Image<3, W, H>;
export function expandToMulticolor<
D extends 3 | 4,
W extends number,
H extends number,
>(image: Image<1 | D, W, H>): Image<D, W, H>;
export function expandToMulticolor<W extends number, H extends number>(
image: Image<1 | 3 | 4, W, H>,
): Image<3 | 4, W, H> {
switch (image.depth) {
case 1:
return new Image(
Uint8Array.from(Seq(image.data).flatMap((v) => Repeat(v, 3))),
image.width,
image.height,
3,
);
case 3:
return new Image(image.data, image.width, image.height, image.depth);
case 4:
return new Image(
image.data.filter((_, i) => i % 4 !== 3),
image.width,
image.height,
image.depth,
);
}
}
16 changes: 11 additions & 5 deletions discojs/src/dataset/data/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async function intoTFDataset<T extends tf.TensorContainer>(
return tf.data.array(materialized);
}

function imageToTensor(image: Image): tf.Tensor3D {
function imageToTensor(image: Image<3>): tf.Tensor3D {
return tf.tensor3d(image.data, [image.width, image.height, 3], "int32");
}

Expand All @@ -48,9 +48,12 @@ export async function datasetToData(
): Promise<Data> {
switch (t) {
case "image": {
const converted = dataset.map((image) => ({
xs: imageToTensor(image),
}));
const converted = dataset
.map(convertors.removeAlpha)
.map((image) => convertors.expandToMulticolor(image))
.map((image) => ({
xs: imageToTensor(image),
}));
return await ImageData.init(await intoTFDataset(converted), task);
}
case "tabular": {
Expand Down Expand Up @@ -78,7 +81,10 @@ export async function labeledDatasetToData(
const converted = dataset
.map(
([image, label]) =>
[image, convertors.indexInList(label, labels)] as const,
[
convertors.expandToMulticolor(convertors.removeAlpha(image)),
convertors.indexInList(label, labels),
] as const,
)
.map(
([image, label]) =>
Expand Down
15 changes: 15 additions & 0 deletions discojs/src/dataset/image.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export class Image<
D extends 1 | 3 | 4 = 1 | 3 | 4,
W extends number = number,
H extends number = number,
> {
constructor(
public readonly data: Readonly<Uint8Array>,
public readonly width: W,
public readonly height: H,
public readonly depth: D,
) {
if (data.length != width * height * depth)
throw new Error("data isn't of excepted size");
}
}
5 changes: 3 additions & 2 deletions discojs/src/dataset/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// TODO use full type to check shape of array
export type Image = { width: number; height: number; data: Uint8Array };
import { Image } from "./image.js"

export { Image };
export type Tabular = Partial<Record<string, string>>;
export type Text = string;

0 comments on commit b596bb5

Please sign in to comment.