###############################################################################
##
##  Copyright (C) 2014 Tavendo GmbH
##
##  This program is free software: you can redistribute it and/or modify
##  it under the terms of the GNU Affero General Public License, version 3,
##  as published by the Free Software Foundation.
##
##  This program is distributed in the hope that it will be useful,
##  but WITHOUT ANY WARRANTY; without even the implied warranty of
##  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
##  GNU Affero General Public License for more details.
##
##  You should have received a copy of the GNU Affero General Public License
##  along with this program. If not, see <http://www.gnu.org/licenses/>.
##
###############################################################################

from __future__ import absolute_import

__all__ = ['CrossbarWampWebSocketServerFactory',
           'CrossbarWampRawSocketServerFactory']

import datetime

from twisted.python import log
from autobahn.twisted.websocket import WampWebSocketServerProtocol, \
                                       WampWebSocketServerFactory

from autobahn.twisted.rawsocket import WampRawSocketServerProtocol, \
                                       WampRawSocketServerFactory

from twisted.internet.defer import Deferred

import json
import urllib
import Cookie

from autobahn import util
from autobahn.websocket import http
from autobahn.websocket.compress import *

from autobahn.wamp import types
from autobahn.wamp import message
from autobahn.wamp.router import RouterFactory
from autobahn.twisted.wamp import RouterSession, RouterSessionFactory

import crossbar



def set_websocket_options(factory, options):
   """
   Set WebSocket options on a WebSocket or WAMP-WebSocket factory.

   :param factory: The WebSocket or WAMP-WebSocket factory to set options on.
   :type factory:  Instance of :class:`autobahn.twisted.websocket.WampWebSocketServerFactory`
                   or :class:`autobahn.twisted.websocket.WebSocketServerFactory`.
   :param options: Options from Crossbar.io transport configuration.
   :type options: dict
   """
   c = options

   versions = []
   if c.get("enable_hixie76", True):
      versions.append(0)
   if c.get("enable_hybi10", True):
      versions.append(8)
   if c.get("enable_rfc6455", True):
      versions.append(13)

   ## FIXME: enforce!!
   ##
   #self.connectionCap = c.get("max_connections")

   ## convert to seconds
   ##
   openHandshakeTimeout = float(c.get("open_handshake_timeout", 0))
   if openHandshakeTimeout:
      openHandshakeTimeout = openHandshakeTimeout / 1000.

   closeHandshakeTimeout = float(c.get("close_handshake_timeout", 0))
   if closeHandshakeTimeout:
      closeHandshakeTimeout = closeHandshakeTimeout / 1000.

   factory.setProtocolOptions(versions = versions,
                              allowHixie76 = c.get("enable_hixie76", True),
                              webStatus = c.get("enable_webstatus", True),
                              utf8validateIncoming = c.get("validate_utf8", True),
                              maskServerFrames = c.get("mask_server_frames", False),
                              requireMaskedClientFrames = c.get("require_masked_client_frames", True),
                              applyMask = c.get("apply_mask", True),
                              maxFramePayloadSize = c.get("max_frame_size", 0),
                              maxMessagePayloadSize = c.get("max_message_size", 0),
                              autoFragmentSize = c.get("auto_fragment_size", 0),
                              failByDrop = c.get("fail_by_drop", False),
                              echoCloseCodeReason = c.get("echo_close_codereason", False),
                              openHandshakeTimeout = openHandshakeTimeout,
                              closeHandshakeTimeout = closeHandshakeTimeout,
                              tcpNoDelay = c.get("tcp_nodelay", True))

   ## WebSocket compression
   ##
   factory.setProtocolOptions(perMessageCompressionAccept = lambda _: None)
   if 'compression' in c:

      ## permessage-deflate
      ##
      if 'deflate' in c['compression']:

         log.msg("enabling WebSocket compression (permessage-deflate)")

         params = c['compression']['deflate']

         requestNoContextTakeover   = params.get('request_no_context_takeover', False)
         requestMaxWindowBits       = params.get('request_max_window_bits', 0)
         noContextTakeover          = params.get('no_context_takeover', None)
         windowBits                 = params.get('max_window_bits', None)
         memLevel                   = params.get('memory_level', None)

         def accept(offers):
            for offer in offers:
               if isinstance(offer, PerMessageDeflateOffer):
                  if (requestMaxWindowBits == 0 or offer.acceptMaxWindowBits) and \
                     (not requestNoContextTakeover or offer.acceptNoContextTakeover):
                     return PerMessageDeflateOfferAccept(offer,
                                                         requestMaxWindowBits = requestMaxWindowBits,
                                                         requestNoContextTakeover = requestNoContextTakeover,
                                                         noContextTakeover = noContextTakeover,
                                                         windowBits = windowBits,
                                                         memLevel = memLevel)

         factory.setProtocolOptions(perMessageCompressionAccept = accept)



