とりあえず形になってきたが……

勉強用にO/R Mapperを自作中 - MasaHeroの日記の続き
メタクラスとか使ったらもちょっと使いやすくなるのではと思い現時点まで出来ているものを廃棄。もったいないので一応記念にここに残しておく。

#!/usr/bin/python
# -*- coding: utf8 -*-

import sqlite3

class DBField:
    """DBのフィールドを表すクラスの基底"""
    def __init__(self, unique=False, null=False, db_index=False, rel=None,
            default=None, type="blob"):
        """初期化"""
        self.unique = unique
        self.null = null
        self.db_index = db_index
        self.rel = rel
        self.default = default
        self.name = None
        self.type = type

    def set_name(self,name):
        """フィールド名の設定"""
        self.name = name

    def get_sql(self):
        """SQLの生成"""
        rets=[]
        if not self.name:
            return ""
        rets.append(self.name)
        rets.append(self.type)
        if self.unique:
            rets.append("unique")
        if not self.null:
            rets.append("not null")
        if self.default:
            rets.append("default")
            rets.append(self.default)
        ret = " ".join(rets)
        return ret

    def get_indexsql(self):
        if not self.db_index:
            return None
        rets = []
        rets.append("create index")
        rets.append("if not exists")
        rets.append(self.name + "_index")
        rets.append("on %s (")
        rets.append(self.name)
        rets.append(");")
        ret = " ".join(rets)
        return ret

class DBTable:
    """
    sqlite3決め打ちのO/R Mapper.
    クラス変数
    connection : データベースへのコネクション
    """

    connection = sqlite3.connect(":memory:") #サンプルとしてin-memory

    @classmethod
    def get_conection(cls):
        return cls.conection

    @classmethod
    def set_connection(cls, con):
        cls.connection = con

    @classmethod
    def get_rownames(cls):
        rownames = set(dir(cls)) ^ set(dir(DBTable))
        ret = []
        for rowname in rownames:
            row = getattr(cls, rowname)
            if isinstance(row, DBField):
                ret.append(rowname)
                if not row.get_sql():
                    row.set_name(rowname)
        return ret

    @classmethod
    def get_table_sql(cls):
        """SQLの生成"""
        columns = ""
        for rowname in cls.get_rownames():
            row = getattr(cls, rowname)
            columns += row.get_sql() + ",\n\t"
        sql = "create table if not exist %s (\n\tid " + \
                "INTEGR PRIMARY KEY AUTOINCREMENT ," + \
                "\t%s\n);"
        sql = sql%(cls.__name__, columns)
        return sql

    @classmethod
    def make_table_index_sql(cls):
        """indexの生成"""
        cur = cls.get_connection().cursor()
        for rowname in cls.get_rownames():
            row = getattr(cls, rowname)
            if row.get_indexsql():
                sql = row.get_indexsql()
                sql = sql%(cls.__name__,)
                cur.execute(sql)
        cur.close()
        cls.get_connection().commit()

    @classmethod
    def create_table(cls):
        """テーブルの生成"""
        cur = cls.get_connection().cursor()
        sql = cls.get_table_sql()
        cur.execute(sql)
        cur.close()
        cls.get_connection().commit()

    def __init__(self, **kw):
        rownames = self.get_rownames()
        if "id" in kw.keys():
            rownamess = ", ".join(rownames)
            tablename = self.__class__.__name__
            sql = "select %s from %s where id=?"%(rownamess,tablename)
            try:
                cur = self.getconnection().cursor()
                cur.execute(sql, (kw["id"],))
                for rowname, value in zip(rownames, cur.fetchone()):
                    setattr(self, rowname, value)
                cur.close()
            except OperationalError: #テーブルがまだない
                self.create_table()
            except TypeError: #そのidにあたる行がまだない
                pass
        else:
            self.id=None
        if kw:
            for rowname in set(rownames) & set(kw.keys()):
                setattr(self, rowname, kw[rowname])

    def save(self):
        """データベースに登録"""
        if self.get_rownames():
            if self.id:
                cur = self.getconnection().cursor()
                tablename = self.__class__.__name__
                sql = "select * from %s where id=?"%(tablename,)
                cur.execute(sql, (self.id,))
                if cur.fetchall():
                    self.__update()
                else:
                    self.__insert()
                cur.close()
            else:
                self.__insert()

    def __update(self):
        tablename = self.__class__.__name__
        sql = "update " + tablename + " set "
        holder = ", ".join(["%s=?"%v for v in self.get_rownames()])
        sql += holder + " where id=%d"%self.id
        value = [getattr(self, v) for v in self.get_rownames()])
        cur = self.getconnection().cursor()
        cur.execute(sql, value)
        self.getconnection().commit()
        cur.close()

    def __insert(self):
        tablename = self.__class__.__name__
        sql = "insert into " + tablename + "("
        holder = ", ".join(["%s"%v for v in self.get_rownames()])
        sql += holder + ") values("
        holder = ", ".join(["?" for v in self.get_rownames()])
        sql += holder + ")"
        value = [getattr(self, v) for v in self.get_rownames()])
        cur = self.getconnection().cursor()
        cur.execute(sql, value)
        cur.execute("select max(id) from %s"%tablename)
        newid = cur.fetchone()[0]
        setattr(self, "id", newid)
        self.getconnection().commit()
        cur.close()

    def __repr__(self):
        """文字列表記"""
        ret = "<%s:"%self.__class__.__name__
        if self.id:
            ret += " %d>"%self.id
        else:
            ret += " not save>"
        return ret