klee
DiscretePDF.inc
Go to the documentation of this file.
1//===- DiscretePDF.inc - --*- C++ -*---------------------------------------===//
2//
3// The KLEE Symbolic Virtual Machine
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#include <cassert>
11namespace klee {
12
13template <class T, class Comparator>
14class DiscretePDF<T, Comparator>::Node
15{
16private:
17 bool m_mark;
18
19public:
20 Node *parent, *left, *right;
21 T key;
22 weight_type weight, sumWeights;
23
24public:
25 Node(T key_, weight_type weight_, Node *parent_);
26 ~Node();
27
28 Node *sibling() { return this==parent->left?parent->right:parent->left; }
29
30 void markRed() { m_mark = true; }
31 void markBlack() { m_mark = false; }
32 bool isBlack() { return !m_mark; }
33 bool leftIsBlack() { return !left || left->isBlack(); }
34 bool rightIsBlack() { return !right || right->isBlack(); }
35 void setSum() {
36 sumWeights = weight;
37 if (left) sumWeights += left->sumWeights;
38 if (right) sumWeights += right->sumWeights;
39 }
40};
41
43
44template <class T, class Comparator>
45DiscretePDF<T, Comparator>::Node::Node(T key_, weight_type weight_, Node *parent_) {
46 m_mark = false;
47
48 key = key_;
49 weight = weight_;
50 sumWeights = 0;
51 left = right = 0;
52 parent = parent_;
53}
54
55template <class T, class Comparator>
56DiscretePDF<T, Comparator>::Node::~Node() {
57 delete left;
58 delete right;
59}
60
61//
62
63template <class T, class Comparator>
65 m_root = 0;
66}
67
68template <class T, class Comparator>
70 delete m_root;
71}
72
73template <class T, class Comparator>
75 return m_root == 0;
76}
77
78template <class T, class Comparator>
79void DiscretePDF<T, Comparator>::insert(T item, weight_type weight) {
80 Comparator lessThan;
81 Node *p=0, *n=m_root;
82
83 while (n) {
84 if (!n->leftIsBlack() && !n->rightIsBlack())
85 split(n);
86
87 p = n;
88 if (n->key==item) {
89 assert(0 && "insert: argument(item) already in tree");
90 } else {
91 n = lessThan(item, n->key) ? n->left : n->right;
92 }
93 }
94
95 n = new Node(item, weight, p);
96
97 if (!p) {
98 m_root = n;
99 } else {
100 if (lessThan(item, p->key)) {
101 p->left = n;
102 } else {
103 p->right = n;
104 }
105
106 split(n);
107 }
108
109 propagateSumsUp(n);
110}
111
112template <class T, class Comparator>
114 Node **np = lookup(item, 0);
115 Node *child, *n = *np;
116
117 if (!n) {
118 assert(0 && "remove: argument(item) not in tree");
119 } else {
120 if (n->left) {
121 Node **leftMaxp = &n->left;
122
123 while ((*leftMaxp)->right)
124 leftMaxp = &(*leftMaxp)->right;
125
126 n->key = (*leftMaxp)->key;
127 n->weight = (*leftMaxp)->weight;
128
129 np = leftMaxp;
130 n = *np;
131 }
132
133 // node now has at most one child
134
135 child = n->left?n->left:n->right;
136 *np = child;
137
138 if (child) {
139 child->parent = n->parent;
140
141 if (n->isBlack()) {
142 lengthen(child);
143 }
144 }
145
146 propagateSumsUp(n->parent);
147
148 n->left = n->right = 0;
149 delete n;
150 }
151}
152
153template <class T, class Comparator>
154void DiscretePDF<T, Comparator>::update(T item, weight_type weight) {
155 Node *n = *lookup(item, 0);
156
157 if (!n) {
158 assert(0 && "update: argument(item) not in tree");
159 } else {
160 n->weight = weight;
161 propagateSumsUp(n);
162 }
163}
164
165template <class T, class Comparator>
167 assert (!((p < 0.0) || (p >= 1.0)) && "choose: argument(p) outside valid range");
168
169 if (!m_root)
170 assert(0 && "choose: choose() called on empty tree");
171
172 weight_type w = (weight_type) (m_root->sumWeights * p);
173 Node *n = m_root;
174
175 while (1) {
176 if (n->left) {
177 if (w<n->left->sumWeights) {
178 n = n->left;
179 continue;
180 } else {
181 w -= n->left->sumWeights;
182 }
183 }
184 if (w<n->weight || !n->right) {
185 break; // !n->right condition shouldn't be necessary, just sanity check
186 }
187 w -= n->weight;
188 n = n->right;
189 }
190
191 return n->key;
192}
193
194template <class T, class Comparator>
196 Node *n = *lookup(item, 0);
197
198 return !!n;
199}
200
201template <class T, class Comparator>
203 Node *n = *lookup(item, 0);
204 assert(n);
205 return n->weight;
206}
207
208//
209
210template <class T, class Comparator>
211typename DiscretePDF<T, Comparator>::Node **
212DiscretePDF<T, Comparator>::lookup(T item, Node **parent_out) {
213 Comparator lessThan;
214 Node *n, *p=0, **np=&m_root;
215
216 while ((n = *np)) {
217 if (n->key==item) {
218 break;
219 } else {
220 p = n;
221 if (lessThan(item, n->key)) {
222 np = &n->left;
223 } else {
224 np = &n->right;
225 }
226 }
227 }
228
229 if (parent_out)
230 *parent_out = p;
231 return np;
232}
233
234template <class T, class Comparator>
236 if (n->left) n->left->markBlack();
237 if (n->right) n->right->markBlack();
238
239 if (n->parent) {
240 Node *p = n->parent;
241
242 n->markRed();
243
244 if (!p->isBlack()) {
245 p->parent->markRed();
246
247 // not same direction
248 if (!((n==p->left && p==p->parent->left) ||
249 (n==p->right && p==p->parent->right))) {
250 rotate(n);
251 p = n;
252 }
253
254 rotate(p);
255 p->markBlack();
256 }
257 }
258}
259
260template <class T, class Comparator>
262 Node *p=n->parent, *pp=p->parent;
263
264 n->parent = pp;
265 p->parent = n;
266
267 if (n==p->left) {
268 p->left = n->right;
269 n->right = p;
270 if (p->left) p->left->parent = p;
271 } else {
272 p->right = n->left;
273 n->left = p;
274 if (p->right) p->right->parent = p;
275 }
276
277 n->setSum();
278 p->setSum();
279
280 if (!pp) {
281 m_root = n;
282 } else {
283 if (p==pp->left) {
284 pp->left = n;
285 } else {
286 pp->right = n;
287 }
288 }
289}
290
291template <class T, class Comparator>
293 if (!n->isBlack()) {
294 n->markBlack();
295 } else if (n->parent) {
296 Node *sibling = n->sibling();
297
298 if (sibling && !sibling->isBlack()) {
299 n->parent->markRed();
300 sibling->markBlack();
301
302 rotate(sibling); // node sibling is now old sibling child, must be black
303 sibling = n->sibling();
304 }
305
306 // sibling is black
307
308 if (!sibling) {
309 lengthen(n->parent);
310 } else if (sibling->leftIsBlack() && sibling->rightIsBlack()) {
311 if (n->parent->isBlack()) {
312 sibling->markRed();
313 lengthen(n->parent);
314 } else {
315 sibling->markRed();
316 n->parent->markBlack();
317 }
318 } else {
319 if (n==n->parent->left && sibling->rightIsBlack()) {
320 rotate(sibling->left); // sibling->left must be red
321 sibling->markRed();
322 sibling->parent->markBlack();
323 sibling = sibling->parent;
324 } else if (n==n->parent->right && sibling->leftIsBlack()) {
325 rotate(sibling->right); // sibling->right must be red
326 sibling->markRed();
327 sibling->parent->markBlack();
328 sibling = sibling->parent;
329 }
330
331 // sibling is black, and sibling's far child is red
332
333 rotate(sibling);
334 if (!n->parent->isBlack())
335 sibling->markRed();
336 sibling->left->markBlack();
337 sibling->right->markBlack();
338 }
339 }
340}
341
342template <class T, class Comparator>
344 for (; n; n=n->parent)
345 n->setSum();
346}
347
348}
349
void insert(T item, weight_type weight)
void update(T item, weight_type newWeight)
void lengthen(Node *node)
weight_type getWeight(T item)
void remove(T item)
T choose(double p)
void rotate(Node *node)
void propagateSumsUp(Node *n)
bool empty() const
bool inTree(T item)
Node ** lookup(T item, Node **parent_out)
void split(Node *node)
Definition: main.cpp:291