21
21
import datetime
22
22
import numbers
23
23
import re
24
+ import time
24
25
import requests
25
26
import six
26
27
28
+
27
29
from firebase_admin import _http_client
28
30
from firebase_admin import _utils
31
+ from firebase_admin import exceptions
29
32
30
33
31
34
_MLKIT_ATTRIBUTE = '_mlkit'
36
39
_GCS_TFLITE_URI_PATTERN = re .compile (r'^gs://[a-z0-9_.-]{3,63}/.+' )
37
40
_RESOURCE_NAME_PATTERN = re .compile (
38
41
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$' )
42
+ _OPERATION_NAME_PATTERN = re .compile (
43
+ r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
44
+ r'/operation/[^/]+$' )
39
45
40
46
41
47
def _get_mlkit_service (app ):
@@ -53,18 +59,60 @@ def _get_mlkit_service(app):
53
59
return _utils .get_app_service (app , _MLKIT_ATTRIBUTE , _MLKitService )
54
60
55
61
62
+ def create_model (model , app = None ):
63
+ """Creates a model in Firebase ML Kit.
64
+
65
+ Args:
66
+ model: An mlkit.Model to create.
67
+ app: A Firebase app instance (or None to use the default app).
68
+
69
+ Returns:
70
+ Model: The model that was created in Firebase ML Kit.
71
+ """
72
+ mlkit_service = _get_mlkit_service (app )
73
+ return Model .from_dict (mlkit_service .create_model (model ), app = app )
74
+
75
+
56
76
def get_model (model_id , app = None ):
77
+ """Gets a model from Firebase ML Kit.
78
+
79
+ Args:
80
+ model_id: The id of the model to get.
81
+ app: A Firebase app instance (or None to use the default app).
82
+
83
+ Returns:
84
+ Model: The requested model.
85
+ """
57
86
mlkit_service = _get_mlkit_service (app )
58
- return Model .from_dict (mlkit_service .get_model (model_id ))
87
+ return Model .from_dict (mlkit_service .get_model (model_id ), app = app )
59
88
60
89
61
90
def list_models (list_filter = None , page_size = None , page_token = None , app = None ):
91
+ """Lists models from Firebase ML Kit.
92
+
93
+ Args:
94
+ list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
95
+ page_size: A number between 1 and 100 inclusive that specifies the maximum
96
+ number of models to return per page. None for default.
97
+ page_token: A next page token returned from a previous page of results. None
98
+ for first page of results.
99
+ app: A Firebase app instance (or None to use the default app).
100
+
101
+ Returns:
102
+ ListModelsPage: A (filtered) list of models.
103
+ """
62
104
mlkit_service = _get_mlkit_service (app )
63
105
return ListModelsPage (
64
- mlkit_service .list_models , list_filter , page_size , page_token )
106
+ mlkit_service .list_models , list_filter , page_size , page_token , app = app )
65
107
66
108
67
109
def delete_model (model_id , app = None ):
110
+ """Deletes a model from Firebase ML Kit.
111
+
112
+ Args:
113
+ model_id: The id of the model you wish to delete.
114
+ app: A Firebase app instance (or None to use the default app).
115
+ """
68
116
mlkit_service = _get_mlkit_service (app )
69
117
mlkit_service .delete_model (model_id )
70
118
@@ -78,6 +126,7 @@ class Model(object):
78
126
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
79
127
"""
80
128
def __init__ (self , display_name = None , tags = None , model_format = None ):
129
+ self ._app = None # Only needed for wait_for_unlo
81
130
self ._data = {}
82
131
self ._model_format = None
83
132
@@ -89,16 +138,22 @@ def __init__(self, display_name=None, tags=None, model_format=None):
89
138
self .model_format = model_format
90
139
91
140
@classmethod
92
- def from_dict (cls , data ):
141
+ def from_dict (cls , data , app = None ):
93
142
data_copy = dict (data )
94
143
tflite_format = None
95
144
tflite_format_data = data_copy .pop ('tfliteModel' , None )
96
145
if tflite_format_data :
97
146
tflite_format = TFLiteFormat .from_dict (tflite_format_data )
98
147
model = Model (model_format = tflite_format )
99
148
model ._data = data_copy # pylint: disable=protected-access
149
+ model ._app = app # pylint: disable=protected-access
100
150
return model
101
151
152
+ def _update_from_dict (self , data ):
153
+ copy = Model .from_dict (data )
154
+ self .model_format = copy .model_format
155
+ self ._data = copy ._data # pylint: disable=protected-access
156
+
102
157
def __eq__ (self , other ):
103
158
if isinstance (other , self .__class__ ):
104
159
# pylint: disable=protected-access
@@ -173,6 +228,26 @@ def locked(self):
173
228
return bool (self ._data .get ('activeOperations' ) and
174
229
len (self ._data .get ('activeOperations' )) > 0 )
175
230
231
+ def wait_for_unlocked (self , max_time_seconds = None ):
232
+ """Waits for the model to be unlocked. (All active operations complete)
233
+
234
+ Args:
235
+ max_time_seconds: The maximum number of seconds to wait for the model to unlock.
236
+ (None for no limit)
237
+
238
+ Raises:
239
+ exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked.
240
+ """
241
+ if not self .locked :
242
+ return
243
+ mlkit_service = _get_mlkit_service (self ._app )
244
+ op_name = self ._data .get ('activeOperations' )[0 ].get ('name' )
245
+ model_dict = mlkit_service .handle_operation (
246
+ mlkit_service .get_operation (op_name ),
247
+ wait_for_operation = True ,
248
+ max_time_seconds = max_time_seconds )
249
+ self ._update_from_dict (model_dict )
250
+
176
251
@property
177
252
def model_format (self ):
178
253
return self ._model_format
@@ -296,17 +371,20 @@ class ListModelsPage(object):
296
371
``iterate_all()`` can be used to iterate through all the models in the
297
372
Firebase project starting from this page.
298
373
"""
299
- def __init__ (self , list_models_func , list_filter , page_size , page_token ):
374
+ def __init__ (self , list_models_func , list_filter , page_size , page_token , app ):
300
375
self ._list_models_func = list_models_func
301
376
self ._list_filter = list_filter
302
377
self ._page_size = page_size
303
378
self ._page_token = page_token
379
+ self ._app = app
304
380
self ._list_response = list_models_func (list_filter , page_size , page_token )
305
381
306
382
@property
307
383
def models (self ):
308
384
"""A list of Models from this page."""
309
- return [Model .from_dict (model ) for model in self ._list_response .get ('models' , [])]
385
+ return [
386
+ Model .from_dict (model , app = self ._app ) for model in self ._list_response .get ('models' , [])
387
+ ]
310
388
311
389
@property
312
390
def list_filter (self ):
@@ -333,7 +411,8 @@ def get_next_page(self):
333
411
self ._list_models_func ,
334
412
self ._list_filter ,
335
413
self ._page_size ,
336
- self .next_page_token )
414
+ self .next_page_token ,
415
+ self ._app )
337
416
return None
338
417
339
418
def iterate_all (self ):
@@ -390,11 +469,25 @@ def _validate_and_parse_name(name):
390
469
return matcher .group ('project_id' ), matcher .group ('model_id' )
391
470
392
471
472
+ def _validate_model (model ):
473
+ if not isinstance (model , Model ):
474
+ raise TypeError ('Model must be an mlkit.Model.' )
475
+ if not model .display_name :
476
+ raise ValueError ('Model must have a display name.' )
477
+
478
+
393
479
def _validate_model_id (model_id ):
394
480
if not _MODEL_ID_PATTERN .match (model_id ):
395
481
raise ValueError ('Model ID format is invalid.' )
396
482
397
483
484
+ def _validate_and_parse_operation_name (op_name ):
485
+ matcher = _OPERATION_NAME_PATTERN .match (op_name )
486
+ if not matcher :
487
+ raise ValueError ('Operation name format is invalid.' )
488
+ return matcher .group ('project_id' ), matcher .group ('model_id' )
489
+
490
+
398
491
def _validate_display_name (display_name ):
399
492
if not _DISPLAY_NAME_PATTERN .match (display_name ):
400
493
raise ValueError ('Display name format is invalid.' )
@@ -417,11 +510,13 @@ def _validate_gcs_tflite_uri(uri):
417
510
raise ValueError ('GCS TFLite URI format is invalid.' )
418
511
return uri
419
512
513
+
420
514
def _validate_model_format (model_format ):
421
515
if not isinstance (model_format , ModelFormat ):
422
516
raise TypeError ('Model format must be a ModelFormat object.' )
423
517
return model_format
424
518
519
+
425
520
def _validate_list_filter (list_filter ):
426
521
if list_filter is not None :
427
522
if not isinstance (list_filter , six .string_types ):
@@ -448,6 +543,9 @@ class _MLKitService(object):
448
543
"""Firebase MLKit service."""
449
544
450
545
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
546
+ OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
547
+ POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5
548
+ POLL_BASE_WAIT_TIME_SECONDS = 3
451
549
452
550
def __init__ (self , app ):
453
551
project_id = app .project_id
@@ -459,6 +557,82 @@ def __init__(self, app):
459
557
self ._client = _http_client .JsonHttpClient (
460
558
credential = app .credential .get_credential (),
461
559
base_url = self ._project_url )
560
+ self ._operation_client = _http_client .JsonHttpClient (
561
+ credential = app .credential .get_credential (),
562
+ base_url = _MLKitService .OPERATION_URL )
563
+
564
+ def get_operation (self , op_name ):
565
+ _validate_and_parse_operation_name (op_name )
566
+ try :
567
+ return self ._operation_client .body ('get' , url = op_name )
568
+ except requests .exceptions .RequestException as error :
569
+ raise _utils .handle_platform_error_from_requests (error )
570
+
571
+ def _exponential_backoff (self , current_attempt , stop_time ):
572
+ """Sleeps for the appropriate amount of time. Or throws deadline exceeded."""
573
+ delay_factor = pow (_MLKitService .POLL_EXPONENTIAL_BACKOFF_FACTOR , current_attempt )
574
+ wait_time_seconds = delay_factor * _MLKitService .POLL_BASE_WAIT_TIME_SECONDS
575
+
576
+ if stop_time is not None :
577
+ max_seconds_left = (stop_time - datetime .datetime .now ()).total_seconds ()
578
+ if max_seconds_left < 1 : # allow a bit of time for rpc
579
+ raise exceptions .DeadlineExceededError ('Polling max time exceeded.' )
580
+ else :
581
+ wait_time_seconds = min (wait_time_seconds , max_seconds_left - 1 )
582
+ time .sleep (wait_time_seconds )
583
+
584
+
585
+ def handle_operation (self , operation , wait_for_operation = False , max_time_seconds = None ):
586
+ """Handles long running operations.
587
+
588
+ Args:
589
+ operation: The operation to handle.
590
+ wait_for_operation: Should we allow polling for the operation to complete.
591
+ If no polling is requested, a locked model will be returned instead.
592
+ max_time_seconds: The maximum seconds to try polling for operation complete.
593
+ (None for no limit)
594
+
595
+ Returns:
596
+ dict: A dictionary of the returned model properties.
597
+
598
+ Raises:
599
+ TypeError: if the operation is not a dictionary.
600
+ ValueError: If the operation is malformed.
601
+ err: If the operation exceeds polling attempts or stop_time
602
+ """
603
+ if not isinstance (operation , dict ):
604
+ raise TypeError ('Operation must be a dictionary.' )
605
+ op_name = operation .get ('name' )
606
+ _ , model_id = _validate_and_parse_operation_name (op_name )
607
+
608
+ current_attempt = 0
609
+ start_time = datetime .datetime .now ()
610
+ stop_time = (None if max_time_seconds is None else
611
+ start_time + datetime .timedelta (seconds = max_time_seconds ))
612
+ while wait_for_operation and not operation .get ('done' ):
613
+ # We just got this operation. Wait before getting another
614
+ # so we don't exceed the GetOperation maximum request rate.
615
+ self ._exponential_backoff (current_attempt , stop_time )
616
+ operation = self .get_operation (op_name )
617
+ current_attempt += 1
618
+
619
+ if operation .get ('done' ):
620
+ if operation .get ('response' ):
621
+ return operation .get ('response' )
622
+ elif operation .get ('error' ):
623
+ raise _utils .handle_operation_error (operation .get ('error' ))
624
+
625
+ # If the operation is not complete or timed out, return a (locked) model instead
626
+ return get_model (model_id ).as_dict ()
627
+
628
+
629
+ def create_model (self , model ):
630
+ _validate_model (model )
631
+ try :
632
+ return self .handle_operation (
633
+ self ._client .body ('post' , url = 'models' , json = model .as_dict ()))
634
+ except requests .exceptions .RequestException as error :
635
+ raise _utils .handle_platform_error_from_requests (error )
462
636
463
637
def get_model (self , model_id ):
464
638
_validate_model_id (model_id )
0 commit comments