Skip to content

Commit

Permalink
feat: add finetune endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
shahrear33 committed Feb 17, 2025
1 parent a27ad26 commit a5537a0
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 10 deletions.
5 changes: 3 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "vlmrun",
"version": "0.2.5",
"version": "0.2.6",
"description": "The official TypeScript library for the VlmRun API",
"author": "VlmRun <[email protected]>",
"main": "dist/index.js",
Expand Down Expand Up @@ -28,6 +28,7 @@
"dotenv": "^16.4.7",
"path": "^0.12.7",
"zod": "~3.24.2",
"mime-types": "^2.1.35",
"zod-to-json-schema": "~3.24.1"
},
"devDependencies": {
Expand Down
181 changes: 181 additions & 0 deletions src/client/fine_tuning.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import { Client, APIRequestor } from "./base_requestor";
import { FinetuningResponse, FinetuningProvisionResponse, FinetuningGenerateParams, FinetuningListParams, PredictionResponse, FinetuningCreateParams, FinetuningProvisionParams } from "./types";
import { encodeImage, processImage } from "../utils";

export class Finetuning {
private requestor: APIRequestor;

constructor(client: Client) {
this.requestor = new APIRequestor({
...client,
baseURL: `${client.baseURL}/fine_tuning`,
});
}

/**
* Create a fine-tuning job
* @param {Object} params - Fine-tuning parameters
* @param {string} params.model - Base model to fine-tune
* @param {string} params.training_file_id - File ID for training data
* @param {string} [params.validation_file_id] - File ID for validation data
* @param {number} [params.num_epochs=1] - Number of epochs
* @param {number|string} [params.batch_size="auto"] - Batch size for training
* @param {number} [params.learning_rate=2e-4] - Learning rate for training
* @param {string} [params.suffix] - Suffix for the fine-tuned model
* @param {string} [params.wandb_api_key] - Weights & Biases API key
* @param {string} [params.wandb_base_url] - Weights & Biases base URL
* @param {string} [params.wandb_project_name] - Weights & Biases project name
*/
async create(params: FinetuningCreateParams): Promise<FinetuningResponse> {
if (params.suffix) {
// Ensure suffix contains only alphanumeric, hyphens or underscores
if (!/^[a-zA-Z0-9_-]+$/.test(params.suffix)) {
throw new Error(
"Suffix must be alphanumeric, hyphens or underscores without spaces"
);
}
}

const [response] = await this.requestor.request<FinetuningResponse>(
"POST",
"create",
undefined,
{
callback_url: params.callbackUrl,
model: params.model,
training_file: params.trainingFile,
validation_file: params.validationFile,
num_epochs: params.numEpochs ?? 1,
batch_size: params.batchSize ?? 1,
learning_rate: params.learningRate ?? 2e-4,
suffix: params.suffix,
wandb_api_key: params.wandbApiKey,
wandb_base_url: params.wandbBaseUrl,
wandb_project_name: params.wandbProjectName,
}
);

return response;
}

/**
* Provision a fine-tuning model
* @param {Object} params - Provisioning parameters
* @param {string} params.model - Model to provision
* @param {number} [params.duration=600] - Duration for the provisioned model (in seconds)
* @param {number} [params.concurrency=1] - Concurrency for the provisioned model
*/
async provision(params: FinetuningProvisionParams): Promise<FinetuningProvisionResponse> {
const [response] = await this.requestor.request<FinetuningProvisionResponse>(
"POST",
"provision",
undefined,
{
model: params.model,
duration: params.duration ?? 600, // 10 minutes default
concurrency: params.concurrency ?? 1,
}
);

return response;
}

/**
* Generate a prediction using a fine-tuned model
* @param {FinetuningGenerateParams} params - Generation parameters
*/
async generate(params: FinetuningGenerateParams): Promise<PredictionResponse> {
if (!params.jsonSchema) {
throw new Error("JSON schema is required for fine-tuned model predictions");
}

if (!params.prompt) {
throw new Error("Prompt is required for fine-tuned model predictions");
}

if (params.domain) {
throw new Error("Domain is not supported for fine-tuned model predictions");
}

if (params.detail && params.detail !== "auto") {
throw new Error("Detail level is not supported for fine-tuned model predictions");
}

if (params.batch) {
throw new Error("Batch mode is not supported for fine-tuned models");
}

if (params.callbackUrl) {
throw new Error("Callback URL is not supported for fine-tuned model predictions");
}

const [response] = await this.requestor.request<PredictionResponse>(
"POST",
"generate",
undefined,
{
images: params.images.map((image) => processImage(image)),
model: params.model,
config: {
prompt: params.prompt,
json_schema: params.jsonSchema,
detail: params.detail ?? "auto",
response_model: params.responseModel,
confidence: params.confidence ?? false,
grounding: params.grounding ?? false,
max_retries: params.maxRetries ?? 3,
max_tokens: params.maxTokens ?? 4096,
},
max_new_tokens: params.maxNewTokens ?? 1024,
temperature: params.temperature ?? 0.0,
metadata: {
environment: params?.environment ?? "dev",
session_id: params?.sessionId,
allow_training: params?.allowTraining ?? true,
},
batch: params.batch ?? false,
callback_url: params.callbackUrl,
}
);

return response;
}

/**
* List all fine-tuning jobs
* @param {FinetuningListParams} params - List parameters
*/
async list(params?: FinetuningListParams): Promise<FinetuningResponse[]> {
const [response] = await this.requestor.request<FinetuningResponse[]>(
"GET",
"jobs",
{
skip: params?.skip ?? 0,
limit: params?.limit ?? 10,
}
);

return response;
}

/**
* Get fine-tuning job details
* @param {string} jobId - ID of job to retrieve
*/
async get(jobId: string): Promise<FinetuningResponse> {
const [response] = await this.requestor.request<FinetuningResponse>(
"GET",
`jobs/${jobId}`
);

return response;
}

/**
* Cancel a fine-tuning job
* @param {string} jobId - ID of job to cancel
*/
async cancel(jobId: string): Promise<Record<string, any>> {
throw new Error("Not implemented");
}
}
73 changes: 73 additions & 0 deletions src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,79 @@ export interface WebPredictionParams extends PredictionGenerateParams {
mode: "fast" | "accurate";
}

export interface FinetuningResponse {
id: string;
created_at: string;
completed_at?: string;
status: JobStatus;
model: string;
training_file_id: string;
validation_file_id?: string;
num_epochs: number;
batch_size: number | string;
learning_rate: number;
suffix?: string;
wandb_url?: string;
message?: string;
}

export interface FinetuningProvisionResponse {
id: string;
created_at: string;
model: string;
duration: number;
concurrency: number;
status: JobStatus;
message?: string;
}

export interface FinetuningCreateParams {
callbackUrl?: string;
model: string;
trainingFile: string;
validationFile?: string;
numEpochs?: number;
batchSize?: number | string;
learningRate?: number;
suffix?: string;
wandbApiKey?: string;
wandbBaseUrl?: string;
wandbProjectName?: string;
}

export interface FinetuningGenerateParams {
images: string[];
model: string;
prompt?: string;
domain?: string;
jsonSchema?: Record<string, any>;
maxNewTokens?: number;
temperature?: number;
detail?: "auto" | "hi" | "lo";
batch?: boolean;
metadata?: Record<string, any>;
callbackUrl?: string;
maxRetries?: number;
maxTokens?: number;
confidence?: boolean;
grounding?: boolean;
environment?: string;
sessionId?: string;
allowTraining?: boolean;
responseModel?: string;
}

export interface FinetuningProvisionParams {
model: string;
duration?: number;
concurrency?: number;
}

export interface FinetuningListParams {
skip?: number;
limit?: number;
}

export class APIError extends Error {
constructor(
message: string,
Expand Down
6 changes: 5 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import {
WebPredictions,
} from "./client/predictions";
import { Feedback } from "./client/feedback";
import { Finetuning } from "./client/fine_tuning";

export * from "./client/types";
export * from "./client/base_requestor";
export * from "./client/models";
export * from "./client/files";
export * from "./client/predictions";
export * from "./client/feedback";
export * from "./client/fine_tuning";

export * from "./utils";

Expand All @@ -36,6 +38,7 @@ export class VlmRun {
readonly video: ReturnType<typeof VideoPredictions>;
readonly web: WebPredictions;
readonly feedback: Feedback;
readonly finetuning: Finetuning;

constructor(config: VlmRunConfig) {
this.client = {
Expand All @@ -50,7 +53,8 @@ export class VlmRun {
this.document = DocumentPredictions(this.client);
this.audio = AudioPredictions(this.client);
this.video = VideoPredictions(this.client);
this.feedback = new Feedback(this.client);
this.web = new WebPredictions(this.client);
this.feedback = new Feedback(this.client);
this.finetuning = new Finetuning(this.client);
}
}
13 changes: 7 additions & 6 deletions src/utils/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ export const readFileFromPathAsFile = async (filePath: string): Promise<File> =>
try {
if (typeof window === 'undefined') {
const fs = require("fs/promises");
const path = require('path');
const path = require("path");
const mime = require("mime-types");

const fileBuffer = await fs.readFile(filePath);
const fileName = path.basename(filePath);
return new File([fileBuffer], fileName, {
type: 'application/pdf',
});
const mimeType = mime.lookup(fileName) || "application/octet-stream";

return new File([fileBuffer], fileName, { type: mimeType });
} else {
throw new Error('File reading is not supported in the browser');
throw new Error("File reading is not supported in the browser");
}
} catch (error: any) {
throw new Error(`Error reading file at ${filePath}: ${error.message}`);
}
}
};
Loading

0 comments on commit a5537a0

Please sign in to comment.