Mercurial > repos > jason-ellul > iterativepca
comparison R_functions/plotting_functions.R @ 0:cb54350e76ae draft default tip
Uploaded
| author | jason-ellul |
|---|---|
| date | Wed, 01 Jun 2016 03:24:56 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:cb54350e76ae |
|---|---|
| 1 ## Plotting and grouping ## | |
| 2 # input data: some number of 2d observations. Each row represents a single observation, | |
| 3 # column 1 = variable 1, to be plotted on the x-axis, | |
| 4 # column 2 = variable 2, to be plotted on the y-axis | |
| 5 # groups: Integer vector with same number of entries as there are rows in the input data, | |
| 6 # representing which group each observation belongs to. Negative numbers are not plotted | |
| 7 # tags: the tag to put on the legend for each group | |
| 8 # plot_colors: colors to use for each group | |
| 9 # plot_symbols: symbols to use for each group | |
| 10 # plot_title: as name suggests | |
| 11 # plot_filename: if this is not null, graph is output to a png with the specified name | |
| 12 plot_by_groups = function(input_data, groups, tags, plot_colors, plot_symbols, plot_title, plot_filename=NULL) { | |
| 13 if(!is.null(plot_filename)) { | |
| 14 png(plot_filename) | |
| 15 } | |
| 16 # leave some extra room on the RHS for the legend | |
| 17 par(mar=c(5.1, 4.1, 4.1, 8.1)) | |
| 18 x = as.numeric(input_data[, 1]) | |
| 19 y = as.numeric(input_data[, 2]) | |
| 20 gids = sort(unique(groups[which(groups >= 0)])) | |
| 21 n = length(gids) | |
| 22 | |
| 23 # first set up the plot area to the correct dimensions | |
| 24 plot(x, y, col="white") | |
| 25 | |
| 26 for (i in 1:n) { | |
| 27 gid = gids[i] | |
| 28 pts_x = x[which(groups == gid)] | |
| 29 pts_y = y[which(groups == gid)] | |
| 30 pts_color = plot_colors[i] | |
| 31 pts_symbol = plot_symbols[i] | |
| 32 points(pts_x, pts_y, col=pts_color, pch=pts_symbol) | |
| 33 } | |
| 34 legend(x="topright", | |
| 35 xpd=TRUE, | |
| 36 inset=c(-0.3, 0), | |
| 37 col=plot_colors, | |
| 38 pch=plot_symbols, | |
| 39 legend=tags, | |
| 40 text.col=plot_colors) | |
| 41 title(main=plot_title) | |
| 42 if(!is.null(plot_filename)) { | |
| 43 dev.off() | |
| 44 } | |
| 45 } | |
| 46 | |
| 47 # Controls vs cases plot. Colour controls blue, cases red, | |
| 48 # Samples which are neither control nor case are black. | |
| 49 setup_cvc_plot = function(pca_data, control_tag, cases_tag) { | |
| 50 plot_info = list() | |
| 51 nsamples = length(pca_data$ids) | |
| 52 groups = rep(1, nsamples) | |
| 53 control_legend = paste0("CO: ", control_tag) | |
| 54 cases_legend = paste0("CA: ", cases_tag) | |
| 55 if (!is.null(control_tag)) { | |
| 56 groups[grep(control_tag, pca_data$ids)] = 2 | |
| 57 } | |
| 58 if (!is.null(cases_tag)) { | |
| 59 groups[grep(cases_tag, pca_data$ids)] = 3 | |
| 60 } | |
| 61 res = sort(unique(groups)) | |
| 62 if (length(res) == 1) { | |
| 63 tags = c("UNKNOWN") | |
| 64 plot_colors = c("black") | |
| 65 } else if (length(res) == 3) { | |
| 66 tags = c("UNKNOWN", control_legend, cases_legend) | |
| 67 plot_colors = c("black", "blue", "red") | |
| 68 } else { | |
| 69 if (all(res == c(1, 2))) { | |
| 70 tags = c("UNKNOWN", control_legend) | |
| 71 plot_colors = c("black", "blue") | |
| 72 } else if (all(res == c(1, 3))) { | |
| 73 tags = c("UNKNOWN", cases_legend) | |
| 74 plot_colors = c("black", "red") | |
| 75 } else { | |
| 76 tags = c(control_legend, cases_legend) | |
| 77 plot_colors = c("blue", "red") | |
| 78 } | |
| 79 } | |
| 80 plot_info$groups = groups | |
| 81 plot_info$tags = tags | |
| 82 plot_info$plot_colors = plot_colors | |
| 83 plot_info$plot_symbols = rep(1, length(res)) | |
| 84 plot_info$plot_title = "Control vs Cases Plot" | |
| 85 return(plot_info) | |
| 86 } | |
| 87 | |
| 88 # outliers plot; colour outliers red, non-outliers green | |
| 89 setup_ol_plot = function(pca_data, outliers) { | |
| 90 plot_info = list() | |
| 91 nsamples = dim(pca_data$values)[1] | |
| 92 groups = 1:nsamples | |
| 93 groups[outliers] = 1 | |
| 94 groups[setdiff(1:nsamples, outliers)] = 2 | |
| 95 plot_info$groups = groups | |
| 96 plot_info$tags = c("outliers", "good data") | |
| 97 plot_info$plot_colors = c("red", "green") | |
| 98 plot_info$plot_symbols = c(1, 20) | |
| 99 plot_info$plot_title = "Outliers Plot" | |
| 100 return(plot_info) | |
| 101 } | |
| 102 | |
| 103 # standard deviations plot; colour samples by s.dev | |
| 104 setup_sd_plot = function(pca_data) { | |
| 105 plot_info = list() | |
| 106 nsamples = dim(pca_data$values)[1] | |
| 107 pc1 = as.numeric(pca_data$values[, 1]) | |
| 108 pc2 = as.numeric(pca_data$values[, 2]) | |
| 109 pc1_sds = as.numeric(lapply(pc1, compute_numsds, pc1)) | |
| 110 pc2_sds = as.numeric(lapply(pc2, compute_numsds, pc2)) | |
| 111 | |
| 112 groups = 1:nsamples | |
| 113 groups[get_sdset2d(pc1_sds, pc2_sds, 1)] = 1 | |
| 114 groups[get_sdset2d(pc1_sds, pc2_sds, 2)] = 2 | |
| 115 groups[get_sdset2d(pc1_sds, pc2_sds, 3)] = 3 | |
| 116 groups[union(which(pc1_sds > 3), which(pc2_sds > 3))] = 4 | |
| 117 plot_info$groups = groups | |
| 118 plot_info$tags = c("SD = 1", "SD = 2", "SD = 3", "SD > 3") | |
| 119 plot_info$plot_colors = rainbow(4) | |
| 120 plot_info$plot_symbols = rep(20, 4) | |
| 121 plot_info$plot_title = "Standard Deviations Plot" | |
| 122 return(plot_info) | |
| 123 } | |
| 124 | |
| 125 # Plot samples, with coloured clusters. Rejected clusters use | |
| 126 # a cross symbol instead of a filled circle | |
| 127 setup_cluster_plot = function(pca_data, clusters, rc=NULL) { | |
| 128 plot_info = list() | |
| 129 groups = clusters | |
| 130 ids = sort(unique(groups)) | |
| 131 n = length(ids) | |
| 132 tags = 1:n | |
| 133 for (i in 1:n) { | |
| 134 tags[i] = sprintf("cluster %s", ids[i]) | |
| 135 } | |
| 136 outliers = which(groups == 0) | |
| 137 if (length(outliers) != 0) { | |
| 138 tags[1] = "outliers" | |
| 139 } | |
| 140 plot_colors = rainbow(n) | |
| 141 plot_symbols = rep(20, n) | |
| 142 if (length(outliers) != 0) { | |
| 143 plot_symbols[1] = 1 | |
| 144 } | |
| 145 # labelling for rejected clusters | |
| 146 if(!is.null(rc)) { | |
| 147 for(i in 1:n) { | |
| 148 if((ids[i] != 0) && (ids[i] %in% as.numeric(rc))) { | |
| 149 tags[i] = "rej. clust." | |
| 150 plot_symbols[i] = 4 | |
| 151 } | |
| 152 } | |
| 153 } | |
| 154 plot_info$groups = groups | |
| 155 plot_info$tags = tags | |
| 156 plot_info$plot_colors = plot_colors | |
| 157 plot_info$plot_symbols = plot_symbols | |
| 158 plot_info$plot_title = "Cluster Plot" | |
| 159 return(plot_info) | |
| 160 } | |
| 161 | |
| 162 # Plot samples, colouring by ethnicity. Different ethnicities also | |
| 163 # have different symbols. | |
| 164 setup_ethnicity_plot = function(pca_data, ethnicity_data) { | |
| 165 plot_info = list() | |
| 166 nsamples = dim(pca_data$values)[1] | |
| 167 eth = 1:nsamples | |
| 168 | |
| 169 for (i in 1:nsamples) { | |
| 170 sample_id = pca_data$ids[i] | |
| 171 eth[i] = as.character(ethnicity_data[sample_id, "population"]) | |
| 172 if(is.na(eth[i])) { | |
| 173 eth[i] = "UNKNOWN" | |
| 174 } | |
| 175 } | |
| 176 n = length(unique(eth)) | |
| 177 plot_info$groups = as.numeric(as.factor(eth)) | |
| 178 plot_info$tags = sort(unique(eth)) | |
| 179 plot_info$plot_colors = rainbow(n) | |
| 180 plot_info$plot_symbols = 1:n | |
| 181 plot_info$plot_title = "Ethnicity Plot" | |
| 182 return(plot_info) | |
| 183 } | |
| 184 | |
| 185 draw_cutoffs = function(input_data, x, y, numsds) { | |
| 186 pcx = as.numeric(input_data[x, ]) | |
| 187 pcy = as.numeric(input_data[y, ]) | |
| 188 | |
| 189 vlines = c(median(pcx) - numsds*sd(pcx), | |
| 190 median(pcx) + numsds*sd(pcx)) | |
| 191 hlines = c(median(pcy) - numsds*sd(pcy), | |
| 192 median(pcy) + numsds*sd(pcy)) | |
| 193 abline(v=vlines) | |
| 194 abline(h=hlines) | |
| 195 } | |
| 196 | |
| 197 # Following helper functions are used in the 'setup_sd_plot' function | |
| 198 # given a list of standard deviations, work out which points are n standard deviations away | |
| 199 get_sdset2d = function(x1, x2, n) { | |
| 200 if (n == 1) { | |
| 201 ind = intersect(which(x1 == 1), which(x2 == 1)) | |
| 202 } else { | |
| 203 lower = get_sdset2d(x1, x2, n - 1) | |
| 204 upper = union(which(x1 > n), which(x2 > n)) | |
| 205 xset = union(lower, upper) | |
| 206 bigset = union(which(x1 == n), which(x2 == n)) | |
| 207 ind = setdiff(bigset, xset) | |
| 208 } | |
| 209 return(ind) | |
| 210 } | |
| 211 | |
| 212 # work out how many standard deviations away from the sample median a single point is | |
| 213 # accuracy of this decreases for outliers, as the error in the estimated sd is | |
| 214 # multiplied | |
| 215 compute_numsds = function(point, x) { | |
| 216 x_sd = sd(x) | |
| 217 sum = x_sd | |
| 218 m = median(x) | |
| 219 i = 1 | |
| 220 while(abs(point - m) > sum) { | |
| 221 i = i + 1 | |
| 222 sum = sum + x_sd | |
| 223 } | |
| 224 return(i) | |
| 225 } |
