#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) Branson Stephens (2015)
#
# This file is part of lvalert-overseer
#
# lvalert-overseer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# It is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LIGO.ORG.  If not, see <http://www.gnu.org/licenses/>.

import json
from datetime import datetime
import logging
from hashlib import sha1

from twisted.internet.protocol import ServerFactory, Protocol
from twisted.internet import reactor, task
from twisted.internet.error import ReactorNotRunning

from threading import Thread
from optparse import OptionParser

# pubsub import must come first because it overloads part of the
# StanzaProcessor class
from ligo.lvalert import pubsub
from ligo.overseer.lvalert_client import LVAlertClient

from pyxmpp.all import JID
from pyxmpp.exceptions import SASLAuthenticationFailed

import pkg_resources

#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------
# Parse options
#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------

parser = OptionParser() 

#username and password
parser.add_option("-a", "--username", action="store", type="string",
    default="", help="the username of the publisher or listener")
parser.add_option("-b", "--password", action="store", type="string",
    default="", help="the password of the publisher or listener")
parser.add_option("-s", "--server", action="store", type="string",
    default="lvalert.cgca.uwm.edu", help="the pubsub server")
parser.add_option("-r", "--resource", action="store", type="string",
    default="overseer", help="resource to use in JID")

# server options
parser.add_option("-p", "--port", action="store", type="int",
    default=8000, help="port for the overseer server to listen on")
parser.add_option("-l", "--audit-filename", action="store", type="string",
    default="audit.log", help="name for an audit (verbose) log file")
parser.add_option("-e", "--error-filename", action="store", type="string",
    default="error.log", help="name for an error log file")

# debugging options
parser.add_option("-d", "--debug", action="store_true",
    default=False, help="should print out lots of information")

# timeout 
parser.add_option("-m", "--max_attempts", action="store",
    default=10, help="max number of timeouts allowed for sending")
parser.add_option("-t", "--msg-timeout", action="store",
    default=10, help="time in seconds after which a message is considered lost")

# version
parser.add_option("-v", "--version", action="store_true",
    default=False, help="display version information")

options,args = parser.parse_args()

if options.version:
    version = pkg_resources.require("ligo-lvalert-overseer")[0].version
    print "LVAlert Overseer v. %s" % version
    exit(0)

if not options.username:
    raise ValueError, "--username is required"
if not options.password:
    raise ValueError, "--username is required"

#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------
# Configure logging
#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------

# We are going to have the main thread as well as two subsidiary threads.
# Logs from each will be distinguished by the name of the logger.
ovrseer_logger = logging.getLogger('ovrseer')
lvalert_logger = logging.getLogger('lvalert')

formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')

audit_fh = logging.FileHandler(options.audit_filename)
audit_fh.setLevel(logging.DEBUG)
audit_fh.setFormatter(formatter)

error_fh = logging.FileHandler(options.error_filename)
error_fh.setLevel(logging.ERROR)
error_fh.setFormatter(formatter)

for logger in [ovrseer_logger, lvalert_logger]:
    if options.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    logger.addHandler(audit_fh)
    logger.addHandler(error_fh)

pyxmpp_logger = logging.getLogger('pyxmpp')
if options.debug:
    pyxmpp_logger.setLevel(logging.DEBUG)
else:
    pyxmpp_logger.setLevel(logging.ERROR)
pyxmpp_logger.addHandler(audit_fh)
pyxmpp_logger.addHandler(error_fh)

#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------
# LVAlert Overseer 
#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------

class LVAOverseerProtocol(Protocol):

    def badRequest(self, msg):
        """What to do if the client sends us a bad request."""
        rdict = {'error': msg, 'success': False}
        ovrseer_logger.error("Bad request: " + msg)
        self.transport.write(json.dumps(rdict))
        return

    def dataReceived(self, data):
        """Handle data delivered to the server. We will assume this is JSON
        and parse the message and node name out of it."""

        try:
            master_dict = json.loads(data)
        except Exception, e:
            self.badRequest(str(e))
            return

        node = master_dict.get('node_name', None)
        if not node:
            self.badRequest('node name missing')
            return
        message = master_dict.get('message', None)
        if not message:
            self.badRequest('message missing')
            return
        action = master_dict.get('action', None)
        if action not in ['push', 'pop']:
            self.badRequest('action must be "push" or "pop"')
            return

        # Calculate the message id
        message_id = sha1(node + message).hexdigest()

        if action == 'push':
            # Push the message into the list with a timestamp
            self.factory.outstanding_messages.update({
                message_id: datetime.now() })
            
            # Now send out the lvalert message.
            try:
                ovrseer_logger.info("sending %s" % message_id)
                self.factory.client.sendMessage(node, message)
            except Exception, e:
                ovrseer_logger.error("SEND FAILED: %s" % str(e)) 
            self.transport.write(json.dumps({'success': True}))
        elif action == 'pop':
            # Pop the message in question from the outstanding dict.
            send_time = self.factory.outstanding_messages.pop(message_id, None)
            if send_time:
                dt = datetime.now() - send_time
                # NOTE: total_seconds only works for python 2.7 or later.
                ovrseer_logger.info("received %s with latency: %s" % (message_id, dt.total_seconds()))
            self.transport.write(json.dumps({'success': True}))

