import cobra
import utils
import pandas as pd
from scipy.stats import spearmanr, pearsonr, gaussian_kde
import numpy as np


def get_flux_correlation(samples:pd.DataFrame, reference_rxn:str, method:str="spearman", fillNan:bool=True) -> pd.DataFrame:

    df_result = pd.DataFrame(index = samples.columns)

    reference_rxn_data = samples[reference_rxn].values

    coefs = []
    pvals = []

    for rxn in samples.columns:
        if(method == "spearman"):
            coef, pval = spearmanr(reference_rxn_data, samples[rxn].values)
        elif(method == "pearson"):
            coef, pval = pearsonr(reference_rxn_data, samples[rxn].values)
        else:
            raise utils.ValueErr(method + " is not recognized, only spearman and pearson are supported.")
        coefs.append(coef)
        pvals.append(pval)

    df_result["p_value"] = pvals
    df_result["coefficient"] = coefs

    if(fillNan):
        df_result.fillna(0, inplace = True)

    return df_result

def get_flux_statistics(samples:pd.DataFrame, method:str="mean")-> pd.DataFrame:
    

    if(method == "dist"):
        columns = []
        for rxn in samples.columns:
            columns.append(rxn + "_min")
            columns.append(rxn + "_1q")
            columns.append(rxn + "_2q")
            columns.append(rxn + "_3q")
            columns.append(rxn + "_max")

        df_result = pd.DataFrame(columns = columns)
    else:
        df_result = pd.DataFrame(columns = samples.columns)

    stat_rxns = []
    for rxn in samples.columns:
        data = samples[rxn]
        if(method == "mean"):
            stat_rxns.append(data.mean())
        elif(method == "median"):
            stat_rxns.append(data.median())
        elif(method == "mode"):
            if(len(set(data)) == 1):
                stat_rxns.append(data[0])
            else:
                kde = gaussian_kde(data)
                x_grid = np.linspace(min(data), max(data), 100)
                density = kde(x_grid)
                mode_index = np.argmax(density)
                stat_rxns.append(x_grid[mode_index]) 
        elif(method == "dist"):
            stat_rxns.append(data.min())
            stat_rxns.append(data.quantile(0.25))
            stat_rxns.append(data.quantile(0.50))
            stat_rxns.append(data.quantile(0.75))
            stat_rxns.append(data.max())
        else:
            raise utils.ValueErr(method + " is not recognized, only mean, median, mode and dist are supported.")
        
    df_result.loc[0] = stat_rxns
    
    return df_result



    