From 41b745bc0408475938a796a61181a6f6d01b04b6 Mon Sep 17 00:00:00 2001 From: Hayden Date: Wed, 10 Jul 2024 18:14:47 +0700 Subject: [PATCH] test: retrieve, update, delete a conversation model --- typesense/api/client_gen.go | 8 +- typesense/api/generator/generator.yml | 14 +-- typesense/api/generator/openapi.yml | 14 +-- typesense/api/types_gen.go | 10 +-- typesense/conversation_model.go | 6 +- typesense/conversation_model_test.go | 122 ++++++++++++++++++++++++++ typesense/conversation_models.go | 6 +- typesense/conversation_models_test.go | 12 ++- typesense/conversation_test.go | 9 +- typesense/conversations.go | 1 + 10 files changed, 172 insertions(+), 30 deletions(-) create mode 100644 typesense/conversation_model_test.go diff --git a/typesense/api/client_gen.go b/typesense/api/client_gen.go index b0d94e5..d5a7b78 100644 --- a/typesense/api/client_gen.go +++ b/typesense/api/client_gen.go @@ -6027,7 +6027,7 @@ func (r RetrieveAllConversationModelsResponse) StatusCode() int { type CreateConversationModelResponse struct { Body []byte HTTPResponse *http.Response - JSON201 *ConversationModelCreateSchema + JSON201 *ConversationModelCreateAndUpdateSchema JSON400 *ApiResponse } @@ -6094,7 +6094,7 @@ func (r RetrieveConversationModelResponse) StatusCode() int { type UpdateConversationModelResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *ConversationModelCreateSchema + JSON200 *ConversationModelCreateAndUpdateSchema } // Status returns HTTPResponse.Status @@ -8171,7 +8171,7 @@ func ParseCreateConversationModelResponse(rsp *http.Response) (*CreateConversati switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 201: - var dest ConversationModelCreateSchema + var dest ConversationModelCreateAndUpdateSchema if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } @@ -8256,7 +8256,7 @@ func ParseUpdateConversationModelResponse(rsp *http.Response) (*UpdateConversati switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest ConversationModelCreateSchema + var dest ConversationModelCreateAndUpdateSchema if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } diff --git a/typesense/api/generator/generator.yml b/typesense/api/generator/generator.yml index 4c248ff..90111c8 100644 --- a/typesense/api/generator/generator.yml +++ b/typesense/api/generator/generator.yml @@ -220,7 +220,7 @@ components: type: integer required: - id - ConversationModelCreateSchema: + ConversationModelCreateAndUpdateSchema: properties: account_id: description: LLM service's account ID (only applicable for Cloudflare) @@ -250,10 +250,12 @@ components: properties: id: type: string + required: + - id type: object ConversationModelSchema: allOf: - - $ref: '#/components/schemas/ConversationModelCreateSchema' + - $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' - properties: id: type: string @@ -2518,14 +2520,14 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' required: true responses: 201: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' description: Created Conversation Model 400: content: @@ -2590,14 +2592,14 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' required: true responses: 200: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' description: The conversation model was successfully updated summary: Update a conversation model tags: diff --git a/typesense/api/generator/openapi.yml b/typesense/api/generator/openapi.yml index 707bc42..5daf10a 100644 --- a/typesense/api/generator/openapi.yml +++ b/typesense/api/generator/openapi.yml @@ -990,14 +990,14 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' required: true responses: 201: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' description: Created Conversation Model 400: content: @@ -1035,7 +1035,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' required: true parameters: - name: modelId @@ -1049,7 +1049,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ConversationModelCreateSchema' + $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' description: The conversation model was successfully updated summary: Update a conversation model tags: @@ -3029,7 +3029,7 @@ components: items: $ref: '#/components/schemas/ConversationSchema' x-go-type: '[]*ConversationSchema' - ConversationModelCreateSchema: + ConversationModelCreateAndUpdateSchema: properties: model_name: description: Name of the LLM model offered by OpenAI, Cloudflare or vLLM @@ -3059,10 +3059,12 @@ components: properties: id: type: string + required: + - id type: object ConversationModelSchema: allOf: - - $ref: '#/components/schemas/ConversationModelCreateSchema' + - $ref: '#/components/schemas/ConversationModelCreateAndUpdateSchema' - properties: id: type: string diff --git a/typesense/api/types_gen.go b/typesense/api/types_gen.go index c614a70..5cea891 100644 --- a/typesense/api/types_gen.go +++ b/typesense/api/types_gen.go @@ -158,8 +158,8 @@ type ConversationDeleteSchema struct { Id int `json:"id"` } -// ConversationModelCreateSchema defines model for ConversationModelCreateSchema. -type ConversationModelCreateSchema struct { +// ConversationModelCreateAndUpdateSchema defines model for ConversationModelCreateAndUpdateSchema. +type ConversationModelCreateAndUpdateSchema struct { // AccountId LLM service's account ID (only applicable for Cloudflare) AccountId *string `json:"account_id,omitempty"` @@ -181,7 +181,7 @@ type ConversationModelCreateSchema struct { // ConversationModelDeleteSchema defines model for ConversationModelDeleteSchema. type ConversationModelDeleteSchema struct { - Id *string `json:"id,omitempty"` + Id string `json:"id"` } // ConversationModelSchema defines model for ConversationModelSchema. @@ -1001,10 +1001,10 @@ type UpsertSearchOverrideJSONRequestBody = SearchOverrideSchema type UpsertSearchSynonymJSONRequestBody = SearchSynonymSchema // CreateConversationModelJSONRequestBody defines body for CreateConversationModel for application/json ContentType. -type CreateConversationModelJSONRequestBody = ConversationModelCreateSchema +type CreateConversationModelJSONRequestBody = ConversationModelCreateAndUpdateSchema // UpdateConversationModelJSONRequestBody defines body for UpdateConversationModel for application/json ContentType. -type UpdateConversationModelJSONRequestBody = ConversationModelCreateSchema +type UpdateConversationModelJSONRequestBody = ConversationModelCreateAndUpdateSchema // UpdateConversationJSONRequestBody defines body for UpdateConversation for application/json ContentType. type UpdateConversationJSONRequestBody = ConversationUpdateSchema diff --git a/typesense/conversation_model.go b/typesense/conversation_model.go index 855f20e..b32a434 100644 --- a/typesense/conversation_model.go +++ b/typesense/conversation_model.go @@ -8,7 +8,7 @@ import ( type ConversationModelInterface interface { Retrieve(ctx context.Context) (*api.ConversationModelSchema, error) - Update(ctx context.Context, model *api.ConversationModelCreateSchema) (*api.ConversationModelCreateSchema, error) + Update(ctx context.Context, model *api.ConversationModelCreateAndUpdateSchema) (*api.ConversationModelCreateAndUpdateSchema, error) Delete(ctx context.Context) (*api.ConversationModelDeleteSchema, error) } @@ -28,8 +28,8 @@ func (c *conversationModel) Retrieve(ctx context.Context) (*api.ConversationMode return response.JSON200, nil } -func (c *conversationModel) Update(ctx context.Context, model *api.ConversationModelCreateSchema) (*api.ConversationModelCreateSchema, error) { - response, err := c.apiClient.UpdateConversationModelWithResponse(ctx, c.modelId, api.UpdateConversationModelJSONRequestBody(*model)) +func (c *conversationModel) Update(ctx context.Context, conversationModelCreateAndUpdateSchema *api.ConversationModelCreateAndUpdateSchema) (*api.ConversationModelCreateAndUpdateSchema, error) { + response, err := c.apiClient.UpdateConversationModelWithResponse(ctx, c.modelId, api.UpdateConversationModelJSONRequestBody(*conversationModelCreateAndUpdateSchema)) if err != nil { return nil, err } diff --git a/typesense/conversation_model_test.go b/typesense/conversation_model_test.go new file mode 100644 index 0000000..9718db1 --- /dev/null +++ b/typesense/conversation_model_test.go @@ -0,0 +1,122 @@ +package typesense + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/typesense/typesense-go/typesense/api" +) + +func TestConversationModelRetrieve(t *testing.T) { + accountId, systemPrompt := "CLOUDFLARE_ACCOUNT_ID", "..." + expectedData := &api.ConversationModelSchema{ + Id: "123", + ModelName: "cf/mistral/mistral-7b-instruct-v0.1", + ApiKey: "CLOUDFLARE_API_KEY", + AccountId: &accountId, + SystemPrompt: &systemPrompt, + MaxBytes: 16384, + } + + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodGet) + data := jsonEncode(t, expectedData) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) + }) + defer server.Close() + + res, err := client.Conversations().Model("123").Retrieve(context.Background()) + assert.NoError(t, err) + assert.Equal(t, expectedData, res) +} + +func TestConversationModelRetrieveOnHttpStatusErrorCodeReturnsError(t *testing.T) { + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodGet) + w.WriteHeader(http.StatusNotFound) + }) + defer server.Close() + + _, err := client.Conversations().Model("123").Retrieve(context.Background()) + assert.Error(t, err) +} + +func TestConversationModelUpdate(t *testing.T) { + accountId, systemPrompt := "CLOUDFLARE_ACCOUNT_ID", "..." + expectedData := &api.ConversationModelCreateAndUpdateSchema{ + ModelName: "cf/mistral/mistral-7b-instruct-v0.1", + ApiKey: "CLOUDFLARE_API_KEY", + AccountId: &accountId, + SystemPrompt: &systemPrompt, + MaxBytes: 16384, + } + + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodPut) + + var reqBody api.ConversationModelCreateAndUpdateSchema + err := json.NewDecoder(r.Body).Decode(&reqBody) + + assert.NoError(t, err) + assert.Equal(t, expectedData, &reqBody) + + data := jsonEncode(t, expectedData) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) + }) + defer server.Close() + + res, err := client.Conversations().Model("123").Update(context.Background(), expectedData) + assert.NoError(t, err) + assert.Equal(t, expectedData, res) +} + +func TestConversationModelUpdateOnHttpStatusErrorCodeReturnsError(t *testing.T) { + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodPut) + w.WriteHeader(http.StatusConflict) + }) + defer server.Close() + + _, err := client.Conversations().Model("123").Update(context.Background(), &api.ConversationModelCreateAndUpdateSchema{ + ModelName: "cf/mistral/mistral-7b-instruct-v0.1", + }) + assert.Error(t, err) +} + +func TestConversationModelDelete(t *testing.T) { + expectedData := &api.ConversationModelDeleteSchema{ + Id: "123", + } + + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodDelete) + + data := jsonEncode(t, expectedData) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) + }) + defer server.Close() + + res, err := client.Conversations().Model("123").Delete(context.Background()) + assert.NoError(t, err) + assert.Equal(t, expectedData, res) +} + +func TestConversationModelDeleteOnHttpStatusErrorCodeReturnsError(t *testing.T) { + server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { + validateRequestMetadata(t, r, "/conversations/models/123", http.MethodDelete) + w.WriteHeader(http.StatusConflict) + }) + defer server.Close() + + _, err := client.Conversations().Model("123").Delete(context.Background()) + assert.Error(t, err) +} diff --git a/typesense/conversation_models.go b/typesense/conversation_models.go index 010cb32..c7b3596 100644 --- a/typesense/conversation_models.go +++ b/typesense/conversation_models.go @@ -8,7 +8,7 @@ import ( // ConversationModelsInterface is a type for ConversationModels API operations type ConversationModelsInterface interface { - Create(ctx context.Context, conversationModelCreateSchema *api.ConversationModelCreateSchema) (*api.ConversationModelCreateSchema, error) + Create(ctx context.Context, conversationModelCreateAndUpdateSchema *api.ConversationModelCreateAndUpdateSchema) (*api.ConversationModelCreateAndUpdateSchema, error) Retrieve(ctx context.Context) ([]*api.ConversationModelSchema, error) } @@ -17,8 +17,8 @@ type conversationModels struct { apiClient APIClientInterface } -func (c *conversationModels) Create(ctx context.Context, conversationModelCreateSchema *api.ConversationModelCreateSchema) (*api.ConversationModelCreateSchema, error) { - response, err := c.apiClient.CreateConversationModelWithResponse(ctx, api.CreateConversationModelJSONRequestBody(*conversationModelCreateSchema)) +func (c *conversationModels) Create(ctx context.Context, conversationModelCreateAndUpdateSchema *api.ConversationModelCreateAndUpdateSchema) (*api.ConversationModelCreateAndUpdateSchema, error) { + response, err := c.apiClient.CreateConversationModelWithResponse(ctx, api.CreateConversationModelJSONRequestBody(*conversationModelCreateAndUpdateSchema)) if err != nil { return nil, err } diff --git a/typesense/conversation_models_test.go b/typesense/conversation_models_test.go index 4fd008e..6420cc1 100644 --- a/typesense/conversation_models_test.go +++ b/typesense/conversation_models_test.go @@ -2,6 +2,7 @@ package typesense import ( "context" + "encoding/json" "net/http" "testing" @@ -47,7 +48,7 @@ func TestConversationModelsRetrieveOnHttpStatusErrorCodeReturnsError(t *testing. func TestConversationModelsCreate(t *testing.T) { accountId, systemPrompt := "CLOUDFLARE_ACCOUNT_ID", "..." - expectedData := &api.ConversationModelCreateSchema{ + expectedData := &api.ConversationModelCreateAndUpdateSchema{ ModelName: "cf/mistral/mistral-7b-instruct-v0.1", ApiKey: "CLOUDFLARE_API_KEY", AccountId: &accountId, @@ -57,6 +58,13 @@ func TestConversationModelsCreate(t *testing.T) { server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { validateRequestMetadata(t, r, "/conversations/models", http.MethodPost) + + var reqBody api.ConversationModelCreateAndUpdateSchema + err := json.NewDecoder(r.Body).Decode(&reqBody) + + assert.NoError(t, err) + assert.Equal(t, expectedData, &reqBody) + data := jsonEncode(t, expectedData) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) @@ -76,7 +84,7 @@ func TestConversationModelsCreateOnHttpStatusErrorCodeReturnsError(t *testing.T) }) defer server.Close() - _, err := client.Conversations().Models().Create(context.Background(), &api.ConversationModelCreateSchema{ + _, err := client.Conversations().Models().Create(context.Background(), &api.ConversationModelCreateAndUpdateSchema{ ModelName: "cf/mistral/mistral-7b-instruct-v0.1", }) assert.Error(t, err) diff --git a/typesense/conversation_test.go b/typesense/conversation_test.go index e5fc6b8..3f24f25 100644 --- a/typesense/conversation_test.go +++ b/typesense/conversation_test.go @@ -2,6 +2,7 @@ package typesense import ( "context" + "encoding/json" "net/http" "testing" @@ -51,8 +52,14 @@ func TestConversationUpdateConversation(t *testing.T) { server, client := newTestServerAndClient(func(w http.ResponseWriter, r *http.Request) { validateRequestMetadata(t, r, "/conversations/123", http.MethodPut) - data := jsonEncode(t, expectedData) + var reqBody api.ConversationUpdateSchema + err := json.NewDecoder(r.Body).Decode(&reqBody) + + assert.NoError(t, err) + assert.Equal(t, expectedData, &reqBody) + + data := jsonEncode(t, expectedData) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(data) diff --git a/typesense/conversations.go b/typesense/conversations.go index cf13eac..20dcb1e 100644 --- a/typesense/conversations.go +++ b/typesense/conversations.go @@ -10,6 +10,7 @@ import ( type ConversationsInterface interface { Retrieve(ctx context.Context) ([]*api.ConversationSchema, error) Models() ConversationModelsInterface + Model(modelId string) ConversationModelInterface } // conversations is internal implementation of ConversationsInterface