Skip to content

Commit 7a7a2a5

Browse files
committed
Integration tests for Firebase ML
1 parent cf748c8 commit 7a7a2a5

File tree

3 files changed

+390
-0
lines changed

3 files changed

+390
-0
lines changed

integration/test_ml.py

Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# Copyright 2020 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Integration tests for firebase_admin.ml module."""
16+
import re
17+
import pytest
18+
19+
20+
from firebase_admin import ml
21+
from firebase_admin import exceptions
22+
from tests import testutils
23+
24+
25+
# pylint: disable=import-error,no-name-in-module
26+
try:
27+
import tensorflow as tf
28+
import os # This is only needed for the tensorflow testing
29+
import shutil # This is only needed for the tensorflow testing
30+
_TF_ENABLED = True
31+
except ImportError:
32+
_TF_ENABLED = False
33+
34+
35+
@pytest.fixture
36+
def name_only_model():
37+
model = ml.Model(display_name="TestModel123")
38+
yield model
39+
40+
41+
@pytest.fixture
42+
def name_and_tags_model():
43+
model = ml.Model(display_name="TestModel123_tags", tags=['test_tag123'])
44+
yield model
45+
46+
47+
@pytest.fixture
48+
def full_model():
49+
tflite_file_name = testutils.resource_filename('model1.tflite')
50+
source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name)
51+
format1 = ml.TFLiteFormat(model_source=source1)
52+
model = ml.Model(
53+
display_name="TestModel123_full",
54+
tags=['test_tag567'],
55+
model_format=format1)
56+
yield model
57+
58+
59+
@pytest.fixture
60+
def invalid_full_model():
61+
tflite_file_name = testutils.resource_filename('invalid_model.tflite')
62+
source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name)
63+
format1 = ml.TFLiteFormat(model_source=source1)
64+
model = ml.Model(
65+
display_name="TestModel123_invalid_full",
66+
tags=['test_tag890'],
67+
model_format=format1)
68+
yield model
69+
70+
71+
# For rpc errors
72+
def check_firebase_error(excinfo, status, msg):
73+
err = excinfo.value
74+
assert isinstance(err, exceptions.FirebaseError)
75+
assert err.cause is not None
76+
assert err.http_response is not None
77+
assert err.http_response.status_code == status
78+
assert str(err) == msg
79+
80+
81+
# For operation errors
82+
def check_operation_error(excinfo, msg):
83+
err = excinfo.value
84+
assert isinstance(err, exceptions.FirebaseError)
85+
assert str(err) == msg
86+
87+
88+
def _ensure_model_exists(model):
89+
# Delete any previously existing model with the same name because
90+
# it may be modified from the model that is passed in.
91+
_delete_if_exists(model)
92+
93+
# And recreate using the model passed in
94+
created_model = ml.create_model(model=model)
95+
return created_model
96+
97+
98+
# Use this when you know the model_id and are sure it exists.
99+
def _clean_up_model(model):
100+
ml.delete_model(model.model_id)
101+
102+
103+
# Use this when you don't know the model_id or it may not exist.
104+
def _delete_if_exists(model):
105+
filter_str = 'displayName={0}'.format(model.display_name)
106+
models_list = ml.list_models(list_filter=filter_str)
107+
for mdl in models_list.models:
108+
ml.delete_model(mdl.model_id)
109+
110+
111+
def test_create_simple_model(name_and_tags_model):
112+
_delete_if_exists(name_and_tags_model)
113+
114+
firebase_model = ml.create_model(model=name_and_tags_model)
115+
assert firebase_model.display_name == name_and_tags_model.display_name
116+
assert firebase_model.tags == name_and_tags_model.tags
117+
assert firebase_model.model_id is not None
118+
assert firebase_model.create_time is not None
119+
assert firebase_model.update_time is not None
120+
assert firebase_model.validation_error == 'No model file has been uploaded.'
121+
assert firebase_model.locked is False
122+
assert firebase_model.published is False
123+
assert firebase_model.etag is not None
124+
assert firebase_model.model_hash is None
125+
126+
_clean_up_model(firebase_model)
127+
128+
def test_create_full_model(full_model):
129+
_delete_if_exists(full_model)
130+
131+
firebase_model = ml.create_model(model=full_model)
132+
assert firebase_model.display_name == full_model.display_name
133+
assert firebase_model.tags == full_model.tags
134+
assert firebase_model.model_format.size_bytes is not None
135+
assert firebase_model.model_format.model_source == full_model.model_format.model_source
136+
assert firebase_model.model_id is not None
137+
assert firebase_model.create_time is not None
138+
assert firebase_model.update_time is not None
139+
assert firebase_model.validation_error is None
140+
assert firebase_model.locked is False
141+
assert firebase_model.published is False
142+
assert firebase_model.etag is not None
143+
assert firebase_model.model_hash is not None
144+
145+
_clean_up_model(firebase_model)
146+
147+
148+
def test_create_already_existing_fails(full_model):
149+
_ensure_model_exists(full_model)
150+
with pytest.raises(exceptions.AlreadyExistsError) as excinfo:
151+
ml.create_model(model=full_model)
152+
check_operation_error(
153+
excinfo,
154+
'Model \'{0}\' already exists'.format(full_model.display_name))
155+
156+
157+
def test_create_invalid_model(invalid_full_model):
158+
_delete_if_exists(invalid_full_model)
159+
160+
firebase_model = ml.create_model(model=invalid_full_model)
161+
assert firebase_model.display_name == invalid_full_model.display_name
162+
assert firebase_model.tags == invalid_full_model.tags
163+
assert firebase_model.model_format.size_bytes is None
164+
assert firebase_model.model_format.model_source == invalid_full_model.model_format.model_source
165+
assert firebase_model.model_id is not None
166+
assert firebase_model.create_time is not None
167+
assert firebase_model.update_time is not None
168+
assert firebase_model.validation_error == 'Invalid flatbuffer format'
169+
assert firebase_model.locked is False
170+
assert firebase_model.published is False
171+
assert firebase_model.etag is not None
172+
assert firebase_model.model_hash is None
173+
174+
_clean_up_model(firebase_model)
175+
176+
def test_get_model(name_only_model):
177+
existing_model = _ensure_model_exists(name_only_model)
178+
179+
firebase_model = ml.get_model(existing_model.model_id)
180+
assert firebase_model.display_name == name_only_model.display_name
181+
assert firebase_model.model_id is not None
182+
assert firebase_model.create_time is not None
183+
assert firebase_model.update_time is not None
184+
assert firebase_model.validation_error == 'No model file has been uploaded.'
185+
assert firebase_model.etag is not None
186+
assert firebase_model.locked is False
187+
assert firebase_model.published is False
188+
assert firebase_model.model_hash is None
189+
190+
_clean_up_model(firebase_model)
191+
192+
193+
def test_get_non_existing_model(name_only_model):
194+
# Get a valid model_id that no longer exists
195+
model = _ensure_model_exists(name_only_model)
196+
ml.delete_model(model.model_id)
197+
198+
with pytest.raises(exceptions.NotFoundError) as excinfo:
199+
ml.get_model(model.model_id)
200+
check_firebase_error(excinfo, 404, 'Requested entity was not found.')
201+
202+
203+
def test_update_model(name_only_model):
204+
new_model_name = 'TestModel123_updated'
205+
_delete_if_exists(ml.Model(display_name=new_model_name))
206+
existing_model = _ensure_model_exists(name_only_model)
207+
existing_model.display_name = new_model_name
208+
209+
firebase_model = ml.update_model(existing_model)
210+
assert firebase_model.display_name == new_model_name
211+
assert firebase_model.model_id == existing_model.model_id
212+
assert firebase_model.create_time == existing_model.create_time
213+
assert firebase_model.update_time != existing_model.update_time
214+
assert firebase_model.validation_error == existing_model.validation_error
215+
assert firebase_model.etag != existing_model.etag
216+
assert firebase_model.published == existing_model.published
217+
assert firebase_model.locked == existing_model.locked
218+
219+
# Second call with same model does not cause error
220+
firebase_model2 = ml.update_model(firebase_model)
221+
assert firebase_model2.display_name == firebase_model.display_name
222+
assert firebase_model2.model_id == firebase_model.model_id
223+
assert firebase_model2.create_time == firebase_model.create_time
224+
assert firebase_model2.update_time != firebase_model.update_time
225+
assert firebase_model2.validation_error == firebase_model.validation_error
226+
assert firebase_model2.etag != existing_model.etag
227+
assert firebase_model2.published == firebase_model.published
228+
assert firebase_model2.locked == firebase_model.locked
229+
230+
_clean_up_model(firebase_model)
231+
232+
233+
def test_update_non_existing_model(name_only_model):
234+
model = _ensure_model_exists(name_only_model)
235+
ml.delete_model(model.model_id)
236+
237+
model.tags = ['tag987']
238+
with pytest.raises(exceptions.NotFoundError) as excinfo:
239+
ml.update_model(model)
240+
check_operation_error(
241+
excinfo,
242+
'Model \'{0}\' was not found'.format(model.as_dict().get('name')))
243+
244+
def test_publish_unpublish_model(full_model):
245+
model = _ensure_model_exists(full_model)
246+
assert model.published is False
247+
248+
published_model = ml.publish_model(model.model_id)
249+
assert published_model.published is True
250+
251+
unpublished_model = ml.unpublish_model(published_model.model_id)
252+
assert unpublished_model.published is False
253+
254+
_clean_up_model(unpublished_model)
255+
256+
257+
def test_publish_invalid_fails(name_only_model):
258+
model = _ensure_model_exists(name_only_model)
259+
assert model.validation_error is not None
260+
261+
with pytest.raises(exceptions.FailedPreconditionError) as excinfo:
262+
ml.publish_model(model.model_id)
263+
check_operation_error(
264+
excinfo,
265+
'Cannot publish a model that is not verified.')
266+
267+
268+
def test_publish_unpublish_non_existing_model(full_model):
269+
model = _ensure_model_exists(full_model)
270+
ml.delete_model(model.model_id)
271+
272+
with pytest.raises(exceptions.NotFoundError) as excinfo:
273+
ml.publish_model(model.model_id)
274+
check_operation_error(
275+
excinfo,
276+
'Model \'{0}\' was not found'.format(model.as_dict().get('name')))
277+
278+
with pytest.raises(exceptions.NotFoundError) as excinfo:
279+
ml.unpublish_model(model.model_id)
280+
check_operation_error(
281+
excinfo,
282+
'Model \'{0}\' was not found'.format(model.as_dict().get('name')))
283+
284+
285+
def test_list_models(name_only_model, name_and_tags_model):
286+
existing_model1 = _ensure_model_exists(name_only_model)
287+
existing_model2 = _ensure_model_exists(name_and_tags_model)
288+
filter_str = 'displayName={0} OR tags:{1}'.format(
289+
existing_model1.display_name, existing_model2.tags[0])
290+
291+
models_list = ml.list_models(list_filter=filter_str)
292+
assert len(models_list.models) == 2
293+
for mdl in models_list.models:
294+
assert mdl == existing_model1 or mdl == existing_model2
295+
assert models_list.models[0] != models_list.models[1]
296+
297+
_clean_up_model(existing_model1)
298+
_clean_up_model(existing_model2)
299+
300+
301+
def test_list_models_invalid_filter():
302+
invalid_filter = 'InvalidFilterParam=123'
303+
304+
with pytest.raises(exceptions.InvalidArgumentError) as excinfo:
305+
ml.list_models(list_filter=invalid_filter)
306+
check_firebase_error(excinfo, 400, 'Request contains an invalid argument.')
307+
308+
309+
def test_delete_model(name_only_model):
310+
existing_model = _ensure_model_exists(name_only_model)
311+
312+
ml.delete_model(existing_model.model_id)
313+
314+
# Second delete of same model will fail
315+
with pytest.raises(exceptions.NotFoundError) as excinfo:
316+
ml.delete_model(existing_model.model_id)
317+
check_firebase_error(excinfo, 404, 'Requested entity was not found.')
318+
319+
320+
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
321+
#'pip install tensorflow=2.0.0' for version 2 etc.
322+
if _TF_ENABLED:
323+
# Test tensor flow conversion functions if tensor flow is enabled.
324+
SAVED_MODEL_DIR = '/tmp/saved_model/1'
325+
326+
def _clean_up_tmp_directory():
327+
if os.path.exists(SAVED_MODEL_DIR):
328+
shutil.rmtree(SAVED_MODEL_DIR)
329+
330+
@pytest.fixture
331+
def keras_model():
332+
x_array = [-1, 0, 1, 2, 3, 4]
333+
y_array = [-3, -1, 1, 3, 5, 7]
334+
model = tf.keras.models.Sequential(
335+
[tf.keras.layers.Dense(units=1, input_shape=[1])])
336+
model.compile(optimizer='sgd', loss='mean_squared_error')
337+
model.fit(x_array, y_array, epochs=3)
338+
yield model
339+
340+
@pytest.fixture
341+
def saved_model_dir(keras_model):
342+
# different versions have different model conversion capability
343+
# pick something that works for each version
344+
save_dir = SAVED_MODEL_DIR
345+
_clean_up_tmp_directory() # previous failures may leave files
346+
if tf.version.VERSION.startswith('1.'):
347+
tf.reset_default_graph()
348+
x_var = tf.placeholder(tf.float32, (None, 3), name="x")
349+
y_var = tf.multiply(x_var, x_var, name="y")
350+
with tf.Session() as sess:
351+
tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var})
352+
else:
353+
# If it's not version 1.x or version 2.x we need to update the test.
354+
assert tf.version.VERSION.startswith('2.')
355+
tf.saved_model.save(keras_model, save_dir)
356+
yield save_dir
357+
358+
359+
def test_from_keras_model(keras_model):
360+
source1 = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
361+
assert re.search(
362+
'^gs://.*/Firebase/ML/Models/model2.tflite$',
363+
source1.gcs_tflite_uri) is not None
364+
365+
# Validate the conversion by creating a model
366+
format1 = ml.TFLiteFormat(model_source=source1)
367+
model1 = ml.Model(display_name="KerasModel1", model_format=format1)
368+
firebase_model = ml.create_model(model1)
369+
assert firebase_model.model_id is not None
370+
assert firebase_model.validation_error is None
371+
372+
_clean_up_model(firebase_model)
373+
374+
def test_from_saved_model(saved_model_dir):
375+
# Test the conversion helper
376+
source1 = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite')
377+
assert re.search(
378+
'^gs://.*/Firebase/ML/Models/model3.tflite$',
379+
source1.gcs_tflite_uri) is not None
380+
381+
# Validate the conversion by creating a model
382+
format1 = ml.TFLiteFormat(model_source=source1)
383+
model1 = ml.Model(display_name="SavedModel1", model_format=format1)
384+
firebase_model = ml.create_model(model1)
385+
assert firebase_model.model_id is not None
386+
assert firebase_model.validation_error is None
387+
388+
_clean_up_model(firebase_model)
389+
_clean_up_tmp_directory()

tests/data/invalid_model.tflite

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is not a tflite file.

tests/data/model1.tflite

736 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)