Mercurial > repos > davidmurphy > codonlogo
comparison test_weblogo.py @ 4:4d47ab2b7bcc
Uploaded
| author | davidmurphy |
|---|---|
| date | Fri, 13 Jan 2012 07:18:19 -0500 |
| parents | c55bdc2fb9fa |
| children |
comparison
equal
deleted
inserted
replaced
| 3:09d2dac9ef73 | 4:4d47ab2b7bcc |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 | |
| 3 # Copyright (c) 2006, The Regents of the University of California, through | |
| 4 # Lawrence Berkeley National Laboratory (subject to receipt of any required | |
| 5 # approvals from the U.S. Dept. of Energy). All rights reserved. | |
| 6 | |
| 7 # This software is distributed under the new BSD Open Source License. | |
| 8 # <http://www.opensource.org/licenses/bsd-license.html> | |
| 9 # | |
| 10 # Redistribution and use in source and binary forms, with or without | |
| 11 # modification, are permitted provided that the following conditions are met: | |
| 12 # | |
| 13 # (1) Redistributions of source code must retain the above copyright notice, | |
| 14 # this list of conditions and the following disclaimer. | |
| 15 # | |
| 16 # (2) Redistributions in binary form must reproduce the above copyright | |
| 17 # notice, this list of conditions and the following disclaimer in the | |
| 18 # documentation and or other materials provided with the distribution. | |
| 19 # | |
| 20 # (3) Neither the name of the University of California, Lawrence Berkeley | |
| 21 # National Laboratory, U.S. Dept. of Energy nor the names of its contributors | |
| 22 # may be used to endorse or promote products derived from this software | |
| 23 # without specific prior written permission. | |
| 24 # | |
| 25 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| 26 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| 27 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
| 28 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |
| 29 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
| 30 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
| 31 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
| 32 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
| 33 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |
| 34 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
| 35 # POSSIBILITY OF SUCH DAMAGE. | |
| 36 | |
| 37 | |
| 38 import unittest | |
| 39 | |
| 40 import weblogolib | |
| 41 from weblogolib import * | |
| 42 from weblogolib import parse_prior, GhostscriptAPI | |
| 43 from weblogolib.color import * | |
| 44 from weblogolib.colorscheme import * | |
| 45 from StringIO import StringIO | |
| 46 import sys | |
| 47 | |
| 48 from numpy import array, asarray, float64, ones, zeros, int32,all,any, shape | |
| 49 import numpy as na | |
| 50 | |
| 51 from corebio import seq_io | |
| 52 from corebio.seq import * | |
| 53 | |
| 54 # python2.3 compatability | |
| 55 from corebio._future.subprocess import * | |
| 56 from corebio._future import resource_stream | |
| 57 | |
| 58 from corebio.moremath import entropy | |
| 59 from math import log, sqrt | |
| 60 codon_alphabetU=['AAA', 'AAC', 'AAG', 'AAU', 'ACA', 'ACC', 'ACG', 'ACU', 'AGA', 'AGC', 'AGG', 'AGU', 'AUA', 'AUC', 'AUG', 'AUU', 'CAA', 'CAC', 'CAG', 'CAU', 'CCA', 'CCC', 'CCG', 'CCU', 'CGA', 'CGC', 'CGG', 'CGU', 'CUA', 'CUC', 'CUG', 'CUU', 'GAA', 'GAC', 'GAG', 'GAU', 'GCA', 'GCC', 'GCG', 'GCU', 'GGA', 'GGC', 'GGG', 'GGU', 'GUA', 'GUC', 'GUG', 'GUU', 'UAA', 'UAC', 'UAG', 'UAU', 'UCA', 'UCC', 'UCG', 'UCU', 'UGA', 'UGC', 'UGG', 'UGU', 'UUA', 'UUC', 'UUG', 'UUU'] | |
| 61 codon_alphabetT=['AAA', 'AAC', 'AAG', 'AAT', 'ACA', 'ACC', 'ACG', 'ACT', 'AGA', 'AGC', 'AGG', 'AGT', 'ATA', 'ATC', 'ATG', 'ATT', 'CAA', 'CAC', 'CAG', 'CAT', 'CCA', 'CCC', 'CCG', 'CCT', 'CGA', 'CGC', 'CGG', 'CGT', 'CTA', 'CTC', 'CTG', 'CTT', 'GAA', 'GAC', 'GAG', 'GAT', 'GCA', 'GCC', 'GCG', 'GCT', 'GGA', 'GGC', 'GGG', 'GGT', 'GTA', 'GTC', 'GTG', 'GTT', 'TAA', 'TAC', 'TAG', 'TAT', 'TCA', 'TCC', 'TCG', 'TCT', 'TGA', 'TGC', 'TGG', 'TGT', 'TTA', 'TTC', 'TTG', 'TTT'] | |
| 62 | |
| 63 | |
| 64 def testdata_stream( name ): | |
| 65 return resource_stream(__name__, 'tests/data/'+name, __file__) | |
| 66 | |
| 67 class test_logoformat(unittest.TestCase) : | |
| 68 | |
| 69 def test_options(self) : | |
| 70 options = LogoOptions() | |
| 71 | |
| 72 | |
| 73 class test_ghostscript(unittest.TestCase) : | |
| 74 def test_version(self) : | |
| 75 version = GhostscriptAPI().version | |
| 76 | |
| 77 | |
| 78 | |
| 79 class test_parse_prior(unittest.TestCase) : | |
| 80 def assertTrue(self, bool) : | |
| 81 self.assertEquals( bool, True) | |
| 82 | |
| 83 def test_parse_prior_none(self) : | |
| 84 self.assertEquals( None, | |
| 85 parse_prior(None, unambiguous_protein_alphabet ) ) | |
| 86 self.assertEquals( None, | |
| 87 parse_prior( 'none', unambiguous_protein_alphabet ) ) | |
| 88 self.assertEquals( None, | |
| 89 parse_prior( 'noNe', None) ) | |
| 90 | |
| 91 def test_parse_prior_equiprobable(self) : | |
| 92 self.assertTrue( all(20.*equiprobable_distribution(20) == | |
| 93 parse_prior( 'equiprobable', unambiguous_protein_alphabet ) ) ) | |
| 94 | |
| 95 self.assertTrue( | |
| 96 all( 1.2* equiprobable_distribution(3) | |
| 97 == parse_prior( ' equiprobablE ', Alphabet('123'), 1.2 ) ) ) | |
| 98 | |
| 99 def test_parse_prior_percentage(self) : | |
| 100 #print parse_prior( '50%', unambiguous_dna_alphabet, 1. ) | |
| 101 self.assertTrue( all( equiprobable_distribution(4) | |
| 102 == parse_prior( '50%', unambiguous_dna_alphabet, 1. ) ) ) | |
| 103 | |
| 104 self.assertTrue( all( equiprobable_distribution(4) | |
| 105 == parse_prior( ' 50.0 % ', unambiguous_dna_alphabet, 1. ) ) ) | |
| 106 | |
| 107 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64) | |
| 108 == parse_prior( ' 40.0 % ', unambiguous_dna_alphabet, 1. ) ) ) | |
| 109 | |
| 110 def test_parse_prior_float(self) : | |
| 111 self.assertTrue( all( equiprobable_distribution(4) | |
| 112 == parse_prior( '0.5', unambiguous_dna_alphabet, 1. ) ) ) | |
| 113 | |
| 114 self.assertTrue( all( equiprobable_distribution(4) | |
| 115 == parse_prior( ' 0.500 ', unambiguous_dna_alphabet, 1. ) ) ) | |
| 116 | |
| 117 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64) | |
| 118 == parse_prior( ' 0.40 ', unambiguous_dna_alphabet, 1. ) ) ) | |
| 119 | |
| 120 def test_auto(self) : | |
| 121 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
| 122 parse_prior( 'auto', unambiguous_dna_alphabet ) ) ) | |
| 123 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
| 124 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) ) | |
| 125 | |
| 126 def test_weight(self) : | |
| 127 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
| 128 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) ) | |
| 129 self.assertTrue( all(123.123*equiprobable_distribution(4) == | |
| 130 parse_prior( 'auto', unambiguous_dna_alphabet , 123.123) ) ) | |
| 131 | |
| 132 def test_explicit(self) : | |
| 133 s = "{'A':10, 'C':40, 'G':40, 'T':10}" | |
| 134 p = array( (10, 40, 40,10), float64)*4./100. | |
| 135 self.assertTrue( all( | |
| 136 p == parse_prior( s, unambiguous_dna_alphabet ) ) ) | |
| 137 | |
| 138 | |
| 139 class test_logooptions(unittest.TestCase) : | |
| 140 def test_create(self) : | |
| 141 opt = LogoOptions() | |
| 142 opt.small_fontsize =10 | |
| 143 options = repr(opt) | |
| 144 | |
| 145 opt = LogoOptions(title="sometitle") | |
| 146 assert opt.title == "sometitle" | |
| 147 | |
| 148 class test_logosize(unittest.TestCase) : | |
| 149 def test_create(self) : | |
| 150 s = LogoSize(101.0,10.0) | |
| 151 assert s.stack_width == 101.0 | |
| 152 r = repr(s) | |
| 153 | |
| 154 | |
| 155 class test_seqlogo(unittest.TestCase) : | |
| 156 # FIXME: The version of python used by Popen may not be the | |
| 157 # same as that used to run this test. | |
| 158 def _exec(self, args, outputtext, returncode =0, stdin=None) : | |
| 159 if not stdin : | |
| 160 stdin = testdata_stream("cap.fa") | |
| 161 args = ["./weblogo"] + args | |
| 162 p = Popen(args,stdin=stdin,stdout=PIPE, stderr=PIPE) | |
| 163 (out, err) = p.communicate() | |
| 164 if returncode ==0 and p.returncode >0 : | |
| 165 print err | |
| 166 self.assertEquals(returncode, p.returncode) | |
| 167 if returncode == 0 : self.assertEquals( len(err), 0) | |
| 168 | |
| 169 for item in outputtext : | |
| 170 self.failUnless(item in out) | |
| 171 | |
| 172 | |
| 173 | |
| 174 def test_malformed_options(self) : | |
| 175 self._exec( ["--notarealoption"], [], 2) | |
| 176 self._exec( ["extrajunk"], [], 2) | |
| 177 self._exec( ["-I"], [], 2) | |
| 178 | |
| 179 def test_help_option(self) : | |
| 180 self._exec( ["-h"], ["options"]) | |
| 181 self._exec( ["--help"], ["options"]) | |
| 182 | |
| 183 def test_version_option(self) : | |
| 184 self._exec( ['--version'], weblogolib.__version__) | |
| 185 | |
| 186 | |
| 187 def test_default_build(self) : | |
| 188 self._exec( [], ["%%Title: Sequence Logo:"] ) | |
| 189 | |
| 190 | |
| 191 # Format options | |
| 192 def test_width(self) : | |
| 193 self._exec( ['-W','1234'], ["/stack_width 1234"] ) | |
| 194 self._exec( ['--stack-width','1234'], ["/stack_width 1234"] ) | |
| 195 | |
| 196 def test_height(self) : | |
| 197 self._exec( ['-H','1234'], ["/stack_height 1234"] ) | |
| 198 self._exec( ['--stack-height','1234'], ["/stack_height 1234"] ) | |
| 199 | |
| 200 | |
| 201 def test_stacks_per_line(self) : | |
| 202 self._exec( ['-n','7'], ["/stacks_per_line 7 def"] ) | |
| 203 self._exec( ['--stacks-per-line','7'], ["/stacks_per_line 7 def"] ) | |
| 204 | |
| 205 | |
| 206 def test_title(self) : | |
| 207 self._exec( ['-t', '3456'], ['/logo_title (3456) def', | |
| 208 '/show_title True def']) | |
| 209 self._exec( ['-t', ''], ['/logo_title () def', | |
| 210 '/show_title False def']) | |
| 211 self._exec( ['--title', '3456'], ['/logo_title (3456) def', | |
| 212 '/show_title True def']) | |
| 213 | |
| 214 | |
| 215 | |
| 216 | |
| 217 | |
| 218 class test_which(unittest.TestCase) : | |
| 219 def test_which(self): | |
| 220 tests = ( | |
| 221 (seq_io.read(testdata_stream('cap.fa')), codon_alphabetT), | |
| 222 (seq_io.read(testdata_stream('capu.fa')), codon_alphabetU), | |
| 223 | |
| 224 #(seq_io.read(testdata_stream('cox2.msf')), unambiguous_protein_alphabet), | |
| 225 #(seq_io.read(testdata_stream('Rv3829c.fasta')), unambiguous_protein_alphabet), | |
| 226 ) | |
| 227 for t in tests : | |
| 228 self.failUnlessEqual(which_alphabet(t[0]), t[1]) | |
| 229 | |
| 230 | |
| 231 | |
| 232 | |
| 233 class test_colorscheme(unittest.TestCase) : | |
| 234 | |
| 235 def test_colorgroup(self) : | |
| 236 cr = ColorGroup( "ABC", "black", "Because") | |
| 237 self.assertEquals( cr.description, "Because") | |
| 238 | |
| 239 def test_colorscheme(self) : | |
| 240 cs = ColorScheme([ | |
| 241 ColorGroup("G", "orange"), | |
| 242 ColorGroup("TU", "red"), | |
| 243 ColorGroup("C", "blue"), | |
| 244 ColorGroup("A", "green") | |
| 245 ], | |
| 246 title = "title", | |
| 247 description = "description", | |
| 248 ) | |
| 249 | |
| 250 self.assertEquals( cs.color('A'), Color.by_name("green")) | |
| 251 self.assertEquals( cs.color('X'), cs.default_color) | |
| 252 | |
| 253 | |
| 254 | |
| 255 class test_color(unittest.TestCase) : | |
| 256 # 2.3 Python compatibility | |
| 257 assertTrue = unittest.TestCase.failUnless | |
| 258 assertFalse = unittest.TestCase.failIf | |
| 259 | |
| 260 def test_color_names(self) : | |
| 261 names = Color.names() | |
| 262 self.failUnlessEqual( len(names), 147) | |
| 263 | |
| 264 for n in names: | |
| 265 c = Color.by_name(n) | |
| 266 self.assertTrue( c != None ) | |
| 267 | |
| 268 | |
| 269 def test_color_components(self) : | |
| 270 white = Color.by_name("white") | |
| 271 self.failUnlessEqual( 1.0, white.red) | |
| 272 self.failUnlessEqual( 1.0, white.green) | |
| 273 self.failUnlessEqual( 1.0, white.blue) | |
| 274 | |
| 275 | |
| 276 c = Color(0.3, 0.4, 0.2) | |
| 277 self.failUnlessEqual( 0.3, c.red) | |
| 278 self.failUnlessEqual( 0.4, c.green) | |
| 279 self.failUnlessEqual( 0.2, c.blue) | |
| 280 | |
| 281 c = Color(0,128,0) | |
| 282 self.failUnlessEqual( 0.0, c.red) | |
| 283 self.failUnlessEqual( 128./255., c.green) | |
| 284 self.failUnlessEqual( 0.0, c.blue) | |
| 285 | |
| 286 | |
| 287 def test_color_from_rgb(self) : | |
| 288 white = Color.by_name("white") | |
| 289 | |
| 290 self.failUnlessEqual(white, Color(1.,1.,1.) ) | |
| 291 self.failUnlessEqual(white, Color(255,255,255) ) | |
| 292 self.failUnlessEqual(white, Color.from_rgb(1.,1.,1.) ) | |
| 293 self.failUnlessEqual(white, Color.from_rgb(255,255,255) ) | |
| 294 | |
| 295 | |
| 296 def test_color_from_hsl(self) : | |
| 297 red = Color.by_name("red") | |
| 298 lime = Color.by_name("lime") | |
| 299 saddlebrown = Color.by_name("saddlebrown") | |
| 300 darkgreen = Color.by_name("darkgreen") | |
| 301 blue = Color.by_name("blue") | |
| 302 green = Color.by_name("green") | |
| 303 | |
| 304 self.failUnlessEqual(red, Color.from_hsl(0, 1.0,0.5) ) | |
| 305 self.failUnlessEqual(lime, Color.from_hsl(120, 1.0, 0.5) ) | |
| 306 self.failUnlessEqual(blue, Color.from_hsl(240, 1.0, 0.5) ) | |
| 307 self.failUnlessEqual(Color.by_name("gray"), Color.from_hsl(0,0,0.5) ) | |
| 308 | |
| 309 self.failUnlessEqual(saddlebrown, Color.from_hsl(25, 0.76, 0.31) ) | |
| 310 | |
| 311 self.failUnlessEqual(darkgreen, Color.from_hsl(120, 1.0, 0.197) ) | |
| 312 | |
| 313 | |
| 314 def test_color_by_name(self): | |
| 315 white = Color.by_name("white") | |
| 316 self.failUnlessEqual(white, Color.by_name("white")) | |
| 317 self.failUnlessEqual(white, Color.by_name("WHITE")) | |
| 318 self.failUnlessEqual(white, Color.by_name(" wHiTe \t\n\t")) | |
| 319 | |
| 320 | |
| 321 self.failUnlessEqual(Color(255,255,240), Color.by_name("ivory")) | |
| 322 self.failUnlessEqual(Color(70,130,180), Color.by_name("steelblue")) | |
| 323 | |
| 324 self.failUnlessEqual(Color(0,128,0), Color.by_name("green")) | |
| 325 | |
| 326 | |
| 327 def test_color_from_invalid_name(self): | |
| 328 self.failUnlessRaises( ValueError, Color.by_name, "not_a_color") | |
| 329 | |
| 330 | |
| 331 def test_color_clipping(self): | |
| 332 red = Color.by_name("red") | |
| 333 self.failUnlessEqual(red, Color(255,0,0) ) | |
| 334 self.failUnlessEqual(red, Color(260,-10,0) ) | |
| 335 self.failUnlessEqual(red, Color(1.1,-0.,-1.) ) | |
| 336 | |
| 337 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).red, 1.0 ) | |
| 338 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).red, 0.0 ) | |
| 339 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).green, 1.0 ) | |
| 340 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).green, 0.0 ) | |
| 341 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).blue, 1.0 ) | |
| 342 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).blue, 0.0 ) | |
| 343 | |
| 344 | |
| 345 def test_color_fail_on_mixed_type(self): | |
| 346 self.failUnlessRaises( TypeError, Color.from_rgb, 1,1,1.0 ) | |
| 347 self.failUnlessRaises( TypeError, Color.from_rgb, 1.0,1,1.0 ) | |
| 348 | |
| 349 def test_color_red(self) : | |
| 350 # Check Usage comment in Color | |
| 351 red = Color.by_name("red") | |
| 352 self.failUnlessEqual( red , Color(255,0,0) ) | |
| 353 self.failUnlessEqual( red, Color(1., 0., 0.) ) | |
| 354 | |
| 355 self.failUnlessEqual( red , Color.from_rgb(1.,0.,0.) ) | |
| 356 self.failUnlessEqual( red , Color.from_rgb(255,0,0) ) | |
| 357 self.failUnlessEqual( red , Color.from_hsl(0.,1., 0.5) ) | |
| 358 | |
| 359 self.failUnlessEqual( red , Color.from_string("red") ) | |
| 360 self.failUnlessEqual( red , Color.from_string("RED") ) | |
| 361 self.failUnlessEqual( red , Color.from_string("#F00") ) | |
| 362 self.failUnlessEqual( red , Color.from_string("#FF0000") ) | |
| 363 self.failUnlessEqual( red , Color.from_string("rgb(255, 0, 0)") ) | |
| 364 self.failUnlessEqual( red , Color.from_string("rgb(100%, 0%, 0%)") ) | |
| 365 self.failUnlessEqual( red , Color.from_string("hsl(0, 100%, 50%)") ) | |
| 366 | |
| 367 | |
| 368 def test_color_from_string(self) : | |
| 369 purple = Color(128,0,128) | |
| 370 red = Color(255,0,0) | |
| 371 skyblue = Color(135,206,235) | |
| 372 | |
| 373 red_strings = ("red", | |
| 374 "ReD", | |
| 375 "RED", | |
| 376 " Red \t", | |
| 377 "#F00", | |
| 378 "#FF0000", | |
| 379 "rgb(255, 0, 0)", | |
| 380 "rgb(100%, 0%, 0%)", | |
| 381 "hsl(0, 100%, 50%)") | |
| 382 | |
| 383 for s in red_strings: | |
| 384 self.failUnlessEqual( red, Color.from_string(s) ) | |
| 385 | |
| 386 skyblue_strings = ("skyblue", | |
| 387 "SKYBLUE", | |
| 388 " \t\n SkyBlue \t", | |
| 389 "#87ceeb", | |
| 390 "rgb(135,206,235)" | |
| 391 ) | |
| 392 | |
| 393 for s in skyblue_strings: | |
| 394 self.failUnlessEqual( skyblue, Color.from_string(s) ) | |
| 395 | |
| 396 | |
| 397 | |
| 398 def test_color_equality(self): | |
| 399 c1 = Color(123,99,12) | |
| 400 c2 = Color(123,99,12) | |
| 401 | |
| 402 self.failUnlessEqual(c1,c2) | |
| 403 | |
| 404 | |
| 405 | |
| 406 | |
| 407 | |
| 408 | |
| 409 class test_Dirichlet(unittest.TestCase) : | |
| 410 # 2.3 Python compatibility | |
| 411 assertTrue = unittest.TestCase.failUnless | |
| 412 assertFalse = unittest.TestCase.failIf | |
| 413 | |
| 414 | |
| 415 def test_init(self) : | |
| 416 d = Dirichlet( ( 1,1,1,1,) ) | |
| 417 | |
| 418 | |
| 419 def test_random(self) : | |
| 420 | |
| 421 | |
| 422 def do_test( alpha, samples = 1000) : | |
| 423 ent = zeros( (samples,), float64) | |
| 424 #alpha = ones( ( K,), Float64 ) * A/K | |
| 425 | |
| 426 #pt = zeros( (len(alpha) ,), Float64) | |
| 427 d = Dirichlet(alpha) | |
| 428 for s in range(samples) : | |
| 429 p = d.sample() | |
| 430 #print p | |
| 431 #pt +=p | |
| 432 ent[s] = entropy(p) | |
| 433 | |
| 434 #print pt/samples | |
| 435 | |
| 436 m = mean(ent) | |
| 437 v = var(ent) | |
| 438 | |
| 439 dm = d.mean_entropy() | |
| 440 dv = d.variance_entropy() | |
| 441 | |
| 442 #print alpha, ':', m, v, dm, dv | |
| 443 error = 4. * sqrt(v/samples) | |
| 444 self.assertTrue( abs(m-dm) < error) | |
| 445 self.assertTrue( abs(v-dv) < error) # dodgy error estimate | |
| 446 | |
| 447 | |
| 448 do_test( (1., 1.) ) | |
| 449 do_test( (2., 1.) ) | |
| 450 do_test( (3., 1.) ) | |
| 451 do_test( (4., 1.) ) | |
| 452 do_test( (5., 1.) ) | |
| 453 do_test( (6., 1.) ) | |
| 454 | |
| 455 do_test( (1., 1.) ) | |
| 456 do_test( (20., 20.) ) | |
| 457 do_test( (1., 1., 1., 1., 1., 1., 1., 1., 1., 1.) ) | |
| 458 do_test( (.1, .1, .1, .1, .1, .1, .1, .1, .1, .1) ) | |
| 459 do_test( (.01, .01, .01, .01, .01, .01, .01, .01, .01, .01) ) | |
| 460 do_test( (2.0, 6.0, 1.0, 1.0) ) | |
| 461 | |
| 462 | |
| 463 def test_mean(self) : | |
| 464 alpha = ones( ( 10,), float64 ) * 23. | |
| 465 d = Dirichlet(alpha) | |
| 466 m = d.mean() | |
| 467 self.assertAlmostEqual( m[2], 1./10) | |
| 468 self.assertAlmostEqual( sum(m), 1.0) | |
| 469 | |
| 470 def test_covariance(self) : | |
| 471 alpha = ones( ( 4,), float64 ) | |
| 472 d = Dirichlet(alpha) | |
| 473 cv = d.covariance() | |
| 474 self.assertEqual( cv.shape, (4,4) ) | |
| 475 self.assertAlmostEqual( cv[0,0], 1.0 * (1.0 - 1./4.0)/ (4.0 * 5.0) ) | |
| 476 self.assertAlmostEqual( cv[0,1], - 1 / ( 4. * 4. * 5.) ) | |
| 477 | |
| 478 def test_mean_x(self) : | |
| 479 alpha = (1.0, 2.0, 3.0, 4.0) | |
| 480 xx = (2.0, 2.0, 2.0, 2.0) | |
| 481 m = Dirichlet(alpha).mean_x(xx) | |
| 482 self.assertEquals( m, 2.0) | |
| 483 | |
| 484 alpha = (1.0, 1.0, 1.0, 1.0) | |
| 485 xx = (2.0, 3.0, 4.0, 3.0) | |
| 486 m = Dirichlet(alpha).mean_x(xx) | |
| 487 self.assertEquals( m, 3.0) | |
| 488 | |
| 489 def test_variance_x(self) : | |
| 490 alpha = (1.0, 1.0, 1.0, 1.0) | |
| 491 xx = (2.0, 2.0, 2.0, 2.0) | |
| 492 v = Dirichlet(alpha).variance_x(xx) | |
| 493 self.assertAlmostEquals( v, 0.0) | |
| 494 | |
| 495 alpha = (1.0, 2.0, 3.0, 4.0) | |
| 496 xx = (2.0, 0.0, 1.0, 10.0) | |
| 497 v = Dirichlet(alpha).variance_x(xx) | |
| 498 #print v | |
| 499 # TODO: Don't actually know if this is correct | |
| 500 | |
| 501 def test_relative_entropy(self): | |
| 502 alpha = (2.0, 10.0, 1.0, 1.0) | |
| 503 d = Dirichlet(alpha) | |
| 504 pvec = (0.1, 0.2, 0.3, 0.4) | |
| 505 | |
| 506 rent = d.mean_relative_entropy(pvec) | |
| 507 vrent = d.variance_relative_entropy(pvec) | |
| 508 low, high = d.interval_relative_entropy(pvec, 0.95) | |
| 509 | |
| 510 #print | |
| 511 #print '> ', rent, vrent, low, high | |
| 512 | |
| 513 # This test can fail randomly, but the precision form a few | |
| 514 # thousand samples is low. Increasing samples, 1000->2000 | |
| 515 samples = 2000 | |
| 516 sent = zeros( (samples,), float64) | |
| 517 | |
| 518 for s in range(samples) : | |
| 519 post = d.sample() | |
| 520 e = -entropy(post) | |
| 521 for k in range(4) : | |
| 522 e += - post[k] * log(pvec[k]) | |
| 523 sent[s] = e | |
| 524 sent.sort() | |
| 525 self.assertTrue( abs(sent.mean() - rent) < 4.*sqrt(vrent) ) | |
| 526 self.assertAlmostEqual( sent.std(), sqrt(vrent), 1 ) | |
| 527 self.assertTrue( abs(low-sent[ int( samples *0.025)])<0.2 ) | |
| 528 self.assertTrue( abs(high-sent[ int( samples *0.975)])<0.2 ) | |
| 529 | |
| 530 #print '>>', mean(sent), var(sent), sent[ int( samples *0.025)] ,sent[ int( samples *0.975)] | |
| 531 | |
| 532 | |
| 533 | |
| 534 def mean( a) : | |
| 535 return sum(a)/ len(a) | |
| 536 | |
| 537 def var(a) : | |
| 538 return (sum(a*a) /len(a) ) - mean(a)**2 | |
| 539 | |
| 540 | |
| 541 | |
| 542 | |
| 543 if __name__ == '__main__': | |
| 544 unittest.main() |
