-
Notifications
You must be signed in to change notification settings - Fork 363
Mechanisms to register components of the code base that have debugging functionality #3549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…debugging features that will be active in a context manager (compile time only for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:19.066284+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:41.306273+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+ engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
@@ -50,11 +52,11 @@
print(new_outputs)
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:19.076284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:41.757086+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict
import logging
from dataclasses import dataclass, field
+
@dataclass
class DebuggerConfig:
- log_level: int = logging.getLevelName('DEBUG')
+ log_level: int = logging.getLevelName("DEBUG")
capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
save_engine_profile: bool = False
engine_builder_monitor: bool = True
break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:19.076284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:41.803440+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig
import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+ _DEBUG_ENABLED_CLS,
+ _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
ATEN_POST_LOWERING_PASSES,
ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
self.cfg = DebuggerConfig(**kwargs)
def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
self.rt_level = torch.ops.tensorrt.get_logging_level()
- #dictConfig(self.get_config())
+ # dictConfig(self.get_config())
self.old_pre_passes, self.old_post_passes = (
ATEN_PRE_LOWERING_PASSES.passes,
ATEN_POST_LOWERING_PASSES.passes,
)
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
post_pass_names = [p.__name__ for p in self.old_post_passes]
- #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+ # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
self._context_stack = contextlib.ExitStack()
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = self.cfg
[
self._context_stack.enter_context(
mock.patch.object(
c,
- '__init__',
- functools.partialmethod(
- c.__init__,
- _debugger_settings=self.cfg
- )
+ "__init__",
+ functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
)
- ) for c in _DEBUG_ENABLED_CLS
+ )
+ for c in _DEBUG_ENABLED_CLS
]
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
- #dictConfig(self.get_default_config())
+ # dictConfig(self.get_default_config())
torch.ops.tensorrt.set_logging_level(self.rt_level)
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
self.old_pre_passes,
self.old_post_passes,
@@ -80,12 +89,10 @@
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = None
self._context_stack.close()
-
-
# def get_config(self) -> dict[str, Any]:
# config = {
# "version": 1,
# "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:19.077284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:41.917136+00:00
@@ -4,12 +4,14 @@
_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []
+
def fn_supports_debugger(func):
_DEBUG_ENABLED_FUNCS.append(func)
return func
-def cls_supports_debugger(cls: Type[T]) -> Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
_DEBUG_ENABLED_CLS.append(cls)
return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:19.077284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:42.103174+00:00
@@ -90,10 +90,11 @@
clear_line()
print()
move_cursor_up(len(self._active_phases) + blank_lines)
sys.stdout.flush()
+
# try:
# from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn
# class _RichMonitor(trt.IProgressMonitor): # type: ignore
# def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:19.078285+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:42.829249+00:00
@@ -69,10 +69,11 @@
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
requires_output_allocator: bool
+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
def __init__(
self,
module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
tactic_sources: Optional[int] = None,
) -> trt.IBuilderConfig:
builder_config = self.builder.create_builder_config()
- if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+ if (
+ self._debugger_settings is not None
+ and self._debugger_settings.engine_builder_monitor
+ ):
builder_config.progress_monitor = TRTBulderMonitor()
if self.compilation_settings.workspace_size != 0:
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:19.081285+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:43.303996+00:00
@@ -9,16 +9,23 @@
clean_up_graph_after_modifications,
)
logger = logging.getLogger(__name__)
+
@fn_supports_debugger
def remove_assert_nodes(
- gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+ gm: torch.fx.GraphModule,
+ settings: CompilationSettings,
+ *,
+ _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
- if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+ if (
+ _debugger_settings is not None
+ and _debugger_settings.break_in_remove_assert_nodes
+ ):
breakpoint()
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:19.389777+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:41.613814+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+ engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
@@ -50,11 +52,11 @@
print(new_outputs)
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:42.057169+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict
import logging
from dataclasses import dataclass, field
+
@dataclass
class DebuggerConfig:
- log_level: int = logging.getLevelName('DEBUG')
+ log_level: int = logging.getLevelName("DEBUG")
capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
save_engine_profile: bool = False
engine_builder_monitor: bool = True
break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:42.084847+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig
import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+ _DEBUG_ENABLED_CLS,
+ _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
ATEN_POST_LOWERING_PASSES,
ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
self.cfg = DebuggerConfig(**kwargs)
def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
self.rt_level = torch.ops.tensorrt.get_logging_level()
- #dictConfig(self.get_config())
+ # dictConfig(self.get_config())
self.old_pre_passes, self.old_post_passes = (
ATEN_PRE_LOWERING_PASSES.passes,
ATEN_POST_LOWERING_PASSES.passes,
)
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
post_pass_names = [p.__name__ for p in self.old_post_passes]
- #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+ # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
self._context_stack = contextlib.ExitStack()
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = self.cfg
[
self._context_stack.enter_context(
mock.patch.object(
c,
- '__init__',
- functools.partialmethod(
- c.__init__,
- _debugger_settings=self.cfg
- )
+ "__init__",
+ functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
)
- ) for c in _DEBUG_ENABLED_CLS
+ )
+ for c in _DEBUG_ENABLED_CLS
]
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
- #dictConfig(self.get_default_config())
+ # dictConfig(self.get_default_config())
torch.ops.tensorrt.set_logging_level(self.rt_level)
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
self.old_pre_passes,
self.old_post_passes,
@@ -80,12 +89,10 @@
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = None
self._context_stack.close()
-
-
# def get_config(self) -> dict[str, Any]:
# config = {
# "version": 1,
# "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:42.196192+00:00
@@ -4,12 +4,14 @@
_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []
+
def fn_supports_debugger(func):
_DEBUG_ENABLED_FUNCS.append(func)
return func
-def cls_supports_debugger(cls: Type[T]) -> Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
_DEBUG_ENABLED_CLS.append(cls)
return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:19.401778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:42.406536+00:00
@@ -90,10 +90,11 @@
clear_line()
print()
move_cursor_up(len(self._active_phases) + blank_lines)
sys.stdout.flush()
+
# try:
# from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn
# class _RichMonitor(trt.IProgressMonitor): # type: ignore
# def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:19.401778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:43.112812+00:00
@@ -69,10 +69,11 @@
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
requires_output_allocator: bool
+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
def __init__(
self,
module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
tactic_sources: Optional[int] = None,
) -> trt.IBuilderConfig:
builder_config = self.builder.create_builder_config()
- if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+ if (
+ self._debugger_settings is not None
+ and self._debugger_settings.engine_builder_monitor
+ ):
builder_config.progress_monitor = TRTBulderMonitor()
if self.compilation_settings.workspace_size != 0:
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:19.405777+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:43.491038+00:00
@@ -9,16 +9,23 @@
clean_up_graph_after_modifications,
)
logger = logging.getLogger(__name__)
+
@fn_supports_debugger
def remove_assert_nodes(
- gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+ gm: torch.fx.GraphModule,
+ settings: CompilationSettings,
+ *,
+ _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
- if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+ if (
+ _debugger_settings is not None
+ and _debugger_settings.break_in_remove_assert_nodes
+ ):
breakpoint()
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:20.312103+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py 2025-06-03 04:51:42.612724+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+ engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
@@ -50,11 +52,11 @@
print(new_outputs)
optimized_model = torch_tensorrt.compile(
model,
- #ir="torch_compile",
+ # ir="torch_compile",
inputs=inputs,
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py 2025-06-03 04:51:43.069284+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict
import logging
from dataclasses import dataclass, field
+
@dataclass
class DebuggerConfig:
- log_level: int = logging.getLevelName('DEBUG')
+ log_level: int = logging.getLevelName("DEBUG")
capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
save_engine_profile: bool = False
engine_builder_monitor: bool = True
break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py 2025-06-03 04:51:43.098669+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig
import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+ _DEBUG_ENABLED_CLS,
+ _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
ATEN_POST_LOWERING_PASSES,
ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
self.cfg = DebuggerConfig(**kwargs)
def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
self.rt_level = torch.ops.tensorrt.get_logging_level()
- #dictConfig(self.get_config())
+ # dictConfig(self.get_config())
self.old_pre_passes, self.old_post_passes = (
ATEN_PRE_LOWERING_PASSES.passes,
ATEN_POST_LOWERING_PASSES.passes,
)
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
post_pass_names = [p.__name__ for p in self.old_post_passes]
- #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+ # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+ self.cfg.capture_fx_graph_before
+ )
- ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
- ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+ ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
+ ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+ self.cfg.capture_fx_graph_after
+ )
self._context_stack = contextlib.ExitStack()
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = self.cfg
[
self._context_stack.enter_context(
mock.patch.object(
c,
- '__init__',
- functools.partialmethod(
- c.__init__,
- _debugger_settings=self.cfg
- )
+ "__init__",
+ functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
)
- ) for c in _DEBUG_ENABLED_CLS
+ )
+ for c in _DEBUG_ENABLED_CLS
]
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
- #dictConfig(self.get_default_config())
+ # dictConfig(self.get_default_config())
torch.ops.tensorrt.set_logging_level(self.rt_level)
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
self.old_pre_passes,
self.old_post_passes,
@@ -80,12 +89,10 @@
for f in _DEBUG_ENABLED_FUNCS:
f.__kwdefaults__["_debugger_settings"] = None
self._context_stack.close()
-
-
# def get_config(self) -> dict[str, Any]:
# config = {
# "version": 1,
# "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py 2025-06-03 04:51:43.226752+00:00
@@ -4,12 +4,14 @@
_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []
+
def fn_supports_debugger(func):
_DEBUG_ENABLED_FUNCS.append(func)
return func
-def cls_supports_debugger(cls: Type[T]) -> Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
_DEBUG_ENABLED_CLS.append(cls)
return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:20.324103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py 2025-06-03 04:51:43.415699+00:00
@@ -90,10 +90,11 @@
clear_line()
print()
move_cursor_up(len(self._active_phases) + blank_lines)
sys.stdout.flush()
+
# try:
# from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn
# class _RichMonitor(trt.IProgressMonitor): # type: ignore
# def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:20.324103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-03 04:51:44.112858+00:00
@@ -69,10 +69,11 @@
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
requires_output_allocator: bool
+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
def __init__(
self,
module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
tactic_sources: Optional[int] = None,
) -> trt.IBuilderConfig:
builder_config = self.builder.create_builder_config()
- if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+ if (
+ self._debugger_settings is not None
+ and self._debugger_settings.engine_builder_monitor
+ ):
builder_config.progress_monitor = TRTBulderMonitor()
if self.compilation_settings.workspace_size != 0:
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:20.328103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py 2025-06-03 04:51:44.527302+00:00
@@ -9,16 +9,23 @@
clean_up_graph_after_modifications,
)
logger = logging.getLogger(__name__)
+
@fn_supports_debugger
def remove_assert_nodes(
- gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+ gm: torch.fx.GraphModule,
+ settings: CompilationSettings,
+ *,
+ _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
- if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+ if (
+ _debugger_settings is not None
+ and _debugger_settings.break_in_remove_assert_nodes
+ ):
breakpoint()
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
Description
This PR demonstrates how we can use a context manager to automatically patch the compiler to toggle debugging settings scattered through the codebase. It operates through 3 main contact points for developers, a optional kwarg, a settings struct and a decorator to declare a component as part of the debugging system.
This config will then be available to be stored as a attr in the class or to be used as part of the function.
Any specific settings should be added to
DebuggerConfig
. This becomes the user facing API for theDebugger
class. The developer can use this setting to manage debugging behavior.Users will set the config value using a context manager:
Here
engine_builder_monitor
is a feature ofTRTInterpreter
andbreak_in_remove_assert_nodes
is a feature of theremove_assert_nodes
pass.Both these components register as part of the debugging system. When the context is entered, for functions, the default kwargs are edited to contain the
DebuggerConfig
. For classes, the class is mocked with the argument filled in the constructor. On exit, these modifications are cleaned up.In this PR (which needs to be cleaned), the example
debugging_example.py
triggers both the TRTInterpreter setting to turn on the builder monitor and the lowering pass setting. Breakpoints are set to verify behavior. Correct execution means that on the first compilation two breakpoints are hit, one in the lowering pass and one in the TRTBuilderMonitor. Then on the second compilation, no breakpoints are hit and the script completes.Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: