Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Move provider-specific logic away from makeRequestOptions (1 provider == 1 module) #1208

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
246 changes: 66 additions & 180 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
import { NEBIUS_API_BASE_URL } from "../providers/nebius";
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import { NOVITA_API_BASE_URL } from "../providers/novita";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
import { FAL_AI_CONFIG } from "../providers/fal-ai";
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
import { NEBIUS_CONFIG } from "../providers/nebius";
import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
import { getProviderModelId } from "./getProviderModelId";
Expand All @@ -22,6 +22,22 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;

/**
* Config to define how to serialize requests for each provider
*/
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
"fal-ai": FAL_AI_CONFIG,
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
sambanova: SAMBANOVA_CONFIG,
together: TOGETHER_CONFIG,
};

/**
* Helper that prepares request arguments
*/
Expand All @@ -37,10 +53,10 @@ export async function makeRequestOptions(
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
let otherArgs = remainingArgs;
const provider = maybeProvider ?? "hf-inference";
const providerConfig = providerConfigs[provider];

const { includeCredentials, taskHint, chatCompletion } = options ?? {};
const { includeCredentials, taskHint, chatCompletion, signal } = options ?? {};

if (endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
Expand All @@ -51,6 +67,9 @@ export async function makeRequestOptions(
if (!maybeModel && !taskHint) {
throw new Error("No model provided, and no task has been specified.");
}
if (!providerConfig) {
throw new Error(`No provider config found for provider ${provider}`);
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
Expand All @@ -68,44 +87,52 @@ export async function makeRequestOptions(
? "credentials-include"
: "none";

// Make URL
const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
authMethod,
chatCompletion: chatCompletion ?? false,
: providerConfig.makeUrl({
baseUrl:
authMethod !== "provider-key"
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
: providerConfig.baseUrl,
model,
provider: provider ?? "hf-inference",
taskHint,
chatCompletion,
});

const headers: Record<string, string> = {};
if (accessToken) {
if (provider === "fal-ai" && authMethod === "provider-key") {
headers["Authorization"] = `Key ${accessToken}`;
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
headers["X-Key"] = accessToken;
} else {
headers["Authorization"] = `Bearer ${accessToken}`;
}
}

// e.g. @huggingface/inference/3.1.3
const ownUserAgent = `${packageName}/${packageVersion}`;
headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
.filter((x) => x !== undefined)
.join(" ");

// Make headers
const binary = "data" in args && !!args.data;
const headers = providerConfig.makeHeaders({
accessToken,
authMethod,
});

// Add content-type to headers
if (!binary) {
headers["Content-Type"] = "application/json";
}

if (provider === "replicate") {
headers["Prefer"] = "wait";
}
// Add user-agent to headers
// e.g. @huggingface/inference/3.1.3
const ownUserAgent = `${packageName}/${packageVersion}`;
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
.filter((x) => x !== undefined)
.join(" ");
headers["User-Agent"] = userAgent;

// Make body
const body = binary
? args.data
: JSON.stringify(
providerConfig.makeBody({
args: remainingArgs as Record<string, unknown>,
model,
taskHint,
chatCompletion,
})
);

Comment on lines +125 to 136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independent from the changes introduced here: I dislike this binary data special case, it's not obvious from an external pov that passing data in arguments will result in the payload being passed as raw bytes in the body.

/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
Expand All @@ -117,158 +144,17 @@ export async function makeRequestOptions(
credentials = "include";
}

/**
* Replicate models wrap all inputs inside { input: ... }
* Versioned Replicate models in the format `owner/model:version` expect the version in the body
*/
if (provider === "replicate") {
const version = model.includes(":") ? model.split(":")[1] : undefined;
(otherArgs as unknown) = { input: otherArgs, version };
}

const info: RequestInit = {
headers,
method: "POST",
body: binary
? args.data
: JSON.stringify({
...otherArgs,
...(taskHint === "text-to-image" && provider === "hyperbolic"
? { model_name: model }
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
? { model }
: undefined),
}),
body,
...(credentials ? { credentials } : undefined),
signal: options?.signal,
signal,
};

return { url, info };
}

function makeUrl(params: {
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
chatCompletion: boolean;
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
}): string {
if (params.authMethod === "none" && params.provider !== "hf-inference") {
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
}

const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
switch (params.provider) {
case "black-forest-labs": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: BLACKFORESTLABS_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "fal-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: FAL_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "nebius": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: NEBIUS_API_BASE_URL;

if (params.taskHint === "text-to-image") {
return `${baseUrl}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return `${baseUrl}/v1/completions`;
}
return baseUrl;
}
case "replicate": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: REPLICATE_API_BASE_URL;
if (params.model.includes(":")) {
/// Versioned model
return `${baseUrl}/v1/predictions`;
}
/// Evergreen / Canonical model
return `${baseUrl}/v1/models/${params.model}/predictions`;
}
case "sambanova": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: SAMBANOVA_API_BASE_URL;
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return baseUrl;
}
case "together": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: TOGETHER_API_BASE_URL;
/// Together API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-to-image") {
return `${baseUrl}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return `${baseUrl}/v1/completions`;
}
return baseUrl;
}

case "fireworks-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: FIREWORKS_AI_API_BASE_URL;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return baseUrl;
}
case "hyperbolic": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: HYPERBOLIC_API_BASE_URL;

if (params.taskHint === "text-to-image") {
return `${baseUrl}/v1/images/generations`;
}
return `${baseUrl}/v1/chat/completions`;
}
case "novita": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: NOVITA_API_BASE_URL;
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${baseUrl}/chat/completions`;
}
return `${baseUrl}/completions`;
}
return baseUrl;
}
default: {
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
}
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
}
return `${baseUrl}/models/${params.model}`;
}
}
}
async function loadDefaultModel(task: InferenceTask): Promise<string> {
if (!tasks) {
tasks = await loadTaskInfo();
Expand Down
28 changes: 26 additions & 2 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";

/**
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
*
Expand All @@ -16,3 +14,29 @@ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";

const makeBody = (params: BodyParams): Record<string, unknown> => {
return params.args;
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
if (params.authMethod === "provider-key") {
return { "X-Key": `${params.accessToken}` };
} else {
return { Authorization: `Bearer ${params.accessToken}` };
}
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/${params.model}`;
};

export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
};
2 changes: 1 addition & 1 deletion packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
"hf-inference": {},
hyperbolic: {},
nebius: {},
novita: {},
replicate: {},
sambanova: {},
together: {},
novita: {},
};
26 changes: 24 additions & 2 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export const FAL_AI_API_BASE_URL = "https://fal.run";

/**
* See the registered mapping of HF model ID => Fal model ID here:
*
Expand All @@ -16,3 +14,27 @@ export const FAL_AI_API_BASE_URL = "https://fal.run";
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const FAL_AI_API_BASE_URL = "https://fal.run";

const makeBody = (params: BodyParams): Record<string, unknown> => {
return params.args;
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return {
Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
};
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/${params.model}`;
};

export const FAL_AI_CONFIG: ProviderConfig = {
baseUrl: FAL_AI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
};
Loading