Skip to content

Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 250 additions & 14 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""

import datetime
import numbers
import re
import requests
import six
Expand All @@ -28,6 +30,12 @@

_MLKIT_ATTRIBUTE = '_mlkit'
_MAX_PAGE_SIZE = 100
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')


def _get_mlkit_service(app):
Expand All @@ -47,7 +55,7 @@ def _get_mlkit_service(app):

def get_model(model_id, app=None):
mlkit_service = _get_mlkit_service(app)
return Model(mlkit_service.get_model(model_id))
return Model.from_dict(mlkit_service.get_model(model_id))


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
Expand All @@ -62,29 +70,222 @@ def delete_model(model_id, app=None):


class Model(object):
"""A Firebase ML Kit Model object."""
def __init__(self, data):
"""Created from a data dictionary."""
self._data = data
"""A Firebase ML Kit Model object.

Args:
display_name: The display name of your model - used to identify your model in code.
tags: Optional list of strings associated with your model. Can be used in list queries.
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
"""
def __init__(self, display_name=None, tags=None, model_format=None):
self._data = {}
self._model_format = None

if display_name is not None:
self.display_name = display_name
if tags is not None:
self.tags = tags
if model_format is not None:
self.model_format = model_format

@classmethod
def from_dict(cls, data):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
return model

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._data == other._data # pylint: disable=protected-access
# pylint: disable=protected-access
return self._data == other._data and self._model_format == other._model_format
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def name(self):
return self._data['name']
def model_id(self):
if not self._data.get('name'):
return None
_, model_id = _validate_and_parse_name(self._data.get('name'))
return model_id

@property
def display_name(self):
return self._data['displayName']
return self._data.get('displayName')

@display_name.setter
def display_name(self, display_name):
self._data['displayName'] = _validate_display_name(display_name)
return self

@property
def create_time(self):
"""Returns the creation timestamp"""
seconds = self._data.get('createTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None

return datetime.datetime.fromtimestamp(float(seconds))

@property
def update_time(self):
"""Returns the last update timestamp"""
seconds = self._data.get('updateTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None

#TODO(ifielker): define the rest of the Model properties etc
return datetime.datetime.fromtimestamp(float(seconds))

@property
def validation_error(self):
return self._data.get('state', {}).get('validationError', {}).get('message')

@property
def published(self):
return bool(self._data.get('state', {}).get('published'))

@property
def etag(self):
return self._data.get('etag')

@property
def model_hash(self):
return self._data.get('modelHash')

@property
def tags(self):
return self._data.get('tags')

@tags.setter
def tags(self, tags):
self._data['tags'] = _validate_tags(tags)
return self

@property
def locked(self):
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)

@property
def model_format(self):
return self._model_format

@model_format.setter
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
return self

def as_dict(self):
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict())
return copy


class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self):
raise NotImplementedError


class TFLiteFormat(ModelFormat):
"""Model format representing a TFLite model.

Args:
model_source: A TFLiteModelSource sub class. Specifies the details of the model source.
"""
def __init__(self, model_source=None):
self._data = {}
self._model_source = None

if model_source is not None:
self.model_source = model_source

@classmethod
def from_dict(cls, data):
data_copy = dict(data)
model_source = None
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
tflite_format = TFLiteFormat(model_source=model_source)
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format


def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return self._data == other._data and self._model_source == other._model_source
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def model_source(self):
return self._model_source

@model_source.setter
def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None

@property
def size_bytes(self):
return self._data.get('sizeBytes')

def as_dict(self):
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict())
return {'tfliteModel': copy}


class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self):
raise NotImplementedError


class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
def __init__(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def gcs_tflite_uri(self):
return self._gcs_tflite_uri

@gcs_tflite_uri.setter
def gcs_tflite_uri(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def as_dict(self):
return {"gcsTfliteUri": self._gcs_tflite_uri}

#TODO(ifielker): implement from_saved_model etc.


class ListModelsPage(object):
Expand All @@ -105,7 +306,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token):
@property
def models(self):
"""A list of Models from this page."""
return [Model(model) for model in self._list_response.get('models', [])]
return [Model.from_dict(model) for model in self._list_response.get('models', [])]

@property
def list_filter(self):
Expand Down Expand Up @@ -179,13 +380,48 @@ def __iter__(self):
return self


def _validate_and_parse_name(name):
# The resource name is added automatically from API call responses.
# The only way it could be invalid is if someone tries to
# create a model from a dictionary manually and does it incorrectly.
matcher = _RESOURCE_NAME_PATTERN.match(name)
if not matcher:
raise ValueError('Model resource name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model_id(model_id):
if not isinstance(model_id, six.string_types):
raise TypeError('Model ID must be a string.')
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
return display_name


def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, six.string_types) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
return tags


def _validate_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
if not _GCS_TFLITE_URI_PATTERN.match(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri

def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')
return model_format

def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, six.string_types):
Expand Down
Loading