view GCMS/library_lookup.py @ 21:43902da5d00e

changed match_library location again
author linda.bakker@wur.nl <linda.bakker@wur.nl>
date Wed, 06 May 2015 08:06:53 +0200
parents f70b2c169e3a
children f0c6feab06e7
line wrap: on
line source

'''
Logic for searching a Retention Index database file given output from NIST
'''
import match_library
import re
import sys
import csv

__author__ = "Marcel Kempenaar"
__contact__ = "brs@nbic.nl"
__copyright__ = "Copyright, 2012, Netherlands Bioinformatics Centre"
__license__ = "MIT"

def create_lookup_table(library_file, column_type_name, statphase):
    '''
    Creates a dictionary holding the contents of the library to be searched
    @param library_file: library to read
    @param column_type_name: the columns type name
    @param statphase: the columns stationary phase
    '''
    (data, header) = match_library.read_library(library_file)
    # Test for presence of required columns
    if ('columntype' not in header or
        'columnphasetype' not in header or
        'cas' not in header):
        raise IOError('Missing columns (create_lookup_table) in ', library_file)

    column_type_column = header.index("columntype")
    statphase_column = header.index("columnphasetype")
    cas_column = header.index("cas")

    filtered_library = [line for line in data if line[column_type_column] == column_type_name
                        and line[statphase_column] == statphase]
    lookup_dict = {}
    for element in filtered_library:
        # Here the cas_number is set to the numeric part of the cas_column value, so if the 
        # cas_column value is 'C1433' then cas_number will be '1433'
        cas_number = str(re.findall(r'\d+', (element[cas_column]).strip())[0])
        try:
            lookup_dict[cas_number].append(element)
        except KeyError:
            lookup_dict[cas_number] = [element]
    return lookup_dict


def _preferred(hits, pref, ctype, polar, model, method):
    '''
    Returns all entries in the lookup_dict that have the same column name, type and polarity
    as given by the user, uses regression if selected given the model and method to use. The
    regression is applied on the column with the best R-squared value in the model
    @param hits: all entries in the lookup_dict for the given CAS number
    @param pref: preferred GC-column, can be one or more names
    @param ctype: column type (capillary etc.)
    @param polar: polarity (polar / non-polar etc.)
    @param model: data loaded from file containing regression models
    @param method: supported regression method (i.e. poly(nomial) or linear)
    '''
    match = []
    for column in pref:
        for hit in hits:
            if hit[4] == ctype and hit[5] == polar and hit[6] == column:
                # Create copy of found hit since it will be altered downstream
                match.extend(hit)
                return match, False

    # No hit found for current CAS number, return if not performing regression
    if not model:
        return False, False

    # Perform regression
    for column in pref:
        if column not in model:
            break
        # Order regression candidates by R-squared value (last element)
        order = sorted(model[column].items(), key=lambda col: col[1][-1])
        # Create list of regression candidate column names
        regress_columns = list(reversed([column for (column, _) in order]))
        # Names of available columns
        available = [hit[6] for hit in hits]
        
        # TODO: combine Rsquared and number of datapoints to get the best regression match
        '''
        # Iterate regression columns (in order) and retrieve their models
        models = {}
        for col in regress_columns:
            if col in available:
                hit = list(hits[available.index(col)])
                if hit[4] == ctype:
                    # models contains all model data including residuals [-2] and rsquared [-1]
                    models[pref[0]] = model[pref[0]][hit[6]] 
        # Get the combined maximum for residuals and rsquared
        best_match = models[]
        # Apply regression
        if method == 'poly':
            regressed = _apply_poly_regression(best_match, hit[6], float(hit[3]), model)
            if regressed:
                hit[3] = regressed
            else:
                return False, False
            else:
                hit[3] = _apply_linear_regression(best_match, hit[6], float(hit[3]), model)
                match.extend(hit)
            return match, hit[6]
        '''
        
        for col in regress_columns:
            if col in available:
                hit = list(hits[available.index(col)])
                if hit[4] == ctype:
                    # Perform regression using a column for which regression is possible
                    if method == 'poly':
                        # Polynomial is only possible within a set border, if the RI falls outside
                        # of this border, skip this lookup
                        regressed = _apply_poly_regression(pref[0], hit[6], float(hit[3]), model)
                        if regressed:
                            hit[3] = regressed
                        else:
                            return False, False
                    else:
                        hit[3] = _apply_linear_regression(pref[0], hit[6], float(hit[3]), model)
                    match.extend(hit)
                    return match, hit[6]

    return False, False



def default_hit(row, cas_nr, compound_id):
    '''
    This method will return a "default"/empty hit for cases where the
    method _preferred() returns False (i.e. a RI could not be found 
    for the given cas nr, also not via regression.
    '''
    return [
            #'CAS', 
            'C' + cas_nr,
            #'NAME', 
            '',
            #'FORMULA', 
            '',
            #'RI', 
            '0.0',
            #'Column.type', 
            '',
            #'Column.phase.type', 
            '',
            #'Column.name', 
            '',
            #'phase.coding', 
            ' ',
            #'CAS_column.Name', 
            '',
            #'Centrotype', -> NOTE THAT compound_id is not ALWAYS centrotype...depends on MsClust algorithm used...for now only one MsClust algorithm is used so it is not an issue, but this should be updated/corrected once that changes
            compound_id,
            #'Regression.Column.Name', 
            '',
            #'min', 
            '',
            #'max', 
            '',
            #'nr.duplicates', 
            '']
    

