Skip to content

Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 156 additions & 13 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""


import datetime
import numbers
import re
Expand All @@ -30,13 +31,27 @@
from firebase_admin import _utils
from firebase_admin import exceptions

# pylint: disable=import-error,no-name-in-module
try:
from firebase_admin import storage
_GCS_ENABLED = True
except ImportError:
_GCS_ENABLED = False

# pylint: disable=import-error,no-name-in-module
try:
import tensorflow as tf
_TF_ENABLED = True
except ImportError:
_TF_ENABLED = False

_MLKIT_ATTRIBUTE = '_mlkit'
_MAX_PAGE_SIZE = 100
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
Expand Down Expand Up @@ -301,16 +316,16 @@ def model_format(self, model_format):
self._model_format = model_format #Can be None
return self

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict())
copy.update(self._model_format.as_dict(for_upload=for_upload))
return copy


class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


Expand Down Expand Up @@ -364,22 +379,70 @@ def model_source(self, model_source):
def size_bytes(self):
return self._data.get('sizeBytes')

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict())
copy.update(self._model_source.as_dict(for_upload=for_upload))
return {'tfliteModel': copy}


class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


class _CloudStorageClient(object):
"""Cloud Storage helper class"""

GCS_URI = 'gs://{0}/{1}'
BLOB_NAME = 'Firebase/MLKit/Models/{0}'

@staticmethod
def _assert_gcs_enabled():
if not _GCS_ENABLED:
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

@staticmethod
def _parse_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
matcher = _GCS_TFLITE_URI_PATTERN.match(uri)
if not matcher:
raise ValueError('GCS TFLite URI format is invalid.')
return matcher.group('bucket_name'), matcher.group('blob_name')

@staticmethod
def upload(bucket_name, model_file_name, app):
_CloudStorageClient._assert_gcs_enabled()
bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
blob = bucket.blob(blob_name)
blob.upload_from_filename(model_file_name)
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
"""Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri."""
_CloudStorageClient._assert_gcs_enabled()
bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
bucket = storage.bucket(bucket_name, app=app)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version='v4',
expiration=datetime.timedelta(minutes=10),
method='GET'
)


class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
def __init__(self, gcs_tflite_uri):

_STORAGE_CLIENT = _CloudStorageClient()

def __init__(self, gcs_tflite_uri, app=None):
self._app = app
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def __eq__(self, other):
Expand All @@ -391,6 +454,81 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
"""Uploads the model file to an existing Google Cloud Storage bucket.

Args:
model_file_name: The name of the model file.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: A Firebase app instance (or None to use the default app).

Returns:
TFLiteGCSModelSource: The source created from the model_file

Raises:
ImportError: If the Cloud Storage Library has not been installed.
"""
gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app)
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)

@staticmethod
def _assert_tf_version_1_enabled():
if not _TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
if not tf.VERSION.startswith('1.'):
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.

Args:
saved_model_dir: The saved model directory.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the saved_model_dir

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)

@classmethod
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.

Args:
keras_model: A tf.keras model.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the keras_model

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
keras_file = 'keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)

@property
def gcs_tflite_uri(self):
return self._gcs_tflite_uri
Expand All @@ -399,10 +537,15 @@ def gcs_tflite_uri(self):
def gcs_tflite_uri(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def as_dict(self):
return {"gcsTfliteUri": self._gcs_tflite_uri}
def _get_signed_gcs_tflite_uri(self):
"""Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified."""
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)

def as_dict(self, for_upload=False):
if for_upload:
return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()}

#TODO(ifielker): implement from_saved_model etc.
return {'gcsTfliteUri': self._gcs_tflite_uri}


class ListModelsPage(object):
Expand Down Expand Up @@ -671,13 +814,13 @@ def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
self._client.body('post', url='models', json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict()}
data = {'model': model.as_dict(for_upload=True)}
if update_mask is not None:
data['updateMask'] = update_mask
try:
Expand Down
50 changes: 49 additions & 1 deletion tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@
}
}

GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite'
GCS_BUCKET_NAME = 'my_bucket'
GCS_BLOB_NAME = 'mymodel.tflite'
GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)
GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI}
GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
TFLITE_FORMAT_JSON = {
Expand All @@ -112,6 +114,10 @@
}
TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON)

GCS_TFLITE_SIGNED_URI_PATTERN = (
'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo')
GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)

GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite'
GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2}
GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2)
Expand Down Expand Up @@ -325,6 +331,18 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non
session_url, adapter(payload, status, recorder))
return recorder

class _TestStorageClient(object):
@staticmethod
def upload(bucket_name, model_file_name, app):
del app # unused variable
blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name)
return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
del app # unused variable
bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name)

class TestModel(object):
"""Tests mlkit.Model class."""
Expand All @@ -333,6 +351,7 @@ def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient()

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -404,6 +423,13 @@ def test_model_format_source_creation(self):
}
}

def test_source_creation_from_tflite_file(self):
model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file(
"my_model.tflite", "my_bucket")
assert model_source.as_dict() == {
'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite'
}

def test_model_source_setters(self):
model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
Expand All @@ -420,6 +446,27 @@ def test_model_format_setters(self):
}
}

def test_model_as_dict_for_upload(self):
model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = mlkit.TFLiteFormat(model_source=model_source)
model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
assert model.as_dict(for_upload=True) == {
'displayName': DISPLAY_NAME_1,
'tfliteModel': {
'gcsTfliteUri': GCS_TFLITE_SIGNED_URI
}
}

@pytest.mark.parametrize('helper_func', [
mlkit.TFLiteGCSModelSource.from_keras_model,
mlkit.TFLiteGCSModelSource.from_saved_model
])
def test_tf_not_enabled(self, helper_func):
mlkit._TF_ENABLED = False # for reliability
with pytest.raises(ImportError) as excinfo:
helper_func(None)
check_error(excinfo, ImportError)

@pytest.mark.parametrize('display_name, exc_type', [
('', ValueError),
('&_*#@:/?', ValueError),
Expand Down Expand Up @@ -803,6 +850,7 @@ def test_rpc_error(self, publish_function):
)
assert len(create_recorder) == 1


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down