class CrossbarWampWebSocketServerProtocol(WampWebSocketServerProtocol):

   ## authid -> cookie -> set(connection)

   def onConnect(self, request):

      protocol, headers = WampWebSocketServerProtocol.onConnect(self, request)

      self._origin = request.origin

      ## transport authentication
      ##
      self._authid = None
      self._authrole = None
      self._authmethod = None

      ## cookie tracking
      ##
      self._cbtid = None
      if 'cookie' in self.factory._config:

         cookie_config = self.factory._config['cookie']
         cookie_id_field = cookie_config.get('name', 'cbtid')
         cookie_id_field_length = int(cookie_config.get('length', 24))

         ## see if there already is a cookie set ..
         if request.headers.has_key('cookie'):
            try:
               cookie = Cookie.SimpleCookie()
               cookie.load(str(request.headers['cookie']))
            except Cookie.CookieError:
               pass
            else:
               if cookie.has_key(cookie_id_field):
                  cbtid = cookie[cookie_id_field].value
                  if self.factory._cookies.has_key(cbtid):
                     self._cbtid = cbtid
                     if self.debug:
                        log.msg("Cookie already set: %s" % self._cbtid)

         ## if no cookie is set, create a new one ..
         if self._cbtid is None:

            self._cbtid = util.newid(cookie_id_field_length)

            ## http://tools.ietf.org/html/rfc6265#page-20
            ## 0: delete cookie
            ## -1: preserve cookie until browser is closed

            max_age = cookie_config.get('max_age', 86400 * 30 * 12)

            cbtData = {'created': util.utcnow(),
                       'authid': None,
                       'authrole': None,
                       'authmethod': None,
                       'max_age': max_age,
                       'connections': set()}

            self.factory._cookies[self._cbtid] = cbtData

            ## do NOT add the "secure" cookie attribute! "secure" refers to the
            ## scheme of the Web page that triggered the WS, not WS itself!!
            ##
            headers['Set-Cookie'] = '%s=%s;max-age=%d' % (cookie_id_field, self._cbtid, max_age)
            if self.debug:
               log.msg("Setting new cookie: %s" % headers['Set-Cookie'])

         ## add this WebSocket connection to the set of connections
         ## associated with the same cookie
         self.factory._cookies[self._cbtid]['connections'].add(self)

         if self.debug:
            log.msg("Cookie tracking enabled on WebSocket connection {}".format(self))

         ## if cookie-based authentication is enabled, set auth info from cookie store
         ##
         if 'auth' in self.factory._config and 'cookie' in self.factory._config['auth']:
            self._authid = self.factory._cookies[self._cbtid]['authid']
            self._authrole = self.factory._cookies[self._cbtid]['authrole']
            self._authmethod = "cookie.{}".format(self.factory._cookies[self._cbtid]['authmethod'])
            if self.debug:
               log.msg("Authenticating client via cookie")
         else:
            if self.debug:
               log.msg("Cookie-based authentication disabled")

      else:

         if self.debug:
            log.msg("Cookie tracking disabled on WebSocket connection {}".format(self))

      ## accept the WebSocket connection, speaking subprotocol `protocol`
      ## and setting HTTP headers `headers`
      return (protocol, headers)


   def sendServerStatus(self, redirectUrl = None, redirectAfter = 0):
      """
      Used to send out server status/version upon receiving a HTTP/GET without
      upgrade to WebSocket header (and option serverStatus is True).
      """
      try:
         page = self.factory._templates.get_template('cb_ws_status.html')
         self.sendHtml(page.render(redirectUrl = redirectUrl,
                                   redirectAfter = redirectAfter,
                                   cbVersion = crossbar.__version__,
                                   wsUri = self.factory.url))
      except Exception as e:         
         log.msg("Error rendering WebSocket status page template: %s" % e)


   def onDisconnect(self):
      ## remove this WebSocket connection from the set of connections
      ## associated with the same cookie
      if self._cbtid:
         self.factory._cookies[self._cbtid]['connections'].discard(self)



