Skip to content

Commit

Permalink
remove unsupported save,load,read,write from api docs for knn estimat… (
Browse files Browse the repository at this point in the history
#646)

* remove unsupported save,load,read,write from api docs for knn estimator, model classes

Signed-off-by: Erik Ordentlich <[email protected]>

* fix class names in error messages

Signed-off-by: Erik Ordentlich <[email protected]>

* typo

Signed-off-by: Erik Ordentlich <[email protected]>

---------

Signed-off-by: Erik Ordentlich <[email protected]>
  • Loading branch information
eordentlich authored May 9, 2024
1 parent c59795a commit 0198d3f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
48 changes: 47 additions & 1 deletion python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,27 @@ def _get_cuml_fit_func(self, dataset: DataFrame) -> Callable[ # type: ignore
pass

def write(self) -> MLWriter:
"""Unsupported."""
raise NotImplementedError(
"NearestNeighbors does not support saving/loading, just re-create the estimator."
)

@classmethod
def read(cls) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
"NearestNeighbors does not support saving/loading, just re-create the estimator."
)

def save(self, path: str) -> None:
"""Unsupported."""
raise NotImplementedError(
"NearestNeighbors does not support saving/loading, just re-create the estimator."
)

@classmethod
def load(cls, path: str) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
"NearestNeighbors does not support saving/loading, just re-create the estimator."
)
Expand Down Expand Up @@ -442,14 +457,29 @@ def _nearest_neighbors_join(
return knnjoin_df

def write(self) -> MLWriter:
"""Unsupported."""
raise NotImplementedError(
f"{self.__class__} does not support saving/loading, just re-fit the estimator to re-create a model."
)

@classmethod
def read(cls) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
f"{cls} does not support loading/loading, just re-fit the estimator to re-create a model."
f"{cls} does not support saving/loading, just re-fit the estimator to re-create a model."
)

def save(self, path: str) -> None:
"""Unsupported."""
raise NotImplementedError(
f"{self.__class__} does not support saving/loading, just re-create the estimator."
)

@classmethod
def load(cls, path: str) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
f"{cls} does not support saving/loading, just re-create the estimator."
)


Expand Down Expand Up @@ -1040,13 +1070,29 @@ def _get_cuml_fit_func(self, dataset: DataFrame) -> Callable[ # type: ignore
"""
pass

# for the following 4 methods leave doc string as below so that they are filtered out from api docs
def write(self) -> MLWriter:
"""Unsupported."""
raise NotImplementedError(
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
)

@classmethod
def read(cls) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
)

@classmethod
def load(cls, path: str) -> MLReader:
"""Unsupported."""
raise NotImplementedError(
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
)

def save(self, path: str) -> None:
"""Unsupported."""
raise NotImplementedError(
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
)
Expand Down
12 changes: 11 additions & 1 deletion python/src/spark_rapids_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,17 @@ def _unsupported_methods_attributes(clazz: Any) -> Set[str]:
_unsupported_methods: List[str] = sum(
[_method_names_from_param(k) for k in _unsupported_params], []
)
return set(_unsupported_params + _unsupported_methods)
methods_and_functions = inspect.getmembers(
clazz,
predicate=lambda member: inspect.isfunction(member)
or inspect.ismethod(member),
)
_other_unsupported = [
entry[0]
for entry in methods_and_functions
if entry and (entry[1].__doc__) == "Unsupported."
]
return set(_unsupported_params + _unsupported_methods + _other_unsupported)
else:
return set()

Expand Down
29 changes: 28 additions & 1 deletion python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,35 @@ class A:
def _param_mapping(cls) -> Dict[str, Optional[str]]:
return {"param1": "param2", "param3": None, "param4": ""}

@classmethod
def unsupported_method(cls) -> None:
"""Unsupported."""
pass

def unsupported_function(self) -> None:
"""Unsupported."""
pass

@classmethod
def supported_method(cls) -> None:
"""supported"""
pass

def supported_function(self) -> None:
"""supported"""
pass

assert _unsupported_methods_attributes(A) == set(
["param3", "getParam3", "setParam3", "param4", "getParam4", "setParam4"]
[
"param3",
"getParam3",
"setParam3",
"param4",
"getParam4",
"setParam4",
"unsupported_method",
"unsupported_function",
]
)


Expand Down

0 comments on commit 0198d3f

Please sign in to comment.