comparison munkres.py @ 33:5064f618ec1c

remove munkres dependency
author Richard Burhans <burhans@bx.psu.edu>
date Fri, 20 Sep 2013 14:01:30 -0400
parents
children
comparison
equal deleted inserted replaced
32:03c22b722882 33:5064f618ec1c
1 #!/usr/bin/env python
2 # -*- coding: iso-8859-1 -*-
3
4 # Documentation is intended to be processed by Epydoc.
5
6 """
7 Introduction
8 ============
9
10 The Munkres module provides an implementation of the Munkres algorithm
11 (also called the Hungarian algorithm or the Kuhn-Munkres algorithm),
12 useful for solving the Assignment Problem.
13
14 Assignment Problem
15 ==================
16
17 Let *C* be an *n*\ x\ *n* matrix representing the costs of each of *n* workers
18 to perform any of *n* jobs. The assignment problem is to assign jobs to
19 workers in a way that minimizes the total cost. Since each worker can perform
20 only one job and each job can be assigned to only one worker the assignments
21 represent an independent set of the matrix *C*.
22
23 One way to generate the optimal set is to create all permutations of
24 the indexes necessary to traverse the matrix so that no row and column
25 are used more than once. For instance, given this matrix (expressed in
26 Python)::
27
28 matrix = [[5, 9, 1],
29 [10, 3, 2],
30 [8, 7, 4]]
31
32 You could use this code to generate the traversal indexes::
33
34 def permute(a, results):
35 if len(a) == 1:
36 results.insert(len(results), a)
37
38 else:
39 for i in range(0, len(a)):
40 element = a[i]
41 a_copy = [a[j] for j in range(0, len(a)) if j != i]
42 subresults = []
43 permute(a_copy, subresults)
44 for subresult in subresults:
45 result = [element] + subresult
46 results.insert(len(results), result)
47
48 results = []
49 permute(range(len(matrix)), results) # [0, 1, 2] for a 3x3 matrix
50
51 After the call to permute(), the results matrix would look like this::
52
53 [[0, 1, 2],
54 [0, 2, 1],
55 [1, 0, 2],
56 [1, 2, 0],
57 [2, 0, 1],
58 [2, 1, 0]]
59
60 You could then use that index matrix to loop over the original cost matrix
61 and calculate the smallest cost of the combinations::
62
63 n = len(matrix)
64 minval = sys.maxint
65 for row in range(n):
66 cost = 0
67 for col in range(n):
68 cost += matrix[row][col]
69 minval = min(cost, minval)
70
71 print minval
72
73 While this approach works fine for small matrices, it does not scale. It
74 executes in O(*n*!) time: Calculating the permutations for an *n*\ x\ *n*
75 matrix requires *n*! operations. For a 12x12 matrix, that's 479,001,600
76 traversals. Even if you could manage to perform each traversal in just one
77 millisecond, it would still take more than 133 hours to perform the entire
78 traversal. A 20x20 matrix would take 2,432,902,008,176,640,000 operations. At
79 an optimistic millisecond per operation, that's more than 77 million years.
80
81 The Munkres algorithm runs in O(*n*\ ^3) time, rather than O(*n*!). This
82 package provides an implementation of that algorithm.
83
84 This version is based on
85 http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html.
86
87 This version was written for Python by Brian Clapper from the (Ada) algorithm
88 at the above web site. (The ``Algorithm::Munkres`` Perl version, in CPAN, was
89 clearly adapted from the same web site.)
90
91 Usage
92 =====
93
94 Construct a Munkres object::
95
96 from munkres import Munkres
97
98 m = Munkres()
99
100 Then use it to compute the lowest cost assignment from a cost matrix. Here's
101 a sample program::
102
103 from munkres import Munkres, print_matrix
104
105 matrix = [[5, 9, 1],
106 [10, 3, 2],
107 [8, 7, 4]]
108 m = Munkres()
109 indexes = m.compute(matrix)
110 print_matrix(matrix, msg='Lowest cost through this matrix:')
111 total = 0
112 for row, column in indexes:
113 value = matrix[row][column]
114 total += value
115 print '(%d, %d) -> %d' % (row, column, value)
116 print 'total cost: %d' % total
117
118 Running that program produces::
119
120 Lowest cost through this matrix:
121 [5, 9, 1]
122 [10, 3, 2]
123 [8, 7, 4]
124 (0, 0) -> 5
125 (1, 1) -> 3
126 (2, 2) -> 4
127 total cost=12
128
129 The instantiated Munkres object can be used multiple times on different
130 matrices.
131
132 Non-square Cost Matrices
133 ========================
134
135 The Munkres algorithm assumes that the cost matrix is square. However, it's
136 possible to use a rectangular matrix if you first pad it with 0 values to make
137 it square. This module automatically pads rectangular cost matrices to make
138 them square.
139
140 Notes:
141
142 - The module operates on a *copy* of the caller's matrix, so any padding will
143 not be seen by the caller.
144 - The cost matrix must be rectangular or square. An irregular matrix will
145 *not* work.
146
147 Calculating Profit, Rather than Cost
148 ====================================
149
150 The cost matrix is just that: A cost matrix. The Munkres algorithm finds
151 the combination of elements (one from each row and column) that results in
152 the smallest cost. It's also possible to use the algorithm to maximize
153 profit. To do that, however, you have to convert your profit matrix to a
154 cost matrix. The simplest way to do that is to subtract all elements from a
155 large value. For example::
156
157 from munkres import Munkres, print_matrix
158
159 matrix = [[5, 9, 1],
160 [10, 3, 2],
161 [8, 7, 4]]
162 cost_matrix = []
163 for row in matrix:
164 cost_row = []
165 for col in row:
166 cost_row += [sys.maxint - col]
167 cost_matrix += [cost_row]
168
169 m = Munkres()
170 indexes = m.compute(cost_matrix)
171 print_matrix(matrix, msg='Highest profit through this matrix:')
172 total = 0
173 for row, column in indexes:
174 value = matrix[row][column]
175 total += value
176 print '(%d, %d) -> %d' % (row, column, value)
177
178 print 'total profit=%d' % total
179
180 Running that program produces::
181
182 Highest profit through this matrix:
183 [5, 9, 1]
184 [10, 3, 2]
185 [8, 7, 4]
186 (0, 1) -> 9
187 (1, 0) -> 10
188 (2, 2) -> 4
189 total profit=23
190
191 The ``munkres`` module provides a convenience method for creating a cost
192 matrix from a profit matrix. Since it doesn't know whether the matrix contains
193 floating point numbers, decimals, or integers, you have to provide the
194 conversion function; but the convenience method takes care of the actual
195 creation of the cost matrix::
196
197 import munkres
198
199 cost_matrix = munkres.make_cost_matrix(matrix,
200 lambda cost: sys.maxint - cost)
201
202 So, the above profit-calculation program can be recast as::
203
204 from munkres import Munkres, print_matrix, make_cost_matrix
205
206 matrix = [[5, 9, 1],
207 [10, 3, 2],
208 [8, 7, 4]]
209 cost_matrix = make_cost_matrix(matrix, lambda cost: sys.maxint - cost)
210 m = Munkres()
211 indexes = m.compute(cost_matrix)
212 print_matrix(matrix, msg='Lowest cost through this matrix:')
213 total = 0
214 for row, column in indexes:
215 value = matrix[row][column]
216 total += value
217 print '(%d, %d) -> %d' % (row, column, value)
218 print 'total profit=%d' % total
219
220 References
221 ==========
222
223 1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html
224
225 2. Harold W. Kuhn. The Hungarian Method for the assignment problem.
226 *Naval Research Logistics Quarterly*, 2:83-97, 1955.
227
228 3. Harold W. Kuhn. Variants of the Hungarian method for assignment
229 problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956.
230
231 4. Munkres, J. Algorithms for the Assignment and Transportation Problems.
232 *Journal of the Society of Industrial and Applied Mathematics*,
233 5(1):32-38, March, 1957.
234
235 5. http://en.wikipedia.org/wiki/Hungarian_algorithm
236
237 Copyright and License
238 =====================
239
240 This software is released under a BSD license, adapted from
241 <http://opensource.org/licenses/bsd-license.php>
242
243 Copyright (c) 2008 Brian M. Clapper
244 All rights reserved.
245
246 Redistribution and use in source and binary forms, with or without
247 modification, are permitted provided that the following conditions are met:
248
249 * Redistributions of source code must retain the above copyright notice,
250 this list of conditions and the following disclaimer.
251
252 * Redistributions in binary form must reproduce the above copyright notice,
253 this list of conditions and the following disclaimer in the documentation
254 and/or other materials provided with the distribution.
255
256 * Neither the name "clapper.org" nor the names of its contributors may be
257 used to endorse or promote products derived from this software without
258 specific prior written permission.
259
260 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
261 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
262 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
263 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
264 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
265 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
266 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
267 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
268 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
269 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
270 POSSIBILITY OF SUCH DAMAGE.
271 """
272
273 __docformat__ = 'restructuredtext'
274
275 # ---------------------------------------------------------------------------
276 # Imports
277 # ---------------------------------------------------------------------------
278
279 import sys
280
281 # ---------------------------------------------------------------------------
282 # Exports
283 # ---------------------------------------------------------------------------
284
285 __all__ = ['Munkres', 'make_cost_matrix']
286
287 # ---------------------------------------------------------------------------
288 # Globals
289 # ---------------------------------------------------------------------------
290
291 # Info about the module
292 __version__ = "1.0.5.4"
293 __author__ = "Brian Clapper, bmc@clapper.org"
294 __url__ = "http://bmc.github.com/munkres/"
295 __copyright__ = "(c) 2008 Brian M. Clapper"
296 __license__ = "BSD-style license"
297
298 # ---------------------------------------------------------------------------
299 # Classes
300 # ---------------------------------------------------------------------------
301
302 class Munkres:
303 """
304 Calculate the Munkres solution to the classical assignment problem.
305 See the module documentation for usage.
306 """
307
308 def __init__(self):
309 """Create a new instance"""
310 self.C = None
311 self.row_covered = []
312 self.col_covered = []
313 self.n = 0
314 self.Z0_r = 0
315 self.Z0_c = 0
316 self.marked = None
317 self.path = None
318
319 def make_cost_matrix(profit_matrix, inversion_function):
320 """
321 **DEPRECATED**
322
323 Please use the module function ``make_cost_matrix()``.
324 """
325 import munkres
326 return munkres.make_cost_matrix(profit_matrix, inversion_function)
327
328 make_cost_matrix = staticmethod(make_cost_matrix)
329
330 def pad_matrix(self, matrix, pad_value=0):
331 """
332 Pad a possibly non-square matrix to make it square.
333
334 :Parameters:
335 matrix : list of lists
336 matrix to pad
337
338 pad_value : int
339 value to use to pad the matrix
340
341 :rtype: list of lists
342 :return: a new, possibly padded, matrix
343 """
344 max_columns = 0
345 total_rows = len(matrix)
346
347 for row in matrix:
348 max_columns = max(max_columns, len(row))
349
350 total_rows = max(max_columns, total_rows)
351
352 new_matrix = []
353 for row in matrix:
354 row_len = len(row)
355 new_row = row[:]
356 if total_rows > row_len:
357 # Row too short. Pad it.
358 new_row += [0] * (total_rows - row_len)
359 new_matrix += [new_row]
360
361 while len(new_matrix) < total_rows:
362 new_matrix += [[0] * total_rows]
363
364 return new_matrix
365
366 def compute(self, cost_matrix):
367 """
368 Compute the indexes for the lowest-cost pairings between rows and
369 columns in the database. Returns a list of (row, column) tuples
370 that can be used to traverse the matrix.
371
372 :Parameters:
373 cost_matrix : list of lists
374 The cost matrix. If this cost matrix is not square, it
375 will be padded with zeros, via a call to ``pad_matrix()``.
376 (This method does *not* modify the caller's matrix. It
377 operates on a copy of the matrix.)
378
379 **WARNING**: This code handles square and rectangular
380 matrices. It does *not* handle irregular matrices.
381
382 :rtype: list
383 :return: A list of ``(row, column)`` tuples that describe the lowest
384 cost path through the matrix
385
386 """
387 self.C = self.pad_matrix(cost_matrix)
388 self.n = len(self.C)
389 self.original_length = len(cost_matrix)
390 self.original_width = len(cost_matrix[0])
391 self.row_covered = [False for i in range(self.n)]
392 self.col_covered = [False for i in range(self.n)]
393 self.Z0_r = 0
394 self.Z0_c = 0
395 self.path = self.__make_matrix(self.n * 2, 0)
396 self.marked = self.__make_matrix(self.n, 0)
397
398 done = False
399 step = 1
400
401 steps = { 1 : self.__step1,
402 2 : self.__step2,
403 3 : self.__step3,
404 4 : self.__step4,
405 5 : self.__step5,
406 6 : self.__step6 }
407
408 while not done:
409 try:
410 func = steps[step]
411 step = func()
412 except KeyError:
413 done = True
414
415 # Look for the starred columns
416 results = []
417 for i in range(self.original_length):
418 for j in range(self.original_width):
419 if self.marked[i][j] == 1:
420 results += [(i, j)]
421
422 return results
423
424 def __copy_matrix(self, matrix):
425 """Return an exact copy of the supplied matrix"""
426 return copy.deepcopy(matrix)
427
428 def __make_matrix(self, n, val):
429 """Create an *n*x*n* matrix, populating it with the specific value."""
430 matrix = []
431 for i in range(n):
432 matrix += [[val for j in range(n)]]
433 return matrix
434
435 def __step1(self):
436 """
437 For each row of the matrix, find the smallest element and
438 subtract it from every element in its row. Go to Step 2.
439 """
440 C = self.C
441 n = self.n
442 for i in range(n):
443 minval = min(self.C[i])
444 # Find the minimum value for this row and subtract that minimum
445 # from every element in the row.
446 for j in range(n):
447 self.C[i][j] -= minval
448
449 return 2
450
451 def __step2(self):
452 """
453 Find a zero (Z) in the resulting matrix. If there is no starred
454 zero in its row or column, star Z. Repeat for each element in the
455 matrix. Go to Step 3.
456 """
457 n = self.n
458 for i in range(n):
459 for j in range(n):
460 if (self.C[i][j] == 0) and \
461 (not self.col_covered[j]) and \
462 (not self.row_covered[i]):
463 self.marked[i][j] = 1
464 self.col_covered[j] = True
465 self.row_covered[i] = True
466
467 self.__clear_covers()
468 return 3
469
470 def __step3(self):
471 """
472 Cover each column containing a starred zero. If K columns are
473 covered, the starred zeros describe a complete set of unique
474 assignments. In this case, Go to DONE, otherwise, Go to Step 4.
475 """
476 n = self.n
477 count = 0
478 for i in range(n):
479 for j in range(n):
480 if self.marked[i][j] == 1:
481 self.col_covered[j] = True
482 count += 1
483
484 if count >= n:
485 step = 7 # done
486 else:
487 step = 4
488
489 return step
490
491 def __step4(self):
492 """
493 Find a noncovered zero and prime it. If there is no starred zero
494 in the row containing this primed zero, Go to Step 5. Otherwise,
495 cover this row and uncover the column containing the starred
496 zero. Continue in this manner until there are no uncovered zeros
497 left. Save the smallest uncovered value and Go to Step 6.
498 """
499 step = 0
500 done = False
501 row = -1
502 col = -1
503 star_col = -1
504 while not done:
505 (row, col) = self.__find_a_zero()
506 if row < 0:
507 done = True
508 step = 6
509 else:
510 self.marked[row][col] = 2
511 star_col = self.__find_star_in_row(row)
512 if star_col >= 0:
513 col = star_col
514 self.row_covered[row] = True
515 self.col_covered[col] = False
516 else:
517 done = True
518 self.Z0_r = row
519 self.Z0_c = col
520 step = 5
521
522 return step
523
524 def __step5(self):
525 """
526 Construct a series of alternating primed and starred zeros as
527 follows. Let Z0 represent the uncovered primed zero found in Step 4.
528 Let Z1 denote the starred zero in the column of Z0 (if any).
529 Let Z2 denote the primed zero in the row of Z1 (there will always
530 be one). Continue until the series terminates at a primed zero
531 that has no starred zero in its column. Unstar each starred zero
532 of the series, star each primed zero of the series, erase all
533 primes and uncover every line in the matrix. Return to Step 3
534 """
535 count = 0
536 path = self.path
537 path[count][0] = self.Z0_r
538 path[count][1] = self.Z0_c
539 done = False
540 while not done:
541 row = self.__find_star_in_col(path[count][1])
542 if row >= 0:
543 count += 1
544 path[count][0] = row
545 path[count][1] = path[count-1][1]
546 else:
547 done = True
548
549 if not done:
550 col = self.__find_prime_in_row(path[count][0])
551 count += 1
552 path[count][0] = path[count-1][0]
553 path[count][1] = col
554
555 self.__convert_path(path, count)
556 self.__clear_covers()
557 self.__erase_primes()
558 return 3
559
560 def __step6(self):
561 """
562 Add the value found in Step 4 to every element of each covered
563 row, and subtract it from every element of each uncovered column.
564 Return to Step 4 without altering any stars, primes, or covered
565 lines.
566 """
567 minval = self.__find_smallest()
568 for i in range(self.n):
569 for j in range(self.n):
570 if self.row_covered[i]:
571 self.C[i][j] += minval
572 if not self.col_covered[j]:
573 self.C[i][j] -= minval
574 return 4
575
576 def __find_smallest(self):
577 """Find the smallest uncovered value in the matrix."""
578 minval = sys.maxint
579 for i in range(self.n):
580 for j in range(self.n):
581 if (not self.row_covered[i]) and (not self.col_covered[j]):
582 if minval > self.C[i][j]:
583 minval = self.C[i][j]
584 return minval
585
586 def __find_a_zero(self):
587 """Find the first uncovered element with value 0"""
588 row = -1
589 col = -1
590 i = 0
591 n = self.n
592 done = False
593
594 while not done:
595 j = 0
596 while True:
597 if (self.C[i][j] == 0) and \
598 (not self.row_covered[i]) and \
599 (not self.col_covered[j]):
600 row = i
601 col = j
602 done = True
603 j += 1
604 if j >= n:
605 break
606 i += 1
607 if i >= n:
608 done = True
609
610 return (row, col)
611
612 def __find_star_in_row(self, row):
613 """
614 Find the first starred element in the specified row. Returns
615 the column index, or -1 if no starred element was found.
616 """
617 col = -1
618 for j in range(self.n):
619 if self.marked[row][j] == 1:
620 col = j
621 break
622
623 return col
624
625 def __find_star_in_col(self, col):
626 """
627 Find the first starred element in the specified row. Returns
628 the row index, or -1 if no starred element was found.
629 """
630 row = -1
631 for i in range(self.n):
632 if self.marked[i][col] == 1:
633 row = i
634 break
635
636 return row
637
638 def __find_prime_in_row(self, row):
639 """
640 Find the first prime element in the specified row. Returns
641 the column index, or -1 if no starred element was found.
642 """
643 col = -1
644 for j in range(self.n):
645 if self.marked[row][j] == 2:
646 col = j
647 break
648
649 return col
650
651 def __convert_path(self, path, count):
652 for i in range(count+1):
653 if self.marked[path[i][0]][path[i][1]] == 1:
654 self.marked[path[i][0]][path[i][1]] = 0
655 else:
656 self.marked[path[i][0]][path[i][1]] = 1
657
658 def __clear_covers(self):
659 """Clear all covered matrix cells"""
660 for i in range(self.n):
661 self.row_covered[i] = False
662 self.col_covered[i] = False
663
664 def __erase_primes(self):
665 """Erase all prime markings"""
666 for i in range(self.n):
667 for j in range(self.n):
668 if self.marked[i][j] == 2:
669 self.marked[i][j] = 0
670
671 # ---------------------------------------------------------------------------
672 # Functions
673 # ---------------------------------------------------------------------------
674
675 def make_cost_matrix(profit_matrix, inversion_function):
676 """
677 Create a cost matrix from a profit matrix by calling
678 'inversion_function' to invert each value. The inversion
679 function must take one numeric argument (of any type) and return
680 another numeric argument which is presumed to be the cost inverse
681 of the original profit.
682
683 This is a static method. Call it like this:
684
685 .. python::
686
687 cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func)
688
689 For example:
690
691 .. python::
692
693 cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxint - x)
694
695 :Parameters:
696 profit_matrix : list of lists
697 The matrix to convert from a profit to a cost matrix
698
699 inversion_function : function
700 The function to use to invert each entry in the profit matrix
701
702 :rtype: list of lists
703 :return: The converted matrix
704 """
705 cost_matrix = []
706 for row in profit_matrix:
707 cost_matrix.append([inversion_function(value) for value in row])
708 return cost_matrix
709
710 def print_matrix(matrix, msg=None):
711 """
712 Convenience function: Displays the contents of a matrix of integers.
713
714 :Parameters:
715 matrix : list of lists
716 Matrix to print
717
718 msg : str
719 Optional message to print before displaying the matrix
720 """
721 import math
722
723 if msg is not None:
724 print msg
725
726 # Calculate the appropriate format width.
727 width = 0
728 for row in matrix:
729 for val in row:
730 width = max(width, int(math.log10(val)) + 1)
731
732 # Make the format string
733 format = '%%%dd' % width
734
735 # Print the matrix
736 for row in matrix:
737 sep = '['
738 for val in row:
739 sys.stdout.write(sep + format % val)
740 sep = ', '
741 sys.stdout.write(']\n')
742
743 # ---------------------------------------------------------------------------
744 # Main
745 # ---------------------------------------------------------------------------
746
747 if __name__ == '__main__':
748
749
750 matrices = [
751 # Square
752 ([[400, 150, 400],
753 [400, 450, 600],
754 [300, 225, 300]],
755 850 # expected cost
756 ),
757
758 # Rectangular variant
759 ([[400, 150, 400, 1],
760 [400, 450, 600, 2],
761 [300, 225, 300, 3]],
762 452 # expected cost
763 ),
764
765 # Square
766 ([[10, 10, 8],
767 [ 9, 8, 1],
768 [ 9, 7, 4]],
769 18
770 ),
771
772 # Rectangular variant
773 ([[10, 10, 8, 11],
774 [ 9, 8, 1, 1],
775 [ 9, 7, 4, 10]],
776 15
777 ),
778 ]
779
780 m = Munkres()
781 for cost_matrix, expected_total in matrices:
782 print_matrix(cost_matrix, msg='cost matrix')
783 indexes = m.compute(cost_matrix)
784 total_cost = 0
785 for r, c in indexes:
786 x = cost_matrix[r][c]
787 total_cost += x
788 print '(%d, %d) -> %d' % (r, c, x)
789 print 'lowest cost=%d' % total_cost
790 assert expected_total == total_cost
791