Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sparse index support #95

Merged
merged 9 commits into from
Feb 4, 2025
66 changes: 57 additions & 9 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) {
// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the Index.
// - Metric: (Required) The distance metric to be used for [similarity] search. You can use
// 'euclidean', 'cosine', or 'dotproduct'.
// - DeletionProtection: (Optional) determines whether [deletion protection] is "enabled" or "disabled" for the index.
// When "enabled", the index cannot be deleted. Defaults to "disabled".
// - Environment: (Required) The [cloud environment] where the Index will be hosted.
// - PodType: (Required) The [type of pod] to use for the [Index]. One of `s1`, `p1`, or `p2` appended with `.` and
// one of `x1`, `x2`, `x4`, or `x8`.
Expand All @@ -413,8 +415,6 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) {
// default, all metadata is indexed; when `metadata_config` is present,
// only specified metadata fields are indexed. These configurations are
// only valid for use with pod-based Indexes.
// - DeletionProtection: (Optional) determines whether [deletion protection] is "enabled" or "disabled" for the index.
// When "enabled", the index cannot be deleted. Defaults to "disabled".
// - Tags: (Optional) A map of tags to associate with the Index.
//
// To create a new pods-based Index, use the [Client.CreatePodIndex] method.
Expand Down Expand Up @@ -540,7 +540,7 @@ func (req CreatePodIndexRequest) TotalCount() int {
// fmt.Printf("Successfully created pod index: %s", idx.Name)
// }
func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) (*Index, error) {
if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Environment == "" || in.PodType == "" {
if in.Name == "" || in.Dimension <= 0 || in.Metric == "" || in.Environment == "" || in.PodType == "" {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refining the check on Dimension here - if it's negative we don't really need to make a request.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Will Dimension <= 0 throw an error message "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest" even though the Dimension is included as a part of the request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I can update to mention "positive Dimension" to make it a bit more clear.

return nil, fmt.Errorf("fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest")
}

Expand All @@ -549,6 +549,7 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest)
pods := in.TotalCount()
replicas := in.ReplicaCount()
shards := in.ShardCount()
vectorType := "dense"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse cannot be selected for pod indexes so we default for the pod creation path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


var tags *db_control.IndexTags
if in.Tags != nil {
Expand All @@ -561,6 +562,7 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest)
Metric: metric,
DeletionProtection: deletionProtection,
Tags: tags,
VectorType: &vectorType,
}

req.Spec = db_control.IndexSpec{
Expand Down Expand Up @@ -601,13 +603,15 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest)
// - Name: (Required) The name of the [Index]. Resource name must be 1-45 characters long,
// start and end with an alphanumeric character,
// and consist only of lower case alphanumeric characters or '-'.
// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the [Index].
// - Metric: (Required) The metric used to measure the [similarity] between vectors ('euclidean', 'cosine', or 'dotproduct').
// - DeletionProtection: (Optional) Determines whether [deletion protection] is "enabled" or "disabled" for the index.
// When "enabled", the index cannot be deleted. Defaults to "disabled".
// - Cloud: (Required) The public [cloud provider] where you would like your [Index] hosted.
// For serverless Indexes, you define only the cloud and region where the [Index] should be hosted.
// - Region: (Required) The [region] where you would like your [Index] to be created.
// - Dimension: (Optional) The [dimensionality] of the vectors to be inserted in the [Index].
// - VectorType: (Optional) The index vector type. You can use `dense` or `sparse`. If `dense`, the vector dimension must be specified.
// If `sparse`, the vector dimension should not be specified, and the Metric must be set to `dotproduct`. Defaults to `dense`.
// - Tags: (Optional) A map of tags to associate with the Index.
//
// To create a new Serverless Index, use the [Client.CreateServerlessIndex] method.
Expand Down Expand Up @@ -652,11 +656,12 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest)
// [deletion protection]: https://docs.pinecone.io/guides/indexes/prevent-index-deletion#enable-deletion-protection
type CreateServerlessIndexRequest struct {
Name string
Dimension int32
Metric IndexMetric
DeletionProtection DeletionProtection
Cloud Cloud
Region string
Dimension *int32
VectorType *string
Tags *IndexTags
}

