Skip to content

Commit

Permalink
feat: Support langchain transformer on fabric (microsoft#2036)
Browse files Browse the repository at this point in the history
* support langchain transformer on fabric

* avoid addtional param

* format code

---------

Co-authored-by: cruise <[email protected]>
  • Loading branch information
lhrotk and mslhrotk authored Aug 10, 2023
1 parent 149c634 commit 8f794c8
Showing 1 changed file with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from pyspark.sql.functions import udf
from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal

OPENAI_API_VERSION = "2022-12-01"
RL = TypeVar("RL", bound="MLReadable")
Expand Down Expand Up @@ -125,6 +126,14 @@ def __init__(
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.running_on_synapse_internal = running_on_synapse_internal()
if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config

self._setDefault(
url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint
+ "cognitive/openai"
)
kwargs = self._input_kwargs
if subscriptionKey:
kwargs["subscriptionKey"] = subscriptionKey
Expand Down Expand Up @@ -196,10 +205,15 @@ def _transform(self, dataset):
def udfFunction(x):
import openai

openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
if self.running_on_synapse_internal and not self.isSet(self.url):
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun

OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
else:
openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
return self.getChain().run(x)

outCol = self.getOutputCol()
Expand Down

0 comments on commit 8f794c8

Please sign in to comment.