Skip to content

Commit 0b70687

Browse files
authored
Integration tests for Firebase ML (#394)
* Integration tests for Firebase ML
1 parent cf748c8 commit 0b70687

File tree

3 files changed

+374
-0
lines changed

3 files changed

+374
-0
lines changed

integration/test_ml.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
# Copyright 2020 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Integration tests for firebase_admin.ml module."""
16+
import os
17+
import random
18+
import re
19+
import shutil
20+
import string
21+
import tempfile
22+
import pytest
23+
24+
25+
from firebase_admin import exceptions
26+
from firebase_admin import ml
27+
from tests import testutils
28+
29+
30+
# pylint: disable=import-error,no-name-in-module
31+
try:
32+
import tensorflow as tf
33+
_TF_ENABLED = True
34+
except ImportError:
35+
_TF_ENABLED = False
36+
37+
38+
def _random_identifier(prefix):
39+
#pylint: disable=unused-variable
40+
suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
41+
return '{0}_{1}'.format(prefix, suffix)
42+
43+
44+
NAME_ONLY_ARGS = {
45+
'display_name': _random_identifier('TestModel123_')
46+
}
47+
NAME_ONLY_ARGS_UPDATED = {
48+
'display_name': _random_identifier('TestModel123_updated_')
49+
}
50+
NAME_AND_TAGS_ARGS = {
51+
'display_name': _random_identifier('TestModel123_tags_'),
52+
'tags': ['test_tag123']
53+
}
54+
FULL_MODEL_ARGS = {
55+
'display_name': _random_identifier('TestModel123_full_'),
56+
'tags': ['test_tag567'],
57+
'file_name': 'model1.tflite'
58+
}
59+
INVALID_FULL_MODEL_ARGS = {
60+
'display_name': _random_identifier('TestModel123_invalid_full_'),
61+
'tags': ['test_tag890'],
62+
'file_name': 'invalid_model.tflite'
63+
}
64+
65+
66+
@pytest.fixture
67+
def firebase_model(request):
68+
args = request.param
69+
tflite_format = None
70+
file_name = args.get('file_name')
71+
if file_name:
72+
file_path = testutils.resource_filename(file_name)
73+
source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path)
74+
tflite_format = ml.TFLiteFormat(model_source=source)
75+
76+
ml_model = ml.Model(
77+
display_name=args.get('display_name'),
78+
tags=args.get('tags'),
79+
model_format=tflite_format)
80+
model = ml.create_model(model=ml_model)
81+
yield model
82+
_clean_up_model(model)
83+
84+
85+
@pytest.fixture
86+
def model_list():
87+
ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_'))
88+
model_1 = ml.create_model(model=ml_model_1)
89+
90+
ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'),
91+
tags=['test_tag123'])
92+
model_2 = ml.create_model(model=ml_model_2)
93+
94+
yield [model_1, model_2]
95+
96+
_clean_up_model(model_1)
97+
_clean_up_model(model_2)
98+
99+
100+
def _clean_up_model(model):
101+
try:
102+
# Try to delete the model.
103+
# Some tests delete the model as part of the test.
104+
ml.delete_model(model.model_id)
105+
except exceptions.NotFoundError:
106+
pass
107+
108+
109+
# For rpc errors
110+
def check_firebase_error(excinfo, status, msg):
111+
err = excinfo.value
112+
assert isinstance(err, exceptions.FirebaseError)
113+
assert err.cause is not None
114+
assert err.http_response is not None
115+
assert err.http_response.status_code == status
116+
assert str(err) == msg
117+
118+
119+
# For operation errors
120+
def check_operation_error(excinfo, msg):
121+
err = excinfo.value
122+
assert isinstance(err, exceptions.FirebaseError)
123+
assert str(err) == msg
124+
125+
126+
def check_model(model, args):
127+
assert model.display_name == args.get('display_name')
128+
assert model.tags == args.get('tags')
129+
assert model.model_id is not None
130+
assert model.create_time is not None
131+
assert model.update_time is not None
132+
assert model.locked is False
133+
assert model.etag is not None
134+
135+
136+
def check_model_format(model, has_model_format=False, validation_error=None):
137+
if has_model_format:
138+
assert model.validation_error == validation_error
139+
assert model.published is False
140+
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
141+
if validation_error:
142+
assert model.model_format.size_bytes is None
143+
assert model.model_hash is None
144+
else:
145+
assert model.model_format.size_bytes is not None
146+
assert model.model_hash is not None
147+
else:
148+
assert model.model_format is None
149+
assert model.validation_error == 'No model file has been uploaded.'
150+
assert model.published is False
151+
assert model.model_hash is None
152+
153+
154+
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
155+
def test_create_simple_model(firebase_model):
156+
check_model(firebase_model, NAME_AND_TAGS_ARGS)
157+
check_model_format(firebase_model)
158+
159+
160+
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
161+
def test_create_full_model(firebase_model):
162+
check_model(firebase_model, FULL_MODEL_ARGS)
163+
check_model_format(firebase_model, True)
164+
165+
166+
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
167+
def test_create_already_existing_fails(firebase_model):
168+
with pytest.raises(exceptions.AlreadyExistsError) as excinfo:
169+
ml.create_model(model=firebase_model)
170+
check_operation_error(
171+
excinfo,
172+
'Model \'{0}\' already exists'.format(firebase_model.display_name))
173+
174+
175+
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
176+
def test_create_invalid_model(firebase_model):
177+
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
178+
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
179+
180+
181+
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
182+
def test_get_model(firebase_model):
183+
get_model = ml.get_model(firebase_model.model_id)
184+
check_model(get_model, NAME_AND_TAGS_ARGS)
185+
check_model_format(get_model)
186+
187+
188+
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
189+
def test_get_non_existing_model(firebase_model):
190+
# Get a valid model_id that no longer exists
191+
ml.delete_model(firebase_model.model_id)
192+
193+
with pytest.raises(exceptions.NotFoundError) as excinfo:
194+
ml.get_model(firebase_model.model_id)
195+
check_firebase_error(excinfo, 404, 'Requested entity was not found.')
196+
197+
198+
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
199+
def test_update_model(firebase_model):
200+
new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name')
201+
firebase_model.display_name = new_model_name
202+
updated_model = ml.update_model(firebase_model)
203+
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
204+
check_model_format(updated_model)
205+
206+
# Second call with same model does not cause error
207+
updated_model2 = ml.update_model(updated_model)
208+
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
209+
check_model_format(updated_model2)
210+
211+
212+
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
213+
def test_update_non_existing_model(firebase_model):
214+
ml.delete_model(firebase_model.model_id)
215+
216+
firebase_model.tags = ['tag987']
217+
with pytest.raises(exceptions.NotFoundError) as excinfo:
218+
ml.update_model(firebase_model)
219+
check_operation_error(
220+
excinfo,
221+
'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name')))
222+
223+
224+
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
225+
def test_publish_unpublish_model(firebase_model):
226+
assert firebase_model.published is False
227+
228+
published_model = ml.publish_model(firebase_model.model_id)
229+
assert published_model.published is True
230+
231+
unpublished_model = ml.unpublish_model(published_model.model_id)
232+
assert unpublished_model.published is False
233+
234+
235+
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
236+
def test_publish_invalid_fails(firebase_model):
237+
assert firebase_model.validation_error is not None
238+
239+
with pytest.raises(exceptions.FailedPreconditionError) as excinfo:
240+
ml.publish_model(firebase_model.model_id)
241+
check_operation_error(
242+
excinfo,
243+
'Cannot publish a model that is not verified.')
244+
245+
246+
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
247+
def test_publish_unpublish_non_existing_model(firebase_model):
248+
ml.delete_model(firebase_model.model_id)
249+
250+
with pytest.raises(exceptions.NotFoundError) as excinfo:
251+
ml.publish_model(firebase_model.model_id)
252+
check_operation_error(
253+
excinfo,
254+
'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name')))
255+
256+
with pytest.raises(exceptions.NotFoundError) as excinfo:
257+
ml.unpublish_model(firebase_model.model_id)
258+
check_operation_error(
259+
excinfo,
260+
'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name')))
261+
262+
263+
def test_list_models(model_list):
264+
filter_str = 'displayName={0} OR tags:{1}'.format(
265+
model_list[0].display_name, model_list[1].tags[0])
266+
267+
all_models = ml.list_models(list_filter=filter_str)
268+
all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()]
269+
for mdl in model_list:
270+
assert mdl.model_id in all_model_ids
271+
272+
273+
def test_list_models_invalid_filter():
274+
invalid_filter = 'InvalidFilterParam=123'
275+
276+
with pytest.raises(exceptions.InvalidArgumentError) as excinfo:
277+
ml.list_models(list_filter=invalid_filter)
278+
check_firebase_error(excinfo, 400, 'Request contains an invalid argument.')
279+
280+
281+
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
282+
def test_delete_model(firebase_model):
283+
ml.delete_model(firebase_model.model_id)
284+
285+
# Second delete of same model will fail
286+
with pytest.raises(exceptions.NotFoundError) as excinfo:
287+
ml.delete_model(firebase_model.model_id)
288+
check_firebase_error(excinfo, 404, 'Requested entity was not found.')
289+
290+
291+
# Test tensor flow conversion functions if tensor flow is enabled.
292+
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
293+
#'pip install tensorflow==2.0.0b' for version 2 etc.
294+
295+
296+
def _clean_up_directory(save_dir):
297+
if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir):
298+
shutil.rmtree(save_dir)
299+
300+
301+
@pytest.fixture
302+
def keras_model():
303+
assert _TF_ENABLED
304+
x_array = [-1, 0, 1, 2, 3, 4]
305+
y_array = [-3, -1, 1, 3, 5, 7]
306+
model = tf.keras.models.Sequential(
307+
[tf.keras.layers.Dense(units=1, input_shape=[1])])
308+
model.compile(optimizer='sgd', loss='mean_squared_error')
309+
model.fit(x_array, y_array, epochs=3)
310+
return model
311+
312+
313+
@pytest.fixture
314+
def saved_model_dir(keras_model):
315+
assert _TF_ENABLED
316+
# Make a new parent directory. The child directory must not exist yet.
317+
# The child directory gets created by tf. If it exists, the tf call fails.
318+
parent = tempfile.mkdtemp()
319+
save_dir = os.path.join(parent, 'child')
320+
321+
# different versions have different model conversion capability
322+
# pick something that works for each version
323+
if tf.version.VERSION.startswith('1.'):
324+
tf.reset_default_graph()
325+
x_var = tf.placeholder(tf.float32, (None, 3), name="x")
326+
y_var = tf.multiply(x_var, x_var, name="y")
327+
with tf.Session() as sess:
328+
tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var})
329+
else:
330+
# If it's not version 1.x or version 2.x we need to update the test.
331+
assert tf.version.VERSION.startswith('2.')
332+
tf.saved_model.save(keras_model, save_dir)
333+
yield save_dir
334+
_clean_up_directory(parent)
335+
336+
337+
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
338+
def test_from_keras_model(keras_model):
339+
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
340+
assert re.search(
341+
'^gs://.*/Firebase/ML/Models/model2.tflite$',
342+
source.gcs_tflite_uri) is not None
343+
344+
# Validate the conversion by creating a model
345+
model_format = ml.TFLiteFormat(model_source=source)
346+
model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format)
347+
created_model = ml.create_model(model)
348+
349+
try:
350+
check_model(created_model, {'display_name': model.display_name})
351+
check_model_format(created_model, True)
352+
finally:
353+
_clean_up_model(created_model)
354+
355+
356+
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
357+
def test_from_saved_model(saved_model_dir):
358+
# Test the conversion helper
359+
source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite')
360+
assert re.search(
361+
'^gs://.*/Firebase/ML/Models/model3.tflite$',
362+
source.gcs_tflite_uri) is not None
363+
364+
# Validate the conversion by creating a model
365+
model_format = ml.TFLiteFormat(model_source=source)
366+
model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format)
367+
created_model = ml.create_model(model)
368+
369+
try:
370+
assert created_model.model_id is not None
371+
assert created_model.validation_error is None
372+
finally:
373+
_clean_up_model(created_model)

tests/data/invalid_model.tflite

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is not a tflite file.

tests/data/model1.tflite

736 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)