Skip to content

Commit 0bc734c

Browse files
committed
review suggestions #2
1 parent 99ccb6d commit 0bc734c

File tree

1 file changed

+76
-95
lines changed

1 file changed

+76
-95
lines changed

integration/test_ml.py

+76-95
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import re
1717
import os
1818
import shutil
19-
import unittest
19+
import random
20+
import tempfile
2021
import pytest
2122

2223

@@ -34,29 +35,33 @@
3435

3536

3637
NAME_ONLY_ARGS = {
37-
'display_name': 'TestModel123'
38+
'display_name': 'TestModel123_{0}'.format(random.randint(1111, 9999))
39+
}
40+
NAME_ONLY_ARGS_UPDATED = {
41+
'display_name': 'TestModel123_updated_{0}'.format(random.randint(1111, 9999))
3842
}
3943
NAME_AND_TAGS_ARGS = {
40-
'display_name': 'TestModel123_tags',
44+
'display_name': 'TestModel123_tags_{0}'.format(random.randint(1111, 9999)),
4145
'tags': ['test_tag123']
42-
}
46+
}
4347
FULL_MODEL_ARGS = {
44-
'display_name': 'TestModel123_full',
48+
'display_name': 'TestModel123_full_{0}'.format(random.randint(1111, 9999)),
4549
'tags': ['test_tag567'],
4650
'file_name': 'model1.tflite'
47-
}
51+
}
4852
INVALID_FULL_MODEL_ARGS = {
49-
'display_name': 'TestModel123_invalid_full',
53+
'display_name': 'TestModel123_invalid_full_{0}'.format(random.randint(1111, 9999)),
5054
'tags': ['test_tag890'],
5155
'file_name': 'invalid_model.tflite'
52-
}
56+
}
5357

5458
@pytest.fixture
5559
def firebase_model(request):
5660
args = request.param
5761
tflite_format = None
58-
if args.get('file_name'):
59-
file_path = testutils.resource_filename(args.get('file_name'))
62+
file_name = args.get('file_name')
63+
if file_name:
64+
file_path = testutils.resource_filename(file_name)
6065
source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path)
6166
tflite_format = ml.TFLiteFormat(model_source=source)
6267

@@ -109,35 +114,44 @@ def check_operation_error(excinfo, msg):
109114
assert str(err) == msg
110115

111116

117+
def check_model(model, args):
118+
assert model.display_name == args.get('display_name')
119+
assert model.tags == args.get('tags')
120+
assert model.model_id is not None
121+
assert model.create_time is not None
122+
assert model.update_time is not None
123+
assert model.locked is False
124+
assert model.etag is not None
125+
126+
127+
def check_model_format(model, has_model_format, validation_error):
128+
if has_model_format:
129+
assert model.validation_error == validation_error
130+
assert model.published is False
131+
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
132+
if validation_error:
133+
assert model.model_format.size_bytes is None
134+
assert model.model_hash is None
135+
else:
136+
assert model.model_format.size_bytes is not None
137+
assert model.model_hash is not None
138+
else:
139+
assert model.model_format is None
140+
assert model.validation_error == 'No model file has been uploaded.'
141+
assert model.published is False
142+
assert model.model_hash is None
143+
144+
112145
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
113146
def test_create_simple_model(firebase_model):
114-
assert firebase_model.display_name == NAME_AND_TAGS_ARGS.get('display_name')
115-
assert firebase_model.tags == NAME_AND_TAGS_ARGS.get('tags')
116-
assert firebase_model.model_id is not None
117-
assert firebase_model.create_time is not None
118-
assert firebase_model.update_time is not None
119-
assert firebase_model.validation_error == 'No model file has been uploaded.'
120-
assert firebase_model.locked is False
121-
assert firebase_model.published is False
122-
assert firebase_model.etag is not None
123-
assert firebase_model.model_hash is None
124-
assert firebase_model.model_format is None
147+
check_model(firebase_model, NAME_AND_TAGS_ARGS)
148+
check_model_format(firebase_model, False, None)
125149

