@@ -657,7 +657,7 @@ def test_operation_error(self):
657
657
instrument_mlkit_service (status = 200 , payload = OPERATION_ERROR_RESPONSE )
658
658
with pytest .raises (Exception ) as excinfo :
659
659
mlkit .update_model (MODEL_1 )
660
- # The http request succeeded, the operation returned contains a create failure
660
+ # The http request succeeded, the operation returned contains an update failure
661
661
check_operation_error (excinfo , OPERATION_ERROR_EXPECTED_STATUS , OPERATION_ERROR_MSG )
662
662
663
663
def test_malformed_operation (self ):
@@ -673,7 +673,7 @@ def test_malformed_operation(self):
673
673
assert recorder [1 ].method == 'GET'
674
674
assert recorder [1 ].url == TestUpdateModel ._url (PROJECT_ID , MODEL_ID_1 )
675
675
676
- def test_rpc_error_create (self ):
676
+ def test_rpc_error (self ):
677
677
create_recorder = instrument_mlkit_service (
678
678
status = 400 , payload = ERROR_RESPONSE_BAD_REQUEST )
679
679
with pytest .raises (Exception ) as excinfo :
@@ -712,6 +712,97 @@ def test_invalid_op_name(self, op_name):
712
712
check_error (excinfo , ValueError , 'Operation name format is invalid.' )
713
713
714
714
715
+ class TestPublishUnpublish (object ):
716
+ """Tests mlkit.publish_model and mlkit.unpublish_model."""
717
+
718
+ PUBLISH_UNPUBLISH_WITH_ARGS = [
719
+ (mlkit .publish_model , True ),
720
+ (mlkit .unpublish_model , False )
721
+ ]
722
+ PUBLISH_UNPUBLISH_FUNCS = [item [0 ] for item in PUBLISH_UNPUBLISH_WITH_ARGS ]
723
+
724
+ @classmethod
725
+ def setup_class (cls ):
726
+ cred = testutils .MockCredential ()
727
+ firebase_admin .initialize_app (cred , {'projectId' : PROJECT_ID })
728
+ mlkit ._MLKitService .POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
729
+
730
+ @classmethod
731
+ def teardown_class (cls ):
732
+ testutils .cleanup_apps ()
733
+
734
+ @staticmethod
735
+ def _url (project_id , model_id ):
736
+ return BASE_URL + 'projects/{0}/models/{1}' .format (project_id , model_id )
737
+
738
+ @staticmethod
739
+ def _op_url (project_id , model_id ):
740
+ return BASE_URL + \
741
+ 'operations/project/{0}/model/{1}/operation/123' .format (project_id , model_id )
742
+
743
+ @pytest .mark .parametrize ('publish_function, published' , PUBLISH_UNPUBLISH_WITH_ARGS )
744
+ def test_immediate_done (self , publish_function , published ):
745
+ recorder = instrument_mlkit_service (status = 200 , payload = OPERATION_DONE_RESPONSE )
746
+ model = publish_function (MODEL_ID_1 )
747
+ assert model == CREATED_UPDATED_MODEL_1
748
+ assert len (recorder ) == 1
749
+ assert recorder [0 ].method == 'PATCH'
750
+ assert recorder [0 ].url == TestPublishUnpublish ._url (PROJECT_ID , MODEL_ID_1 )
751
+ body = json .loads (recorder [0 ].body .decode ())
752
+ assert body .get ('model' , {}).get ('state' , {}).get ('published' , None ) is published
753
+ assert body .get ('updateMask' , {}) == 'state.published'
754
+
755
+ @pytest .mark .parametrize ('publish_function' , PUBLISH_UNPUBLISH_FUNCS )
756
+ def test_returns_locked (self , publish_function ):
757
+ recorder = instrument_mlkit_service (
758
+ status = [200 , 200 ],
759
+ payload = [OPERATION_NOT_DONE_RESPONSE , LOCKED_MODEL_2_RESPONSE ])
760
+ expected_model = mlkit .Model .from_dict (LOCKED_MODEL_JSON_2 )
761
+ model = publish_function (MODEL_ID_1 )
762
+
763
+ assert model == expected_model
764
+ assert len (recorder ) == 2
765
+ assert recorder [0 ].method == 'PATCH'
766
+ assert recorder [0 ].url == TestPublishUnpublish ._url (PROJECT_ID , MODEL_ID_1 )
767
+ assert recorder [1 ].method == 'GET'
768
+ assert recorder [1 ].url == TestPublishUnpublish ._url (PROJECT_ID , MODEL_ID_1 )
769
+
770
+ @pytest .mark .parametrize ('publish_function' , PUBLISH_UNPUBLISH_FUNCS )
771
+ def test_operation_error (self , publish_function ):
772
+ instrument_mlkit_service (status = 200 , payload = OPERATION_ERROR_RESPONSE )
773
+ with pytest .raises (Exception ) as excinfo :
774
+ publish_function (MODEL_ID_1 )
775
+ # The http request succeeded, the operation returned contains an update failure
776
+ check_operation_error (excinfo , OPERATION_ERROR_EXPECTED_STATUS , OPERATION_ERROR_MSG )
777
+
778
+ @pytest .mark .parametrize ('publish_function' , PUBLISH_UNPUBLISH_FUNCS )
779
+ def test_malformed_operation (self , publish_function ):
780
+ recorder = instrument_mlkit_service (
781
+ status = [200 , 200 ],
782
+ payload = [OPERATION_MALFORMED_RESPONSE , LOCKED_MODEL_2_RESPONSE ])
783
+ expected_model = mlkit .Model .from_dict (LOCKED_MODEL_JSON_2 )
784
+ model = publish_function (MODEL_ID_1 )
785
+ assert model == expected_model
786
+ assert len (recorder ) == 2
787
+ assert recorder [0 ].method == 'PATCH'
788
+ assert recorder [0 ].url == TestPublishUnpublish ._url (PROJECT_ID , MODEL_ID_1 )
789
+ assert recorder [1 ].method == 'GET'
790
+ assert recorder [1 ].url == TestPublishUnpublish ._url (PROJECT_ID , MODEL_ID_1 )
791
+
792
+ @pytest .mark .parametrize ('publish_function' , PUBLISH_UNPUBLISH_FUNCS )
793
+ def test_rpc_error (self , publish_function ):
794
+ create_recorder = instrument_mlkit_service (
795
+ status = 400 , payload = ERROR_RESPONSE_BAD_REQUEST )
796
+ with pytest .raises (Exception ) as excinfo :
797
+ publish_function (MODEL_ID_1 )
798
+ check_firebase_error (
799
+ excinfo ,
800
+ ERROR_STATUS_BAD_REQUEST ,
801
+ ERROR_CODE_BAD_REQUEST ,
802
+ ERROR_MSG_BAD_REQUEST
803
+ )
804
+ assert len (create_recorder ) == 1
805
+
715
806
class TestGetModel (object ):
716
807
"""Tests mlkit.get_model."""
717
808
@classmethod
0 commit comments