Skip to content

Commit

Permalink
2.7.3 (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
z7ye authored Jan 19, 2023
2 parents d086c59 + 826213b commit 7757da3
Show file tree
Hide file tree
Showing 37 changed files with 1,540 additions and 252 deletions.
2 changes: 1 addition & 1 deletion ads/ads_version.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "2.7.2"
"version": "2.7.3"
}
14 changes: 14 additions & 0 deletions ads/catalog/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from ads.dataset.progress import DummyProgressBar, TqdmProgressBar
from ads.feature_engineering.schema import Schema
from ads.model.model_version_set import ModelVersionSet, _extract_model_version_set_id
from ads.model.deployment.model_deployer import ModelDeployer
from oci.data_science.data_science_client import DataScienceClient
from oci.data_science.models import (
Expand All @@ -72,6 +73,8 @@
"description",
"freeform_tags",
"defined_tags",
"model_version_set_id",
"version_label",
]
_MODEL_PROVENANCE_ATTRIBUTES = ModelProvenance().swagger_types.keys()
_ETAG_KEY = "ETag"
Expand Down Expand Up @@ -1284,6 +1287,8 @@ def upload_model(
bucket_uri: Optional[str] = None,
remove_existing_artifact: Optional[bool] = True,
overwrite_existing_artifact: Optional[bool] = True,
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
version_label: Optional[str] = None,
):
"""
Uploads the model artifact to cloud storage.
Expand Down Expand Up @@ -1315,6 +1320,10 @@ def upload_model(
Whether artifacts uploaded to object storage bucket need to be removed or not.
overwrite_existing_artifact: (bool, optional). Defaults to `True`.
Overwrite target bucket artifact if exists.
model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None.
The Model version set OCID, or name, or `ModelVersionSet` instance.
version_label: (str, optional). Defaults to None.
The model version label.
Returns
-------
Expand All @@ -1340,6 +1349,9 @@ def upload_model(
)
copy_artifact_to_os = True

# extract model_version_set_id from model_version_set attribute or environment
# variables in case of saving model in context of model version set.
model_version_set_id = _extract_model_version_set_id(model_version_set)
# Set default display_name if not specified - randomly generated easy to remember name generated
display_name = display_name or utils.get_random_name_for_resource()

Expand Down Expand Up @@ -1373,6 +1385,8 @@ def upload_model(
else '{"schema": []}',
freeform_tags=freeform_tags,
defined_tags=defined_tags,
model_version_set_id=model_version_set_id,
version_label=version_label,
)
model = self.ds_client.create_model(create_model_details)

Expand Down
4 changes: 4 additions & 0 deletions ads/catalog/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(self, entity_list, datetime_format=utils.date_format):
self.df["compartment_id"] = "..." + self.df["compartment_id"].str[-6:]
if "project_id" in ordered_columns:
self.df["project_id"] = "..." + self.df["project_id"].str[-6:]
if "model_version_set_id" in ordered_columns:
self.df["model_version_set_id"] = (
"..." + self.df["model_version_set_id"].str[-6:]
)
self.df["time_created"] = pd.to_datetime(
self.df["time_created"]
).dt.strftime(datetime_format)
Expand Down
82 changes: 80 additions & 2 deletions ads/common/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import copy
import os
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Any
Expand All @@ -16,6 +17,7 @@

import ads.telemetry
from ads.common import logger
from ads.common.decorator.deprecate import deprecated
from ads.common.extended_enum import ExtendedEnumMeta


Expand Down Expand Up @@ -255,7 +257,7 @@ def create_signer(
Parameters
----------
auth: Optional[str], default 'api_key'
auth_type: Optional[str], default 'api_key'
'api_key', 'resource_principal' or 'instance_principal'. Enable/disable resource principal identity,
instance principal or keypair identity in a notebook session
oci_config_location: Optional[str], default oci.config.DEFAULT_LOCATION, which is '~/.oci/config'
Expand Down Expand Up @@ -378,6 +380,10 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
return signer_generator(signer_args).create_signer()


@deprecated(
"2.7.3",
details="Deprecated, use: from ads.common.auth import create_signer. https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html#overriding-defaults.",
)
def get_signer(
oci_config: Optional[str] = None, oci_profile: Optional[str] = None, **client_kwargs
) -> Dict:
Expand Down Expand Up @@ -686,6 +692,10 @@ class OCIAuthContext:
>>> df_run = DataFlowRun.from_ocid(run_id)
"""

