Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomForestModel object has no attribute 'predict_class' #144

Open
sermomon opened this issue Nov 7, 2024 · 2 comments
Open

RandomForestModel object has no attribute 'predict_class' #144

sermomon opened this issue Nov 7, 2024 · 2 comments

Comments

@sermomon
Copy link

sermomon commented Nov 7, 2024

I am trying to replicate the following example of classification with Random Forests that is provided in the documentation: Classification - YDF documentation

There is an error in the last line of code in the example:

model.predict_class(test_ds)

**---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[26], line 1
----> 1 model.predict_class(test_ds)

AttributeError: 'RandomForestModel' object has no attribute 'predict_class'**

@rstz
Copy link
Collaborator

rstz commented Nov 7, 2024

This feature will be available in the next version of YDF (or if you create a build yourself) - looks like the documentation update happened too early :(

@sermomon
Copy link
Author

sermomon commented Nov 7, 2024

Since model.predict() returns a numpy array, the mathematical operator argmax can be used to obtain the class position with the highest probability. Therefore, for a multi-class problem a tentative solution could be:

import numpy as np

# Predict probabilities
y_prob = model.predict(test_data)

# Get label using argmax
y_pred = np.argmax(y_prob, axis=1)

Since for a binary classification problem model.predict() only returns the probability of the first class, this should be approached differently:

# Predict probabilities (binary classification: only first class probability)
y_prob = model.predict(test_ds)

# Get label using condition
y_pred = np.where(y_prob >= 0.5, 0, 1)

I am not able to provide a better solution by myself. But maybe this will help someone temporarily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants