-
Notifications
You must be signed in to change notification settings - Fork 0
/
keras_tfserving.py
37 lines (31 loc) · 1.68 KB
/
keras_tfserving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.models as models
import os
def save_model_to_serving(model, export_version, export_path='prod_models'):
print(model.input, model.output)
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'image': model.input}, outputs={'label': model.output})
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
},
legacy_init_op=legacy_init_op)
builder.save()
if __name__ == '__main__':
model = models.load_model('../weights/my_modelnew.h5')
model.summary()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
model.evaluate(x_test, y_test)
export_path = "export_model"
save_model_to_serving(model, "1", export_path)