Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add wait endpoint #52

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "vlmrun",
"version": "0.2.0",
"version": "0.2.1",
"description": "The official TypeScript library for the VlmRun API",
"author": "VlmRun <[email protected]>",
"main": "dist/index.js",
Expand Down
8 changes: 8 additions & 0 deletions src/client/files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ export class Files {
return response;
}

async get(fileId: string): Promise<FileResponse> {
const [response] = await this.requestor.request<FileResponse>(
"GET",
`files/${fileId}`
);
return response;
}

async delete(fileId: string): Promise<void> {
await this.requestor.request<void>("DELETE", `files/${fileId}`);
}
Expand Down
33 changes: 31 additions & 2 deletions src/client/predictions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,42 @@ export class Predictions {
return response;
}

async get(params: { id: string }): Promise<PredictionResponse> {
async get(id: string): Promise<PredictionResponse> {
const [response] = await this.requestor.request<PredictionResponse>(
"GET",
`predictions/${params.id}`
`predictions/${id}`
);
return response;
}

/**
* Wait for prediction to complete
* @param params.id - ID of prediction to wait for
* @param params.timeout - Timeout in seconds (default: 60)
* @param params.sleep - Sleep time in seconds (default: 1)
* @returns Promise containing the prediction response
* @throws TimeoutError if prediction doesn't complete within timeout
*/
async wait(
id: string,
timeout: number = 60,
sleep: number = 1
): Promise<PredictionResponse> {
const startTime = Date.now();
const timeoutMs = timeout * 1000;

while (Date.now() - startTime < timeoutMs) {
const response = await this.get(id);
if (response.status === "completed") {
return response;
}
await new Promise((resolve) => setTimeout(resolve, sleep * 1000));
}

throw new Error(
`Prediction ${id} did not complete within ${timeout} seconds`
);
}
}

export class ImagePredictions extends Predictions {
Expand Down
69 changes: 46 additions & 23 deletions tests/integration/client/files.test.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import { config } from 'dotenv';
config({ path: '.env.test' });
import { config } from "dotenv";
config({ path: ".env.test" });

import { VlmRun } from '../../../src/index';
import { FileResponse, FilePurpose } from '../../../src/client/types';
import { VlmRun } from "../../../src/index";
import { FileResponse, FilePurpose } from "../../../src/client/types";

jest.setTimeout(60000);

describe('Integration: Files', () => {
describe("Integration: Files", () => {
let client: VlmRun;

beforeAll(() => {
client = new VlmRun({
apiKey: process.env.TEST_API_KEY ?? '',
baseURL: process.env.TEST_BASE_URL ?? '',
apiKey: process.env.TEST_API_KEY ?? "",
baseURL: process.env.TEST_BASE_URL ?? "",
});
});

describe('list', () => {
it('should list files with default pagination', async () => {
describe("list", () => {
it("should list files with default pagination", async () => {
const result = await client.files.list({});
expect(Array.isArray(result)).toBe(true);

if (result.length > 0) {
const file: FileResponse = result[0];
expect(file).toHaveProperty('id');
expect(file).toHaveProperty('filename');
expect(file).toHaveProperty('bytes');
expect(file).toHaveProperty('purpose');
expect(file).toHaveProperty('created_at');
expect(file).toHaveProperty('object');
expect(file).toHaveProperty("id");
expect(file).toHaveProperty("filename");
expect(file).toHaveProperty("bytes");
expect(file).toHaveProperty("purpose");
expect(file).toHaveProperty("created_at");
expect(file).toHaveProperty("object");
}
});

it('should list files with custom pagination', async () => {
it("should list files with custom pagination", async () => {
const skip = 0;
const limit = 5;
const result = await client.files.list({ skip, limit });
Expand All @@ -41,22 +41,45 @@ describe('Integration: Files', () => {
});
});

describe('upload', () => {
const testFilePath = 'tests/integration/assets/google_invoice.pdf';
describe("upload", () => {
const testFilePath = "tests/integration/assets/google_invoice.pdf";

it('should return existing file if found and checkDuplicate is true', async () => {
it("should upload file and get file details", async () => {
const result = await client.files.upload({
filePath: testFilePath,
purpose: 'vision',
purpose: "vision",
checkDuplicate: true,
});

expect(result.id).toBeTruthy();
expect(result.filename).toBe('google_invoice.pdf');
expect(result.filename).toBe("google_invoice.pdf");
expect(result.created_at).toBeTruthy();
expect(result.object).toBe('file');
expect(result.object).toBe("file");
expect(result.bytes).toBeTruthy();
expect(result.purpose).toBe('vision' as FilePurpose);
expect(result.purpose).toBe("vision" as FilePurpose);

// Test get endpoint
const getResponse = await client.files.get(result.id);
expect(getResponse.id).toBe(result.id);
expect(getResponse).toHaveProperty("filename");
expect(getResponse).toHaveProperty("created_at");
expect(getResponse).toHaveProperty("object");
expect(getResponse).toHaveProperty("bytes");
});

it("should return existing file if found and checkDuplicate is true", async () => {
const result = await client.files.upload({
filePath: testFilePath,
purpose: "vision",
checkDuplicate: true,
});

expect(result.id).toBeTruthy();
expect(result.filename).toBe("google_invoice.pdf");
expect(result.created_at).toBeTruthy();
expect(result.object).toBe("file");
expect(result.bytes).toBeTruthy();
expect(result.purpose).toBe("vision" as FilePurpose);
});
});
});
47 changes: 47 additions & 0 deletions tests/integration/client/predictions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,5 +226,52 @@ describe("Integration: Predictions", () => {
expect(result.response).not.toHaveProperty("customer_billing_address");
expect(result.response).not.toHaveProperty("customer_shipping_address");
});

it("should generate document predictions when batch is true using url from custom zod schema", async () => {
const documentUrl =
"https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.invoice/google_invoice.pdf";

const schema = z.object({
invoice_id: z.string(),
total: z.number(),
sub_total: z.number(),
tax: z.number(),
items: z.array(
z.object({
name: z.string(),
quantity: z.number(),
price: z.number(),
total: z.number(),
})
),
});

const result = await client.document.generate({
url: documentUrl,
domain: "document.invoice",
batch: true,
config: {
jsonSchema: schema,
},
});

expect(result.status).toBe("pending");

const waitResponse = await client.predictions.wait(result.id);
const response = waitResponse.response as z.infer<typeof schema>;

console.log(waitResponse);
expect(waitResponse.status).toBe("completed");
expect(waitResponse.response).toHaveProperty("invoice_id");
expect(waitResponse.response).toHaveProperty("total");
expect(waitResponse.response).toHaveProperty("sub_total");
expect(waitResponse.response).toHaveProperty("tax");
expect(waitResponse.response).toHaveProperty("items");

// Test get endpoint
const getResponse = await client.predictions.get(result.id);
expect(getResponse.status).toBe("completed");
expect(getResponse.response).toHaveProperty("invoice_id");
});
});
});