|
31 | 31 | import shutil
|
32 | 32 | import subprocess
|
33 | 33 | import re
|
| 34 | +import copy |
34 | 35 | import tempfile
|
35 | 36 | from os.path import join, dirname
|
36 | 37 | from warnings import warn
|
37 | 38 |
|
38 | 39 | import sqlite3
|
39 | 40 |
|
40 | 41 | from .. import config, logging
|
41 |
| -from ..utils.filemanip import copyfile, list_to_filename, filename_to_list |
| 42 | +from ..utils.filemanip import ( |
| 43 | + copyfile, list_to_filename, filename_to_list, |
| 44 | + get_related_files, related_filetype_sets) |
42 | 45 | from ..utils.misc import human_order_sorted, str2bool
|
43 | 46 | from .base import (
|
44 | 47 | TraitedSpec, traits, Str, File, Directory, BaseInterface, InputMultiPath,
|
@@ -2412,6 +2415,65 @@ def __init__(self, infields=None, outfields=None, **kwargs):
|
2412 | 2415 | and self.inputs.template[-1] != '$'):
|
2413 | 2416 | self.inputs.template += '$'
|
2414 | 2417 |
|
| 2418 | + def _get_files_over_ssh(self, template): |
| 2419 | + """Get the files matching template over an SSH connection.""" |
| 2420 | + # Connect over SSH |
| 2421 | + client = self._get_ssh_client() |
| 2422 | + sftp = client.open_sftp() |
| 2423 | + sftp.chdir(self.inputs.base_directory) |
| 2424 | + |
| 2425 | + # Get all files in the dir, and filter for desired files |
| 2426 | + template_dir = os.path.dirname(template) |
| 2427 | + template_base = os.path.basename(template) |
| 2428 | + every_file_in_dir = sftp.listdir(template_dir) |
| 2429 | + if self.inputs.template_expression == 'fnmatch': |
| 2430 | + outfiles = fnmatch.filter(every_file_in_dir, template_base) |
| 2431 | + elif self.inputs.template_expression == 'regexp': |
| 2432 | + regexp = re.compile(template_base) |
| 2433 | + outfiles = list(filter(regexp.match, every_file_in_dir)) |
| 2434 | + else: |
| 2435 | + raise ValueError('template_expression value invalid') |
| 2436 | + |
| 2437 | + if len(outfiles) == 0: |
| 2438 | + # no files |
| 2439 | + msg = 'Output template: %s returned no files' % template |
| 2440 | + if self.inputs.raise_on_empty: |
| 2441 | + raise IOError(msg) |
| 2442 | + else: |
| 2443 | + warn(msg) |
| 2444 | + |
| 2445 | + # return value |
| 2446 | + outfiles = None |
| 2447 | + |
| 2448 | + else: |
| 2449 | + # found files, sort and save to outputs |
| 2450 | + if self.inputs.sort_filelist: |
| 2451 | + outfiles = human_order_sorted(outfiles) |
| 2452 | + |
| 2453 | + # actually download the files, if desired |
| 2454 | + if self.inputs.download_files: |
| 2455 | + files_to_download = copy.copy(outfiles) # make sure new list! |
| 2456 | + |
| 2457 | + # check to see if there are any related files to download |
| 2458 | + for file_to_download in files_to_download: |
| 2459 | + related_to_current = get_related_files( |
| 2460 | + file_to_download, include_this_file=False) |
| 2461 | + existing_related_not_downloading = [ |
| 2462 | + f for f in related_to_current |
| 2463 | + if f in every_file_in_dir and f not in files_to_download] |
| 2464 | + files_to_download.extend(existing_related_not_downloading) |
| 2465 | + |
| 2466 | + for f in files_to_download: |
| 2467 | + try: |
| 2468 | + sftp.get(os.path.join(template_dir, f), f) |
| 2469 | + except IOError: |
| 2470 | + iflogger.info('remote file %s not found' % f) |
| 2471 | + |
| 2472 | + # return value |
| 2473 | + outfiles = list_to_filename(outfiles) |
| 2474 | + |
| 2475 | + return outfiles |
| 2476 | + |
2415 | 2477 | def _list_outputs(self):
|
2416 | 2478 | try:
|
2417 | 2479 | paramiko
|
@@ -2439,32 +2501,10 @@ def _list_outputs(self):
|
2439 | 2501 | isdefined(self.inputs.field_template) and \
|
2440 | 2502 | key in self.inputs.field_template:
|
2441 | 2503 | template = self.inputs.field_template[key]
|
| 2504 | + |
2442 | 2505 | if not args:
|
2443 |
| - client = self._get_ssh_client() |
2444 |
| - sftp = client.open_sftp() |
2445 |
| - sftp.chdir(self.inputs.base_directory) |
2446 |
| - filelist = sftp.listdir() |
2447 |
| - if self.inputs.template_expression == 'fnmatch': |
2448 |
| - filelist = fnmatch.filter(filelist, template) |
2449 |
| - elif self.inputs.template_expression == 'regexp': |
2450 |
| - regexp = re.compile(template) |
2451 |
| - filelist = list(filter(regexp.match, filelist)) |
2452 |
| - else: |
2453 |
| - raise ValueError('template_expression value invalid') |
2454 |
| - if len(filelist) == 0: |
2455 |
| - msg = 'Output key: %s Template: %s returned no files' % ( |
2456 |
| - key, template) |
2457 |
| - if self.inputs.raise_on_empty: |
2458 |
| - raise IOError(msg) |
2459 |
| - else: |
2460 |
| - warn(msg) |
2461 |
| - else: |
2462 |
| - if self.inputs.sort_filelist: |
2463 |
| - filelist = human_order_sorted(filelist) |
2464 |
| - outputs[key] = list_to_filename(filelist) |
2465 |
| - if self.inputs.download_files: |
2466 |
| - for f in filelist: |
2467 |
| - sftp.get(f, f) |
| 2506 | + outputs[key] = self._get_files_over_ssh(template) |
| 2507 | + |
2468 | 2508 | for argnum, arglist in enumerate(args):
|
2469 | 2509 | maxlen = 1
|
2470 | 2510 | for arg in arglist:
|
@@ -2498,44 +2538,18 @@ def _list_outputs(self):
|
2498 | 2538 | e.message +
|
2499 | 2539 | ": Template %s failed to convert with args %s"
|
2500 | 2540 | % (template, str(tuple(argtuple))))
|
2501 |
| - client = self._get_ssh_client() |
2502 |
| - sftp = client.open_sftp() |
2503 |
| - sftp.chdir(self.inputs.base_directory) |
2504 |
| - filledtemplate_dir = os.path.dirname(filledtemplate) |
2505 |
| - filledtemplate_base = os.path.basename(filledtemplate) |
2506 |
| - filelist = sftp.listdir(filledtemplate_dir) |
2507 |
| - if self.inputs.template_expression == 'fnmatch': |
2508 |
| - outfiles = fnmatch.filter(filelist, |
2509 |
| - filledtemplate_base) |
2510 |
| - elif self.inputs.template_expression == 'regexp': |
2511 |
| - regexp = re.compile(filledtemplate_base) |
2512 |
| - outfiles = list(filter(regexp.match, filelist)) |
2513 |
| - else: |
2514 |
| - raise ValueError('template_expression value invalid') |
2515 |
| - if len(outfiles) == 0: |
2516 |
| - msg = 'Output key: %s Template: %s returned no files' % ( |
2517 |
| - key, filledtemplate) |
2518 |
| - if self.inputs.raise_on_empty: |
2519 |
| - raise IOError(msg) |
2520 |
| - else: |
2521 |
| - warn(msg) |
2522 |
| - outputs[key].append(None) |
2523 |
| - else: |
2524 |
| - if self.inputs.sort_filelist: |
2525 |
| - outfiles = human_order_sorted(outfiles) |
2526 |
| - outputs[key].append(list_to_filename(outfiles)) |
2527 |
| - if self.inputs.download_files: |
2528 |
| - for f in outfiles: |
2529 |
| - try: |
2530 |
| - sftp.get( |
2531 |
| - os.path.join(filledtemplate_dir, f), f) |
2532 |
| - except IOError: |
2533 |
| - iflogger.info('remote file %s not found', |
2534 |
| - f) |
| 2541 | + |
| 2542 | + outputs[key].append(self._get_files_over_ssh(filledtemplate)) |
| 2543 | + |
| 2544 | + # disclude where there was any invalid matches |
2535 | 2545 | if any([val is None for val in outputs[key]]):
|
2536 | 2546 | outputs[key] = []
|
| 2547 | + |
| 2548 | + # no outputs is None, not empty list |
2537 | 2549 | if len(outputs[key]) == 0:
|
2538 | 2550 | outputs[key] = None
|
| 2551 | + |
| 2552 | + # one output is the item, not a list |
2539 | 2553 | elif len(outputs[key]) == 1:
|
2540 | 2554 | outputs[key] = outputs[key][0]
|
2541 | 2555 |
|
|
0 commit comments