Skip to content

Commit 2a3be77

Browse files
authored
Firebase ML Kit Update Model API implementation (#343)
* Firebase ML Kit Create Model API implementation
1 parent e5cf14a commit 2a3be77

File tree

2 files changed

+156
-34
lines changed

2 files changed

+156
-34
lines changed

firebase_admin/mlkit.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ def create_model(model, app=None):
7373
return Model.from_dict(mlkit_service.create_model(model), app=app)
7474

7575

76+
def update_model(model, app=None):
77+
"""Updates a model in Firebase ML Kit.
78+
79+
Args:
80+
model: The mlkit.Model to update.
81+
app: A Firebase app instance (or None to use the default app).
82+
83+
Returns:
84+
Model: The updated model.
85+
"""
86+
mlkit_service = _get_mlkit_service(app)
87+
return Model.from_dict(mlkit_service.update_model(model), app=app)
88+
89+
7690
def get_model(model_id, app=None):
7791
"""Gets a model from Firebase ML Kit.
7892
@@ -469,10 +483,10 @@ def _validate_and_parse_name(name):
469483
return matcher.group('project_id'), matcher.group('model_id')
470484

471485

472-
def _validate_model(model):
486+
def _validate_model(model, update_mask=None):
473487
if not isinstance(model, Model):
474488
raise TypeError('Model must be an mlkit.Model.')
475-
if not model.display_name:
489+
if update_mask is None and not model.display_name:
476490
raise ValueError('Model must have a display name.')
477491

478492

@@ -634,6 +648,17 @@ def create_model(self, model):
634648
except requests.exceptions.RequestException as error:
635649
raise _utils.handle_platform_error_from_requests(error)
636650

651+
def update_model(self, model, update_mask=None):
652+
_validate_model(model, update_mask)
653+
data = {'model': model.as_dict()}
654+
if update_mask is not None:
655+
data['updateMask'] = update_mask
656+
try:
657+
return self.handle_operation(
658+
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
659+
except requests.exceptions.RequestException as error:
660+
raise _utils.handle_platform_error_from_requests(error)
661+
637662
def get_model(self, model_id):
638663
_validate_model_id(model_id)
639664
try:

tests/test_mlkit.py

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tests import testutils
2525

2626
BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
27-
2827
PROJECT_ID = 'myProject1'
2928
PAGE_TOKEN = 'pageToken'
3029
NEXT_PAGE_TOKEN = 'nextPageToken'
@@ -122,7 +121,7 @@
122121
}
123122
TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)
124123

125-
CREATED_MODEL_JSON_1 = {
124+
CREATED_UPDATED_MODEL_JSON_1 = {
126125
'name': MODEL_NAME_1,
127126
'displayName': DISPLAY_NAME_1,
128127
'createTime': CREATE_TIME_JSON,
@@ -132,7 +131,7 @@
132131
'modelHash': MODEL_HASH,
133132
'tags': TAGS,
134133
}
135-
CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1)
134+
CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1)
136135

137136
LOCKED_MODEL_JSON_1 = {
138137
'name': MODEL_NAME_1,
@@ -155,19 +154,16 @@
155154
OPERATION_DONE_MODEL_JSON_1 = {
156155
'name': OPERATION_NAME_1,
157156
'done': True,
158-
'response': CREATED_MODEL_JSON_1
157+
'response': CREATED_UPDATED_MODEL_JSON_1
159158
}
160-
161159
OPERATION_MALFORMED_JSON_1 = {
162160
'name': OPERATION_NAME_1,
163161
'done': True,
164162
# if done is true then either response or error should be populated
165163
}
166-
167164
OPERATION_MISSING_NAME = {
168165
'done': False
169166
}
170-
171167
OPERATION_ERROR_CODE = 400
172168
OPERATION_ERROR_MSG = "Invalid argument"
173169
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
@@ -254,15 +250,33 @@
254250
}
255251
ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST)
256252

257-
invalid_model_id_args = [
253+
INVALID_MODEL_ID_ARGS = [
258254
('', ValueError),
259255
('&_*#@:/?', ValueError),
260256
(None, TypeError),
261257
(12345, TypeError),
262258
]
259+
INVALID_MODEL_ARGS = [
260+
'abc',
261+
4.2,
262+
list(),
263+
dict(),
264+
True,
265+
-1,
266+
0,
267+
None
268+
]
269+
INVALID_OP_NAME_ARGS = [
270+
'abc',
271+
'123',
272+
'projects/operations/project/1234/model/abc/operation/123',
273+
'operations/project/model/abc/operation/123',
274+
'operations/project/123/model/$#@/operation/123',
275+
'operations/project/1234/model/abc/operation/123/extrathing',
276+
]
263277
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
264278
'1 and {0}'.format(mlkit._MAX_PAGE_SIZE)
265-
invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()]
279+
INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()]
266280

267281

268282
# For validation type errors
@@ -524,7 +538,7 @@ def _get_url(project_id, model_id):
524538
def test_immediate_done(self):
525539
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
526540
model = mlkit.create_model(MODEL_1)
527-
assert model == CREATED_MODEL_1
541+
assert model == CREATED_UPDATED_MODEL_1
528542

529543
def test_returns_locked(self):
530544
recorder = instrument_mlkit_service(
@@ -573,16 +587,7 @@ def test_rpc_error_create(self):
573587
)
574588
assert len(create_recorder) == 1
575589

576-
@pytest.mark.parametrize('model', [
577-
'abc',
578-
4.2,
579-
list(),
580-
dict(),
581-
True,
582-
-1,
583-
0,
584-
None
585-
])
590+
@pytest.mark.parametrize('model', INVALID_MODEL_ARGS)
586591
def test_not_model(self, model):
587592
with pytest.raises(Exception) as excinfo:
588593
mlkit.create_model(model)
@@ -599,14 +604,7 @@ def test_missing_op_name(self):
599604
mlkit.create_model(MODEL_1)
600605
check_error(excinfo, TypeError)
601606

602-
@pytest.mark.parametrize('op_name', [
603-
'abc',
604-
'123',
605-
'projects/operations/project/1234/model/abc/operation/123',
606-
'operations/project/model/abc/operation/123',
607-
'operations/project/123/model/$#@/operation/123',
608-
'operations/project/1234/model/abc/operation/123/extrathing',
609-
])
607+
@pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS)
610608
def test_invalid_op_name(self, op_name):
611609
payload = json.dumps({'name': op_name})
612610
instrument_mlkit_service(status=200, payload=payload)
@@ -615,6 +613,105 @@ def test_invalid_op_name(self, op_name):
615613
check_error(excinfo, ValueError, 'Operation name format is invalid.')
616614

617615

616+
class TestUpdateModel(object):
617+
"""Tests mlkit.update_model."""
618+
@classmethod
619+
def setup_class(cls):
620+
cred = testutils.MockCredential()
621+
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
622+
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
623+
624+
@classmethod
625+
def teardown_class(cls):
626+
testutils.cleanup_apps()
627+
628+
@staticmethod
629+
def _url(project_id, model_id):
630+
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
631+
632+
@staticmethod
633+
def _op_url(project_id, model_id):
634+
return BASE_URL + \
635+
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
636+
637+
def test_immediate_done(self):
638+
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
639+
model = mlkit.update_model(MODEL_1)
640+
assert model == CREATED_UPDATED_MODEL_1
641+
642+
def test_returns_locked(self):
643+
recorder = instrument_mlkit_service(
644+
status=[200, 200],
645+
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
646+
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
647+
model = mlkit.update_model(MODEL_1)
648+
649+
assert model == expected_model
650+
assert len(recorder) == 2
651+
assert recorder[0].method == 'PATCH'
652+
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
653+
assert recorder[1].method == 'GET'
654+
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
655+
656+
def test_operation_error(self):
657+
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
658+
with pytest.raises(Exception) as excinfo:
659+
mlkit.update_model(MODEL_1)
660+
# The http request succeeded, the operation returned contains a create failure
661+
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
662+
663+
def test_malformed_operation(self):
664+
recorder = instrument_mlkit_service(
665+
status=[200, 200],
666+
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
667+
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
668+
model = mlkit.update_model(MODEL_1)
669+
assert model == expected_model
670+
assert len(recorder) == 2
671+
assert recorder[0].method == 'PATCH'
672+
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
673+
assert recorder[1].method == 'GET'
674+
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
675+
676+
def test_rpc_error_create(self):
677+
create_recorder = instrument_mlkit_service(
678+
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
679+
with pytest.raises(Exception) as excinfo:
680+
mlkit.update_model(MODEL_1)
681+
check_firebase_error(
682+
excinfo,
683+
ERROR_STATUS_BAD_REQUEST,
684+
ERROR_CODE_BAD_REQUEST,
685+
ERROR_MSG_BAD_REQUEST
686+
)
687+
assert len(create_recorder) == 1
688+
689+
@pytest.mark.parametrize('model', INVALID_MODEL_ARGS)
690+
def test_not_model(self, model):
691+
with pytest.raises(Exception) as excinfo:
692+
mlkit.update_model(model)
693+
check_error(excinfo, TypeError, 'Model must be an mlkit.Model.')
694+
695+
def test_missing_display_name(self):
696+
with pytest.raises(Exception) as excinfo:
697+
mlkit.update_model(mlkit.Model.from_dict({}))
698+
check_error(excinfo, ValueError, 'Model must have a display name.')
699+
700+
def test_missing_op_name(self):
701+
instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE)
702+
with pytest.raises(Exception) as excinfo:
703+
mlkit.update_model(MODEL_1)
704+
check_error(excinfo, TypeError)
705+
706+
@pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS)
707+
def test_invalid_op_name(self, op_name):
708+
payload = json.dumps({'name': op_name})
709+
instrument_mlkit_service(status=200, payload=payload)
710+
with pytest.raises(Exception) as excinfo:
711+
mlkit.update_model(MODEL_1)
712+
check_error(excinfo, ValueError, 'Operation name format is invalid.')
713+
714+
618715
class TestGetModel(object):
619716
"""Tests mlkit.get_model."""
620717
@classmethod
@@ -640,7 +737,7 @@ def test_get_model(self):
640737
assert model.model_id == MODEL_ID_1
641738
assert model.display_name == DISPLAY_NAME_1
642739

643-
@pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args)
740+
@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
644741
def test_get_model_validation_errors(self, model_id, exc_type):
645742
with pytest.raises(exc_type) as excinfo:
646743
mlkit.get_model(model_id)
@@ -690,7 +787,7 @@ def test_delete_model(self):
690787
assert recorder[0].method == 'DELETE'
691788
assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)
692789

693-
@pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args)
790+
@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
694791
def test_delete_model_validation_errors(self, model_id, exc_type):
695792
with pytest.raises(exc_type) as excinfo:
696793
mlkit.delete_model(model_id)
@@ -771,7 +868,7 @@ def test_list_models_with_all_args(self):
771868
assert models_page.models[0] == MODEL_3
772869
assert not models_page.has_next_page
773870

774-
@pytest.mark.parametrize('list_filter', invalid_string_or_none_args)
871+
@pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS)
775872
def test_list_models_list_filter_validation(self, list_filter):
776873
with pytest.raises(TypeError) as excinfo:
777874
mlkit.list_models(list_filter=list_filter)
@@ -792,7 +889,7 @@ def test_list_models_page_size_validation(self, page_size, exc_type, error_messa
792889
mlkit.list_models(page_size=page_size)
793890
check_error(excinfo, exc_type, error_message)
794891

795-
@pytest.mark.parametrize('page_token', invalid_string_or_none_args)
892+
@pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS)
796893
def test_list_models_page_token_validation(self, page_token):
797894
with pytest.raises(TypeError) as excinfo:
798895
mlkit.list_models(page_token=page_token)

0 commit comments

Comments
 (0)