16
16
import re
17
17
import os
18
18
import shutil
19
- import unittest
19
+ import random
20
+ import tempfile
20
21
import pytest
21
22
22
23
34
35
35
36
36
37
NAME_ONLY_ARGS = {
37
- 'display_name' : 'TestModel123'
38
+ 'display_name' : 'TestModel123_{0}' .format (random .randint (1111 , 9999 ))
39
+ }
40
+ NAME_ONLY_ARGS_UPDATED = {
41
+ 'display_name' : 'TestModel123_updated_{0}' .format (random .randint (1111 , 9999 ))
38
42
}
39
43
NAME_AND_TAGS_ARGS = {
40
- 'display_name' : 'TestModel123_tags' ,
44
+ 'display_name' : 'TestModel123_tags_{0}' . format ( random . randint ( 1111 , 9999 )) ,
41
45
'tags' : ['test_tag123' ]
42
- }
46
+ }
43
47
FULL_MODEL_ARGS = {
44
- 'display_name' : 'TestModel123_full' ,
48
+ 'display_name' : 'TestModel123_full_{0}' . format ( random . randint ( 1111 , 9999 )) ,
45
49
'tags' : ['test_tag567' ],
46
50
'file_name' : 'model1.tflite'
47
- }
51
+ }
48
52
INVALID_FULL_MODEL_ARGS = {
49
- 'display_name' : 'TestModel123_invalid_full' ,
53
+ 'display_name' : 'TestModel123_invalid_full_{0}' . format ( random . randint ( 1111 , 9999 )) ,
50
54
'tags' : ['test_tag890' ],
51
55
'file_name' : 'invalid_model.tflite'
52
- }
56
+ }
53
57
54
58
@pytest .fixture
55
59
def firebase_model (request ):
56
60
args = request .param
57
61
tflite_format = None
58
- if args .get ('file_name' ):
59
- file_path = testutils .resource_filename (args .get ('file_name' ))
62
+ file_name = args .get ('file_name' )
63
+ if file_name :
64
+ file_path = testutils .resource_filename (file_name )
60
65
source = ml .TFLiteGCSModelSource .from_tflite_model_file (file_path )
61
66
tflite_format = ml .TFLiteFormat (model_source = source )
62
67
@@ -109,35 +114,44 @@ def check_operation_error(excinfo, msg):
109
114
assert str (err ) == msg
110
115
111
116
117
+ def check_model (model , args ):
118
+ assert model .display_name == args .get ('display_name' )
119
+ assert model .tags == args .get ('tags' )
120
+ assert model .model_id is not None
121
+ assert model .create_time is not None
122
+ assert model .update_time is not None
123
+ assert model .locked is False
124
+ assert model .etag is not None
125
+
126
+
127
+ def check_model_format (model , has_model_format , validation_error ):
128
+ if has_model_format :
129
+ assert model .validation_error == validation_error
130
+ assert model .published is False
131
+ assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
132
+ if validation_error :
133
+ assert model .model_format .size_bytes is None
134
+ assert model .model_hash is None
135
+ else :
136
+ assert model .model_format .size_bytes is not None
137
+ assert model .model_hash is not None
138
+ else :
139
+ assert model .model_format is None
140
+ assert model .validation_error == 'No model file has been uploaded.'
141
+ assert model .published is False
142
+ assert model .model_hash is None
143
+
144
+
112
145
@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
113
146
def test_create_simple_model (firebase_model ):
114
- assert firebase_model .display_name == NAME_AND_TAGS_ARGS .get ('display_name' )
115
- assert firebase_model .tags == NAME_AND_TAGS_ARGS .get ('tags' )
116
- assert firebase_model .model_id is not None
117
- assert firebase_model .create_time is not None
118
- assert firebase_model .update_time is not None
119
- assert firebase_model .validation_error == 'No model file has been uploaded.'
120
- assert firebase_model .locked is False
121
- assert firebase_model .published is False
122
- assert firebase_model .etag is not None
123
- assert firebase_model .model_hash is None
124
- assert firebase_model .model_format is None
147
+ check_model (firebase_model , NAME_AND_TAGS_ARGS )
148
+ check_model_format (firebase_model , False , None )
125
149
126
150
127
151
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
128
152
def test_create_full_model (firebase_model ):
129
- assert firebase_model .display_name == FULL_MODEL_ARGS .get ('display_name' )
130
- assert firebase_model .tags == FULL_MODEL_ARGS .get ('tags' )
131
- assert firebase_model .model_format .size_bytes is not None
132
- assert firebase_model .model_format .model_source .gcs_tflite_uri is not None
133
- assert firebase_model .model_id is not None
134
- assert firebase_model .create_time is not None
135
- assert firebase_model .update_time is not None
136
- assert firebase_model .validation_error is None
137
- assert firebase_model .locked is False
138
- assert firebase_model .published is False
139
- assert firebase_model .etag is not None
140
- assert firebase_model .model_hash is not None
153
+ check_model (firebase_model , FULL_MODEL_ARGS )
154
+ check_model_format (firebase_model , True , None )
141
155
142
156
143
157
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -151,33 +165,15 @@ def test_create_already_existing_fails(firebase_model):
151
165
152
166
@pytest .mark .parametrize ('firebase_model' , [INVALID_FULL_MODEL_ARGS ], indirect = True )
153
167
def test_create_invalid_model (firebase_model ):
154
- assert firebase_model .display_name == INVALID_FULL_MODEL_ARGS .get ('display_name' )
155
- assert firebase_model .tags == INVALID_FULL_MODEL_ARGS .get ('tags' )
156
- assert firebase_model .model_format .size_bytes is None
157
- assert firebase_model .model_format .model_source .gcs_tflite_uri is not None
158
- assert firebase_model .model_id is not None
159
- assert firebase_model .create_time is not None
160
- assert firebase_model .update_time is not None
161
- assert firebase_model .validation_error == 'Invalid flatbuffer format'
162
- assert firebase_model .locked is False
163
- assert firebase_model .published is False
164
- assert firebase_model .etag is not None
165
- assert firebase_model .model_hash is None
168
+ check_model (firebase_model , INVALID_FULL_MODEL_ARGS )
169
+ check_model_format (firebase_model , True , 'Invalid flatbuffer format' )
166
170
167
171
168
172
@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
169
173
def test_get_model (firebase_model ):
170
174
get_model = ml .get_model (firebase_model .model_id )
171
- assert get_model .display_name == firebase_model .display_name
172
- assert get_model .tags == firebase_model .tags
173
- assert get_model .model_id is not None
174
- assert get_model .create_time is not None
175
- assert get_model .update_time is not None
176
- assert get_model .validation_error == 'No model file has been uploaded.'
177
- assert get_model .etag is not None
178
- assert get_model .locked is False
179
- assert get_model .published is False
180
- assert get_model .model_hash is None
175
+ check_model (get_model , NAME_AND_TAGS_ARGS )
176
+ check_model_format (get_model , False , None )
181
177
182
178
183
179
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -192,29 +188,16 @@ def test_get_non_existing_model(firebase_model):
192
188
193
189
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
194
190
def test_update_model (firebase_model ):
195
- new_model_name = 'TestModel123_updated'
191
+ new_model_name = NAME_ONLY_ARGS_UPDATED . get ( 'display_name' )
196
192
firebase_model .display_name = new_model_name
197
-
198
193
updated_model = ml .update_model (firebase_model )
199
- assert updated_model .display_name == new_model_name
200
- assert updated_model .model_id == firebase_model .model_id
201
- assert updated_model .create_time == firebase_model .create_time
202
- assert updated_model .update_time != firebase_model .update_time
203
- assert updated_model .validation_error == firebase_model .validation_error
204
- assert updated_model .etag != firebase_model .etag
205
- assert updated_model .published == firebase_model .published
206
- assert updated_model .locked == firebase_model .locked
194
+ check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
195
+ check_model_format (updated_model , False , None )
207
196
208
197
# Second call with same model does not cause error
209
198
updated_model2 = ml .update_model (updated_model )
210
- assert updated_model2 .display_name == updated_model .display_name
211
- assert updated_model2 .model_id == updated_model .model_id
212
- assert updated_model2 .create_time == updated_model .create_time
213
- assert updated_model2 .update_time != updated_model .update_time
214
- assert updated_model2 .validation_error == updated_model .validation_error
215
- assert updated_model2 .etag != updated_model .etag
216
- assert updated_model2 .published == updated_model .published
217
- assert updated_model2 .locked == updated_model .locked
199
+ check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
200
+ check_model_format (updated_model2 , False , None )
218
201
219
202
220
203
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -272,11 +255,10 @@ def test_list_models(model_list):
272
255
filter_str = 'displayName={0} OR tags:{1}' .format (
273
256
model_list [0 ].display_name , model_list [1 ].tags [0 ])
274
257
275
- models_list = ml .list_models (list_filter = filter_str )
276
- assert len (models_list .models ) == 2
277
- for mdl in models_list .models :
278
- assert mdl == model_list [0 ] or mdl == model_list [1 ]
279
- assert models_list .models [0 ] != models_list .models [1 ]
258
+ all_models = ml .list_models (list_filter = filter_str )
259
+ all_model_ids = [mdl .model_id for mdl in all_models .iterate_all ()]
260
+ for mdl in model_list :
261
+ assert mdl .model_id in all_model_ids
280
262
281
263
282
264
def test_list_models_invalid_filter ():
@@ -302,12 +284,9 @@ def test_delete_model(firebase_model):
302
284
#'pip install tensorflow==2.0.0b' for version 2 etc.
303
285
304
286
305
- SAVED_MODEL_DIR = '/tmp/saved_model/1'
306
-
307
-
308
- def _clean_up_tmp_directory ():
309
- if os .path .exists (SAVED_MODEL_DIR ):
310
- shutil .rmtree (SAVED_MODEL_DIR )
287
+ def _clean_up_directory (save_dir ):
288
+ if save_dir .startswith (tempfile .gettempdir ()) and os .path .exists (save_dir ):
289
+ shutil .rmtree (save_dir )
311
290
312
291
313
292
@pytest .fixture
@@ -327,8 +306,8 @@ def saved_model_dir(keras_model):
327
306
assert _TF_ENABLED
328
307
# different versions have different model conversion capability
329
308
# pick something that works for each version
330
- save_dir = SAVED_MODEL_DIR
331
- _clean_up_tmp_directory () # previous failures may leave files
309
+ parent = tempfile . mkdtemp ()
310
+ save_dir = os . path . join ( parent , 'child' )
332
311
if tf .version .VERSION .startswith ('1.' ):
333
312
tf .reset_default_graph ()
334
313
x_var = tf .placeholder (tf .float32 , (None , 3 ), name = "x" )
@@ -340,28 +319,29 @@ def saved_model_dir(keras_model):
340
319
assert tf .version .VERSION .startswith ('2.' )
341
320
tf .saved_model .save (keras_model , save_dir )
342
321
yield save_dir
343
- _clean_up_tmp_directory ( )
322
+ _clean_up_directory ( parent )
344
323
345
324
346
- @unittest . skipUnless ( _TF_ENABLED , 'Tensor flow is required for this test.' )
325
+ @pytest . mark . skipif ( not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
347
326
def test_from_keras_model (keras_model ):
348
327
source = ml .TFLiteGCSModelSource .from_keras_model (keras_model , 'model2.tflite' )
349
328
assert re .search (
350
329
'^gs://.*/Firebase/ML/Models/model2.tflite$' ,
351
330
source .gcs_tflite_uri ) is not None
352
331
353
332
# Validate the conversion by creating a model
333
+ model_format = ml .TFLiteFormat (model_source = source )
334
+ model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
335
+ created_model = ml .create_model (model )
336
+
354
337
try :
355
- model_format = ml .TFLiteFormat (model_source = source )
356
- model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
357
- created_model = ml .create_model (model )
358
338
assert created_model .model_id is not None
359
339
assert created_model .validation_error is None
360
340
finally :
361
341
_clean_up_model (created_model )
362
342
363
343
364
- @unittest . skipUnless ( _TF_ENABLED , 'Tensor flow is required for this test.' )
344
+ @pytest . mark . skipif ( not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
365
345
def test_from_saved_model (saved_model_dir ):
366
346
# Test the conversion helper
367
347
source = ml .TFLiteGCSModelSource .from_saved_model (saved_model_dir , 'model3.tflite' )
@@ -370,10 +350,11 @@ def test_from_saved_model(saved_model_dir):
370
350
source .gcs_tflite_uri ) is not None
371
351
372
352
# Validate the conversion by creating a model
353
+ model_format = ml .TFLiteFormat (model_source = source )
354
+ model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
355
+ created_model = ml .create_model (model )
356
+
373
357
try :
374
- model_format = ml .TFLiteFormat (model_source = source )
375
- model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
376
- created_model = ml .create_model (model )
377
358
assert created_model .model_id is not None
378
359
assert created_model .validation_error is None
379
360
finally :
0 commit comments