The BSTree data structure

Agenda

  • API
  • Implementation
    • Search
    • Addition
    • Removal
    • Iteration / Traversal

API

In [82]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None
    
    def add(self, val):
        """Adds `val` to this tree while maintaining BSTree properties."""
        assert(val not in self)
        pass
    
    def __contains__(self, val):
        """Returns `True` if val is in this tree and `False` otherwise."""
        pass
    
    def __delitem__(self, val):
        """Removes `val` from this tree while maintaining BSTree properties."""
        assert(val in self)
        pass
    
    def __iter__(self):
        """Returns an iterator over all the values in the tree, in ascending order."""
        pass

    def __len__(self):
        return self.size

    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)
    
    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return max(1+height_rec(t.left), 1+height_rec(t.right))
        return height_rec(self.root)
In [79]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3
In [80]:
t.pprint()
                               5                                
               2                               10               
In [81]:
t.height()
Out[81]:
2

Implementation

In [97]:
class BSTree(BSTree):
    def __contains__(self, val):
        def contains_rec(t):
            if not t:
                return False
            elif t.val == val:
                return True
            elif t.val < val:
                return contains_rec(t.right)
            else: #t.val > val
                return contains_rec(t.left)
        
        return contains_rec(self.root)
In [100]:
#find() copied from the previous notebook, can we use it here?

def find(t, x):
    if not t:
        return False
    elif t.val == x:
        return True
    elif t.val < x:
        return find(t.right, x)
    else: #t.val >x
        return find(t.left, x)
In [90]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3
In [92]:
10 in t
Out[92]:
True

Addition

In [101]:
class BSTree(BSTree):
    def add(self, val):
        def add_rec(t):
            if not t:
                return BSTree.Node(val) #empty spot in the tree
            elif t.val == val: #the value is already in the tree
                return t
            elif t.val < val:
                t.right = add_rec(t.right)
                return t
            else: #t.val > val
                t.left = add_rec(t.left)
                return t
            
        assert(val not in self)
        self.root = add_rec(self.root) #retun the root of the tree
        self.size += 1
In [102]:
t = BSTree()
t.add(10)
In [103]:
t.pprint()
                               10                               
In [104]:
t.add(5)
t.pprint()
                               10                               
               5                               -                
In [105]:
t.add(15)
t.pprint()
                               10                               
               5                               15               

[Note] Need to handle insert first node at root, or at leaf.

That is why add_rec() is written to return a node so the new node is inserted at a leaf (possibly the root) and then the parent nodes are “re-connected” along the path taken.

In [132]:
# RUN MULTIPLE TIMES, NOTE DIFFERENT BSTs

import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()
                               0                                
               -                               1                
       -               -               -               3        
   -       -       -       -       -       -       2       4    
In [133]:
import random
t = BSTree()
vals = list(range(1, 10, 2)) #odd numbes from 1 to 10
random.shuffle(vals)
for x in vals:
    t.add(x)

t.pprint()

#unit tests
assert(all(x in t for x in range(1, 10, 2)))
assert(all(x not in t for x in range(0, 12, 2)))
                               5                                
               1                               9                
       -               3               7               -        

Removal

In [144]:
class BSTree(BSTree):
    def __delitem__(self, val):
        def del_rec(t):
            if not t: #empty spot in the tree, nothing to do
                return None
            elif t.val < val:
                t.right = del_rec(t.right)
                return t
            elif t.val > val:
                t.left = del_rec(t.left)
                return t
            else: #t.val == val
                if not t.left and not t.right:
                    return None
                elif not t.left and t.right:
                    return t.right
                elif not t.right and t.left:
                    return t.left

        assert(val in self)
        #deal with relatively simple case first!
        self.root = del_rec(self.root)
In [147]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               -               -               17       
In [148]:
del t[2]
t.pprint()
                               10                               
               5                               15               
       -               -               -               17       
In [149]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[5]
t.pprint()
                               10                               
               2                               15               
       -               -               -               17       
In [150]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[15]
t.pprint()
                               10                               
               5                               17               
       2               -               -               -        
In [174]:
class BSTree(BSTree):
    def __delitem__(self, val):
        def del_rec(t):
            if not t: # empty spot in tree, nothing to delete
                return None
            elif t.val < val:
                t.right = del_rec(t.right)
                return t
            elif t.val > val:
                t.left = del_rec(t.left)
                return t
            else: # t.val == val
                if not t.left and not t.right:
                    return None
                elif not t.left and t.right:
                    return t.right
                elif not t.right and t.left:
                    return t.left
                else: 
                    # value to remove has both children
                    n = t.left #removal candidate 
                    if not n.right:
                        t.left = n.left
                        t.val = n.val
                        return t
                    else:
                        while n.right.right:
                            n = n.right #n.right is the rightmost node of the left sub-tree

                        t.val = n.right.val
                        n.right = n.right.left
                        return t

        assert(val in self)
        # deal with relatively simple cases first!
        self.root = del_rec(self.root)
In [175]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [176]:
del t[15]
t.pprint()
                               10                               
               5                               12               
       2               7               -               18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [177]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [178]:
del t[5]
t.pprint()
                               10                               
               2                               15               
       1               7               12              18       
   -       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [179]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()
                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  
In [180]:
del t[10]
t.pprint()
                               9                                
               5                               15               
       2               7               12              18       
   1       -       -       8       -       -       -       -    

Iteration / Traversal

In [157]:
#Version 1
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(t):
            if t:
                yield t.val

                for x in iter_rec(t.right):
                    yield x # 're-yielding' of values up the recursion

                for x in iter_rec(t.left):
                    yield x
                
        for x in iter_rec(self.root):
            yield x
In [189]:
#Version 2
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(t):
            if t:
                yield from iter_rec(t.left)
                yield t.val
                yield from iter_rec(t.right) #does the same thing as the for loop

        yield from iter_rec(self.root)
In [192]:
import random
t = BSTree()
vals = list(range(10))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()
                               4                                
               1                               5                
       0               2               -               9        
   -       -       -       3       -       -       8       -    
 -   -   -   -   -   -   -   -   -   -   -   -   7   -   -   -  
- - - - - - - - - - - - - - - - - - - - - - - - 6 - - - - - - - 
In [193]:
for x in t:
    print(x)
0
1
2
3
4
5
6
7
8
9
In [163]:
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(t):
            if t:
                yield from iter_rec(t.left)
                yield t.val
                yield from iter_rec(t.right)
                
        yield from iter_rec(self.root)
In [194]:
class BSTree(BSTree):
    def printt(self):
        def printt_rec(t):
            if t:
                printt_rec(t.left) #print left subtree
                print(t.val)
                printt_rec(t.right) #print right subtree
                
        printt_rec(self.root)
In [195]:
import random
t = BSTree()
vals = list(range(10))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()
                               1                                
               0                               5                
       -               -               4               6        
   -       -       -       -       3       -       -       9    
 -   -   -   -   -   -   -   -   2   -   -   -   -   -   7   -  
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - 8 - - 
In [196]:
t.printt()
0
1
2
3
4
5
6
7
8
9