Skip to content

Commit

Permalink
Merge pull request #49 from vlm-run/sh/add-web-endpoint
Browse files Browse the repository at this point in the history
Added web predictions
  • Loading branch information
shahrear33 authored Feb 10, 2025
2 parents 15c8baa + 0442ebd commit 8cd873e
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 2 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "vlmrun",
"version": "0.1.12",
"version": "0.1.13",
"description": "The official TypeScript library for the VlmRun API",
"author": "VlmRun <[email protected]>",
"main": "dist/index.js",
Expand Down
31 changes: 31 additions & 0 deletions src/client/predictions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ListParams,
ImagePredictionParams,
FilePredictionParams,
WebPredictionParams,
} from "./types";
import { processImage } from "../utils/image";

Expand Down Expand Up @@ -138,6 +139,36 @@ export class FilePredictions extends Predictions {
}
}

export class WebPredictions extends Predictions {
async generate(params: WebPredictionParams): Promise<PredictionResponse> {
const { url, model, domain, mode, metadata, callbackUrl, config } = params;
const [response] = await this.requestor.request<PredictionResponse>(
"POST",
`/web/generate`,
undefined,
{
url,
model,
domain,
mode,
config: {
detail: config?.detail ?? "auto",
json_schema: config?.jsonSchema ?? null,
confidence: config?.confidence ?? false,
grounding: config?.grounding ?? false,
},
metadata: {
environment: metadata?.environment ?? "dev",
session_id: metadata?.sessionId,
allow_training: metadata?.allowTraining ?? true,
},
callback_url: callbackUrl,
}
);
return response;
}
}

// Create specialized instances for different file types
export const DocumentPredictions = (client: Client) =>
new FilePredictions(client, "document");
Expand Down
8 changes: 7 additions & 1 deletion src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ export interface FeedbackSubmitParams {
export interface PredictionGenerateParams {
model?: string;
domain: string;
batch?: boolean;
config?: GenerationConfigParams;
metadata?: RequestMetadataParams;
callbackUrl?: string;
Expand Down Expand Up @@ -157,14 +156,21 @@ export class GenerationConfig {
export type GenerationConfigInput = GenerationConfig | GenerationConfigParams;

export interface ImagePredictionParams extends PredictionGenerateParams {
batch?: boolean;
images: string[];
}

export interface FilePredictionParams extends PredictionGenerateParams {
batch?: boolean;
fileId?: string;
url?: string;
}

export interface WebPredictionParams extends PredictionGenerateParams {
url: string;
mode: "fast" | "accurate";
}

export class APIError extends Error {
constructor(
message: string,
Expand Down
3 changes: 3 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
DocumentPredictions,
AudioPredictions,
VideoPredictions,
WebPredictions,
} from "./client/predictions";
import { Feedback } from "./client/feedback";

Expand All @@ -33,6 +34,7 @@ export class VlmRun {
readonly document: ReturnType<typeof DocumentPredictions>;
readonly audio: ReturnType<typeof AudioPredictions>;
readonly video: ReturnType<typeof VideoPredictions>;
readonly web: WebPredictions;
readonly feedback: Feedback;

constructor(config: VlmRunConfig) {
Expand All @@ -49,5 +51,6 @@ export class VlmRun {
this.audio = AudioPredictions(this.client);
this.video = VideoPredictions(this.client);
this.feedback = new Feedback(this.client);
this.web = new WebPredictions(this.client);
}
}
101 changes: 101 additions & 0 deletions tests/unit/client/predictions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
DocumentPredictions,
AudioPredictions,
VideoPredictions,
WebPredictions,
} from "../../../src/client/predictions";
import * as imageUtils from "../../../src/utils/image";

Expand Down Expand Up @@ -261,4 +262,104 @@ describe("Predictions", () => {
});
});
});

describe("WebPredictions", () => {
let webPredictions: WebPredictions;

beforeEach(() => {
webPredictions = new WebPredictions(client);
});

describe("generate", () => {
it("should generate web predictions with default options", async () => {
const mockResponse = { id: "pred_123", status: "completed" };
jest
.spyOn(webPredictions["requestor"], "request")
.mockResolvedValue([mockResponse, 200, {}]);

const result = await webPredictions.generate({
url: "https://example.com",
model: "model1",
domain: "domain1",
mode: "accurate",
});

expect(result).toEqual(mockResponse);
expect(webPredictions["requestor"].request).toHaveBeenCalledWith(
"POST",
"/web/generate",
undefined,
{
url: "https://example.com",
model: "model1",
domain: "domain1",
mode: "accurate",
config: {
detail: "auto",
json_schema: null,
confidence: false,
grounding: false,
},
metadata: {
environment: "dev",
session_id: undefined,
allow_training: true,
},
callback_url: undefined,
}
);
});

it("should generate web predictions with custom options", async () => {
const mockResponse = { id: "pred_123", status: "completed" };
jest
.spyOn(webPredictions["requestor"], "request")
.mockResolvedValue([mockResponse, 200, {}]);

const result = await webPredictions.generate({
url: "https://example.com",
model: "model1",
domain: "domain1",
mode: "fast",
config: {
detail: "hi",
jsonSchema: { type: "object" },
confidence: true,
grounding: true,
},
metadata: {
environment: "prod",
sessionId: "session123",
allowTraining: false,
},
callbackUrl: "https://callback.example.com",
});

expect(result).toEqual(mockResponse);
expect(webPredictions["requestor"].request).toHaveBeenCalledWith(
"POST",
"/web/generate",
undefined,
{
url: "https://example.com",
model: "model1",
domain: "domain1",
mode: "fast",
config: {
detail: "hi",
json_schema: { type: "object" },
confidence: true,
grounding: true,
},
metadata: {
environment: "prod",
session_id: "session123",
allow_training: false,
},
callback_url: "https://callback.example.com",
}
);
});
});
});
});

0 comments on commit 8cd873e

Please sign in to comment.