#-*- coding:utf8 -*-
__author__ = 'zhonghong'

import pymysql
from pymysql import converters
import re
import time
from DBUtils.PooledDB import PooledDB


from zapi.core.utils import FlyweightMixin


class Pool(FlyweightMixin):
    def __init__(self, **kwargs):
        self.pool = PooledDB(pymysql, **kwargs)


def reconnect(func):
    def _wrapper(obj=None, *args, **kwargs):
        for i in xrange(3):
            try:
                obj.conn.ping(True)
                break
            except:
                time.sleep(0.5)
        return func(obj, *args, **kwargs)
    return _wrapper


class Base(object):
    only_sql = False

    def __init__(self, opts, cursorclass=pymysql.cursors.Cursor, autocommit=False):
        self.conn = Pool(cursorclass=cursorclass, autocommit=autocommit, **opts).pool.connection()
        self.queries = list()
        self.save_queries = True

    def close_conn(self):
        self.conn.close()

    @reconnect
    def get_fields(self, table=None):
        if table is None:
            return False
        cur = self.conn.cursor()
        try:
            cur.execute('SHOW FIELDS FROM %s' % table)
        except Exception, e:
            self.error_msg = e
            return False
        fields = []
        for item in cur.fetchall():
            fields.append(item.get('Field') or item.get('field'))

        return fields

    @reconnect
    def select_db(self, db=None):
        if db is not None:
            self.conn.select_db(db)

    @reconnect
    def execute(self, sqlstr):
        if self.only_sql is True:
            self.only_sql = False
            return sqlstr
        return self.conn.cursor().execute(sqlstr)

    @reconnect
    def exec_sql(self, sqlstr):
        if self.only_sql is True:
            self.only_sql = False
            return sqlstr
        if not self.conn:
            return False
        error = False
        count = 0
        sqlstr = sqlstr.strip(' ').strip('\n')
        sqltype =sqlstr[0:6].lower()
        cur = self.conn.cursor()
        try:
            if self.save_queries:
                self.queries.append(sqlstr)
            count = cur.execute(sqlstr)
        except Exception, e:
            error = True
            self.error_msg = e
        if sqltype == 'select':
            if count == 0 or error:
                cur.close()
                self.res = False
                return False
            else:
                ret_list = cur.fetchall()
                cur.close()
                self.res = ret_list
                return ret_list
        elif sqltype == 'update' or sqltype == 'insert' or sqltype == 'delete':
            try:
                self.last_row_id = cur.lastrowid
            except:
                self.last_row_id = 0
            cur.close()
            if count == 0 or error:
                self.res = False
                return False
            else:
                self.res = True
                return True

    def get_last_row_id(self):
        return self.last_row_id

    def last_query(self):
        return self.queries[-1]

    def result(self):
        return self.res

    def commit(self):
        self.conn.commit()

    def rollback(self):
        self.conn.rollback()

    def begin(self):
        self.conn._con._con.autocommit(False)

    def end(self, option='commit'):
        if option == 'commit':
            self.commit()
        else:
            self.rollback()
        self.conn._con._con.autocommit(True)


