|
| 1 | +import { Client, APIRequestor } from "./base_requestor"; |
| 2 | +import { FinetuningResponse, FinetuningProvisionResponse, FinetuningGenerateParams, FinetuningListParams, PredictionResponse, FinetuningCreateParams, FinetuningProvisionParams } from "./types"; |
| 3 | +import { encodeImage, processImage } from "../utils"; |
| 4 | + |
| 5 | +export class Finetuning { |
| 6 | + private requestor: APIRequestor; |
| 7 | + |
| 8 | + constructor(client: Client) { |
| 9 | + this.requestor = new APIRequestor({ |
| 10 | + ...client, |
| 11 | + baseURL: `${client.baseURL}/fine_tuning`, |
| 12 | + }); |
| 13 | + } |
| 14 | + |
| 15 | + /** |
| 16 | + * Create a fine-tuning job |
| 17 | + * @param {Object} params - Fine-tuning parameters |
| 18 | + * @param {string} params.model - Base model to fine-tune |
| 19 | + * @param {string} params.training_file_id - File ID for training data |
| 20 | + * @param {string} [params.validation_file_id] - File ID for validation data |
| 21 | + * @param {number} [params.num_epochs=1] - Number of epochs |
| 22 | + * @param {number|string} [params.batch_size="auto"] - Batch size for training |
| 23 | + * @param {number} [params.learning_rate=2e-4] - Learning rate for training |
| 24 | + * @param {string} [params.suffix] - Suffix for the fine-tuned model |
| 25 | + * @param {string} [params.wandb_api_key] - Weights & Biases API key |
| 26 | + * @param {string} [params.wandb_base_url] - Weights & Biases base URL |
| 27 | + * @param {string} [params.wandb_project_name] - Weights & Biases project name |
| 28 | + */ |
| 29 | + async create(params: FinetuningCreateParams): Promise<FinetuningResponse> { |
| 30 | + if (params.suffix) { |
| 31 | + // Ensure suffix contains only alphanumeric, hyphens or underscores |
| 32 | + if (!/^[a-zA-Z0-9_-]+$/.test(params.suffix)) { |
| 33 | + throw new Error( |
| 34 | + "Suffix must be alphanumeric, hyphens or underscores without spaces" |
| 35 | + ); |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + const [response] = await this.requestor.request<FinetuningResponse>( |
| 40 | + "POST", |
| 41 | + "create", |
| 42 | + undefined, |
| 43 | + { |
| 44 | + callback_url: params.callbackUrl, |
| 45 | + model: params.model, |
| 46 | + training_file: params.trainingFile, |
| 47 | + validation_file: params.validationFile, |
| 48 | + num_epochs: params.numEpochs ?? 1, |
| 49 | + batch_size: params.batchSize ?? 1, |
| 50 | + learning_rate: params.learningRate ?? 2e-4, |
| 51 | + suffix: params.suffix, |
| 52 | + wandb_api_key: params.wandbApiKey, |
| 53 | + wandb_base_url: params.wandbBaseUrl, |
| 54 | + wandb_project_name: params.wandbProjectName, |
| 55 | + } |
| 56 | + ); |
| 57 | + |
| 58 | + return response; |
| 59 | + } |
| 60 | + |
| 61 | + /** |
| 62 | + * Provision a fine-tuning model |
| 63 | + * @param {Object} params - Provisioning parameters |
| 64 | + * @param {string} params.model - Model to provision |
| 65 | + * @param {number} [params.duration=600] - Duration for the provisioned model (in seconds) |
| 66 | + * @param {number} [params.concurrency=1] - Concurrency for the provisioned model |
| 67 | + */ |
| 68 | + async provision(params: FinetuningProvisionParams): Promise<FinetuningProvisionResponse> { |
| 69 | + const [response] = await this.requestor.request<FinetuningProvisionResponse>( |
| 70 | + "POST", |
| 71 | + "provision", |
| 72 | + undefined, |
| 73 | + { |
| 74 | + model: params.model, |
| 75 | + duration: params.duration ?? 600, // 10 minutes default |
| 76 | + concurrency: params.concurrency ?? 1, |
| 77 | + } |
| 78 | + ); |
| 79 | + |
| 80 | + return response; |
| 81 | + } |
| 82 | + |
| 83 | + /** |
| 84 | + * Generate a prediction using a fine-tuned model |
| 85 | + * @param {FinetuningGenerateParams} params - Generation parameters |
| 86 | + */ |
| 87 | + async generate(params: FinetuningGenerateParams): Promise<PredictionResponse> { |
| 88 | + if (!params.jsonSchema) { |
| 89 | + throw new Error("JSON schema is required for fine-tuned model predictions"); |
| 90 | + } |
| 91 | + |
| 92 | + if (!params.prompt) { |
| 93 | + throw new Error("Prompt is required for fine-tuned model predictions"); |
| 94 | + } |
| 95 | + |
| 96 | + if (params.domain) { |
| 97 | + throw new Error("Domain is not supported for fine-tuned model predictions"); |
| 98 | + } |
| 99 | + |
| 100 | + if (params.detail && params.detail !== "auto") { |
| 101 | + throw new Error("Detail level is not supported for fine-tuned model predictions"); |
| 102 | + } |
| 103 | + |
| 104 | + if (params.batch) { |
| 105 | + throw new Error("Batch mode is not supported for fine-tuned models"); |
| 106 | + } |
| 107 | + |
| 108 | + if (params.callbackUrl) { |
| 109 | + throw new Error("Callback URL is not supported for fine-tuned model predictions"); |
| 110 | + } |
| 111 | + |
| 112 | + const [response] = await this.requestor.request<PredictionResponse>( |
| 113 | + "POST", |
| 114 | + "generate", |
| 115 | + undefined, |
| 116 | + { |
| 117 | + images: params.images.map((image) => processImage(image)), |
| 118 | + model: params.model, |
| 119 | + config: { |
| 120 | + prompt: params.prompt, |
| 121 | + json_schema: params.jsonSchema, |
| 122 | + detail: params.detail ?? "auto", |
| 123 | + response_model: params.responseModel, |
| 124 | + confidence: params.confidence ?? false, |
| 125 | + grounding: params.grounding ?? false, |
| 126 | + max_retries: params.maxRetries ?? 3, |
| 127 | + max_tokens: params.maxTokens ?? 4096, |
| 128 | + }, |
| 129 | + max_new_tokens: params.maxNewTokens ?? 1024, |
| 130 | + temperature: params.temperature ?? 0.0, |
| 131 | + metadata: { |
| 132 | + environment: params?.environment ?? "dev", |
| 133 | + session_id: params?.sessionId, |
| 134 | + allow_training: params?.allowTraining ?? true, |
| 135 | + }, |
| 136 | + batch: params.batch ?? false, |
| 137 | + callback_url: params.callbackUrl, |
| 138 | + } |
| 139 | + ); |
| 140 | + |
| 141 | + return response; |
| 142 | + } |
| 143 | + |
| 144 | + /** |
| 145 | + * List all fine-tuning jobs |
| 146 | + * @param {FinetuningListParams} params - List parameters |
| 147 | + */ |
| 148 | + async list(params?: FinetuningListParams): Promise<FinetuningResponse[]> { |
| 149 | + const [response] = await this.requestor.request<FinetuningResponse[]>( |
| 150 | + "GET", |
| 151 | + "jobs", |
| 152 | + { |
| 153 | + skip: params?.skip ?? 0, |
| 154 | + limit: params?.limit ?? 10, |
| 155 | + } |
| 156 | + ); |
| 157 | + |
| 158 | + return response; |
| 159 | + } |
| 160 | + |
| 161 | + /** |
| 162 | + * Get fine-tuning job details |
| 163 | + * @param {string} jobId - ID of job to retrieve |
| 164 | + */ |
| 165 | + async get(jobId: string): Promise<FinetuningResponse> { |
| 166 | + const [response] = await this.requestor.request<FinetuningResponse>( |
| 167 | + "GET", |
| 168 | + `jobs/${jobId}` |
| 169 | + ); |
| 170 | + |
| 171 | + return response; |
| 172 | + } |
| 173 | + |
| 174 | + /** |
| 175 | + * Cancel a fine-tuning job |
| 176 | + * @param {string} jobId - ID of job to cancel |
| 177 | + */ |
| 178 | + async cancel(jobId: string): Promise<Record<string, any>> { |
| 179 | + throw new Error("Not implemented"); |
| 180 | + } |
| 181 | +} |
0 commit comments