Skip to content

Feature/fluentapi #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions src/main/java/tech/amikos/chromadb/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,56 @@ public class Client {

DefaultApi api;

public Client(String basePath) {
private Client(String basePath) {
apiClient.setBasePath(basePath);
api = new DefaultApi(apiClient);
}

public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException {
return new Collection(api, collectionName, embeddingFunction).fetch();

public static enum DistanceFunction {
L2,
COSINE,
IP
}

public static Client newClient(String basePath){
return new Client(basePath);
}

public Boolean reset() throws ApiException {
return api.reset();
}

public String version() throws ApiException {
return api.version();
}

public Map<String, BigDecimal> heartbeat() throws ApiException {
return api.heartbeat();
}

public List<Collection> listCollections() throws ApiException {
List<LinkedTreeMap> apiResponse = (List<LinkedTreeMap>) api.listCollections();
return apiResponse.stream().map((LinkedTreeMap m) -> {
try {
return getCollection((String) m.get("name"), null);
} catch (ApiException e) {
e.printStackTrace(); //this is not great as we're swallowing the exception
}
return null;
}).collect(Collectors.toList());
}


public Collection newCollection(){
return new Collection(api,this,null,null);
}


public Collection createCollection(String collectionName, Map<String, String> metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction) throws ApiException {
return this.createCollection(collectionName, metadata, createOrGet, embeddingFunction, DistanceFunction.L2);
}

public static enum DistanceFunction {
L2,
COSINE,
IP
}

public Collection createCollection(String collectionName, Map<String, String> metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction, DistanceFunction distanceFunction) throws ApiException {
CreateCollection req = new CreateCollection();
Expand All @@ -58,35 +86,30 @@ public Collection createCollection(String collectionName, Map<String, String> me
return new Collection(api, (String) resp.get("name"), embeddingFunction).fetch();
}

public Collection deleteCollection(String collectionName) throws ApiException {
Collection collection = Collection.getInstance(api, collectionName);
api.deleteCollection(collectionName);
return collection;




public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException {
return new Collection(api, collectionName, embeddingFunction).fetch();
}

public Collection upsert(String collectionName, EmbeddingFunction ef) throws ApiException {


public Collection upsertCollection(String collectionName, EmbeddingFunction ef) throws ApiException {
Collection collection = getCollection(collectionName, ef);
// collection.upsert();
return collection;
}

public Boolean reset() throws ApiException {
return api.reset();
}

public List<Collection> listCollections() throws ApiException {
List<LinkedTreeMap> apiResponse = (List<LinkedTreeMap>) api.listCollections();
return apiResponse.stream().map((LinkedTreeMap m) -> {
try {
return getCollection((String) m.get("name"), null);
} catch (ApiException e) {
e.printStackTrace(); //this is not great as we're swallowing the exception
}
return null;
}).collect(Collectors.toList());
public Collection deleteCollection(String collectionName) throws ApiException {
Collection collection = Collection.getInstance(api, collectionName);
api.deleteCollection(collectionName);
return collection;
}

public String version() throws ApiException {
return api.version();
}



}
143 changes: 97 additions & 46 deletions src/main/java/tech/amikos/chromadb/Collection.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
public class Collection {
static Gson gson = new Gson();
DefaultApi api;

Client client;
String collectionName;

String collectionId;
Expand All @@ -22,13 +24,29 @@ public class Collection {

private EmbeddingFunction embeddingFunction;

public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) {

public Collection(DefaultApi api,String collectionName, EmbeddingFunction embeddingFunction) {
this.api = api;
this.collectionName = collectionName;
this.embeddingFunction = embeddingFunction;

}

public Collection(DefaultApi api,Client client, String collectionName, EmbeddingFunction embeddingFunction) {
this.api = api;
this.client = client;
this.collectionName = collectionName;
this.embeddingFunction = embeddingFunction;

}



public Collection name(String collectionName){
this.collectionName = collectionName;
return this;
}

public String getName() {
return collectionName;
}
Expand All @@ -37,10 +55,27 @@ public String getId() {
return collectionId;
}

public Collection metadata(String key,String value){
metadata.put(key,value);
return this;
}
public Map<String, Object> getMetadata() {
return metadata;
}

public Collection ef(EmbeddingFunction embeddingFunction){
this.embeddingFunction = embeddingFunction;
return this;
}

public Collection createOrGet(){
return client.createCollection(this.collectionName,this.metadata,true,this.embeddingFunction);
}

public Collection create(){
return client.createCollection(this.collectionName,this.metadata,false,this.embeddingFunction);
}

public Collection fetch() throws ApiException {
try {
LinkedTreeMap<String, ?> resp = (LinkedTreeMap<String, ?>) api.getCollection(collectionName);
Expand All @@ -53,8 +88,34 @@ public Collection fetch() throws ApiException {
}
}

public Object update(){
return this.update(this.collectionName,this.metadata);
}
public Object update(String newName, Map<String, Object> newMetadata) throws ApiException {
UpdateCollection req = new UpdateCollection();
if (newName != null) {
req.setNewName(newName);
}
if (newMetadata != null && embeddingFunction != null) {
if (!newMetadata.containsKey("embedding_function")) {
newMetadata.put("embedding_function", embeddingFunction.getClass().getName());
}
req.setNewMetadata(newMetadata);
}
Object resp = api.updateCollection(req, this.collectionId);
this.collectionName = newName;
this.fetch(); //do we really need to fetch?
return resp;
}

public Collection remove(){
return client.deleteCollection(this.collectionName);
}



public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException {
return new Collection(api, collectionName, null);
return new Collection(api,collectionName, null);
}

@Override
Expand All @@ -66,6 +127,29 @@ public String toString() {
'}';
}


public Embedding newEmbedding(){
return new Embedding(this);
}

public Object add(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
AddEmbedding req = new AddEmbedding();
List<List<Float>> _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.createEmbedding(documents);
}
req.setEmbeddings((List<Object>) (Object) _embeddings);
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
req.setDocuments(documents);
req.incrementIndex(true);
req.setIds(ids);
return api.add(req, this.collectionId);
}

public GetResult get() throws ApiException {
return this.get(null, null, null);
}

public GetResult get(List<String> ids, Map<String, String> where, Map<String, Object> whereDocument) throws ApiException {
GetEmbedding req = new GetEmbedding();
req.ids(ids).where(where).whereDocument(whereDocument);
Expand All @@ -74,13 +158,6 @@ public GetResult get(List<String> ids, Map<String, String> where, Map<String, Ob
return new Gson().fromJson(json, GetResult.class);
}

public GetResult get() throws ApiException {
return this.get(null, null, null);
}

public Object delete() throws ApiException {
return this.delete(null, null, null);
}

public Object upsert(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
AddEmbedding req = new AddEmbedding();
Expand All @@ -96,23 +173,21 @@ public Object upsert(List<List<Float>> embeddings, List<Map<String, String>> met
return api.upsert(req, this.collectionId);
}


public Object add(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
AddEmbedding req = new AddEmbedding();
public Object updateEmbeddings(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
UpdateEmbedding req = new UpdateEmbedding();
List<List<Float>> _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.createEmbedding(documents);
}
req.setEmbeddings((List<Object>) (Object) _embeddings);
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
req.setDocuments(documents);
req.incrementIndex(true);
req.setMetadatas((List<Object>) (Object) metadatas);
req.setIds(ids);
return api.add(req, this.collectionId);
return api.update(req, this.collectionId);
}

public Integer count() throws ApiException {
return api.count(this.collectionId);
public Object delete() throws ApiException {
return this.delete(null, null, null);
}

public Object delete(List<String> ids, Map<String, String> where, Map<String, Object> whereDocument) throws ApiException {
Expand Down Expand Up @@ -143,42 +218,18 @@ public Object deleteWhereDocuments(Map<String, Object> whereDocument) throws Api
return delete(null, null, whereDocument);
}

public Integer count() throws ApiException {
return api.count(this.collectionId);
}
@Deprecated
public Boolean createIndex() throws ApiException {
return (Boolean) api.createIndex(this.collectionId);
}

public Object update(String newName, Map<String, Object> newMetadata) throws ApiException {
UpdateCollection req = new UpdateCollection();
if (newName != null) {
req.setNewName(newName);
}
if (newMetadata != null && embeddingFunction != null) {
if (!newMetadata.containsKey("embedding_function")) {
newMetadata.put("embedding_function", embeddingFunction.getClass().getName());
}
req.setNewMetadata(newMetadata);
}
Object resp = api.updateCollection(req, this.collectionId);
this.collectionName = newName;
this.fetch(); //do we really need to fetch?
return resp;
public Query newQuery(){
return new Query();
}

public Object updateEmbeddings(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
UpdateEmbedding req = new UpdateEmbedding();
List<List<Float>> _embeddings = embeddings;
if (_embeddings == null) {
_embeddings = this.embeddingFunction.createEmbedding(documents);
}
req.setEmbeddings((List<Object>) (Object) _embeddings);
req.setDocuments(documents);
req.setMetadatas((List<Object>) (Object) metadatas);
req.setIds(ids);
return api.update(req, this.collectionId);
}


public QueryResponse query(List<String> queryTexts, Integer nResults, Map<String, String> where, Map<String, String> whereDocument, List<QueryEmbedding.IncludeEnum> include) throws ApiException {
QueryEmbedding body = new QueryEmbedding();
body.queryEmbeddings((List<Object>) (Object) this.embeddingFunction.createEmbedding(queryTexts));
Expand Down
Loading