-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Computer Vision pipelines #560
base: master
Are you sure you want to change the base?
Conversation
* TODO: Need a robust way to find if the MLHandler expects image data for training/testing * TODO: Separate class is required for training and testing CV models * TODO: Code has to be made generic to dynamically select Resnet50/VGG, etc. models * TODO: Model training parameters viz. epochs, batch size have to be taken from the user with defaults set * TODO: setup and get_model functions have to be made more scalable * TODO: ModelStore __init__ has to be made scalable
@jaidevd @sanand0 Please add/edit/delete any more TODOs that you feel may be required. |
|
@jaidevd Please look at the recent commit fe2c85e. Is this the correct approach to implement it? |
* Predict is working with default VGG16 and Resnet models. * TODO: Test and add all other models supported by Keras * TODO: Implement training functionality * TODO: Implement functionality to predict from trained models provided by user/trained in gramex
d7791e6
to
fe2c85e
Compare
if op.exists(cls.store.model_path): # If the pkl exists, load it | ||
if op.isdir(cls.store.model_path): | ||
mclass, wrapper = ml.search_modelclass(mclass) | ||
cls.model = locate(wrapper).from_disk(mclass, cls.store.model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition is required for loading transformers. Why is this removed?
gramex/ml_api.py
Outdated
@@ -46,6 +46,10 @@ | |||
"statsmodels.tsa.statespace.sarimax", | |||
], | |||
"gramex.ml_api.HFTransformer": ["gramex.transformers"], | |||
"gramex.ml_api.KerasApplications": [ | |||
"tensorflow.keras.applications.vgg16", | |||
"tensorflow.keras.applications.resnet50" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just leave this at tensorflow.keras.applications
, because
from tensorflow.keras.applications import *
covers everything
gramex/ml_api.py
Outdated
@@ -426,3 +435,49 @@ def _predict( | |||
): | |||
text = X["text"] | |||
return self.model.predict(text) | |||
|
|||
|
|||
class KerasApplications(AbstractModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call this KerasApplication
- singular.
gramex/handlers/mlhandler.py
Outdated
data = imutils.resize(cv2.imdecode(np.fromstring( | ||
self.request.files['image'][0].body, np.uint8), cv2.IMREAD_UNCHANGED), | ||
width=224) | ||
data = cv2.resize(data, (224, 224)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- All of this logic should be in the wrapper class.
- OpenCV should not be a dependency. For loading images from files / streams, use
PIL.Image.open
- For resizing, use
skimage.transform.resize
ortf.image.resize
orPIL.Image.resize
. - The target size after resizing should not be hardcoded to
[224, 224]
- there are other sizes in Keras apps too. For this, you can check the shape of the input tensor in the corresponding model withmodel.input_shape
.
gramex/handlers/mlhandler.py
Outdated
self.set_header('Content-Disposition', | ||
f'attachment; filename={op.basename(self.store.model_path)}') | ||
with open(self.store.model_path, 'rb') as fout: | ||
self.write(fout.read()) | ||
elif '_model' in self.args: | ||
self.write(json.dumps(self.model.get_params(), indent=2)) | ||
elif len(self.request.files.keys()) and \ | ||
self.request.files['image'][0].content_type in \ | ||
['image/jpeg', 'image/jpg', 'image/png']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check should not happen here. As far as possible, send the request payload blindly to the _fit
or _predict
methods. The wrapper class should take care of everything if it is written correctly.
Also note that there are two ways one can send images in a request.
- Send files (multipart form data) - in which case you look at the mimetypes of the files received and open them accordingly. For this, take a look at the definition of
self._parse_multipart_form_data
- Send the raw bytestream of an image with a
Content-Type: image/whatever
header. In this case, write functions called_parse_image_jpeg
,_parse_image_png
, etc. The handler will automatically call them.
Basically the handler knows how to parse data given the content type of the request.
gramex/handlers/mlhandler.py
Outdated
if 'training_data' in data.keys(): | ||
training_results = yield gramex.service.threadpool.submit( | ||
self._train, data=data['training_data'].iloc[0]) | ||
self.write(json.dumps(training_results, indent=2, cls=CustomJSONEncoder)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
gramex/handlers/mlhandler.py
Outdated
json.dump(class_names, fout) | ||
keras_model.save(config_dir) | ||
return class_names | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this and move it to the wrapper class.
* Training keras models would be done in ml_api in KerasApplication wrapper * _parse_multipart_form_data used for parsing images
970a8cc
to
c479f08
Compare
* Removed unnecessary code from ModelStore
gramex/handlers/mlhandler.py
Outdated
@@ -184,7 +186,10 @@ def _transform(self, data, **kwargs): | |||
return data | |||
|
|||
def _predict(self, data=None, score_col=''): | |||
self._check_model_path() | |||
import io | |||
if type(data) == io.BytesIO: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use isinstance
for checking types.
gramex/handlers/mlhandler.py
Outdated
data = self.args['training_data'] | ||
training_results = yield gramex.service.threadpool.submit( | ||
self._train, data=data) | ||
self.write(json.dumps(training_results, indent=2, cls=CustomJSONEncoder)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training should not happen in GET. Only in POST or PUT.
self.store.load('class'), self.store.load('params'), | ||
data=data, target_col=target_col, | ||
nums=self.store.load('nums'), cats=self.store.load('cats') | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jaidevd - find a way to remove dataframe-specific code from MLHandler - it should deal only with train / predict semantics.
gramex/ml_api.py
Outdated
input_tensor=None, | ||
input_shape=None, | ||
pooling=None, | ||
classes=1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to __init__
.
gramex/ml_api.py
Outdated
input_tensor=None, | ||
input_shape=None, | ||
pooling=None, | ||
classes=1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should do only self.model.predict
.
* Model initialisation code moved to __init__ * Training added to POST request
No description provided.