diff R_functions/plotting_functions.R @ 0:64e75e21466e draft default tip

Uploaded
author pmac
date Wed, 01 Jun 2016 03:38:39 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/R_functions/plotting_functions.R	Wed Jun 01 03:38:39 2016 -0400
@@ -0,0 +1,225 @@
+## Plotting and grouping ##
+# input data: some number of 2d observations. Each row represents a single observation,
+# column 1 = variable 1, to be plotted on the x-axis,
+# column 2 = variable 2, to be plotted on the y-axis
+# groups: Integer vector with same number of entries as there are rows in the input data,
+# representing which group each observation belongs to. Negative numbers are not plotted
+# tags: the tag to put on the legend for each group
+# plot_colors: colors to use for each group
+# plot_symbols: symbols to use for each group
+# plot_title: as name suggests
+# plot_filename: if this is not null, graph is output to a png with the specified name
+plot_by_groups = function(input_data, groups, tags, plot_colors, plot_symbols, plot_title, plot_filename=NULL) {
+  if(!is.null(plot_filename)) {
+    png(plot_filename)
+  }
+  # leave some extra room on the RHS for the legend
+  par(mar=c(5.1, 4.1, 4.1, 8.1))
+  x = as.numeric(input_data[, 1])
+  y = as.numeric(input_data[, 2])
+  gids = sort(unique(groups[which(groups >= 0)]))
+  n = length(gids)
+  
+  # first set up the plot area to the correct dimensions
+  plot(x, y, col="white")
+  
+  for (i in 1:n) {
+    gid = gids[i]
+    pts_x = x[which(groups == gid)]
+    pts_y = y[which(groups == gid)]
+    pts_color = plot_colors[i]
+    pts_symbol = plot_symbols[i]
+    points(pts_x, pts_y, col=pts_color, pch=pts_symbol)
+  }
+  legend(x="topright",
+         xpd=TRUE,
+         inset=c(-0.3, 0),
+         col=plot_colors, 
+         pch=plot_symbols,
+         legend=tags,
+         text.col=plot_colors)
+  title(main=plot_title)
+  if(!is.null(plot_filename)) {
+    dev.off()
+  }
+}
+
+# Controls vs cases plot. Colour controls blue, cases red,
+# Samples which are neither control nor case are black.
+setup_cvc_plot = function(pca_data, control_tag, cases_tag) {
+  plot_info = list()
+  nsamples = length(pca_data$ids)
+  groups = rep(1, nsamples)
+  control_legend = paste0("CO: ", control_tag)
+  cases_legend = paste0("CA: ", cases_tag)
+  if (!is.null(control_tag)) {
+    groups[grep(control_tag, pca_data$ids)] = 2
+  } 
+  if (!is.null(cases_tag)) {
+    groups[grep(cases_tag, pca_data$ids)] = 3
+  } 
+  res = sort(unique(groups))
+  if (length(res) == 1) {
+    tags = c("UNKNOWN")
+    plot_colors = c("black")
+  } else if (length(res) == 3) {
+    tags = c("UNKNOWN", control_legend, cases_legend)
+    plot_colors = c("black", "blue", "red")
+  } else {
+    if (all(res == c(1, 2))) {
+      tags = c("UNKNOWN", control_legend)
+      plot_colors = c("black", "blue")
+    } else if (all(res == c(1, 3))) {
+      tags = c("UNKNOWN", cases_legend)
+      plot_colors = c("black", "red")
+    } else {
+      tags = c(control_legend, cases_legend)
+      plot_colors = c("blue", "red")
+    }
+  }
+  plot_info$groups = groups
+  plot_info$tags = tags
+  plot_info$plot_colors = plot_colors
+  plot_info$plot_symbols = rep(1, length(res))
+  plot_info$plot_title = "Control vs Cases Plot"
+  return(plot_info)
+}
+
+# outliers plot; colour outliers red, non-outliers green
+setup_ol_plot = function(pca_data, outliers) {
+  plot_info = list()
+  nsamples = dim(pca_data$values)[1]
+  groups = 1:nsamples
+  groups[outliers] = 1
+  groups[setdiff(1:nsamples, outliers)] = 2
+  plot_info$groups = groups
+  plot_info$tags = c("outliers", "good data")
+  plot_info$plot_colors = c("red", "green")
+  plot_info$plot_symbols = c(1, 20)
+  plot_info$plot_title = "Outliers Plot"
+  return(plot_info)
+}
+
+# standard deviations plot; colour samples by s.dev
+setup_sd_plot = function(pca_data) {
+  plot_info = list()
+  nsamples = dim(pca_data$values)[1]
+  pc1 = as.numeric(pca_data$values[, 1])
+  pc2 = as.numeric(pca_data$values[, 2])
+  pc1_sds = as.numeric(lapply(pc1, compute_numsds, pc1))
+  pc2_sds = as.numeric(lapply(pc2, compute_numsds, pc2))
+  
+  groups = 1:nsamples
+  groups[get_sdset2d(pc1_sds, pc2_sds, 1)] = 1
+  groups[get_sdset2d(pc1_sds, pc2_sds, 2)] = 2
+  groups[get_sdset2d(pc1_sds, pc2_sds, 3)] = 3
+  groups[union(which(pc1_sds > 3), which(pc2_sds > 3))] = 4
+  plot_info$groups = groups
+  plot_info$tags = c("SD = 1", "SD = 2", "SD = 3", "SD > 3")
+  plot_info$plot_colors = rainbow(4)
+  plot_info$plot_symbols = rep(20, 4)
+  plot_info$plot_title = "Standard Deviations Plot"
+  return(plot_info)
+}
+
+# Plot samples, with coloured clusters. Rejected clusters use
+# a cross symbol instead of a filled circle
+setup_cluster_plot = function(pca_data, clusters, rc=NULL) {
+  plot_info = list()
+  groups = clusters
+  ids = sort(unique(groups))
+  n = length(ids)
+  tags = 1:n
+  for (i in 1:n) {
+    tags[i] = sprintf("cluster %s", ids[i])
+  }
+  outliers = which(groups == 0)
+  if (length(outliers) != 0) {
+    tags[1] = "outliers"
+  }
+  plot_colors = rainbow(n)
+  plot_symbols = rep(20, n)
+  if (length(outliers) != 0) {
+    plot_symbols[1] = 1
+  }
+  # labelling for rejected clusters
+  if(!is.null(rc)) {
+    for(i in 1:n) {
+      if((ids[i] != 0) && (ids[i] %in% as.numeric(rc))) {
+        tags[i] = "rej. clust."
+        plot_symbols[i] = 4
+      }
+    }
+  }
+  plot_info$groups = groups
+  plot_info$tags = tags
+  plot_info$plot_colors = plot_colors
+  plot_info$plot_symbols = plot_symbols
+  plot_info$plot_title = "Cluster Plot"
+  return(plot_info)
+}
+
+# Plot samples, colouring by ethnicity. Different ethnicities also
+# have different symbols.
+setup_ethnicity_plot = function(pca_data, ethnicity_data) {
+  plot_info = list()
+  nsamples = dim(pca_data$values)[1]
+  eth = 1:nsamples
+  
+  for (i in 1:nsamples) {
+    sample_id = pca_data$ids[i]
+    eth[i] = as.character(ethnicity_data[sample_id, "population"])
+    if(is.na(eth[i])) {
+      eth[i] = "UNKNOWN"
+    }
+  }
+  n = length(unique(eth))
+  plot_info$groups = as.numeric(as.factor(eth))
+  plot_info$tags = sort(unique(eth))
+  plot_info$plot_colors = rainbow(n)
+  plot_info$plot_symbols = 1:n
+  plot_info$plot_title = "Ethnicity Plot"
+  return(plot_info)
+}
+
+draw_cutoffs = function(input_data, x, y, numsds) {
+  pcx = as.numeric(input_data[x, ])
+  pcy = as.numeric(input_data[y, ])
+  
+  vlines = c(median(pcx) - numsds*sd(pcx), 
+             median(pcx) + numsds*sd(pcx))
+  hlines = c(median(pcy) - numsds*sd(pcy), 
+             median(pcy) + numsds*sd(pcy))  
+  abline(v=vlines)
+  abline(h=hlines)
+}
+
+# Following helper functions are used in the 'setup_sd_plot' function
+# given a list of standard deviations, work out which points are n standard deviations away
+get_sdset2d = function(x1, x2, n) {
+  if (n == 1) {
+    ind = intersect(which(x1 == 1), which(x2 == 1))
+  } else {
+    lower = get_sdset2d(x1, x2, n - 1)
+    upper = union(which(x1 > n), which(x2 > n))
+    xset = union(lower, upper)
+    bigset = union(which(x1 == n), which(x2 == n))
+    ind = setdiff(bigset, xset)
+  }
+  return(ind)
+}
+
+# work out how many standard deviations away from the sample median a single point is
+# accuracy of this decreases for outliers, as the error in the estimated sd is 
+# multiplied
+compute_numsds = function(point, x) {
+  x_sd = sd(x)
+  sum = x_sd
+  m = median(x)
+  i = 1
+  while(abs(point - m) > sum) {
+    i = i + 1
+    sum = sum + x_sd
+  }
+  return(i)
+}
\ No newline at end of file