diff --git a/gramex/handlers/mlhandler.py b/gramex/handlers/mlhandler.py index b1a9e662b..364e8c5ac 100644 --- a/gramex/handlers/mlhandler.py +++ b/gramex/handlers/mlhandler.py @@ -98,7 +98,7 @@ def setup(cls, data=None, model={}, config_dir='', template=DEFAULT_TEMPLATE, ** if op.exists(cls.store.model_path): # If the pkl exists, load it if op.isdir(cls.store.model_path): mclass, wrapper = ml.search_modelclass(mclass) - cls.model = locate(wrapper).from_disk(mclass, cls.store.model_path) + cls.model = locate(wrapper).from_disk(path=cls.store.model_path, klass=mclass) else: cls.model = get_model(cls.store.model_path, {}) elif data is not None: @@ -125,6 +125,8 @@ def _parse_multipart_form_data(self): for f in files: buff = BytesIO(f['body']) try: + if f['content_type'] in ['image/jpeg', 'image/jpg', 'image/png']: + return buff ext = re.sub(r'^.', '', op.splitext(f['filename'])[-1]) xdf = cache.open_callback['jsondata' if ext == 'json' else ext](buff) dfs.append(xdf) @@ -136,6 +138,13 @@ def _parse_multipart_form_data(self): def _parse_application_json(self): return pd.read_json(self.request.body.decode('utf8')) + def _parse_image_jpeg(self): + from PIL import Image + buff = BytesIO(self.request.body) + return Image.open(buff) + + _parse_image_jpg = _parse_image_png = _parse_image_jpeg + def _parse_data(self, _cache=True, append=False): header = self.request.headers.get('Content-Type', '').split(';')[0] header = slugify(header).replace('-', '_') @@ -176,6 +185,9 @@ def _filterrows(cls, data, **kwargs): return data def _transform(self, data, **kwargs): + if not isinstance(data, (pd.DataFrame, pd.Series)): + return data + orgdata = self.store.load_data() for col in np.intersect1d(data.columns, orgdata.columns): data[col] = data[col].astype(orgdata[col].dtype) @@ -184,7 +196,10 @@ def _transform(self, data, **kwargs): return data def _predict(self, data=None, score_col=''): - self._check_model_path() + import io + if isinstance(data, io.BytesIO): + data = self.model.predict(data=data, mclass=self.store.load('class')) + return data metric = self.get_argument('_metric', False) if metric: scorer = get_scorer(metric) @@ -204,6 +219,8 @@ def _predict(self, data=None, score_col=''): except Exception as exc: app_log.exception(exc) return data + except AttributeError: + return self.model.predict(data) def _check_model_path(self): try: @@ -241,13 +258,18 @@ def get(self, *path_args, **path_kwargs): else: self._check_model_path() if '_download' in self.args: - self.set_header('Content-Type', 'application/octet-strem') + self.set_header('Content-Type', 'application/octet-stream') self.set_header('Content-Disposition', f'attachment; filename={op.basename(self.store.model_path)}') with open(self.store.model_path, 'rb') as fout: self.write(fout.read()) elif '_model' in self.args: self.write(json.dumps(self.model.get_params(), indent=2)) + # elif isinstance(self.model, ml.KerasApplication): + # data = self._parse_multipart_form_data() + # prediction = yield gramex.service.threadpool.submit( + # self._predict, data) + # self.write(json.dumps(prediction, indent=2, cls=CustomJSONEncoder)) else: try: data_args = {k: v for k, v in self.args.items() if not k.startswith('_')} @@ -275,14 +297,18 @@ def _train(self, data=None): target_col = self.get_argument('target_col', self.store.load('target_col')) index_col = self.get_argument('index_col', self.store.load('index_col')) self.store.dump('target_col', target_col) - data = self._parse_data(False) if data is None else data - data = self._filtercols(data) - data = self._filterrows(data) - self.model = get_model( - self.store.load('class'), self.store.load('params'), - data=data, target_col=target_col, - nums=self.store.load('nums'), cats=self.store.load('cats') - ) + if isinstance(self.model, ml.KerasApplication): + result = self.model.fit(data, self.kwargs['config_dir']) + return result + else: + data = self._parse_data(False) if data is None else data + data = self._filtercols(data) + data = self._filterrows(data) + self.model = get_model( + self.store.load('class'), self.store.load('params'), + data=data, target_col=target_col, + nums=self.store.load('nums'), cats=self.store.load('cats') + ) if not isinstance(self.model, ml.SklearnTransformer): target = data[target_col] train = data[[c for c in data if c not in (target_col, index_col)]] @@ -306,11 +332,18 @@ def _score(self): @coroutine def post(self, *path_args, **path_kwargs): action = self.args.pop('_action', 'predict') - if action not in ACTIONS: - raise HTTPError(BAD_REQUEST, f'Action {action} not supported.') - res = yield gramex.service.threadpool.submit(getattr(self, f"_{action}")) - self.set_header('Content-Type', 'application/json') - self.write(json.dumps(res, indent=2, cls=CustomJSONEncoder)) + if 'training_data' in self.args and action == 'train': + data = self.args['training_data'] + training_results = yield gramex.service.threadpool.submit( + self._train, data=data) + self.set_header('Content-Type', 'application/json') + self.write(json.dumps(training_results, indent=2, cls=CustomJSONEncoder)) + else: + if action not in ACTIONS: + raise HTTPError(BAD_REQUEST, f'Action {action} not supported.') + res = yield gramex.service.threadpool.submit(getattr(self, f"_{action}")) + self.set_header('Content-Type', 'application/json') + self.write(json.dumps(res, indent=2, cls=CustomJSONEncoder)) super(MLHandler, self).post(*path_args, **path_kwargs) @coroutine diff --git a/gramex/ml_api.py b/gramex/ml_api.py index 98d3380b0..79aa6b6a9 100644 --- a/gramex/ml_api.py +++ b/gramex/ml_api.py @@ -14,7 +14,7 @@ from sklearn.base import BaseEstimator from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline -from sklearn.preprocessing import OneHotEncoder, StandardScaler +from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder TRANSFORMS = { "include": [], @@ -46,6 +46,9 @@ "statsmodels.tsa.statespace.sarimax", ], "gramex.ml_api.HFTransformer": ["gramex.transformers"], + "gramex.ml_api.KerasApplication": [ + "tensorflow.keras.applications" + ] } @@ -204,7 +207,10 @@ class ModelStore(cache.JSONStore): def __init__(self, path, *args, **kwargs): _mkdir(path) self.data_store = op.join(path, "data.h5") - self.model_path = op.join(path, op.basename(path) + ".pkl") + if op.exists(op.join(path, op.basename(path) + ".pkl")): + self.model_path = op.join(path, op.basename(path) + ".pkl") + else: + self.model_path = path self.path = path super(ModelStore, self).__init__(op.join(path, "config.json"), *args, **kwargs) @@ -404,12 +410,6 @@ def __init__(self, model, params=None, data=None, **kwargs): self.params = params self.kwargs = kwargs - @classmethod - def from_disk(cls, path, klass): - model = op.join(path, "model") - tokenizer = op.join(path, "tokenizer") - return cls(klass(model, tokenizer)) - def fit( self, X: Union[pd.DataFrame, np.ndarray], @@ -426,3 +426,119 @@ def _predict( ): text = X["text"] return self.model.predict(text) + + @classmethod + def from_disk(cls, path, klass): + # Load model from disk + model = op.join(path, "model") + tokenizer = op.join(path, "tokenizer") + return cls(klass(model, tokenizer)) + + +class KerasApplication(AbstractModel): + def __init__(self, model, params=None, data=None, **kwargs): + from tensorflow.keras.models import load_model + if params is None: + params = {} + self.params = params + self.kwargs = kwargs + self.preprocess_input = locate('preprocess_input', [model.__module__]) + self.decode_predictions = locate('decode_predictions', [model.__module__]) + self.model_path = kwargs['path'] + + try: + self.model = load_model(self.model_path) + self.custom_model = True + except OSError: + self.custom_model = False + self.model = model(include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000) + + @classmethod + def from_disk(cls, path, klass): + # Load model from disk + return cls(klass, path=path) + + def predict(self, data=None, **kwargs): + from tensorflow.keras.preprocessing import image + + _, height, width, _ = self.model.input_shape + data = data.resize((height, width)) + x = image.img_to_array(data) + x = np.expand_dims(x, axis=0) + x = self.preprocess_input(x) + preds = self.model.predict(x) + if not self.custom_model: + # decode the results into a list of tuples (class, description, probability) + results = self.decode_predictions(preds)[0][0][1] + else: + # If the model is trained, provide it with the relevant class names + class_names = joblib.load(op.join(self.model_path, 'class_names.pkl')) + results = class_names.inverse_transform([np.argmax(preds)])[0] + return results + + def fit(self, data, model_path, *args, **kwargs): + import tensorflow as tf + from tensorflow.keras.optimizers import Adam + from tensorflow.keras.models import Sequential + from tensorflow.python.keras.layers import Dense, Flatten + import pathlib + + data_dir = pathlib.Path(data) + img_height, img_width = 224, 224 + batch_size = 32 + train_ds = tf.keras.preprocessing.image_dataset_from_directory( + data_dir, + validation_split=0.2, + subset="training", + seed=123, + image_size=(img_height, img_width), + batch_size=batch_size) + + val_ds = tf.keras.preprocessing.image_dataset_from_directory( + data_dir, + validation_split=0.2, + subset="validation", + seed=123, + image_size=(img_height, img_width), + batch_size=batch_size) + + class_names = train_ds.class_names + keras_model = Sequential() + pretrained_model = tf.keras.applications.ResNet50(include_top=False, + input_shape=(224, 224, 3), + pooling='avg', classes=len(class_names), + weights='imagenet') + for layer in pretrained_model.layers: + layer.trainable = False + + keras_model.add(pretrained_model) + keras_model.add(Flatten()) + keras_model.add(Dense(512, activation='relu')) + keras_model.add(Dense(len(class_names), activation='softmax')) + keras_model.compile(optimizer=Adam(lr=0.001), + loss='sparse_categorical_crossentropy', metrics=['accuracy']) + epochs = 1 + keras_model.fit( + train_ds, + validation_data=val_ds, + epochs=epochs + ) + le = LabelEncoder() + le.fit(class_names) + keras_model.save(model_path) + joblib.dump(le, op.join(self.model_path, 'class_names.pkl')) + return class_names + + def get_params(self, **kwargs): + super().get_params(**kwargs) + + def score(self, X, y_true, **kwargs): + super().score(X, y_true, **kwargs) + + def get_attributes(self): + super().get_attributes()