勉強用にO/R Mapperを自作中 - MasaHeroの日記の続き
メタクラスとか使ったらもちょっと使いやすくなるのではと思い現時点まで出来ているものを廃棄。もったいないので一応記念にここに残しておく。
#!/usr/bin/python # -*- coding: utf8 -*- import sqlite3 class DBField: """DBのフィールドを表すクラスの基底""" def __init__(self, unique=False, null=False, db_index=False, rel=None, default=None, type="blob"): """初期化""" self.unique = unique self.null = null self.db_index = db_index self.rel = rel self.default = default self.name = None self.type = type def set_name(self,name): """フィールド名の設定""" self.name = name def get_sql(self): """SQLの生成""" rets=[] if not self.name: return "" rets.append(self.name) rets.append(self.type) if self.unique: rets.append("unique") if not self.null: rets.append("not null") if self.default: rets.append("default") rets.append(self.default) ret = " ".join(rets) return ret def get_indexsql(self): if not self.db_index: return None rets = [] rets.append("create index") rets.append("if not exists") rets.append(self.name + "_index") rets.append("on %s (") rets.append(self.name) rets.append(");") ret = " ".join(rets) return ret class DBTable: """ sqlite3決め打ちのO/R Mapper. クラス変数 connection : データベースへのコネクション """ connection = sqlite3.connect(":memory:") #サンプルとしてin-memory @classmethod def get_conection(cls): return cls.conection @classmethod def set_connection(cls, con): cls.connection = con @classmethod def get_rownames(cls): rownames = set(dir(cls)) ^ set(dir(DBTable)) ret = [] for rowname in rownames: row = getattr(cls, rowname) if isinstance(row, DBField): ret.append(rowname) if not row.get_sql(): row.set_name(rowname) return ret @classmethod def get_table_sql(cls): """SQLの生成""" columns = "" for rowname in cls.get_rownames(): row = getattr(cls, rowname) columns += row.get_sql() + ",\n\t" sql = "create table if not exist %s (\n\tid " + \ "INTEGR PRIMARY KEY AUTOINCREMENT ," + \ "\t%s\n);" sql = sql%(cls.__name__, columns) return sql @classmethod def make_table_index_sql(cls): """indexの生成""" cur = cls.get_connection().cursor() for rowname in cls.get_rownames(): row = getattr(cls, rowname) if row.get_indexsql(): sql = row.get_indexsql() sql = sql%(cls.__name__,) cur.execute(sql) cur.close() cls.get_connection().commit() @classmethod def create_table(cls): """テーブルの生成""" cur = cls.get_connection().cursor() sql = cls.get_table_sql() cur.execute(sql) cur.close() cls.get_connection().commit() def __init__(self, **kw): rownames = self.get_rownames() if "id" in kw.keys(): rownamess = ", ".join(rownames) tablename = self.__class__.__name__ sql = "select %s from %s where id=?"%(rownamess,tablename) try: cur = self.getconnection().cursor() cur.execute(sql, (kw["id"],)) for rowname, value in zip(rownames, cur.fetchone()): setattr(self, rowname, value) cur.close() except OperationalError: #テーブルがまだない self.create_table() except TypeError: #そのidにあたる行がまだない pass else: self.id=None if kw: for rowname in set(rownames) & set(kw.keys()): setattr(self, rowname, kw[rowname]) def save(self): """データベースに登録""" if self.get_rownames(): if self.id: cur = self.getconnection().cursor() tablename = self.__class__.__name__ sql = "select * from %s where id=?"%(tablename,) cur.execute(sql, (self.id,)) if cur.fetchall(): self.__update() else: self.__insert() cur.close() else: self.__insert() def __update(self): tablename = self.__class__.__name__ sql = "update " + tablename + " set " holder = ", ".join(["%s=?"%v for v in self.get_rownames()]) sql += holder + " where id=%d"%self.id value = [getattr(self, v) for v in self.get_rownames()]) cur = self.getconnection().cursor() cur.execute(sql, value) self.getconnection().commit() cur.close() def __insert(self): tablename = self.__class__.__name__ sql = "insert into " + tablename + "(" holder = ", ".join(["%s"%v for v in self.get_rownames()]) sql += holder + ") values(" holder = ", ".join(["?" for v in self.get_rownames()]) sql += holder + ")" value = [getattr(self, v) for v in self.get_rownames()]) cur = self.getconnection().cursor() cur.execute(sql, value) cur.execute("select max(id) from %s"%tablename) newid = cur.fetchone()[0] setattr(self, "id", newid) self.getconnection().commit() cur.close() def __repr__(self): """文字列表記""" ret = "<%s:"%self.__class__.__name__ if self.id: ret += " %d>"%self.id else: ret += " not save>" return ret