Skip to content

Commit ae9527f

Browse files
committed
review suggestions #3
1 parent 0bc734c commit ae9527f

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

integration/test_ml.py

+34-22
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# limitations under the License.
1414

1515
"""Integration tests for firebase_admin.ml module."""
16-
import re
1716
import os
18-
import shutil
1917
import random
18+
import re
19+
import shutil
20+
import string
2021
import tempfile
2122
import pytest
2223

2324

24-
from firebase_admin import ml
2525
from firebase_admin import exceptions
26+
from firebase_admin import ml
2627
from tests import testutils
2728

2829

@@ -34,27 +35,34 @@
3435
_TF_ENABLED = False
3536

3637

38+
def _random_identifier(prefix):
39+
#pylint: disable=unused-variable
40+
suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
41+
return '{0}_{1}'.format(prefix, suffix)
42+
43+
3744
NAME_ONLY_ARGS = {
38-
'display_name': 'TestModel123_{0}'.format(random.randint(1111, 9999))
45+
'display_name': _random_identifier('TestModel123_')
3946
}
4047
NAME_ONLY_ARGS_UPDATED = {
41-
'display_name': 'TestModel123_updated_{0}'.format(random.randint(1111, 9999))
48+
'display_name': _random_identifier('TestModel123_updated_')
4249
}
4350
NAME_AND_TAGS_ARGS = {
44-
'display_name': 'TestModel123_tags_{0}'.format(random.randint(1111, 9999)),
51+
'display_name': _random_identifier('TestModel123_tags_'),
4552
'tags': ['test_tag123']
4653
}
4754
FULL_MODEL_ARGS = {
48-
'display_name': 'TestModel123_full_{0}'.format(random.randint(1111, 9999)),
55+
'display_name': _random_identifier('TestModel123_full_'),
4956
'tags': ['test_tag567'],
5057
'file_name': 'model1.tflite'
5158
}
5259
INVALID_FULL_MODEL_ARGS = {
53-
'display_name': 'TestModel123_invalid_full_{0}'.format(random.randint(1111, 9999)),
60+
'display_name': _random_identifier('TestModel123_invalid_full_'),
5461
'tags': ['test_tag890'],
5562
'file_name': 'invalid_model.tflite'
5663
}
5764

65+
5866
@pytest.fixture
5967
def firebase_model(request):
6068
args = request.param
@@ -76,10 +84,11 @@ def firebase_model(request):
7684

7785
@pytest.fixture
7886
def model_list():
79-
ml_model_1 = ml.Model(display_name="TestModel123")
87+
ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_'))
8088
model_1 = ml.create_model(model=ml_model_1)
8189

82-
ml_model_2 = ml.Model(display_name="TestModel123_tags", tags=['test_tag123'])
90+
ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'),
91+
tags=['test_tag123'])
8392
model_2 = ml.create_model(model=ml_model_2)
8493

8594
yield [model_1, model_2]
@@ -124,7 +133,7 @@ def check_model(model, args):
124133
assert model.etag is not None
125134

126135

127-
def check_model_format(model, has_model_format, validation_error):
136+
def check_model_format(model, has_model_format=False, validation_error=None):
128137
if has_model_format:
129138
assert model.validation_error == validation_error
130139
assert model.published is False
@@ -145,13 +154,13 @@ def check_model_format(model, has_model_format, validation_error):
145154
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
146155
def test_create_simple_model(firebase_model):
147156
check_model(firebase_model, NAME_AND_TAGS_ARGS)
148-
check_model_format(firebase_model, False, None)
157+
check_model_format(firebase_model)
149158

150159

151160
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
152161
def test_create_full_model(firebase_model):
153162
check_model(firebase_model, FULL_MODEL_ARGS)
154-
check_model_format(firebase_model, True, None)
163+
check_model_format(firebase_model, True)
155164

156165

157166
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
@@ -173,7 +182,7 @@ def test_create_invalid_model(firebase_model):
173182
def test_get_model(firebase_model):
174183
get_model = ml.get_model(firebase_model.model_id)
175184
check_model(get_model, NAME_AND_TAGS_ARGS)
176-
check_model_format(get_model, False, None)
185+
check_model_format(get_model)
177186

178187

179188
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -192,12 +201,12 @@ def test_update_model(firebase_model):
192201
firebase_model.display_name = new_model_name
193202
updated_model = ml.update_model(firebase_model)
194203
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
195-
check_model_format(updated_model, False, None)
204+
check_model_format(updated_model)
196205

197206
# Second call with same model does not cause error
198207
updated_model2 = ml.update_model(updated_model)
199208
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
200-
check_model_format(updated_model2, False, None)
209+
check_model_format(updated_model2)
201210

202211

203212
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -304,10 +313,13 @@ def keras_model():
304313
@pytest.fixture
305314
def saved_model_dir(keras_model):
306315
assert _TF_ENABLED
307-
# different versions have different model conversion capability
308-
# pick something that works for each version
316+
# Make a new parent directory. The child directory must not exist yet.
317+
# The child directory gets created by tf. If it exists, the tf call fails.
309318
parent = tempfile.mkdtemp()
310319
save_dir = os.path.join(parent, 'child')
320+
321+
# different versions have different model conversion capability
322+
# pick something that works for each version
311323
if tf.version.VERSION.startswith('1.'):
312324
tf.reset_default_graph()
313325
x_var = tf.placeholder(tf.float32, (None, 3), name="x")
@@ -331,12 +343,12 @@ def test_from_keras_model(keras_model):
331343

332344
# Validate the conversion by creating a model
333345
model_format = ml.TFLiteFormat(model_source=source)
334-
model = ml.Model(display_name="KerasModel1", model_format=model_format)
346+
model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format)
335347
created_model = ml.create_model(model)
336348

337349
try:
338-
assert created_model.model_id is not None
339-
assert created_model.validation_error is None
350+
check_model(created_model, {'display_name': model.display_name})
351+
check_model_format(created_model, True)
340352
finally:
341353
_clean_up_model(created_model)
342354

@@ -351,7 +363,7 @@ def test_from_saved_model(saved_model_dir):
351363

352364
# Validate the conversion by creating a model
353365
model_format = ml.TFLiteFormat(model_source=source)
354-
model = ml.Model(display_name="SavedModel1", model_format=model_format)
366+
model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format)
355367
created_model = ml.create_model(model)
356368

357369
try:

0 commit comments

Comments
 (0)