数独ソルバー

何となく作ってみた。

class SudokuSolver:
    """9x9 sudoku solver"""
    def __init__(self,question):
        if not len(question)==81:
            raise Error()
        self.buf=question
    def get_rows(self,x):
        """O<=x<9"""
        if x<0 or x>=9:
           return []
        return self.buf[x*9:(x+1)*9]
    def get_columns(self,x):
        """0<=x<9"""
        if x<0 or x>=9:
           return []
        return self.buf[x::9]
    def get_groups(self,pos):
        """0<=pos<81"""
        if pos<0 or pos>=81:
            print "range error"
            return []
        y=pos/9/3
        x=pos%9/3
        base=y*27+x*3
        return self.buf[base:base+3]+self.buf[base+9:base+12]+\
                        self.buf[base+18:base+21]
    def get_pos(self,x):
        """0<=pos<81"""
        y=pos/9
        x=pos%9
        return x,y
    def prints(self):
        print self.buf[0:3],self.buf[3:6],self.buf[6:9]
        print self.buf[9:12],self.buf[12:15],self.buf[15:18]
        print self.buf[18:21],self.buf[21:24],self.buf[24:27]
        print ""
        print self.buf[27:30],self.buf[30:33],self.buf[33:36]
        print self.buf[36:39],self.buf[39:42],self.buf[42:45]
        print self.buf[45:48],self.buf[48:51],self.buf[51:54]
        print ""
        print self.buf[54:57],self.buf[57:60],self.buf[60:63]
        print self.buf[63:66],self.buf[66:69],self.buf[69:72]
        print self.buf[72:75],self.buf[75:78],self.buf[78:81]
    def count_zero(self,buf):
        cnt=0
        for x in range(len(buf)):
            if buf[x]==0:
                cnt=cnt+1
        return cnt
    def solve_single(self):
        zcnt=self.count_zero(self.buf)
        for x in range(81):
            if self.buf[x]==0:
                buf_set=self.serch_num(x)
                if len(buf_set)==1:
                    self.buf[x]=buf_set.pop()
        return zcnt-self.count_zero(self.buf)
    def serch_num(self,pos):
        num_buf=self.buf[pos]
        self.buf[pos]=0
        buf=set(range(10))
        groups=set(self.get_groups(pos))
        columns=set(self.get_columns(pos%9))
        rows=set(self.get_rows(pos/9))
        allset=groups.union(columns)
        allset=allset.union(rows)
        buf=buf.difference(allset)
        self.buf[pos]=num_buf
        return buf
    def solve_row(self):
        zcnt=self.count_zero(self.buf)
        for rows in range(9):
            pos_list=range(rows*9,rows*9+9)
            buf_set_list=[self.serch_num(pos_list[0])]
            for x in range(1,9):
                buf_set_list.append(self.serch_num(pos_list[x]))
            for x in range(9):
                buf=buf_set_list[x]
                nl=set(range(9))
                nl.remove(x)
                for y in nl:
                    if self.buf[pos_list[y]]==0:
                        buf=buf.difference(buf_set_list[y])
                if len(buf)==1:
                    if self.buf[pos_list[x]]==0:
                        self.buf[pos_list[x]]=buf.pop()
        return zcnt-self.count_zero(self.buf)
    def solve_column(self):
        zcnt=self.count_zero(self.buf)
        for column in range(9):
            pos_list=[column,column+9,column+18]
            pos_list=pos_list+[column+27,column+36,column+45]
            pos_list=pos_list+[column+54,column+63,column+72]
            buf_set_list=[self.serch_num(pos_list[0])]
            for x in range(1,9):
                buf_set_list.append(self.serch_num(pos_list[x]))
            for x in range(9):
                buf=buf_set_list[x]
                nl=set(range(9))
                nl.remove(x)
                for y in nl:
                    if self.buf[column+y*9]==0:
                        buf=buf.difference(buf_set_list[y])
                if len(buf)==1:
                    if self.buf[column+x*9]==0:
                        self.buf[column+x*9]=buf.pop()
        return zcnt-self.count_zero(self.buf)
    def solve_group(self):
        zcnt=self.count_zero(self.buf)
        for group in range(9):
            base=group%3*3+group/3*27
            pos_list=[base,base+1,base+2]
            pos_list=pos_list+[base+9,base+10,base+11]
            pos_list=pos_list+[base+18,base+19,base+20]
            buf_set_list=[self.serch_num(pos_list[0])]
            for x in range(1,9):
                buf_set_list.append(self.serch_num(pos_list[x]))
            for x in range(9):
                buf=buf_set_list[x]
                nl=set(range(9))
                nl.remove(x)
                for y in nl:
                    if self.buf[pos_list[y]]==0:
                        buf=buf.difference(buf_set_list[y])
                if len(buf)==1:
                    if self.buf[pos_list[x]]==0:
                        self.buf[pos_list[x]]=buf.pop()
        return zcnt-self.count_zero(self.buf)
    def solve_all(self):
        ret=self.solve_group()
        ret=ret+self.solve_row()
        ret=ret+self.solve_column()
        ret=ret+self.solve_single()
        return ret
    def solve_all_try(self):
        zcnt=self.count_zero(self.buf)
        while True:
            if self.solve_all()==0:
                break
        return zcnt-self.count_zero(self.buf)
    def prints_memo(self):
        buf=range(81)
        for i in range(81):
            if self.buf[i]==0:
                buf[i]=self.serch_num(i)
            else:
                buf[i]=self.buf[i]
        print buf[0:3],buf[3:6],buf[6:9]
        print buf[9:12],buf[12:15],buf[15:18]
        print buf[18:21],buf[21:24],buf[24:27]
        print ""
        print buf[27:30],buf[30:33],buf[33:36]
        print buf[36:39],buf[39:42],buf[42:45]
        print buf[45:48],buf[48:51],buf[51:54]
        print ""
        print buf[54:57],buf[57:60],buf[60:63]
        print buf[63:66],buf[66:69],buf[69:72]
        print buf[72:75],buf[75:78],buf[78:81]
    def solve_te(self):
        """try&error"""
        zcnt=self.count_zero(self.buf)
        backup_buf=self.buf[:]
        pos=0
        set_buf=set(range(10))
        for i in range(81):
            if self.buf[i]==0:
                if len(set_buf) > len(self.serch_num(i)):
                    set_buf=self.serch_num(i).copy()
                    pos=i
        for tmp in set_buf:
            self.buf[pos]=tmp
            self.solve_all_try()
            if self.validate():
                break
        if not self.validate():
            self.buf=backup_buf[:]
        return zcnt-self.count_zero(self.buf)
    def validate(self):
        for i in range(81):
            if not self.buf[i]==0:
                set_buf=self.serch_num(i)
                if self.buf[i] not in set_buf:
                    return False
        return True

