#encoding=utf8
import importlib
import inspect
import json
import logging.config
import os
import random
import socket
import ssl
import sys
import threading
import time
import uuid    

class Protocol: 
    #############Command Values############
    # MQ Produce/Consume
    PRODUCE = "produce"
    CONSUME = "consume"
    RPC = "rpc"
    ROUTE = "route"  # route back message to sender, designed for RPC

    # Topic/ConsumeGroup control
    DECLARE = "declare"
    QUERY = "query"
    REMOVE = "remove"
    EMPTY = "empty"

    # Tracker
    TRACK_PUB = "track_pub"
    TRACK_SUB = "track_sub"

    COMMAND = "cmd"
    TOPIC = "topic"
    TOPIC_MASK = "topic_mask"
    TAG = "tag"
    OFFSET = "offset"
    
    DELAY = "delay"

    CONSUME_GROUP = "consume_group"
    GROUP_NAME_AUTO = "group_name_auto";  
    GROUP_START_COPY = "group_start_copy"
    GROUP_START_OFFSET = "group_start_offset"
    GROUP_START_MSGID = "group_start_msgid"
    GROUP_START_TIME = "group_start_time"
    GROUP_FILTER = "group_filter"
    GROUP_MASK = "group_mask"
    CONSUME_WINDOW = "consume_window"  
    GROUP_ACK_WINDOW = "group_ack_window"
    GROUP_ACK_TIMEOUT = "group_ack_timeout"
    
    SENDER = "sender"
    RECVER = "recver"
    ID = "id"
    
    ACK = "ack"
    ENCODING = "encoding"

    ORIGIN_ID = "origin_id"
    ORIGIN_URL = "origin_url"
    ORIGIN_STATUS = "origin_status"

    # Security
    TOKEN = "token"

    MASK_DISK            = 0
    MASK_MEMORY    	     = 1<<0
    MASK_RPC    	     = 1<<1
    MASK_PROXY    	     = 1<<2
    MASK_PAUSE    	     = 1<<3 
    MASK_EXCLUSIVE 	     = 1<<4 
    MASK_DELETE_ON_EXIT  = 1<<5
    MASK_ACK_REQUIRED    = 1<<6
    
##########################################################################
# support both python2 and python3
if sys.version_info[0] < 3:
    Queue = importlib.import_module('Queue')
    def _bytes(buf, encoding='utf8'):
        return buf.encode(encoding)
    myurllib = importlib.import_module('urllib')
    def _urldecode(val):
        return myurllib.unquote(val).decode('utf8')
    def _urlencode(val):
        return myurllib.quote_plus(val, safe='/')
else:
    Queue = importlib.import_module('queue')
    def _bytes(buf, encoding='utf8'):
        return bytes(buf, encoding)
    myurllib = importlib.import_module('urllib.parse')
    def _urldecode(url):
        return myurllib.unquote(url)
    def _urlencode(val):
        return myurllib.quote_plus(val, safe='/')
try:
    log_file = 'log.conf'
    if os.path.exists(log_file):
        logging.config.fileConfig(log_file)
    else:
        import os.path
        log_dir = os.path.dirname(os.path.realpath(__file__))
        log_file = os.path.join(log_dir, 'log.conf')
        logging.config.fileConfig(log_file)
except:
    logging.basicConfig(
        format='%(asctime)s - %(filename)s-%(lineno)s - %(levelname)s - %(message)s')


'''
HTTP-header extensions (strong typed)
'''
class MessageCtrl:
    def __init__(self):
        self.topic = None
        self.topic_mask = None
        self.token = None  
        self.consume_group = None
        self.group_name_auto = None
        self.group_mask = None
        self.group_filter = None 
        self.group_start_copy = None
        self.group_start_offset = None
        self.group_start_time = None 
        self.group_ack_window = None
        self.group_ack_timeout = None
        
        self.offset = None
        
    def to_message(self, msg):
        members = inspect.getmembers(self, lambda a : not(inspect.isroutine(a)))
        attrs = []
        for m in members:
            if not m[0].startswith('__'):
                attrs.append((m[0]))
        for name in attrs:
            val = getattr(self, name)
            if val:
                msg[name] = val 
                
class Message(dict):
    http_status = {
        200: "OK",
        201: "Created",
        202: "Accepted",
        204: "No Content",
        206: "Partial Content",
        301: "Moved Permanently",
        304: "Not Modified",
        400: "Bad Request",
        401: "Unauthorized",
        403: "Forbidden",
        404: "Not Found",
        405: "Method Not Allowed",
        416: "Requested Range Not Satisfiable",
        500: "Internal Server Error",
    }
    reserved_keys = set(['status', 'method', 'url', 'body'])
    codec_keys = set([Protocol.TOPIC, 
                      Protocol.CONSUME_GROUP,
                      Protocol.TAG,
                      Protocol.GROUP_FILTER,
                      Protocol.GROUP_START_COPY,
                      Protocol.TOKEN])

    def __init__(self, opt=None):
        self.body = None
        if opt:
            if isinstance(opt, dict):
                for k in opt:
                    self[k] = opt[k]
            elif isinstance(opt, MessageCtrl):
                opt.to_message(self)
        

    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            return None

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        self.pop(name, None)

    def __getitem__(self, key):
        if key not in self:
            return None
        return dict.__getitem__(self, key)
    
    def encode(self):
        return msg_encode(self) 
    
    @staticmethod
    def decode(buf, start=0):
        return msg_decode(buf, start)

