Skip to content

Commit

Permalink
feat: add data type generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Debbl committed Dec 11, 2024
1 parent 4f953ef commit 5af24bb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/nextjs/src/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export default function Home() {
const [input, setInput] = useState("I love transformers!");

const { data, isLoading, mutate } = useTransformers({
task: "sentiment-analysis",
task: "object-detection",
options: {
dtype: "q8",
},
Expand Down
15 changes: 12 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import { createBirpc } from "birpc";
import { useEffect, useRef, useState } from "react";
import type { PipelineType, ProgressInfo } from "@huggingface/transformers";
import type {
AllTasks,
PipelineType,
ProgressInfo,
} from "@huggingface/transformers";
import type { BirpcReturn } from "birpc";
import type { ClientFunctions, PipelineProps, ServerFunctions } from "./types";
import type {
ClientFunctions,
PipelineProps,
ServerFunctions,
UnwrapPromise,
} from "./types";

export function useTransformers<T extends PipelineType>({
task,
Expand All @@ -13,7 +22,7 @@ export function useTransformers<T extends PipelineType>({

const [isReady, setIsReady] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [data, setData] = useState<any>();
const [data, setData] = useState<UnwrapPromise<ReturnType<AllTasks[T]>>>();
const [progressInfo, setProgressInfo] = useState<ProgressInfo>();

const mutate = async (...data: any) => {
Expand Down
2 changes: 2 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { PipelineType, ProgressInfo } from "@huggingface/transformers";

export type UnwrapPromise<T> = T extends Promise<infer U> ? U : T;

export interface PretrainedModelOptions {
progress_callback?: (progress: any) => void;
config?: any;
Expand Down

0 comments on commit 5af24bb

Please sign in to comment.