view scripts/modules/utils.py @ 3:0cbed1c0a762 draft default tip

planemo upload commit 15239f1674081ab51ab8dd75a9a40cf1bfaa93e8
author cstrittmatter
date Tue, 28 Jan 2020 10:42:31 -0500
parents 965517909457
children
line wrap: on
line source

import pickle
import traceback
import shlex
import subprocess
from threading import Timer
import shutil
import time
import functools
import os.path
import sys
import argparse


def start_logger(workdir):
    time_str = time.strftime("%Y%m%d-%H%M%S")
    sys.stdout = Logger(workdir, time_str)
    logfile = sys.stdout.getLogFile()
    return logfile, time_str


class Logger(object):
    def __init__(self, out_directory, time_str):
        self.logfile = os.path.join(out_directory, str('run.' + time_str + '.log'))
        self.terminal = sys.stdout
        self.log = open(self.logfile, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        pass

    def getLogFile(self):
        return self.logfile


def checkPrograms(programs_version_dictionary):
    print('\n' + 'Checking dependencies...')
    programs = programs_version_dictionary
    which_program = ['which', '']
    listMissings = []
    for program in programs:
        which_program[1] = program
        run_successfully, stdout, stderr = runCommandPopenCommunicate(which_program, False, None, False)
        if not run_successfully:
            listMissings.append(program + ' not found in PATH.')
        else:
            print(stdout.splitlines()[0])
            if programs[program][0] is None:
                print(program + ' (impossible to determine programme version) found at: ' + stdout.splitlines()[0])
            else:
                if program.endswith('.jar'):
                    check_version = ['java', '-jar', stdout.splitlines()[0], programs[program][0]]
                    programs[program].append(stdout.splitlines()[0])
                else:
                    check_version = [stdout.splitlines()[0], programs[program][0]]
                run_successfully, stdout, stderr = runCommandPopenCommunicate(check_version, False, None, False)
                if stdout == '':
                    stdout = stderr
                if program in ['wget', 'awk']:
                    version_line = stdout.splitlines()[0].split(' ', 3)[2]
                elif program in ['prefetch', 'fastq-dump']:
                    version_line = stdout.splitlines()[1].split(' ')[-1]
                else:
                    version_line = stdout.splitlines()[0].split(' ')[-1]
                replace_characters = ['"', 'v', 'V', '+', ',']
                for i in replace_characters:
                    version_line = version_line.replace(i, '')
                print(program + ' (' + version_line + ') found')
                if programs[program][1] == '>=':
                    program_found_version = version_line.split('.')
                    program_version_required = programs[program][2].split('.')
                    if len(program_version_required) == 3:
                        if len(program_found_version) == 2:
                            program_found_version.append(0)
                        else:
                            program_found_version[2] = program_found_version[2].split('_')[0]
                    for i in range(0, len(program_version_required)):
                        if int(program_found_version[i]) > int(program_version_required[i]):
                            break
                        elif int(program_found_version[i]) == int(program_version_required[i]):
                            continue
                        else:
                            listMissings.append('It is required ' + program + ' with version ' +
                                                programs[program][1] + ' ' + programs[program][2])
                else:
                    if version_line != programs[program][2]:
                        listMissings.append('It is required ' + program + ' with version ' + programs[program][1] +
                                            ' ' + programs[program][2])
    return listMissings


def requiredPrograms():
    programs_version_dictionary = {}
    programs_version_dictionary['rematch.py'] = ['--version', '>=', '4.0']
    missingPrograms = checkPrograms(programs_version_dictionary)
    if len(missingPrograms) > 0:
        sys.exit('\n' + 'Errors:' + '\n' + '\n'.join(missingPrograms))


def general_information(logfile, version, outdir, time_str):
    # Check if output directory exists

    print('\n' + '==========> patho_typing <==========')
    print('\n' + 'Program start: ' + time.ctime())

    # Tells where the logfile will be stored
    print('\n' + 'LOGFILE:')
    print(logfile)

    # Print command
    print('\n' + 'COMMAND:')
    script_path = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'patho_typing.py')
    print(sys.executable + ' ' + ' '.join(sys.argv))

    # Print directory where programme was lunch
    print('\n' + 'PRESENT DIRECTORY:')
    present_directory = os.path.abspath(os.getcwd())
    print(present_directory)

    # Print program version
    print('\n' + 'VERSION:')
    script_version_git(version, present_directory, script_path)

    # Check programms
    requiredPrograms()

    return script_path


