Skip to content

Commit ae5ea56

Browse files
authored
Merge pull request #57 from vlm-run/sh/add-finetuning
Add finetune endpoints
2 parents a27ad26 + a5537a0 commit ae5ea56

File tree

7 files changed

+505
-10
lines changed

7 files changed

+505
-10
lines changed

package-lock.json

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "vlmrun",
3-
"version": "0.2.5",
3+
"version": "0.2.6",
44
"description": "The official TypeScript library for the VlmRun API",
55
"author": "VlmRun <[email protected]>",
66
"main": "dist/index.js",
@@ -28,6 +28,7 @@
2828
"dotenv": "^16.4.7",
2929
"path": "^0.12.7",
3030
"zod": "~3.24.2",
31+
"mime-types": "^2.1.35",
3132
"zod-to-json-schema": "~3.24.1"
3233
},
3334
"devDependencies": {

src/client/fine_tuning.ts

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import { Client, APIRequestor } from "./base_requestor";
2+
import { FinetuningResponse, FinetuningProvisionResponse, FinetuningGenerateParams, FinetuningListParams, PredictionResponse, FinetuningCreateParams, FinetuningProvisionParams } from "./types";
3+
import { encodeImage, processImage } from "../utils";
4+
5+
export class Finetuning {
6+
private requestor: APIRequestor;
7+
8+
constructor(client: Client) {
9+
this.requestor = new APIRequestor({
10+
...client,
11+
baseURL: `${client.baseURL}/fine_tuning`,
12+
});
13+
}
14+
15+
/**
16+
* Create a fine-tuning job
17+
* @param {Object} params - Fine-tuning parameters
18+
* @param {string} params.model - Base model to fine-tune
19+
* @param {string} params.training_file_id - File ID for training data
20+
* @param {string} [params.validation_file_id] - File ID for validation data
21+
* @param {number} [params.num_epochs=1] - Number of epochs
22+
* @param {number|string} [params.batch_size="auto"] - Batch size for training
23+
* @param {number} [params.learning_rate=2e-4] - Learning rate for training
24+
* @param {string} [params.suffix] - Suffix for the fine-tuned model
25+
* @param {string} [params.wandb_api_key] - Weights & Biases API key
26+
* @param {string} [params.wandb_base_url] - Weights & Biases base URL
27+
* @param {string} [params.wandb_project_name] - Weights & Biases project name
28+
*/
29+
async create(params: FinetuningCreateParams): Promise<FinetuningResponse> {
30+
if (params.suffix) {
31+
// Ensure suffix contains only alphanumeric, hyphens or underscores
32+
if (!/^[a-zA-Z0-9_-]+$/.test(params.suffix)) {
33+
throw new Error(
34+
"Suffix must be alphanumeric, hyphens or underscores without spaces"
35+
);
36+
}
37+
}
38+
39+
const [response] = await this.requestor.request<FinetuningResponse>(
40+
"POST",
41+
"create",
42+
undefined,
43+
{
44+
callback_url: params.callbackUrl,
45+
model: params.model,
46+
training_file: params.trainingFile,
47+
validation_file: params.validationFile,
48+
num_epochs: params.numEpochs ?? 1,
49+
batch_size: params.batchSize ?? 1,
50+
learning_rate: params.learningRate ?? 2e-4,
51+
suffix: params.suffix,
52+
wandb_api_key: params.wandbApiKey,
53+
wandb_base_url: params.wandbBaseUrl,
54+
wandb_project_name: params.wandbProjectName,
55+
}
56+
);
57+
58+
return response;
59+
}
60+
61+
/**
62+
* Provision a fine-tuning model
63+
* @param {Object} params - Provisioning parameters
64+
* @param {string} params.model - Model to provision
65+
* @param {number} [params.duration=600] - Duration for the provisioned model (in seconds)
66+
* @param {number} [params.concurrency=1] - Concurrency for the provisioned model
67+
*/
68+
async provision(params: FinetuningProvisionParams): Promise<FinetuningProvisionResponse> {
69+
const [response] = await this.requestor.request<FinetuningProvisionResponse>(
70+
"POST",
71+
"provision",
72+
undefined,
73+
{
74+
model: params.model,
75+
duration: params.duration ?? 600, // 10 minutes default
76+
concurrency: params.concurrency ?? 1,
77+
}
78+
);
79+
80+
return response;
81+
}
82+
83+
/**
84+
* Generate a prediction using a fine-tuned model
85+
* @param {FinetuningGenerateParams} params - Generation parameters
86+
*/
87+
async generate(params: FinetuningGenerateParams): Promise<PredictionResponse> {
88+
if (!params.jsonSchema) {
89+
throw new Error("JSON schema is required for fine-tuned model predictions");
90+
}
91+
92+
if (!params.prompt) {
93+
throw new Error("Prompt is required for fine-tuned model predictions");
94+
}
95+
96+
if (params.domain) {
97+
throw new Error("Domain is not supported for fine-tuned model predictions");
98+
}
99+
100+
if (params.detail && params.detail !== "auto") {
101+
throw new Error("Detail level is not supported for fine-tuned model predictions");
102+
}
103+
104+
if (params.batch) {
105+
throw new Error("Batch mode is not supported for fine-tuned models");
106+
}
107+
108+
if (params.callbackUrl) {
109+
throw new Error("Callback URL is not supported for fine-tuned model predictions");
110+
}
111+
112+
const [response] = await this.requestor.request<PredictionResponse>(
113+
"POST",
114+
"generate",
115+
undefined,
116+
{
117+
images: params.images.map((image) => processImage(image)),
118+
model: params.model,
119+
config: {
120+
prompt: params.prompt,
121+
json_schema: params.jsonSchema,
122+
detail: params.detail ?? "auto",
123+
response_model: params.responseModel,
124+
confidence: params.confidence ?? false,
125+
grounding: params.grounding ?? false,
126+
max_retries: params.maxRetries ?? 3,
127+
max_tokens: params.maxTokens ?? 4096,
128+
},
129+
max_new_tokens: params.maxNewTokens ?? 1024,
130+
temperature: params.temperature ?? 0.0,
131+
metadata: {
132+
environment: params?.environment ?? "dev",
133+
session_id: params?.sessionId,
134+
allow_training: params?.allowTraining ?? true,
135+
},
136+
batch: params.batch ?? false,
137+
callback_url: params.callbackUrl,
138+
}
139+
);
140+
141+
return response;
142+
}
143+
144+
/**
145+
* List all fine-tuning jobs
146+
* @param {FinetuningListParams} params - List parameters
147+
*/
148+
async list(params?: FinetuningListParams): Promise<FinetuningResponse[]> {
149+
const [response] = await this.requestor.request<FinetuningResponse[]>(
150+
"GET",
151+
"jobs",
152+
{
153+
skip: params?.skip ?? 0,
154+
limit: params?.limit ?? 10,
155+
}
156+
);
157+
158+
return response;
159+
}
160+
161+
/**
162+
* Get fine-tuning job details
163+
* @param {string} jobId - ID of job to retrieve
164+
*/
165+
async get(jobId: string): Promise<FinetuningResponse> {
166+
const [response] = await this.requestor.request<FinetuningResponse>(
167+
"GET",
168+
`jobs/${jobId}`
169+
);
170+
171+
return response;
172+
}
173+
174+
/**
175+
* Cancel a fine-tuning job
176+
* @param {string} jobId - ID of job to cancel
177+
*/
178+
async cancel(jobId: string): Promise<Record<string, any>> {
179+
throw new Error("Not implemented");
180+
}
181+
}

src/client/types.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,79 @@ export interface WebPredictionParams extends PredictionGenerateParams {
174174
mode: "fast" | "accurate";
175175
}
176176

177+
export interface FinetuningResponse {
178+
id: string;
179+
created_at: string;
180+
completed_at?: string;
181+
status: JobStatus;
182+
model: string;
183+
training_file_id: string;
184+
validation_file_id?: string;
185+
num_epochs: number;
186+
batch_size: number | string;
187+
learning_rate: number;
188+
suffix?: string;
189+
wandb_url?: string;
190+
message?: string;
191+
}
192+
193+
export interface FinetuningProvisionResponse {
194+
id: string;
195+
created_at: string;
196+
model: string;
197+
duration: number;
198+
concurrency: number;
199+
status: JobStatus;
200+
message?: string;
201+
}
202+
203+
export interface FinetuningCreateParams {
204+
callbackUrl?: string;
205+
model: string;
206+
trainingFile: string;
207+
validationFile?: string;
208+
numEpochs?: number;
209+
batchSize?: number | string;
210+
learningRate?: number;
211+
suffix?: string;
212+
wandbApiKey?: string;
213+
wandbBaseUrl?: string;
214+
wandbProjectName?: string;
215+
}
216+
217+
export interface FinetuningGenerateParams {
218+
images: string[];
219+
model: string;
220+
prompt?: string;
221+
domain?: string;
222+
jsonSchema?: Record<string, any>;
223+
maxNewTokens?: number;
224+
temperature?: number;
225+
detail?: "auto" | "hi" | "lo";
226+
batch?: boolean;
227+
metadata?: Record<string, any>;
228+
callbackUrl?: string;
229+
maxRetries?: number;
230+
maxTokens?: number;
231+
confidence?: boolean;
232+
grounding?: boolean;
233+
environment?: string;
234+
sessionId?: string;
235+
allowTraining?: boolean;
236+
responseModel?: string;
237+
}
238+
239+
export interface FinetuningProvisionParams {
240+
model: string;
241+
duration?: number;
242+
concurrency?: number;
243+
}
244+
245+
export interface FinetuningListParams {
246+
skip?: number;
247+
limit?: number;
248+
}
249+
177250
export class APIError extends Error {
178251
constructor(
179252
message: string,

src/index.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ import {
1010
WebPredictions,
1111
} from "./client/predictions";
1212
import { Feedback } from "./client/feedback";
13+
import { Finetuning } from "./client/fine_tuning";
1314

1415
export * from "./client/types";
1516
export * from "./client/base_requestor";
1617
export * from "./client/models";
1718
export * from "./client/files";
1819
export * from "./client/predictions";
1920
export * from "./client/feedback";
21+
export * from "./client/fine_tuning";
2022

2123
export * from "./utils";
2224

@@ -36,6 +38,7 @@ export class VlmRun {
3638
readonly video: ReturnType<typeof VideoPredictions>;
3739
readonly web: WebPredictions;
3840
readonly feedback: Feedback;
41+
readonly finetuning: Finetuning;
3942

4043
constructor(config: VlmRunConfig) {
4144
this.client = {
@@ -50,7 +53,8 @@ export class VlmRun {
5053
this.document = DocumentPredictions(this.client);
5154
this.audio = AudioPredictions(this.client);
5255
this.video = VideoPredictions(this.client);
53-
this.feedback = new Feedback(this.client);
5456
this.web = new WebPredictions(this.client);
57+
this.feedback = new Feedback(this.client);
58+
this.finetuning = new Finetuning(this.client);
5559
}
5660
}

src/utils/file.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ export const readFileFromPathAsFile = async (filePath: string): Promise<File> =>
22
try {
33
if (typeof window === 'undefined') {
44
const fs = require("fs/promises");
5-
const path = require('path');
5+
const path = require("path");
6+
const mime = require("mime-types");
67

78
const fileBuffer = await fs.readFile(filePath);
89
const fileName = path.basename(filePath);
9-
return new File([fileBuffer], fileName, {
10-
type: 'application/pdf',
11-
});
10+
const mimeType = mime.lookup(fileName) || "application/octet-stream";
11+
12+
return new File([fileBuffer], fileName, { type: mimeType });
1213
} else {
13-
throw new Error('File reading is not supported in the browser');
14+
throw new Error("File reading is not supported in the browser");
1415
}
1516
} catch (error: any) {
1617
throw new Error(`Error reading file at ${filePath}: ${error.message}`);
1718
}
18-
}
19+
};

0 commit comments

Comments
 (0)