-
Notifications
You must be signed in to change notification settings - Fork 364
/
Copy pathtest_modelopt_models.py
117 lines (98 loc) · 3.86 KB
/
test_modelopt_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# type: ignore
import importlib
import platform
import unittest
from importlib import metadata
import pytest
import torch
import torch_tensorrt as torchtrt
from packaging.version import Version
assertions = unittest.TestCase()
@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
"ModelOpt is required to run this test",
)
@pytest.mark.unit
def test_base_fp8():
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.ReLU()(x)
x = self.linear2(x)
return x
def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
input_tensor = torch.randn(1, 10).cuda()
model = SimpleNetwork().eval().cuda()
quant_cfg = mtq.FP8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
output_pyt = model(input_tensor)
with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,), strict=False)
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.float8_e4m3fn},
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)
@unittest.skipIf(
platform.system() != "Linux"
or not importlib.util.find_spec("modelopt")
or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"),
"modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux",
)
@pytest.mark.unit
def test_base_int8():
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.ReLU()(x)
x = self.linear2(x)
return x
def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
input_tensor = torch.randn(1, 10).cuda()
model = SimpleNetwork().eval().cuda()
quant_cfg = mtq.INT8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has INT8 qdq nodes at this point
output_pyt = model(input_tensor)
with torchtrt.logging.debug(), torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,), strict=False)
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.int8},
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
truncate_double=True,
debug=True,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)