Skip to content

Firebase ML Kit Publish and Unpublish Implementation #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@ def update_model(model, app=None):
return Model.from_dict(mlkit_service.update_model(model), app=app)


def publish_model(model_id, app=None):
"""Publishes a model in Firebase ML Kit.

Args:
model_id: The id of the model to publish.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The published model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app)


def unpublish_model(model_id, app=None):
"""Unpublishes a model in Firebase ML Kit.

Args:
model_id: The id of the model to unpublish.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The unpublished model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Expand Down Expand Up @@ -562,12 +590,12 @@ class _MLKitService(object):
POLL_BASE_WAIT_TIME_SECONDS = 3

def __init__(self, app):
project_id = app.project_id
if not project_id:
self._project_id = app.project_id
if not self._project_id:
raise ValueError(
'Project ID is required to access MLKit service. Either set the '
'projectId option, or use service account credentials.')
self._project_url = _MLKitService.PROJECT_URL.format(project_id)
self._project_url = _MLKitService.PROJECT_URL.format(self._project_id)
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
Expand Down Expand Up @@ -595,7 +623,6 @@ def _exponential_backoff(self, current_attempt, stop_time):
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)


def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
"""Handles long running operations.

Expand Down Expand Up @@ -659,6 +686,17 @@ def update_model(self, model, update_mask=None):
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def set_published(self, model_id, publish):
_validate_model_id(model_id)
model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id)
model = Model.from_dict({
'name': model_name,
'state': {
'published': publish
}
})
return self.update_model(model, update_mask='state.published')

def get_model(self, model_id):
_validate_model_id(model_id)
try:
Expand Down
95 changes: 93 additions & 2 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def test_operation_error(self):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
# The http request succeeded, the operation returned contains a create failure
# The http request succeeded, the operation returned contains an update failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
Expand All @@ -673,7 +673,7 @@ def test_malformed_operation(self):
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_rpc_error_create(self):
def test_rpc_error(self):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
Expand Down Expand Up @@ -712,6 +712,97 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestPublishUnpublish(object):
"""Tests mlkit.publish_model and mlkit.unpublish_model."""

PUBLISH_UNPUBLISH_WITH_ARGS = [
(mlkit.publish_model, True),
(mlkit.unpublish_model, False)
]
PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS]

@classmethod
def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test

@classmethod
def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)

@pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS)
def test_immediate_done(self, publish_function, published):
recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = publish_function(MODEL_ID_1)
assert model == CREATED_UPDATED_MODEL_1
assert len(recorder) == 1
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
body = json.loads(recorder[0].body.decode())
assert body.get('model', {}).get('state', {}).get('published', None) is published
assert body.get('updateMask', {}) == 'state.published'

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_returns_locked(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)

assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_operation_error(self, publish_function):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
# The http request succeeded, the operation returned contains an update failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_malformed_operation(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_rpc_error(self, publish_function):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
check_firebase_error(
excinfo,
ERROR_STATUS_BAD_REQUEST,
ERROR_CODE_BAD_REQUEST,
ERROR_MSG_BAD_REQUEST
)
assert len(create_recorder) == 1

class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down