Skip to content

Commit f3f9002

Browse files
Daniel Ohayonfacebook-github-bot
Daniel Ohayon
authored andcommitted
handle components whose implementation lives in a different file
Summary: Add support for cases like: ```lang=python # some_file.py # ==================== def my_component(...) -> specs.AppDef: ... # other_file.py # ==================== from some_file import my_component ``` where the component is invoked with `torchx run ... other_file.py:my_component` This was currently failing with a validation error because in the step where we inspect the AST of the component, we assume that the file where the component is being looked up is the same as the file where it is implemented. Differential Revision: D75496839
1 parent 24dc0d5 commit f3f9002

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

torchx/specs/finder.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from types import ModuleType
2020
from typing import Any, Callable, Dict, Generator, List, Optional, Union
2121

22+
from torchx.specs.api import AppDef
2223
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
2324
from torchx.util import entrypoints
2425
from torchx.util.io import read_conf_file
@@ -274,12 +275,21 @@ def _get_validation_errors(
274275
linter_errors = validate(path, function_name, validators)
275276
return [linter_error.description for linter_error in linter_errors]
276277

278+
def _get_path_to_function_decl(self, function: Callable[..., AppDef]) -> str:
279+
"""
280+
Attempts to return the path to the file where the function is implemented.
281+
This can be different from the path where the function is looked up, for example if we have:
282+
my_component defined in some_file.py, imported in other_file.py
283+
and the component is invoked as other_file.py:my_component
284+
"""
285+
path_to_function_decl = function.__code__.co_filename # pyre-ignore[16]
286+
if not os.path.isfile(path_to_function_decl):
287+
return self._filepath
288+
return path_to_function_decl
289+
277290
def find(
278291
self, validators: Optional[List[TorchxFunctionValidator]]
279292
) -> List[_Component]:
280-
validation_errors = self._get_validation_errors(
281-
self._filepath, self._function_name, validators
282-
)
283293

284294
file_source = read_conf_file(self._filepath)
285295
namespace = copy.copy(globals())
@@ -292,6 +302,12 @@ def find(
292302
)
293303
app_fn = namespace[self._function_name]
294304
fn_desc, _ = get_fn_docstring(app_fn)
305+
306+
func_path = self._get_path_to_function_decl(app_fn)
307+
validation_errors = self._get_validation_errors(
308+
func_path, self._function_name, validators
309+
)
310+
295311
return [
296312
_Component(
297313
name=f"{self._filepath}:{self._function_name}",

0 commit comments

Comments
 (0)