/** * @file BAVL.h * @author Ambroz Bizjak * * @section LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the author nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * @section DESCRIPTION * * AVL tree. */ #ifndef BADVPN_STRUCTURE_BAVL_H #define BADVPN_STRUCTURE_BAVL_H //#define BAVL_DEBUG #include #include #include /** * Handler function called by tree algorithms to compare two values. * For any two values, the comparator must always return the same result. * The <= relation defined by the comparator must be a total order. * Values are obtained like that: * - The value of a node in the tree, or a node that is being inserted is: * (uint8_t *)node + offset. * - The value being looked up is the same as given to the lookup function. * * @param user as in {@link BAVL_Init} * @param val1 first value * @param val2 second value * @return -1 if val1 < val2, 0 if val1 = val2, 1 if val1 > val2 */ typedef int (*BAVL_comparator) (void *user, void *val1, void *val2); struct BAVLNode; /** * AVL tree. */ typedef struct { int offset; BAVL_comparator comparator; void *user; struct BAVLNode *root; } BAVL; /** * AVL tree node. */ typedef struct BAVLNode { struct BAVLNode *parent; struct BAVLNode *link[2]; int8_t balance; #ifdef BAVL_COUNT uint64_t count; #endif } BAVLNode; /** * Initializes the tree. * * @param o tree to initialize * @param offset offset of a value from its node * @param comparator value comparator handler function * @param user value to pass to comparator */ static void BAVL_Init (BAVL *o, int offset, BAVL_comparator comparator, void *user); /** * Inserts a node into the tree. * Must not be called from comparator. * * @param o the tree * @param node uninitialized node to insert. Must have a valid value (its value * may be passed to the comparator during insertion). * @param ref if not NULL, will return (regardless if insertion succeeded): * - the greatest node lesser than the inserted value, or (not in order) * - the smallest node greater than the inserted value, or * - NULL meaning there were no nodes in the tree. * @param 1 on success, 0 if an element with an equal value is already in the tree */ static int BAVL_Insert (BAVL *o, BAVLNode *node, BAVLNode **ref); /** * Removes a node from the tree. * Must not be called from comparator. * * @param o the tree * @param node node to remove */ static void BAVL_Remove (BAVL *o, BAVLNode *node); /** * Checks if the tree is empty. * Must not be called from comparator. * * @param o the tree * @return 1 if empty, 0 if not */ static int BAVL_IsEmpty (const BAVL *o); /** * Looks for a value in the tree. * Must not be called from comparator. * * @param o the tree * @param val value to look for * @return If a node is in the thee with an equal value, that node. * Else if the tree is not empty: * - the greatest node lesser than the given value, or (not in order) * - the smallest node greater than the given value. * NULL if the tree is empty. */ static BAVLNode * BAVL_Lookup (const BAVL *o, void *val); /** * Looks for a value in the tree. * Must not be called from comparator. * * @param o the tree * @param val value to look for * @return If a node is in the thee with an equal value, that node. * Else NULL. */ static BAVLNode * BAVL_LookupExact (const BAVL *o, void *val); /** * Returns the smallest node in the tree, or NULL if the tree is empty. * * @param o the tree * @return smallest node or NULL */ static BAVLNode * BAVL_GetFirst (const BAVL *o); /** * Returns the greatest node in the tree, or NULL if the tree is empty. * * @param o the tree * @return greatest node or NULL */ static BAVLNode * BAVL_GetLast (const BAVL *o); /** * Returns the node that follows the given node, or NULL if it's the * last node. * * @param o the tree * @param n node * @return next node, or NULL */ static BAVLNode * BAVL_GetNext (const BAVL *o, BAVLNode *n); /** * Returns the node that precedes the given node, or NULL if it's the * first node. * * @param o the tree * @param n node * @return previous node, or NULL */ static BAVLNode * BAVL_GetPrev (const BAVL *o, BAVLNode *n); #ifdef BAVL_COUNT static uint64_t BAVL_Count (const BAVL *o); static uint64_t BAVL_IndexOf (const BAVL *o, BAVLNode *n); static BAVLNode * BAVL_GetAt (const BAVL *o, uint64_t index); #endif static void BAVL_Verify (BAVL *o); #define BAVL_MAX(_a, _b) ((_a) > (_b) ? (_a) : (_b)) #define BAVL_OPTNEG(_a, _neg) ((_neg) ? -(_a) : (_a)) static void * _BAVL_node_value (const BAVL *o, BAVLNode *n) { return ((uint8_t *)n + o->offset); } static int _BAVL_compare_values (const BAVL *o, void *v1, void *v2) { int res = o->comparator(o->user, v1, v2); ASSERT(res == -1 || res == 0 || res == 1) return res; } static int _BAVL_compare_nodes (BAVL *o, BAVLNode *n1, BAVLNode *n2) { return _BAVL_compare_values(o, _BAVL_node_value(o, n1), _BAVL_node_value(o, n2)); } #ifdef BAVL_AUTO_VERIFY #define BAVL_ASSERT(_h) BAVL_Verify((_h)); #else #define BAVL_ASSERT(_h) #endif static int _BAVL_assert_recurser (BAVL *o, BAVLNode *n) { ASSERT_FORCE(n->balance >= -1) ASSERT_FORCE(n->balance <= 1) int height_left = 0; int height_right = 0; #ifdef BAVL_COUNT uint64_t count_left = 0; uint64_t count_right = 0; #endif // check left subtree if (n->link[0]) { // check parent link ASSERT_FORCE(n->link[0]->parent == n) // check binary search tree ASSERT_FORCE(_BAVL_compare_nodes(o, n->link[0], n) == -1) // recursively calculate height height_left = _BAVL_assert_recurser(o, n->link[0]); #ifdef BAVL_COUNT count_left = n->link[0]->count; #endif } // check right subtree if (n->link[1]) { // check parent link ASSERT_FORCE(n->link[1]->parent == n) // check binary search tree ASSERT_FORCE(_BAVL_compare_nodes(o, n->link[1], n) == 1) // recursively calculate height height_right = _BAVL_assert_recurser(o, n->link[1]); #ifdef BAVL_COUNT count_right = n->link[1]->count; #endif } // check balance factor ASSERT_FORCE(n->balance == height_right - height_left) #ifdef BAVL_COUNT // check count ASSERT_FORCE(n->count == 1 + count_left + count_right) #endif return (BAVL_MAX(height_left, height_right) + 1); } #ifdef BAVL_COUNT static void _BAVL_update_count_from_children (BAVLNode *n) { n->count = 1 + (n->link[0] ? n->link[0]->count : 0) + (n->link[1] ? n->link[1]->count : 0); } #endif static void _BAVL_rotate (BAVL *tree, BAVLNode *r, uint8_t dir) { BAVLNode *nr = r->link[!dir]; r->link[!dir] = nr->link[dir]; if (r->link[!dir]) { r->link[!dir]->parent = r; } nr->link[dir] = r; nr->parent = r->parent; if (nr->parent) { nr->parent->link[r == r->parent->link[1]] = nr; } else { tree->root = nr; } r->parent = nr; #ifdef BAVL_COUNT // update counts _BAVL_update_count_from_children(r); // first r! _BAVL_update_count_from_children(nr); #endif } static BAVLNode * _BAVL_subtree_max (BAVLNode *n) { ASSERT(n) while (n->link[1]) { n = n->link[1]; } return n; } static void _BAVL_replace_subtree (BAVL *tree, BAVLNode *dest, BAVLNode *n) { ASSERT(dest) if (dest->parent) { dest->parent->link[dest == dest->parent->link[1]] = n; } else { tree->root = n; } if (n) { n->parent = dest->parent; } #ifdef BAVL_COUNT // update counts for (BAVLNode *c = dest->parent; c; c = c->parent) { ASSERT(c->count >= dest->count) c->count -= dest->count; if (n) { ASSERT(n->count <= UINT64_MAX - c->count) c->count += n->count; } } #endif } static void _BAVL_swap_nodes (BAVL *tree, BAVLNode *n1, BAVLNode *n2) { if (n2->parent == n1 || n1->parent == n2) { // when the nodes are directly connected we need special handling // make sure n1 is above n2 if (n1->parent == n2) { BAVLNode *t = n1; n1 = n2; n2 = t; } uint8_t side = (n2 == n1->link[1]); BAVLNode *c = n1->link[!side]; if (n1->link[0] = n2->link[0]) { n1->link[0]->parent = n1; } if (n1->link[1] = n2->link[1]) { n1->link[1]->parent = n1; } if (n2->parent = n1->parent) { n2->parent->link[n1 == n1->parent->link[1]] = n2; } else { tree->root = n2; } n2->link[side] = n1; n1->parent = n2; if (n2->link[!side] = c) { c->parent = n2; } } else { BAVLNode *temp; // swap parents temp = n1->parent; if (n1->parent = n2->parent) { n1->parent->link[n2 == n2->parent->link[1]] = n1; } else { tree->root = n1; } if (n2->parent = temp) { n2->parent->link[n1 == temp->link[1]] = n2; } else { tree->root = n2; } // swap left children temp = n1->link[0]; if (n1->link[0] = n2->link[0]) { n1->link[0]->parent = n1; } if (n2->link[0] = temp) { n2->link[0]->parent = n2; } // swap right children temp = n1->link[1]; if (n1->link[1] = n2->link[1]) { n1->link[1]->parent = n1; } if (n2->link[1] = temp) { n2->link[1]->parent = n2; } } // swap balance factors int8_t b = n1->balance; n1->balance = n2->balance; n2->balance = b; #ifdef BAVL_COUNT // swap counts uint64_t c = n1->count; n1->count = n2->count; n2->count = c; #endif } static void _BAVL_rebalance (BAVL *o, BAVLNode *node, uint8_t side, int8_t deltac) { ASSERT(side == 0 || side == 1) ASSERT(deltac >= -1 && deltac <= 1) ASSERT(node->balance >= -1 && node->balance <= 1) // if no subtree changed its height, no more rebalancing is needed if (deltac == 0) { return; } // calculate how much our height changed int8_t delta = BAVL_MAX(deltac, BAVL_OPTNEG(node->balance, side)) - BAVL_MAX(0, BAVL_OPTNEG(node->balance, side)); ASSERT(delta >= -1 && delta <= 1) // update our balance factor node->balance -= BAVL_OPTNEG(deltac, side); BAVLNode *child; BAVLNode *gchild; // perform transformations if the balance factor is wrong if (node->balance == 2 || node->balance == -2) { uint8_t bside; int8_t bsidef; if (node->balance == 2) { bside = 1; bsidef = 1; } else { bside = 0; bsidef = -1; } ASSERT(node->link[bside]) child = node->link[bside]; switch (child->balance * bsidef) { case 1: _BAVL_rotate(o, node, !bside); node->balance = 0; child->balance = 0; node = child; delta -= 1; break; case 0: _BAVL_rotate(o, node, !bside); node->balance = 1 * bsidef; child->balance = -1 * bsidef; node = child; break; case -1: ASSERT(child->link[!bside]) gchild = child->link[!bside]; _BAVL_rotate(o, child, bside); _BAVL_rotate(o, node, !bside); node->balance = -BAVL_MAX(0, gchild->balance * bsidef) * bsidef; child->balance = BAVL_MAX(0, -gchild->balance * bsidef) * bsidef; gchild->balance = 0; node = gchild; delta -= 1; break; default: ASSERT(0); } } ASSERT(delta >= -1 && delta <= 1) // Transformations above preserve this. Proof: // - if a child subtree gained 1 height and rebalancing was needed, // it was the heavier subtree. Then delta was was originally 1, because // the heaviest subtree gained one height. If the transformation reduces // delta by one, it becomes 0. // - if a child subtree lost 1 height and rebalancing was needed, it // was the lighter subtree. Then delta was originally 0, because // the height of the heaviest subtree was unchanged. If the transformation // reduces delta by one, it becomes -1. if (node->parent) { _BAVL_rebalance(o, node->parent, node == node->parent->link[1], delta); } } void BAVL_Init (BAVL *o, int offset, BAVL_comparator comparator, void *user) { o->offset = offset; o->comparator = comparator; o->user = user; o->root = NULL; BAVL_ASSERT(o) } int BAVL_Insert (BAVL *o, BAVLNode *node, BAVLNode **ref) { // insert to root? if (!o->root) { o->root = node; node->parent = NULL; node->link[0] = NULL; node->link[1] = NULL; node->balance = 0; #ifdef BAVL_COUNT node->count = 1; #endif BAVL_ASSERT(o) if (ref) { *ref = NULL; } return 1; } // find node to insert to BAVLNode *c = o->root; int side; while (1) { // compare int comp = _BAVL_compare_nodes(o, node, c); // have we found a node that compares equal? if (comp == 0) { if (ref) { *ref = c; } return 0; } side = (comp == 1); // have we reached a leaf? if (!c->link[side]) { break; } c = c->link[side]; } // insert c->link[side] = node; node->parent = c; node->link[0] = NULL; node->link[1] = NULL; node->balance = 0; #ifdef BAVL_COUNT node->count = 1; #endif #ifdef BAVL_COUNT // update counts for (BAVLNode *p = c; p; p = p->parent) { ASSERT(p->count < UINT64_MAX) p->count++; } #endif // rebalance _BAVL_rebalance(o, c, side, 1); BAVL_ASSERT(o) if (ref) { *ref = c; } return 1; } void BAVL_Remove (BAVL *o, BAVLNode *node) { // if we have both subtrees, swap the node and the largest node // in the left subtree, so we have at most one subtree if (node->link[0] && node->link[1]) { // find the largest node in the left subtree BAVLNode *max = _BAVL_subtree_max(node->link[0]); // swap the nodes _BAVL_swap_nodes(o, node, max); } // have at most one child now ASSERT(!(node->link[0] && node->link[1])) BAVLNode *parent = node->parent; BAVLNode *child = (node->link[0] ? node->link[0] : node->link[1]); if (parent) { // remember on which side node is int side = (node == parent->link[1]); // replace node with child _BAVL_replace_subtree(o, node, child); // rebalance _BAVL_rebalance(o, parent, side, -1); } else { // replace node with child _BAVL_replace_subtree(o, node, child); } BAVL_ASSERT(o) } int BAVL_IsEmpty (const BAVL *o) { return (!o->root); } BAVLNode * BAVL_Lookup (const BAVL *o, void *val) { if (!o->root) { return NULL; } BAVLNode *c = o->root; while (1) { // compare int comp = _BAVL_compare_values(o, val, _BAVL_node_value(o, c)); // have we found a node that compares equal? if (comp == 0) { return c; } int side = (comp == 1); // have we reached a leaf? if (!c->link[side]) { return c; } c = c->link[side]; } } BAVLNode * BAVL_LookupExact (const BAVL *o, void *val) { if (!o->root) { return NULL; } BAVLNode *c = o->root; while (1) { // compare int comp = _BAVL_compare_values(o, val, _BAVL_node_value(o, c)); // have we found a node that compares equal? if (comp == 0) { return c; } int side = (comp == 1); // have we reached a leaf? if (!c->link[side]) { return NULL; } c = c->link[side]; } } BAVLNode * BAVL_GetFirst (const BAVL *o) { if (!o->root) { return NULL; } BAVLNode *n = o->root; while (n->link[0]) { n = n->link[0]; } return n; } BAVLNode * BAVL_GetLast (const BAVL *o) { if (!o->root) { return NULL; } BAVLNode *n = o->root; while (n->link[1]) { n = n->link[1]; } return n; } BAVLNode * BAVL_GetNext (const BAVL *o, BAVLNode *n) { if (n->link[1]) { n = n->link[1]; while (n->link[0]) { n = n->link[0]; } } else { while (n->parent && n == n->parent->link[1]) { n = n->parent; } n = n->parent; } return n; } BAVLNode * BAVL_GetPrev (const BAVL *o, BAVLNode *n) { if (n->link[0]) { n = n->link[0]; while (n->link[1]) { n = n->link[1]; } } else { while (n->parent && n == n->parent->link[0]) { n = n->parent; } n = n->parent; } return n; } #ifdef BAVL_COUNT static uint64_t BAVL_Count (const BAVL *o) { return (o->root ? o->root->count : 0); } static uint64_t BAVL_IndexOf (const BAVL *o, BAVLNode *n) { uint64_t index = (n->link[0] ? n->link[0]->count : 0); for (BAVLNode *c = n; c->parent; c = c->parent) { if (c == c->parent->link[1]) { ASSERT(c->parent->count > c->count) ASSERT(c->parent->count - c->count <= UINT64_MAX - index) index += c->parent->count - c->count; } } return index; } static BAVLNode * BAVL_GetAt (const BAVL *o, uint64_t index) { if (index >= BAVL_Count(o)) { return NULL; } BAVLNode *c = o->root; while (1) { ASSERT(c) ASSERT(index < c->count) uint64_t left_count = (c->link[0] ? c->link[0]->count : 0); if (index == left_count) { return c; } if (index < left_count) { c = c->link[0]; } else { c = c->link[1]; index -= left_count + 1; } } } #endif static void BAVL_Verify (BAVL *o) { if (o->root) { ASSERT(!o->root->parent) _BAVL_assert_recurser(o, o->root); } } #endif