Skip to content

Commit e5cf14a

Browse files
authored
Firebase ML Kit Create Model API implementation (#337)
* create model plus long running operation handling * Model.wait_for_unlocked
1 parent 4618b1e commit e5cf14a

File tree

3 files changed

+494
-40
lines changed

3 files changed

+494
-40
lines changed

firebase_admin/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
106106
return exc if exc else _handle_func_requests(error, message, error_dict)
107107

108108

109+
def handle_operation_error(error):
110+
"""Constructs a ``FirebaseError`` from the given operation error.
111+
112+
Args:
113+
error: An error returned by a long running operation.
114+
115+
Returns:
116+
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
117+
"""
118+
if not isinstance(error, dict):
119+
return exceptions.UnknownError(
120+
message='Unknown error while making a remote service call: {0}'.format(error),
121+
cause=error)
122+
123+
status_code = error.get('code')
124+
message = error.get('message')
125+
error_code = _http_status_to_error_code(status_code)
126+
err_type = _error_code_to_exception_type(error_code)
127+
return err_type(message=message)
128+
129+
109130
def _handle_func_requests(error, message, error_dict):
110131
"""Constructs a ``FirebaseError`` from the given GCP error.
111132

firebase_admin/mlkit.py

Lines changed: 180 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
import datetime
2222
import numbers
2323
import re
24+
import time
2425
import requests
2526
import six
2627

28+
2729
from firebase_admin import _http_client
2830
from firebase_admin import _utils
31+
from firebase_admin import exceptions
2932

3033

3134
_MLKIT_ATTRIBUTE = '_mlkit'
@@ -36,6 +39,9 @@
3639
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
3740
_RESOURCE_NAME_PATTERN = re.compile(
3841
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
42+
_OPERATION_NAME_PATTERN = re.compile(
43+
r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
44+
r'/operation/[^/]+$')
3945

4046

4147
def _get_mlkit_service(app):
@@ -53,18 +59,60 @@ def _get_mlkit_service(app):
5359
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)
5460

5561

62+
def create_model(model, app=None):
63+
"""Creates a model in Firebase ML Kit.
64+
65+
Args:
66+
model: An mlkit.Model to create.
67+
app: A Firebase app instance (or None to use the default app).
68+
69+
Returns:
70+
Model: The model that was created in Firebase ML Kit.
71+
"""
72+
mlkit_service = _get_mlkit_service(app)
73+
return Model.from_dict(mlkit_service.create_model(model), app=app)
74+
75+
5676
def get_model(model_id, app=None):
77+
"""Gets a model from Firebase ML Kit.
78+
79+
Args:
80+
model_id: The id of the model to get.
81+
app: A Firebase app instance (or None to use the default app).
82+
83+
Returns:
84+
Model: The requested model.
85+
"""
5786
mlkit_service = _get_mlkit_service(app)
58-
return Model.from_dict(mlkit_service.get_model(model_id))
87+
return Model.from_dict(mlkit_service.get_model(model_id), app=app)
5988

6089

6190
def list_models(list_filter=None, page_size=None, page_token=None, app=None):
91+
"""Lists models from Firebase ML Kit.
92+
93+
Args:
94+
list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
95+
page_size: A number between 1 and 100 inclusive that specifies the maximum
96+
number of models to return per page. None for default.
97+
page_token: A next page token returned from a previous page of results. None
98+
for first page of results.
99+
app: A Firebase app instance (or None to use the default app).
100+
101+
Returns:
102+
ListModelsPage: A (filtered) list of models.
103+
"""
62104
mlkit_service = _get_mlkit_service(app)
63105
return ListModelsPage(
64-
mlkit_service.list_models, list_filter, page_size, page_token)
106+
mlkit_service.list_models, list_filter, page_size, page_token, app=app)
65107

66108

67109
def delete_model(model_id, app=None):
110+
"""Deletes a model from Firebase ML Kit.
111+
112+
Args:
113+
model_id: The id of the model you wish to delete.
114+
app: A Firebase app instance (or None to use the default app).
115+
"""
68116
mlkit_service = _get_mlkit_service(app)
69117
mlkit_service.delete_model(model_id)
70118

@@ -78,6 +126,7 @@ class Model(object):
78126
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
79127
"""
80128
def __init__(self, display_name=None, tags=None, model_format=None):
129+
self._app = None # Only needed for wait_for_unlo
81130
self._data = {}
82131
self._model_format = None
83132

@@ -89,16 +138,22 @@ def __init__(self, display_name=None, tags=None, model_format=None):
89138
self.model_format = model_format
90139

91140
@classmethod
92-
def from_dict(cls, data):
141+
def from_dict(cls, data, app=None):
93142
data_copy = dict(data)
94143
tflite_format = None
95144
tflite_format_data = data_copy.pop('tfliteModel', None)
96145
if tflite_format_data:
97146
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
98147
model = Model(model_format=tflite_format)
99148
model._data = data_copy # pylint: disable=protected-access
149+
model._app = app # pylint: disable=protected-access
100150
return model
101151

152+
def _update_from_dict(self, data):
153+
copy = Model.from_dict(data)
154+
self.model_format = copy.model_format
155+
self._data = copy._data # pylint: disable=protected-access
156+
102157
def __eq__(self, other):
103158
if isinstance(other, self.__class__):
104159
# pylint: disable=protected-access
@@ -173,6 +228,26 @@ def locked(self):
173228
return bool(self._data.get('activeOperations') and
174229
len(self._data.get('activeOperations')) > 0)
175230

231+
def wait_for_unlocked(self, max_time_seconds=None):
232+
"""Waits for the model to be unlocked. (All active operations complete)
233+
234+
Args:
235+
max_time_seconds: The maximum number of seconds to wait for the model to unlock.
236+
(None for no limit)
237+
238+
Raises:
239+
exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked.
240+
"""
241+
if not self.locked:
242+
return
243+
mlkit_service = _get_mlkit_service(self._app)
244+
op_name = self._data.get('activeOperations')[0].get('name')
245+
model_dict = mlkit_service.handle_operation(
246+
mlkit_service.get_operation(op_name),
247+
wait_for_operation=True,
248+
max_time_seconds=max_time_seconds)
249+
self._update_from_dict(model_dict)
250+
176251
@property
177252
def model_format(self):
178253
return self._model_format
@@ -296,17 +371,20 @@ class ListModelsPage(object):
296371
``iterate_all()`` can be used to iterate through all the models in the
297372
Firebase project starting from this page.
298373
"""
299-
def __init__(self, list_models_func, list_filter, page_size, page_token):
374+
def __init__(self, list_models_func, list_filter, page_size, page_token, app):
300375
self._list_models_func = list_models_func
301376
self._list_filter = list_filter
302377
self._page_size = page_size
303378
self._page_token = page_token
379+
self._app = app
304380
self._list_response = list_models_func(list_filter, page_size, page_token)
305381

306382
@property
307383
def models(self):
308384
"""A list of Models from this page."""
309-
return [Model.from_dict(model) for model in self._list_response.get('models', [])]
385+
return [
386+
Model.from_dict(model, app=self._app) for model in self._list_response.get('models', [])
387+
]
310388

311389
@property
312390
def list_filter(self):
@@ -333,7 +411,8 @@ def get_next_page(self):
333411
self._list_models_func,
334412
self._list_filter,
335413
self._page_size,
336-
self.next_page_token)
414+
self.next_page_token,
415+
self._app)
337416
return None
338417

339418
def iterate_all(self):
@@ -390,11 +469,25 @@ def _validate_and_parse_name(name):
390469
return matcher.group('project_id'), matcher.group('model_id')
391470

392471

472+
def _validate_model(model):
473+
if not isinstance(model, Model):
474+
raise TypeError('Model must be an mlkit.Model.')
475+
if not model.display_name:
476+
raise ValueError('Model must have a display name.')
477+
478+
393479
def _validate_model_id(model_id):
394480
if not _MODEL_ID_PATTERN.match(model_id):
395481
raise ValueError('Model ID format is invalid.')
396482

397483

484+
def _validate_and_parse_operation_name(op_name):
485+
matcher = _OPERATION_NAME_PATTERN.match(op_name)
486+
if not matcher:
487+
raise ValueError('Operation name format is invalid.')
488+
return matcher.group('project_id'), matcher.group('model_id')
489+
490+
398491
def _validate_display_name(display_name):
399492
if not _DISPLAY_NAME_PATTERN.match(display_name):
400493
raise ValueError('Display name format is invalid.')
@@ -417,11 +510,13 @@ def _validate_gcs_tflite_uri(uri):
417510
raise ValueError('GCS TFLite URI format is invalid.')
418511
return uri
419512

513+
420514
def _validate_model_format(model_format):
421515
if not isinstance(model_format, ModelFormat):
422516
raise TypeError('Model format must be a ModelFormat object.')
423517
return model_format
424518

519+
425520
def _validate_list_filter(list_filter):
426521
if list_filter is not None:
427522
if not isinstance(list_filter, six.string_types):
@@ -448,6 +543,9 @@ class _MLKitService(object):
448543
"""Firebase MLKit service."""
449544

450545
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
546+
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
547+
POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5
548+
POLL_BASE_WAIT_TIME_SECONDS = 3
451549

452550
def __init__(self, app):
453551
project_id = app.project_id
@@ -459,6 +557,82 @@ def __init__(self, app):
459557
self._client = _http_client.JsonHttpClient(
460558
credential=app.credential.get_credential(),
461559
base_url=self._project_url)
560+
self._operation_client = _http_client.JsonHttpClient(
561+
credential=app.credential.get_credential(),
562+
base_url=_MLKitService.OPERATION_URL)
563+
564+
def get_operation(self, op_name):
565+
_validate_and_parse_operation_name(op_name)
566+
try:
567+
return self._operation_client.body('get', url=op_name)
568+
except requests.exceptions.RequestException as error:
569+
raise _utils.handle_platform_error_from_requests(error)
570+
571+
def _exponential_backoff(self, current_attempt, stop_time):
572+
"""Sleeps for the appropriate amount of time. Or throws deadline exceeded."""
573+
delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
574+
wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS
575+
576+
if stop_time is not None:
577+
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
578+
if max_seconds_left < 1: # allow a bit of time for rpc
579+
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
580+
else:
581+
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
582+
time.sleep(wait_time_seconds)
583+
584+
585+
def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
586+
"""Handles long running operations.
587+
588+
Args:
589+
operation: The operation to handle.
590+
wait_for_operation: Should we allow polling for the operation to complete.
591+
If no polling is requested, a locked model will be returned instead.
592+
max_time_seconds: The maximum seconds to try polling for operation complete.
593+
(None for no limit)
594+
595+
Returns:
596+
dict: A dictionary of the returned model properties.
597+
598+
Raises:
599+
TypeError: if the operation is not a dictionary.
600+
ValueError: If the operation is malformed.
601+
err: If the operation exceeds polling attempts or stop_time
602+
"""
603+
if not isinstance(operation, dict):
604+
raise TypeError('Operation must be a dictionary.')
605+
op_name = operation.get('name')
606+
_, model_id = _validate_and_parse_operation_name(op_name)
607+
608+
current_attempt = 0
609+
start_time = datetime.datetime.now()
610+
stop_time = (None if max_time_seconds is None else
611+
start_time + datetime.timedelta(seconds=max_time_seconds))
612+
while wait_for_operation and not operation.get('done'):
613+
# We just got this operation. Wait before getting another
614+
# so we don't exceed the GetOperation maximum request rate.
615+
self._exponential_backoff(current_attempt, stop_time)
616+
operation = self.get_operation(op_name)
617+
current_attempt += 1
618+
619+
if operation.get('done'):
620+
if operation.get('response'):
621+
return operation.get('response')
622+
elif operation.get('error'):
623+
raise _utils.handle_operation_error(operation.get('error'))
624+
625+
# If the operation is not complete or timed out, return a (locked) model instead
626+
return get_model(model_id).as_dict()
627+
628+
629+
def create_model(self, model):
630+
_validate_model(model)
631+
try:
632+
return self.handle_operation(
633+
self._client.body('post', url='models', json=model.as_dict()))
634+
except requests.exceptions.RequestException as error:
635+
raise _utils.handle_platform_error_from_requests(error)
462636

463637
def get_model(self, model_id):
464638
_validate_model_id(model_id)

0 commit comments

Comments
 (0)