Skip to content

Firebase ML Kit Update Model API implementation #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def create_model(model, app=None):
return Model.from_dict(mlkit_service.create_model(model), app=app)


def update_model(model, app=None):
"""Updates a model in Firebase ML Kit.

Args:
model: The mlkit.Model to update.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The updated model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.update_model(model), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Expand Down Expand Up @@ -469,10 +483,10 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
def _validate_model(model, update_mask=None):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
if update_mask is None and not model.display_name:
raise ValueError('Model must have a display name.')


Expand Down Expand Up @@ -634,6 +648,17 @@ def create_model(self, model):
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict()}
if update_mask is not None:
data['updateMask'] = update_mask
try:
return self.handle_operation(
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
try:
Expand Down
161 changes: 129 additions & 32 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tests import testutils

BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'

PROJECT_ID = 'myProject1'
PAGE_TOKEN = 'pageToken'
NEXT_PAGE_TOKEN = 'nextPageToken'
Expand Down Expand Up @@ -122,7 +121,7 @@
}
TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)

CREATED_MODEL_JSON_1 = {
CREATED_UPDATED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
Expand All @@ -132,7 +131,7 @@
'modelHash': MODEL_HASH,
'tags': TAGS,
}
CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1)
CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1)

LOCKED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
Expand All @@ -155,19 +154,16 @@
OPERATION_DONE_MODEL_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'response': CREATED_MODEL_JSON_1
'response': CREATED_UPDATED_MODEL_JSON_1
}

OPERATION_MALFORMED_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
# if done is true then either response or error should be populated
}

OPERATION_MISSING_NAME = {
'done': False
}

OPERATION_ERROR_CODE = 400
OPERATION_ERROR_MSG = "Invalid argument"
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
Expand Down Expand Up @@ -254,15 +250,33 @@
}
ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST)

invalid_model_id_args = [
INVALID_MODEL_ID_ARGS = [
('', ValueError),
('&_*#@:/?', ValueError),
(None, TypeError),
(12345, TypeError),
]
INVALID_MODEL_ARGS = [
'abc',
4.2,
list(),
dict(),
True,
-1,
0,
None
]
INVALID_OP_NAME_ARGS = [
'abc',
'123',
'projects/operations/project/1234/model/abc/operation/123',
'operations/project/model/abc/operation/123',
'operations/project/123/model/$#@/operation/123',
'operations/project/1234/model/abc/operation/123/extrathing',
]
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
'1 and {0}'.format(mlkit._MAX_PAGE_SIZE)
invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()]
INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()]


# For validation type errors
Expand Down Expand Up @@ -524,7 +538,7 @@ def _get_url(project_id, model_id):
def test_immediate_done(self):
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = mlkit.create_model(MODEL_1)
assert model == CREATED_MODEL_1
assert model == CREATED_UPDATED_MODEL_1

def test_returns_locked(self):
recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -573,16 +587,7 @@ def test_rpc_error_create(self):
)
assert len(create_recorder) == 1

@pytest.mark.parametrize('model', [
'abc',
4.2,
list(),
dict(),
True,
-1,
0,
None
])
@pytest.mark.parametrize('model', INVALID_MODEL_ARGS)
def test_not_model(self, model):
with pytest.raises(Exception) as excinfo:
mlkit.create_model(model)
Expand All @@ -599,14 +604,7 @@ def test_missing_op_name(self):
mlkit.create_model(MODEL_1)
check_error(excinfo, TypeError)

@pytest.mark.parametrize('op_name', [
'abc',
'123',
'projects/operations/project/1234/model/abc/operation/123',
'operations/project/model/abc/operation/123',
'operations/project/123/model/$#@/operation/123',
'operations/project/1234/model/abc/operation/123/extrathing',
])
@pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS)
def test_invalid_op_name(self, op_name):
payload = json.dumps({'name': op_name})
instrument_mlkit_service(status=200, payload=payload)
Expand All @@ -615,6 +613,105 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestUpdateModel(object):
"""Tests mlkit.update_model."""
@classmethod
def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test

@classmethod
def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)

def test_immediate_done(self):
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = mlkit.update_model(MODEL_1)
assert model == CREATED_UPDATED_MODEL_1

def test_returns_locked(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)

assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_operation_error(self):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
# The http request succeeded, the operation returned contains a create failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_rpc_error_create(self):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_firebase_error(
excinfo,
ERROR_STATUS_BAD_REQUEST,
ERROR_CODE_BAD_REQUEST,
ERROR_MSG_BAD_REQUEST
)
assert len(create_recorder) == 1

@pytest.mark.parametrize('model', INVALID_MODEL_ARGS)
def test_not_model(self, model):
with pytest.raises(Exception) as excinfo:
mlkit.update_model(model)
check_error(excinfo, TypeError, 'Model must be an mlkit.Model.')

def test_missing_display_name(self):
with pytest.raises(Exception) as excinfo:
mlkit.update_model(mlkit.Model.from_dict({}))
check_error(excinfo, ValueError, 'Model must have a display name.')

def test_missing_op_name(self):
instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, TypeError)

@pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS)
def test_invalid_op_name(self, op_name):
payload = json.dumps({'name': op_name})
instrument_mlkit_service(status=200, payload=payload)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand All @@ -640,7 +737,7 @@ def test_get_model(self):
assert model.model_id == MODEL_ID_1
assert model.display_name == DISPLAY_NAME_1

@pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args)
@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
def test_get_model_validation_errors(self, model_id, exc_type):
with pytest.raises(exc_type) as excinfo:
mlkit.get_model(model_id)
Expand Down Expand Up @@ -690,7 +787,7 @@ def test_delete_model(self):
assert recorder[0].method == 'DELETE'
assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args)
@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
def test_delete_model_validation_errors(self, model_id, exc_type):
with pytest.raises(exc_type) as excinfo:
mlkit.delete_model(model_id)
Expand Down Expand Up @@ -771,7 +868,7 @@ def test_list_models_with_all_args(self):
assert models_page.models[0] == MODEL_3
assert not models_page.has_next_page

@pytest.mark.parametrize('list_filter', invalid_string_or_none_args)
@pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS)
def test_list_models_list_filter_validation(self, list_filter):
with pytest.raises(TypeError) as excinfo:
mlkit.list_models(list_filter=list_filter)
Expand All @@ -792,7 +889,7 @@ def test_list_models_page_size_validation(self, page_size, exc_type, error_messa
mlkit.list_models(page_size=page_size)
check_error(excinfo, exc_type, error_message)

@pytest.mark.parametrize('page_token', invalid_string_or_none_args)
@pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS)
def test_list_models_page_token_validation(self, page_token):
with pytest.raises(TypeError) as excinfo:
mlkit.list_models(page_token=page_token)
Expand Down