Mercurial > repos > pmac > iterativepca
diff R_functions/plotting_functions.R @ 0:64e75e21466e draft default tip
Uploaded
author | pmac |
---|---|
date | Wed, 01 Jun 2016 03:38:39 -0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/R_functions/plotting_functions.R Wed Jun 01 03:38:39 2016 -0400 @@ -0,0 +1,225 @@ +## Plotting and grouping ## +# input data: some number of 2d observations. Each row represents a single observation, +# column 1 = variable 1, to be plotted on the x-axis, +# column 2 = variable 2, to be plotted on the y-axis +# groups: Integer vector with same number of entries as there are rows in the input data, +# representing which group each observation belongs to. Negative numbers are not plotted +# tags: the tag to put on the legend for each group +# plot_colors: colors to use for each group +# plot_symbols: symbols to use for each group +# plot_title: as name suggests +# plot_filename: if this is not null, graph is output to a png with the specified name +plot_by_groups = function(input_data, groups, tags, plot_colors, plot_symbols, plot_title, plot_filename=NULL) { + if(!is.null(plot_filename)) { + png(plot_filename) + } + # leave some extra room on the RHS for the legend + par(mar=c(5.1, 4.1, 4.1, 8.1)) + x = as.numeric(input_data[, 1]) + y = as.numeric(input_data[, 2]) + gids = sort(unique(groups[which(groups >= 0)])) + n = length(gids) + + # first set up the plot area to the correct dimensions + plot(x, y, col="white") + + for (i in 1:n) { + gid = gids[i] + pts_x = x[which(groups == gid)] + pts_y = y[which(groups == gid)] + pts_color = plot_colors[i] + pts_symbol = plot_symbols[i] + points(pts_x, pts_y, col=pts_color, pch=pts_symbol) + } + legend(x="topright", + xpd=TRUE, + inset=c(-0.3, 0), + col=plot_colors, + pch=plot_symbols, + legend=tags, + text.col=plot_colors) + title(main=plot_title) + if(!is.null(plot_filename)) { + dev.off() + } +} + +# Controls vs cases plot. Colour controls blue, cases red, +# Samples which are neither control nor case are black. +setup_cvc_plot = function(pca_data, control_tag, cases_tag) { + plot_info = list() + nsamples = length(pca_data$ids) + groups = rep(1, nsamples) + control_legend = paste0("CO: ", control_tag) + cases_legend = paste0("CA: ", cases_tag) + if (!is.null(control_tag)) { + groups[grep(control_tag, pca_data$ids)] = 2 + } + if (!is.null(cases_tag)) { + groups[grep(cases_tag, pca_data$ids)] = 3 + } + res = sort(unique(groups)) + if (length(res) == 1) { + tags = c("UNKNOWN") + plot_colors = c("black") + } else if (length(res) == 3) { + tags = c("UNKNOWN", control_legend, cases_legend) + plot_colors = c("black", "blue", "red") + } else { + if (all(res == c(1, 2))) { + tags = c("UNKNOWN", control_legend) + plot_colors = c("black", "blue") + } else if (all(res == c(1, 3))) { + tags = c("UNKNOWN", cases_legend) + plot_colors = c("black", "red") + } else { + tags = c(control_legend, cases_legend) + plot_colors = c("blue", "red") + } + } + plot_info$groups = groups + plot_info$tags = tags + plot_info$plot_colors = plot_colors + plot_info$plot_symbols = rep(1, length(res)) + plot_info$plot_title = "Control vs Cases Plot" + return(plot_info) +} + +# outliers plot; colour outliers red, non-outliers green +setup_ol_plot = function(pca_data, outliers) { + plot_info = list() + nsamples = dim(pca_data$values)[1] + groups = 1:nsamples + groups[outliers] = 1 + groups[setdiff(1:nsamples, outliers)] = 2 + plot_info$groups = groups + plot_info$tags = c("outliers", "good data") + plot_info$plot_colors = c("red", "green") + plot_info$plot_symbols = c(1, 20) + plot_info$plot_title = "Outliers Plot" + return(plot_info) +} + +# standard deviations plot; colour samples by s.dev +setup_sd_plot = function(pca_data) { + plot_info = list() + nsamples = dim(pca_data$values)[1] + pc1 = as.numeric(pca_data$values[, 1]) + pc2 = as.numeric(pca_data$values[, 2]) + pc1_sds = as.numeric(lapply(pc1, compute_numsds, pc1)) + pc2_sds = as.numeric(lapply(pc2, compute_numsds, pc2)) + + groups = 1:nsamples + groups[get_sdset2d(pc1_sds, pc2_sds, 1)] = 1 + groups[get_sdset2d(pc1_sds, pc2_sds, 2)] = 2 + groups[get_sdset2d(pc1_sds, pc2_sds, 3)] = 3 + groups[union(which(pc1_sds > 3), which(pc2_sds > 3))] = 4 + plot_info$groups = groups + plot_info$tags = c("SD = 1", "SD = 2", "SD = 3", "SD > 3") + plot_info$plot_colors = rainbow(4) + plot_info$plot_symbols = rep(20, 4) + plot_info$plot_title = "Standard Deviations Plot" + return(plot_info) +} + +# Plot samples, with coloured clusters. Rejected clusters use +# a cross symbol instead of a filled circle +setup_cluster_plot = function(pca_data, clusters, rc=NULL) { + plot_info = list() + groups = clusters + ids = sort(unique(groups)) + n = length(ids) + tags = 1:n + for (i in 1:n) { + tags[i] = sprintf("cluster %s", ids[i]) + } + outliers = which(groups == 0) + if (length(outliers) != 0) { + tags[1] = "outliers" + } + plot_colors = rainbow(n) + plot_symbols = rep(20, n) + if (length(outliers) != 0) { + plot_symbols[1] = 1 + } + # labelling for rejected clusters + if(!is.null(rc)) { + for(i in 1:n) { + if((ids[i] != 0) && (ids[i] %in% as.numeric(rc))) { + tags[i] = "rej. clust." + plot_symbols[i] = 4 + } + } + } + plot_info$groups = groups + plot_info$tags = tags + plot_info$plot_colors = plot_colors + plot_info$plot_symbols = plot_symbols + plot_info$plot_title = "Cluster Plot" + return(plot_info) +} + +# Plot samples, colouring by ethnicity. Different ethnicities also +# have different symbols. +setup_ethnicity_plot = function(pca_data, ethnicity_data) { + plot_info = list() + nsamples = dim(pca_data$values)[1] + eth = 1:nsamples + + for (i in 1:nsamples) { + sample_id = pca_data$ids[i] + eth[i] = as.character(ethnicity_data[sample_id, "population"]) + if(is.na(eth[i])) { + eth[i] = "UNKNOWN" + } + } + n = length(unique(eth)) + plot_info$groups = as.numeric(as.factor(eth)) + plot_info$tags = sort(unique(eth)) + plot_info$plot_colors = rainbow(n) + plot_info$plot_symbols = 1:n + plot_info$plot_title = "Ethnicity Plot" + return(plot_info) +} + +draw_cutoffs = function(input_data, x, y, numsds) { + pcx = as.numeric(input_data[x, ]) + pcy = as.numeric(input_data[y, ]) + + vlines = c(median(pcx) - numsds*sd(pcx), + median(pcx) + numsds*sd(pcx)) + hlines = c(median(pcy) - numsds*sd(pcy), + median(pcy) + numsds*sd(pcy)) + abline(v=vlines) + abline(h=hlines) +} + +# Following helper functions are used in the 'setup_sd_plot' function +# given a list of standard deviations, work out which points are n standard deviations away +get_sdset2d = function(x1, x2, n) { + if (n == 1) { + ind = intersect(which(x1 == 1), which(x2 == 1)) + } else { + lower = get_sdset2d(x1, x2, n - 1) + upper = union(which(x1 > n), which(x2 > n)) + xset = union(lower, upper) + bigset = union(which(x1 == n), which(x2 == n)) + ind = setdiff(bigset, xset) + } + return(ind) +} + +# work out how many standard deviations away from the sample median a single point is +# accuracy of this decreases for outliers, as the error in the estimated sd is +# multiplied +compute_numsds = function(point, x) { + x_sd = sd(x) + sum = x_sd + m = median(x) + i = 1 + while(abs(point - m) > sum) { + i = i + 1 + sum = sum + x_sd + } + return(i) +} \ No newline at end of file