class CrossbarWampWebSocketServerFactory(WampWebSocketServerFactory):

   protocol = CrossbarWampWebSocketServerProtocol

   def __init__(self, factory, config, templates, debug = False):
      """
      Ctor.

      :param factory: WAMP session factory.
      :type factory: An instance of ..
      :param config: Crossbar transport configuration.
      :type config: dict 
      """
      self.debug = config.get('debug', debug)

      options = config.get('options', {})

      server = "Crossbar/{}".format(crossbar.__version__)
      externalPort = options.get('external_port', None)

      ## explicit list of WAMP serializers
      ##
      if 'serializers' in options:
         serializers = []
         sers = set(options['serializers'])

         if 'json' in sers:
            ## try JSON WAMP serializer
            try:
               from autobahn.wamp.serializer import JsonSerializer
               serializers.append(JsonSerializer())
            except ImportError:
               print("Warning: could not load WAMP-JSON serializer")
            else:
               sers.discard('json')

         if 'msgpack' in sers:
            ## try MsgPack WAMP serializer
            try:
               from autobahn.wamp.serializer import MsgPackSerializer
               serializers.append(MsgPackSerializer())
            except ImportError:
               print("Warning: could not load WAMP-MsgPack serializer")
            else:
               sers.discard('msgpack')

         if not serializers:
            raise Exception("no valid WAMP serializers specified")

         if len(sers) > 0:
            raise Exception("invalid WAMP serializers specified: {}".format(sers))

      else:
         serializers = None

      WampWebSocketServerFactory.__init__(self,
                                          factory,
                                          serializers = serializers,
                                          url = config.get('url', None),
                                          server = server,
                                          externalPort = externalPort,
                                          debug = self.debug,
                                          debug_wamp = self.debug)

      ## transport configuration
      self._config = config

      ## Jinja2 templates for 404 etc
      self._templates = templates

      ## cookie tracking
      if 'cookie' in config:
         self._cookies = {}

      ## set WebSocket options
      set_websocket_options(self, options)



class CrossbarWampRawSocketServerProtocol(WampRawSocketServerProtocol):

   def connectionMade(self):
      WampRawSocketServerProtocol.connectionMade(self)
      ## transport authentication
      ##
      self._authid = None
      self._authrole = None
      self._authmethod = None



class CrossbarWampRawSocketServerFactory(WampRawSocketServerFactory):

   protocol = CrossbarWampRawSocketServerProtocol

   def __init__(self, factory, config):

      ## transport configuration
      self._config = config

      ## WAMP serializer
      ##
      serid = config.get('serializer', 'msgpack')

      if serid == 'json':
         ## try JSON WAMP serializer
         try:
            from autobahn.wamp.serializer import JsonSerializer
            serializer = JsonSerializer()
         except ImportError:
            raise Exception("could not load WAMP-JSON serializer")

      elif serid == 'msgpack':
         ## try MsgPack WAMP serializer
         try:
            from autobahn.wamp.serializer import MsgPackSerializer
            serializer = MsgPackSerializer()
            serializer._serializer.ENABLE_V5 = False ## FIXME
         except ImportError:
            raise Exception("could not load WAMP-MsgPack serializer")

      else:
         raise Exception("invalid WAMP serializer '{}'".format(serid))

      WampRawSocketServerFactory.__init__(self, factory, serializer)