class LVAOverseerFactory(ServerFactory):

    protocol = LVAOverseerProtocol

    def __init__(self, client = None, thread = None):
        self.client = client
        self.thread = thread
        self.outstanding_messages = {}

    # Kill the connection to the LVAlert server. 
    def stopFactory(self):
        try:
            self.thread.kill_switch = True
            self.client.disconnect()
            self.thread.join()
        except Exception, e:
            ovrseer_logger.error("Problem killing the LVAlert loop: %s" % str(e))
            exit(1)

    # Only accept connections from localhost
    def buildProtocol(self, addr):
        if addr.host == "127.0.0.1":
            return ServerFactory.buildProtocol(self, addr)
        return None    

    # Routine to check for old outstanding messages and write errors.
    # We might want to hook this up to Nagios, by touching a file that 
    # Nagios would be able to check for. Or just have it send texts or emails.
    # This should be run periodically on the same interval as the wait time.
    def checkOutstandingMessages(self):
        msg = "before culling outstanding, N = %d" % len(self.outstanding_messages.keys())
        ovrseer_logger.debug(msg)
        pop_list = []
        for message_id, send_time in self.outstanding_messages.iteritems():
            dt = datetime.now() - send_time
            if dt.total_seconds() > options.msg_timeout:
                msg = "message %s not received after %d seconds. Popping..." % (message_id, 
                    options.msg_timeout)
                ovrseer_logger.error(msg)
                pop_list.append(message_id)
        # Pop the old messages from the list.
        for message_id in pop_list:
            self.outstanding_messages.pop(message_id, None)
        msg = "after culling outstanding, N = %d" % len(self.outstanding_messages.keys())
        ovrseer_logger.debug(msg)

#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------
# Thread object for starting up an LVAlert client
#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------

class LVAlertThread(Thread):
    def __init__(self, lvalert_client):
        Thread.__init__(self)
        self.lvalert_client = lvalert_client
        self.kill_switch = False
        self.daemon = True

    def run(self):
        # If we get kicked off, just try to reconnect and start it over again.
        while not self.kill_switch:
            self.lvalert_client.logger.info("connecting... ")
            try:
                self.lvalert_client.connect()
                self.lvalert_client.loop(1)
            except SASLAuthenticationFailed:
                msg = "Authentication Error: Wrong password."
                self.lvalert_client.logger.error(msg)
                self.kill_switch = True
                try:
                    reactor.callFromThread(reactor.stop)
                except ReactorNotRunning:
                    pass
            except Exception, e:
                msg = "Error: %s" % str(e)
                self.lvalert_client.logger.error(msg)
                self.lvalert_client.logger.error("Exception type: %s" % type(e).__name__)
                self.kill_switch = True
                try:
                    reactor.callFromThread(reactor.stop)
                except ReactorNotRunning:
                    pass

#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------
# LVAlert client class
#-------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------

if __name__ == '__main__':

    # Instantiate the jabber client 
    jid=JID(options.username+"@"+options.server+"/"+options.resource)
    lvalert_client = LVAlertClient(jid, options.password, options.max_attempts, 
        lvalert_logger, options.port)
    lvalert_thread = LVAlertThread(lvalert_client)
    lvalert_thread.start()

    # Start the twisted server. This will listen for new data to transmit via LVAlert.
    f = LVAOverseerFactory(lvalert_client, lvalert_thread)
    reactor.listenTCP(options.port, f)
    # Set a looping call to cull the outstanding messages that are too old.
    l = task.LoopingCall(f.checkOutstandingMessages)
    l.start(options.msg_timeout)
    ovrseer_logger.info("starting the overseer reactor...")
    reactor.run()
