From 0442ebd2f04c4676c8ac7426d44e0d6e8af87acd Mon Sep 17 00:00:00 2001 From: shahrear33 Date: Mon, 10 Feb 2025 22:55:11 +0600 Subject: [PATCH] feat: added web predictions --- package.json | 2 +- src/client/predictions.ts | 31 ++++++++ src/client/types.ts | 8 +- src/index.ts | 3 + tests/unit/client/predictions.test.ts | 101 ++++++++++++++++++++++++++ 5 files changed, 143 insertions(+), 2 deletions(-) diff --git a/package.json b/package.json index 983ed8f..c6a222b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "vlmrun", - "version": "0.1.12", + "version": "0.1.13", "description": "The official TypeScript library for the VlmRun API", "author": "VlmRun ", "main": "dist/index.js", diff --git a/src/client/predictions.ts b/src/client/predictions.ts index e86221f..e9a39f6 100644 --- a/src/client/predictions.ts +++ b/src/client/predictions.ts @@ -5,6 +5,7 @@ import { ListParams, ImagePredictionParams, FilePredictionParams, + WebPredictionParams, } from "./types"; import { processImage } from "../utils/image"; @@ -138,6 +139,36 @@ export class FilePredictions extends Predictions { } } +export class WebPredictions extends Predictions { + async generate(params: WebPredictionParams): Promise { + const { url, model, domain, mode, metadata, callbackUrl, config } = params; + const [response] = await this.requestor.request( + "POST", + `/web/generate`, + undefined, + { + url, + model, + domain, + mode, + config: { + detail: config?.detail ?? "auto", + json_schema: config?.jsonSchema ?? null, + confidence: config?.confidence ?? false, + grounding: config?.grounding ?? false, + }, + metadata: { + environment: metadata?.environment ?? "dev", + session_id: metadata?.sessionId, + allow_training: metadata?.allowTraining ?? true, + }, + callback_url: callbackUrl, + } + ); + return response; + } +} + // Create specialized instances for different file types export const DocumentPredictions = (client: Client) => new FilePredictions(client, "document"); diff --git a/src/client/types.ts b/src/client/types.ts index dc325db..3bc6c38 100644 --- a/src/client/types.ts +++ b/src/client/types.ts @@ -63,7 +63,6 @@ export interface FeedbackSubmitParams { export interface PredictionGenerateParams { model?: string; domain: string; - batch?: boolean; config?: GenerationConfigParams; metadata?: RequestMetadataParams; callbackUrl?: string; @@ -157,14 +156,21 @@ export class GenerationConfig { export type GenerationConfigInput = GenerationConfig | GenerationConfigParams; export interface ImagePredictionParams extends PredictionGenerateParams { + batch?: boolean; images: string[]; } export interface FilePredictionParams extends PredictionGenerateParams { + batch?: boolean; fileId?: string; url?: string; } +export interface WebPredictionParams extends PredictionGenerateParams { + url: string; + mode: "fast" | "accurate"; +} + export class APIError extends Error { constructor( message: string, diff --git a/src/index.ts b/src/index.ts index b8343e9..95532d9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,6 +7,7 @@ import { DocumentPredictions, AudioPredictions, VideoPredictions, + WebPredictions, } from "./client/predictions"; import { Feedback } from "./client/feedback"; @@ -33,6 +34,7 @@ export class VlmRun { readonly document: ReturnType; readonly audio: ReturnType; readonly video: ReturnType; + readonly web: WebPredictions; readonly feedback: Feedback; constructor(config: VlmRunConfig) { @@ -49,5 +51,6 @@ export class VlmRun { this.audio = AudioPredictions(this.client); this.video = VideoPredictions(this.client); this.feedback = new Feedback(this.client); + this.web = new WebPredictions(this.client); } } diff --git a/tests/unit/client/predictions.test.ts b/tests/unit/client/predictions.test.ts index 339fffc..4598915 100644 --- a/tests/unit/client/predictions.test.ts +++ b/tests/unit/client/predictions.test.ts @@ -4,6 +4,7 @@ import { DocumentPredictions, AudioPredictions, VideoPredictions, + WebPredictions, } from "../../../src/client/predictions"; import * as imageUtils from "../../../src/utils/image"; @@ -261,4 +262,104 @@ describe("Predictions", () => { }); }); }); + + describe("WebPredictions", () => { + let webPredictions: WebPredictions; + + beforeEach(() => { + webPredictions = new WebPredictions(client); + }); + + describe("generate", () => { + it("should generate web predictions with default options", async () => { + const mockResponse = { id: "pred_123", status: "completed" }; + jest + .spyOn(webPredictions["requestor"], "request") + .mockResolvedValue([mockResponse, 200, {}]); + + const result = await webPredictions.generate({ + url: "https://example.com", + model: "model1", + domain: "domain1", + mode: "accurate", + }); + + expect(result).toEqual(mockResponse); + expect(webPredictions["requestor"].request).toHaveBeenCalledWith( + "POST", + "/web/generate", + undefined, + { + url: "https://example.com", + model: "model1", + domain: "domain1", + mode: "accurate", + config: { + detail: "auto", + json_schema: null, + confidence: false, + grounding: false, + }, + metadata: { + environment: "dev", + session_id: undefined, + allow_training: true, + }, + callback_url: undefined, + } + ); + }); + + it("should generate web predictions with custom options", async () => { + const mockResponse = { id: "pred_123", status: "completed" }; + jest + .spyOn(webPredictions["requestor"], "request") + .mockResolvedValue([mockResponse, 200, {}]); + + const result = await webPredictions.generate({ + url: "https://example.com", + model: "model1", + domain: "domain1", + mode: "fast", + config: { + detail: "hi", + jsonSchema: { type: "object" }, + confidence: true, + grounding: true, + }, + metadata: { + environment: "prod", + sessionId: "session123", + allowTraining: false, + }, + callbackUrl: "https://callback.example.com", + }); + + expect(result).toEqual(mockResponse); + expect(webPredictions["requestor"].request).toHaveBeenCalledWith( + "POST", + "/web/generate", + undefined, + { + url: "https://example.com", + model: "model1", + domain: "domain1", + mode: "fast", + config: { + detail: "hi", + json_schema: { type: "object" }, + confidence: true, + grounding: true, + }, + metadata: { + environment: "prod", + session_id: "session123", + allow_training: false, + }, + callback_url: "https://callback.example.com", + } + ); + }); + }); + }); });