Skip to content

Commit

Permalink
Merge pull request #1273 from weaviate/modules/add-generative-databricks
Browse files Browse the repository at this point in the history
Add `generative-databricks` support in `Configure` factory
  • Loading branch information
tsmith023 committed Sep 2, 2024
2 parents a4b3214 + b206dcd commit 9a53a32
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,32 @@ def test_config_with_vectorizer_and_properties(
}
},
),
(
Configure.Generative.databricks(endpoint="https://api.databricks.com"),
{
"generative-databricks": {
"endpoint": "https://api.databricks.com",
}
},
),
(
Configure.Generative.databricks(
endpoint="https://api.databricks.com",
max_tokens=100,
temperature=0.5,
top_k=10,
top_p=0.5,
),
{
"generative-databricks": {
"endpoint": "https://api.databricks.com",
"maxTokens": 100,
"temperature": 0.5,
"topK": 10,
"topP": 0.5,
}
},
),
]


Expand Down
90 changes: 90 additions & 0 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ class GenerativeSearches(str, Enum):
Weaviate module backed by Anyscale generative models.
`COHERE`
Weaviate module backed by Cohere generative models.
`DATABRICKS`
Weaviate module backed by Databricks generative models.
`FRIENDLIAI`
Weaviate module backed by FriendliAI generative models.
`MISTRAL`
Weaviate module backed by Mistral generative models.
`OCTOAI`
Expand All @@ -178,6 +182,7 @@ class GenerativeSearches(str, Enum):
ANTHROPIC = "generative-anthropic"
ANYSCALE = "generative-anyscale"
COHERE = "generative-cohere"
DATABRICKS = "generative-databricks"
FRIENDLIAI = "generative-friendliai"
MISTRAL = "generative-mistral"
OCTOAI = "generative-octoai"
Expand Down Expand Up @@ -426,6 +431,17 @@ def _to_dict(self) -> Dict[str, Any]:
return self.module_config


class _GenerativeDatabricks(_GenerativeConfigCreate):
generative: Union[GenerativeSearches, _EnumLikeStr] = Field(
default=GenerativeSearches.DATABRICKS, frozen=True, exclude=True
)
endpoint: str
maxTokens: Optional[int]
temperature: Optional[float]
topK: Optional[int]
topP: Optional[float]


class _GenerativeOctoai(_GenerativeConfigCreate):
generative: Union[GenerativeSearches, _EnumLikeStr] = Field(
default=GenerativeSearches.OCTOAI, frozen=True, exclude=True
Expand Down Expand Up @@ -611,6 +627,14 @@ def anyscale(
model: Optional[str] = None,
temperature: Optional[float] = None,
) -> _GenerativeConfigCreate:
"""Create a `_GenerativeAnyscale` object for use when generating using the `generative-anyscale` module.
Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeAnyscale(model=model, temperature=temperature)

@staticmethod
Expand All @@ -628,6 +652,37 @@ def custom(
"""
return _GenerativeCustom(generative=_EnumLikeStr(module_name), module_config=module_config)

@staticmethod
def databricks(
*,
endpoint: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> _GenerativeConfigCreate:
"""Create a `_GenerativeDatabricks` object for use when performing AI generation using the `generative-databricks` module.
Arguments:
`endpoint`
The URL where the API request should go. Defaults to `None`, which uses the server-defined default
`max_tokens`
The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
`top_k`
The top K value to use. Defaults to `None`, which uses the server-defined default
`top_p`
The top P value to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeDatabricks(
endpoint=endpoint,
maxTokens=max_tokens,
temperature=temperature,
topK=top_k,
topP=top_p,
)

@staticmethod
def friendliai(
*,
Expand All @@ -636,6 +691,19 @@ def friendliai(
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> _GenerativeConfigCreate:
"""
Create a `_GenerativeFriendliai` object for use when performing AI generation using the `generative-friendliai` module.
Arguments:
`base_url`
The base URL where the API request should go. Defaults to `None`, which uses the server-defined default
`model`
The model to use. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
`max_tokens`
The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeFriendliai(
model=model, temperature=temperature, maxTokens=max_tokens, baseURL=base_url
)
Expand All @@ -646,6 +714,16 @@ def mistral(
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> _GenerativeConfigCreate:
"""Create a `_GenerativeMistral` object for use when performing AI generation using the `generative-mistral` module.
Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
`max_tokens`
The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeMistral(model=model, temperature=temperature, maxTokens=max_tokens)

@staticmethod
Expand All @@ -656,6 +734,18 @@ def octoai(
model: Optional[str] = None,
temperature: Optional[float] = None,
) -> _GenerativeConfigCreate:
"""Create a `_GenerativeOctoai` object for use when performing AI generation using the `generative-octoai` module.
Arguments:
`base_url`
The base URL where the API request should go. Defaults to `None`, which uses the server-defined default
`max_tokens`
The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
`model`
The model to use. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeOctoai(
baseURL=base_url, maxTokens=max_tokens, model=model, temperature=temperature
)
Expand Down

0 comments on commit 9a53a32

Please sign in to comment.