class DB(Base):
    ar_select = list()
    ar_distinct = False
    ar_from = list()
    ar_join = list()
    ar_where = list()
    ar_like = list()
    ar_groupby = list()
    ar_having = list()
    ar_keys = list()
    ar_limit = False
    ar_offset = 0
    ar_order = False
    ar_orderby = list()
    ar_set = dict()
    ar_wherein = list()
    ar_where_find_in_set = list()
    ar_aliased_tables = list()
    ar_stroe_array = list()

    # Active Record Caching variables
    ar_caching = False
    ar_cache_exists = list()
    ar_cache_select = list()
    ar_cache_from = list()
    ar_cache_join = list()
    ar_cache_where = list()
    ar_cache_like = list()
    ar_cache_groupby = list()
    ar_cache_having = list()
    ar_cache_orderby = list()
    ar_cache_set = list()

    ar_no_escape = list()
    ar_cache_no_escape = list()

    # Private variables
    _protect_identifiers_ = True
    _reserved_identifiers = ['*'] # Identifiers that should NOT be escaped
    _escape_char = '`'
    _random_keyword = ' RAND()'
    _count_string = 'SELECT COUNT(*) AS '

    def __init__(self, opts, cursorclass=pymysql.cursors.Cursor, autocommit=False):
        self._reset_select()
        self._reset_write()
        super(DB, self).__init__(opts, cursorclass=cursorclass, autocommit=autocommit)

    def __getattr__(self, in_field):
        dynamic_properties = ["find_by_", "delete_by_"]
        query = None

        for idx, prop in enumerate(dynamic_properties):
            if in_field.startswith(prop):
                size_of_query = len(dynamic_properties[idx])
                field = in_field[size_of_query:]
                query = in_field[:size_of_query]
                break
        if query is None:
            raise AttributeError(in_field)
        dynamic_query = {
            "find_by_": lambda value, table='': self._find(field, value, table)
        }[query]
        return dynamic_query

    def _find(self, field, value, table):
        if table == '':
            self.error_msg = 'db must set table'
            return False
        self.from_(table)
        self.where(field, value)
        return self.get()

    @property
    def to_sql(self):
        self.only_sql = True
        return self

    def select(self, select='*', escape=None):
        if isinstance(select, basestring):
            select = select.split(',')
        for val in select:
            val = val.strip()
            if val != '':
                self.ar_select.append(val)
                self.ar_no_escape.append(escape)
                if self.ar_caching:
                    self.ar_cache_select.append(val)
                    self.ar_cache_exists.append('select')
                    self.ar_cache_no_escape.append(escape)
        return self

    def select_max(self, select='', alias=''):
        return self._max_min_avg_sum(select, alias, 'MAX')

    def select_min(self, select='', alias=''):
        return self._max_min_avg_sum(select, alias, 'MIN')

    def select_avg(self, select='', alias=''):
        return self._max_min_avg_sum(select, alias, 'AVG')

    def select_sum(self, select='', alias=''):
        return self._max_min_avg_sum(select, alias, 'SUM')

    def _max_min_avg_sum(self, select='', alias='', _type='MAX'):
        if type(select) != type(basestring) or select == '':
            print 'db_invalid_query'
            return
        _type = _type.upper()
        if _type not in ['MAX', 'MIN', 'AVG', 'SUM']:
            print 'Invalid function type: %s' % _type
            return
        if alias == '':
            alias = self._create_alias_from_table(select.strip())
        sql = _type+'('+self._protect_identifiers(select.strip())+') AS '+alias
        self.ar_select.append(sql)

        if self.ar_caching is True:
            self.ar_cache_select.append(sql)
            self.ar_cache_exists.append('select')

        return self

    def _create_alias_from_table(self, item):
        if item.find('.') != -1:
            return item.split('.')[-1]
        return item

    def distinct(self, val=True):
        self.ar_distinct = val if isinstance(val, bool) else True
        return self

    def from_(self, from_str):
        if isinstance(from_str, basestring):
            from_str = from_str.split(',')
        for val in from_str:
            v = val.strip()
            self._track_aliases(v)
            self.ar_from.append(self._protect_identifiers(v, True, None, False))
            if self.ar_caching:
                self.ar_cache_from.append(self._protect_identifiers(v, True, None, False))
                self.ar_cache_exists.append('from')
        return self

    def join_(self, table, cond, _type=''):
        if _type != '':
            _type = _type.strip().upper()
            if _type not in ['LEFT', 'RIGHT', 'OUTER', 'INNER', 'LEFT OUTER', 'RIGHT OUTER']:
                _type = ''
            else:
                _type += ' '
        self._track_aliases(table)

        match = re.match(r'([\w\.]+)([\W\s]+)(.+)', cond)
        if match:
            cond = "%s%s%s" % (self._protect_identifiers(match.group(1)),
                                match.group(2), self._protect_identifiers(match.group(3)))

        _join = _type+'JOIN '+self._protect_identifiers(table, True, None, False)+' ON '+cond
        self.ar_join.append(_join)
        if self.ar_caching is True:
            self.ar_cache_join.append(_join)
            self.ar_cache_exists.append('join')

        return self

    def where(self, key, value=None, escape=True):
        return self._where(key, value, 'AND ', escape)

    def or_where(self, key, value=None, escape=True):
        return self._where(key, value, 'OR ', escape)

    def _where(self, key, value=None, type='AND ', escape=None):
        if not isinstance(key, dict):
            key = {key: value}
        for k, v in key.iteritems():
            prefix = '' if (len(self.ar_where) == 0 and len(self.ar_cache_where) == 0) else type
            if (v == '') and (not self._has_operator(k)):
                k += ' IS NULL'
            if v != '':
                if escape is True:
                    k = self._protect_identifiers(k, False, escape)
                    v = ' %s'%self.escape(v)
                if not self._has_operator(k):
                    k += ' = '
            else:
                k = self._protect_identifiers(k, False, escape)
            self.ar_where.append('%s%s%s' % (prefix, k, v))
            if self.ar_caching:
                self.ar_cache_where.append('%s%s%s' % (prefix, k, v))
                self.ar_cache_exists.append('where')
        return self

    def where_in(self, key=None, values=None):
        return self._where_in(key, values)

    def or_where_in(self, key=None, values=None):
        return self._where_in(key, values, False, 'OR ')

    def where_not_in(self, key=None, values=None):
        return self._where_in(key, values, True)

    def or_where_not_in(self, key=None, values=None):
        return self._where_in(key, values, True, 'OR ')

    def _where_in(self, key=None, values=None, not_=False, type='AND '):
        if key is None or values is None:
            return
        if not isinstance(values, (list, tuple)):
            values = list(values)
        not_ = ' NOT' if not_ else ''
        for value in values:
            v = self.escape(value)
            v = v if not isinstance(v, (int, float)) else str(v)
            self.ar_wherein.append(v)

        prefix = '' if len(self.ar_where)==0 else type
        where_in = prefix + self._escape_identifiers(key) + not_ + " IN (" + ', '.join(self.ar_wherein) + ") "
        self.ar_where.append(where_in)
        if self.ar_caching:
            self.ar_cache_where.append(where_in)
            self.ar_cache_exists.append('where')

        self.ar_wherein = list()
        return self

    def where_find_in_set(self, key=None, value=None):
        return self._where_find_in_set(key, value)

    def or_where_find_in_set(self, key=None, value=None):
        return self._where_find_in_set(key, value, False, 'OR ')

    def where_not_find_in_set(self, key=None, value=None):
        return self._where_find_in_set(key, value, True)

    def or_where_not_find_in_set(self, key=None, value=None):
        return self._where_find_in_set(key, value, True, 'OR ')

    def _where_find_in_set(self, key=None, value=None, not_=False, type='AND '):
        if key is None or value is None:
            return
        value = self.escape(value)
        not_ = ' NOT' if not_ else ''

        prefix = '' if len(self.ar_where)==0 else type
        # FIND_IN_SET(value, field)
        where_find_in_set = prefix + not_ + " FIND_IN_SET("+value+", "+self._escape_identifiers(key) + ") "
        self.ar_where.append(where_find_in_set)
        if self.ar_caching:
            self.ar_cache_where.append(where_find_in_set)
            self.ar_cache_exists.append('where')

        return self


    def like(self, field, match='', side='both'):
        return self._like(field, match, 'AND ', side)

    def not_like(self, field, match='', side='both'):
        return self._like(field, match, 'AND ', side, 'NOT')

    def or_like(self, field, match='', side='both'):
        return self._like(field, match, 'OR ', side)

    def or_not_like(self, field, match='', side='both'):
        return self._like(field, match, 'OR ', side, 'NOT')

    def _like(self, field, match='', _type='AND ', side='both', not_=''):
        if not isinstance(field, dict):
            field = {field: match}
        for k, v in field.iteritems():
            k = self._protect_identifiers(k)
            prefix = '' if len(self.ar_like)==0 else _type
            # v = self.escape_like_str(v)
            if side == 'none':
                like_statement = prefix+" %s %s LIKE '%s'" % (k, not_, v)
            elif side == 'before':
                like_statement = prefix+" %s %s LIKE '%%%s'" % (k, not_, v)
            elif side == 'after':
                like_statement = prefix+" %s %s LIKE '%s%%'" % (k, not_, v)
            else:
                like_statement = prefix+" %s %s LIKE '%%%s%%'" % (k, not_, v)

            self.ar_like.append(like_statement)
            if self.ar_caching is True:
                self.ar_cache_like.append(like_statement)
                self.ar_cache_exists.append('like')

        return self

    def group_by(self, by):
        if isinstance(by, basestring):
            by = by.split(',')
        for val in by:
            val = val.strip()
            if val != '':
                self.ar_groupby.append(self._protect_identifiers(val))
                if self.ar_caching is True:
                    self.ar_cache_groupby.append(self._protect_identifiers(val))
                    self.ar_cache_exists.append('groupby')
        return self

    def having(self, key, value='', escape=True):
        return self._having(key, value, 'AND ', escape)

    def or_having(self, key, value='', escape=True):
        return self._having(key, value, 'OR ', escape)

    def _having(self, key, value='', _type='AND ', escape=True):
        if type(key) != type(dict):
            key = {key: value}
        for k, v in key:
            prefix = '' if len(self.ar_having)==0 else _type
            if escape is True:
                k = self._protect_identifiers(k)
            if not self._has_operator(k):
                k += ' = '
            if v != '':
                v = ' '+self.escape(v)
            self.ar_having.append(prefix+k+v)
            if self.ar_caching is True:
                self.ar_cache_having.append(prefix+k+v)
                self.ar_cache_exists.append('having')
        return self

    def order_by(self, orderby, direction=''):
        if direction.lower() == 'random':
            orderby = ''
            direction = self._random_keyword
        elif direction.strip() != '':
            direction += ' %s'%(direction if direction.strip().upper() in ['ASC', 'DESC'] else 'ASC')

        if orderby.find(',') != -1:
            temp = list()
            for part in orderby.split(','):
                part = part.strip()
                if part in self.ar_aliased_tables:
                    part = self._protect_identifiers(part)
                temp.append(part)
            orderby = ', '.join(temp)
        elif direction != self._random_keyword:
            orderby = self._protect_identifiers(orderby)

        orderby_statement = orderby+direction
        self.ar_orderby.append(orderby_statement)
        if self.ar_caching is True:
            self.ar_cache_orderby.append(orderby_statement)
            self.ar_cache_exists.append('orderby')

        return self

    def limit(self, value, offset=''):
        self.ar_limit = int(value)
        if offset != '':
            self.ar_offset = int(offset)
        return self

    def offset(self, offset):
        self.ar_offset = int(offset)
        return self

    def set_(self, key, value='', escape=True):
        if not isinstance(key, dict):
            key = {key: value}
        for k, v in key.iteritems():
            if escape is False:
                self.ar_set[self._protect_identifiers(k)] = v
            else:
                self.ar_set[self._protect_identifiers(k, False, True)] = self.escape(v)
        return self


    def escape(self, str_):
        if isinstance(str_, basestring):
            str_ = converters.escape_str(str_)
        elif isinstance(str_, bool):
            str_ = converters.escape_bool(str_)
        elif str_ == '':
            str_ = converters.escape_None(str_)
        elif str_ is None:
            str_ = converters.escape_None(str_)
        return str_

    def escape_like_str(self, str_):
        return converters.escape_str(str_)

    def _escape_identifiers(self, item):
        """
        This function escapes column and table names
        @param item:
        """
        if self._escape_char == '':
            return item

        for field in self._reserved_identifiers:
            if item.find('.%s' % field) != -1:
                _str = "%s%s" % (self._escape_char, item.replace('.', '%s.' % self._escape_char))
                # remove duplicates if the user already included the escape
                return re.sub(r'[%s]+'%self._escape_char, self._escape_char, _str)

        if item.find('.') != -1:
            _str = "%s%s%s" % (self._escape_char, item.replace('.', '%s.%s'%(self._escape_char, self._escape_char)),
            self._escape_char)
        else:
            _str = self._escape_char+item+self._escape_char
        # remove duplicates if the user already included the escape
        return re.sub(r'[%s]+'%self._escape_char, self._escape_char, _str)

    def _protect_identifiers(self, item, prefix_single=False, protect_identifiers=None, field_exists=True):
        """

        @param item:
        @param prefix_single:
        @param protect_identifiers:
        @param field_exists:
        """
        if not isinstance(protect_identifiers, bool):
            protect_identifiers = self._protect_identifiers_

        if isinstance(item, dict):
            escaped_dict = dict()
            for k, v in item.iteritems():
                escaped_dict[self._protect_identifiers(k)] = self._protect_identifiers(v)
            return escaped_dict

        # Convert tabs or multiple spaces into single spaces
        item = re.sub(r'[\t ]+', ' ', item)

        if item.find(' ') != -1:
            alias = item[item.find(' '):]
            item = item[0:item.find(' ')]
        else:
            alias = ''

        if item.find('(') != -1:
            return item+alias

        if item.find('.') != -1:
            parts = item.split('.')
            if parts[0] in self.ar_aliased_tables:
                if protect_identifiers is True:
                    for key, val in enumerate(parts):
                        if val not in self._reserved_identifiers:
                            parts[key] = self._escape_identifiers(val)
                    item = '.'.join(parts)
                return item+alias
            if protect_identifiers is True:
                item = self._escape_identifiers(item)
            return item+alias
        if protect_identifiers is True and item not in self._reserved_identifiers:
            item = self._escape_identifiers(item)
        return item+alias

    def _has_operator(self, str_):
        if not re.search(r'(\s|<|>|!|=|is null|is not null)', str_.strip(), re.I):
            return False
        return True

    def get(self, table='', limit=None, offset=''):
        if table != '':
            self._track_aliases(table)
            self.from_(table)

        if limit is not None:
            self.limit(limit, offset)

        sql = self._compile_select()
        result = self.exec_sql(sql)
        self._reset_select()
        return result

    def insert(self, table='', _set=None):
        self._merge_cache()
        if _set is not None:
            self.set_(_set, escape=True)

        if len(self.ar_set) == 0:
            self.error_msg = 'insert columns and values are empty'
            return False

        if table == '':
            self.error_msg = 'db must set table'
            return False

        sql = self._insert(self._protect_identifiers(table, True, None, False), self.ar_set.keys(), self.ar_set.values())
        self._reset_write()
        return self.exec_sql(sql)

    def _insert(self, table, keys, values):
        esc_values = []
        for val in values:
            if type(val) == int:
                esc_values.append(str(val))
            else:
                esc_values.append(val)
        return """INSERT INTO %s
                (%s)
                VALUES (%s);
        """ % (
            table,
            ', '.join(keys),
            ', '.join(esc_values)
        )

    def insert_duplicate(self, table='', _set=None):
        self._merge_cache()
        if _set is not None:
            self.set_(_set, escape=True)

        if len(self.ar_set) == 0:
            self.error_msg = 'insert columns and values are empty'
            return False

        if table == '':
            self.error_msg = 'db must set table'
            return False

        sql = self._insert_duplicate(self._protect_identifiers(table, True, None, False), self.ar_set.keys(), self.ar_set.values(), self.ar_set)
        self._reset_write()
        return self.exec_sql(sql)

    def _insert_duplicate(self, table, keys, values, _set):
        esc_values = []
        for val in values:
            if type(val) == int:
                esc_values.append(str(val))
            else:
                esc_values.append(val)
        update_data = []
        for k, v in _set.iteritems():
            update_data.append("%s = %s" % (k, v))

        return """INSERT INTO %s
                (%s)
                VALUES (%s)
                ON DUPLICATE KEY UPDATE %s;
        """ % (
            table,
            ', '.join(keys),
            ', '.join(esc_values),
            ', '.join(update_data)


        )

    def replace(self, table, _set=None):
        self._merge_cache()
        if _set is not None:
            self.set_(_set, escape=True)

        if len(self.ar_set) == 0:
            self.error_msg = 'replace columns and values are empty'
            return False

        if table == '':
            self.error_msg = 'db must set table'
            return False

        sql = self._replace(self._protect_identifiers(table, True, None, False), self.ar_set.keys(), self.ar_set.values())
        self._reset_write()
        return self.exec_sql(sql)

    def _replace(self, table, keys, values):
        esc_values = []
        for val in values:
            if type(val) == int:
                esc_values.append(str(val))
            else:
                esc_values.append(val)
        return """REPLACE INTO %s
                (%s)
                VALUES (%s);
        """ % (
            table,
            ', '.join(keys),
            ', '.join(esc_values)
        )

    def update(self, table='', _set=None, where=None, limit=None):
        self._merge_cache()
        if _set is not None:
            self.set_(_set, escape=True)

        if len(self.ar_set) == 0:
            self.error_msg = 'update columns and values are empty'
            return False

        if table == '':
            self.error_msg = 'db must set table'
            return False

        if where is not None:
            self.where(where)
        if limit is not None:
            self.limit(limit)
        sql = self._update(self._protect_identifiers(table, True, None, False), self.ar_set, self.ar_where, self.ar_orderby, self.ar_limit)
        self._reset_write()
        return self.exec_sql(sql)

    def _update(self, table, values, where, orderby=None, limit=False):
        valstr = []
        for key, value in values.iteritems():
            valstr.append('%s = %s'%(key,value))
        limit = ' LIMIT %s' % limit if limit is not False else ''
        if not isinstance(orderby, (list, tuple)):
            orderby = ''
        else:
            orderby = ' ORDER BY %s' % ', '.join(orderby) if len(orderby)>=1 else ''
        sql = "UPDATE "+table+" SET "+', '.join(valstr)
        where = " WHERE "+' '.join(where) if where != '' and len(where)>=1 else ''
        sql += where
        sql += orderby+limit
        return sql

    def delete(self, table='', where='', limit=None, reset_data=True):
        self._merge_cache()
        if table == '':
            self.error_msg = 'db must set table'
            return False
        elif isinstance(table, (list, tuple)):
            for single_table in table:
                self.delete(single_table, where, limit, False)
            self._reset_write()
            return True
        else:
            table = self._protect_identifiers(table, True, None, False)

        if where != '':
            self.where(where)
        if limit is not None:
            self.limit(limit)

        if len(self.ar_where) == 0 and len(self.ar_wherein) == 0 and len(self.ar_like) == 0:
            self.error_msg = 'db del must use where'
            return False

        sql = self._delete(table, self.ar_where, self.ar_like, self.ar_limit)

        if reset_data:
            self._reset_write()
        return self.exec_sql(sql)

    def _delete(self, table, where=None, like=None, limit=False):
        conditions = ''
        if (isinstance(where, (list, tuple)) and len(where) > 0) or (isinstance(like, (list, tuple)) and len(like) > 0):
            conditions = "\nWHERE "
            conditions += "\n".join(self.ar_where)
            if (isinstance(where, (list, tuple)) and len(where) > 0) and (isinstance(like, (list, tuple)) and len(like) > 0):
                conditions += " AND "
            conditions += "\n".join(like)
        limit = ' LIMIT %s' % limit if limit is not False else ''

        return "DELETE FROM "+table+conditions+limit

    def insert_batch(self, table='', _set=None):
        if _set is None:
            self.error_msg = 'db must use values'
            return False

        self.set_insert_batch(_set)

        if table == '':
            if len(self.ar_from) == 0:
                self.error_msg = 'db must set table'
                return False
            table = self.ar_from[0]

        sql = self._insert_batch(self._protect_identifiers(table, True, None, False), self.ar_keys, self.ar_set)

        self._reset_write()

        return self.exec_sql(sql)

    def _insert_batch(self, table, keys, values):
        return """INSERT INTO %s
                (%s)
                VALUES %s;
        """ % (
            table,
            ', '.join(keys),
            ', '.join(values)
        )

    def set_insert_batch(self, key, escape=True):
        self.ar_set = []
        if not isinstance(key, list):
            return

        keys = sorted(key[0].keys())

        for row in key:
            if len(set(keys)^set(row.keys())) > 0:
                return
            if escape is False:
                esc_values = []
                for k in sorted(row.keys()):
                    if type(row[k]) == int:
                        esc_values.append(str(row[k]))
                    else:
                        esc_values.append(row[k])
                self.ar_set.append(
                    '('+','.join(esc_values)+')'
                )
            else:
                clean = []
                for k in sorted(row.keys()):
                    if type(row[k]) == int:
                        clean.append(str(self.escape(row[k])))
                    else:
                        clean.append(self.escape(row[k]))

                self.ar_set.append(
                    '('+','.join(clean)+')'
                )

        for k in keys:
            self.ar_keys.append(self._protect_identifiers(k))

        return self

    def replace_batch(self, table='', _set=None):
        if _set is None:
            self.error_msg = 'db must use values'
            return False

        self.set_insert_batch(_set)

        if table == '':
            if len(self.ar_from) == 0:
                self.error_msg = 'db must set table'
                return False
            table = self.ar_from[0]

        sql = self._replace_batch(self._protect_identifiers(table, True, None, False), self.ar_keys, self.ar_set)

        self._reset_write()

        return self.exec_sql(sql)

    def _replace_batch(self, table, keys, values):
        return """REPLACE INTO %s
                (%s)
                VALUES %s;
        """ % (
            table,
            ', '.join(keys),
            ', '.join(values)
        )

    def _track_aliases(self, table):
        if isinstance(table, (list, tuple)):
            for t in table:
                self._track_aliases(t)
            return
        if table.find(',') != -1:
            return self._track_aliases(table.split(','))

        if table.find(" ") != -1:
            table = re.sub(r'\s+AS\s+', ' ', table, flags=re.I)
            table = table[table.rfind(" ")].strip()
            if table not in self.ar_aliased_tables:
                self.ar_aliased_tables.append(table)

    def _compile_select(self, select_override=None):
        self._merge_cache()
        if select_override:
            sql = select_override
        else:
            sql = 'SELECT ' if not self.ar_distinct else 'SELECT DISTINCT '
            if len(self.ar_select) == 0:
                sql += '*'
            else:
                for key, val in enumerate(self.ar_select):
                    try:
                        no_escape = self.ar_no_escape[key]
                    except:
                        no_escape = None
                    self.ar_select[key] = self._protect_identifiers(val, False, no_escape)
                sql += ', '.join(self.ar_select)

        if len(self.ar_from) > 0:
            sql += "\nFROM "
            sql += ', '.join(self.ar_from)

        if len(self.ar_join) > 0:
            sql += "\n"
            sql += "\n".join(self.ar_join)

        if len(self.ar_where) > 0 or len(self.ar_like) > 0:
            sql += "\nWHERE "
        sql += '\n'.join(self.ar_where)

        if len(self.ar_like) > 0:
            if len(self.ar_where) > 0:
                sql += "\nAND "
            sql += "\n".join(self.ar_like)

        if len(self.ar_groupby) > 0:
            sql += "\nGROUP BY "
            sql += '\n'.join(self.ar_groupby)

        if len(self.ar_having) > 0:
            sql += "\nHAVING "
            sql += '\n'.join(self.ar_having)

        if len(self.ar_orderby) > 0:
            sql += "\nORDER BY "
            sql += '\n'.join(self.ar_orderby)
            if self.ar_order is not False:
                sql += ' DESC' if self.ar_order == 'desc' else ' ASC'

        if type(self.ar_limit) == int:
            sql += "\n"
            sql = self._limit(sql, self.ar_limit, self.ar_offset)

        return sql

    def _limit(self, sql, limit, offset):
        if offset == 0:
            offset = ''
        else:
            offset = "%s, " % offset

        return sql+"LIMIT %s%s"%(offset, limit)

    def _merge_cache(self):
        self.ar_variable = dict()
        if len(self.ar_cache_exists) == 0:
            return
        for val in self.ar_cache_exists:
            ar_variable = 'ar_' + val
            ar_cache_var = 'ar_cache_' + val
            if len(getattr(self, ar_cache_var)) == 0:
                continue
            self.ar_variable.update({ar_cache_var: ar_variable})
        self.ar_no_escape = self.ar_cache_no_escape

    def _reset_select(self):
        self.ar_select = list()
        self.ar_from = list()
        self.ar_join = list()
        self.ar_where = list()
        self.ar_like = list()
        self.ar_groupby = list()
        self.ar_having = list()
        self.orderby = list()
        self.wherein = list()
        self.ar_no_escape = list()
        self.ar_distinct = False
        self.ar_limit = False
        self.ar_offset = 0
        self.ar_order = False

    def _reset_write(self):
        self.ar_set = {}
        self.ar_from = list()
        self.ar_where = list()
        self.ar_like = list()
        self.ar_groupby = list()
        self.ar_keys = list()
        self.ar_limit = False
        self.ar_order = False

if __name__ == "__main__":
    # db = Base()
    # sql = 'select * from a where b=1'
    # print db.exec_sql(sql)
    # print db.to_sql.exec_sql(sql)
    # db2 = DB()
    # print db2.to_sql.select('access_num').from_('z_acl').where('access_name', 'vm_platform').where_find_in_set('ip', '172.16.3.92').get()
    # print db2.to_sql.find_by_vm_id(61, 'z_vm')
    # print db2.find_by_vm_id(61, 'z_vm')
    # print getattr(db2, 'find_by_acl_id')(61, 'z_acl') or [{}][0]
    # db = PooledDB(pymysql, **db_opts)
    # print help(db.connection()._con._con.autocommit)
    # print db.connection()._con._con.get_autocommit()
    db = DB(cursorclass=pymysql.cursors.DictCursor, autocommit=True)
    db2 = DB(cursorclass=pymysql.cursors.DictCursor, autocommit=False)
    print type(db) == type(db2)