def format_result(lookup_dict, nist_tabular_filename, pref, ctype, polar, model, method):
    '''
    Looks up the compounds in the library lookup table and formats the results
    @param lookup_dict: dictionary containing the library to be searched
    @param nist_tabular_filename: NIST output file to be matched
    @param pref: (list of) column-name(s) to look for
    @param ctype: column type of interest
    @param polar: polarity of the used column
    @param model: data loaded from file containing regression models
    @param method: supported regression method (i.e. poly(nomial) or linear)
    '''
    (nist_tabular_list, header_clean) = match_library.read_library(nist_tabular_filename)
    # Retrieve indices of the CAS and compound_id columns (exit if not present)
    try:
        casi = header_clean.index("cas")
        idi = header_clean.index("id")
    except:
        raise IOError("'CAS' or 'compound_id' not found in header of library file")

    data = []
    for row in nist_tabular_list:
        casf = str(row[casi].replace('-', '').strip())
        compound_id = str(row[idi].split('-')[0])
        if casf in lookup_dict:
            found_hit, regress = _preferred(lookup_dict[casf], pref, ctype, polar, model, method)
            if found_hit:
                # Keep cas nr as 'C'+ numeric part:
                found_hit[0] = 'C' + casf
                # Add compound id
                found_hit.insert(9, compound_id)
                # Add information on regression process
                found_hit.insert(10, regress if regress else 'None')
                # Replace column index references with actual number of duplicates
                dups = len(found_hit[-1].split(','))
                if dups > 1:
                    found_hit[-1] = str(dups + 1)
                else:
                    found_hit[-1] = '0'
                data.append(found_hit)
                found_hit = ''
            else:
                data.append(default_hit(row, casf, compound_id))
        else:
            data.append(default_hit(row, casf, compound_id))
            
        casf = ''
        compound_id = ''
        found_hit = []
        dups = []
    return data


def _save_data(content, outfile):
    '''
    Write to output file
    @param content: content to write
    @param outfile: file to write to
    '''
    # header
    header = ['CAS',
              'NAME',
              'FORMULA',
              'RI',
              'Column.type',
              'Column.phase.type',
              'Column.name',
              'phase.coding',
              'CAS_column.Name',
              'Centrotype',
              'Regression.Column.Name',
              'min',
              'max',
              'nr.duplicates']
    output_handle = csv.writer(open(outfile, 'wb'), delimiter="\t")
    output_handle.writerow(header)
    for entry in content:
        output_handle.writerow(entry)


def _read_model(model_file):
    '''
    Creates an easy to search dictionary for getting the regression parameters
    for each valid combination of GC-columns
    @param model_file: filename containing the regression models
    '''
    regress = list(csv.reader(open(model_file, 'rU'), delimiter='\t'))
    if len(regress.pop(0)) > 9:
        method = 'poly'
    else:
        method = 'linear'

    model = {}
    # Create new dictionary for each GC-column
    for line in regress:
        model[line[0]] = {}

    # Add data
    for line in regress:
        if method == 'poly':
            model[line[0]][line[1]] = [float(col) for col in line[2:11]]
        else:  # linear
            model[line[0]][line[1]] = [float(col) for col in line[2:9]]

    return model, method


def _apply_poly_regression(column1, column2, retention_index, model):
    '''
    Calculates a new retention index (RI) value using a given 3rd-degree polynomial
    model based on data from GC columns 1 and 2
    @param column1: name of the selected GC-column
    @param column2: name of the GC-column to use for regression
    @param retention_index: RI to convert
    @param model: dictionary containing model information for all GC-columns
    '''
    coeff = model[column1][column2]
    # If the retention index to convert is within range of the data the model is based on, perform regression
    if coeff[4] < retention_index < coeff[5]:
        return (coeff[3] * (retention_index ** 3) + coeff[2] * (retention_index ** 2) + 
                (retention_index * coeff[1]) + coeff[0])
    else:
        return False


def _apply_linear_regression(column1, column2, retention_index, model):
    '''
    Calculates a new retention index (RI) value using a given linear model based on data
    from GC columns 1 and 2
    @param column1: name of the selected GC-column
    @param column2: name of the GC-column to use for regression
    @param retention_index: RI to convert
    @param model: dictionary containing model information for all GC-columns
    '''
    # TODO: No use of limits
    coeff = model[column1][column2]
    return coeff[1] * retention_index + coeff[0]


def main():
    '''
    Library Lookup main function
    '''
    library_file = sys.argv[1]
    nist_tabular_filename = sys.argv[2]
    ctype = sys.argv[3]
    polar = sys.argv[4]
    outfile = sys.argv[5]
    pref = sys.argv[6:-1]
    regress = sys.argv[-1]

    if regress != 'False':
        model, method = _read_model(regress)
    else:
        model, method = False, None

    lookup_dict = create_lookup_table(library_file, ctype, polar)
    data = format_result(lookup_dict, nist_tabular_filename, pref, ctype, polar, model, method)

    _save_data(data, outfile)


if __name__ == "__main__":
    main()