annotate clustalomega/clustal-omega-0.2.0/src/kmpp/KmTree.cpp @ 0:ff1768533a07

Migrated tool version 0.2 from old tool shed archive to new tool shed repository
author clustalomega
date Tue, 07 Jun 2011 17:04:25 -0400
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
1 // See KmTree.cpp
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
2 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
3 // Author: David Arthur (darthur@gmail.com), 2009
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
4
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
5 // Includes
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
6 #include "KmTree.h"
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
7 #include <iostream>
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
8 #include <stdlib.h>
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
9 #include <stdio.h>
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
10 using namespace std;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
11
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
12 KmTree::KmTree(int n, int d, Scalar *points): n_(n), d_(d), points_(points) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
13 // Initialize memory
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
14 // DD: need to cast to long otherwise malloc will fail
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
15 // if we need more than 2 gigabytes or so
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
16 int node_size = sizeof(Node) + d_ * 3 * sizeof(Scalar);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
17 node_data_ = (char*)malloc((2*(long unsigned int)n-1) * node_size);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
18 point_indices_ = (int*)malloc(n * sizeof(int));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
19 for (int i = 0; i < n; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
20 point_indices_[i] = i;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
21 KM_ASSERT(node_data_ != 0 && point_indices_ != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
22
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
23 // Calculate the bounding box for the points
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
24 Scalar *bound_v1 = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
25 Scalar *bound_v2 = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
26 KM_ASSERT(bound_v1 != 0 && bound_v2 != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
27 PointCopy(bound_v1, points, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
28 PointCopy(bound_v2, points, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
29 for (int i = 1; i < n; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
30 for (int j = 0; j < d; j++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
31 if (bound_v1[j] > points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
32 if (bound_v2[j] < points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
33 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
34
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
35 // Build the tree
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
36 char *temp_node_data = node_data_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
37 top_node_ = BuildNodes(points, 0, n-1, &temp_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
38
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
39 // Cleanup
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
40 PointFree(bound_v1);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
41 PointFree(bound_v2);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
42 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
43
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
44 KmTree::~KmTree() {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
45 free(point_indices_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
46 free(node_data_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
47 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
48
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
49 Scalar KmTree::DoKMeansStep(int k, Scalar *centers, int *assignment) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
50 // Create an invalid center for comparison purposes
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
51 Scalar *bad_center = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
52 KM_ASSERT(bad_center != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
53 memset(bad_center, 0xff, d_ * sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
54
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
55 // Allocate data
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
56 Scalar *sums = (Scalar*)calloc(k * d_, sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
57 int *counts = (int*)calloc(k, sizeof(int));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
58 int num_candidates = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
59 int *candidates = (int*)malloc(k * sizeof(int));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
60 KM_ASSERT(sums != 0 && counts != 0 && candidates != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
61 for (int i = 0; i < k; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
62 if (memcmp(centers + i*d_, bad_center, d_ * sizeof(Scalar)) != 0)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
63 candidates[num_candidates++] = i;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
64
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
65 // Find nodes
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
66 Scalar result = DoKMeansStepAtNode(top_node_, num_candidates, candidates, centers, sums,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
67 counts, assignment);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
68
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
69 // Set the new centers
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
70 for (int i = 0; i < k; i++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
71 if (counts[i] > 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
72 PointScale(sums + i*d_, Scalar(1) / counts[i], d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
73 PointCopy(centers + i*d_, sums + i*d_, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
74 } else {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
75 memcpy(centers + i*d_, bad_center, d_ * sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
76 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
77 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
78
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
79 // Cleanup memory
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
80 PointFree(bad_center);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
81 free(candidates);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
82 free(counts);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
83 free(sums);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
84 return result;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
85 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
86
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
87 // Helper functions for constructor
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
88 // ================================
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
89
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
90 // Build a kd tree from the given set of points
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
91 KmTree::Node *KmTree::BuildNodes(Scalar *points, int first_index, int last_index,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
92 char **next_node_data) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
93 // Allocate the node
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
94 Node *node = (Node*)(*next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
95 (*next_node_data) += sizeof(Node);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
96 node->sum = (Scalar*)(*next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
97 (*next_node_data) += sizeof(Scalar) * d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
98 node->median = (Scalar*)(*next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
99 (*next_node_data) += sizeof(Scalar) * d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
100 node->radius = (Scalar*)(*next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
101 (*next_node_data) += sizeof(Scalar) * d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
102
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
103 // Fill in basic info
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
104 node->num_points = (last_index - first_index + 1);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
105 node->first_point_index = first_index;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
106
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
107 // Calculate the bounding box
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
108 Scalar *first_point = points + point_indices_[first_index] * d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
109 Scalar *bound_p1 = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
110 Scalar *bound_p2 = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
111 KM_ASSERT(bound_p1 != 0 && bound_p2 != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
112 PointCopy(bound_p1, first_point, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
113 PointCopy(bound_p2, first_point, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
114 for (int i = first_index+1; i <= last_index; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
115 for (int j = 0; j < d_; j++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
116 Scalar c = points[point_indices_[i]*d_ + j];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
117 if (bound_p1[j] > c) bound_p1[j] = c;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
118 if (bound_p2[j] < c) bound_p2[j] = c;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
119 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
120
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
121 // Calculate bounding box stats and delete the bounding box memory
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
122 Scalar max_radius = -1;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
123 int split_d = -1;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
124 for (int j = 0; j < d_; j++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
125 node->median[j] = (bound_p1[j] + bound_p2[j]) / 2;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
126 node->radius[j] = (bound_p2[j] - bound_p1[j]) / 2;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
127 if (node->radius[j] > max_radius) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
128 max_radius = node->radius[j];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
129 split_d = j;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
130 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
131 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
132 PointFree(bound_p2);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
133 PointFree(bound_p1);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
134
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
135 // If the max spread is 0, make this a leaf node
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
136 if (max_radius == 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
137 node->lower_node = node->upper_node = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
138 PointCopy(node->sum, first_point, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
139 if (last_index != first_index)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
140 PointScale(node->sum, Scalar(last_index - first_index + 1), d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
141 node->opt_cost = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
142 return node;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
143 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
144
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
145 // Partition the points around the midpoint in this dimension. The partitioning is done in-place
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
146 // by iterating from left-to-right and right-to-left in the same way that partioning is done for
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
147 // quicksort.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
148 Scalar split_pos = node->median[split_d];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
149 int i1 = first_index, i2 = last_index, size1 = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
150 while (i1 <= i2) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
151 bool is_i1_good = (points[point_indices_[i1]*d_ + split_d] < split_pos);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
152 bool is_i2_good = (points[point_indices_[i2]*d_ + split_d] >= split_pos);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
153 if (!is_i1_good && !is_i2_good) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
154 int temp = point_indices_[i1];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
155 point_indices_[i1] = point_indices_[i2];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
156 point_indices_[i2] = temp;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
157 is_i1_good = is_i2_good = true;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
158 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
159 if (is_i1_good) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
160 i1++;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
161 size1++;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
162 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
163 if (is_i2_good) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
164 i2--;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
165 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
166 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
167
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
168 // Create the child nodes
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
169 KM_ASSERT(size1 >= 1 && size1 <= last_index - first_index);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
170 node->lower_node = BuildNodes(points, first_index, first_index + size1 - 1, next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
171 node->upper_node = BuildNodes(points, first_index + size1, last_index, next_node_data);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
172
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
173 // Calculate the new sum and opt cost
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
174 PointCopy(node->sum, node->lower_node->sum, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
175 PointAdd(node->sum, node->upper_node->sum, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
176 Scalar *center = PointAllocate(d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
177 KM_ASSERT(center != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
178 PointCopy(center, node->sum, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
179 PointScale(center, Scalar(1) / node->num_points, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
180 node->opt_cost = GetNodeCost(node->lower_node, center) + GetNodeCost(node->upper_node, center);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
181 PointFree(center);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
182 return node;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
183 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
184
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
185 // Returns the total contribution of all points in the given kd-tree node, assuming they are all
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
186 // assigned to a center at the given location. We need to return:
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
187 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
188 // sum_{x \in node} ||x - center||^2.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
189 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
190 // If c denotes the center of mass of the points in this node and n denotes the number of points in
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
191 // it, then this quantity is given by
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
192 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
193 // n * ||c - center||^2 + sum_{x \in node} ||x - c||^2
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
194 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
195 // The sum is precomputed for each node as opt_cost. This formula follows from expanding both sides
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
196 // as dot products. See Kanungo/Mount for more info.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
197 Scalar KmTree::GetNodeCost(const Node *node, Scalar *center) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
198 Scalar dist_sq = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
199 for (int i = 0; i < d_; i++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
200 Scalar x = (node->sum[i] / node->num_points) - center[i];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
201 dist_sq += x*x;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
202 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
203 return node->opt_cost + node->num_points * dist_sq;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
204 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
205
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
206 // Helper functions for DoKMeans step
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
207 // ==================================
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
208
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
209 // A recursive version of DoKMeansStep. This determines which clusters all points that are rooted
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
210 // node will be assigned to, and updates sums, counts and assignment (if not null) accordingly.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
211 // candidates maintains the set of cluster indices which could possibly be the closest clusters
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
212 // for points in this subtree.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
213 Scalar KmTree::DoKMeansStepAtNode(const Node *node, int k, int *candidates, Scalar *centers,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
214 Scalar *sums, int *counts, int *assignment) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
215 // Determine which center the node center is closest to
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
216 Scalar min_dist_sq = PointDistSq(node->median, centers + candidates[0]*d_, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
217 int closest_i = candidates[0];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
218 for (int i = 1; i < k; i++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
219 Scalar dist_sq = PointDistSq(node->median, centers + candidates[i]*d_, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
220 if (dist_sq < min_dist_sq) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
221 min_dist_sq = dist_sq;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
222 closest_i = candidates[i];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
223 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
224 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
225
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
226 // If this is a non-leaf node, recurse if necessary
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
227 if (node->lower_node != 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
228 // Build the new list of candidates
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
229 int new_k = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
230 int *new_candidates = (int*)malloc(k * sizeof(int));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
231 KM_ASSERT(new_candidates != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
232 for (int i = 0; i < k; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
233 if (!ShouldBePruned(node->median, node->radius, centers, closest_i, candidates[i]))
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
234 new_candidates[new_k++] = candidates[i];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
235
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
236 // Recurse if there's at least two
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
237 if (new_k > 1) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
238 Scalar result = DoKMeansStepAtNode(node->lower_node, new_k, new_candidates, centers,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
239 sums, counts, assignment) +
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
240 DoKMeansStepAtNode(node->upper_node, new_k, new_candidates, centers,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
241 sums, counts, assignment);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
242 free(new_candidates);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
243 return result;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
244 } else {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
245 free(new_candidates);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
246 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
247 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
248
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
249 // Assigns all points within this node to a single center
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
250 PointAdd(sums + closest_i*d_, node->sum, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
251 counts[closest_i] += node->num_points;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
252 if (assignment != 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
253 for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
254 assignment[point_indices_[i]] = closest_i;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
255 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
256 return GetNodeCost(node, centers + closest_i*d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
257 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
258
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
259 // Determines whether every point in the box is closer to centers[best_index] than to
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
260 // centers[test_index].
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
261 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
262 // If x is a point, c_0 = centers[best_index], c = centers[test_index], then:
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
263 // (x-c).(x-c) < (x-c_0).(x-c_0)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
264 // <=> (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
265 //
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
266 // The right-hand side is maximized for a vertex of the box where for each dimension, we choose
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
267 // the low or high value based on the sign of x-c_0 in that dimension.
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
268 bool KmTree::ShouldBePruned(Scalar *box_median, Scalar *box_radius, Scalar *centers,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
269 int best_index, int test_index) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
270 if (best_index == test_index)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
271 return false;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
272
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
273 Scalar *best = centers + best_index*d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
274 Scalar *test = centers + test_index*d_;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
275 Scalar lhs = 0, rhs = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
276 for (int i = 0; i < d_; i++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
277 Scalar component = test[i] - best[i];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
278 lhs += component * component;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
279 if (component > 0)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
280 rhs += (box_median[i] + box_radius[i] - best[i]) * component;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
281 else
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
282 rhs += (box_median[i] - box_radius[i] - best[i]) * component;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
283 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
284 return (lhs >= 2*rhs);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
285 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
286
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
287 Scalar KmTree::SeedKMeansPlusPlus(int k, Scalar *centers) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
288 Scalar *dist_sq = (Scalar*)malloc(n_ * sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
289 KM_ASSERT(dist_sq != 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
290
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
291 // Choose an initial center uniformly at random
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
292 SeedKmppSetClusterIndex(top_node_, 0);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
293 int i = GetRandom(n_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
294 memcpy(centers, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
295 Scalar total_cost = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
296 for (int j = 0; j < n_; j++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
297 dist_sq[j] = PointDistSq(points_ + point_indices_[j]*d_, centers, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
298 total_cost += dist_sq[j];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
299 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
300
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
301 // Repeatedly choose more centers
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
302 for (int new_cluster = 1; new_cluster < k; new_cluster++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
303 while (1) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
304 Scalar cutoff = (rand() / Scalar(RAND_MAX)) * total_cost;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
305 Scalar cur_cost = 0;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
306 for (i = 0; i < n_; i++) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
307 cur_cost += dist_sq[i];
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
308 if (cur_cost >= cutoff)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
309 break;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
310 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
311 if (i < n_)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
312 break;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
313 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
314 memcpy(centers + new_cluster*d_, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
315 total_cost = SeedKmppUpdateAssignment(top_node_, new_cluster, centers, dist_sq);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
316 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
317
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
318 // Clean up and return
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
319 free(dist_sq);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
320 return total_cost;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
321 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
322
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
323 // Helper functions for SeedKMeansPlusPlus
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
324 // =======================================
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
325
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
326 // Sets kmpp_cluster_index to 0 for all nodes
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
327 void KmTree::SeedKmppSetClusterIndex(const Node *node, int value) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
328 node->kmpp_cluster_index = value;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
329 if (node->lower_node != 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
330 SeedKmppSetClusterIndex(node->lower_node, value);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
331 SeedKmppSetClusterIndex(node->upper_node, value);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
332 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
333 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
334
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
335 Scalar KmTree::SeedKmppUpdateAssignment(const Node *node, int new_cluster, Scalar *centers,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
336 Scalar *dist_sq) const {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
337 // See if we can assign all points in this node to one cluster
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
338 if (node->kmpp_cluster_index >= 0) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
339 if (ShouldBePruned(node->median, node->radius, centers, node->kmpp_cluster_index, new_cluster))
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
340 return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
341 if (ShouldBePruned(node->median, node->radius, centers, new_cluster,
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
342 node->kmpp_cluster_index)) {
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
343 SeedKmppSetClusterIndex(node, new_cluster);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
344 for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
345 dist_sq[i] = PointDistSq(points_ + point_indices_[i]*d_, centers + new_cluster*d_, d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
346 return GetNodeCost(node, centers + new_cluster*d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
347 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
348
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
349 // It may be that the a leaf-node point is equidistant from the new center or old
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
350 if (node->lower_node == 0)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
351 return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
352 }
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
353
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
354 // Recurse
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
355 Scalar cost = SeedKmppUpdateAssignment(node->lower_node, new_cluster, centers, dist_sq) +
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
356 SeedKmppUpdateAssignment(node->upper_node, new_cluster, centers, dist_sq);
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
357 int i1 = node->lower_node->kmpp_cluster_index, i2 = node->upper_node->kmpp_cluster_index;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
358 if (i1 == i2 && i1 != -1)
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
359 node->kmpp_cluster_index = i1;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
360 else
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
361 node->kmpp_cluster_index = -1;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
362 return cost;
ff1768533a07 Migrated tool version 0.2 from old tool shed archive to new tool shed repository
clustalomega
parents:
diff changeset
363 }