-
Notifications
You must be signed in to change notification settings - Fork 364
chore: Re-test BF16 fixes on main, refactor test suite #3490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
peri044
wants to merge
17
commits into
main
Choose a base branch
from
bf16_fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
d18c013
fix: Fix BF16 compilation issues
peri044 daa97a8
chore: minor fixes
peri044 008f3d4
chore: minor fix
peri044 a7b6304
chore: revert bf16 enum fix
peri044 062a94b
chore: fix CI failures
peri044 7105474
chore: bug fix
peri044 7e9d388
chore: fix CI test failures
peri044 088e7c8
Merge branch 'main' into bf16_fix
peri044 2c61ccb
chore: additional CI test failure fixes
peri044 7bc6eaf
chore: updates
peri044 0d5e91f
chore: updates
peri044 c748fac
chore: updates
peri044 6693778
chore: rebase
peri044 e923627
chore: updates
peri044 8deb1f9
chore: add modelopt tests file
peri044 a65d255
Merge branch 'main' into bf16_fix
peri044 3a0aa18
chore: updates
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# type: ignore | ||
import importlib | ||
import platform | ||
import unittest | ||
from importlib import metadata | ||
|
||
import pytest | ||
import torch | ||
import torch_tensorrt as torchtrt | ||
|
||
from packaging.version import Version | ||
|
||
assertions = unittest.TestCase() | ||
|
||
|
||
@unittest.skipIf( | ||
torch.cuda.get_device_capability() < (8, 9), | ||
"FP8 quantization requires compute capability 8.9 or later", | ||
) | ||
@unittest.skipIf( | ||
not importlib.util.find_spec("modelopt"), | ||
"ModelOpt is required to run this test", | ||
) | ||
@pytest.mark.unit | ||
def test_base_fp8(): | ||
import modelopt.torch.quantization as mtq | ||
from modelopt.torch.quantization.utils import export_torch_mode | ||
|
||
class SimpleNetwork(torch.nn.Module): | ||
def __init__(self): | ||
super(SimpleNetwork, self).__init__() | ||
self.linear1 = torch.nn.Linear(in_features=10, out_features=5) | ||
self.linear2 = torch.nn.Linear(in_features=5, out_features=1) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = torch.nn.ReLU()(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
def calibrate_loop(model): | ||
"""Simple calibration function for testing.""" | ||
model(input_tensor) | ||
|
||
input_tensor = torch.randn(1, 10).cuda() | ||
model = SimpleNetwork().eval().cuda() | ||
|
||
quant_cfg = mtq.FP8_DEFAULT_CFG | ||
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) | ||
# model has FP8 qdq nodes at this point | ||
output_pyt = model(input_tensor) | ||
|
||
with torch.no_grad(): | ||
with export_torch_mode(): | ||
exp_program = torch.export.export(model, (input_tensor,), strict=False) | ||
trt_model = torchtrt.dynamo.compile( | ||
exp_program, | ||
inputs=[input_tensor], | ||
enabled_precisions={torch.float8_e4m3fn}, | ||
min_block_size=1, | ||
cache_built_engines=False, | ||
reuse_cached_engines=False, | ||
) | ||
outputs_trt = trt_model(input_tensor) | ||
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) | ||
|
||
|
||
@unittest.skipIf( | ||
platform.system() != "Linux" | ||
or not importlib.util.find_spec("modelopt") | ||
or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"), | ||
"modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", | ||
) | ||
@pytest.mark.unit | ||
def test_base_int8(): | ||
import modelopt.torch.quantization as mtq | ||
from modelopt.torch.quantization.utils import export_torch_mode | ||
|
||
class SimpleNetwork(torch.nn.Module): | ||
def __init__(self): | ||
super(SimpleNetwork, self).__init__() | ||
self.linear1 = torch.nn.Linear(in_features=10, out_features=5) | ||
self.linear2 = torch.nn.Linear(in_features=5, out_features=1) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = torch.nn.ReLU()(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
def calibrate_loop(model): | ||
"""Simple calibration function for testing.""" | ||
model(input_tensor) | ||
|
||
input_tensor = torch.randn(1, 10).cuda() | ||
model = SimpleNetwork().eval().cuda() | ||
|
||
quant_cfg = mtq.INT8_DEFAULT_CFG | ||
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) | ||
# model has INT8 qdq nodes at this point | ||
output_pyt = model(input_tensor) | ||
|
||
with torchtrt.logging.debug(), torch.no_grad(): | ||
with export_torch_mode(): | ||
exp_program = torch.export.export(model, (input_tensor,), strict=False) | ||
trt_model = torchtrt.dynamo.compile( | ||
exp_program, | ||
inputs=[input_tensor], | ||
enabled_precisions={torch.int8}, | ||
min_block_size=1, | ||
cache_built_engines=False, | ||
reuse_cached_engines=False, | ||
truncate_double=True, | ||
debug=True, | ||
) | ||
outputs_trt = trt_model(input_tensor) | ||
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need to be separated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only issue is that people will forget to add their tests to the list here