Skip to content

Commit

Permalink
feat: oci genai (#1102)
Browse files Browse the repository at this point in the history
Signed-off-by: Anders Swanson <[email protected]>
Co-authored-by: Alex Jones <[email protected]>
  • Loading branch information
anders-swanson and AlexsJones authored May 16, 2024
1 parent eda5231 commit 047afd4
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 1 deletion.
3 changes: 3 additions & 0 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ var addCmd = &cobra.Command{
Temperature: temperature,
ProviderRegion: providerRegion,
ProviderId: providerId,
CompartmentId: compartmentId,
TopP: topP,
TopK: topK,
MaxTokens: maxTokens,
Expand Down Expand Up @@ -173,4 +174,6 @@ func init() {
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
//add flag for vertexAI Project ID
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var (
temperature float32
providerRegion string
providerId string
compartmentId string
topP float32
topK int32
maxTokens int
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
github.com/hupe1980/go-huggingface v0.0.15
github.com/olekukonko/tablewriter v0.0.5
github.com/oracle/oci-go-sdk/v65 v65.65.1
github.com/prometheus/prometheus v0.49.1
github.com/pterm/pterm v0.12.79
google.golang.org/api v0.172.0
Expand Down Expand Up @@ -70,6 +71,7 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-kit/log v0.2.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/gofrs/flock v0.8.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect
Expand All @@ -87,6 +89,7 @@ require (
github.com/prometheus/common/sigv4 v0.1.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sony/gobreaker v0.5.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.opencensus.io v0.24.0 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,8 @@ github.com/gobuffalo/packr/v2 v2.8.3/go.mod h1:0SahksCVcx4IMnigTjiFuyldmTrdTctXs
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
Expand Down Expand Up @@ -1934,6 +1936,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI=
github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8=
github.com/oracle/oci-go-sdk/v65 v65.65.1 h1:sv7uD844tJGa2Vc+2KaByoXQ0FllZDGV/2+9MdxN6nA=
github.com/oracle/oci-go-sdk/v65 v65.65.1/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0=
github.com/ovh/go-ovh v1.4.3 h1:Gs3V823zwTFpzgGLZNI6ILS4rmxZgJwJCz54Er9LwD0=
github.com/ovh/go-ovh v1.4.3/go.mod h1:AkPXVtgwB6xlKblMjRKJJmjRp+ogrE7fz2lVgcQY8SY=
github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c=
Expand Down Expand Up @@ -2050,6 +2054,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ=
github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo=
github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg=
github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
Expand Down
10 changes: 9 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
&GoogleGenAIClient{},
&HuggingfaceClient{},
&GoogleVertexAIClient{},
&OCIGenAIClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -41,6 +42,7 @@ var (
noopAIClientName,
huggingfaceAIClientName,
googleVertexAIClientName,
ociClientName,
}
)

Expand Down Expand Up @@ -75,6 +77,7 @@ type IAIConfig interface {
GetTopK() int32
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
}

func NewClient(provider string) IAI {
Expand Down Expand Up @@ -104,6 +107,7 @@ type AIProvider struct {
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
Expand Down Expand Up @@ -156,7 +160,11 @@ func (p *AIProvider) GetProviderId() string {
return p.ProviderId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai"}
func (p *AIProvider) GetCompartmentId() string {
return p.CompartmentId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
97 changes: 97 additions & 0 deletions pkg/ai/ocigenai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright 2024 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"errors"
"github.com/oracle/oci-go-sdk/v65/common"
"github.com/oracle/oci-go-sdk/v65/generativeaiinference"
"strings"
)

const ociClientName = "oci"

type OCIGenAIClient struct {
nopCloser

client *generativeaiinference.GenerativeAiInferenceClient
model string
compartmentId string
temperature float32
topP float32
maxTokens int
}

func (c *OCIGenAIClient) GetName() string {
return ociClientName
}

func (c *OCIGenAIClient) Configure(config IAIConfig) error {
config.GetEndpointName()
c.model = config.GetModel()
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.maxTokens = config.GetMaxTokens()
c.compartmentId = config.GetCompartmentId()
provider := common.DefaultConfigProvider()
client, err := generativeaiinference.NewGenerativeAiInferenceClientWithConfigurationProvider(provider)
if err != nil {
return err
}
c.client = &client
return nil
}

func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
generateTextRequest := c.newGenerateTextRequest(prompt)
generateTextResponse, err := c.client.GenerateText(ctx, generateTextRequest)
if err != nil {
return "", err
}
return extractGeneratedText(generateTextResponse.InferenceResponse)
}

func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest {
temperatureF64 := float64(c.temperature)
topPF64 := float64(c.topP)
return generativeaiinference.GenerateTextRequest{
GenerateTextDetails: generativeaiinference.GenerateTextDetails{
CompartmentId: &c.compartmentId,
ServingMode: generativeaiinference.OnDemandServingMode{
ModelId: &c.model,
},
InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{
Prompt: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopP: &topPF64,
},
},
}
}

func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) {
response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse)
if !ok {
return "", errors.New("failed to extract generated text from backed response")
}
sb := strings.Builder{}
for _, text := range response.GeneratedTexts {
if text.Text != nil {
sb.WriteString(*text.Text)
}
}
return sb.String(), nil
}

0 comments on commit 047afd4

Please sign in to comment.