def setPATHvariable(doNotUseProvidedSoftware, script_path):
    path_variable = os.environ['PATH']
    script_folder = os.path.dirname(script_path)
    # Set path to use provided softwares
    if not doNotUseProvidedSoftware:
        bowtie2 = os.path.join(script_folder, 'src', 'bowtie2-2.2.9')
        samtools = os.path.join(script_folder, 'src', 'samtools-1.3.1', 'bin')
        bcftools = os.path.join(script_folder, 'src', 'bcftools-1.3.1', 'bin')

        os.environ['PATH'] = str(':'.join([bowtie2, samtools, bcftools, path_variable]))

    # Print PATH variable
    print('\n' + 'PATH variable:')
    print(os.environ['PATH'])


def script_version_git(version, current_directory, script_path, no_git_info=False):
    """
    Print script version and get GitHub commit information

    Parameters
    ----------
    version : str
        Version of the script, e.g. "4.0"
    current_directory : str
        Path to the directory where the script was start to run
    script_path : str
        Path to the script running
    no_git_info : bool, default False
        True if it is not necessary to retreive the GitHub commit information

    Returns
    -------

    """
    print('Version {}'.format(version))

    if not no_git_info:
        try:
            os.chdir(os.path.dirname(os.path.dirname(script_path)))
            command = ['git', 'log', '-1', '--date=local', '--pretty=format:"%h (%H) - Commit by %cn, %cd) : %s"']
            run_successfully, stdout, stderr = runCommandPopenCommunicate(command, False, 15, False)
            print(stdout)
            command = ['git', 'remote', 'show', 'origin']
            run_successfully, stdout, stderr = runCommandPopenCommunicate(command, False, 15, False)
            print(stdout)
        except:
            print('HARMLESS WARNING: git command possibly not found. The GitHub repository information will not be'
                  ' obtained.')
        finally:
            os.chdir(current_directory)


def runTime(start_time):
    end_time = time.time()
    time_taken = end_time - start_time
    hours, rest = divmod(time_taken, 3600)
    minutes, seconds = divmod(rest, 60)
    print('Runtime :' + str(hours) + 'h:' + str(minutes) + 'm:' + str(round(seconds, 2)) + 's')
    return round(time_taken, 2)


def timer(function, name):
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        print('\n' + 'RUNNING {0}\n'.format(name))
        start_time = time.time()

        results = list(function(*args, **kwargs))  # guarantees return is a list to allow .insert()

        time_taken = runTime(start_time)
        print('END {0}'.format(name))

        results.insert(0, time_taken)
        return results
    return wrapper


def removeDirectory(directory):
    if os.path.isdir(directory):
        shutil.rmtree(directory)


def saveVariableToPickle(variableToStore, pickleFile):
    with open(pickleFile, 'wb') as writer:
        pickle.dump(variableToStore, writer)


def extractVariableFromPickle(pickleFile):
    with open(pickleFile, 'rb') as reader:
        variable = pickle.load(reader)
    return variable


