Skip to content

Commit c63f9ae

Browse files
Wauplinmishig25
andauthored
Refacto Inference snippets tests (autogeneration + 1 snippet == 1 file) (#1046)
This PR is a refacto of the inference snippet tests. The main goal is to reduce friction when updating the inference snippets logic. In particular: - each generated snippet is saved into a single test file with the correct file extension => should greatly improve reviews. Currently snippets are saved as hardcoded strings in JS files, making them hard to review (no syntax highlighting, weird indentation). - snippets can be updated with a script, reducing the need to manually edit tests (which is quite time-consuming) This PR is quite large given the auto-generated files. The main parts to review are: - `generate-snippets-fixtures.ts` => the script that generates the snippets and test them. - package.json, pnpm lock files, etc. => I'm not entirely sure of what I did there (especially to make `tasks-gen` depend on `tasks`) - `python.spec.ts` / `js.specs.ts` and `curl.specs.ts` have been removed - everything in `packages/tasks-gen/snippets-fixtures/` => the test cases. I've only committed the ones that where previously tested. Thanks to this PR, I fixed a typo in the JS snippets (a missing comma) and curl snippets (consistency between `"` and `'`). cc @mishig25 with whom I quickly discussed this --------- Co-authored-by: Mishig <[email protected]>
1 parent afdfb0b commit c63f9ae

34 files changed

+824
-399
lines changed

packages/tasks-gen/README.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
## @huggingface.js/tasks-gen
2+
3+
This package is not a published one. It contains scripts that generate or test parts of the `@huggingface.js/tasks` package.
4+
5+
### generate-snippets-fixtures.ts
6+
7+
This script generates and tests Inference API snippets. The goal is to have a simple way to review changes in the snippets.
8+
When updating logic in `packages/tasks/src/snippets`, the test snippets must be updated and committed in the same PR.
9+
10+
To (re-)generate the snippets, run:
11+
12+
```
13+
pnpm generate-snippets-fixtures
14+
```
15+
16+
If some logic has been updated, you should see the result with a
17+
```
18+
git diff
19+
# the diff has to be committed if correct
20+
```
21+
22+
To test the snippets, run:
23+
24+
```
25+
pnpm test
26+
```
27+
28+
Finally if you want to add a test case, you must add an entry in `TEST_CASES` array in `generate-snippets-fixtures.ts`.
29+
30+
### inference-codegen.ts
31+
32+
Generates JS and Python dataclasses based on the Inference Specs (jsonschema files).
33+
34+
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` and `huggingface_hub` up to date.
35+
36+
To update the specs manually, run:
37+
38+
```
39+
pnpm inference-codegen
40+
```
41+
42+
### inference-tei-import.ts
43+
44+
Fetches TEI specs and generates JSON schema for input and output of text-embeddings (also called feature-extraction).
45+
See https://huggingface.github.io/text-embeddings-inference/ for more details.
46+
47+
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TEI updates.
48+
49+
To update the specs manually, run:
50+
51+
```
52+
pnpm inference-tei-import
53+
```
54+
55+
### inference-tgi-import.ts
56+
57+
Fetches TGI specs and generates JSON schema for input, output and stream_output of text-generation and chat-completion tasks.
58+
See https://huggingface.github.io/text-generation-inference/ for more details.
59+
60+
This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TGI updates.
61+
62+
To update the specs manually, run:
63+
64+
```
65+
pnpm inference-tgi-import
66+
```
67+

packages/tasks-gen/package.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
"format": "prettier --write .",
1212
"format:check": "prettier --check .",
1313
"check": "tsc",
14+
"generate-snippets-fixtures": "tsx scripts/generate-snippets-fixtures.ts",
1415
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write ../tasks/src/tasks/*/inference.ts",
1516
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write ../tasks/src/tasks/text-generation/spec/*.json && prettier --write ../tasks/src/tasks/chat-completion/spec/*.json",
16-
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write ../tasks/src/tasks/feature-extraction/spec/*.json"
17+
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write ../tasks/src/tasks/feature-extraction/spec/*.json",
18+
"test": "vitest run"
1719
},
1820
"type": "module",
1921
"author": "Hugging Face",
@@ -22,5 +24,8 @@
2224
"@types/node": "^20.11.5",
2325
"quicktype-core": "https://github.com/huggingface/quicktype/raw/pack-18.0.17/packages/quicktype-core/quicktype-core-18.0.17.tgz",
2426
"type-fest": "^3.13.1"
27+
},
28+
"dependencies": {
29+
"@huggingface/tasks": "workspace:^"
2530
}
2631
}

packages/tasks-gen/pnpm-lock.yaml

Lines changed: 8 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* Generates Inference API snippets using @huggingface/tasks snippets.
3+
*
4+
* If used in test mode ("pnpm test"), it compares the generated snippets with the expected ones.
5+
* If used in generation mode ("pnpm generate-snippets-fixtures"), it generates the expected snippets.
6+
*
7+
* Expected snippets are saved under ./snippets-fixtures and are meant to be versioned on GitHub.
8+
* Each snippet is saved in a separate file placed under "./{test-name}/{index}.{client}.{language}":
9+
* - test-name: the name of the test (e.g. "text-to-image", "conversational-llm", etc.)
10+
* - index: the order of the snippet in the array of snippets (0 if not an array)
11+
* - client: the client name (e.g. "requests", "huggingface_hub", "openai", etc.). Default to "default" if client is not specified.
12+
* - language: the language of the snippet (e.g. "sh", "js", "py", etc.)
13+
*
14+
* Example:
15+
* ./packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py
16+
*/
17+
18+
import { existsSync as pathExists } from "node:fs";
19+
import * as fs from "node:fs/promises";
20+
import * as path from "node:path/posix";
21+
22+
import type { InferenceSnippet } from "@huggingface/tasks";
23+
import { snippets } from "@huggingface/tasks";
24+
25+
type LANGUAGE = "sh" | "js" | "py";
26+
27+
const TEST_CASES: {
28+
testName: string;
29+
model: snippets.ModelDataMinimal;
30+
languages: LANGUAGE[];
31+
opts?: Record<string, unknown>;
32+
}[] = [
33+
{
34+
testName: "conversational-llm-non-stream",
35+
model: {
36+
id: "meta-llama/Llama-3.1-8B-Instruct",
37+
pipeline_tag: "text-generation",
38+
tags: ["conversational"],
39+
inference: "",
40+
},
41+
languages: ["sh", "js", "py"],
42+
opts: { streaming: false },
43+
},
44+
{
45+
testName: "conversational-llm-stream",
46+
model: {
47+
id: "meta-llama/Llama-3.1-8B-Instruct",
48+
pipeline_tag: "text-generation",
49+
tags: ["conversational"],
50+
inference: "",
51+
},
52+
languages: ["sh", "js", "py"],
53+
opts: { streaming: true },
54+
},
55+
{
56+
testName: "conversational-vlm-non-stream",
57+
model: {
58+
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
59+
pipeline_tag: "image-text-to-text",
60+
tags: ["conversational"],
61+
inference: "",
62+
},
63+
languages: ["sh", "js", "py"],
64+
opts: { streaming: false },
65+
},
66+
{
67+
testName: "conversational-vlm-stream",
68+
model: {
69+
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
70+
pipeline_tag: "image-text-to-text",
71+
tags: ["conversational"],
72+
inference: "",
73+
},
74+
languages: ["sh", "js", "py"],
75+
opts: { streaming: true },
76+
},
77+
{
78+
testName: "text-to-image",
79+
model: {
80+
id: "black-forest-labs/FLUX.1-schnell",
81+
pipeline_tag: "text-to-image",
82+
tags: [],
83+
inference: "",
84+
},
85+
languages: ["sh", "js", "py"],
86+
},
87+
] as const;
88+
89+
const GET_SNIPPET_FN = {
90+
sh: snippets.curl.getCurlInferenceSnippet,
91+
js: snippets.js.getJsInferenceSnippet,
92+
py: snippets.python.getPythonInferenceSnippet,
93+
} as const;
94+
95+
const rootDirFinder = (): string => {
96+
let currentPath = path.normalize(import.meta.url).replace("file:", "");
97+
98+
while (currentPath !== "/") {
99+
if (pathExists(path.join(currentPath, "package.json"))) {
100+
return currentPath;
101+
}
102+
103+
currentPath = path.normalize(path.join(currentPath, ".."));
104+
}
105+
106+
return "/";
107+
};
108+
109+
function getFixtureFolder(testName: string): string {
110+
return path.join(rootDirFinder(), "snippets-fixtures", testName);
111+
}
112+
113+
function generateInferenceSnippet(
114+
model: snippets.ModelDataMinimal,
115+
language: LANGUAGE,
116+
opts?: Record<string, unknown>
117+
): InferenceSnippet[] {
118+
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts);
119+
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets];
120+
}
121+
122+
async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise<InferenceSnippet[]> {
123+
const fixtureFolder = getFixtureFolder(testName);
124+
const files = await fs.readdir(fixtureFolder);
125+
126+
const expectedSnippets: InferenceSnippet[] = [];
127+
for (const file of files.filter((file) => file.endsWith("." + language)).sort()) {
128+
const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js"
129+
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" });
130+
expectedSnippets.push(client === "default" ? { content } : { client, content });
131+
}
132+
return expectedSnippets;
133+
}
134+
135+
async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) {
136+
const fixtureFolder = getFixtureFolder(testName);
137+
await fs.mkdir(fixtureFolder, { recursive: true });
138+
139+
for (const [index, snippet] of snippets.entries()) {
140+
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`);
141+
await fs.writeFile(file, snippet.content);
142+
}
143+
}
144+
145+
if (import.meta.vitest) {
146+
// Run test if in test mode
147+
const { describe, expect, it } = import.meta.vitest;
148+
149+
describe("inference API snippets", () => {
150+
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
151+
describe(testName, () => {
152+
languages.forEach((language) => {
153+
it(language, async () => {
154+
const generatedSnippets = generateInferenceSnippet(model, language, opts);
155+
const expectedSnippets = await getExpectedInferenceSnippet(testName, language);
156+
expect(generatedSnippets).toEqual(expectedSnippets);
157+
});
158+
});
159+
});
160+
});
161+
});
162+
} else {
163+
// Otherwise, generate the fixtures
164+
console.log("✨ Re-generating snippets");
165+
console.debug(" 🚜 Removing existing fixtures...");
166+
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });
167+
168+
console.debug(" 🏭 Generating new fixtures...");
169+
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
170+
console.debug(` ${testName} (${languages.join(", ")})`);
171+
languages.forEach(async (language) => {
172+
const generatedSnippets = generateInferenceSnippet(model, language, opts);
173+
await saveExpectedInferenceSnippet(testName, language, generatedSnippets);
174+
});
175+
});
176+
console.log("✅ All done!");
177+
console.log("👉 Please check the generated fixtures before committing them.");
178+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
2+
-H 'Authorization: Bearer api_token' \
3+
-H 'Content-Type: application/json' \
4+
--data '{
5+
"model": "meta-llama/Llama-3.1-8B-Instruct",
6+
"messages": [
7+
{
8+
"role": "user",
9+
"content": "What is the capital of France?"
10+
}
11+
],
12+
"max_tokens": 500,
13+
"stream": false
14+
}'
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { HfInference } from "@huggingface/inference";
2+
3+
const client = new HfInference("api_token");
4+
5+
const chatCompletion = await client.chatCompletion({
6+
model: "meta-llama/Llama-3.1-8B-Instruct",
7+
messages: [
8+
{
9+
role: "user",
10+
content: "What is the capital of France?"
11+
}
12+
],
13+
max_tokens: 500
14+
});
15+
16+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from huggingface_hub import InferenceClient
2+
3+
client = InferenceClient(api_key="api_token")
4+
5+
messages = [
6+
{
7+
"role": "user",
8+
"content": "What is the capital of France?"
9+
}
10+
]
11+
12+
completion = client.chat.completions.create(
13+
model="meta-llama/Llama-3.1-8B-Instruct",
14+
messages=messages,
15+
max_tokens=500
16+
)
17+
18+
print(completion.choices[0].message)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "https://api-inference.huggingface.co/v1/",
5+
apiKey: "api_token"
6+
});
7+
8+
const chatCompletion = await client.chat.completions.create({
9+
model: "meta-llama/Llama-3.1-8B-Instruct",
10+
messages: [
11+
{
12+
role: "user",
13+
content: "What is the capital of France?"
14+
}
15+
],
16+
max_tokens: 500
17+
});
18+
19+
console.log(chatCompletion.choices[0].message);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from openai import OpenAI
2+
3+
client = OpenAI(
4+
base_url="https://api-inference.huggingface.co/v1/",
5+
api_key="api_token"
6+
)
7+
8+
messages = [
9+
{
10+
"role": "user",
11+
"content": "What is the capital of France?"
12+
}
13+
]
14+
15+
completion = client.chat.completions.create(
16+
model="meta-llama/Llama-3.1-8B-Instruct",
17+
messages=messages,
18+
max_tokens=500
19+
)
20+
21+
print(completion.choices[0].message)

0 commit comments

Comments
 (0)