|
18 | 18 | import torch.nn.functional as F
|
19 | 19 | from torch import nn
|
20 | 20 |
|
| 21 | +from ..utils import is_torch_version |
21 | 22 | from .attention import AdaGroupNorm
|
22 | 23 | from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
23 | 24 | from .dual_transformer_2d import DualTransformer2DModel
|
@@ -866,13 +867,27 @@ def custom_forward(*inputs):
|
866 | 867 |
|
867 | 868 | return custom_forward
|
868 | 869 |
|
869 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
870 |
| - hidden_states = torch.utils.checkpoint.checkpoint( |
871 |
| - create_custom_forward(attn, return_dict=False), |
872 |
| - hidden_states, |
873 |
| - encoder_hidden_states, |
874 |
| - cross_attention_kwargs, |
875 |
| - )[0] |
| 870 | + if is_torch_version(">=", "1.11.0"): |
| 871 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 872 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 873 | + ) |
| 874 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 875 | + create_custom_forward(attn, return_dict=False), |
| 876 | + hidden_states, |
| 877 | + encoder_hidden_states, |
| 878 | + cross_attention_kwargs, |
| 879 | + use_reentrant=False, |
| 880 | + )[0] |
| 881 | + else: |
| 882 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 883 | + create_custom_forward(resnet), hidden_states, temb |
| 884 | + ) |
| 885 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 886 | + create_custom_forward(attn, return_dict=False), |
| 887 | + hidden_states, |
| 888 | + encoder_hidden_states, |
| 889 | + cross_attention_kwargs, |
| 890 | + )[0] |
876 | 891 | else:
|
877 | 892 | hidden_states = resnet(hidden_states, temb)
|
878 | 893 | hidden_states = attn(
|
@@ -957,7 +972,14 @@ def custom_forward(*inputs):
|
957 | 972 |
|
958 | 973 | return custom_forward
|
959 | 974 |
|
960 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 975 | + if is_torch_version(">=", "1.11.0"): |
| 976 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 977 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 978 | + ) |
| 979 | + else: |
| 980 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 981 | + create_custom_forward(resnet), hidden_states, temb |
| 982 | + ) |
961 | 983 | else:
|
962 | 984 | hidden_states = resnet(hidden_states, temb)
|
963 | 985 |
|
@@ -1361,7 +1383,14 @@ def custom_forward(*inputs):
|
1361 | 1383 |
|
1362 | 1384 | return custom_forward
|
1363 | 1385 |
|
1364 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 1386 | + if is_torch_version(">=", "1.11.0"): |
| 1387 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1388 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 1389 | + ) |
| 1390 | + else: |
| 1391 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1392 | + create_custom_forward(resnet), hidden_states, temb |
| 1393 | + ) |
1365 | 1394 | else:
|
1366 | 1395 | hidden_states = resnet(hidden_states, temb)
|
1367 | 1396 |
|
@@ -1558,7 +1587,14 @@ def custom_forward(*inputs):
|
1558 | 1587 |
|
1559 | 1588 | return custom_forward
|
1560 | 1589 |
|
1561 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 1590 | + if is_torch_version(">=", "1.11.0"): |
| 1591 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1592 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 1593 | + ) |
| 1594 | + else: |
| 1595 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1596 | + create_custom_forward(resnet), hidden_states, temb |
| 1597 | + ) |
1562 | 1598 | else:
|
1563 | 1599 | hidden_states = resnet(hidden_states, temb)
|
1564 | 1600 |
|
@@ -1653,14 +1689,29 @@ def custom_forward(*inputs):
|
1653 | 1689 |
|
1654 | 1690 | return custom_forward
|
1655 | 1691 |
|
1656 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
1657 |
| - hidden_states = torch.utils.checkpoint.checkpoint( |
1658 |
| - create_custom_forward(attn, return_dict=False), |
1659 |
| - hidden_states, |
1660 |
| - encoder_hidden_states, |
1661 |
| - attention_mask, |
1662 |
| - cross_attention_kwargs, |
1663 |
| - ) |
| 1692 | + if is_torch_version(">=", "1.11.0"): |
| 1693 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1694 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 1695 | + ) |
| 1696 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1697 | + create_custom_forward(attn, return_dict=False), |
| 1698 | + hidden_states, |
| 1699 | + encoder_hidden_states, |
| 1700 | + attention_mask, |
| 1701 | + cross_attention_kwargs, |
| 1702 | + use_reentrant=False, |
| 1703 | + ) |
| 1704 | + else: |
| 1705 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1706 | + create_custom_forward(resnet), hidden_states, temb |
| 1707 | + ) |
| 1708 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1709 | + create_custom_forward(attn, return_dict=False), |
| 1710 | + hidden_states, |
| 1711 | + encoder_hidden_states, |
| 1712 | + attention_mask, |
| 1713 | + cross_attention_kwargs, |
| 1714 | + ) |
1664 | 1715 | else:
|
1665 | 1716 | hidden_states = resnet(hidden_states, temb)
|
1666 | 1717 | hidden_states = attn(
|
@@ -1874,13 +1925,27 @@ def custom_forward(*inputs):
|
1874 | 1925 |
|
1875 | 1926 | return custom_forward
|
1876 | 1927 |
|
1877 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
1878 |
| - hidden_states = torch.utils.checkpoint.checkpoint( |
1879 |
| - create_custom_forward(attn, return_dict=False), |
1880 |
| - hidden_states, |
1881 |
| - encoder_hidden_states, |
1882 |
| - cross_attention_kwargs, |
1883 |
| - )[0] |
| 1928 | + if is_torch_version(">=", "1.11.0"): |
| 1929 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1930 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 1931 | + ) |
| 1932 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1933 | + create_custom_forward(attn, return_dict=False), |
| 1934 | + hidden_states, |
| 1935 | + encoder_hidden_states, |
| 1936 | + cross_attention_kwargs, |
| 1937 | + use_reentrant=False, |
| 1938 | + )[0] |
| 1939 | + else: |
| 1940 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1941 | + create_custom_forward(resnet), hidden_states, temb |
| 1942 | + ) |
| 1943 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 1944 | + create_custom_forward(attn, return_dict=False), |
| 1945 | + hidden_states, |
| 1946 | + encoder_hidden_states, |
| 1947 | + cross_attention_kwargs, |
| 1948 | + )[0] |
1884 | 1949 | else:
|
1885 | 1950 | hidden_states = resnet(hidden_states, temb)
|
1886 | 1951 | hidden_states = attn(
|
@@ -1960,7 +2025,14 @@ def custom_forward(*inputs):
|
1960 | 2025 |
|
1961 | 2026 | return custom_forward
|
1962 | 2027 |
|
1963 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 2028 | + if is_torch_version(">=", "1.11.0"): |
| 2029 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2030 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 2031 | + ) |
| 2032 | + else: |
| 2033 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2034 | + create_custom_forward(resnet), hidden_states, temb |
| 2035 | + ) |
1964 | 2036 | else:
|
1965 | 2037 | hidden_states = resnet(hidden_states, temb)
|
1966 | 2038 |
|
@@ -2388,7 +2460,14 @@ def custom_forward(*inputs):
|
2388 | 2460 |
|
2389 | 2461 | return custom_forward
|
2390 | 2462 |
|
2391 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 2463 | + if is_torch_version(">=", "1.11.0"): |
| 2464 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2465 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 2466 | + ) |
| 2467 | + else: |
| 2468 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2469 | + create_custom_forward(resnet), hidden_states, temb |
| 2470 | + ) |
2392 | 2471 | else:
|
2393 | 2472 | hidden_states = resnet(hidden_states, temb)
|
2394 | 2473 |
|
@@ -2593,7 +2672,14 @@ def custom_forward(*inputs):
|
2593 | 2672 |
|
2594 | 2673 | return custom_forward
|
2595 | 2674 |
|
2596 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| 2675 | + if is_torch_version(">=", "1.11.0"): |
| 2676 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2677 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 2678 | + ) |
| 2679 | + else: |
| 2680 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2681 | + create_custom_forward(resnet), hidden_states, temb |
| 2682 | + ) |
2597 | 2683 | else:
|
2598 | 2684 | hidden_states = resnet(hidden_states, temb)
|
2599 | 2685 |
|
@@ -2714,14 +2800,29 @@ def custom_forward(*inputs):
|
2714 | 2800 |
|
2715 | 2801 | return custom_forward
|
2716 | 2802 |
|
2717 |
| - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
2718 |
| - hidden_states = torch.utils.checkpoint.checkpoint( |
2719 |
| - create_custom_forward(attn, return_dict=False), |
2720 |
| - hidden_states, |
2721 |
| - encoder_hidden_states, |
2722 |
| - attention_mask, |
2723 |
| - cross_attention_kwargs, |
2724 |
| - )[0] |
| 2803 | + if is_torch_version(">=", "1.11.0"): |
| 2804 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2805 | + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
| 2806 | + ) |
| 2807 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2808 | + create_custom_forward(attn, return_dict=False), |
| 2809 | + hidden_states, |
| 2810 | + encoder_hidden_states, |
| 2811 | + attention_mask, |
| 2812 | + cross_attention_kwargs, |
| 2813 | + use_reentrant=False, |
| 2814 | + )[0] |
| 2815 | + else: |
| 2816 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2817 | + create_custom_forward(resnet), hidden_states, temb |
| 2818 | + ) |
| 2819 | + hidden_states = torch.utils.checkpoint.checkpoint( |
| 2820 | + create_custom_forward(attn, return_dict=False), |
| 2821 | + hidden_states, |
| 2822 | + encoder_hidden_states, |
| 2823 | + attention_mask, |
| 2824 | + cross_attention_kwargs, |
| 2825 | + )[0] |
2725 | 2826 | else:
|
2726 | 2827 | hidden_states = resnet(hidden_states, temb)
|
2727 | 2828 | hidden_states = attn(
|
|
0 commit comments