import unittest

class UnitTestSudokuSolver(unittest.TestCase):
    def setUp(self):
        self.buf=[0,0,0,0,2,0,0,9,0,0,0,0,0,6,3,0,0,8,3,0,0,0,0,8,1,4,0,0,0,0,0,
4,0,8,0,7,0,8,4,0,0,0,6,1,0,1,0,7,0,5,0,0,0,0,0,1,5,9,0,0,0,0,2,9,0,0,4,8,0,0,0,
0,0,2,0,0,1,0,0,0,0]
        pass
    def test_group(self):
        print "solve_group() test"
        s=SudokuSolver(self.buf)
        s.prints()
        print s.solve_group()
        s.prints()
    def test_serch_num(self):
        s=SudokuSolver(self.buf)
        buf=s.serch_num(3)
        if 1 not in buf or 5 not in buf or 7 not in buf:
            self.fail("now:%s expected:%s\n"%(buf,[1,5,7]))
        if 2 in buf or 3 in buf or 4 in buf or 6 in buf or 8 in buf or 9 in buf:
            self.fail("now:%s expected:%s\n"%(buf,[1,5,7]))
    def test_get_groups(self):
        s=SudokuSolver(self.buf)
        buf=s.get_groups(3)
        if not len(buf)==9:
            self.fail("get_groups(3):%s"%buf)
        if not min(buf)==0:
            self.fail("get_groups(3):%s"%buf)
        if not max(buf)==8:
            self.fail("get_groups(3):%s"%buf)
    def test_solves(self):
        print "test_solves()"
        s=SudokuSolver(self.buf)
        s.prints()
        print s.solve_all_try()
        print s.solve_te()
        s.prints_memo()
        self.assert_(s.validate())
        


if __name__ == '__main__':
    unittest.main()