Skip to content

Commit ae3fd02

Browse files
authored
Merge pull request #2376 from oesteban/maint/engine-base
[MAINT] Cleanup EngineBase
2 parents 80d3f05 + 1d38127 commit ae3fd02

File tree

13 files changed

+627
-585
lines changed

13 files changed

+627
-585
lines changed

CHANGES

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
Upcoming release (0.14.1)
22
=========================
33

4+
* MAINT: Cleanup EngineBase (https://github.com/nipy/nipype/pull/2376)
45
* FIX: Robustly handled outputs of 3dFWHMx across different versions of AFNI (https://github.com/nipy/nipype/pull/2373)
56
* FIX: Cluster threshold in randomise + change default prefix (https://github.com/nipy/nipype/pull/2369)
6-
* MAINT: Cleaning / simplify ``Node`` (https://github.com/nipy/nipype/pull/#2325)
7+
* MAINT: Cleaning / simplify ``Node`` (https://github.com/nipy/nipype/pull/2325)
78
* STY: Cleanup of PEP8 violations (https://github.com/nipy/nipype/pull/2358)
89
* STY: Cleanup of trailing spaces and adding of missing newlines at end of files (https://github.com/nipy/nipype/pull/2355)
910

nipype/interfaces/tests/test_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def test_s3datagrabber():
147147
"node_output": ["model"]
148148
}),
149149
])
150-
def test_selectfiles(SF_args, inputs_att, expected):
150+
def test_selectfiles(tmpdir, SF_args, inputs_att, expected):
151+
tmpdir.chdir()
151152
base_dir = op.dirname(nipype.__file__)
152153
dg = nio.SelectFiles(base_directory=base_dir, **SF_args)
153154
for key, val in inputs_att.items():

nipype/pipeline/engine/base.py

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,14 @@
1818
absolute_import)
1919
from builtins import object
2020

21-
from future import standard_library
22-
standard_library.install_aliases()
23-
2421
from copy import deepcopy
2522
import re
2623
import numpy as np
27-
from ... import logging
24+
25+
from ... import config
2826
from ...interfaces.base import DynamicTraitedSpec
2927
from ...utils.filemanip import loadpkl, savepkl
3028

31-
logger = logging.getLogger('workflow')
32-
3329

3430
class EngineBase(object):
3531
"""Defines common attributes and functions for workflows and nodes."""
@@ -47,35 +43,36 @@ def __init__(self, name=None, base_dir=None):
4743
default=None, which results in the use of mkdtemp
4844
4945
"""
46+
self._hierarchy = None
47+
self._name = None
48+
5049
self.base_dir = base_dir
51-
self.config = None
52-
self._verify_name(name)
50+
self.config = deepcopy(config._sections)
5351
self.name = name
54-
# for compatibility with node expansion using iterables
55-
self._id = self.name
56-
self._hierarchy = None
5752

5853
@property
59-
def inputs(self):
60-
raise NotImplementedError
54+
def name(self):
55+
return self._name
6156

62-
@property
63-
def outputs(self):
64-
raise NotImplementedError
57+
@name.setter
58+
def name(self, name):
59+
if not name or not re.match(r'^[\w-]+$', name):
60+
raise ValueError('[Workflow|Node] name "%s" is not valid.' % name)
61+
self._name = name
6562

6663
@property
6764
def fullname(self):
68-
fullname = self.name
6965
if self._hierarchy:
70-
fullname = self._hierarchy + '.' + self.name
71-
return fullname
66+
return '%s.%s' % (self._hierarchy, self.name)
67+
return self.name
7268

7369
@property
74-
def itername(self):
75-
itername = self._id
76-
if self._hierarchy:
77-
itername = self._hierarchy + '.' + self._id
78-
return itername
70+
def inputs(self):
71+
raise NotImplementedError
72+
73+
@property
74+
def outputs(self):
75+
raise NotImplementedError
7976

8077
def clone(self, name):
8178
"""Clone an EngineBase object
@@ -86,13 +83,10 @@ def clone(self, name):
8683
name : string (mandatory)
8784
A clone of node or workflow must have a new name
8885
"""
89-
if (name is None) or (name == self.name):
90-
raise Exception('Cloning requires a new name')
91-
self._verify_name(name)
86+
if name == self.name:
87+
raise ValueError('Cloning requires a new name, "%s" is in use.' % name)
9288
clone = deepcopy(self)
9389
clone.name = name
94-
clone._id = name
95-
clone._hierarchy = None
9690
return clone
9791

9892
def _check_outputs(self, parameter):
@@ -103,17 +97,8 @@ def _check_inputs(self, parameter):
10397
return True
10498
return hasattr(self.inputs, parameter)
10599

106-
def _verify_name(self, name):
107-
valid_name = bool(re.match('^[\w-]+$', name))
108-
if not valid_name:
109-
raise ValueError('[Workflow|Node] name \'%s\' contains'
110-
' special characters' % name)
111-
112-
def __repr__(self):
113-
if self._hierarchy:
114-
return '.'.join((self._hierarchy, self._id))
115-
else:
116-
return '{}'.format(self._id)
100+
def __str__(self):
101+
return self.fullname
117102

118103
def save(self, filename=None):
119104
if filename is None:

nipype/pipeline/engine/nodes.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,13 @@ def __init__(self,
161161

162162
super(Node, self).__init__(name, kwargs.get('base_dir'))
163163

164-
self.name = name
165164
self._interface = interface
166165
self._hierarchy = None
167166
self._got_inputs = False
168167
self._originputs = None
169168
self._output_dir = None
169+
self._id = self.name # for compatibility with node expansion using iterables
170+
170171
self.iterables = iterables
171172
self.synchronize = synchronize
172173
self.itersource = itersource
@@ -228,7 +229,6 @@ def n_procs(self):
228229
if hasattr(self._interface.inputs, 'num_threads') and isdefined(
229230
self._interface.inputs.num_threads):
230231
return self._interface.inputs.num_threads
231-
232232
return 1
233233

234234
@n_procs.setter
@@ -240,6 +240,13 @@ def n_procs(self, value):
240240
if hasattr(self._interface.inputs, 'num_threads'):
241241
self._interface.inputs.num_threads = self._n_procs
242242

243+
@property
244+
def itername(self):
245+
itername = self._id
246+
if self._hierarchy:
247+
itername = self._hierarchy + '.' + self._id
248+
return itername
249+
243250
def output_dir(self):
244251
"""Return the location of the output directory for the node"""
245252
# Output dir is cached
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
3+
# vi: set ft=python sts=4 ts=4 sw=4 et:
4+
from __future__ import print_function, unicode_literals
5+
6+
import pytest
7+
from ..base import EngineBase
8+
from ....interfaces import base as nib
9+
10+
11+
class InputSpec(nib.TraitedSpec):
12+
input1 = nib.traits.Int(desc='a random int')
13+
input2 = nib.traits.Int(desc='a random int')
14+
input_file = nib.traits.File(desc='Random File')
15+
16+
17+
class OutputSpec(nib.TraitedSpec):
18+
output1 = nib.traits.List(nib.traits.Int, desc='outputs')
19+
20+
21+
class EngineTestInterface(nib.BaseInterface):
22+
input_spec = InputSpec
23+
output_spec = OutputSpec
24+
25+
def _run_interface(self, runtime):
26+
runtime.returncode = 0
27+
return runtime
28+
29+
def _list_outputs(self):
30+
outputs = self._outputs().get()
31+
outputs['output1'] = [1, self.inputs.input1]
32+
return outputs
33+
34+
35+
@pytest.mark.parametrize(
36+
'name', ['valid1', 'valid_node', 'valid-node', 'ValidNode0'])
37+
def test_create(name):
38+
base = EngineBase(name=name)
39+
assert base.name == name
40+
41+
42+
@pytest.mark.parametrize(
43+
'name', ['invalid*1', 'invalid.1', 'invalid@', 'in/valid', None])
44+
def test_create_invalid(name):
45+
with pytest.raises(ValueError):
46+
EngineBase(name=name)
47+
48+
49+
def test_hierarchy():
50+
base = EngineBase(name='nodename')
51+
base._hierarchy = 'some.history.behind'
52+
53+
assert base.name == 'nodename'
54+
assert base.fullname == 'some.history.behind.nodename'
55+
56+
57+
def test_clone():
58+
base = EngineBase(name='nodename')
59+
base2 = base.clone('newnodename')
60+
61+
assert (base.base_dir == base2.base_dir and
62+
base.config == base2.config and
63+
base2.name == 'newnodename')
64+
65+
with pytest.raises(ValueError):
66+
base.clone('nodename')

0 commit comments

Comments
 (0)