Skip to content

Commit ca4999e

Browse files
authored
Merge pull request #2104 from ashgillman/sshdatagrabber-grab-related
ENH: Update SSHDataGrabber to fetch related files
2 parents e679215 + 7b5d667 commit ca4999e

File tree

5 files changed

+155
-64
lines changed

5 files changed

+155
-64
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python:
1212
env:
1313
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1
1414
- INSTALL_DEB_DEPENDECIES=false NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1
15-
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit" CI_SKIP_TEST=1
15+
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit,ssh" CI_SKIP_TEST=1
1616
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" PIP_FLAGS="--pre" CI_SKIP_TEST=1
1717

1818
addons:

docker/generate_dockerfiles.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function generate_main_dockerfile() {
103103
--arg PYTHON_VERSION_MAJOR=3 PYTHON_VERSION_MINOR=6 BUILD_DATE VCS_REF VERSION \
104104
--miniconda env_name=neuro \
105105
conda_install='python=${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}
106-
icu=58.1 libxml2 libxslt matplotlib mkl numpy
106+
icu=58.1 libxml2 libxslt matplotlib mkl numpy paramiko
107107
pandas psutil scikit-learn scipy traits=4.6.0' \
108108
pip_opts="-e" \
109109
pip_install="/src/nipype[all]" \

nipype/info.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def get_nipype_gitversion():
163163
'profiler': ['psutil>=5.0'],
164164
'duecredit': ['duecredit'],
165165
'xvfbwrapper': ['xvfbwrapper'],
166-
'pybids': ['pybids']
166+
'pybids': ['pybids'],
167+
'ssh': ['paramiko'],
167168
# 'mesh': ['mayavi'] # Enable when it works
168169
}
169170

nipype/interfaces/io.py

Lines changed: 74 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
import shutil
3232
import subprocess
3333
import re
34+
import copy
3435
import tempfile
3536
from os.path import join, dirname
3637
from warnings import warn
3738

3839
import sqlite3
3940

4041
from .. import config, logging
41-
from ..utils.filemanip import copyfile, list_to_filename, filename_to_list
42+
from ..utils.filemanip import (
43+
copyfile, list_to_filename, filename_to_list,
44+
get_related_files, related_filetype_sets)
4245
from ..utils.misc import human_order_sorted, str2bool
4346
from .base import (
4447
TraitedSpec, traits, Str, File, Directory, BaseInterface, InputMultiPath,
@@ -2412,6 +2415,65 @@ def __init__(self, infields=None, outfields=None, **kwargs):
24122415
and self.inputs.template[-1] != '$'):
24132416
self.inputs.template += '$'
24142417

