Skip to content

Commit

Permalink
Refacto Inference snippets tests (autogeneration + 1 snippet == 1 fil…
Browse files Browse the repository at this point in the history
…e) (#1046)

This PR is a refacto of the inference snippet tests. The main goal is to
reduce friction when updating the inference snippets logic. In
particular:
- each generated snippet is saved into a single test file with the
correct file extension => should greatly improve reviews. Currently
snippets are saved as hardcoded strings in JS files, making them hard to
review (no syntax highlighting, weird indentation).
- snippets can be updated with a script, reducing the need to manually
edit tests (which is quite time-consuming)

This PR is quite large given the auto-generated files. The main parts to
review are:
- `generate-snippets-fixtures.ts` => the script that generates the
snippets and test them.
- package.json, pnpm lock files, etc. => I'm not entirely sure of what I
did there (especially to make `tasks-gen` depend on `tasks`)
- `python.spec.ts` / `js.specs.ts` and `curl.specs.ts` have been removed
- everything in `packages/tasks-gen/snippets-fixtures/` => the test
cases. I've only committed the ones that where previously tested.

Thanks to this PR, I fixed a typo in the JS snippets (a missing comma)
and curl snippets (consistency between `"` and `'`).


cc @mishig25 with whom I quickly discussed this

---------

Co-authored-by: Mishig <[email protected]>
  • Loading branch information
Wauplin and mishig25 authored Nov 25, 2024
1 parent afdfb0b commit c63f9ae
Show file tree
Hide file tree
Showing 34 changed files with 824 additions and 399 deletions.
67 changes: 67 additions & 0 deletions packages/tasks-gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
## @huggingface.js/tasks-gen

This package is not a published one. It contains scripts that generate or test parts of the `@huggingface.js/tasks` package.

### generate-snippets-fixtures.ts

This script generates and tests Inference API snippets. The goal is to have a simple way to review changes in the snippets.
When updating logic in `packages/tasks/src/snippets`, the test snippets must be updated and committed in the same PR.

To (re-)generate the snippets, run:

```
pnpm generate-snippets-fixtures
```

If some logic has been updated, you should see the result with a
```
git diff
# the diff has to be committed if correct
```

To test the snippets, run:

```
pnpm test
```

Finally if you want to add a test case, you must add an entry in `TEST_CASES` array in `generate-snippets-fixtures.ts`.

### inference-codegen.ts

Generates JS and Python dataclasses based on the Inference Specs (jsonschema files).

This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` and `huggingface_hub` up to date.

To update the specs manually, run:

```
pnpm inference-codegen
```

### inference-tei-import.ts

Fetches TEI specs and generates JSON schema for input and output of text-embeddings (also called feature-extraction).
See https://huggingface.github.io/text-embeddings-inference/ for more details.

This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TEI updates.

To update the specs manually, run:

```
pnpm inference-tei-import
```

### inference-tgi-import.ts

Fetches TGI specs and generates JSON schema for input, output and stream_output of text-generation and chat-completion tasks.
See https://huggingface.github.io/text-generation-inference/ for more details.

This script is run by a cron job once a day and helps getting `@huggingface.js/tasks` up to date with TGI updates.

To update the specs manually, run:

```
pnpm inference-tgi-import
```

7 changes: 6 additions & 1 deletion packages/tasks-gen/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
"format": "prettier --write .",
"format:check": "prettier --check .",
"check": "tsc",
"generate-snippets-fixtures": "tsx scripts/generate-snippets-fixtures.ts",
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write ../tasks/src/tasks/*/inference.ts",
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write ../tasks/src/tasks/text-generation/spec/*.json && prettier --write ../tasks/src/tasks/chat-completion/spec/*.json",
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write ../tasks/src/tasks/feature-extraction/spec/*.json"
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write ../tasks/src/tasks/feature-extraction/spec/*.json",
"test": "vitest run"
},
"type": "module",
"author": "Hugging Face",
Expand All @@ -22,5 +24,8 @@
"@types/node": "^20.11.5",
"quicktype-core": "https://github.com/huggingface/quicktype/raw/pack-18.0.17/packages/quicktype-core/quicktype-core-18.0.17.tgz",
"type-fest": "^3.13.1"
},
"dependencies": {
"@huggingface/tasks": "workspace:^"
}
}
11 changes: 8 additions & 3 deletions packages/tasks-gen/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

178 changes: 178 additions & 0 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Generates Inference API snippets using @huggingface/tasks snippets.
*
* If used in test mode ("pnpm test"), it compares the generated snippets with the expected ones.
* If used in generation mode ("pnpm generate-snippets-fixtures"), it generates the expected snippets.
*
* Expected snippets are saved under ./snippets-fixtures and are meant to be versioned on GitHub.
* Each snippet is saved in a separate file placed under "./{test-name}/{index}.{client}.{language}":
* - test-name: the name of the test (e.g. "text-to-image", "conversational-llm", etc.)
* - index: the order of the snippet in the array of snippets (0 if not an array)
* - client: the client name (e.g. "requests", "huggingface_hub", "openai", etc.). Default to "default" if client is not specified.
* - language: the language of the snippet (e.g. "sh", "js", "py", etc.)
*
* Example:
* ./packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py
*/

import { existsSync as pathExists } from "node:fs";
import * as fs from "node:fs/promises";
import * as path from "node:path/posix";

import type { InferenceSnippet } from "@huggingface/tasks";
import { snippets } from "@huggingface/tasks";

type LANGUAGE = "sh" | "js" | "py";

const TEST_CASES: {
testName: string;
model: snippets.ModelDataMinimal;
languages: LANGUAGE[];
opts?: Record<string, unknown>;
}[] = [
{
testName: "conversational-llm-non-stream",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
languages: ["sh", "js", "py"],
opts: { streaming: false },
},
{
testName: "conversational-llm-stream",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
languages: ["sh", "js", "py"],
opts: { streaming: true },
},
{
testName: "conversational-vlm-non-stream",
model: {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
},
languages: ["sh", "js", "py"],
opts: { streaming: false },
},
{
testName: "conversational-vlm-stream",
model: {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
},
languages: ["sh", "js", "py"],
opts: { streaming: true },
},
{
testName: "text-to-image",
model: {
id: "black-forest-labs/FLUX.1-schnell",
pipeline_tag: "text-to-image",
tags: [],
inference: "",
},
languages: ["sh", "js", "py"],
},
] as const;

const GET_SNIPPET_FN = {
sh: snippets.curl.getCurlInferenceSnippet,
js: snippets.js.getJsInferenceSnippet,
py: snippets.python.getPythonInferenceSnippet,
} as const;

const rootDirFinder = (): string => {
let currentPath = path.normalize(import.meta.url).replace("file:", "");

while (currentPath !== "/") {
if (pathExists(path.join(currentPath, "package.json"))) {
return currentPath;
}

currentPath = path.normalize(path.join(currentPath, ".."));
}

return "/";
};

function getFixtureFolder(testName: string): string {
return path.join(rootDirFinder(), "snippets-fixtures", testName);
}

function generateInferenceSnippet(
model: snippets.ModelDataMinimal,
language: LANGUAGE,
opts?: Record<string, unknown>
): InferenceSnippet[] {
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts);
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets];
}

async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise<InferenceSnippet[]> {
const fixtureFolder = getFixtureFolder(testName);
const files = await fs.readdir(fixtureFolder);

const expectedSnippets: InferenceSnippet[] = [];
for (const file of files.filter((file) => file.endsWith("." + language)).sort()) {
const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js"
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" });
expectedSnippets.push(client === "default" ? { content } : { client, content });
}
return expectedSnippets;
}

async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) {
const fixtureFolder = getFixtureFolder(testName);
await fs.mkdir(fixtureFolder, { recursive: true });

for (const [index, snippet] of snippets.entries()) {
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`);
await fs.writeFile(file, snippet.content);
}
}

if (import.meta.vitest) {
// Run test if in test mode
const { describe, expect, it } = import.meta.vitest;

describe("inference API snippets", () => {
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
describe(testName, () => {
languages.forEach((language) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language);
expect(generatedSnippets).toEqual(expectedSnippets);
});
});
});
});
});
} else {
// Otherwise, generate the fixtures
console.log("✨ Re-generating snippets");
console.debug(" 🚜 Removing existing fixtures...");
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });

console.debug(" 🏭 Generating new fixtures...");
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
console.debug(` ${testName} (${languages.join(", ")})`);
languages.forEach(async (language) => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
await saveExpectedInferenceSnippet(testName, language, generatedSnippets);
});
});
console.log("✅ All done!");
console.log("👉 Please check the generated fixtures before committing them.");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
-H 'Authorization: Bearer api_token' \
-H 'Content-Type: application/json' \
--data '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"max_tokens": 500,
"stream": false
}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { HfInference } from "@huggingface/inference";

const client = new HfInference("api_token");

const chatCompletion = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from openai import OpenAI

client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key="api_token"
)

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Loading

0 comments on commit c63f9ae

Please sign in to comment.