def parse_content_type(text):
    if not text:
        return None, None
    bb = text.split(';')
    ct = bb[0].strip()
    charset = None
    if len(bb)>1:
        cc = bb[1].split('=')
        if len(cc)==2 and cc[0].strip().lower()=='charset':
            charset = cc[1].strip()
    return ct,charset
def msg_encode(msg):
    if not isinstance(msg, dict):
        raise ValueError('%s must be dict type' % msg)
    if not isinstance(msg, Message):
        msg = Message(msg)

    charset = 'utf8' #headers default to 'utf8'
    res = bytearray()
    if msg.status is not None:
        desc = Message.http_status.get('%s' % msg.status)
        if desc is None:
            desc = b"Unknown Status"
        res += _bytes("HTTP/1.1 %s %s\r\n" % (msg.status, desc), charset)
    else:
        m = msg.method
        if not m:
            m = 'GET'
        url = msg.url
        if not url:
            url = '/'
        url = _urlencode(url)
        res += _bytes("%s %s HTTP/1.1\r\n" % (m, url), charset)

    body_len = 0
    content_type = msg['content-type'] or 'text/plain'
    content_type, charset = parse_content_type(content_type)
    if charset is None: charset = 'utf8'
    msg_body = msg.body
    if msg_body:  
        if not isinstance(msg_body, (bytes, bytearray, str)) or content_type.startswith('application/json'):
            msg_body = json.dumps(msg_body, ensure_ascii=False).encode(encoding=charset)
            content_type = 'application/json' 
        else:
            msg_body = _bytes(str(msg_body), charset)
        body_len = len(msg_body)
    else: #handle None type json
        if content_type.startswith('application/json'):
            msg_body = json.dumps(msg_body, ensure_ascii=False).encode(encoding=charset)
            body_len = len(msg_body)
    
    for k in msg:
        k = k.lower()
        if k in Message.reserved_keys:
            continue
        v = msg[k]
        if v is None:
            continue
        if k == 'content-type' or k == 'content-length' or k == 'encoding':
            continue
        if k in Message.codec_keys: 
            v = _urlencode(v)
        
        res += _bytes('%s: %s\r\n' % (k, v), charset) 
    
    content_type = '%s; charset=%s'%(content_type, charset)
    res += _bytes('content-type: %s\r\n' % content_type, charset)
    res += _bytes('content-length: %s\r\n' % body_len, charset)

    res += _bytes('\r\n', charset)

    if msg_body:
        res += msg_body
    return res


def find_header_end(buf, start=0):
    i = start
    end = len(buf)
    while i + 3 < end:
        if buf[i] == 13 and buf[i + 1] == 10 and buf[i + 2] == 13 and buf[i + 3] == 10:
            return i + 3
        i += 1
    return -1


def decode_headers(buf):
    msg = Message()
    buf = buf.decode('utf8')
    lines = buf.splitlines()
    meta = lines[0]
    blocks = meta.split()
    if meta.upper().startswith('HTTP'):
        msg.status = int(blocks[1])
    else:
        msg.method = blocks[0].upper()
        if len(blocks) > 1:
            msg.url = _urldecode(blocks[1])

    for i in range(1, len(lines)):
        line = lines[i]
        if len(line) == 0:
            continue
        try:
            p = line.index(':')
            key = str(line[0:p]).strip()
            val = str(line[p + 1:]).strip() 
            if key in Message.codec_keys: 
                val = _urldecode(val)
            msg[key] = val
        except Exception as e:
            logging.error(e)

    return msg


def msg_decode(buf, start=0):
    p = find_header_end(buf, start)
    if p < 0:
        return (None, start)
    head = buf[start: p]
    msg = decode_headers(head)
    if msg is None:
        return (None, start)
    p += 1  # new start

    body_len = msg['content-length']
    if body_len is None:
        return (msg, p)
    body_len = int(body_len)
    if len(buf) - p < body_len:
        return (None, start)
    msg['content-length'] = body_len

    msg.body = buf[p: p + body_len] 
    content_type = msg['content-type'] or 'text/plain'
    content_type, charset = parse_content_type(content_type)
    if charset is None: charset = 'utf8' 
    
    msg.encoding = charset
    if content_type:
        if content_type.startswith('text'):
            msg.body = msg.body.decode(charset)
        if content_type.startswith('application/json'):
            try:
                msg.body = json.loads(msg.body, encoding=charset)
            except:
                pass
    return (msg, p + body_len)

class ServerAddress:
    def __init__(self, address, ssl_enabled=False):
        if isinstance(address, str):
            self.address = address
            self.ssl_enabled = ssl_enabled
        elif isinstance(address, dict):
            if 'address' not in address:
                raise TypeError('missing address in dictionary')
            if 'sslEnabled' not in address:  # camel style from java/js
                raise TypeError('missing sslEnabled in dictionary')

            self.address = address['address']
            self.ssl_enabled = address['sslEnabled']
        elif isinstance(address, ServerAddress):
            self.address = address.address
            self.ssl_enabled = address.ssl_enabled
        else:
            raise TypeError(address + " address not support")

    def __key(self):
        if self.ssl_enabled:
            return '[SSL]%s' % self.address
        return self.address

    def __hash__(self):
        return hash(self.address)

    def __eq__(self, other):
        return self.address == other.address and self.ssl_enabled == other.ssl_enabled

    def __str__(self):
        return self.__key()

    def __repr__(self):
        return self.__str__()