126150

127151
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
128152
def test_create_full_model(firebase_model):
129-
assert firebase_model.display_name == FULL_MODEL_ARGS.get('display_name')
130-
assert firebase_model.tags == FULL_MODEL_ARGS.get('tags')
131-
assert firebase_model.model_format.size_bytes is not None
132-
assert firebase_model.model_format.model_source.gcs_tflite_uri is not None
133-
assert firebase_model.model_id is not None
134-
assert firebase_model.create_time is not None
135-
assert firebase_model.update_time is not None
136-
assert firebase_model.validation_error is None
137-
assert firebase_model.locked is False
138-
assert firebase_model.published is False
139-
assert firebase_model.etag is not None
140-
assert firebase_model.model_hash is not None
153+
check_model(firebase_model, FULL_MODEL_ARGS)
154+
check_model_format(firebase_model, True, None)
141155

142156

143157
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
@@ -151,33 +165,15 @@ def test_create_already_existing_fails(firebase_model):
151165

152166
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
153167
def test_create_invalid_model(firebase_model):
154-
assert firebase_model.display_name == INVALID_FULL_MODEL_ARGS.get('display_name')
155-
assert firebase_model.tags == INVALID_FULL_MODEL_ARGS.get('tags')
156-
assert firebase_model.model_format.size_bytes is None
157-
assert firebase_model.model_format.model_source.gcs_tflite_uri is not None
158-
assert firebase_model.model_id is not None
159-
assert firebase_model.create_time is not None
160-
assert firebase_model.update_time is not None
161-
assert firebase_model.validation_error == 'Invalid flatbuffer format'
162-
assert firebase_model.locked is False
163-
assert firebase_model.published is False
164-
assert firebase_model.etag is not None
165-
assert firebase_model.model_hash is None
168+
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
169+
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
166170

167171

168172
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
169173
def test_get_model(firebase_model):
170174
get_model = ml.get_model(firebase_model.model_id)
171-
assert get_model.display_name == firebase_model.display_name
172-
assert get_model.tags == firebase_model.tags
173-
assert get_model.model_id is not None
174-
assert get_model.create_time is not None
175-
assert get_model.update_time is not None
176-
assert get_model.validation_error == 'No model file has been uploaded.'
177-
assert get_model.etag is not None
178-
assert get_model.locked is False
179-
assert get_model.published is False
180-
assert get_model.model_hash is None
175+
check_model(get_model, NAME_AND_TAGS_ARGS)
176+
check_model_format(get_model, False, None)
181177

182178

183179
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -192,29 +188,16 @@ def test_get_non_existing_model(firebase_model):
192188

193189
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
194190
def test_update_model(firebase_model):
195-
new_model_name = 'TestModel123_updated'
191+
new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name')
196192
firebase_model.display_name = new_model_name
197-
198193
updated_model = ml.update_model(firebase_model)
199-
assert updated_model.display_name == new_model_name
200-
assert updated_model.model_id == firebase_model.model_id
201-
assert updated_model.create_time == firebase_model.create_time
202-
assert updated_model.update_time != firebase_model.update_time
203-
assert updated_model.validation_error == firebase_model.validation_error
204-
assert updated_model.etag != firebase_model.etag
205-
assert updated_model.published == firebase_model.published
206-
assert updated_model.locked == firebase_model.locked
194+
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
195+
check_model_format(updated_model, False, None)
207196

208197
# Second call with same model does not cause error
209198
updated_model2 = ml.update_model(updated_model)
210-
assert updated_model2.display_name == updated_model.display_name
211-
assert updated_model2.model_id == updated_model.model_id
212-
assert updated_model2.create_time == updated_model.create_time
213-
assert updated_model2.update_time != updated_model.update_time
214-
assert updated_model2.validation_error == updated_model.validation_error
215-
assert updated_model2.etag != updated_model.etag
216-
assert updated_model2.published == updated_model.published
217-
assert updated_model2.locked == updated_model.locked
199+
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
200+
check_model_format(updated_model2, False, None)
218201

