Skip to content

Adding File naming capability to from_saved_model and from_keras_model. #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Dec 13, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model):
return converter.convert()

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite',
bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.

Args:
saved_model_dir: The saved model directory.
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)
Expand All @@ -541,16 +543,18 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
"""
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
open('firebase_ml_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_ml_model.tflite', bucket_name, app)
with open(model_file_name, 'wb') as model_file:
model_file.write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)

@classmethod
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite',
bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.

Args:
keras_model: A tf.keras model.
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)
Expand All @@ -563,9 +567,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
"""
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
open('firebase_ml_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_ml_model.tflite', bucket_name, app)
with open(model_file_name, 'wb') as model_file:
model_file.write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)

@property
def gcs_tflite_uri(self):
Expand Down