Expand Down Expand Up @@ -701,8 +706,29 @@ type CreateServerlessIndexRequest struct {
// fmt.Printf("Successfully created serverless index: %s", idx.Name)
// }
func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerlessIndexRequest) (*Index, error) {
if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Cloud == "" || in.Region == "" {
return nil, fmt.Errorf("fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest")
if in.Name == "" || in.Metric == "" || in.Cloud == "" || in.Region == "" {
return nil, fmt.Errorf("fields Name, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest")
}

// default to "dense" if VectorType is not specified
vectorType := "dense"

// validate VectorType options
if in.VectorType != nil {
switch *in.VectorType {
case "sparse":
if in.Dimension != nil {
return nil, fmt.Errorf("dimension should not be specified when VectorType is 'sparse'")
} else if in.Metric != Dotproduct {
return nil, fmt.Errorf("metric should be 'dotproduct' when VectorType is 'sparse'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric should be optional for sparse.

}
vectorType = "sparse"
case "dense":
vectorType = "dense"
}
}
if in.Dimension == nil && vectorType == "dense" {
return nil, fmt.Errorf("dimension should be specified when VectorType is 'dense'")
}

deletionProtection := pointerOrNil(db_control.DeletionProtection(in.DeletionProtection))
Expand All @@ -715,9 +741,10 @@ func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerless

req := db_control.CreateIndexRequest{
Name: in.Name,
Dimension: &in.Dimension,
Dimension: in.Dimension,
Metric: metric,
DeletionProtection: deletionProtection,
VectorType: &vectorType,
Spec: db_control.IndexSpec{
Serverless: &db_control.ServerlessSpec{
Cloud: db_control.ServerlessSpecCloud(in.Cloud),
Expand Down Expand Up @@ -1561,18 +1588,39 @@ func toIndex(idx *db_control.IndexModel) *Index {
Ready: idx.Status.Ready,
State: IndexStatusState(idx.Status.State),
}
var embed *IndexEmbed
if idx.Embed != nil {
var metric *IndexMetric
if idx.Embed.Metric != nil {
convertedMetric := IndexMetric(*idx.Embed.Metric)
metric = &convertedMetric
}

embed = &IndexEmbed{
Dimension: idx.Embed.Dimension,
FieldMap: idx.Embed.FieldMap,
Metric: metric,
Model: idx.Embed.Model,
ReadParameters: idx.Embed.ReadParameters,
VectorType: idx.Embed.VectorType,
WriteParameters: idx.Embed.WriteParameters,
}
}

tags := (*IndexTags)(idx.Tags)
deletionProtection := derefOrDefault(idx.DeletionProtection, "disabled")

return &Index{
Name: idx.Name,
Dimension: *idx.Dimension,
Host: idx.Host,
Metric: IndexMetric(idx.Metric),
VectorType: idx.VectorType,
DeletionProtection: DeletionProtection(deletionProtection),
Dimension: idx.Dimension,
Spec: spec,
Status: status,
Tags: tags,
Embed: embed,
}
}

Expand Down
140 changes: 113 additions & 27 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func (ts *IntegrationTests) TestListIndexes() {
require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist")
}

func (ts *IntegrationTests) TestCreatePodIndex() {
func (ts *IntegrationTests) TestCreatePodIndexDense() {
if ts.indexType == "serverless" {
ts.T().Skip("Skipping pod index tests for serverless suite")
}

name := uuid.New().String()

defer func(ts *IntegrationTests, name string) {
Expand All @@ -48,53 +52,78 @@ func (ts *IntegrationTests) TestCreatePodIndex() {
})
require.NoError(ts.T(), err)
require.Equal(ts.T(), name, idx.Name, "Index name does not match")
// create index should default to "dense" if no VectorType is specified
require.Equal(ts.T(), "dense", idx.VectorType, "Index vector type does not match")
}

func (ts *IntegrationTests) TestCreatePodIndexInvalidDimension() {
name := uuid.New().String()

_, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{
Name: name,
Dimension: -1,
Metric: Cosine,
Environment: "us-east1-gcp",
PodType: "p1.x1",
})
require.Error(ts.T(), err)
require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError")
}
func (ts *IntegrationTests) TestCreateServerlessIndexDense() {
if ts.indexType == "pod" {
ts.T().Skip("Skipping serverless index tests for pod suite")
}

func (ts *IntegrationTests) TestCreateServerlessIndexInvalidDimension() {
name := uuid.New().String()
dimension := int32(10)

_, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
defer func(ts *IntegrationTests, name string) {
err := ts.deleteIndex(name)
require.NoError(ts.T(), err)
}(ts, name)

idx, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: name,
Dimension: -1,
Dimension: &dimension,
Metric: Cosine,
Cloud: Aws,
Region: "us-west-2",
})
require.Error(ts.T(), err)
require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError")
require.NoError(ts.T(), err)
require.Equal(ts.T(), name, idx.Name, "Index name does not match")
// create index should default to "dense" if no VectorType is specified
require.Equal(ts.T(), "dense", idx.VectorType, "Index vector type does not match")
}

func (ts *IntegrationTests) TestCreateServerlessIndex() {
func (ts *IntegrationTests) TestCreateServerlessIndexSparse() {
if ts.indexType == "pod" {
ts.T().Skip("Skipping serverless index tests for pod suite")
}

name := uuid.New().String()
vectorType := "sparse"

defer func(ts *IntegrationTests, name string) {
err := ts.deleteIndex(name)
require.NoError(ts.T(), err)
}(ts, name)

idx, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: name,
Metric: Dotproduct,
Cloud: Aws,
Region: "us-west-2",
VectorType: &vectorType,
})
require.NoError(ts.T(), err)
require.Equal(ts.T(), name, idx.Name, "Index name does not match")
require.Equal(ts.T(), vectorType, idx.VectorType, "Index vector type does not match")
}

func (ts *IntegrationTests) TestCreateServerlessIndexInvalidDimension() {
if ts.indexType == "pod" {
ts.T().Skip("Skipping serverless index tests for pod suite")
}

name := uuid.New().String()
dimension := int32(-1)

_, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: name,
Dimension: 10,
Dimension: &dimension,
Metric: Cosine,
Cloud: Aws,
Region: "us-west-2",
})
require.NoError(ts.T(), err)
require.Equal(ts.T(), name, idx.Name, "Index name does not match")
require.Error(ts.T(), err)
require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError")
}

func (ts *IntegrationTests) TestDescribeIndex() {
Expand Down Expand Up @@ -781,14 +810,71 @@ func TestCreatePodIndexMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{})
require.Error(t, err)
require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest", err.Error()) //_, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{})
require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest", err.Error())
}

func TestCreateServerlessIndexMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{})
require.Error(t, err)
require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest", err.Error())
require.ErrorContainsf(t, err, "fields Name, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest", err.Error())
}

func TestCreateServerlessIndexInvalidSparseDimensionUnit(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following are all unit tests to check the validation in the serverless creation path.

vectorType := "sparse"
dimension := int32(1)
client := &Client{}
_, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: "test-invalid-dimension",
Metric: Dotproduct,
Cloud: "aws",
Region: "us-east-1",
Dimension: &dimension,
VectorType: &vectorType,
})
require.Error(t, err)
require.ErrorContains(t, err, "dimension should not be specified when VectorType is 'sparse'")
}

func TestCreateServerlessIndexInvalidSparseMetricUnit(t *testing.T) {
vectorType := "sparse"
client := &Client{}
_, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: "test-invalid-dimension",
Metric: Cosine,
Cloud: "aws",
Region: "us-east-1",
VectorType: &vectorType,
})
require.Error(t, err)
require.ErrorContains(t, err, "metric should be 'dotproduct' when VectorType is 'sparse'")
}

func TestCreateServerlessIndexInvalidDenseDimensionUnit(t *testing.T) {
vectorType := "dense"
client := &Client{}
_, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{
Name: "test-invalid-dimension",
Metric: Cosine,
Cloud: "aws",
Region: "us-east-1",
VectorType: &vectorType,
})
require.Error(t, err)
require.ErrorContains(t, err, "dimension should be specified when VectorType is 'dense'")
}

func TestCreatePodIndexInvalidDimensionUnit(t *testing.T) {
client := &Client{}
_, err := client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{
Name: "test-invalid-dimension",
Dimension: -1,
Metric: Cosine,
Environment: "us-east1-gcp",
PodType: "p1.x1",
})
require.Error(t, err)
require.ErrorContains(t, err, "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest")
}

func TestCreateCollectionMissingReqdFieldsUnit(t *testing.T) {
Expand Down Expand Up @@ -1044,7 +1130,7 @@ func TestToIndexUnit(t *testing.T) {
},
expectedOutput: &Index{
Name: "testIndex",
Dimension: 128,
Dimension: &dimension,
Host: "test-host",
Metric: "cosine",
DeletionProtection: "disabled",
Expand Down Expand Up @@ -1092,7 +1178,7 @@ func TestToIndexUnit(t *testing.T) {
},
expectedOutput: &Index{
Name: "testIndex",
Dimension: 128,
Dimension: &dimension,
Host: "test-host",
Metric: "cosine",
DeletionProtection: "enabled",
Expand Down
Loading
Loading