Skip to content

Update Tests Fetcher #5950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/pr_test_fetcher.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ jobs:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
fetch-depth: 0
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
echo $(git --version)
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
Expand Down Expand Up @@ -110,7 +111,7 @@ jobs:
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt

- name: Test suite reports artifacts
if: ${{ always() }}
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
"ruff>=0.1.5,<=0.2",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
"scipy",
"onnx",
"regex!=2019.12.17",
Expand Down Expand Up @@ -206,6 +207,7 @@ def run(self):
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
"compel",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ruff": "ruff>=0.1.5,<=0.2",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"regex": "regex!=2019.12.17",
Expand Down
95 changes: 58 additions & 37 deletions utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,23 @@
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
PATH_TO_TESTS = PATH_TO_REPO / "tests"

# List here the pipelines to always test.
# Ignore fixtures in tests folder
# Ignore lora since they are always tested
MODULES_TO_IGNORE = ["fixtures", "lora"]

IMPORTANT_PIPELINES = [
"controlnet",
"stable_diffusion",
"stable_diffusion_2",
"stable_diffusion_xl",
"stable_video_diffusion",
"deepfloyd_if",
"kandinsky",
"kandinsky2_2",
"text_to_video_synthesis",
"wuerstchen",
]

# Ignore fixtures in tests folder
# Ignore lora since they are always tested
MODULES_TO_IGNORE = ["fixtures", "lora"]


@contextmanager
def checkout_commit(repo: Repo, commit_id: str):
Expand Down Expand Up @@ -289,10 +289,13 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
repo = Repo(PATH_TO_REPO)

if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
# Need to fetch refs for main using remotes when running with github actions.
upstream_main = repo.remotes.origin.refs.main

print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")

branching_commits = repo.merge_base(repo.refs.main, repo.head)
branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
return get_diff(repo, repo.head.commit, branching_commits)
Expand Down Expand Up @@ -415,10 +418,11 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:

test_files_to_run = [] # noqa
if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
upstream_main = repo.remotes.origin.refs.main
print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")

branching_commits = repo.merge_base(repo.refs.main, repo.head)
branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
Expand All @@ -432,7 +436,7 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
all_test_files_to_run = get_all_doctest_files()

# Add to the test files to run any removed entry from "utils/not_doctested.txt".
new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit)
new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
test_files_to_run = list(set(test_files_to_run + new_test_files))

# Do not run slow doctest tests on CircleCI
Expand Down Expand Up @@ -774,18 +778,16 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
return reverse_map


def create_module_to_test_map(
reverse_map: Dict[str, List[str]] = None, filter_models: bool = False
) -> Dict[str, List[str]]:
def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
"""
Extract the tests from the reverse_dependency_map and potentially filters the model tests.

Args:
reverse_map (`Dict[str, List[str]]`, *optional*):
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
that function if not provided.
filter_models (`bool`, *optional*, defaults to `False`):
Whether or not to filter model tests to only include core models if a file impacts a lot of models.
filter_pipelines (`bool`, *optional*, defaults to `False`):
Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.

Returns:
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
Expand All @@ -804,21 +806,7 @@ def is_test(fname):
# Build the test map
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}

if not filter_models:
return test_map

# Now we deal with the filtering if `filter_models` is True.
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))

def has_many_models(tests):
# We filter to core models when a given file impacts more than half the model tests.
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
return len(model_tests) > num_model_tests // 2

def filter_tests(tests):
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_PIPELINES]

return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
return test_map


def check_imports_all_exist():
Expand All @@ -844,7 +832,39 @@ def _print_list(l) -> str:
return "\n".join([f"- {f}" for f in l])


def create_json_map(test_files_to_run: List[str], json_output_file: str):
def update_test_map_with_core_pipelines(json_output_file: str):
print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
with open(json_output_file, "rb") as fp:
test_map = json.load(fp)

# Add core pipelines as their own test group
test_map["core_pipelines"] = " ".join(
sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
)

# If there are no existing pipeline tests save the map
if "pipelines" not in test_map:
with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)

pipeline_tests = test_map.pop("pipelines")
pipeline_tests = pipeline_tests.split(" ")

# Remove core pipeline tests from the fetched pipeline tests
updated_pipeline_tests = []
for pipe in pipeline_tests:
if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
continue
updated_pipeline_tests.append(pipe)

if len(updated_pipeline_tests) > 0:
test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))

with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)


def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
"""
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.

Expand Down Expand Up @@ -881,14 +901,14 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
# sort the keys & values
keys = sorted(test_map.keys())
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}

with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)


def infer_tests_to_run(
output_file: str,
diff_with_last_commit: bool = False,
filter_models: bool = True,
json_output_file: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -929,8 +949,9 @@ def infer_tests_to_run(
# Grab the corresponding test files:
if any(x in modified_files for x in ["setup.py"]):
test_files_to_run = ["tests", "examples"]

# in order to trigger pipeline tests even if no code change at all
elif "tests/utils/tiny_model_summary.json" in modified_files:
if "tests/utils/tiny_model_summary.json" in modified_files:
test_files_to_run = ["tests"]
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
else:
Expand All @@ -939,7 +960,7 @@ def infer_tests_to_run(
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
]
# Then we grab the corresponding test files.
test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
test_map = create_module_to_test_map(reverse_map=reverse_map)
for f in modified_files:
if f in test_map:
test_files_to_run.extend(test_map[f])
Expand Down Expand Up @@ -1064,8 +1085,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args = parser.parse_args()
if args.print_dependencies_of is not None:
print_tree_deps_of(args.print_dependencies_of)
elif args.filter_tests:
filter_tests(args.output_file, ["pipelines", "repo_utils"])
else:
repo = Repo(PATH_TO_REPO)
commit_message = repo.head.commit.message
Expand All @@ -1089,9 +1108,10 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args.output_file,
diff_with_last_commit=diff_with_last_commit,
json_output_file=args.json_output_file,
filter_models=not commit_flags["no_filter"],
)
filter_tests(args.output_file, ["repo_utils"])
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)

except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
commit_flags["test_all"] = True
Expand All @@ -1105,3 +1125,4 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:

test_files_to_run = get_all_tests()
create_json_map(test_files_to_run, args.json_output_file)
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)