Skip to content

Commit 65f64c0

Browse files
authored
Firebase ML Kit Get Model API implementation (#326)
* added GetModel * Added tests for get_model
1 parent dd3c4bd commit 65f64c0

File tree

3 files changed

+205
-3
lines changed

3 files changed

+205
-3
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ before_install:
1616
- nvm install 8 && npm install -g firebase-tools
1717
script:
1818
- pytest
19-
- firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py'
2019
cache:
2120
pip: true
2221
npm: true

firebase_admin/mlkit.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,77 @@
1818
deleting, publishing and unpublishing Firebase ML Kit models.
1919
"""
2020

21+
import re
22+
import requests
23+
import six
24+
25+
from firebase_admin import _http_client
26+
from firebase_admin import _utils
27+
28+
29+
_MLKIT_ATTRIBUTE = '_mlkit'
30+
31+
32+
def _get_mlkit_service(app):
33+
""" Returns an _MLKitService instance for an App.
34+
35+
Args:
36+
app: A Firebase App instance (or None to use the default App).
37+
38+
Returns:
39+
_MLKitService: An _MLKitService for the specified App instance.
40+
41+
Raises:
42+
ValueError: If the app argument is invalid.
43+
"""
44+
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)
45+
46+
47+
def get_model(model_id, app=None):
48+
mlkit_service = _get_mlkit_service(app)
49+
return Model(mlkit_service.get_model(model_id))
50+
51+
52+
class Model(object):
53+
"""A Firebase ML Kit Model object."""
54+
def __init__(self, data):
55+
"""Created from a data dictionary."""
56+
self._data = data
57+
58+
def __eq__(self, other):
59+
if isinstance(other, self.__class__):
60+
return self._data == other._data # pylint: disable=protected-access
61+
else:
62+
return False
63+
64+
def __ne__(self, other):
65+
return not self.__eq__(other)
66+
67+
#TODO(ifielker): define the Model properties etc
68+
69+
2170
class _MLKitService(object):
2271
"""Firebase MLKit service."""
2372

24-
BASE_URL = 'https://mlkit.googleapis.com'
25-
PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/'
73+
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
74+
75+
def __init__(self, app):
76+
project_id = app.project_id
77+
if not project_id:
78+
raise ValueError(
79+
'Project ID is required to access MLKit service. Either set the '
80+
'projectId option, or use service account credentials.')
81+
self._project_url = _MLKitService.PROJECT_URL.format(project_id)
82+
self._client = _http_client.JsonHttpClient(
83+
credential=app.credential.get_credential(),
84+
base_url=self._project_url)
85+
86+
def get_model(self, model_id):
87+
if not isinstance(model_id, six.string_types):
88+
raise TypeError('Model ID must be a string.')
89+
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id):
90+
raise ValueError('Model ID format is invalid.')
91+
try:
92+
return self._client.body('get', url='models/{0}'.format(model_id))
93+
except requests.exceptions.RequestException as error:
94+
raise _utils.handle_platform_error_from_requests(error)

tests/test_mlkit.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2019 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test cases for the firebase_admin.mlkit module."""
16+
17+
import json
18+
import pytest
19+
20+
import firebase_admin
21+
from firebase_admin import exceptions
22+
from firebase_admin import mlkit
23+
from tests import testutils
24+
25+
BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
26+
27+
PROJECT_ID = 'myProject1'
28+
MODEL_ID_1 = 'modelId1'
29+
MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1)
30+
DISPLAY_NAME_1 = 'displayName1'
31+
MODEL_JSON_1 = {
32+
'name': MODEL_NAME_1,
33+
'displayName': DISPLAY_NAME_1
34+
}
35+
MODEL_1 = mlkit.Model(MODEL_JSON_1)
36+
_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1)
37+
38+
ERROR_CODE = 404
39+
ERROR_MSG = 'The resource was not found'
40+
ERROR_STATUS = 'NOT_FOUND'
41+
ERROR_JSON = {
42+
'error': {
43+
'code': ERROR_CODE,
44+
'message': ERROR_MSG,
45+
'status': ERROR_STATUS
46+
}
47+
}
48+
_ERROR_RESPONSE = json.dumps(ERROR_JSON)
49+
50+
51+
class TestGetModel(object):
52+
"""Tests mlkit.get_model."""
53+
@classmethod
54+
def setup_class(cls):
55+
cred = testutils.MockCredential()
56+
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
57+
58+
@classmethod
59+
def teardown_class(cls):
60+
testutils.cleanup_apps()
61+
62+
@staticmethod
63+
def check_error(err, err_type, msg):
64+
assert isinstance(err, err_type)
65+
assert str(err) == msg
66+
67+
@staticmethod
68+
def check_firebase_error(err, code, status, msg):
69+
assert isinstance(err, exceptions.FirebaseError)
70+
assert err.code == code
71+
assert err.http_response is not None
72+
assert err.http_response.status_code == status
73+
assert str(err) == msg
74+
75+
def _get_url(self, project_id, model_id):
76+
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
77+
78+
def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE):
79+
if not app:
80+
app = firebase_admin.get_app()
81+
mlkit_service = mlkit._get_mlkit_service(app)
82+
recorder = []
83+
mlkit_service._client.session.mount(
84+
'https://mlkit.googleapis.com',
85+
testutils.MockAdapter(payload, status, recorder)
86+
)
87+
return mlkit_service, recorder
88+
89+
def test_get_model(self):
90+
_, recorder = self._instrument_mlkit_service()
91+
model = mlkit.get_model(MODEL_ID_1)
92+
assert len(recorder) == 1
93+
assert recorder[0].method == 'GET'
94+
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1)
95+
assert model == MODEL_1
96+
assert model._data['name'] == MODEL_NAME_1
97+
assert model._data['displayName'] == DISPLAY_NAME_1
98+
99+
def test_get_model_validation_errors(self):
100+
#Empty model-id
101+
with pytest.raises(ValueError) as err:
102+
mlkit.get_model('')
103+
self.check_error(err.value, ValueError, 'Model ID format is invalid.')
104+
105+
#None model-id
106+
with pytest.raises(TypeError) as err:
107+
mlkit.get_model(None)
108+
self.check_error(err.value, TypeError, 'Model ID must be a string.')
109+
110+
#Wrong type
111+
with pytest.raises(TypeError) as err:
112+
mlkit.get_model(12345)
113+
self.check_error(err.value, TypeError, 'Model ID must be a string.')
114+
115+
#Invalid characters
116+
with pytest.raises(ValueError) as err:
117+
mlkit.get_model('&_*#@:/?')
118+
self.check_error(err.value, ValueError, 'Model ID format is invalid.')
119+
120+
def test_get_model_error(self):
121+
_, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE)
122+
with pytest.raises(exceptions.NotFoundError) as err:
123+
mlkit.get_model(MODEL_ID_1)
124+
self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG)
125+
assert len(recorder) == 1
126+
assert recorder[0].method == 'GET'
127+
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1)
128+
129+
def test_no_project_id(self):
130+
def evaluate():
131+
app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id')
132+
with pytest.raises(ValueError):
133+
mlkit.get_model(MODEL_ID_1, app)
134+
testutils.run_without_project_id(evaluate)

0 commit comments

Comments
 (0)