@deprecated(
"2.7.3",
details="Deprecated, use: from ads.common.auth import AuthContext",
)
def __init__(self, profile: str = None):
"""
Initialize class OCIAuthContext and saves global state of authentication type and configuration profile.
Expand All @@ -700,6 +710,10 @@ def __init__(self, profile: str = None):
self.prev_profile = AuthState().oci_key_profile
self.oci_cli_auth = AuthState().oci_cli_auth

@deprecated(
"2.7.3",
details="Deprecated, use: from ads.common.auth import AuthContext",
)
def __enter__(self):
"""
When called by the 'with' statement and if 'profile' provided - 'api_key' authentication with 'profile' used.
Expand All @@ -718,3 +732,67 @@ def __exit__(self, exc_type, exc_val, exc_tb):
When called by the 'with' statement restores initial state of authentication type and profile value.
"""
ads.set_auth(auth=self.prev_mode, profile=self.prev_profile)


class AuthContext:
"""
AuthContext used in 'with' statement for properly managing global authentication type, signer, config
and global configuration parameters.
Examples
--------
>>> from ads import set_auth
>>> from ads.jobs import DataFlowRun
>>> with AuthContext(auth='resource_principal'):
>>> df_run = DataFlowRun.from_ocid(run_id)
>>> from ads.model.framework.sklearn_model import SklearnModel
>>> model = SklearnModel.from_model_artifact(uri="model_artifact_path", artifact_dir="model_artifact_path")
>>> set_auth(auth='api_key', oci_config_location="~/.oci/config")
>>> with AuthContext(auth='api_key', oci_config_location="~/another_config_location/config"):
>>> # upload model to Object Storage using config from another_config_location/config
>>> model.upload_artifact(uri="oci://bucket@namespace/prefix/")
>>> # upload model to Object Storage using config from ~/.oci/config, which was set before 'with AuthContext():'
>>> model.upload_artifact(uri="oci://bucket@namespace/prefix/")
"""

def __init__(self, **kwargs):
"""
Initialize class AuthContext and saves global state of authentication type, signer, config
and global configuration parameters.
Parameters
----------
**kwargs: optional, list of parameters passed to ads.set_auth() method, which can be:
auth: Optional[str], default 'api_key'
'api_key', 'resource_principal' or 'instance_principal'. Enable/disable resource principal
identity, instance principal or keypair identity
oci_config_location: Optional[str], default oci.config.DEFAULT_LOCATION, which is '~/.oci/config'
config file location
profile: Optional[str], default is DEFAULT_PROFILE, which is 'DEFAULT'
profile name for api keys config file
config: Optional[Dict], default {}
created config dictionary
signer: Optional[Any], default None
created signer, can be resource principals signer, instance principal signer or other
signer_callable: Optional[Callable], default None
a callable object that returns signer
signer_kwargs: Optional[Dict], default None
parameters accepted by the signer
"""
self.kwargs = kwargs

def __enter__(self):
"""
When called by the 'with' statement current state of authentication type, signer, config
and configuration parameters saved.
"""
self.previous_state = copy.deepcopy(AuthState())
set_auth(**self.kwargs)

def __exit__(self, exc_type, exc_val, exc_tb):
"""
When called by the 'with' statement initial state of authentication type, signer, config
and configuration parameters restored.
"""
AuthState().__dict__.update(self.previous_state.__dict__)
19 changes: 17 additions & 2 deletions ads/common/ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,22 @@
# Copyright (c) 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging
import sys
from ads.common import logger

# TODO - Revisit this as part of ADS logging changes https://jira.oci.oraclecorp.com/browse/ODSC-36245
# Use a unique logger that we can individually configure without impacting other log statements.
# We don't want the logger name to mention "ads", since this logger will report any exception that happens in a
# notebook cell, and we don't want customers incorrectly assuming that ADS is somehow responsible for every error.
logger = logging.getLogger("ipython.traceback")
# Set propagate to False so logs aren't passed back up to the root logger handlers. There are some places in ADS
# where logging.basicConfig() is called. This changes root logger configurations. The user could import/use code that
# invokes the logging.basicConfig() function at any time, making the behavior of the root logger unpredictable.
logger.propagate = False
logger.handlers.clear()
traceback_handler = logging.StreamHandler()
traceback_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s"))
logger.addHandler(traceback_handler)


def _log_traceback(self, exc_tuple=None, **kwargs):
Expand All @@ -15,7 +29,8 @@ def _log_traceback(self, exc_tuple=None, **kwargs):
print("No traceback available to show.", file=sys.stderr)
return
msg = etype.__name__, str(value)
logger.error("ADS Exception", exc_info=(etype, value, tb))
# User a generic message that makes no mention of ADS.
logger.error("Exception", exc_info=(etype, value, tb))
sys.stderr.write("{0}: {1}".format(*msg))


Expand Down
9 changes: 9 additions & 0 deletions ads/common/model_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from ads.common.decorator.deprecate import deprecated
from ads.feature_engineering.schema import DataSizeTooWide, Schema, SchemaSizeTooLarge
from ads.model.extractor.model_info_extractor_factory import ModelInfoExtractorFactory
from ads.model.model_version_set import ModelVersionSet
from ads.model.common.utils import fetch_manifest_from_conda_location
from git import InvalidGitRepositoryError, Repo

Expand Down Expand Up @@ -714,6 +715,8 @@ def save(
defined_tags=None,
bucket_uri: Optional[str] = None,
remove_existing_artifact: Optional[bool] = True,
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
version_label: Optional[str] = None,
):
"""
Saves the model artifact in the model catalog.
Expand Down Expand Up @@ -757,6 +760,10 @@ def save(
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`
remove_existing_artifact: (bool, optional). Defaults to `True`.
Whether artifacts uploaded to object storage bucket need to be removed or not.
model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None.
The Model version set OCID, or name, or `ModelVersionSet` instance.
version_label: (str, optional). Defaults to None.
The model version label.
Examples
________
Expand Down Expand Up @@ -894,6 +901,8 @@ def save(
defined_tags=defined_tags,
bucket_uri=bucket_uri,
remove_existing_artifact=remove_existing_artifact,
model_version_set=model_version_set,
version_label=version_label,
)
except oci.exceptions.RequestException as e:
if "The write operation timed out" in str(e):
Expand Down
4 changes: 3 additions & 1 deletion ads/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,9 @@ def default(self, obj):
),
):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
elif isinstance(
obj, (np.float_, np.float16, np.float32, np.float64, np.double)
):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
Expand Down
2 changes: 1 addition & 1 deletion ads/dataset/classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def mapper(df, column_name, arg):
df[column_name] = df[column_name].map(arg)
return df

df = df.map_partitions(mapper, target, update_arg)
df = mapper(df, target, update_arg)
sampled_df = mapper(sampled_df, target, update_arg)
ClassificationDataset.__init__(
self, df, sampled_df, target, target_type, shape, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion ads/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def _convert_dtypes_to_avro_types(self):
avro_dtype = "double"
elif "float" in str(dtype):
avro_dtype = "float"
elif dtype == np.bool:
elif dtype == np.bool_:
avro_dtype = "boolean"
else:
avro_dtype = "string"
Expand Down
3 changes: 0 additions & 3 deletions ads/dataset/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,6 @@ def read_log(path, **kwargs):
},
**kwargs,
)
df["time"] = df["time"].map_partitions(
pd.to_datetime, utc=True, meta="datetime64[ns]"
)
return df

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions ads/dbmixin/db_pandas_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.bds.big_data_service import ADSHiveConnection
Expand Down Expand Up @@ -33,7 +33,7 @@ def get(cls, engine="oracle"):

if not Connection:
if engine == "mysql":
print("Requires mysql-connection-python package to use mysql engine")
print("Requires mysql-connector-python package to use mysql engine")
elif engine == "oracle":
print(
f"The `oracledb` or `cx_Oracle` module was not found. Please run "
Expand Down Expand Up @@ -102,6 +102,7 @@ def to_sql(
if_exists: str = "fail",
batch_size=100000,
engine="oracle",
encoding="utf-8",
):
"""To save the dataframe df to database.
Expand All @@ -120,6 +121,8 @@ def to_sql(
Inserting in batches improves insertion performance. Choose this value based on available memore and network bandwidth.
engine: {'oracle', 'mysql'}, default 'oracle'
Select the database type - MySQL or Oracle to store the data
encoding: str, default is "utf-8"
Encoding provided will be used for ecoding all columns, when inserting into table
Returns
Expand All @@ -146,5 +149,5 @@ def to_sql(

Connection = ConnectionFactory.get(engine)
return Connection(**connection_parameters).insert(
table_name, self._obj, if_exists, batch_size
table_name, self._obj, if_exists, batch_size, encoding
)
14 changes: 8 additions & 6 deletions ads/jobs/builders/infrastructure/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
CONDA_PACK_SUFFIX = "#conda"


def conda_pack_name_to_dataflow_config(conda_uri):
return {
"spark.archives": conda_uri + CONDA_PACK_SUFFIX, # .replace(" ", "%20")
"dataflow.auth": "resource_principal",
}


class DataFlowApp(OCIModelMixin, oci.data_flow.models.Application):
@classmethod
def init_client(cls, **kwargs) -> oci.data_flow.data_flow_client.DataFlowClient:
Expand Down Expand Up @@ -778,12 +785,7 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
else:
raise ValueError(f"Conda built type not understood: {conda_type}.")
runtime_config = runtime.configuration or dict()
runtime_config.update(
{
"spark.archives": conda_uri.replace(" ", "%20") + CONDA_PACK_SUFFIX,
"dataflow.auth": "resource_principal",
}
)
runtime_config.update(conda_pack_name_to_dataflow_config(conda_uri))
runtime.with_configuration(runtime_config)
payload.update(
{
Expand Down
Loading

0 comments on commit 7757da3

Please sign in to comment.