13
13
# limitations under the License.
14
14
15
15
"""Integration tests for firebase_admin.ml module."""
16
- import re
17
16
import os
18
- import shutil
19
17
import random
18
+ import re
19
+ import shutil
20
+ import string
20
21
import tempfile
21
22
import pytest
22
23
23
24
24
- from firebase_admin import ml
25
25
from firebase_admin import exceptions
26
+ from firebase_admin import ml
26
27
from tests import testutils
27
28
28
29
34
35
_TF_ENABLED = False
35
36
36
37
38
+ def _random_identifier (prefix ):
39
+ #pylint: disable=unused-variable
40
+ suffix = '' .join ([random .choice (string .ascii_letters + string .digits ) for n in range (8 )])
41
+ return '{0}_{1}' .format (prefix , suffix )
42
+
43
+
37
44
NAME_ONLY_ARGS = {
38
- 'display_name' : 'TestModel123_{0}' . format ( random . randint ( 1111 , 9999 ) )
45
+ 'display_name' : _random_identifier ( 'TestModel123_' )
39
46
}
40
47
NAME_ONLY_ARGS_UPDATED = {
41
- 'display_name' : 'TestModel123_updated_{0}' . format ( random . randint ( 1111 , 9999 ) )
48
+ 'display_name' : _random_identifier ( 'TestModel123_updated_' )
42
49
}
43
50
NAME_AND_TAGS_ARGS = {
44
- 'display_name' : 'TestModel123_tags_{0}' . format ( random . randint ( 1111 , 9999 ) ),
51
+ 'display_name' : _random_identifier ( 'TestModel123_tags_' ),
45
52
'tags' : ['test_tag123' ]
46
53
}
47
54
FULL_MODEL_ARGS = {
48
- 'display_name' : 'TestModel123_full_{0}' . format ( random . randint ( 1111 , 9999 ) ),
55
+ 'display_name' : _random_identifier ( 'TestModel123_full_' ),
49
56
'tags' : ['test_tag567' ],
50
57
'file_name' : 'model1.tflite'
51
58
}
52
59
INVALID_FULL_MODEL_ARGS = {
53
- 'display_name' : 'TestModel123_invalid_full_{0}' . format ( random . randint ( 1111 , 9999 ) ),
60
+ 'display_name' : _random_identifier ( 'TestModel123_invalid_full_' ),
54
61
'tags' : ['test_tag890' ],
55
62
'file_name' : 'invalid_model.tflite'
56
63
}
57
64
65
+
58
66
@pytest .fixture
59
67
def firebase_model (request ):
60
68
args = request .param
@@ -76,10 +84,11 @@ def firebase_model(request):
76
84
77
85
@pytest .fixture
78
86
def model_list ():
79
- ml_model_1 = ml .Model (display_name = "TestModel123" )
87
+ ml_model_1 = ml .Model (display_name = _random_identifier ( 'TestModel123_list1_' ) )
80
88
model_1 = ml .create_model (model = ml_model_1 )
81
89
82
- ml_model_2 = ml .Model (display_name = "TestModel123_tags" , tags = ['test_tag123' ])
90
+ ml_model_2 = ml .Model (display_name = _random_identifier ('TestModel123_list2_' ),
91
+ tags = ['test_tag123' ])
83
92
model_2 = ml .create_model (model = ml_model_2 )
84
93
85
94
yield [model_1 , model_2 ]
@@ -124,7 +133,7 @@ def check_model(model, args):
124
133
assert model .etag is not None
125
134
126
135
127
- def check_model_format (model , has_model_format , validation_error ):
136
+ def check_model_format (model , has_model_format = False , validation_error = None ):
128
137
if has_model_format :
129
138
assert model .validation_error == validation_error
130
139
assert model .published is False
@@ -145,13 +154,13 @@ def check_model_format(model, has_model_format, validation_error):
145
154
@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
146
155
def test_create_simple_model (firebase_model ):
147
156
check_model (firebase_model , NAME_AND_TAGS_ARGS )
148
- check_model_format (firebase_model , False , None )
157
+ check_model_format (firebase_model )
149
158
150
159
151
160
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
152
161
def test_create_full_model (firebase_model ):
153
162
check_model (firebase_model , FULL_MODEL_ARGS )
154
- check_model_format (firebase_model , True , None )
163
+ check_model_format (firebase_model , True )
155
164
156
165
157
166
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -173,7 +182,7 @@ def test_create_invalid_model(firebase_model):
173
182
def test_get_model (firebase_model ):
174
183
get_model = ml .get_model (firebase_model .model_id )
175
184
check_model (get_model , NAME_AND_TAGS_ARGS )
176
- check_model_format (get_model , False , None )
185
+ check_model_format (get_model )
177
186
178
187
179
188
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -192,12 +201,12 @@ def test_update_model(firebase_model):
192
201
firebase_model .display_name = new_model_name
193
202
updated_model = ml .update_model (firebase_model )
194
203
check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
195
- check_model_format (updated_model , False , None )
204
+ check_model_format (updated_model )
196
205
197
206
# Second call with same model does not cause error
198
207
updated_model2 = ml .update_model (updated_model )
199
208
check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
200
- check_model_format (updated_model2 , False , None )
209
+ check_model_format (updated_model2 )
201
210
202
211
203
212
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -304,10 +313,13 @@ def keras_model():
304
313
@pytest .fixture
305
314
def saved_model_dir (keras_model ):
306
315
assert _TF_ENABLED
307
- # different versions have different model conversion capability
308
- # pick something that works for each version
316
+ # Make a new parent directory. The child directory must not exist yet.
317
+ # The child directory gets created by tf. If it exists, the tf call fails.
309
318
parent = tempfile .mkdtemp ()
310
319
save_dir = os .path .join (parent , 'child' )
320
+
321
+ # different versions have different model conversion capability
322
+ # pick something that works for each version
311
323
if tf .version .VERSION .startswith ('1.' ):
312
324
tf .reset_default_graph ()
313
325
x_var = tf .placeholder (tf .float32 , (None , 3 ), name = "x" )
@@ -331,12 +343,12 @@ def test_from_keras_model(keras_model):
331
343
332
344
# Validate the conversion by creating a model
333
345
model_format = ml .TFLiteFormat (model_source = source )
334
- model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
346
+ model = ml .Model (display_name = _random_identifier ( 'KerasModel_' ) , model_format = model_format )
335
347
created_model = ml .create_model (model )
336
348
337
349
try :
338
- assert created_model . model_id is not None
339
- assert created_model . validation_error is None
350
+ check_model ( created_model , { 'display_name' : model . display_name })
351
+ check_model_format ( created_model , True )
340
352
finally :
341
353
_clean_up_model (created_model )
342
354
@@ -351,7 +363,7 @@ def test_from_saved_model(saved_model_dir):
351
363
352
364
# Validate the conversion by creating a model
353
365
model_format = ml .TFLiteFormat (model_source = source )
354
- model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
366
+ model = ml .Model (display_name = _random_identifier ( 'SavedModel_' ) , model_format = model_format )
355
367
created_model = ml .create_model (model )
356
368
357
369
try :
0 commit comments