Mercurial > repos > pmac > iterativepca
view R_functions/plotting_functions.R @ 0:64e75e21466e draft default tip
Uploaded
author | pmac |
---|---|
date | Wed, 01 Jun 2016 03:38:39 -0400 |
parents | |
children |
line wrap: on
line source
## Plotting and grouping ## # input data: some number of 2d observations. Each row represents a single observation, # column 1 = variable 1, to be plotted on the x-axis, # column 2 = variable 2, to be plotted on the y-axis # groups: Integer vector with same number of entries as there are rows in the input data, # representing which group each observation belongs to. Negative numbers are not plotted # tags: the tag to put on the legend for each group # plot_colors: colors to use for each group # plot_symbols: symbols to use for each group # plot_title: as name suggests # plot_filename: if this is not null, graph is output to a png with the specified name plot_by_groups = function(input_data, groups, tags, plot_colors, plot_symbols, plot_title, plot_filename=NULL) { if(!is.null(plot_filename)) { png(plot_filename) } # leave some extra room on the RHS for the legend par(mar=c(5.1, 4.1, 4.1, 8.1)) x = as.numeric(input_data[, 1]) y = as.numeric(input_data[, 2]) gids = sort(unique(groups[which(groups >= 0)])) n = length(gids) # first set up the plot area to the correct dimensions plot(x, y, col="white") for (i in 1:n) { gid = gids[i] pts_x = x[which(groups == gid)] pts_y = y[which(groups == gid)] pts_color = plot_colors[i] pts_symbol = plot_symbols[i] points(pts_x, pts_y, col=pts_color, pch=pts_symbol) } legend(x="topright", xpd=TRUE, inset=c(-0.3, 0), col=plot_colors, pch=plot_symbols, legend=tags, text.col=plot_colors) title(main=plot_title) if(!is.null(plot_filename)) { dev.off() } } # Controls vs cases plot. Colour controls blue, cases red, # Samples which are neither control nor case are black. setup_cvc_plot = function(pca_data, control_tag, cases_tag) { plot_info = list() nsamples = length(pca_data$ids) groups = rep(1, nsamples) control_legend = paste0("CO: ", control_tag) cases_legend = paste0("CA: ", cases_tag) if (!is.null(control_tag)) { groups[grep(control_tag, pca_data$ids)] = 2 } if (!is.null(cases_tag)) { groups[grep(cases_tag, pca_data$ids)] = 3 } res = sort(unique(groups)) if (length(res) == 1) { tags = c("UNKNOWN") plot_colors = c("black") } else if (length(res) == 3) { tags = c("UNKNOWN", control_legend, cases_legend) plot_colors = c("black", "blue", "red") } else { if (all(res == c(1, 2))) { tags = c("UNKNOWN", control_legend) plot_colors = c("black", "blue") } else if (all(res == c(1, 3))) { tags = c("UNKNOWN", cases_legend) plot_colors = c("black", "red") } else { tags = c(control_legend, cases_legend) plot_colors = c("blue", "red") } } plot_info$groups = groups plot_info$tags = tags plot_info$plot_colors = plot_colors plot_info$plot_symbols = rep(1, length(res)) plot_info$plot_title = "Control vs Cases Plot" return(plot_info) } # outliers plot; colour outliers red, non-outliers green setup_ol_plot = function(pca_data, outliers) { plot_info = list() nsamples = dim(pca_data$values)[1] groups = 1:nsamples groups[outliers] = 1 groups[setdiff(1:nsamples, outliers)] = 2 plot_info$groups = groups plot_info$tags = c("outliers", "good data") plot_info$plot_colors = c("red", "green") plot_info$plot_symbols = c(1, 20) plot_info$plot_title = "Outliers Plot" return(plot_info) } # standard deviations plot; colour samples by s.dev setup_sd_plot = function(pca_data) { plot_info = list() nsamples = dim(pca_data$values)[1] pc1 = as.numeric(pca_data$values[, 1]) pc2 = as.numeric(pca_data$values[, 2]) pc1_sds = as.numeric(lapply(pc1, compute_numsds, pc1)) pc2_sds = as.numeric(lapply(pc2, compute_numsds, pc2)) groups = 1:nsamples groups[get_sdset2d(pc1_sds, pc2_sds, 1)] = 1 groups[get_sdset2d(pc1_sds, pc2_sds, 2)] = 2 groups[get_sdset2d(pc1_sds, pc2_sds, 3)] = 3 groups[union(which(pc1_sds > 3), which(pc2_sds > 3))] = 4 plot_info$groups = groups plot_info$tags = c("SD = 1", "SD = 2", "SD = 3", "SD > 3") plot_info$plot_colors = rainbow(4) plot_info$plot_symbols = rep(20, 4) plot_info$plot_title = "Standard Deviations Plot" return(plot_info) } # Plot samples, with coloured clusters. Rejected clusters use # a cross symbol instead of a filled circle setup_cluster_plot = function(pca_data, clusters, rc=NULL) { plot_info = list() groups = clusters ids = sort(unique(groups)) n = length(ids) tags = 1:n for (i in 1:n) { tags[i] = sprintf("cluster %s", ids[i]) } outliers = which(groups == 0) if (length(outliers) != 0) { tags[1] = "outliers" } plot_colors = rainbow(n) plot_symbols = rep(20, n) if (length(outliers) != 0) { plot_symbols[1] = 1 } # labelling for rejected clusters if(!is.null(rc)) { for(i in 1:n) { if((ids[i] != 0) && (ids[i] %in% as.numeric(rc))) { tags[i] = "rej. clust." plot_symbols[i] = 4 } } } plot_info$groups = groups plot_info$tags = tags plot_info$plot_colors = plot_colors plot_info$plot_symbols = plot_symbols plot_info$plot_title = "Cluster Plot" return(plot_info) } # Plot samples, colouring by ethnicity. Different ethnicities also # have different symbols. setup_ethnicity_plot = function(pca_data, ethnicity_data) { plot_info = list() nsamples = dim(pca_data$values)[1] eth = 1:nsamples for (i in 1:nsamples) { sample_id = pca_data$ids[i] eth[i] = as.character(ethnicity_data[sample_id, "population"]) if(is.na(eth[i])) { eth[i] = "UNKNOWN" } } n = length(unique(eth)) plot_info$groups = as.numeric(as.factor(eth)) plot_info$tags = sort(unique(eth)) plot_info$plot_colors = rainbow(n) plot_info$plot_symbols = 1:n plot_info$plot_title = "Ethnicity Plot" return(plot_info) } draw_cutoffs = function(input_data, x, y, numsds) { pcx = as.numeric(input_data[x, ]) pcy = as.numeric(input_data[y, ]) vlines = c(median(pcx) - numsds*sd(pcx), median(pcx) + numsds*sd(pcx)) hlines = c(median(pcy) - numsds*sd(pcy), median(pcy) + numsds*sd(pcy)) abline(v=vlines) abline(h=hlines) } # Following helper functions are used in the 'setup_sd_plot' function # given a list of standard deviations, work out which points are n standard deviations away get_sdset2d = function(x1, x2, n) { if (n == 1) { ind = intersect(which(x1 == 1), which(x2 == 1)) } else { lower = get_sdset2d(x1, x2, n - 1) upper = union(which(x1 > n), which(x2 > n)) xset = union(lower, upper) bigset = union(which(x1 == n), which(x2 == n)) ind = setdiff(bigset, xset) } return(ind) } # work out how many standard deviations away from the sample median a single point is # accuracy of this decreases for outliers, as the error in the estimated sd is # multiplied compute_numsds = function(point, x) { x_sd = sd(x) sum = x_sd m = median(x) i = 1 while(abs(point - m) > sum) { i = i + 1 sum = sum + x_sd } return(i) }