Skip to content

Commit b217292

Browse files
Update Tests Fetcher (#5950)
* update setup and deps table * update * update * update * up * up * update * up * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * quality fix * fix failure reporting --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 8a812e4 commit b217292

File tree

4 files changed

+65
-40
lines changed

4 files changed

+65
-40
lines changed

.github/workflows/pr_test_fetcher.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ jobs:
3535
- name: Checkout diffusers
3636
uses: actions/checkout@v3
3737
with:
38-
fetch-depth: 2
38+
fetch-depth: 0
3939
- name: Install dependencies
4040
run: |
4141
apt-get update && apt-get install libsndfile1-dev libgl1 -y
42-
python -m pip install -e .
42+
python -m pip install -e .[quality,test]
4343
- name: Environment
4444
run: |
4545
python utils/print_env.py
46+
echo $(git --version)
4647
- name: Fetch Tests
4748
run: |
4849
python utils/tests_fetcher.py | tee test_preparation.txt
@@ -110,7 +111,7 @@ jobs:
110111
continue-on-error: true
111112
run: |
112113
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
113-
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
114+
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
114115
115116
- name: Test suite reports artifacts
116117
if: ${{ always() }}

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
"ruff>=0.1.5,<=0.2",
122122
"safetensors>=0.3.1",
123123
"sentencepiece>=0.1.91,!=0.1.92",
124+
"GitPython<3.1.19",
124125
"scipy",
125126
"onnx",
126127
"regex!=2019.12.17",
@@ -206,6 +207,7 @@ def run(self):
206207
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
207208
extras["test"] = deps_list(
208209
"compel",
210+
"GitPython",
209211
"datasets",
210212
"Jinja2",
211213
"invisible-watermark",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"ruff": "ruff>=0.1.5,<=0.2",
3434
"safetensors": "safetensors>=0.3.1",
3535
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
36+
"GitPython": "GitPython<3.1.19",
3637
"scipy": "scipy",
3738
"onnx": "onnx",
3839
"regex": "regex!=2019.12.17",

utils/tests_fetcher.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,23 @@
6666
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
6767
PATH_TO_TESTS = PATH_TO_REPO / "tests"
6868

69-
# List here the pipelines to always test.
69+
# Ignore fixtures in tests folder
70+
# Ignore lora since they are always tested
71+
MODULES_TO_IGNORE = ["fixtures", "lora"]
72+
7073
IMPORTANT_PIPELINES = [
7174
"controlnet",
7275
"stable_diffusion",
7376
"stable_diffusion_2",
7477
"stable_diffusion_xl",
78+
"stable_video_diffusion",
7579
"deepfloyd_if",
7680
"kandinsky",
7781
"kandinsky2_2",
7882
"text_to_video_synthesis",
7983
"wuerstchen",
8084
]
8185

82-
# Ignore fixtures in tests folder
83-
# Ignore lora since they are always tested
84-
MODULES_TO_IGNORE = ["fixtures", "lora"]
85-
8686

8787
@contextmanager
8888
def checkout_commit(repo: Repo, commit_id: str):
@@ -289,10 +289,13 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
289289
repo = Repo(PATH_TO_REPO)
290290

291291
if not diff_with_last_commit:
292-
print(f"main is at {repo.refs.main.commit}")
292+
# Need to fetch refs for main using remotes when running with github actions.
293+
upstream_main = repo.remotes.origin.refs.main
294+
295+
print(f"main is at {upstream_main.commit}")
293296
print(f"Current head is at {repo.head.commit}")
294297

295-
branching_commits = repo.merge_base(repo.refs.main, repo.head)
298+
branching_commits = repo.merge_base(upstream_main, repo.head)
296299
for commit in branching_commits:
297300
print(f"Branching commit: {commit}")
298301
return get_diff(repo, repo.head.commit, branching_commits)
@@ -415,10 +418,11 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
415418

416419
test_files_to_run = [] # noqa
417420
if not diff_with_last_commit:
418-
print(f"main is at {repo.refs.main.commit}")
421+
upstream_main = repo.remotes.origin.refs.main
422+
print(f"main is at {upstream_main.commit}")
419423
print(f"Current head is at {repo.head.commit}")
420424

421-
branching_commits = repo.merge_base(repo.refs.main, repo.head)
425+
branching_commits = repo.merge_base(upstream_main, repo.head)
422426
for commit in branching_commits:
423427
print(f"Branching commit: {commit}")
424428
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
@@ -432,7 +436,7 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
432436
all_test_files_to_run = get_all_doctest_files()
433437

434438
# Add to the test files to run any removed entry from "utils/not_doctested.txt".
435-
new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit)
439+
new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
436440
test_files_to_run = list(set(test_files_to_run + new_test_files))
437441

438442
# Do not run slow doctest tests on CircleCI
@@ -774,18 +778,16 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
774778
return reverse_map
775779

776780

777-
def create_module_to_test_map(
778-
reverse_map: Dict[str, List[str]] = None, filter_models: bool = False
779-
) -> Dict[str, List[str]]:
781+
def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
780782
"""
781783
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
782784
783785
Args:
784786
reverse_map (`Dict[str, List[str]]`, *optional*):
785787
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
786788
that function if not provided.
787-
filter_models (`bool`, *optional*, defaults to `False`):
788-
Whether or not to filter model tests to only include core models if a file impacts a lot of models.
789+
filter_pipelines (`bool`, *optional*, defaults to `False`):
790+
Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
789791
790792
Returns:
791793
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
@@ -804,21 +806,7 @@ def is_test(fname):
804806
# Build the test map
805807
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
806808

807-
if not filter_models:
808-
return test_map
809-
810-
# Now we deal with the filtering if `filter_models` is True.
811-
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
812-
813-
def has_many_models(tests):
814-
# We filter to core models when a given file impacts more than half the model tests.
815-
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
816-
return len(model_tests) > num_model_tests // 2
817-
818-
def filter_tests(tests):
819-
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_PIPELINES]
820-
821-
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
809+
return test_map
822810

823811

824812
def check_imports_all_exist():
@@ -844,7 +832,39 @@ def _print_list(l) -> str:
844832
return "\n".join([f"- {f}" for f in l])
845833

846834

847-
def create_json_map(test_files_to_run: List[str], json_output_file: str):
835+
def update_test_map_with_core_pipelines(json_output_file: str):
836+
print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
837+
with open(json_output_file, "rb") as fp:
838+
test_map = json.load(fp)
839+
840+
# Add core pipelines as their own test group
841+
test_map["core_pipelines"] = " ".join(
842+
sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
843+
)
844+
845+
# If there are no existing pipeline tests save the map
846+
if "pipelines" not in test_map:
847+
with open(json_output_file, "w", encoding="UTF-8") as fp:
848+
json.dump(test_map, fp, ensure_ascii=False)
849+
850+
pipeline_tests = test_map.pop("pipelines")
851+
pipeline_tests = pipeline_tests.split(" ")
852+
853+
# Remove core pipeline tests from the fetched pipeline tests
854+
updated_pipeline_tests = []
855+
for pipe in pipeline_tests:
856+
if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
857+
continue
858+
updated_pipeline_tests.append(pipe)
859+
860+
if len(updated_pipeline_tests) > 0:
861+
test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))
862+
863+
with open(json_output_file, "w", encoding="UTF-8") as fp:
864+
json.dump(test_map, fp, ensure_ascii=False)
865+
866+
867+
def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
848868
"""
849869
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
850870
@@ -881,14 +901,14 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
881901
# sort the keys & values
882902
keys = sorted(test_map.keys())
883903
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
904+
884905
with open(json_output_file, "w", encoding="UTF-8") as fp:
885906
json.dump(test_map, fp, ensure_ascii=False)
886907

887908

888909
def infer_tests_to_run(
889910
output_file: str,
890911
diff_with_last_commit: bool = False,
891-
filter_models: bool = True,
892912
json_output_file: Optional[str] = None,
893913
):
894914
"""
@@ -929,8 +949,9 @@ def infer_tests_to_run(
929949
# Grab the corresponding test files:
930950
if any(x in modified_files for x in ["setup.py"]):
931951
test_files_to_run = ["tests", "examples"]
952+
932953
# in order to trigger pipeline tests even if no code change at all
933-
elif "tests/utils/tiny_model_summary.json" in modified_files:
954+
if "tests/utils/tiny_model_summary.json" in modified_files:
934955
test_files_to_run = ["tests"]
935956
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
936957
else:
@@ -939,7 +960,7 @@ def infer_tests_to_run(
939960
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
940961
]
941962
# Then we grab the corresponding test files.
942-
test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
963+
test_map = create_module_to_test_map(reverse_map=reverse_map)
943964
for f in modified_files:
944965
if f in test_map:
945966
test_files_to_run.extend(test_map[f])
@@ -1064,8 +1085,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
10641085
args = parser.parse_args()
10651086
if args.print_dependencies_of is not None:
10661087
print_tree_deps_of(args.print_dependencies_of)
1067-
elif args.filter_tests:
1068-
filter_tests(args.output_file, ["pipelines", "repo_utils"])
10691088
else:
10701089
repo = Repo(PATH_TO_REPO)
10711090
commit_message = repo.head.commit.message
@@ -1089,9 +1108,10 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
10891108
args.output_file,
10901109
diff_with_last_commit=diff_with_last_commit,
10911110
json_output_file=args.json_output_file,
1092-
filter_models=not commit_flags["no_filter"],
10931111
)
10941112
filter_tests(args.output_file, ["repo_utils"])
1113+
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)
1114+
10951115
except Exception as e:
10961116
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
10971117
commit_flags["test_all"] = True
@@ -1105,3 +1125,4 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
11051125

11061126
test_files_to_run = get_all_tests()
11071127
create_json_map(test_files_to_run, args.json_output_file)
1128+
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)

0 commit comments

Comments
 (0)