Skip to content

Commit 0344172

Browse files
authored
Firebase ML Kit Publish and Unpublish Implementation (#345)
* Firebase ML Kit Publish and Unpublish Implementation
1 parent 2a3be77 commit 0344172

File tree

2 files changed

+135
-6
lines changed

2 files changed

+135
-6
lines changed

firebase_admin/mlkit.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,34 @@ def update_model(model, app=None):
8787
return Model.from_dict(mlkit_service.update_model(model), app=app)
8888

8989

90+
def publish_model(model_id, app=None):
91+
"""Publishes a model in Firebase ML Kit.
92+
93+
Args:
94+
model_id: The id of the model to publish.
95+
app: A Firebase app instance (or None to use the default app).
96+
97+
Returns:
98+
Model: The published model.
99+
"""
100+
mlkit_service = _get_mlkit_service(app)
101+
return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app)
102+
103+
104+
def unpublish_model(model_id, app=None):
105+
"""Unpublishes a model in Firebase ML Kit.
106+
107+
Args:
108+
model_id: The id of the model to unpublish.
109+
app: A Firebase app instance (or None to use the default app).
110+
111+
Returns:
112+
Model: The unpublished model.
113+
"""
114+
mlkit_service = _get_mlkit_service(app)
115+
return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app)
116+
117+
90118
def get_model(model_id, app=None):
91119
"""Gets a model from Firebase ML Kit.
92120
@@ -562,12 +590,12 @@ class _MLKitService(object):
562590
POLL_BASE_WAIT_TIME_SECONDS = 3
563591

564592
def __init__(self, app):
565-
project_id = app.project_id
566-
if not project_id:
593+
self._project_id = app.project_id
594+
if not self._project_id:
567595
raise ValueError(
568596
'Project ID is required to access MLKit service. Either set the '
569597
'projectId option, or use service account credentials.')
570-
self._project_url = _MLKitService.PROJECT_URL.format(project_id)
598+
self._project_url = _MLKitService.PROJECT_URL.format(self._project_id)
571599
self._client = _http_client.JsonHttpClient(
572600
credential=app.credential.get_credential(),
573601
base_url=self._project_url)
@@ -595,7 +623,6 @@ def _exponential_backoff(self, current_attempt, stop_time):
595623
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
596624
time.sleep(wait_time_seconds)
597625

598-
599626
def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
600627
"""Handles long running operations.
601628
@@ -659,6 +686,17 @@ def update_model(self, model, update_mask=None):
659686
except requests.exceptions.RequestException as error:
660687
raise _utils.handle_platform_error_from_requests(error)
661688

689+
def set_published(self, model_id, publish):
690+
_validate_model_id(model_id)
691+
model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id)
692+
model = Model.from_dict({
693+
'name': model_name,
694+
'state': {
695+
'published': publish
696+
}
697+
})
698+
return self.update_model(model, update_mask='state.published')
699+
662700
def get_model(self, model_id):
663701
_validate_model_id(model_id)
664702
try:

tests/test_mlkit.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def test_operation_error(self):
657657
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
658658
with pytest.raises(Exception) as excinfo:
659659
mlkit.update_model(MODEL_1)
660-
# The http request succeeded, the operation returned contains a create failure
660+
# The http request succeeded, the operation returned contains an update failure
661661
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
662662

663663
def test_malformed_operation(self):
@@ -673,7 +673,7 @@ def test_malformed_operation(self):
673673
assert recorder[1].method == 'GET'
674674
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
675675

676-
def test_rpc_error_create(self):
676+
def test_rpc_error(self):
677677
create_recorder = instrument_mlkit_service(
678678
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
679679
with pytest.raises(Exception) as excinfo:
@@ -712,6 +712,97 @@ def test_invalid_op_name(self, op_name):
712712
check_error(excinfo, ValueError, 'Operation name format is invalid.')
713713

714714

715+
class TestPublishUnpublish(object):
716+
"""Tests mlkit.publish_model and mlkit.unpublish_model."""
717+
718+
PUBLISH_UNPUBLISH_WITH_ARGS = [
719+
(mlkit.publish_model, True),
720+
(mlkit.unpublish_model, False)
721+
]
722+
PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS]
723+
724+
@classmethod
725+
def setup_class(cls):
726+
cred = testutils.MockCredential()
727+
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
728+
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
729+
730+
@classmethod
731+
def teardown_class(cls):
732+
testutils.cleanup_apps()
733+
734+
@staticmethod
735+
def _url(project_id, model_id):
736+
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
737+
738+
@staticmethod
739+
def _op_url(project_id, model_id):
740+
return BASE_URL + \
741+
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
742+
743+
@pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS)
744+
def test_immediate_done(self, publish_function, published):
745+
recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
746+
model = publish_function(MODEL_ID_1)
747+
assert model == CREATED_UPDATED_MODEL_1
748+
assert len(recorder) == 1
749+
assert recorder[0].method == 'PATCH'
750+
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
751+
body = json.loads(recorder[0].body.decode())
752+
assert body.get('model', {}).get('state', {}).get('published', None) is published
753+
assert body.get('updateMask', {}) == 'state.published'
754+
755+
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
756+
def test_returns_locked(self, publish_function):
757+
recorder = instrument_mlkit_service(
758+
status=[200, 200],
759+
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
760+
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
761+
model = publish_function(MODEL_ID_1)
762+
763+
assert model == expected_model
764+
assert len(recorder) == 2
765+
assert recorder[0].method == 'PATCH'
766+
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
767+
assert recorder[1].method == 'GET'
768+
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
769+
770+
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
771+
def test_operation_error(self, publish_function):
772+
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
773+
with pytest.raises(Exception) as excinfo:
774+
publish_function(MODEL_ID_1)
775+
# The http request succeeded, the operation returned contains an update failure
776+
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
777+
778+
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
779+
def test_malformed_operation(self, publish_function):
780+
recorder = instrument_mlkit_service(
781+
status=[200, 200],
782+
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
783+
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
784+
model = publish_function(MODEL_ID_1)
785+
assert model == expected_model
786+
assert len(recorder) == 2
787+
assert recorder[0].method == 'PATCH'
788+
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
789+
assert recorder[1].method == 'GET'
790+
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
791+
792+
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
793+
def test_rpc_error(self, publish_function):
794+
create_recorder = instrument_mlkit_service(
795+
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
796+
with pytest.raises(Exception) as excinfo:
797+
publish_function(MODEL_ID_1)
798+
check_firebase_error(
799+
excinfo,
800+
ERROR_STATUS_BAD_REQUEST,
801+
ERROR_CODE_BAD_REQUEST,
802+
ERROR_MSG_BAD_REQUEST
803+
)
804+
assert len(create_recorder) == 1
805+
715806
class TestGetModel(object):
716807
"""Tests mlkit.get_model."""
717808
@classmethod

0 commit comments

Comments
 (0)