24
24
from tests import testutils
25
25
26
26
BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
27
-
28
27
PROJECT_ID = 'myProject1'
29
28
PAGE_TOKEN = 'pageToken'
30
29
NEXT_PAGE_TOKEN = 'nextPageToken'
122
121
}
123
122
TFLITE_FORMAT_2 = mlkit .TFLiteFormat .from_dict (TFLITE_FORMAT_JSON_2 )
124
123
125
- CREATED_MODEL_JSON_1 = {
124
+ CREATED_UPDATED_MODEL_JSON_1 = {
126
125
'name' : MODEL_NAME_1 ,
127
126
'displayName' : DISPLAY_NAME_1 ,
128
127
'createTime' : CREATE_TIME_JSON ,
132
131
'modelHash' : MODEL_HASH ,
133
132
'tags' : TAGS ,
134
133
}
135
- CREATED_MODEL_1 = mlkit .Model .from_dict (CREATED_MODEL_JSON_1 )
134
+ CREATED_UPDATED_MODEL_1 = mlkit .Model .from_dict (CREATED_UPDATED_MODEL_JSON_1 )
136
135
137
136
LOCKED_MODEL_JSON_1 = {
138
137
'name' : MODEL_NAME_1 ,
155
154
OPERATION_DONE_MODEL_JSON_1 = {
156
155
'name' : OPERATION_NAME_1 ,
157
156
'done' : True ,
158
- 'response' : CREATED_MODEL_JSON_1
157
+ 'response' : CREATED_UPDATED_MODEL_JSON_1
159
158
}
160
-
161
159
OPERATION_MALFORMED_JSON_1 = {
162
160
'name' : OPERATION_NAME_1 ,
163
161
'done' : True ,
164
162
# if done is true then either response or error should be populated
165
163
}
166
-
167
164
OPERATION_MISSING_NAME = {
168
165
'done' : False
169
166
}
170
-
171
167
OPERATION_ERROR_CODE = 400
172
168
OPERATION_ERROR_MSG = "Invalid argument"
173
169
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
254
250
}
255
251
ERROR_RESPONSE_BAD_REQUEST = json .dumps (ERROR_JSON_BAD_REQUEST )
256
252
257
- invalid_model_id_args = [
253
+ INVALID_MODEL_ID_ARGS = [
258
254
('' , ValueError ),
259
255
('&_*#@:/?' , ValueError ),
260
256
(None , TypeError ),
261
257
(12345 , TypeError ),
262
258
]
259
+ INVALID_MODEL_ARGS = [
260
+ 'abc' ,
261
+ 4.2 ,
262
+ list (),
263
+ dict (),
264
+ True ,
265
+ - 1 ,
266
+ 0 ,
267
+ None
268
+ ]
269
+ INVALID_OP_NAME_ARGS = [
270
+ 'abc' ,
271
+ '123' ,
272
+ 'projects/operations/project/1234/model/abc/operation/123' ,
273
+ 'operations/project/model/abc/operation/123' ,
274
+ 'operations/project/123/model/$#@/operation/123' ,
275
+ 'operations/project/1234/model/abc/operation/123/extrathing' ,
276
+ ]
263
277
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
264
278
'1 and {0}' .format (mlkit ._MAX_PAGE_SIZE )
265
- invalid_string_or_none_args = [0 , - 1 , 4.2 , 0x10 , False , list (), dict ()]
279
+ INVALID_STRING_OR_NONE_ARGS = [0 , - 1 , 4.2 , 0x10 , False , list (), dict ()]
266
280
267
281
268
282
# For validation type errors
@@ -524,7 +538,7 @@ def _get_url(project_id, model_id):
524
538
def test_immediate_done (self ):
525
539
instrument_mlkit_service (status = 200 , payload = OPERATION_DONE_RESPONSE )
526
540
model = mlkit .create_model (MODEL_1 )
527
- assert model == CREATED_MODEL_1
541
+ assert model == CREATED_UPDATED_MODEL_1
528
542
529
543
def test_returns_locked (self ):
530
544
recorder = instrument_mlkit_service (
@@ -573,16 +587,7 @@ def test_rpc_error_create(self):
573
587
)
574
588
assert len (create_recorder ) == 1
575
589
576
- @pytest .mark .parametrize ('model' , [
577
- 'abc' ,
578
- 4.2 ,
579
- list (),
580
- dict (),
581
- True ,
582
- - 1 ,
583
- 0 ,
584
- None
585
- ])
590
+ @pytest .mark .parametrize ('model' , INVALID_MODEL_ARGS )
586
591
def test_not_model (self , model ):
587
592
with pytest .raises (Exception ) as excinfo :
588
593
mlkit .create_model (model )
@@ -599,14 +604,7 @@ def test_missing_op_name(self):
599
604
mlkit .create_model (MODEL_1 )
600
605
check_error (excinfo , TypeError )
601
606
602
- @pytest .mark .parametrize ('op_name' , [
603
- 'abc' ,
604
- '123' ,
605
- 'projects/operations/project/1234/model/abc/operation/123' ,
606
- 'operations/project/model/abc/operation/123' ,
607
- 'operations/project/123/model/$#@/operation/123' ,
608
- 'operations/project/1234/model/abc/operation/123/extrathing' ,
609
- ])
607
+ @pytest .mark .parametrize ('op_name' , INVALID_OP_NAME_ARGS )
610
608
def test_invalid_op_name (self , op_name ):
611
609
payload = json .dumps ({'name' : op_name })
612
610
instrument_mlkit_service (status = 200 , payload = payload )
@@ -615,6 +613,105 @@ def test_invalid_op_name(self, op_name):
615
613
check_error (excinfo , ValueError , 'Operation name format is invalid.' )
616
614
617
615
616
+ class TestUpdateModel (object ):
617
+ """Tests mlkit.update_model."""
618
+ @classmethod
619
+ def setup_class (cls ):
620
+ cred = testutils .MockCredential ()
621
+ firebase_admin .initialize_app (cred , {'projectId' : PROJECT_ID })
622
+ mlkit ._MLKitService .POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
623
+
624
+ @classmethod
625
+ def teardown_class (cls ):
626
+ testutils .cleanup_apps ()
627
+
628
+ @staticmethod
629
+ def _url (project_id , model_id ):
630
+ return BASE_URL + 'projects/{0}/models/{1}' .format (project_id , model_id )
631
+
632
+ @staticmethod
633
+ def _op_url (project_id , model_id ):
634
+ return BASE_URL + \
635
+ 'operations/project/{0}/model/{1}/operation/123' .format (project_id , model_id )
636
+
637
+ def test_immediate_done (self ):
638
+ instrument_mlkit_service (status = 200 , payload = OPERATION_DONE_RESPONSE )
639
+ model = mlkit .update_model (MODEL_1 )
640
+ assert model == CREATED_UPDATED_MODEL_1
641
+
642
+ def test_returns_locked (self ):
643
+ recorder = instrument_mlkit_service (
644
+ status = [200 , 200 ],
645
+ payload = [OPERATION_NOT_DONE_RESPONSE , LOCKED_MODEL_2_RESPONSE ])
646
+ expected_model = mlkit .Model .from_dict (LOCKED_MODEL_JSON_2 )
647
+ model = mlkit .update_model (MODEL_1 )
648
+
649
+ assert model == expected_model
650
+ assert len (recorder ) == 2
651
+ assert recorder [0 ].method == 'PATCH'
652
+ assert recorder [0 ].url == TestUpdateModel ._url (PROJECT_ID , MODEL_ID_1 )
653
+ assert recorder [1 ].method == 'GET'
654
+ assert recorder [1 ].url == TestUpdateModel ._url (PROJECT_ID , MODEL_ID_1 )
655
+
656
+ def test_operation_error (self ):
657
+ instrument_mlkit_service (status = 200 , payload = OPERATION_ERROR_RESPONSE )
658
+ with pytest .raises (Exception ) as excinfo :
659
+ mlkit .update_model (MODEL_1 )
660
+ # The http request succeeded, the operation returned contains a create failure
661
+ check_operation_error (excinfo , OPERATION_ERROR_EXPECTED_STATUS , OPERATION_ERROR_MSG )
662
+
663
+ def test_malformed_operation (self ):
664
+ recorder = instrument_mlkit_service (
665
+ status = [200 , 200 ],
666
+ payload = [OPERATION_MALFORMED_RESPONSE , LOCKED_MODEL_2_RESPONSE ])
667
+ expected_model = mlkit .Model .from_dict (LOCKED_MODEL_JSON_2 )
668
+ model = mlkit .update_model (MODEL_1 )
669
+ assert model == expected_model
670
+ assert len (recorder ) == 2
671
+ assert recorder [0 ].method == 'PATCH'
672
+ assert recorder [0 ].url == TestUpdateModel ._url (PROJECT_ID , MODEL_ID_1 )
673
+ assert recorder [1 ].method == 'GET'
674
+ assert recorder [1 ].url == TestUpdateModel ._url (PROJECT_ID , MODEL_ID_1 )
675
+
676
+ def test_rpc_error_create (self ):
677
+ create_recorder = instrument_mlkit_service (
678
+ status = 400 , payload = ERROR_RESPONSE_BAD_REQUEST )
679
+ with pytest .raises (Exception ) as excinfo :
680
+ mlkit .update_model (MODEL_1 )
681
+ check_firebase_error (
682
+ excinfo ,
683
+ ERROR_STATUS_BAD_REQUEST ,
684
+ ERROR_CODE_BAD_REQUEST ,
685
+ ERROR_MSG_BAD_REQUEST
686
+ )
687
+ assert len (create_recorder ) == 1
688
+
689
+ @pytest .mark .parametrize ('model' , INVALID_MODEL_ARGS )
690
+ def test_not_model (self , model ):
691
+ with pytest .raises (Exception ) as excinfo :
692
+ mlkit .update_model (model )
693
+ check_error (excinfo , TypeError , 'Model must be an mlkit.Model.' )
694
+
695
+ def test_missing_display_name (self ):
696
+ with pytest .raises (Exception ) as excinfo :
697
+ mlkit .update_model (mlkit .Model .from_dict ({}))
698
+ check_error (excinfo , ValueError , 'Model must have a display name.' )
699
+
700
+ def test_missing_op_name (self ):
701
+ instrument_mlkit_service (status = 200 , payload = OPERATION_MISSING_NAME_RESPONSE )
702
+ with pytest .raises (Exception ) as excinfo :
703
+ mlkit .update_model (MODEL_1 )
704
+ check_error (excinfo , TypeError )
705
+
706
+ @pytest .mark .parametrize ('op_name' , INVALID_OP_NAME_ARGS )
707
+ def test_invalid_op_name (self , op_name ):
708
+ payload = json .dumps ({'name' : op_name })
709
+ instrument_mlkit_service (status = 200 , payload = payload )
710
+ with pytest .raises (Exception ) as excinfo :
711
+ mlkit .update_model (MODEL_1 )
712
+ check_error (excinfo , ValueError , 'Operation name format is invalid.' )
713
+
714
+
618
715
class TestGetModel (object ):
619
716
"""Tests mlkit.get_model."""
620
717
@classmethod
@@ -640,7 +737,7 @@ def test_get_model(self):
640
737
assert model .model_id == MODEL_ID_1
641
738
assert model .display_name == DISPLAY_NAME_1
642
739
643
- @pytest .mark .parametrize ('model_id, exc_type' , invalid_model_id_args )
740
+ @pytest .mark .parametrize ('model_id, exc_type' , INVALID_MODEL_ID_ARGS )
644
741
def test_get_model_validation_errors (self , model_id , exc_type ):
645
742
with pytest .raises (exc_type ) as excinfo :
646
743
mlkit .get_model (model_id )
@@ -690,7 +787,7 @@ def test_delete_model(self):
690
787
assert recorder [0 ].method == 'DELETE'
691
788
assert recorder [0 ].url == TestDeleteModel ._url (PROJECT_ID , MODEL_ID_1 )
692
789
693
- @pytest .mark .parametrize ('model_id, exc_type' , invalid_model_id_args )
790
+ @pytest .mark .parametrize ('model_id, exc_type' , INVALID_MODEL_ID_ARGS )
694
791
def test_delete_model_validation_errors (self , model_id , exc_type ):
695
792
with pytest .raises (exc_type ) as excinfo :
696
793
mlkit .delete_model (model_id )
@@ -771,7 +868,7 @@ def test_list_models_with_all_args(self):
771
868
assert models_page .models [0 ] == MODEL_3
772
869
assert not models_page .has_next_page
773
870
774
- @pytest .mark .parametrize ('list_filter' , invalid_string_or_none_args )
871
+ @pytest .mark .parametrize ('list_filter' , INVALID_STRING_OR_NONE_ARGS )
775
872
def test_list_models_list_filter_validation (self , list_filter ):
776
873
with pytest .raises (TypeError ) as excinfo :
777
874
mlkit .list_models (list_filter = list_filter )
@@ -792,7 +889,7 @@ def test_list_models_page_size_validation(self, page_size, exc_type, error_messa
792
889
mlkit .list_models (page_size = page_size )
793
890
check_error (excinfo , exc_type , error_message )
794
891
795
- @pytest .mark .parametrize ('page_token' , invalid_string_or_none_args )
892
+ @pytest .mark .parametrize ('page_token' , INVALID_STRING_OR_NONE_ARGS )
796
893
def test_list_models_page_token_validation (self , page_token ):
797
894
with pytest .raises (TypeError ) as excinfo :
798
895
mlkit .list_models (page_token = page_token )
0 commit comments