Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finetune endpoints #57

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "vlmrun",
"version": "0.2.5",
"version": "0.2.6",
"description": "The official TypeScript library for the VlmRun API",
"author": "VlmRun <[email protected]>",
"main": "dist/index.js",
Expand Down Expand Up @@ -28,6 +28,7 @@
"dotenv": "^16.4.7",
"path": "^0.12.7",
"zod": "~3.24.2",
"mime-types": "^2.1.35",
"zod-to-json-schema": "~3.24.1"
},
"devDependencies": {
Expand Down
181 changes: 181 additions & 0 deletions src/client/fine_tuning.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import { Client, APIRequestor } from "./base_requestor";
import { FinetuningResponse, FinetuningProvisionResponse, FinetuningGenerateParams, FinetuningListParams, PredictionResponse, FinetuningCreateParams, FinetuningProvisionParams } from "./types";
import { encodeImage, processImage } from "../utils";

export class Finetuning {
private requestor: APIRequestor;

constructor(client: Client) {
this.requestor = new APIRequestor({
...client,
baseURL: `${client.baseURL}/fine_tuning`,
});
}

/**
* Create a fine-tuning job
* @param {Object} params - Fine-tuning parameters
* @param {string} params.model - Base model to fine-tune
* @param {string} params.training_file_id - File ID for training data
* @param {string} [params.validation_file_id] - File ID for validation data
* @param {number} [params.num_epochs=1] - Number of epochs
* @param {number|string} [params.batch_size="auto"] - Batch size for training
* @param {number} [params.learning_rate=2e-4] - Learning rate for training
* @param {string} [params.suffix] - Suffix for the fine-tuned model
* @param {string} [params.wandb_api_key] - Weights & Biases API key
* @param {string} [params.wandb_base_url] - Weights & Biases base URL
* @param {string} [params.wandb_project_name] - Weights & Biases project name
*/
async create(params: FinetuningCreateParams): Promise<FinetuningResponse> {
if (params.suffix) {
// Ensure suffix contains only alphanumeric, hyphens or underscores
if (!/^[a-zA-Z0-9_-]+$/.test(params.suffix)) {
throw new Error(
"Suffix must be alphanumeric, hyphens or underscores without spaces"
);
}
}

const [response] = await this.requestor.request<FinetuningResponse>(
"POST",
"create",
undefined,
{
callback_url: params.callbackUrl,
model: params.model,
training_file: params.trainingFile,
validation_file: params.validationFile,
num_epochs: params.numEpochs ?? 1,
batch_size: params.batchSize ?? 1,
learning_rate: params.learningRate ?? 2e-4,
suffix: params.suffix,
wandb_api_key: params.wandbApiKey,
wandb_base_url: params.wandbBaseUrl,
wandb_project_name: params.wandbProjectName,
}
);

return response;
}

/**
* Provision a fine-tuning model
* @param {Object} params - Provisioning parameters
* @param {string} params.model - Model to provision
* @param {number} [params.duration=600] - Duration for the provisioned model (in seconds)
* @param {number} [params.concurrency=1] - Concurrency for the provisioned model
*/
async provision(params: FinetuningProvisionParams): Promise<FinetuningProvisionResponse> {
const [response] = await this.requestor.request<FinetuningProvisionResponse>(
"POST",
"provision",
undefined,
{
model: params.model,
duration: params.duration ?? 600, // 10 minutes default
concurrency: params.concurrency ?? 1,
}
);

return response;
}

/**
* Generate a prediction using a fine-tuned model
* @param {FinetuningGenerateParams} params - Generation parameters
*/
async generate(params: FinetuningGenerateParams): Promise<PredictionResponse> {
if (!params.jsonSchema) {
throw new Error("JSON schema is required for fine-tuned model predictions");
}

if (!params.prompt) {
throw new Error("Prompt is required for fine-tuned model predictions");
}

if (params.domain) {
throw new Error("Domain is not supported for fine-tuned model predictions");
}

if (params.detail && params.detail !== "auto") {
throw new Error("Detail level is not supported for fine-tuned model predictions");
}

if (params.batch) {
throw new Error("Batch mode is not supported for fine-tuned models");
}

if (params.callbackUrl) {
throw new Error("Callback URL is not supported for fine-tuned model predictions");
}

const [response] = await this.requestor.request<PredictionResponse>(
"POST",
"generate",
undefined,
{
images: params.images.map((image) => processImage(image)),
model: params.model,
config: {
prompt: params.prompt,
json_schema: params.jsonSchema,
detail: params.detail ?? "auto",
response_model: params.responseModel,
confidence: params.confidence ?? false,
grounding: params.grounding ?? false,
max_retries: params.maxRetries ?? 3,
max_tokens: params.maxTokens ?? 4096,
},
max_new_tokens: params.maxNewTokens ?? 1024,
temperature: params.temperature ?? 0.0,
metadata: {
environment: params?.environment ?? "dev",
session_id: params?.sessionId,
allow_training: params?.allowTraining ?? true,
},
batch: params.batch ?? false,
callback_url: params.callbackUrl,
}
);

return response;
}

/**
* List all fine-tuning jobs
* @param {FinetuningListParams} params - List parameters
*/
async list(params?: FinetuningListParams): Promise<FinetuningResponse[]> {
const [response] = await this.requestor.request<FinetuningResponse[]>(
"GET",
"jobs",
{
skip: params?.skip ?? 0,
limit: params?.limit ?? 10,
}
);

return response;
}

/**
* Get fine-tuning job details
* @param {string} jobId - ID of job to retrieve
*/
async get(jobId: string): Promise<FinetuningResponse> {
const [response] = await this.requestor.request<FinetuningResponse>(
"GET",
`jobs/${jobId}`
);

return response;
}

/**
* Cancel a fine-tuning job
* @param {string} jobId - ID of job to cancel
*/
async cancel(jobId: string): Promise<Record<string, any>> {
throw new Error("Not implemented");
}
}
73 changes: 73 additions & 0 deletions src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,79 @@ export interface WebPredictionParams extends PredictionGenerateParams {
mode: "fast" | "accurate";
}

export interface FinetuningResponse {
id: string;
created_at: string;
completed_at?: string;
status: JobStatus;
model: string;
training_file_id: string;
validation_file_id?: string;
num_epochs: number;
batch_size: number | string;
learning_rate: number;
suffix?: string;
wandb_url?: string;
message?: string;
}

export interface FinetuningProvisionResponse {
id: string;
created_at: string;
model: string;
duration: number;
concurrency: number;
status: JobStatus;
message?: string;
}

export interface FinetuningCreateParams {
callbackUrl?: string;
model: string;
trainingFile: string;
validationFile?: string;
numEpochs?: number;
batchSize?: number | string;
learningRate?: number;
suffix?: string;
wandbApiKey?: string;
wandbBaseUrl?: string;
wandbProjectName?: string;
}

export interface FinetuningGenerateParams {
images: string[];
model: string;
prompt?: string;
domain?: string;
jsonSchema?: Record<string, any>;
maxNewTokens?: number;
temperature?: number;
detail?: "auto" | "hi" | "lo";
batch?: boolean;
metadata?: Record<string, any>;
callbackUrl?: string;
maxRetries?: number;
maxTokens?: number;
confidence?: boolean;
grounding?: boolean;
environment?: string;
sessionId?: string;
allowTraining?: boolean;
responseModel?: string;
}

export interface FinetuningProvisionParams {
model: string;
duration?: number;
concurrency?: number;
}

export interface FinetuningListParams {
skip?: number;
limit?: number;
}

export class APIError extends Error {
constructor(
message: string,
Expand Down
6 changes: 5 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import {
WebPredictions,
} from "./client/predictions";
import { Feedback } from "./client/feedback";
import { Finetuning } from "./client/fine_tuning";

export * from "./client/types";
export * from "./client/base_requestor";
export * from "./client/models";
export * from "./client/files";
export * from "./client/predictions";
export * from "./client/feedback";
export * from "./client/fine_tuning";

export * from "./utils";

Expand All @@ -36,6 +38,7 @@ export class VlmRun {
readonly video: ReturnType<typeof VideoPredictions>;
readonly web: WebPredictions;
readonly feedback: Feedback;
readonly finetuning: Finetuning;

constructor(config: VlmRunConfig) {
this.client = {
Expand All @@ -50,7 +53,8 @@ export class VlmRun {
this.document = DocumentPredictions(this.client);
this.audio = AudioPredictions(this.client);
this.video = VideoPredictions(this.client);
this.feedback = new Feedback(this.client);
this.web = new WebPredictions(this.client);
this.feedback = new Feedback(this.client);
this.finetuning = new Finetuning(this.client);
}
}
13 changes: 7 additions & 6 deletions src/utils/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ export const readFileFromPathAsFile = async (filePath: string): Promise<File> =>
try {
if (typeof window === 'undefined') {
const fs = require("fs/promises");
const path = require('path');
const path = require("path");
const mime = require("mime-types");

const fileBuffer = await fs.readFile(filePath);
const fileName = path.basename(filePath);
return new File([fileBuffer], fileName, {
type: 'application/pdf',
});
const mimeType = mime.lookup(fileName) || "application/octet-stream";

return new File([fileBuffer], fileName, { type: mimeType });
} else {
throw new Error('File reading is not supported in the browser');
throw new Error("File reading is not supported in the browser");
}
} catch (error: any) {
throw new Error(`Error reading file at ${filePath}: ${error.message}`);
}
}
};
Loading