-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refacto Inference snippets tests (autogeneration + 1 snippet == 1 fil…
…e) (#1046) This PR is a refacto of the inference snippet tests. The main goal is to reduce friction when updating the inference snippets logic. In particular: - each generated snippet is saved into a single test file with the correct file extension => should greatly improve reviews. Currently snippets are saved as hardcoded strings in JS files, making them hard to review (no syntax highlighting, weird indentation). - snippets can be updated with a script, reducing the need to manually edit tests (which is quite time-consuming) This PR is quite large given the auto-generated files. The main parts to review are: - `generate-snippets-fixtures.ts` => the script that generates the snippets and test them. - package.json, pnpm lock files, etc. => I'm not entirely sure of what I did there (especially to make `tasks-gen` depend on `tasks`) - `python.spec.ts` / `js.specs.ts` and `curl.specs.ts` have been removed - everything in `packages/tasks-gen/snippets-fixtures/` => the test cases. I've only committed the ones that where previously tested. Thanks to this PR, I fixed a typo in the JS snippets (a missing comma) and curl snippets (consistency between `"` and `'`). cc @mishig25 with whom I quickly discussed this --------- Co-authored-by: Mishig <[email protected]>
- Loading branch information
Showing
34 changed files
with
824 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
## @huggingface.js/tasks-gen | ||
|
||
This package is not a published one. It contains scripts that generate or test parts of the `@huggingface.js/tasks` package. | ||
|
||
### generate-snippets-fixtures.ts | ||
|
||
This script generates and tests Inference API snippets. The goal is to have a simple way to review changes in the snippets. | ||
When updating logic in `packages/tasks/src/snippets`, the test snippets must be updated and committed in the same PR. | ||
|
||
To (re-)generate the snippets, run: | ||
|
||
``` | ||
pnpm generate-snippets-fixtures | ||
``` | ||
|
||
If some logic has been updated, you should see the result with a | ||
``` | ||
git diff | ||
# the diff has to be committed if correct | ||
``` | ||
|
||
To test the snippets, run: | ||
|
||
``` | ||
pnpm test | ||
``` | ||
|
||
Finally if you want to add a test case, you must add an entry in `TEST_CASES` array in `generate-snippets-fixtures.ts`. | ||
|
||
### inference-codegen.ts | ||
|
||
Generates JS and Python dataclasses based on the Inference Specs (jsonschema files). | ||
|
||
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` and `huggingface_hub` up to date. | ||
|
||
To update the specs manually, run: | ||
|
||
``` | ||
pnpm inference-codegen | ||
``` | ||
|
||
### inference-tei-import.ts | ||
|
||
Fetches TEI specs and generates JSON schema for input and output of text-embeddings (also called feature-extraction). | ||
See https://huggingface.github.io/text-embeddings-inference/ for more details. | ||
|
||
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TEI updates. | ||
|
||
To update the specs manually, run: | ||
|
||
``` | ||
pnpm inference-tei-import | ||
``` | ||
|
||
### inference-tgi-import.ts | ||
|
||
Fetches TGI specs and generates JSON schema for input, output and stream_output of text-generation and chat-completion tasks. | ||
See https://huggingface.github.io/text-generation-inference/ for more details. | ||
|
||
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TGI updates. | ||
|
||
To update the specs manually, run: | ||
|
||
``` | ||
pnpm inference-tgi-import | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
178 changes: 178 additions & 0 deletions
178
packages/tasks-gen/scripts/generate-snippets-fixtures.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* | ||
* Generates Inference API snippets using @huggingface/tasks snippets. | ||
* | ||
* If used in test mode ("pnpm test"), it compares the generated snippets with the expected ones. | ||
* If used in generation mode ("pnpm generate-snippets-fixtures"), it generates the expected snippets. | ||
* | ||
* Expected snippets are saved under ./snippets-fixtures and are meant to be versioned on GitHub. | ||
* Each snippet is saved in a separate file placed under "./{test-name}/{index}.{client}.{language}": | ||
* - test-name: the name of the test (e.g. "text-to-image", "conversational-llm", etc.) | ||
* - index: the order of the snippet in the array of snippets (0 if not an array) | ||
* - client: the client name (e.g. "requests", "huggingface_hub", "openai", etc.). Default to "default" if client is not specified. | ||
* - language: the language of the snippet (e.g. "sh", "js", "py", etc.) | ||
* | ||
* Example: | ||
* ./packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py | ||
*/ | ||
|
||
import { existsSync as pathExists } from "node:fs"; | ||
import * as fs from "node:fs/promises"; | ||
import * as path from "node:path/posix"; | ||
|
||
import type { InferenceSnippet } from "@huggingface/tasks"; | ||
import { snippets } from "@huggingface/tasks"; | ||
|
||
type LANGUAGE = "sh" | "js" | "py"; | ||
|
||
const TEST_CASES: { | ||
testName: string; | ||
model: snippets.ModelDataMinimal; | ||
languages: LANGUAGE[]; | ||
opts?: Record<string, unknown>; | ||
}[] = [ | ||
{ | ||
testName: "conversational-llm-non-stream", | ||
model: { | ||
id: "meta-llama/Llama-3.1-8B-Instruct", | ||
pipeline_tag: "text-generation", | ||
tags: ["conversational"], | ||
inference: "", | ||
}, | ||
languages: ["sh", "js", "py"], | ||
opts: { streaming: false }, | ||
}, | ||
{ | ||
testName: "conversational-llm-stream", | ||
model: { | ||
id: "meta-llama/Llama-3.1-8B-Instruct", | ||
pipeline_tag: "text-generation", | ||
tags: ["conversational"], | ||
inference: "", | ||
}, | ||
languages: ["sh", "js", "py"], | ||
opts: { streaming: true }, | ||
}, | ||
{ | ||
testName: "conversational-vlm-non-stream", | ||
model: { | ||
id: "meta-llama/Llama-3.2-11B-Vision-Instruct", | ||
pipeline_tag: "image-text-to-text", | ||
tags: ["conversational"], | ||
inference: "", | ||
}, | ||
languages: ["sh", "js", "py"], | ||
opts: { streaming: false }, | ||
}, | ||
{ | ||
testName: "conversational-vlm-stream", | ||
model: { | ||
id: "meta-llama/Llama-3.2-11B-Vision-Instruct", | ||
pipeline_tag: "image-text-to-text", | ||
tags: ["conversational"], | ||
inference: "", | ||
}, | ||
languages: ["sh", "js", "py"], | ||
opts: { streaming: true }, | ||
}, | ||
{ | ||
testName: "text-to-image", | ||
model: { | ||
id: "black-forest-labs/FLUX.1-schnell", | ||
pipeline_tag: "text-to-image", | ||
tags: [], | ||
inference: "", | ||
}, | ||
languages: ["sh", "js", "py"], | ||
}, | ||
] as const; | ||
|
||
const GET_SNIPPET_FN = { | ||
sh: snippets.curl.getCurlInferenceSnippet, | ||
js: snippets.js.getJsInferenceSnippet, | ||
py: snippets.python.getPythonInferenceSnippet, | ||
} as const; | ||
|
||
const rootDirFinder = (): string => { | ||
let currentPath = path.normalize(import.meta.url).replace("file:", ""); | ||
|
||
while (currentPath !== "/") { | ||
if (pathExists(path.join(currentPath, "package.json"))) { | ||
return currentPath; | ||
} | ||
|
||
currentPath = path.normalize(path.join(currentPath, "..")); | ||
} | ||
|
||
return "/"; | ||
}; | ||
|
||
function getFixtureFolder(testName: string): string { | ||
return path.join(rootDirFinder(), "snippets-fixtures", testName); | ||
} | ||
|
||
function generateInferenceSnippet( | ||
model: snippets.ModelDataMinimal, | ||
language: LANGUAGE, | ||
opts?: Record<string, unknown> | ||
): InferenceSnippet[] { | ||
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts); | ||
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets]; | ||
} | ||
|
||
async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise<InferenceSnippet[]> { | ||
const fixtureFolder = getFixtureFolder(testName); | ||
const files = await fs.readdir(fixtureFolder); | ||
|
||
const expectedSnippets: InferenceSnippet[] = []; | ||
for (const file of files.filter((file) => file.endsWith("." + language)).sort()) { | ||
const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js" | ||
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" }); | ||
expectedSnippets.push(client === "default" ? { content } : { client, content }); | ||
} | ||
return expectedSnippets; | ||
} | ||
|
||
async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) { | ||
const fixtureFolder = getFixtureFolder(testName); | ||
await fs.mkdir(fixtureFolder, { recursive: true }); | ||
|
||
for (const [index, snippet] of snippets.entries()) { | ||
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`); | ||
await fs.writeFile(file, snippet.content); | ||
} | ||
} | ||
|
||
if (import.meta.vitest) { | ||
// Run test if in test mode | ||
const { describe, expect, it } = import.meta.vitest; | ||
|
||
describe("inference API snippets", () => { | ||
TEST_CASES.forEach(({ testName, model, languages, opts }) => { | ||
describe(testName, () => { | ||
languages.forEach((language) => { | ||
it(language, async () => { | ||
const generatedSnippets = generateInferenceSnippet(model, language, opts); | ||
const expectedSnippets = await getExpectedInferenceSnippet(testName, language); | ||
expect(generatedSnippets).toEqual(expectedSnippets); | ||
}); | ||
}); | ||
}); | ||
}); | ||
}); | ||
} else { | ||
// Otherwise, generate the fixtures | ||
console.log("✨ Re-generating snippets"); | ||
console.debug(" 🚜 Removing existing fixtures..."); | ||
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true }); | ||
|
||
console.debug(" 🏭 Generating new fixtures..."); | ||
TEST_CASES.forEach(({ testName, model, languages, opts }) => { | ||
console.debug(` ${testName} (${languages.join(", ")})`); | ||
languages.forEach(async (language) => { | ||
const generatedSnippets = generateInferenceSnippet(model, language, opts); | ||
await saveExpectedInferenceSnippet(testName, language, generatedSnippets); | ||
}); | ||
}); | ||
console.log("✅ All done!"); | ||
console.log("👉 Please check the generated fixtures before committing them."); | ||
} |
14 changes: 14 additions & 0 deletions
14
packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.default.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \ | ||
-H 'Authorization: Bearer api_token' \ | ||
-H 'Content-Type: application/json' \ | ||
--data '{ | ||
"model": "meta-llama/Llama-3.1-8B-Instruct", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "What is the capital of France?" | ||
} | ||
], | ||
"max_tokens": 500, | ||
"stream": false | ||
}' |
16 changes: 16 additions & 0 deletions
16
packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { HfInference } from "@huggingface/inference"; | ||
|
||
const client = new HfInference("api_token"); | ||
|
||
const chatCompletion = await client.chatCompletion({ | ||
model: "meta-llama/Llama-3.1-8B-Instruct", | ||
messages: [ | ||
{ | ||
role: "user", | ||
content: "What is the capital of France?" | ||
} | ||
], | ||
max_tokens: 500 | ||
}); | ||
|
||
console.log(chatCompletion.choices[0].message); |
18 changes: 18 additions & 0 deletions
18
packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from huggingface_hub import InferenceClient | ||
|
||
client = InferenceClient(api_key="api_token") | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": "What is the capital of France?" | ||
} | ||
] | ||
|
||
completion = client.chat.completions.create( | ||
model="meta-llama/Llama-3.1-8B-Instruct", | ||
messages=messages, | ||
max_tokens=500 | ||
) | ||
|
||
print(completion.choices[0].message) |
19 changes: 19 additions & 0 deletions
19
packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import { OpenAI } from "openai"; | ||
|
||
const client = new OpenAI({ | ||
baseURL: "https://api-inference.huggingface.co/v1/", | ||
apiKey: "api_token" | ||
}); | ||
|
||
const chatCompletion = await client.chat.completions.create({ | ||
model: "meta-llama/Llama-3.1-8B-Instruct", | ||
messages: [ | ||
{ | ||
role: "user", | ||
content: "What is the capital of France?" | ||
} | ||
], | ||
max_tokens: 500 | ||
}); | ||
|
||
console.log(chatCompletion.choices[0].message); |
21 changes: 21 additions & 0 deletions
21
packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
base_url="https://api-inference.huggingface.co/v1/", | ||
api_key="api_token" | ||
) | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": "What is the capital of France?" | ||
} | ||
] | ||
|
||
completion = client.chat.completions.create( | ||
model="meta-llama/Llama-3.1-8B-Instruct", | ||
messages=messages, | ||
max_tokens=500 | ||
) | ||
|
||
print(completion.choices[0].message) |
Oops, something went wrong.