2418+
def _get_files_over_ssh(self, template):
2419+
"""Get the files matching template over an SSH connection."""
2420+
# Connect over SSH
2421+
client = self._get_ssh_client()
2422+
sftp = client.open_sftp()
2423+
sftp.chdir(self.inputs.base_directory)
2424+
2425+
# Get all files in the dir, and filter for desired files
2426+
template_dir = os.path.dirname(template)
2427+
template_base = os.path.basename(template)
2428+
every_file_in_dir = sftp.listdir(template_dir)
2429+
if self.inputs.template_expression == 'fnmatch':
2430+
outfiles = fnmatch.filter(every_file_in_dir, template_base)
2431+
elif self.inputs.template_expression == 'regexp':
2432+
regexp = re.compile(template_base)
2433+
outfiles = list(filter(regexp.match, every_file_in_dir))
2434+
else:
2435+
raise ValueError('template_expression value invalid')
2436+
2437+
if len(outfiles) == 0:
2438+
# no files
2439+
msg = 'Output template: %s returned no files' % template
2440+
if self.inputs.raise_on_empty:
2441+
raise IOError(msg)
2442+
else:
2443+
warn(msg)
2444+
2445+
# return value
2446+
outfiles = None
2447+
2448+
else:
2449+
# found files, sort and save to outputs
2450+
if self.inputs.sort_filelist:
2451+
outfiles = human_order_sorted(outfiles)
2452+
2453+
# actually download the files, if desired
2454+
if self.inputs.download_files:
2455+
files_to_download = copy.copy(outfiles) # make sure new list!
2456+
2457+
# check to see if there are any related files to download
2458+
for file_to_download in files_to_download:
2459+
related_to_current = get_related_files(
2460+
file_to_download, include_this_file=False)
2461+
existing_related_not_downloading = [
2462+
f for f in related_to_current
2463+
if f in every_file_in_dir and f not in files_to_download]
2464+
files_to_download.extend(existing_related_not_downloading)
2465+
2466+
for f in files_to_download:
2467+
try:
2468+
sftp.get(os.path.join(template_dir, f), f)
2469+
except IOError:
2470+
iflogger.info('remote file %s not found' % f)
2471+
2472+
# return value
2473+
outfiles = list_to_filename(outfiles)
2474+
2475+
return outfiles
2476+
24152477
def _list_outputs(self):
24162478
try:
24172479
paramiko
@@ -2439,32 +2501,10 @@ def _list_outputs(self):
24392501
isdefined(self.inputs.field_template) and \
24402502
key in self.inputs.field_template:
24412503
template = self.inputs.field_template[key]
2504+
24422505
if not args:
2443-
client = self._get_ssh_client()
2444-
sftp = client.open_sftp()
2445-
sftp.chdir(self.inputs.base_directory)
2446-
filelist = sftp.listdir()
2447-
if self.inputs.template_expression == 'fnmatch':
2448-
filelist = fnmatch.filter(filelist, template)
2449-
elif self.inputs.template_expression == 'regexp':
2450-
regexp = re.compile(template)
2451-
filelist = list(filter(regexp.match, filelist))
2452-
else:
2453-
raise ValueError('template_expression value invalid')
2454-
if len(filelist) == 0:
2455-
msg = 'Output key: %s Template: %s returned no files' % (
2456-
key, template)
2457-
if self.inputs.raise_on_empty:
2458-
raise IOError(msg)
2459-
else:
2460-
warn(msg)
2461-
else:
2462-
if self.inputs.sort_filelist:
2463-
filelist = human_order_sorted(filelist)
2464-
outputs[key] = list_to_filename(filelist)
2465-
if self.inputs.download_files:
2466-
for f in filelist:
2467-
sftp.get(f, f)
2506+
outputs[key] = self._get_files_over_ssh(template)
2507+
24682508
for argnum, arglist in enumerate(args):
24692509
maxlen = 1
24702510
for arg in arglist:
@@ -2498,44 +2538,18 @@ def _list_outputs(self):
24982538
e.message +
24992539
": Template %s failed to convert with args %s"
25002540
% (template, str(tuple(argtuple))))
2501-
client = self._get_ssh_client()
2502-
sftp = client.open_sftp()
2503-
sftp.chdir(self.inputs.base_directory)
2504-
filledtemplate_dir = os.path.dirname(filledtemplate)
2505-
filledtemplate_base = os.path.basename(filledtemplate)
2506-
filelist = sftp.listdir(filledtemplate_dir)
2507-
if self.inputs.template_expression == 'fnmatch':
2508-
outfiles = fnmatch.filter(filelist,
2509-
filledtemplate_base)
2510-
elif self.inputs.template_expression == 'regexp':
2511-
regexp = re.compile(filledtemplate_base)
2512-
outfiles = list(filter(regexp.match, filelist))
2513-
else:
2514-
raise ValueError('template_expression value invalid')
2515-
if len(outfiles) == 0:
2516-
msg = 'Output key: %s Template: %s returned no files' % (
2517-
key, filledtemplate)
2518-
if self.inputs.raise_on_empty:
2519-
raise IOError(msg)
2520-
else:
2521-
warn(msg)
2522-
outputs[key].append(None)
2523-
else:
2524-
if self.inputs.sort_filelist:
2525-
outfiles = human_order_sorted(outfiles)
2526-
outputs[key].append(list_to_filename(outfiles))
2527-
if self.inputs.download_files:
2528-
for f in outfiles:
2529-
try:
2530-
sftp.get(
2531-
os.path.join(filledtemplate_dir, f), f)
2532-
except IOError:
2533-
iflogger.info('remote file %s not found',
2534-
f)
2541+
2542+
outputs[key].append(self._get_files_over_ssh(filledtemplate))
2543+
2544+
# disclude where there was any invalid matches
25352545
if any([val is None for val in outputs[key]]):
25362546
outputs[key] = []
2547+
2548+
# no outputs is None, not empty list
25372549
if len(outputs[key]) == 0:
25382550
outputs[key] = None
2551+
2552+
# one output is the item, not a list
25392553
elif len(outputs[key]) == 1:
25402554
outputs[key] = outputs[key][0]
25412555