class MessageClient(object):
    log = logging.getLogger(__name__)

    def __init__(self, address='localhost:15555', ssl_cert_file=None):
        self.server_address = ServerAddress(address)
        self.ssl_cert_file = ssl_cert_file

        bb = self.server_address.address.split(':')
        self.host = bb[0]
        self.port = 80
        if len(bb) > 1:
            self.port = int(bb[1])

        self.read_buf = bytearray()
        self.sock = None
        self.pid = os.getpid()
        self.auto_reconnect = True
        self.reconnect_interval = 3  # 3 seconds
        self.heartbeat_interval = 60 # 1 minutes

        self.result_table = {}

        self.connect_lock = threading.Lock()
        self.read_lock = threading.Lock()
        self.write_lock = threading.Lock()

        self.on_connected = None
        self.on_disconnected = None
        self.on_message = None
        self.manually_closed = False
    
        self.heartbeat_timer()
      
    def close(self):
        self.manually_closed = True
        self.auto_reconnect = False
        self.on_disconnected = None
        self.sock.close()
        self.read_buf = bytearray() 
        
        self.heartbeator.cancel()
        

    def invoke(self, msg, timeout=3):
        msgid = self.send(msg, timeout)
        return self.recv(msgid, timeout)

    def send(self, msg, timeout=3):
        self.connect() #connect if needed
        with self.write_lock:
            return self._send(msg, timeout)
    
    def heartbeat_timer(self):
        self.heartbeator = threading.Timer(self.heartbeat_interval, self.heartbeat_timer)
        self.heartbeator.start()
        self.heartbeat()

    def heartbeat(self):
        if self.sock == None:
            return 
        msg = Message()
        msg.cmd = 'heartbeat'
        self.log.debug('Heartbeat: %s' % msg)
        try:
            self.send(msg)
        except:
            pass

    def recv(self, msgid=None, timeout=3):
        self.connect() #connect if needed
        with self.read_lock:
            return self._recv(msgid, timeout)

    def connect(self):
        if self.sock != None:
            return
        with self.connect_lock:
            self.manually_closed = False
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            if self.server_address.ssl_enabled:
                self.sock = ssl.wrap_socket(
                    self.sock, ca_certs=self.ssl_cert_file, cert_reqs=ssl.CERT_REQUIRED)

            self.log.debug('Trying connect to (%s)' % self.server_address)
            self.sock.connect((self.host, self.port))
            self.log.debug('Connected to (%s)' % self.server_address)

        if self.on_connected:
            self.on_connected()

        self.read_buf = bytearray()

    def _send(self, msg, timeout=10):
        msgid = msg.id
        if not msgid:
            msgid = msg.id = str(uuid.uuid4())

        #self.log.debug('Request: %s' % msg)
        self.sock.sendall(msg_encode(msg))
        return msgid

    def _recv(self, msgid=None, timeout=3):
        if not msgid and len(self.result_table) > 0:
            try:
                return self.result_table.popitem()[1]
            except:
                pass

        if msgid in self.result_table:
            return self.result_table[msgid]

        self.sock.settimeout(timeout)
        while True:
            buf = self.sock.recv(1024)
            #!!! when remote socket idle closed, could return empty, fixed by raising exception!!!
            if buf == None or len(buf) == 0:
                raise socket.error('remote server socket status error, possible idle closed')

            self.read_buf += buf
            idx = 0
            while True:
                msg, idx = msg_decode(self.read_buf, idx)
                if msg is None:
                    if idx != 0:
                        self.read_buf = self.read_buf[idx:]
                    break

                self.read_buf = self.read_buf[idx:]

                if msgid:
                    if msg.id != msgid:
                        self.result_table[msg.id] = msg
                        continue

                #self.log.debug('Result: %s' % msg)
                return msg

    def start(self, recv_timeout=60):
        def serve():  
            while True:
                try:
                    msg = self.recv(None, recv_timeout)
                    if msg and self.on_message:
                        self.on_message(msg)
                except socket.timeout as e: 
                    continue
                except socket.error as e:
                    if self.manually_closed:
                        break

                    self.log.warn('%s: %s'%(self.server_address, e))
                    
                    if self.on_disconnected:
                        self.on_disconnected()
                    if not self.auto_reconnect:
                        break
                    while self.auto_reconnect:
                        try:
                            self.sock.close()
                            self.sock = None
                            self.connect()
                            break
                        except socket.error as e:
                            self.log.warn('%s: %s'%(self.server_address, e))
                            time.sleep(self.reconnect_interval)

        self._thread = threading.Thread(target=serve)
        self._thread.start()
 


