Bツリー

今までこれといってデータ構造について勉強してこなかったので勉強ついでに書いてみた。

#!/usr/bin/python
# -*- coding: utf8 -*-

class BTree:
    """Binary Tree"""
    def __init__(self,index=None,value=None):
        self.gt=None
        self.lt=None
        self.index=index
        self.value=value
    def add(self,index,value):
        """ノードの追加"""
        if not self.index:
            self.index=index
            self.value=value
        else:
            if index > self.index:
                if self.gt:
                    self.gt.add(index,value)
                else:
                    self.gt=BTree(index,value)
            elif index < self.index:
                if self.lt:
                    self.lt.add(index,value)
                else:
                    self.lt=BTree(index,value)
            elif index == self.index:
                self.value=value
            else:
                 raise Exception()
    def get_depth(self):
        """深さの計測"""
        ret=1
        if self.gt:
            if self.lt:
                if self.gt.get_depth() > self.lt.get_depth():
                    ret=ret+self.gt.get_depth()
                else:
                    ret=ret+self.lt.get_depth()
            else:
                ret=ret+self.gt.get_depth()
        elif self.lt:
            ret=ret+self.lt.get_depth()
        return ret
    def get_value(self,index):
        """indexに対応するvalueの探索"""
        if index == self.index:
            return self.value
        elif index > self.index and self.gt :
            return self.gt.get_value(index)
        elif index < self.index and self.lt:
            return self.lt.get_value(index)
        else:
            return None
        return None
    def optimize(self):
        """ツリーの最適化"""
        len_gt=0
        len_lt=0
        ret=self
        if self.gt:
            self.gt=self.gt.optimize()
            len_gt=self.gt.get_depth()
        if self.lt:
            self.lt=self.lt.optimize()
            len_lt=self.lt.get_depth()
        if len_gt-len_lt>=2:
            ret=self.gt
            self.gt=None
            ret.relocate(self)
            ret.lt=ret.lt.optimize()
            ret=ret.optimize()
        elif len_lt-len_gt>=2:
            ret=self.lt
            self.lt=None
            ret.relocate(self)
            ret.gt=ret.gt.optimize()
            ret=ret.optimize()
        return ret
    def relocate(self,node):
        """ノードを接ぎ木"""
        if self.index < node.index:
            if self.gt:
                self.gt.relocate(node)
            else:
                self.gt=node
        elif self.index > node.index:
            if self.lt:
                self.lt.relocate(node)
            else:
                self.lt=node
        else:
            raise Exception()
    def __str__(self):
        return "%s(%s,%s)"%(self.index,self.gt,self.lt) 

Pythonインタープリターの制限か、大量にデータを挿入するとこれ以上再帰できないとエラー*1が出る。まあ勉強用なので問題なし。

*1:RuntimeError: maximum recursion depth exceeded