nipype/interfaces/tests/test_io.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from builtins import str, zip, range, open
66
from future import standard_library
77
import os
8+
import copy
89
import simplejson
910
import glob
1011
import shutil
@@ -37,6 +38,32 @@
3738
except ImportError:
3839
noboto3 = True
3940

41+
# Check for paramiko
42+
try:
43+
import paramiko
44+
no_paramiko = False
45+
46+
# Check for localhost SSH Server
47+
# FIXME: Tests requiring this are never run on CI
48+
try:
49+
proxy = None
50+
client = paramiko.SSHClient()
51+
client.load_system_host_keys()
52+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
53+
client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy,
54+
timeout=10)
55+
56+
no_local_ssh = False
57+
58+
except (paramiko.SSHException,
59+
paramiko.ssh_exception.NoValidConnectionsError,
60+
OSError):
61+
no_local_ssh = True
62+
63+
except ImportError:
64+
no_paramiko = True
65+
no_local_ssh = True
66+
4067
# Check for fakes3
4168
standard_library.install_aliases()
4269
from subprocess import check_call, CalledProcessError
@@ -316,7 +343,7 @@ def test_datasink_to_s3(dummy_input, tmpdir):
316343
aws_access_key_id='mykey',
317344
aws_secret_access_key='mysecret',
318345
service_name='s3',
319-
endpoint_url='http://localhost:4567',
346+
endpoint_url='http://127.0.0.1:4567',
320347
use_ssl=False)
321348
resource.meta.client.meta.events.unregister('before-sign.s3', fix_s3_host)
322349

@@ -611,3 +638,52 @@ def test_bids_infields_outfields(tmpdir):
611638
bg = nio.BIDSDataGrabber()
612639
for outfield in ['anat', 'func']:
613640
assert outfield in bg._outputs().traits()
641+
642+
643+
@pytest.mark.skipif(no_paramiko, reason="paramiko library is not available")
644+
@pytest.mark.skipif(no_local_ssh, reason="SSH Server is not running")
645+
def test_SSHDataGrabber(tmpdir):
646+
"""Test SSHDataGrabber by connecting to localhost and collecting some data.
647+
"""
648+
old_cwd = tmpdir.chdir()
649+
650+
source_dir = tmpdir.mkdir('source')
651+
source_hdr = source_dir.join('somedata.hdr')
652+
source_dat = source_dir.join('somedata.img')
653+
source_hdr.ensure() # create
654+
source_dat.ensure() # create
655+
656+
# ssh client that connects to localhost, current user, regardless of
657+
# ~/.ssh/config
658+
def _mock_get_ssh_client(self):
659+
proxy = None
660+
client = paramiko.SSHClient()
661+
client.load_system_host_keys()
662+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
663+
client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy,
664+
timeout=10)
665+
return client
666+
MockSSHDataGrabber = copy.copy(nio.SSHDataGrabber)
667+
MockSSHDataGrabber._get_ssh_client = _mock_get_ssh_client
668+
669+
# grabber to get files from source_dir matching test.hdr
670+
ssh_grabber = MockSSHDataGrabber(infields=['test'],
671+
outfields=['test_file'])
672+
ssh_grabber.inputs.base_directory = str(source_dir)
673+
ssh_grabber.inputs.hostname = '127.0.0.1'
674+
ssh_grabber.inputs.field_template = dict(test_file='%s.hdr')
675+
ssh_grabber.inputs.template = ''
676+
ssh_grabber.inputs.template_args = dict(test_file=[['test']])
677+
ssh_grabber.inputs.test = 'somedata'
678+
ssh_grabber.inputs.sort_filelist = True
679+
680+
runtime = ssh_grabber.run()
681+
682+
# did we successfully get the header?
683+
assert runtime.outputs.test_file == str(tmpdir.join(source_hdr.basename))
684+
# did we successfully get the data?
685+
assert (tmpdir.join(source_hdr.basename) # header file
686+
.new(ext='.img') # data file
687+
.check(file=True, exists=True)) # exists?
688+
689+
old_cwd.chdir()

0 commit comments

Comments
 (0)