#coding=utf-8

'''
pip install pymysql
'''

import pymysql
from threadlock import ThreadLock

class DB(object):

    __slots__ = ['__host', '__port', '__user', '__passwd', '__dbName', '__conn', '__cur', '__ispersist', '__log', '__is_mutex']

    def __init__(self, host, port, user, passwd, dbName, ispersist=False, log=None, is_mutex = False):
        self.__host = host
        self.__port = port
        self.__user = user
        self.__passwd = passwd
        self.__dbName = dbName
        self.__conn = None
        self.__cur = None
        self.__ispersist = ispersist
        self.__log = log
        self.__is_mutex = is_mutex
        
        if (self.__ispersist):
            DB.connect(self)
            
    def log_error(self, err_info):
        if (self.__log):
            self.__log.error(err_info)

    def connect(self):
        try:
            self.__conn = pymysql.Connect(
                host = self.__host,
                port = self.__port,
                user = self.__user,
                passwd = self.__passwd,
                db = self.__dbName,
                charset = 'utf8',
				cursorclass=pymysql.cursors.DictCursor
            )
            self.__cur = self.__conn.cursor()
        except Exception as ex:
            self.log_error('db connect exception: ' + str(ex))    
            raise ex
    
    def close(self):
        try:
            self.__cur.close()
            self.__conn.close()
        except Exception as ex:
            self.log_error('db close exception: ' + str(ex))    
            raise ex

    def lastrowid(self):
        return self.__cur.lastrowid
        
    def lock(self):
        if self.__is_mutex:
            ThreadLock.lock()

    def unlock(self):
        if self.__is_mutex:
            ThreadLock.unlock()

    def update(self, sql):
        self.lock()    
        
        if (not self.__ispersist):
            try:
                self.connect()
            except Exception as ex:
                self.unlock()
                raise ex
        else:
            try:
                self.__conn.ping(True)
            except Exception as ex:
                self.log_error(str(ex))
                try:
                    self.connect()
                except Exception as ex:
                    self.unlock()
                    raise ex

        effect_rows = 0
        try:
            effect_rows = self.__cur.execute(sql)
            self.__conn.commit()
        except Exception as ex:
            self.unlock()
            self.log_error('db update exception: ' + str(ex))    
            raise ex

        if (not self.__ispersist):
            try:
                self.close()
            except Exception as ex:
                self.unlock()
                raise ex
        
        self.unlock()
        return effect_rows

    def query(self, sql):
        self.lock()

        if (not self.__ispersist):
            try:
                self.connect()
            except Exception as ex:
                self.unlock()
                raise ex
        else:
            try:
                self.__conn.ping(True)
            except Exception as ex:
                self.log_error(str(ex))
                try:
                    self.connect()
                except Exception as ex:
                    self.unlock()
                    raise ex

        result = None
        try:
            self.__cur.execute(sql)
            result = self.__cur.fetchall()
            self.__conn.commit()
        except Exception as ex:
            self.unlock()
            self.log_error('db query exception: ' + str(ex))
            raise ex

        if (not self.__ispersist):
            try:
                self.close()
            except Exception as ex:
                self.unlock()
                raise ex

        self.unlock()
        return result

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

class DBCluster(object):
    def __init__(self, masterDb, slaveDbList = []):
        self._masterDb = masterDb
        self._slaveDbList = slaveDbList
        self._slaveDbCount = len(slaveDbList)
        self._curSlave = 0

    def query(self, sql):
        if len(self._slaveDbList) == 0:
            return self._masterDb.query(sql)
        self._curSlave = self._curSlave + 1
        slaveIndex = self._curSlave % self._slaveDbCount
        slaveDb = self._slaveDbList[slaveIndex]
        return slaveDb.query(sql)

    def queryInstance(self):
        self._curSlave = self._curSlave + 1
        slaveIndex = self._curSlave % self._slaveDbCount
        return slaveIndex

    def update(self, sql):
        return self._masterDb.update(sql)

class DBProxy(object):
    def __init__(self, masterDb, slaveDbList = []):
        self._masterDb = masterDb
        self._slaveDbList = slaveDbList
        self._slaveDbCount = len(slaveDbList)
        self._curSlave = 0

    def query(self, sql):
        if len(self._slaveDbList) == 0:
            return self._masterDb.query(sql)
        self._curSlave = self._curSlave + 1
        slaveIndex = self._curSlave % self._slaveDbCount
        slaveDb = self._slaveDbList[slaveIndex]
        return slaveDb.query(sql)

    def queryInstance(self):
        self._curSlave = self._curSlave + 1
        slaveIndex = self._curSlave % self._slaveDbCount
        return slaveIndex

    def update(self, sql):
        return self._masterDb.update(sql)

from pymongo import MongoClient

class MongoCollection:

    def __init__(self, coll_name, db):
        self._coll_name = coll_name
        self._db = db
        self._coll = self._db[coll_name]

    def add(self, obj):
        effect_num = 0
        obj_type = type(obj)
        if obj_type != list and obj_type != dict: return 0
        to_add_list = []
        if obj_type == list: to_add_list = obj
        else: to_add_list.append(obj)
        for item in to_add_list:
            self._coll.insert(item)
            effect_num = effect_num + 1
        return effect_num

    def delete(self, condition):
        ret = self._coll.remove(condition)
        effect_num = ret['n']
        return effect_num

    def update(self, condition, to_update, multi = True):
        to_update = {"$set" : to_update}
        ret = self._coll.update(condition, to_update, multi= multi)
        effect_num = ret['nModified']
        return effect_num

    def find_one(self, condition):
        return self._coll.find_one(condition)

    def find(self, condition, limit_cnt = -1, skip_cnt = -1):
        rows = []
        if limit_cnt > 0 and skip_cnt > 0:
            rows = [row for row in self._coll.find(condition).limit(limit_cnt).skip(skip_cnt)]
        elif limit_cnt > 0:
            rows = [row for row in self._coll.find(condition).limit(limit_cnt)]
        elif skip_cnt > 0:
            rows = [row for row in self._coll.find(condition).skip(skip_cnt)]
        else:
            rows = [row for row in self._coll.find(condition)]
        return rows

class MongoDB:

    def __init__(self, db_name, conn):
        self._db_name = db_name
        self._conn = conn
        self._db = self._conn[db_name]

    def select_coll(self, collection_name):
        coll = MongoCollection(collection_name, self._db)
        return coll

    def delete_coll(self, coll_name):
        ret = self._db.drop_collection(coll_name)
        return self

class Mongo:

    def __init__(self, host, port = 27017):
        self._host = host
        self._port = port
        self._conn = MongoClient(self._host, self._port)

    def select_db(self, db_name):
        db = MongoDB(db_name, self._conn)
        return db