library(flashpcaR)
library(dbscan)
library(cluster)

## MAIN ###
# get command line arguments
CLI_FLAG = 1
if (CLI_FLAG == 1) {
  n = commandArgs(trailingOnly=TRUE)[1]
  outdir = commandArgs(trailingOnly=TRUE)[2]
  basename = commandArgs(trailingOnly=TRUE)[3]
  data_source = commandArgs(trailingOnly=TRUE)[4]
  data_type = commandArgs(trailingOnly=TRUE)[5]
  eth_filename = commandArgs(trailingOnly=TRUE)[6]
  control_tag = commandArgs(trailingOnly=TRUE)[7]
  if (control_tag == "None") {control_tag = NULL}
  cases_tag = commandArgs(trailingOnly=TRUE)[8]
  if (cases_tag == "None") {cases_tag = NULL}
  numsds = as.numeric(commandArgs(trailingOnly=TRUE)[9])
  cmethod = commandArgs(trailingOnly=TRUE)[10]
  tmethod = commandArgs(trailingOnly=TRUE)[11]
  path_to_r_functions = commandArgs(trailingOnly=TRUE)[12]
  xsamples_filename = commandArgs(trailingOnly=TRUE)[13]
  xsnps_filename = commandArgs(trailingOnly=TRUE)[14]
} else {
  n = 10
  basename = "test_eth2"
  data_source = "./data/halo1_numeric.ped"
  data_type = "numeric_ped"
  outdir = paste0(getwd(), "/full_output_", basename)
  #data_source = "./data/HapMap3_flashPCA_data.rds"
  #data_type = "rds"
  #eth_filename = "./data/HapMap3_ethnicity_rf.txt"
  eth_filename = "./data/Halo_ethnicity_rf.txt"
  control_tag = "HAPS"
  cases_tag = NULL
  numsds = 1.1
  cmethod = "hclust"
  tmethod = "mcd"
  path_to_r_functions = paste0(getwd(), "/R_functions/")
  xsamples_filename = "./xfiles/halo1_xsamples.txt"
  xsnps_filename = "./xfiles/halo1_xsnps.txt"
}

# get source code
source(paste0(path_to_r_functions, "/", "plotting_functions.R"))
source(paste0(path_to_r_functions, "/", "pca_helpers.R"))
source(paste0(path_to_r_functions, "/", "pipeline_code.R"))
source(paste0(path_to_r_functions, "/", "clustering_functions.R"))
source(paste0(path_to_r_functions, "/", "outlier_trimming.R"))

if (CLI_FLAG != 1) {
  unlink(paste0(getwd(), "/", "full_output_", basename), recursive=TRUE)
}

# read in data
ped_data = get_source_data(data_source, data_type)
eth_data = parse_ethnicity_file(eth_filename)
xsamples = get_first_column(xsamples_filename)
xsnps = get_first_column(xsnps_filename)

# do the pca and prepare plots
iterations = list()
for(i in 1:n) {
  fpd = filter_ped_data(ped_data, xsamples, xsnps)
  iterations[[i]] = single_iteration(outdir, basename, fpd, xsamples, numsds, 
                                     cmethod, tmethod, control_tag, cases_tag, ethnicity_data=eth_data)
  iterations[[i]]$dirname = generate_directory_name(outdir, basename, i)
  xsamples = iterations[[i]]$xsamples
}

# create folders and plots
for (i in 1:n) {
  titer = iterations[[i]]
  dir.create(titer$dirname, recursive=TRUE)
  num_plots = titer$num_plots
  
  for (j in 1:num_plots) {
    plot_filename = sprintf("%s/%s_plot_number_%d.png", titer$dirname, basename, j)
    plot_by_groups(titer$pca_data$values[, c(1, 2)], 
                  titer$plots[[j]]$groups, 
                  titer$plots[[j]]$tags, 
                  titer$plots[[j]]$plot_colors, 
                  titer$plots[[j]]$plot_symbols, 
                  titer$plots[[j]]$plot_title,
                  plot_filename=plot_filename
    )
  }
  
  # write outliers to file  
  xfilename = paste0(titer$dirname, "/", basename, "_xfile.txt")
  outliers_filename = paste0(titer$dirname, "/", basename, "_outliers.txt")
  xscon = add_ethnicity_data(titer$old_xsamples, eth_data)
  olcon = add_ethnicity_data(titer$outliers, eth_data)
  write.table(xscon, file=xfilename, row.names=FALSE, col.names=TRUE, sep=",")
  write.table(olcon, file=outliers_filename, row.names=FALSE, col.names=TRUE, sep=",")
}