def trace_unhandled_exceptions(func):
    @functools.wraps(func)
    def wrapped_func(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except:
            print('Exception in ' + func.__name__)
            traceback.print_exc()
    return wrapped_func


def kill_subprocess_Popen(subprocess_Popen, command):
    print('Command run out of time: ' + str(command))
    subprocess_Popen.kill()


def runCommandPopenCommunicate(command, shell_True, timeout_sec_None, print_comand_True):
    run_successfully = False
    if not isinstance(command, str):
        command = ' '.join(command)
    command = shlex.split(command)

    if print_comand_True:
        print('Running: ' + ' '.join(command))

    if shell_True:
        command = ' '.join(command)
        proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    else:
        proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    not_killed_by_timer = True
    if timeout_sec_None is None:
        stdout, stderr = proc.communicate()
    else:
        time_counter = Timer(timeout_sec_None, kill_subprocess_Popen, args=(proc, command,))
        time_counter.start()
        stdout, stderr = proc.communicate()
        time_counter.cancel()
        not_killed_by_timer = time_counter.isAlive()

    stdout = stdout.decode("utf-8")
    stderr = stderr.decode("utf-8")

    if proc.returncode == 0:
        run_successfully = True
    else:
        if not print_comand_True and not_killed_by_timer:
            print('Running: ' + str(command))
        if len(stdout) > 0:
            print('STDOUT')
            print(stdout)
        if len(stderr) > 0:
            print('STDERR')
            print(stderr)
    return run_successfully, stdout, stderr


def required_length(tuple_length_options, argument_name):
    class RequiredLength(argparse.Action):
        def __call__(self, parser, args, values, option_string=None):
            if len(values) not in tuple_length_options:
                msg = 'Option {argument_name} requires one of the following number of' \
                      ' arguments: {tuple_length_options}'.format(argument_name=self.argument_name,
                                                                  tuple_length_options=tuple_length_options)
                raise argparse.ArgumentTypeError(msg)
            setattr(args, self.dest, values)
    return RequiredLength


def get_sequence_information(fasta_file, length_extra_seq):
    sequence_dict = {}
    headers = {}

    with open(fasta_file, 'rtU') as reader:
        blank_line_found = False
        sequence_counter = 0
        temp_sequence_dict = {}
        for line in reader:
            line = line.splitlines()[0]
            if len(line) > 0:
                if not blank_line_found:
                    if line.startswith('>'):
                        if len(temp_sequence_dict) > 0:
                            if list(temp_sequence_dict.values())[0]['length'] - 2 * length_extra_seq > 0:
                                sequence_dict[list(temp_sequence_dict.keys())[0]] = list(temp_sequence_dict.values())[0]
                                headers[list(temp_sequence_dict.values())[0]['header'].lower()] = sequence_counter
                            else:
                                print(list(temp_sequence_dict.values())[0]['header'] + ' sequence ignored due to '
                                                                                       'length <= 0')
                            temp_sequence_dict = {}

                        if line[1:].lower() in headers:
                            sys.exit('Found duplicated sequence headers')

                        sequence_counter += 1
                        temp_sequence_dict[sequence_counter] = {'header': line[1:].lower(), 'sequence': '', 'length': 0}
                    else:
                        temp_sequence_dict[sequence_counter]['sequence'] += line.upper()
                        temp_sequence_dict[sequence_counter]['length'] += len(line)
                else:
                    sys.exit('It was found a blank line between the fasta file above line ' + line)
            else:
                blank_line_found = True

        if len(temp_sequence_dict) > 0:
            if list(temp_sequence_dict.values())[0]['length'] - 2 * length_extra_seq > 0:
                sequence_dict[list(temp_sequence_dict.keys())[0]] = list(temp_sequence_dict.values())[0]
                headers[list(temp_sequence_dict.values())[0]['header'].lower()] = sequence_counter
            else:
                print(list(temp_sequence_dict.values())[0]['header'] + ' sequence ignored due to length <= 0')

    return sequence_dict, headers


def simplify_sequence_dict(sequence_dict):
    simple_sequence_dict = {}
    for counter, info in list(sequence_dict.items()):
        simple_sequence_dict[info['header']] = info
        del simple_sequence_dict[info['header']]['header']
    return simple_sequence_dict


def chunkstring(string, length):
    return (string[0 + i:length + i] for i in range(0, len(string), length))


def clean_headers_sequences(sequence_dict):
    problematic_characters = ["|", " ", ",", ".", "(", ")", "'", "/", ":"]
    # print 'Checking if reference sequences contain ' + str(problematic_characters) + '\n'

    headers_changed = False
    new_headers = {}
    for i in sequence_dict:
        if any(x in sequence_dict[i]['header'] for x in problematic_characters):
            for x in problematic_characters:
                sequence_dict[i]['header'] = sequence_dict[i]['header'].replace(x, '_')
            headers_changed = True
        new_headers[sequence_dict[i]['header'].lower()] = i

    if headers_changed:
        print('At least one of the those characters was found. Replacing those with _' + '\n')

    return sequence_dict, new_headers