-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cross Validation Added #407
base: master
Are you sure you want to change the base?
Conversation
gramex/handlers/mlhandler.py
Outdated
@@ -20,6 +20,9 @@ | |||
from slugify import slugify | |||
from tornado.gen import coroutine | |||
from tornado.web import HTTPError | |||
from sklearn.metrics import get_scorer | |||
from sklearn.model_selection import cross_val_predict, cross_val_score | |||
from sklearn.model_selection import cross_val_predict, cross_val_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line appears twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra line is unnecessary.
gramex/handlers/mlhandler.py
Outdated
# train the model | ||
target = data[target_col] | ||
train = data[[c for c in data if c != target_col]] | ||
# cross validation | ||
mod = cls.modelFunction() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required. The model is already present as cls.model
, see line no: 116.
gramex/handlers/mlhandler.py
Outdated
# cross validation | ||
mod = cls.modelFunction() | ||
CVscore = cross_val_score(mod, train, target) | ||
CV = sum(CVscore)/len(CVscore) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Use
CVscore.mean()
- Variable naming has to follow a specified style - do
pip install flake8
and run theflake8
command against this file, i.e.flake8 mlhandler.py
, and check the output.
@prakrutisingh24 In this PR, we are just computing the cross val score when the model is set up for the first time, and simply printing the CV score. What we need is:
Thanks, |
gramex/handlers/mlhandler.py
Outdated
@@ -40,6 +44,8 @@ | |||
'nums': [], | |||
'cats': [], | |||
'target_col': None, | |||
'CV': True, | |||
'CVargs': [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a single argument, cv
, which can take any value, i.e in gramex.yaml, users should be able to write any of the following.
cv: false # disable cross validation
cv: 5 # Use 5 folds
cv:
cv: 8 # Use 8 folds
n_jobs: -1 # with an optional other parameter.
gramex/handlers/mlhandler.py
Outdated
@@ -20,6 +20,9 @@ | |||
from slugify import slugify | |||
from tornado.gen import coroutine | |||
from tornado.web import HTTPError | |||
from sklearn.metrics import get_scorer | |||
from sklearn.model_selection import cross_val_predict, cross_val_score | |||
from sklearn.model_selection import cross_val_predict, cross_val_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra line is unnecessary.
gramex/handlers/mlhandler.py
Outdated
# cross validation | ||
print('yayyy we are here') | ||
cls.CrossValidation(train,target) | ||
print('should have printed') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the prints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The training is happening in def _fit
. Cross validation should also happen there.
gramex/handlers/mlhandler.py
Outdated
@@ -20,6 +20,9 @@ | |||
from slugify import slugify | |||
from tornado.gen import coroutine | |||
from tornado.web import HTTPError | |||
from sklearn.metrics import get_scorer | |||
from sklearn.model_selection import cross_val_predict, cross_val_score | |||
from ast import literal_eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be required.
gramex/handlers/mlhandler.py
Outdated
@@ -40,6 +43,7 @@ | |||
'nums': [], | |||
'cats': [], | |||
'target_col': None, | |||
'CV': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it lowercase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to support three cases for the cv
option:
- If the user sets
cv: false
- then no cross validation happens - If the user sets
cv: 4
(or some other integer) pass it straight tocross_val_score
- The default should be
cv: None
, and in this case, the user should not have to write anything in gramex.yaml
gramex/handlers/mlhandler.py
Outdated
# train the model | ||
target = data[target_col] | ||
train = data[[c for c in data if c != target_col]] | ||
# cross validation | ||
cls.CrossValidation(train,target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it lowercase.
gramex/handlers/mlhandler.py
Outdated
mclass = model_kwargs.get('class', False) | ||
if mclass: | ||
model = search_modelclass(mclass)(**model_kwargs.get('params', {})) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not required.
if CV: | ||
CVscore = cross_val_score(mod, X=train, y=target, **literal_eval(json.dumps(CV))) | ||
CVavg = sum(CVscore)/len(CVscore) | ||
print('Cross Validation Score : ',CVavg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CV should take place within the train method only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cv:
cvscore = cross_val_score(mod, X=train, y=target, cv=cv)
else:
# Do the usual .fit
# train the model | ||
target = data[target_col] | ||
train = data[[c for c in data if c != target_col]] | ||
# cross validation | ||
cls.cross_validation(train,target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required here.
No description provided.