Skip to content

Commit 5de962f

Browse files
Daniel Ohayonfacebook-github-bot
Daniel Ohayon
authored andcommitted
handle components whose implementation lives in a different file (pytorch#1075)
Summary: Add support for cases like: ```lang=python # some_file.py # ==================== def my_component(...) -> specs.AppDef: ... # other_file.py # ==================== from some_file import my_component ``` where the component is invoked with `torchx run ... other_file.py:my_component` This was currently failing with a validation error because in the step where we inspect the AST of the component, we assume that the file where the component is being looked up is the same as the file where it is implemented. Reviewed By: kiukchung Differential Revision: D75496839
1 parent 5c43bc6 commit 5de962f

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

torchx/specs/finder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,21 @@ def _get_validation_errors(
274274
linter_errors = validate(path, function_name, validators)
275275
return [linter_error.description for linter_error in linter_errors]
276276

277+
def _get_path_to_function_decl(self, function: Callable[..., Any]) -> str: # pyre-ignore
278+
"""
279+
Attempts to return the path to the file where the function is implemented.
280+
This can be different from the path where the function is looked up, for example if we have:
281+
my_component defined in some_file.py, imported in other_file.py
282+
and the component is invoked as other_file.py:my_component
283+
"""
284+
path_to_function_decl = inspect.getabsfile(function)
285+
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
286+
return self._filepath
287+
return path_to_function_decl
288+
277289
def find(
278290
self, validators: Optional[List[TorchxFunctionValidator]]
279291
) -> List[_Component]:
280-
validation_errors = self._get_validation_errors(
281-
self._filepath, self._function_name, validators
282-
)
283292

284293
file_source = read_conf_file(self._filepath)
285294
namespace = copy.copy(globals())
@@ -292,6 +301,12 @@ def find(
292301
)
293302
app_fn = namespace[self._function_name]
294303
fn_desc, _ = get_fn_docstring(app_fn)
304+
305+
func_path = self._get_path_to_function_decl(app_fn)
306+
validation_errors = self._get_validation_errors(
307+
func_path, self._function_name, validators
308+
)
309+
295310
return [
296311
_Component(
297312
name=f"{self._filepath}:{self._function_name}",

torchx/specs/test/finder_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_components,
3030
ModuleComponentsFinder,
3131
)
32+
from torchx.specs.test.components.a import comp_a
3233
from torchx.util.test.entrypoints_test import EntryPoint_from_text
3334
from torchx.util.types import none_throws
3435

@@ -238,6 +239,10 @@ def test_get_component_invalid(self) -> None:
238239
with self.assertRaises(ComponentValidationException):
239240
get_component(f"{current_file_path()}:invalid_component")
240241

242+
def test_get_component_imported_from_other_file(self) -> None:
243+
component = get_component(f"{current_file_path()}:comp_a")
244+
self.assertListEqual([], component.validation_errors)
245+
241246

242247
class GetBuiltinSourceTest(unittest.TestCase):
243248
def setUp(self) -> None:

0 commit comments

Comments
 (0)