-
Notifications
You must be signed in to change notification settings - Fork 532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor to simplify input/output descriptors and decorators #6124
base: branch-24.12
Are you sure you want to change the base?
Conversation
@betatim @wphicks @divyegala this reflects our POC after our initial discussions, will be applying them to a real estimator alongside finishing some todos so we can see the full design in action and discuss any necessary aspects remaining. |
self.intercept_ = CumlArray.zeros(self.n_features_in_, | ||
dtype=self.dtype) | ||
|
||
# do awesome C++ fitting here :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One weird thing while playing with this a bit: when I add a print(f"{self.coef_=}")
here I get the following:
Traceback (most recent call last):
File "/home/coder/cuml/../ff.py", line 9, in <module>
e.fit(X, y)
File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cuml/internals/api_decorators.py", line 190, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cuml/sample/estimator.py", line 32, in wrapper
result = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cuml/sample/estimator.py", line 115, in fit
print(f"{self.coef_=}")
^^^^^^^^^^
File "base.pyx", line 337, in cuml.internals.base.Base.__getattr__
AttributeError: coef_. Did you mean: '_coef_'?
I have stared at this for quite a while but can't work out what is going on??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realised about 2min after leaving the office: it is because _is_fitted
doesn't get set until fit
returns. Maybe something to improve as it makes for a tedious to debug thing :D - I'll ponder a suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the same behavior a scikit-learn, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. I was trying to access the coef_
attribute within fit
.
In scikit-learn these are normal attributes, so once they are set you can use them. Right now we define __getattribute__
which uses _is_fit
. I think it is a bit weird to have code like this fail, mostly because it makes you question your sanity and because the exception doesn't contain a clue (we get to see the AttributeError
from __getattr__
not __getattribute__
:():
self.foo_ = 42
print(self.foo_) # Error, `foo_` doesn't exist!
Maybe we can get around the need to checking _is_fit
and using __getattribute__
by recording inside the DynamicDescriptor
if it has been set or not:
class DynamicDescriptor:
def __init__(self, attribute_name):
self.set = False
self.attribute_name = f"{name}"
def __get__(self, obj, objtype=None):
if obj is None:
return self
if not self.set:
raise AttributeError(f"{obj.__class__.__name__} object has no attribute {self.attribute_name}")
else:
if GlobalSettings().is_internal
return self.raw
else:
return self.raw.to_output(obj._input_type)
def __set__(self, obj, value):
self.set = True
# we can even store the value inside the descriptor?!
self.raw = value
This might need a bit of tweaking to make the message in the exception look right ("'Estimator' object has no attribute 'foo_'"
).
This PR aims to refactor our descriptors and decorators to simplify them to make them significantly easier to test, debug and mantain.