Skip to content

Commit

Permalink
Merge pull request #12 from Arize-ai/arize-dev/azure-china
Browse files Browse the repository at this point in the history
Ability to set the azure base domain. Using env variable to match style used in this file.
  • Loading branch information
ddowker authored Aug 9, 2024
2 parents 056ebd7 + b099ea1 commit 72c2838
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion broker/fragment/store_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ type azureStoreConfig struct {
accountTenantID string // The tenant ID that owns the storage account that we're writing into
// NOTE: This is not the tenant ID that owns the servie principal
storageAccountName string // Storage accounts in Azure are the equivalent to a "bucket" in S3
blobDomain string // base storage domain for azure cloud (e.g. "blob.core.windows.net")
containerName string // In azure, blobs are stored inside of containers, which live inside accounts
prefix string // This is the path prefix for the blobs inside the container

RewriterConfig
}

func (cfg *azureStoreConfig) serviceUrl() string {
return fmt.Sprintf("https://%s.blob.core.windows.net", cfg.storageAccountName)
return fmt.Sprintf("https://%s.%s", cfg.storageAccountName, cfg.blobDomain)
}

func (cfg *azureStoreConfig) containerURL() string {
Expand Down Expand Up @@ -122,6 +123,7 @@ func (a *azureBackend) SignGet(endpoint *url.URL, fragment pb.Fragment, d time.D
log.WithFields(log.Fields{
"tenantId": cfg.accountTenantID,
"storageAccountName": cfg.storageAccountName,
"blobDomain": cfg.blobDomain,
"containerName": cfg.containerName,
"blobName": blobName,
"expiryTime": sasQueryParams.ExpiryTime(),
Expand Down Expand Up @@ -216,12 +218,14 @@ func (a *azureBackend) List(ctx context.Context, store pb.FragmentStore, ep *url
} else if frag, err := pb.ParseFragmentFromRelativePath(journal, blob.Name[len(*segmentList.Prefix):]); err != nil {
log.WithFields(log.Fields{
"storageAccountName": cfg.storageAccountName,
"blobDomain": cfg.blobDomain,
"name": blob.Name,
"err": err,
}).Warning("parsing fragment")
} else if *(blob.Properties.ContentLength) == 0 && frag.ContentLength() > 0 {
log.WithFields(log.Fields{
"storageAccountName": cfg.storageAccountName,
"blobDomain": cfg.blobDomain,
"name": blob.Name,
}).Warning("zero-length fragment")
} else {
Expand Down Expand Up @@ -293,6 +297,12 @@ func parseAzureEndpoint(endpoint *url.URL) (cfg azureStoreConfig, err error) {
// enforces that URL Paths end in '/'.
var splitPath = strings.Split(endpoint.Path[1:], "/")

// arize change to support china cloud
cfg.blobDomain = os.Getenv("AZURE_BLOB_DOMAIN")
if cfg.blobDomain == "" {
cfg.blobDomain = "blob.core.windows.net"
}

if endpoint.Scheme == "azure" {
// Since only one non-ad "Shared Key" credential can be injected via
// environment variables, we should only keep around one client for
Expand Down Expand Up @@ -444,6 +454,7 @@ func (a *azureBackend) getAzurePipeline(ep *url.URL) (cfg azureStoreConfig, clie
log.WithFields(log.Fields{
"tenant": cfg.accountTenantID,
"storageAccountName": cfg.storageAccountName,
"blobDomain": cfg.blobDomain,
"storageContainerName": cfg.containerName,
"pathPrefix": cfg.prefix,
}).Info("constructed new Azure Storage pipeline client")
Expand Down

0 comments on commit 72c2838

Please sign in to comment.