diff options
author | pzread <netfirewall@gmail.com> | 2013-05-23 00:06:58 +0800 |
---|---|---|
committer | pzread <netfirewall@gmail.com> | 2013-05-23 00:06:58 +0800 |
commit | 0ab07657e48593b569880895d7ae461f3c6b77f5 (patch) | |
tree | c889530f0a741c61729ce0e39a5d8a1f3e8c172d /src | |
parent | e2a22bcdb9cc965dfe0f6e9b53e85b61261dcb8a (diff) | |
download | taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar.gz taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar.bz2 taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar.lz taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar.xz taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.tar.zst taiwan-online-judge-0ab07657e48593b569880895d7ae461f3c6b77f5.zip |
Add AsyncDB transaction support. Update TOJAuth
Diffstat (limited to 'src')
-rw-r--r-- | src/py/asyncdb.py | 237 | ||||
-rw-r--r-- | src/py/imc/async.py | 8 | ||||
-rw-r--r-- | src/py/tojauth.py | 159 |
3 files changed, 360 insertions, 44 deletions
diff --git a/src/py/asyncdb.py b/src/py/asyncdb.py index 408a2eb..725aba4 100644 --- a/src/py/asyncdb.py +++ b/src/py/asyncdb.py @@ -1,6 +1,9 @@ from collections import deque +import time +import random import tornado.ioloop +import tornado.stack_context import psycopg2 import imc.async @@ -9,13 +12,11 @@ class RestrictCursor: def __init__(self,db,cur): self._db = db self._cur = cur + self._ori_cur = cur + self._in_transaction = False + + self._init_implement() - self.fetchone = self._cur.fetchone - self.fetchmany = self._cur.fetchmany - self.fetchall = self._cur.fetchall - self.scroll = self._cur.scroll - self.cast = self._cur.cast - def __iter__(self): return self._cur @@ -29,87 +30,239 @@ class RestrictCursor: self.lastrowid = self._cur.lastrowid self.query = self._cur.query self.statusmessage = self._cur.statusmessage + + def begin(self): + if self._in_transaction == True: + return + + self._cur = self._db.begin_transaction() + self._init_implement() + + self._db.execute(self._cur,'BEGIN;') + + self._in_transaction = True + + def commit(self): + if self._in_transaction == False: + return + + self._db.execute(self._cur,'COMMIT;') + if self._cur.statusmessage == 'COMMIT': + ret = True + + else: + ret = False + + self._db.end_transaction(self._cur.connection) + self._cur = self._ori_cur + + self._in_transaction = False + + return ret + + def rollback(self): + if self._in_transaction == False: + return + + self._db.execute(self._cur,'ROLLBACK;') + + self._db.end_transaction(self._cur.connection) + self._cur = self._ori_cur + + self._in_transaction = False + + def auto_transaction(self,f): + def wrapper(*args,**kwargs): + retry = + while True: + self.begin() + + try: + ret = f(*args,**kwargs) + + except psycopg2.Error: + self.rollback() + continue + + except Exception: + self.rollback() + raise + + if self.commit() == True: + break + + return ret + + return wrapper + + def _init_implement(self): + self.fetchone = self._cur.fetchone + self.fetchmany = self._cur.fetchmany + self.fetchall = self._cur.fetchall + self.scroll = self._cur.scroll + self.cast = self._cur.cast self.tzinfo_factory = self._cur.tzinfo_factory + self.arraysize = 0 + self.itersize = 0 + self.rowcount = 0 + self.rownumber = 0 + self.lastrowid = None + self.query = '' + self.statusmessage = '' + class AsyncDB: def __init__(self,dbname,user,password): self.OPER_CURSOR = 0 self.OPER_EXECUTE = 1 self._ioloop = tornado.ioloop.IOLoop.instance() - self._conn = psycopg2.connect(database = dbname, - user = user, - password = password, - async = 1) - self._connno = self._conn.fileno() - self._pend_oper = deque() - self._oper_callback = None + self._dbname = dbname + self._user = user + self._password = password + self._conn_fdmap = {} + self._free_connpool = [] + self._share_connpool = [] + self._pendoper_fdmap = {} + self._opercallback_fdmap = {} - self._ioloop.add_handler(self._connno, - self._oper_dispatch, - tornado.ioloop.IOLoop.ERROR) + for i in range(8): + conn = self._create_conn() + self._free_connpool.append(conn) - self._oper_dispatch(self._connno,0) + self._ioloop.add_handler(conn.fileno(), + self._oper_dispatch, + tornado.ioloop.IOLoop.ERROR) - @imc.async.callee - def cursor(self,_grid): - self._pend_oper.append((self.OPER_CURSOR,None,_grid)) - self._oper_dispatch(self._connno,0) + self._oper_dispatch(conn.fileno(),0) - cur = imc.async.switchtop() - return RestrictCursor(self,cur) + for i in range(2): + conn = self._create_conn() + self._share_connpool.append(conn) + + self._ioloop.add_handler(conn.fileno(), + self._oper_dispatch, + tornado.ioloop.IOLoop.ERROR) + + self._oper_dispatch(conn.fileno(),0) + + def cursor(self): + return RestrictCursor(self,self._cursor()) @imc.async.callee def execute(self,cur,sql,param = None,_grid = None): - self._pend_oper.append((self.OPER_EXECUTE,(cur,sql,param),_grid)) - self._oper_dispatch(self._connno,0) + fd = cur.connection.fileno() + + self._pendoper_fdmap[fd].append((self.OPER_EXECUTE,(cur,sql,param),_grid)) + self._oper_dispatch(fd,0) imc.async.switchtop() + def begin_transaction(self): + if len(self._free_connpool) > 0: + conn = self._free_connpool.pop() + + else: + conn = self._create_conn() + self._ioloop.add_handler(conn.fileno(), + self._oper_dispatch, + tornado.ioloop.IOLoop.ERROR) + + return self._cursor() + + def end_transaction(self,conn): + if len(self._free_connpool) < 16: + self._free_connpool.append(conn) + + else: + conn.close() + + @imc.async.callee + def _cursor(self,conn = None,_grid = None): + if conn != None: + fd = conn.fileno() + + else: + fd = self._share_connpool[random.randrange(len(self._share_connpool))].fileno() + + self._pendoper_fdmap[fd].append((self.OPER_CURSOR,None,_grid)) + self._oper_dispatch(fd,0) + + cur = imc.async.switchtop() + return cur + + def _create_conn(self): + conn = psycopg2.connect(database = self._dbname, + user = self._user, + password = self._password, + async = 1) + + fd = conn.fileno() + self._conn_fdmap[fd] = conn + self._pendoper_fdmap[fd] = deque() + self._opercallback_fdmap[fd] = None + + return conn + def _oper_dispatch(self,fd,evt): - stat = self._conn.poll() - if stat == psycopg2.extensions.POLL_OK: - self._ioloop.update_handler(self._connno, + err = None + conn = self._conn_fdmap[fd] + try: + stat = conn.poll() + + except Exception as e: + err = e + + if err != None or stat == psycopg2.extensions.POLL_OK: + self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.ERROR) elif stat == psycopg2.extensions.POLL_READ: - self._ioloop.update_handler(self._connno, + self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.READ | tornado.ioloop.IOLoop.ERROR) return elif stat == psycopg2.extensions.POLL_WRITE: - self._ioloop.update_handler(self._connno, + self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.WRITE | tornado.ioloop.IOLoop.ERROR) return - if self._oper_callback != None: - cb = self._oper_callback - self._oper_callback = None - cb() + cb = self._opercallback_fdmap[fd] + if cb != None: + self._opercallback_fdmap[fd] = None + cb(err) else: try: - oper,data,grid = self._pend_oper.popleft() + oper,data,grid = self._pendoper_fdmap[fd].popleft() except IndexError: return if oper == self.OPER_CURSOR: - def _ret_cursor(): - imc.async.retcall(grid,self._conn.cursor()) + def _ret_cursor(err = None): + if err == None: + imc.async.retcall(grid,conn.cursor()) + + else: + imc.async.retcall(grid,err = err) - self._oper_callback = _ret_cursor + self._opercallback_fdmap[fd] = _ret_cursor elif oper == self.OPER_EXECUTE: - def _ret_execute(): - imc.async.retcall(grid,None) + def _ret_execute(err = None): + if err == None: + imc.async.retcall(grid,None) + + else: + imc.async.retcall(grid,err = err) cur,sql,param = data cur.execute(sql,param) - self._oper_callback = _ret_execute + self._opercallback_fdmap[fd] = _ret_execute - self._ioloop.add_callback(self._oper_dispatch,self._connno,0) + self._ioloop.add_callback(self._oper_dispatch,fd,0) diff --git a/src/py/imc/async.py b/src/py/imc/async.py index 1eb0409..59cfc5e 100644 --- a/src/py/imc/async.py +++ b/src/py/imc/async.py @@ -61,7 +61,7 @@ def caller(f): return wrapper -def retcall(grid,result): +def retcall(grid,value = None,err = None): global gr_waitmap try: @@ -70,8 +70,12 @@ def retcall(grid,result): old_iden = auth.current_iden auth.current_iden = iden - gr.switch(result) + if err == None: + gr.switch(value) + else: + gr.throw(err) + auth.current_iden = old_iden except Exception as err: diff --git a/src/py/tojauth.py b/src/py/tojauth.py index d5c13c3..77dd2df 100644 --- a/src/py/tojauth.py +++ b/src/py/tojauth.py @@ -1,6 +1,15 @@ from imc.auth import Auth +import config +from asyncdb import AsyncDB class TOJAuth(Auth): + ACCESS_READ = 0x1 + ACCESS_WRITE = 0x2 + ACCESS_CREATE = 0x4 + ACCESS_DELETE = 0x8 + ACCESS_SETPER = 0x10 + ACCESS_EXECUTE = 0x20 + def __init__(self,pubkey,privkey = None): super().__init__() @@ -9,6 +18,8 @@ class TOJAuth(Auth): self.set_signkey(privkey) TOJAuth.instance = self + TOJAuth.db = AsyncDB(config.CORE_DBNAME,config.CORE_DBUSER, + config.CORE_DBPASSWORD) def create_iden(self,linkclass,linkid): iden = { @@ -26,3 +37,151 @@ class TOJAuth(Auth): return None return iden + + def check_access(self, accessid, access_mask): + def wrapper(f): + idenid = self.current_iden['idenid'] + ok = False + + cur = self.db.cursor() + + if not ok: + sqlstr = ('SELECT "owner_idenid" FROM "ACCESS" WHERE ' + '"accessid"=%s;') + sqlarr = (accessid, ) + cur.execute(sqlstr, sqlarr) + for data in cur: + owner_idenid = data[0] + if owner_idenid == idenid: + ok = True + + if not ok: + sqlstr = ('SELECT "ACCESS_ROLE"."permission" FROM "ACCESS_ROLE"' + ' INNER JOIN "IDEN_ROLE" ON "ACCESS_ROLE"."roleid" = ' + '"IDEN_ROLE"."roleid" WHERE "ACCESS_ROLE"."accessid"=%s' + ' AND "IDEN_ROLE"."idenid"=%s;') + sqlarr = (accessid, idenid) + cur.execute(sqlstr, sqlarr) + + for data in cur: + permission = data[0] + if (permission & access_mask) == access_mask: + ok = True + break + + if ok: + return f + else: + raise Exception('TOJAuth.check_access() : PERMISSION DENIED') + + return wrapper + + def create_access(self): + self.check_access(self.auth_accessid, self.ACCESS_EXECUTE)(0) + cur = self.db.cursor() + sqlstr = ('INSERT INTO "ACCESS" ("owner_idenid") VALUES (%s) ' + 'RETURNING "accessid";') + sqlarr = (self.current_iden['idenid'], ) + cur.execute(sqlstr, sqlarr) + for data in cur: + accessid = data[0] + return accessid + + def set_access_list(self, accessid, roleid, permission): + self.check_access(accessid, self.ACCESS_SETPER)(0) + + def _db_write(accessid, roleid, permission): + cur = self.db.cursor() + if not self._does_access_list_exist(cur, accessid, roleid): + sqlstr = ('INSERT INTO "ACCESS_ROLE" ("accessid", "roleid", ' + '"permission") VALUES (%s, %s, %s);') + sqlarr = (accessid, roleid, permission) + else: + sqlstr = ('UPDATE "ACCESS_ROLE" SET "permission"=%s ' + 'WHERE "accessid"=%s AND "roleid"=%s;') + sqlarr = (permission, accessid, roleid) + cur.execute(sqlstr, sqlarr) + + _db_write(accessid, roleid, permission) + + def del_access_list(self, accessid, roleid): + self.check_access(accessid, self.ACCESS_SETPER)(0) + + def _db_write(accessid, roleid): + cur = self.db.cursor() + if self._does_access_list_exist(cur, accessid, roleid): + sqlstr = ('DELETE FROM "ACCESS_ROLE" WHERE "accessid"=%s ' + 'AND "roleid"=%s;') + sqlarr = (accessid, roleid) + cur.execute(sqlstr, sqlarr) + else: + raise Exception('TOJAuth.del_access_list() : Access object ' + 'doesn\'t exist') + + _db_write(accessid, roleid) + + def _does_access_list_exist(self, cur, accessid, roleid): + sqlstr = ('SELECT COUNT(*) FROM "ACCESS_ROLE" WHERE ' + '"accessid"=%s AND "roleid"=%s;') + sqlarr = (accessid, roleid) + cur.execute(sqlstr, sqlarr) + for data in cur: + count = data[0] + return count>0 + + def create_role(self, rolename, roletype): + self.check_access(self.auth_accessid, self.ACCESS_EXECUTE)(0) + cur = self.db.cursor() + sqlstr = ('INSERT INTO "ROLE" ("rolename") VALUES (%s)' + ' RETURNING "roleid";') + sqlarr = (rolename, ) + cur.execute(sqlstr, sqlarr) + for data in cur: + roleid = data[0] + return roleid + + def set_role_relation(self, idenid, roleid): + self.check_access(self.auth_accessid, self.ACCESS_EXECUTE)(0) + + def _db_write(idenid, roleid): + cur = self.db.cursor() + if not self._does_role_relation_exist(cur, idenid, roleid): + sqlstr = ('INSERT INTO "IDEN_ROLE" ("idenid", "roleid") ' + 'VALUES (%s, %s);') + sqlarr = (idenid, roleid) + cur.execute(sqlstr, sqlarr) + + _db_write(idenid, roleid) + + def del_role_relation(self, idenid, roleid): + self.check_access(self.auth_accessid, self.ACCESS_EXECUTE)(0) + + def _db_write(idenid, roleid): + cur = self.db.cursor() + if self._does_role_relation_exist(cur, idenid, roleid): + sqlstr = ('DELETE FROM "IDEN_ROLE" WHERE "idenid"=%s ' + 'AND "roleid"=%s;') + sqlarr = (idenid, roleid) + cur.execute(sqlstr, sqlarr) + else: + raise Exception('TOJAuth.del_role_relation() : Role relation ' + 'doesn\'t exist') + + _db_write(idenid, roleid) + + def _does_role_relation_exist(self, cur, idenid, roleid): + sqlstr = ('SELECT COUNT(*) FROM "IDEN_ROLE" WHERE "idenid"=%s ' + 'AND "roleid"=%s;') + sqlarr = (idenid, roleid) + cur.execute(sqlstr, sqlarr) + for data in cur: + count = data[0] + return count>0 + + def set_owner(self, idenid, accessid): + self.check_access(accessid, self.ACCESS_SETPER)(0) + cur = self.db.cursor() + sqlstr = ('UPDATE "ACCESS" SET "owner_idenid"=%s WHERE "accessid"=%s;') + sqlarr = (idenid, accessid) + cur.execute(sqlstr, sqlarr) + |