17
17
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {
18
18
public static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" ;
19
19
public static final String DEFAULT_BASE_API = "https://api-inference.huggingface.co/pipeline/feature-extraction/" ;
20
+ public static final String HFEI_API_PATH = "/embed" ;
20
21
public static final String HF_API_KEY_ENV = "HF_API_KEY" ;
22
+ public static final String API_TYPE_CONFIG_KEY = "apiType" ;
21
23
private final OkHttpClient client = new OkHttpClient ();
22
24
private final Map <String , Object > configParams = new HashMap <>();
23
25
private static final Gson gson = new Gson ();
24
26
25
27
private static final List <WithParam > defaults = Arrays .asList (
28
+ new WithAPIType (APIType .HF_API ),
26
29
WithParam .baseAPI (DEFAULT_BASE_API ),
27
30
WithParam .defaultModel (DEFAULT_MODEL_NAME )
28
31
);
@@ -46,14 +49,21 @@ public HuggingFaceEmbeddingFunction(WithParam... params) throws EFException {
46
49
}
47
50
48
51
public CreateEmbeddingResponse createEmbedding (CreateEmbeddingRequest req ) throws EFException {
49
- Request request = new Request .Builder ()
50
- . url ( this . configParams . get ( Constants . EF_PARAMS_BASE_API ). toString () + this . configParams . get ( Constants . EF_PARAMS_MODEL ). toString ())
52
+ Request . Builder rb = new Request .Builder ()
53
+
51
54
.post (RequestBody .create (req .json (), JSON ))
52
55
.addHeader ("Accept" , "application/json" )
53
56
.addHeader ("Content-Type" , "application/json" )
54
- .addHeader ("User-Agent" , Constants .HTTP_AGENT )
55
- .addHeader ("Authorization" , "Bearer " + configParams .get (Constants .EF_PARAMS_API_KEY ).toString ())
56
- .build ();
57
+ .addHeader ("User-Agent" , Constants .HTTP_AGENT );
58
+ if (configParams .containsKey (API_TYPE_CONFIG_KEY ) && configParams .get (API_TYPE_CONFIG_KEY ).equals (APIType .HFEI_API )) {
59
+ rb .url (this .configParams .get (Constants .EF_PARAMS_BASE_API ).toString () + HFEI_API_PATH );
60
+ } else {
61
+ rb .url (this .configParams .get (Constants .EF_PARAMS_BASE_API ).toString () + this .configParams .get (Constants .EF_PARAMS_MODEL ).toString ());
62
+ }
63
+ if (configParams .containsKey (Constants .EF_PARAMS_API_KEY )) {
64
+ rb .addHeader ("Authorization" , "Bearer " + configParams .get (Constants .EF_PARAMS_API_KEY ).toString ());
65
+ }
66
+ Request request = rb .build ();
57
67
try (Response response = client .newCall (request ).execute ()) {
58
68
if (!response .isSuccessful ()) {
59
69
throw new IOException ("Unexpected code " + response );
@@ -86,4 +96,22 @@ public List<Embedding> embedDocuments(String[] documents) throws EFException {
86
96
CreateEmbeddingResponse response = this .createEmbedding (new CreateEmbeddingRequest ().inputs (documents ));
87
97
return response .getEmbeddings ().stream ().map (Embedding ::fromList ).collect (Collectors .toList ());
88
98
}
99
+
100
+ public static class WithAPIType extends WithParam {
101
+ private final APIType apiType ;
102
+
103
+ public WithAPIType (APIType apitype ) {
104
+ this .apiType = apitype ;
105
+ }
106
+
107
+ @ Override
108
+ public void apply (Map <String , Object > params ) {
109
+ params .put (API_TYPE_CONFIG_KEY , apiType );
110
+ }
111
+ }
112
+
113
+ public static enum APIType {
114
+ HF_API ,
115
+ HFEI_API
116
+ }
89
117
}
0 commit comments