diff --git a/src/main/java/tech/amikos/chromadb/Client.java b/src/main/java/tech/amikos/chromadb/Client.java index da2f153..792d30d 100644 --- a/src/main/java/tech/amikos/chromadb/Client.java +++ b/src/main/java/tech/amikos/chromadb/Client.java @@ -19,28 +19,56 @@ public class Client { DefaultApi api; - public Client(String basePath) { + private Client(String basePath) { apiClient.setBasePath(basePath); api = new DefaultApi(apiClient); } - public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { - return new Collection(api, collectionName, embeddingFunction).fetch(); + + public static enum DistanceFunction { + L2, + COSINE, + IP + } + + public static Client newClient(String basePath){ + return new Client(basePath); + } + + public Boolean reset() throws ApiException { + return api.reset(); + } + + public String version() throws ApiException { + return api.version(); } public Map heartbeat() throws ApiException { return api.heartbeat(); } + public List listCollections() throws ApiException { + List apiResponse = (List) api.listCollections(); + return apiResponse.stream().map((LinkedTreeMap m) -> { + try { + return getCollection((String) m.get("name"), null); + } catch (ApiException e) { + e.printStackTrace(); //this is not great as we're swallowing the exception + } + return null; + }).collect(Collectors.toList()); + } + + + public Collection newCollection(){ + return new Collection(api,this,null,null); + } + + public Collection createCollection(String collectionName, Map metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction) throws ApiException { return this.createCollection(collectionName, metadata, createOrGet, embeddingFunction, DistanceFunction.L2); } - public static enum DistanceFunction { - L2, - COSINE, - IP - } public Collection createCollection(String collectionName, Map metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction, DistanceFunction distanceFunction) throws ApiException { CreateCollection req = new CreateCollection(); @@ -58,35 +86,30 @@ public Collection createCollection(String collectionName, Map me return new Collection(api, (String) resp.get("name"), embeddingFunction).fetch(); } - public Collection deleteCollection(String collectionName) throws ApiException { - Collection collection = Collection.getInstance(api, collectionName); - api.deleteCollection(collectionName); - return collection; + + + + + public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { + return new Collection(api, collectionName, embeddingFunction).fetch(); } - public Collection upsert(String collectionName, EmbeddingFunction ef) throws ApiException { + + + public Collection upsertCollection(String collectionName, EmbeddingFunction ef) throws ApiException { Collection collection = getCollection(collectionName, ef); // collection.upsert(); return collection; } - public Boolean reset() throws ApiException { - return api.reset(); - } - public List listCollections() throws ApiException { - List apiResponse = (List) api.listCollections(); - return apiResponse.stream().map((LinkedTreeMap m) -> { - try { - return getCollection((String) m.get("name"), null); - } catch (ApiException e) { - e.printStackTrace(); //this is not great as we're swallowing the exception - } - return null; - }).collect(Collectors.toList()); + public Collection deleteCollection(String collectionName) throws ApiException { + Collection collection = Collection.getInstance(api, collectionName); + api.deleteCollection(collectionName); + return collection; } - public String version() throws ApiException { - return api.version(); - } + + + } diff --git a/src/main/java/tech/amikos/chromadb/Collection.java b/src/main/java/tech/amikos/chromadb/Collection.java index 0fb76d5..75f9527 100644 --- a/src/main/java/tech/amikos/chromadb/Collection.java +++ b/src/main/java/tech/amikos/chromadb/Collection.java @@ -14,6 +14,8 @@ public class Collection { static Gson gson = new Gson(); DefaultApi api; + + Client client; String collectionName; String collectionId; @@ -22,13 +24,29 @@ public class Collection { private EmbeddingFunction embeddingFunction; - public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) { + + public Collection(DefaultApi api,String collectionName, EmbeddingFunction embeddingFunction) { + this.api = api; + this.collectionName = collectionName; + this.embeddingFunction = embeddingFunction; + + } + + public Collection(DefaultApi api,Client client, String collectionName, EmbeddingFunction embeddingFunction) { this.api = api; + this.client = client; this.collectionName = collectionName; this.embeddingFunction = embeddingFunction; } + + + public Collection name(String collectionName){ + this.collectionName = collectionName; + return this; + } + public String getName() { return collectionName; } @@ -37,10 +55,27 @@ public String getId() { return collectionId; } + public Collection metadata(String key,String value){ + metadata.put(key,value); + return this; + } public Map getMetadata() { return metadata; } + public Collection ef(EmbeddingFunction embeddingFunction){ + this.embeddingFunction = embeddingFunction; + return this; + } + + public Collection createOrGet(){ + return client.createCollection(this.collectionName,this.metadata,true,this.embeddingFunction); + } + + public Collection create(){ + return client.createCollection(this.collectionName,this.metadata,false,this.embeddingFunction); + } + public Collection fetch() throws ApiException { try { LinkedTreeMap resp = (LinkedTreeMap) api.getCollection(collectionName); @@ -53,8 +88,34 @@ public Collection fetch() throws ApiException { } } + public Object update(){ + return this.update(this.collectionName,this.metadata); + } + public Object update(String newName, Map newMetadata) throws ApiException { + UpdateCollection req = new UpdateCollection(); + if (newName != null) { + req.setNewName(newName); + } + if (newMetadata != null && embeddingFunction != null) { + if (!newMetadata.containsKey("embedding_function")) { + newMetadata.put("embedding_function", embeddingFunction.getClass().getName()); + } + req.setNewMetadata(newMetadata); + } + Object resp = api.updateCollection(req, this.collectionId); + this.collectionName = newName; + this.fetch(); //do we really need to fetch? + return resp; + } + + public Collection remove(){ + return client.deleteCollection(this.collectionName); + } + + + public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException { - return new Collection(api, collectionName, null); + return new Collection(api,collectionName, null); } @Override @@ -66,6 +127,29 @@ public String toString() { '}'; } + + public Embedding newEmbedding(){ + return new Embedding(this); + } + + public Object add(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { + AddEmbedding req = new AddEmbedding(); + List> _embeddings = embeddings; + if (_embeddings == null) { + _embeddings = this.embeddingFunction.createEmbedding(documents); + } + req.setEmbeddings((List) (Object) _embeddings); + req.setMetadatas((List>) (Object) metadatas); + req.setDocuments(documents); + req.incrementIndex(true); + req.setIds(ids); + return api.add(req, this.collectionId); + } + + public GetResult get() throws ApiException { + return this.get(null, null, null); + } + public GetResult get(List ids, Map where, Map whereDocument) throws ApiException { GetEmbedding req = new GetEmbedding(); req.ids(ids).where(where).whereDocument(whereDocument); @@ -74,13 +158,6 @@ public GetResult get(List ids, Map where, Map> embeddings, List> metadatas, List documents, List ids) throws ApiException { AddEmbedding req = new AddEmbedding(); @@ -96,23 +173,21 @@ public Object upsert(List> embeddings, List> met return api.upsert(req, this.collectionId); } - - public Object add(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { - AddEmbedding req = new AddEmbedding(); + public Object updateEmbeddings(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { + UpdateEmbedding req = new UpdateEmbedding(); List> _embeddings = embeddings; if (_embeddings == null) { _embeddings = this.embeddingFunction.createEmbedding(documents); } req.setEmbeddings((List) (Object) _embeddings); - req.setMetadatas((List>) (Object) metadatas); req.setDocuments(documents); - req.incrementIndex(true); + req.setMetadatas((List) (Object) metadatas); req.setIds(ids); - return api.add(req, this.collectionId); + return api.update(req, this.collectionId); } - public Integer count() throws ApiException { - return api.count(this.collectionId); + public Object delete() throws ApiException { + return this.delete(null, null, null); } public Object delete(List ids, Map where, Map whereDocument) throws ApiException { @@ -143,42 +218,18 @@ public Object deleteWhereDocuments(Map whereDocument) throws Api return delete(null, null, whereDocument); } + public Integer count() throws ApiException { + return api.count(this.collectionId); + } @Deprecated public Boolean createIndex() throws ApiException { return (Boolean) api.createIndex(this.collectionId); } - public Object update(String newName, Map newMetadata) throws ApiException { - UpdateCollection req = new UpdateCollection(); - if (newName != null) { - req.setNewName(newName); - } - if (newMetadata != null && embeddingFunction != null) { - if (!newMetadata.containsKey("embedding_function")) { - newMetadata.put("embedding_function", embeddingFunction.getClass().getName()); - } - req.setNewMetadata(newMetadata); - } - Object resp = api.updateCollection(req, this.collectionId); - this.collectionName = newName; - this.fetch(); //do we really need to fetch? - return resp; + public Query newQuery(){ + return new Query(); } - public Object updateEmbeddings(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { - UpdateEmbedding req = new UpdateEmbedding(); - List> _embeddings = embeddings; - if (_embeddings == null) { - _embeddings = this.embeddingFunction.createEmbedding(documents); - } - req.setEmbeddings((List) (Object) _embeddings); - req.setDocuments(documents); - req.setMetadatas((List) (Object) metadatas); - req.setIds(ids); - return api.update(req, this.collectionId); - } - - public QueryResponse query(List queryTexts, Integer nResults, Map where, Map whereDocument, List include) throws ApiException { QueryEmbedding body = new QueryEmbedding(); body.queryEmbeddings((List) (Object) this.embeddingFunction.createEmbedding(queryTexts)); diff --git a/src/main/java/tech/amikos/chromadb/Embedding.java b/src/main/java/tech/amikos/chromadb/Embedding.java new file mode 100644 index 0000000..d349927 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/Embedding.java @@ -0,0 +1,99 @@ +package tech.amikos.chromadb; + +import com.google.gson.annotations.SerializedName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Embedding { + + + private Collection collection; + private List> embeddings = null; + private List> metadatas = null; + private List documents = null; + private List ids = null; + private Map wheres; + private Map whereDocuments; + + public Embedding(Collection collection){ + this.collection = collection; + } + + + public Embedding id(String id){ + if(this.ids == null) + this.ids = new ArrayList<>(); + this.ids.add(id); + return this; + } + + public Embedding metadata(String key,String value){ + if(this.metadatas == null) + metadatas = new ArrayList<>(); + if(this.metadatas.get(0) == null) + metadatas.add(new HashMap<>()); + metadatas.get(metadatas.size()-1).put(key,value); + return this; + } + + public Embedding document(String document){ + if(this.documents == null) + this.documents = new ArrayList<>(); + documents.add(document); + return this; + } + + public Embedding embedding(Float embedding){ + if(this.embeddings == null) + this.embeddings = new ArrayList<>(); + if(this.embeddings.get(0) == null) + this.embeddings.add(new ArrayList<>()); + this.embeddings.get(embeddings.size()-1).add(embedding); + return this; + } + + public Object add(){ + return this.collection.add(embeddings,metadatas,documents,ids); + } + + public Object batchAdd(){ + return this.collection.add(embeddings,metadatas,documents,ids); + } + + public Embedding where(String key,String value){ + if(this.wheres == null) + wheres = new HashMap<>(); + wheres.put(key,value); + return this; + } + + public Embedding whereDocument(String key,Object value){ + if(this.whereDocuments == null) + whereDocuments = new HashMap<>(); + whereDocuments.put(key,value); + return this; + } + + public Collection.GetResult get(){ + if(ids == null) + return this.collection.get(); + return this.collection.get(ids,wheres,whereDocuments); + } + + public Object upsert(){ + return this.collection.upsert(embeddings,metadatas,documents,ids); + } + + public Object update(){ + return this.collection.updateEmbeddings(embeddings,metadatas,documents,ids); + } + + public Object delete(){ + return this.collection.delete(ids,wheres,whereDocuments); + } + + +} diff --git a/src/main/java/tech/amikos/chromadb/Query.java b/src/main/java/tech/amikos/chromadb/Query.java new file mode 100644 index 0000000..6ab5fa5 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/Query.java @@ -0,0 +1,78 @@ +package tech.amikos.chromadb; + +import com.google.gson.annotations.SerializedName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Query { + + private Collection collection; + private List queryTexts; + private Integer nResults; + private Map wheres; + private Map whereDocuments; + private List includes; + public Query(){ + } + + public Query queryText(String queryText){ + if(this.queryTexts == null) + this.queryTexts = new ArrayList<>(); + this.queryTexts.add(queryText); + return this; + } + + public Query nResults(Integer nResults){ + this.nResults = nResults; + return this; + } + + public Query where(String key,String value){ + if(this.wheres == null) + wheres = new HashMap<>(); + wheres.put(key,value); + return this; + } + + public Query whereDocument(String key,String value){ + if(this.whereDocuments == null) + whereDocuments = new HashMap<>(); + whereDocuments.put(key,value); + return this; + } + + public Query includeDocuments(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.DOCUMENTS); + return this; + } + + public Query includeEmbeddings(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.EMBEDDINGS); + return this; + } + + public Query includeMetadatas(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.METADATAS); + return this; + } + + public Query includeDistances(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.DISTANCES); + return this; + } + + public Collection.QueryResponse query(){ + return this.collection.query(queryTexts,nResults,wheres,whereDocuments,includes); + } +} diff --git a/src/main/resources/openapi/api.yaml b/src/main/resources/openapi/api.yaml index ef5e6fc..e48cf5f 100644 --- a/src/main/resources/openapi/api.yaml +++ b/src/main/resources/openapi/api.yaml @@ -63,7 +63,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RawSql' + $ref: e'#/components/schemas/RawSql' required: true responses: '200':