From 5af24bb760590eb7d24d875a4026109c81b80f5a Mon Sep 17 00:00:00 2001 From: Brendan Dash Date: Wed, 11 Dec 2024 23:10:15 +0800 Subject: [PATCH] feat: add data type generic --- examples/nextjs/src/app/page.tsx | 2 +- src/index.ts | 15 ++++++++++++--- src/types.ts | 2 ++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/nextjs/src/app/page.tsx b/examples/nextjs/src/app/page.tsx index 77f052f..1143233 100644 --- a/examples/nextjs/src/app/page.tsx +++ b/examples/nextjs/src/app/page.tsx @@ -8,7 +8,7 @@ export default function Home() { const [input, setInput] = useState("I love transformers!"); const { data, isLoading, mutate } = useTransformers({ - task: "sentiment-analysis", + task: "object-detection", options: { dtype: "q8", }, diff --git a/src/index.ts b/src/index.ts index f814bf2..ee3a6d7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,8 +1,17 @@ import { createBirpc } from "birpc"; import { useEffect, useRef, useState } from "react"; -import type { PipelineType, ProgressInfo } from "@huggingface/transformers"; +import type { + AllTasks, + PipelineType, + ProgressInfo, +} from "@huggingface/transformers"; import type { BirpcReturn } from "birpc"; -import type { ClientFunctions, PipelineProps, ServerFunctions } from "./types"; +import type { + ClientFunctions, + PipelineProps, + ServerFunctions, + UnwrapPromise, +} from "./types"; export function useTransformers({ task, @@ -13,7 +22,7 @@ export function useTransformers({ const [isReady, setIsReady] = useState(false); const [isLoading, setIsLoading] = useState(false); - const [data, setData] = useState(); + const [data, setData] = useState>>(); const [progressInfo, setProgressInfo] = useState(); const mutate = async (...data: any) => { diff --git a/src/types.ts b/src/types.ts index 06093f4..f1576b2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,7 @@ import type { PipelineType, ProgressInfo } from "@huggingface/transformers"; +export type UnwrapPromise = T extends Promise ? U : T; + export interface PretrainedModelOptions { progress_callback?: (progress: any) => void; config?: any;