class MqClient(MessageClient):
    def __init__(self, address='localhost:15555', ssl_cert_file=None):
        MessageClient.__init__(self, address, ssl_cert_file)
        self.token = None

    def _normalize_msg(self, msg):
        if isinstance(msg, Message):
            pass
        elif isinstance(msg, dict):
            msg = Message(msg)  
        else:
            raise Exception('msg type not support')
        if not msg.token:
            msg.token = self.token
        return msg
    
    #msg could be dict or Message
    def invoke(self, msg, timeout=3):
        msg = self._normalize_msg(msg)  
        return super(MqClient, self).invoke(msg, timeout=timeout)

    def invoke_object(self, msg, timeout=3):
        res = self.invoke(msg, timeout=timeout)
        if res.status != 200:  # not throw exception, for batch operations' convenience
            return {'error': res.body.decode(res.encoding or 'utf8')}
        return res.body

    def produce(self, msg, timeout=3): 
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.PRODUCE
        return self.invoke(msg, timeout=timeout)

    def consume(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        ctrl = Message()   
        ctrl.topic = msg.topic
        ctrl.consume_group = msg.consume_group
        ctrl.offset = msg.offset #case when pulling message by offset
        
        ctrl.cmd = Protocol.CONSUME  
        return self.invoke(ctrl, timeout=timeout)

    def query(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.QUERY
        return self.invoke_object(msg, timeout=timeout)

    def declare(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.DECLARE
        return self.invoke_object(msg, timeout=timeout)

    def remove(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.REMOVE
        return self.invoke_object(msg, timeout=timeout)

    def empty(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.EMPTY
        return self.invoke_object(msg, timeout=timeout)

    def route(self, msg, timeout=3):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.ROUTE 
        if msg.status:
            msg.origin_status = msg.status
            msg.status = None

        self.send(msg, timeout)
    
    def _build_ack_msg(self, msg):
        ack_msg = Message()
        ack_msg.cmd = Protocol.ACK
        ack_msg.topic = msg.topic
        ack_msg.consume_group = msg.consume_group
        ack_msg.offset = msg.offset 
        ack_msg.token = self.token 
        return ack_msg
    
    def ack(self, msg, timeout=3): 
        ack_msg = self._build_ack_msg(msg) 
        return self.invoke(ack_msg, timeout=timeout)
    
    def ack_async(self, msg, timeout=3):
        ack_msg = self._build_ack_msg(msg)  
        ack_msg.ack = False #ack's ack  
        return self.send(ack_msg, timeout=timeout)


class MqClientPool:
    log = logging.getLogger(__name__)

    def __init__(self, server_address='localhost:15555', ssl_cert_file=None, maxsize=50, timeout=3):
        self.server_address = ServerAddress(server_address)

        self.maxsize = maxsize
        self.timeout = timeout
        self.ssl_cert_file = ssl_cert_file
        self.reset() 
  

    def make_client(self):
        return MqClient(self.server_address, self.ssl_cert_file)  

    def _check_pid(self):
        if self.pid != os.getpid():
            with self._check_lock:
                if self.pid == os.getpid():
                    return
                self.log.debug('new process, pid changed')
                self.destroy()
                self.reset()

    def reset(self):
        self.pid = os.getpid()
        self._check_lock = threading.Lock()

        self.client_pool = Queue.LifoQueue(self.maxsize)
        while True:
            try:
                self.client_pool.put_nowait(None)
            except Queue.Full:
                break
        self.clients = []

    def borrow_client(self):
        self._check_pid()
        client = None
        try:
            client = self.client_pool.get(block=True, timeout=self.timeout)
        except Queue.Empty:
            raise Exception('No client available')
        if client is None:
            client = self.make_client()
            self.clients.append(client)
        return client

    def return_client(self, client):
        self._check_pid()
        if client.pid != self.pid:
            return
        if not isinstance(client, (tuple, list)):
            client = [client]
        for c in client:
            try:
                self.client_pool.put_nowait(c)
            except Queue.Full:
                pass

    def close(self): 
        for client in self.clients:
            client.close()




class BrokerRouteTable:
    class Vote:
        def __init__(self, version):
            self.version = version
            self.server_list = []
            
    def __init__(self):
        self.topic_table = {}     #{ TopicName=>[TopicInfo] }
        self.server_table = {}    #{ ServerAddress=>ServerInfo }
        self.votes_table = {}     #{ TrackerAddress=>Vote }
        self.vote_factor = 0.5
        
        self.voted_trackers = {}
        
    def update_tracker(self, tracker_info):
        #1) Update votes
        tracker_address = ServerAddress(tracker_info['serverAddress']) 
        self.voted_trackers[tracker_address] = True
        vote = self.votes_table.get(tracker_address) 
        new_server_table = tracker_info['serverTable'] 
        tracker_version = tracker_info['infoVersion']
        if vote and vote.version >= tracker_version:
            return []
        server_list = []
        for server_info in new_server_table.values(): 
            server_list.append(ServerAddress(server_info['serverAddress']))

        if not vote:
            vote = BrokerRouteTable.Vote(tracker_version)
            self.votes_table[tracker_address] = vote
            
        vote.version = tracker_version
        vote.server_list = server_list            
        
        #2) Merge ServerTable
        merged_table = self.server_table.copy()
        for server_info in new_server_table.values(): 
            server_address = ServerAddress(server_info['serverAddress'])
            old_server_info = merged_table.get(server_address) 
            if old_server_info and old_server_info['infoVersion']>=server_info['infoVersion']:
                continue
            merged_table[server_address] = server_info
        self.server_table = merged_table
        
        #3) Purge  
        return self._purge()

    def remove_tracker(self, tracker_address):
        tracker_address = ServerAddress(tracker_address)
        self.votes_table.pop(tracker_address, None)
        return self._purge() 

    def _purge(self):
        to_remove = []
        for server_address in self.server_table:
            server_info = self.server_table[server_address]
            count = 0
            for tracker_address in self.votes_table:
                vote = self.votes_table[tracker_address]
                if server_address in vote.server_list:
                    count += 1
            total_count = len(self.voted_trackers)
            if count < total_count*self.vote_factor:
                to_remove.append(server_address)
        
        for server_address in to_remove:
            self.server_table.pop(server_address, None)
         
        topic_table = {}
        for key in self.server_table:
            server_info = self.server_table[key]
            server_topic_table = server_info['topicTable']
            for topic_name in server_topic_table:
                topic = server_topic_table[topic_name]
                if topic_name not in topic_table:
                    topic_table[topic_name] = [topic]
                else:
                    topic_table[topic_name].append(topic)

        self.topic_table = topic_table
        return to_remove


class _CountDownLatch(object):
    def __init__(self, count=1):
        self.count = count
        self.lock = threading.Condition()
        self.is_set = False

    def count_down(self):
        if self.is_set:
            return
        self.lock.acquire()
        self.count -= 1
        if self.count <= 0:
            self.lock.notifyAll()
            self.is_set = True
        self.lock.release()

    def wait(self, timeout=3):
        self.lock.acquire()
        if self.count > 0:
            self.lock.wait(timeout)
        self.lock.release()

class Broker:
    log = logging.getLogger(__name__) 
    def __init__(self, tracker_list=None):
        self.pool_table = {}
        self.route_table = BrokerRouteTable()
        self.ssl_cert_file_table = {}
        self.tracker_subscribers = {}
        self.ready_timeout = 3 #3 seconds
        self.ready_event = _CountDownLatch(1)
    
        self.on_server_join = None
        self.on_server_leave = None
        self.on_server_updated = None
        self.direct_mode = False
        
        
        if tracker_list:
            trackers = tracker_list.split(';')
            for tracker in trackers:
                self.add_tracker(tracker, None, False)
            self.ready_event.wait(self.ready_timeout)

    def add_tracker(self, tracker_address, cert_file=None, wait_ready=True):
        tracker_address = ServerAddress(tracker_address)
        if tracker_address in self.tracker_subscribers:
            return
        if cert_file:
            self.ssl_cert_file_table[tracker_address.address] = cert_file 

        client = MqClient(tracker_address, cert_file)
        self.tracker_subscribers[tracker_address] = client

        def tracker_connected():
            msg = Message()
            msg.cmd = Protocol.TRACK_SUB
            client.send(msg)
         
        def tracker_disconnected(): 
            to_remove = self.route_table.remove_tracker(client.server_address) 
            for server_address in to_remove:
                self._remove_server(server_address) 

        def on_message(msg):
            if msg.status != 200:
                self.log.error(msg)
                return 
            tracker_info = msg.body 
            client.server_address = ServerAddress(tracker_info['serverAddress'])
            to_remove = self.route_table.update_tracker(tracker_info)
            for server_address in self.route_table.server_table:
                server_info = self.route_table.server_table[server_address]
                self._add_server(server_info)
            
            for server_address in to_remove:
                self._remove_server(server_address) 
                
            self.ready_event.count_down()
            

        client.on_connected = tracker_connected
        client.on_disconnected = tracker_disconnected
        client.on_message = on_message
        client.start()  
        
        if wait_ready:
            self.ready_event.wait(self.ready_timeout)
    
    def add_server(self, address):
        self.direct_mode = True
        server_address = address
        if not isinstance(server_address, ServerAddress):
            server_address = ServerAddress(server_address)
        self._add_server0(server_address)
            
            
    def _add_server0(self, server_address): 
        if server_address in self.pool_table:
            return 
        
        self.log.debug('%s joined' % server_address)
        pool = MqClientPool(server_address, self.ssl_cert_file_table.get(server_address.address))
        self.pool_table[server_address] = pool
        
        if self.on_server_join:
            self.on_server_join(pool)      
    
    def _add_server(self, server_info):
        server_address = ServerAddress(server_info['serverAddress'])
        self._add_server0(server_address)
            
    def _remove_server(self, server_address): 
        self.log.debug('%s left' % server_address)
        pool = self.pool_table.pop(server_address, None)
        if pool:
            if self.on_server_leave:
                self.on_server_leave(server_address)
            pool.close()

    def select(self, selector, msg):
        if self.direct_mode:
            values = list(self.pool_table.values())
            if len(values) == 0:
                return []
            return [values[0]] #change to random selection?
        
        keys = selector(self.route_table, msg)
        if not keys or len(keys) < 1:
            raise Exception("Missing MqServer for: %s"%msg)
        res = []
        for key in keys:
            if key in self.pool_table:
                res.append(self.pool_table[key])
        return res

    def close(self):
        for address in self.tracker_subscribers:
            client = self.tracker_subscribers[address]
            client.close()
        self.tracker_subscribers.clear()

        for key in self.pool_table:
            pool = self.pool_table[key]
            pool.close()
        self.pool_table.clear()


class MqAdmin:
    def __init__(self, broker):
        self.broker = broker

        def admin_selector(route_table, msg):
            return list(route_table.server_table.keys())
        self.admin_selector = admin_selector
        self.token = None
    
    def _normalize_msg(self, msg):
        if isinstance(msg, Message):
            pass
        elif isinstance(msg, dict):
            msg = Message(msg)  
        elif isinstance(msg, MessageCtrl):
            ctrl = msg
            msg = Message()  
            ctrl.to_message(msg) 
        elif isinstance(msg, str):
            topic = msg
            msg = Message()
            msg.topic = topic
        else:
            raise Exception('msg type not support')
        if not msg.token:
            msg.token = self.token
        return msg
    
    def invoke_object(self, msg, timeout=3, selector=None): 
        pools = self.broker.select(selector or self.admin_selector, msg)
        res = []
        for pool in pools:
            client = None
            try:
                client = pool.borrow_client()
                res_i = client.invoke_object(msg, timeout=timeout)
                res.append(res_i)
            finally:
                if client:
                    pool.return_client(client)
        return res

    def declare(self, msg, timeout=3, selector=None):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.DECLARE
        return self.invoke_object(msg, timeout=timeout, selector=selector)

    def query(self, msg, timeout=3, selector=None):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.QUERY
        return self.invoke_object(msg, timeout=timeout, selector=selector)

    def remove(self, msg, timeout=3, selector=None):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.REMOVE
        return self.invoke_object(msg, timeout=timeout, selector=selector)

    def empty(self, msg, timeout=3, selector=None):
        msg = self._normalize_msg(msg)
        msg.cmd = Protocol.EMPTY
        return self.invoke_object(msg, timeout=timeout, selector=selector)


class Producer(MqAdmin):
    def __init__(self, broker):
        MqAdmin.__init__(self, broker)
        random.seed(int(time.time()))

        def produce_selector(route_table, msg):
            server_table = route_table.server_table
            topic_table = route_table.topic_table
            if len(server_table) < 1:
                return []

            if msg.topic not in topic_table:
                return []
            topic_server_list = topic_table[msg.topic]
            target = topic_server_list[0]
            for server_topic in topic_server_list:
                if target['consumerCount'] < server_topic['consumerCount']:
                    target = server_topic
            return [ServerAddress(target['serverAddress'])]

        self.produce_selector = produce_selector

    def publish(self, msg, timeout=3, selector=None):
        msg = self._normalize_msg(msg) 
        msg.cmd = Protocol.PRODUCE 
        
        pools = self.broker.select(selector or self.produce_selector, msg)
        res = []
        for pool in pools:
            client = None
            try:
                client = pool.borrow_client()
                res_i = client.invoke(msg, timeout=timeout)
                res.append(res_i)
            finally:
                if client:
                    pool.return_client(client)
        if len(res) == 1:
            return res[0]
        return res


class ConsumeThread:
    log = logging.getLogger(__name__)

    def __init__(self, pool, msg_ctrl, message_handler=None, connection_count=1, timeout=60): 
        self.pool = pool
        self.msg_ctrl = Message(msg_ctrl)  
        self.consume_timeout = timeout
        self.connection_count = connection_count
        self.message_handler = message_handler
        
        self.client_threads = []
        

    def take(self, client):
        ctrl = Message(self.msg_ctrl)
        res = client.consume(ctrl, timeout=self.consume_timeout)
        if res.status == 404:
            res = client.declare(ctrl, timeout=self.consume_timeout)
            if 'error' in res and res['error']:
                raise Exception(res['error'])
            self.msg_ctrl.consume_group = res['groupName']
            return self.take(client)
        if res.status == 200:
            res.id = res.origin_id
            del res.origin_id
            if res.origin_url:
                res.url = res.origin_url
                res.status = None 
                del res.origin_url
                
            if res.origin_method:
                res.method = res.origin_method or 'GET'  
                del res.origin_method  

            return res
        raise Exception(res.body) 
    
    def _run_client(self, client):
        if not self.message_handler:
            raise Exception("missing message_handler")
        
        ctrl = Message(self.msg_ctrl)
        res = client.declare(ctrl, timeout=self.consume_timeout)
        if 'error' in res and res['error']:
            raise Exception(res['error'])
        self.msg_ctrl.consume_group = res['groupName']
           
        while True:
            try:
                msg = self.take(client)
                if not msg: continue
                self.message_handler(msg, client)
            except socket.timeout:
                continue
            except Exception as e: 
                self.log.error(e)
                break
    
    def start(self):
        for _ in range(self.connection_count):
            client = self.pool.make_client()
            thread = threading.Thread(target=self._run_client, args=(client,))
            thread.client = client 
            thread.start()
            
            self.client_threads.append(thread)

    def close(self):
        for thread in self.client_threads:
            thread.client.close() 



class Consumer(MqAdmin):
    log = logging.getLogger(__name__)

    def __init__(self, broker, msg_ctrl):
        MqAdmin.__init__(self, broker) 
        
        self.msg_ctrl = self._normalize_msg(msg_ctrl)
        if not self.msg_ctrl.consume_group and not self.msg_ctrl.group_name_auto:
            self.msg_ctrl.consume_group = self.msg_ctrl.topic
        
        def consume_selecotr(route_table, msg):
            return list(route_table.server_table.keys())
        
        self.consume_selector = consume_selecotr
        self.connection_count = 1
        self.consume_timeout = 60*1000 #milliseconds
        self.message_handler = None

        self.consume_thread_groups = {}

    def start_consume_thread(self, pool):
        if pool.server_address in self.consume_thread_groups:
            return

        consume_thread = ConsumeThread(pool, self.msg_ctrl, message_handler=self.message_handler, 
            connection_count=self.connection_count, timeout=self.consume_timeout/1000.0)

        self.consume_thread_groups[pool.server_address] = consume_thread
        consume_thread.start()


    def start(self):
        def on_server_join(pool):
            self.start_consume_thread(pool)

        def on_server_leave(server_address):
            consume_thread = self.consume_thread_groups.pop(server_address, None)
            if consume_thread:
                consume_thread.close()

        self.broker.on_server_join = on_server_join
        self.broker.on_server_leave = on_server_leave

        pools = self.broker.select(self.consume_selector, self.msg_ctrl)
        for pool in pools:
            self.start_consume_thread(pool)


class RpcInvoker:
    log = logging.getLogger(__name__)

    def __init__(self, broker=None, topic=None, module=None, method=None, timeout=3, selector=None, token=None, producer=None):
        self.producer = producer or Producer(broker)
        self.producer.token = token

        self.topic = topic
        self.timeout = timeout
        self.server_selector = selector

        self.method = method
        self.module = module

    def __getattr__(self, name):
        return RpcInvoker(method=name, topic=self.topic, module=self.module,
                          producer=self.producer, timeout=self.timeout, selector=self.server_selector)

    def invoke(self, method=None, params=None, module='', topic=None, selector=None):
        topic = topic or self.topic
        if not topic:
            raise Exception("missing topic")

        selector = selector or self.server_selector
        req = {
            'method': method or self.method,
            'params': params,
            'module': module or self.module,
        }

        msg = Message()
        msg.topic = topic
        msg.ack = False  # RPC ack must set to False to wait return
        msg.body = req

        msg_res = self.producer.publish(msg, timeout=self.timeout, selector=selector)

        if isinstance(msg_res.body, bytearray):
            msg_res.body = msg_res.body.decode(msg_res.encoding or 'utf8') 
            msg_res.body = json.loads(str(msg_res.body))  
        
        elif isinstance(msg_res.body, (str,bytes)) and msg_res['content-type'] == 'application/json':
            msg_res.body = json.loads(str(msg_res.body))  

        if msg_res.status != 200: 
            raise Exception(msg_res.body)
        
        return msg_res.body  
        
    def __call__(self, *args, **kv_args):
        return self.invoke(params=args, **kv_args)


def Remote(_id=None):
    def func(fn):
        fn.remote_id = _id or fn.__name__
        return fn
    return func

RpcInfoTemplate = '''
<html><head>
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
<title>%s Python</title>     
%s
</head>
<body>    
<div>  
<div class="url">
    <span>URL=/%s/[module]/[method]/[param1]/[param2]/...</span>
    <a href="/">zbus</a>
    <a href="/%s">service home</a>
</div>
<table class="table"> 
<thead>
<tr class="table-info"> 
    <th class="returnType">Return Type</th> 
    <th class="methodParams">Method and Params</th> 
    <th class="modules">Module</th>
</tr> 
<thead> 
<tbody> 
%s 
</tbody> 
</table> </div> </body></html>
'''

RpcStyleTemplate = '''
<style type="text/css"> 
body {
    font-family: -apple-system,system-ui,BlinkMacSystemFont,"Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
    font-size: 1rem;
    font-weight: 400;
    line-height: 1.5;
    color: #292b2c;
    background-color: #fff;
    margin: 0px;
    padding: 0px;
}
table {  background-color: transparent;  display: table; border-collapse: separate;  border-color: grey; } 
.table { width: 100%; max-width: 100%;  margin-bottom: 1rem; }  
.table th {  height: 30px; }
.table td, .table th {    border-bottom: 1px solid #eceeef;   text-align: left; } 
th.returnType {  width: 20%; }
th.methodParams {   width: 40%; }
th.modules {  width: 40%; }
thead { display: table-header-group; vertical-align: middle; border-color: inherit;}
tbody { display: table-row-group; vertical-align: middle; border-color: inherit;}  
tr { display: table-row;  vertical-align: inherit; border-color: inherit; }
.table-info, .table-info>td, .table-info>th { background-color: #dff0d8; }
.url { margin: 4px 0; } 
</style>
'''

RpcModuleTemplate = '''
<tr>
    <td class="returnType"></td>
    <td class="methodParams">
        <code><strong><a href="%s"/>%s</a></strong>(%s)</code>     
    </td> 
    <td class="modules"> <a href="/%s/%s">%s</a>  </td>
</tr> 
'''

def _build_module_info(rpc, module):
    module_info = '' 
    module_methods = rpc.modules[module]
    for method_name in module_methods:
        m = module_methods[method_name] 
        link = '/%s/%s/%s'%(rpc.topic_context, module, m['method'])
        args = ', '.join(m['params'])
        module_info += RpcModuleTemplate%(link, m['method'], args, rpc.topic_context, module, module)
    return module_info

def _build_rpc_info(rpc, module= None):
    modules_info = ''  
    if module is None:
        for module_name in rpc.modules:
            modules_info += _build_module_info(rpc, module_name)
    else:
        modules_info = _build_module_info(rpc, module)
    return RpcInfoTemplate%(rpc.topic_context, RpcStyleTemplate, rpc.topic_context, rpc.topic_context, modules_info)

class RpcRootInfo:
    def __init__(self, rpc_processor):
        self.rpc_processor = rpc_processor
        
    def index(self):
        res = Message()
        res.status = 200 
        res['content-type'] = 'text/html; charset=utf-8' 
        res.body = _build_rpc_info(self.rpc_processor)
           
        return res

class RpcModuleInfo: 
    def __init__(self, module, rpc_processor):
        self.module = module
        self.rpc_processor = rpc_processor
        
    def index(self):
        res = Message()
        res.status = 200 
        res['content-type'] = 'text/html; charset=utf-8' 
        res.body = _build_rpc_info(self.rpc_processor, self.module) 
        return res 

class RpcProcessor:
    log = logging.getLogger(__name__) 
    
    def __init__(self):
        self.modules = {}
        self.methods = {}   
        self.topic_context = ''
        
        self.add_module('index', RpcRootInfo(self))

    def add_module(self, module, service):
        if inspect.isclass(service):
            service = service()
        
        module_methods = None
        if module not in self.modules:
            module_methods = {}
            self.modules[module] = module_methods
        else:
            module_methods = self.modules[module]
            
        methods = inspect.getmembers(service, predicate=inspect.ismethod)
        for method in methods:
            method_name = str(method[0])
            if method_name.startswith('__'):
                continue
            
            if hasattr(method[1], 'remote_id'):
                method_name = getattr(method[1], 'remote_id')

            key = '%s:%s' % (module, method_name)
            if key in self.methods:
                self.log.warn('%s duplicated' % key)
            
            params = inspect.getargspec(method[1]) 
            self.methods[key] = (method[1], params.args) 
            
            args = params[0][1:]
            module_methods[method_name] = {'method': method_name, 'params': args}
        
        module_index_key = '%s:index'%module
        if module_index_key not in self.methods:
            module_info = RpcModuleInfo(module, self)  
            self.methods[module_index_key] = (module_info.index, ['self'])
        
        

    def _get_value(self, req, name, default=None):
        if name not in req:
            return default
        return req[name] or default

    def handle_request(self, msg, client):
        charset = msg.encoding or 'utf8'
        msg_res = Message()  
        msg_res.recver = msg.sender
        msg_res.id = msg.id

        error = None
        result = None
        try:
            if isinstance(msg.body, (bytes, bytearray, str)):
                msg.body = json.loads(msg.body, encoding=charset) 

            req = msg.body

        except Exception as e:
            error = e

        if not error:
            try:
                method = self._get_value(req, 'method', 'index')
                module = self._get_value(req, 'module', 'index')
                params = self._get_value(req, 'params', [])
            except Exception as e:
                error = e

        if not error:
            key = '%s:%s' % (module, method)
            if key not in self.methods:
                error = Exception('%s method not found' % key)
            else:
                method_info = self.methods[key]
                method = method_info[0]
                params_len = len(method_info[1])

        if not error:
            try: 
                if params_len-1 > len(params):
                    params.append(msg) #last parameter optional as Message context
                result = method(*params)
            except Exception as e:
                error = e
        
        if error:
            msg_res.status = 600
            msg_res.body = str(error)
        else:
            if isinstance(result, Message):
                if not result.status:
                    result.status = 200
                result.encoding = charset
                result.recver = msg_res.recver
                result.id = msg_res.id
                msg_res = result
            else:
                msg_res.status = 200
                msg_res.body = result #json.dumps(result,ensure_ascii=False).encode(charset)
                msg_res['content-type'] = 'application/json'
                msg_res['encoding'] = charset
                

        try: client.route(msg_res)
        except Exception as e:
            print(e)
            pass

    def __call__(self, *args, **kv_args):
        return self.handle_request(*args)


class ServiceBootstrap:
    
    def __init__(self):
        self._processor = RpcProcessor() 
        self._service_address= None 
        self._connection_count = 1  
        self.consume_headers = Message()
        self.consume_headers.topic_mask = Protocol.MASK_RPC | Protocol.MASK_MEMORY
    
    def service_name(self, value):
        self.consume_headers.topic = value
        self._processor.topic_context = value
        
    def service_mask(self, value):
        self.consume_headers.topic_mask = Protocol.MASK_RPC | value
    
    def service_token(self, value):
        self.consume_headers.token = value
        
    def service_address(self, value):
        self._service_address = value
    
    def connection_count(self, value):
        self._connection_count = value
    
    def start(self):
        self.broker = Broker(self._service_address)  
        self.consumer = Consumer(self.broker, self.consume_headers)
        self.consumer.connection_count = self._connection_count
        self.consumer.message_handler = self._processor
        self.consumer.start()
        
    def add_module(self, module, service):
        self._processor.add_module(module, service)
        
    def close(self):
        if self.consumer:
            pass #TODO
        if self.broker:
            self.broker.close()




class ClientBootstrap:
    
    def __init__(self): 
        self._service_address= None  
        self._direct_mode = False
        self._broker = None 
        self._service_token = None
        self._service_name = None
        self._module = None
        self._timeout = 10 #10 seconds
        
    def service_address(self, value):
        self._service_address = value
    
    def service_token(self, value):
        self._service_token = value
    
    def service_name(self, value):
        self._service_name = value
        
    def module(self, value):
        self._module = value 
    
    def timeout(self, value):
        '''
        timeout unit = second
        '''
        self._timeout = value
    
    def ha(self, value):
        self._direct_mode = value
    
    def invoker(self, service_name=None, service_token=None, module=None):
        service_name = service_name or self._service_name
        if not service_name:
            raise Exception("Missing service name")
        
        service_token = service_token or self._service_token 
        if not self._broker:
            if self._direct_mode:
                self._broker = Broker(self._service_address)
            else:
                self._broker = Broker()
                self._broker.add_server(self._service_address)
        rpc = RpcInvoker(broker=self._broker, 
                         topic=service_name, 
                         token=service_token,
                         timeout=self._timeout) 
        module = module or self._module
        if module:
            rpc.module = module
        return rpc
        
    def close(self): 
        if self._broker:
            self._broker.close()        
        
    

__all__ = [
    Message, MessageClient, MqClient, MqClientPool, ServerAddress, MessageCtrl,
    Broker, MqAdmin, Producer, Consumer, RpcInvoker, RpcProcessor, Remote,
    ConsumeThread, Protocol, ServiceBootstrap, ClientBootstrap
]