219202

220203
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -272,11 +255,10 @@ def test_list_models(model_list):
272255
filter_str = 'displayName={0} OR tags:{1}'.format(
273256
model_list[0].display_name, model_list[1].tags[0])
274257

275-
models_list = ml.list_models(list_filter=filter_str)
276-
assert len(models_list.models) == 2
277-
for mdl in models_list.models:
278-
assert mdl == model_list[0] or mdl == model_list[1]
279-
assert models_list.models[0] != models_list.models[1]
258+
all_models = ml.list_models(list_filter=filter_str)
259+
all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()]
260+
for mdl in model_list:
261+
assert mdl.model_id in all_model_ids
280262

281263

282264
def test_list_models_invalid_filter():
@@ -302,12 +284,9 @@ def test_delete_model(firebase_model):
302284
#'pip install tensorflow==2.0.0b' for version 2 etc.
303285

304286

305-
SAVED_MODEL_DIR = '/tmp/saved_model/1'
306-
307-
308-
def _clean_up_tmp_directory():
309-
if os.path.exists(SAVED_MODEL_DIR):
310-
shutil.rmtree(SAVED_MODEL_DIR)
287+
def _clean_up_directory(save_dir):
288+
if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir):
289+
shutil.rmtree(save_dir)
311290

312291

313292
@pytest.fixture
@@ -327,8 +306,8 @@ def saved_model_dir(keras_model):
327306
assert _TF_ENABLED
328307
# different versions have different model conversion capability
329308
# pick something that works for each version
330-
save_dir = SAVED_MODEL_DIR
331-
_clean_up_tmp_directory() # previous failures may leave files
309+
parent = tempfile.mkdtemp()
310+
save_dir = os.path.join(parent, 'child')
332311
if tf.version.VERSION.startswith('1.'):
333312
tf.reset_default_graph()
334313
x_var = tf.placeholder(tf.float32, (None, 3), name="x")
@@ -340,28 +319,29 @@ def saved_model_dir(keras_model):
340319
assert tf.version.VERSION.startswith('2.')
341320
tf.saved_model.save(keras_model, save_dir)
342321
yield save_dir
343-
_clean_up_tmp_directory()
322+
_clean_up_directory(parent)
344323

345324

346-
@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.')
325+
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
347326
def test_from_keras_model(keras_model):
348327
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
349328
assert re.search(
350329
'^gs://.*/Firebase/ML/Models/model2.tflite$',
351330
source.gcs_tflite_uri) is not None
352331

353332
# Validate the conversion by creating a model
333+
model_format = ml.TFLiteFormat(model_source=source)
334+
model = ml.Model(display_name="KerasModel1", model_format=model_format)
335+
created_model = ml.create_model(model)
336+
354337
try:
355-
model_format = ml.TFLiteFormat(model_source=source)
356-
model = ml.Model(display_name="KerasModel1", model_format=model_format)
357-
created_model = ml.create_model(model)
358338
assert created_model.model_id is not None
359339
assert created_model.validation_error is None
360340
finally:
361341
_clean_up_model(created_model)
362342

363343

364-
@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.')
344+
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
365345
def test_from_saved_model(saved_model_dir):
366346
# Test the conversion helper
367347
source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite')
@@ -370,10 +350,11 @@ def test_from_saved_model(saved_model_dir):
370350
source.gcs_tflite_uri) is not None
371351

372352
# Validate the conversion by creating a model
353+
model_format = ml.TFLiteFormat(model_source=source)
354+
model = ml.Model(display_name="SavedModel1", model_format=model_format)
355+
created_model = ml.create_model(model)
356+
373357
try:
374-
model_format = ml.TFLiteFormat(model_source=source)
375-
model = ml.Model(display_name="SavedModel1", model_format=model_format)
376-
created_model = ml.create_model(model)
377358
assert created_model.model_id is not None
378359
assert created_model.validation_error is None
379360
finally:

0 commit comments

Comments
 (0)