Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Aug 4, 2023
1 parent 33f1a3e commit 9db8414
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
42 changes: 31 additions & 11 deletions examples/daal4py/log_reg_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
import sys

print("KNOWN BUG IN EXAMPLES. TODO: fixme")
sys.exit()

import numpy as np
from sklearn.datasets import load_iris
Expand Down Expand Up @@ -44,26 +40,50 @@ def main():
nClasses=n_classes, resultsToEvaluate="computeClassLabels"
)
# set parameters and compute predictions
predict_result_daal = predict_alg.compute(X, builder.model)
daal4py_prediction = predict_alg.compute(X, builder.model).prediction
predict_result_sklearn = clf.predict(X)
assert np.allclose(predict_result_daal.prediction.flatten(), predict_result_sklearn)
return (builder, predict_result_daal)
assert np.allclose(daal4py_prediction.flatten(), predict_result_sklearn)

# set parameters and compute predictions
predict_alg = d4p.logistic_regression_prediction(
nClasses=n_classes, resultsToEvaluate="computeClassProbabilities"
)
# set parameters and compute predictions
daal4py_probabilities = predict_alg.compute(X, builder.model).probabilities
predict_result_sklearn = clf.predict_proba(X)
assert np.allclose(daal4py_probabilities, predict_result_sklearn)

# set parameters and compute predictions
predict_alg = d4p.logistic_regression_prediction(
nClasses=n_classes, resultsToEvaluate="computeClassLogProbabilities"
)
# set parameters and compute predictions
daal4py_logProbabilities = predict_alg.compute(X, builder.model).logProbabilities
predict_result_sklearn = clf.predict_log_proba(X)
assert np.allclose(daal4py_logProbabilities, predict_result_sklearn)

return (builder, daal4py_prediction, daal4py_probabilities, daal4py_logProbabilities)


if __name__ == "__main__":
if daal_check_version(((2021, "P", 1))):
(builder, predict_result_daal) = main()
(
builder,
daal4py_prediction,
daal4py_probabilities,
daal4py_logProbabilities,
) = main()
print("\nLogistic Regression coefficients:\n", builder.model)
print(
"\nLogistic regression prediction results (first 10 rows):\n",
predict_result_daal.prediction[0:10],
daal4py_prediction[0:10],
)
print(
"\nLogistic regression prediction probabilities (first 10 rows):\n",
predict_result_daal.probabilities[0:10],
daal4py_probabilities[0:10],
)
print(
"\nLogistic regression prediction log probabilities (first 10 rows):\n",
predict_result_daal.logProbabilities[0:10],
daal4py_logProbabilities[0:10],
)
print("All looks good!")
2 changes: 1 addition & 1 deletion tests/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def check_library(rule):

req_os = defaultdict(lambda: [])

skiped_files = ["log_reg_model_builder.py"]
skiped_files = []


def get_exe_cmd(ex, nodist, nostream):
Expand Down

0 comments on commit 9db8414

Please sign in to comment.