-
Notifications
You must be signed in to change notification settings - Fork 6k
[DPMSolverSinglestepScheduler] correct get_order_list
for solver_order=2
and lower_order_final=True
#6953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DPMSolverSinglestepScheduler] correct get_order_list
for solver_order=2
and lower_order_final=True
#6953
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,7 @@ def __init__( | |
sample_max_value: float = 1.0, | ||
algorithm_type: str = "dpmsolver++", | ||
solver_type: str = "midpoint", | ||
lower_order_final: bool = True, | ||
lower_order_final: bool = False, | ||
use_karras_sigmas: Optional[bool] = False, | ||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" | ||
lambda_min_clipped: float = -float("inf"), | ||
|
@@ -233,7 +233,7 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: | |
orders = [1, 2, 3] * (steps // 3) + [1, 2] | ||
elif order == 2: | ||
if steps % 2 == 0: | ||
orders = [1, 2] * (steps // 2) | ||
orders = [1, 2] * (steps // 2 - 1) + [1, 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this only necessary for DPMSingleStepSolver? Not for the other ones? |
||
else: | ||
orders = [1, 2] * (steps // 2) + [1] | ||
elif order == 1: | ||
|
@@ -320,7 +320,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
|
||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: | ||
logger.warn( | ||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." | ||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a typo here but let me know if it's not @patrickvonplaten There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes LGTM - thanks! |
||
) | ||
self.register_to_config(lower_order_final=True) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the default value so that we keep the default behavior the same as before @patrickvonplaten
this way I don't have to update test