66
66
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
67
67
PATH_TO_TESTS = PATH_TO_REPO / "tests"
68
68
69
- # List here the pipelines to always test.
69
+ # Ignore fixtures in tests folder
70
+ # Ignore lora since they are always tested
71
+ MODULES_TO_IGNORE = ["fixtures" , "lora" ]
72
+
70
73
IMPORTANT_PIPELINES = [
71
74
"controlnet" ,
72
75
"stable_diffusion" ,
73
76
"stable_diffusion_2" ,
74
77
"stable_diffusion_xl" ,
78
+ "stable_video_diffusion" ,
75
79
"deepfloyd_if" ,
76
80
"kandinsky" ,
77
81
"kandinsky2_2" ,
78
82
"text_to_video_synthesis" ,
79
83
"wuerstchen" ,
80
84
]
81
85
82
- # Ignore fixtures in tests folder
83
- # Ignore lora since they are always tested
84
- MODULES_TO_IGNORE = ["fixtures" , "lora" ]
85
-
86
86
87
87
@contextmanager
88
88
def checkout_commit (repo : Repo , commit_id : str ):
@@ -289,10 +289,13 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
289
289
repo = Repo (PATH_TO_REPO )
290
290
291
291
if not diff_with_last_commit :
292
- print (f"main is at { repo .refs .main .commit } " )
292
+ # Need to fetch refs for main using remotes when running with github actions.
293
+ upstream_main = repo .remotes .origin .refs .main
294
+
295
+ print (f"main is at { upstream_main .commit } " )
293
296
print (f"Current head is at { repo .head .commit } " )
294
297
295
- branching_commits = repo .merge_base (repo . refs . main , repo .head )
298
+ branching_commits = repo .merge_base (upstream_main , repo .head )
296
299
for commit in branching_commits :
297
300
print (f"Branching commit: { commit } " )
298
301
return get_diff (repo , repo .head .commit , branching_commits )
@@ -415,10 +418,11 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
415
418
416
419
test_files_to_run = [] # noqa
417
420
if not diff_with_last_commit :
418
- print (f"main is at { repo .refs .main .commit } " )
421
+ upstream_main = repo .remotes .origin .refs .main
422
+ print (f"main is at { upstream_main .commit } " )
419
423
print (f"Current head is at { repo .head .commit } " )
420
424
421
- branching_commits = repo .merge_base (repo . refs . main , repo .head )
425
+ branching_commits = repo .merge_base (upstream_main , repo .head )
422
426
for commit in branching_commits :
423
427
print (f"Branching commit: { commit } " )
424
428
test_files_to_run = get_diff_for_doctesting (repo , repo .head .commit , branching_commits )
@@ -432,7 +436,7 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
432
436
all_test_files_to_run = get_all_doctest_files ()
433
437
434
438
# Add to the test files to run any removed entry from "utils/not_doctested.txt".
435
- new_test_files = get_new_doctest_files (repo , repo .head .commit , repo . refs . main .commit )
439
+ new_test_files = get_new_doctest_files (repo , repo .head .commit , upstream_main .commit )
436
440
test_files_to_run = list (set (test_files_to_run + new_test_files ))
437
441
438
442
# Do not run slow doctest tests on CircleCI
@@ -774,18 +778,16 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
774
778
return reverse_map
775
779
776
780
777
- def create_module_to_test_map (
778
- reverse_map : Dict [str , List [str ]] = None , filter_models : bool = False
779
- ) -> Dict [str , List [str ]]:
781
+ def create_module_to_test_map (reverse_map : Dict [str , List [str ]] = None ) -> Dict [str , List [str ]]:
780
782
"""
781
783
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
782
784
783
785
Args:
784
786
reverse_map (`Dict[str, List[str]]`, *optional*):
785
787
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
786
788
that function if not provided.
787
- filter_models (`bool`, *optional*, defaults to `False`):
788
- Whether or not to filter model tests to only include core models if a file impacts a lot of models.
789
+ filter_pipelines (`bool`, *optional*, defaults to `False`):
790
+ Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
789
791
790
792
Returns:
791
793
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
@@ -804,21 +806,7 @@ def is_test(fname):
804
806
# Build the test map
805
807
test_map = {module : [f for f in deps if is_test (f )] for module , deps in reverse_map .items ()}
806
808
807
- if not filter_models :
808
- return test_map
809
-
810
- # Now we deal with the filtering if `filter_models` is True.
811
- num_model_tests = len (list (PATH_TO_TESTS .glob ("models/*" )))
812
-
813
- def has_many_models (tests ):
814
- # We filter to core models when a given file impacts more than half the model tests.
815
- model_tests = {Path (t ).parts [2 ] for t in tests if t .startswith ("tests/models/" )}
816
- return len (model_tests ) > num_model_tests // 2
817
-
818
- def filter_tests (tests ):
819
- return [t for t in tests if not t .startswith ("tests/models/" ) or Path (t ).parts [2 ] in IMPORTANT_PIPELINES ]
820
-
821
- return {module : (filter_tests (tests ) if has_many_models (tests ) else tests ) for module , tests in test_map .items ()}
809
+ return test_map
822
810
823
811
824
812
def check_imports_all_exist ():
@@ -844,7 +832,39 @@ def _print_list(l) -> str:
844
832
return "\n " .join ([f"- { f } " for f in l ])
845
833
846
834
847
- def create_json_map (test_files_to_run : List [str ], json_output_file : str ):
835
+ def update_test_map_with_core_pipelines (json_output_file : str ):
836
+ print (f"\n ### ADD CORE PIPELINE TESTS ###\n { _print_list (IMPORTANT_PIPELINES )} " )
837
+ with open (json_output_file , "rb" ) as fp :
838
+ test_map = json .load (fp )
839
+
840
+ # Add core pipelines as their own test group
841
+ test_map ["core_pipelines" ] = " " .join (
842
+ sorted ([str (PATH_TO_TESTS / f"pipelines/{ pipe } " ) for pipe in IMPORTANT_PIPELINES ])
843
+ )
844
+
845
+ # If there are no existing pipeline tests save the map
846
+ if "pipelines" not in test_map :
847
+ with open (json_output_file , "w" , encoding = "UTF-8" ) as fp :
848
+ json .dump (test_map , fp , ensure_ascii = False )
849
+
850
+ pipeline_tests = test_map .pop ("pipelines" )
851
+ pipeline_tests = pipeline_tests .split (" " )
852
+
853
+ # Remove core pipeline tests from the fetched pipeline tests
854
+ updated_pipeline_tests = []
855
+ for pipe in pipeline_tests :
856
+ if pipe == "tests/pipelines" or Path (pipe ).parts [2 ] in IMPORTANT_PIPELINES :
857
+ continue
858
+ updated_pipeline_tests .append (pipe )
859
+
860
+ if len (updated_pipeline_tests ) > 0 :
861
+ test_map ["pipelines" ] = " " .join (sorted (updated_pipeline_tests ))
862
+
863
+ with open (json_output_file , "w" , encoding = "UTF-8" ) as fp :
864
+ json .dump (test_map , fp , ensure_ascii = False )
865
+
866
+
867
+ def create_json_map (test_files_to_run : List [str ], json_output_file : Optional [str ] = None ):
848
868
"""
849
869
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
850
870
@@ -881,14 +901,14 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
881
901
# sort the keys & values
882
902
keys = sorted (test_map .keys ())
883
903
test_map = {k : " " .join (sorted (test_map [k ])) for k in keys }
904
+
884
905
with open (json_output_file , "w" , encoding = "UTF-8" ) as fp :
885
906
json .dump (test_map , fp , ensure_ascii = False )
886
907
887
908
888
909
def infer_tests_to_run (
889
910
output_file : str ,
890
911
diff_with_last_commit : bool = False ,
891
- filter_models : bool = True ,
892
912
json_output_file : Optional [str ] = None ,
893
913
):
894
914
"""
@@ -929,8 +949,9 @@ def infer_tests_to_run(
929
949
# Grab the corresponding test files:
930
950
if any (x in modified_files for x in ["setup.py" ]):
931
951
test_files_to_run = ["tests" , "examples" ]
952
+
932
953
# in order to trigger pipeline tests even if no code change at all
933
- elif "tests/utils/tiny_model_summary.json" in modified_files :
954
+ if "tests/utils/tiny_model_summary.json" in modified_files :
934
955
test_files_to_run = ["tests" ]
935
956
any (f .split (os .path .sep )[0 ] == "utils" for f in modified_files )
936
957
else :
@@ -939,7 +960,7 @@ def infer_tests_to_run(
939
960
f for f in modified_files if f .startswith ("tests" ) and f .split (os .path .sep )[- 1 ].startswith ("test" )
940
961
]
941
962
# Then we grab the corresponding test files.
942
- test_map = create_module_to_test_map (reverse_map = reverse_map , filter_models = filter_models )
963
+ test_map = create_module_to_test_map (reverse_map = reverse_map )
943
964
for f in modified_files :
944
965
if f in test_map :
945
966
test_files_to_run .extend (test_map [f ])
@@ -1064,8 +1085,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
1064
1085
args = parser .parse_args ()
1065
1086
if args .print_dependencies_of is not None :
1066
1087
print_tree_deps_of (args .print_dependencies_of )
1067
- elif args .filter_tests :
1068
- filter_tests (args .output_file , ["pipelines" , "repo_utils" ])
1069
1088
else :
1070
1089
repo = Repo (PATH_TO_REPO )
1071
1090
commit_message = repo .head .commit .message
@@ -1089,9 +1108,10 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
1089
1108
args .output_file ,
1090
1109
diff_with_last_commit = diff_with_last_commit ,
1091
1110
json_output_file = args .json_output_file ,
1092
- filter_models = not commit_flags ["no_filter" ],
1093
1111
)
1094
1112
filter_tests (args .output_file , ["repo_utils" ])
1113
+ update_test_map_with_core_pipelines (json_output_file = args .json_output_file )
1114
+
1095
1115
except Exception as e :
1096
1116
print (f"\n Error when trying to grab the relevant tests: { e } \n \n Running all tests." )
1097
1117
commit_flags ["test_all" ] = True
@@ -1105,3 +1125,4 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
1105
1125
1106
1126
test_files_to_run = get_all_tests ()
1107
1127
create_json_map (test_files_to_run , args .json_output_file )
1128
+ update_test_map_with_core_pipelines (json_output_file = args .json_output_file )
0 commit comments