Skip to content

Commit 2a129b6

Browse files
Daniel Ohayonfacebook-github-bot
Daniel Ohayon
authored andcommitted
handle components whose implementation lives in a different file (#1075)
Summary: Add support for cases like: ```lang=python # some_file.py # ==================== def my_component(...) -> specs.AppDef: ... # other_file.py # ==================== from some_file import my_component ``` where the component is invoked with `torchx run ... other_file.py:my_component` This was currently failing with a validation error because in the step where we inspect the AST of the component, we assume that the file where the component is being looked up is the same as the file where it is implemented. Differential Revision: D75496839
1 parent 24dc0d5 commit 2a129b6

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

torchx/specs/finder.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inspect import getmembers, isfunction
1818
from pathlib import Path
1919
from types import ModuleType
20-
from typing import Any, Callable, Dict, Generator, List, Optional, Union
20+
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar, Union
2121

2222
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
2323
from torchx.util import entrypoints
@@ -36,6 +36,9 @@ class ComponentNotFoundException(Exception):
3636
pass
3737

3838

39+
T = TypeVar("T")
40+
41+
3942
@dataclass
4043
class _Component:
4144
"""
@@ -274,12 +277,21 @@ def _get_validation_errors(
274277
linter_errors = validate(path, function_name, validators)
275278
return [linter_error.description for linter_error in linter_errors]
276279

280+
def _get_path_to_function_decl(self, function: Callable[..., T]) -> str:
281+
"""
282+
Attempts to return the path to the file where the function is implemented.
283+
This can be different from the path where the function is looked up, for example if we have:
284+
my_component defined in some_file.py, imported in other_file.py
285+
and the component is invoked as other_file.py:my_component
286+
"""
287+
path_to_function_decl = inspect.getabsfile(function)
288+
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
289+
return self._filepath
290+
return path_to_function_decl
291+
277292
def find(
278293
self, validators: Optional[List[TorchxFunctionValidator]]
279294
) -> List[_Component]:
280-
validation_errors = self._get_validation_errors(
281-
self._filepath, self._function_name, validators
282-
)
283295

284296
file_source = read_conf_file(self._filepath)
285297
namespace = copy.copy(globals())
@@ -292,6 +304,12 @@ def find(
292304
)
293305
app_fn = namespace[self._function_name]
294306
fn_desc, _ = get_fn_docstring(app_fn)
307+
308+
func_path = self._get_path_to_function_decl(app_fn)
309+
validation_errors = self._get_validation_errors(
310+
func_path, self._function_name, validators
311+
)
312+
295313
return [
296314
_Component(
297315
name=f"{self._filepath}:{self._function_name}",

torchx/specs/test/finder_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_components,
3030
ModuleComponentsFinder,
3131
)
32+
from torchx.specs.test.components.a import comp_a
3233
from torchx.util.test.entrypoints_test import EntryPoint_from_text
3334
from torchx.util.types import none_throws
3435

@@ -238,6 +239,10 @@ def test_get_component_invalid(self) -> None:
238239
with self.assertRaises(ComponentValidationException):
239240
get_component(f"{current_file_path()}:invalid_component")
240241

242+
def test_get_component_imported_from_other_file(self) -> None:
243+
component = get_component(f"{current_file_path()}:comp_a")
244+
self.assertListEqual([], component.validation_errors)
245+
241246

242247
class GetBuiltinSourceTest(unittest.TestCase):
243248
def setUp(self) -> None:

0 commit comments

Comments
 (0)