Skip to content

Mechanisms to register components of the code base that have debugging functionality #3549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/dynamo/debugging_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
.. _debugging_torchtrt:

Using the Torch-TensorRT Debugger
==========================================================

This tutorial demonstrates how the use the Torch-TensorRT debugger to inspect different components of the
compiler.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch_tensorrt
import torchvision.models as models

# %%

model = models.resnet18(pretrained=True).half().eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]
enabled_precisions = {torch.half}
workspace_size = 20 << 30
min_block_size = 7

# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile within the debugging context to control different aspects of the compilation pipeline
with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
optimized_model = torch_tensorrt.compile(
model,
#ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)

# Does not cause recompilation (same batch size as input)
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
new_outputs = optimized_model(*new_inputs)
print(new_outputs)

# Does not cause recompilation (same batch size as input)
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
new_outputs = optimized_model(*new_inputs)
print(new_outputs)


optimized_model = torch_tensorrt.compile(
model,
#ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)

# Does not cause recompilation (same batch size as input)
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
new_outputs = optimized_model(*new_inputs)
print(new_outputs)
156 changes: 156 additions & 0 deletions py/torch_tensorrt/dynamo/_Debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, List, Optional, Dict
import copy
import functools
import contextlib

from unittest import mock
from dataclasses import dataclass
import logging
import os
import tempfile
from logging.config import dictConfig

import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
ATEN_POST_LOWERING_PASSES,
ATEN_PRE_LOWERING_PASSES,
)

_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")
GRAPH_LEVEL = 5
logging.addLevelName(GRAPH_LEVEL, "GRAPH")

# Debugger States


class Debugger:
def __init__(self, **kwargs: Dict[str, Any]):
self.cfg = DebuggerConfig(**kwargs)

def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
self.rt_level = torch.ops.tensorrt.get_logging_level()
#dictConfig(self.get_config())

self.old_pre_passes, self.old_post_passes = (
ATEN_PRE_LOWERING_PASSES.passes,
ATEN_POST_LOWERING_PASSES.passes,
)
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
post_pass_names = [p.__name__ for p in self.old_post_passes]
#path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")

ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)

ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)

self._context_stack = contextlib.ExitStack()

for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = self.cfg

[
self._context_stack.enter_context(
mock.patch.object(
c,
'__init__',
functools.partialmethod(
c.__init__,
_debugger_settings=self.cfg
)
)
) for c in _DEBUG_ENABLED_CLS
]

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
#dictConfig(self.get_default_config())
torch.ops.tensorrt.set_logging_level(self.rt_level)

ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
self.old_pre_passes,
self.old_post_passes,
)

for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = None

self._context_stack.close()



# def get_config(self) -> dict[str, Any]:
# config = {
# "version": 1,
# "disable_existing_loggers": False,
# "formatters": {
# "brief": {
# "format": "%(asctime)s - %(levelname)s - %(message)s",
# "datefmt": "%H:%M:%S",
# },
# "standard": {
# "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# "datefmt": "%Y-%m-%d %H:%M:%S",
# },
# },
# "handlers": {
# "file": {
# "level": self.level,
# "class": "logging.FileHandler",
# "filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
# "formatter": "standard",
# },
# "console": {
# "level": self.level,
# "class": "logging.StreamHandler",
# "formatter": "brief",
# },
# },
# "loggers": {
# "": { # root logger
# "handlers": ["file", "console"],
# "level": self.level,
# "propagate": True,
# },
# },
# "force": True,
# }
# return config

# def get_default_config(self) -> dict[str, Any]:
# config = {
# "version": 1,
# "disable_existing_loggers": False,
# "formatters": {
# "brief": {
# "format": "%(asctime)s - %(levelname)s - %(message)s",
# "datefmt": "%H:%M:%S",
# },
# "standard": {
# "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# "datefmt": "%Y-%m-%d %H:%M:%S",
# },
# },
# "handlers": {
# "console": {
# "level": self.original_lvl,
# "class": "logging.StreamHandler",
# "formatter": "brief",
# },
# },
# "loggers": {
# "": { # root logger
# "handlers": ["console"],
# "level": self.original_lvl,
# "propagate": True,
# },
# },
# "force": True,
# }
# return config
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/_DebuggerConfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, List, Optional, Dict

import logging
from dataclasses import dataclass, field

@dataclass
class DebuggerConfig:
log_level: int = logging.getLevelName('DEBUG')
capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
save_engine_profile: bool = False
engine_builder_monitor: bool = True
break_in_remove_assert_nodes: bool = True
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
load_cross_compiled_exported_program,
save_cross_compiled_exported_program,
)
from ._debugger import Debugger
from ._Debugger import Debugger
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
Expand Down
34 changes: 17 additions & 17 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,23 +907,23 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
)

trt_modules[name] = trt_module
from torch_tensorrt.dynamo._debugger import (
DEBUG_FILE_DIR,
SAVE_ENGINE_PROFILE,
)

if SAVE_ENGINE_PROFILE:
if settings.use_python_runtime:
logger.warning(
"Profiling can only be enabled when using the C++ runtime"
)
else:
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
os.makedirs(path, exist_ok=True)
trt_module.enable_profiling(
profiling_results_dir=path,
profile_format="trex",
)
# from torch_tensorrt.dynamo._debugger import (
# DEBUG_FILE_DIR,
# SAVE_ENGINE_PROFILE,
# )

# if SAVE_ENGINE_PROFILE:
# if settings.use_python_runtime:
# logger.warning(
# "Profiling can only be enabled when using the C++ runtime"
# )
# else:
# path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
# os.makedirs(path, exist_ok=True)
# trt_module.enable_profiling(
# profiling_results_dir=path,
# profile_format="trex",
# )

# Parse the graph I/O and store it in dryrun tracker
parse_graph_io(gm, dryrun_tracker)
Expand Down
Loading
Loading