PK=JG4U_&twisted/plugins/vumi_worker_starter.py"""Plugins for starting Vumi workers from twistd.""" from vumi.servicemaker import (VumiWorkerServiceMaker, DeprecatedStartWorkerServiceMaker) # Having instances of IServiceMaker present magically announces the # service makers to twistd. # See: http://twistedmatrix.com/documents/current/core/howto/tap.html vumi_worker = VumiWorkerServiceMaker() start_worker = DeprecatedStartWorkerServiceMaker() PK]axGOp+vumi-0.6.9.data/scripts/vumi_redis_tools.py#!python # -*- test-case-name: vumi.scripts.tests.test_vumi_redis_tools -*- import re import sys import yaml from twisted.python import usage from vumi.persist.redis_manager import RedisManager class TaskError(Exception): """Raised when an error is encoutered while using tasks.""" class Task(object): """ A task to perform on a redis key. """ name = None hidden = False # set to True to hide from docs runner = None redis = None @classmethod def parse(cls, task_desc): """ Parse a string description into a task. Task description format:: [:[=[,...]]] """ task_type, _, param_desc = task_desc.partition(':') task_cls = cls._parse_task_type(task_type) params = {} if param_desc: params = cls._parse_param_desc(param_desc) return task_cls(**params) @classmethod def task_types(cls): return cls.__subclasses__() @classmethod def _parse_task_type(cls, task_type): names = dict((t.name, t) for t in cls.task_types()) if task_type not in names: raise TaskError("Unknown task type %r" % (task_type,)) return names[task_type] @classmethod def _parse_param_desc(cls, param_desc): params = [x.partition('=') for x in param_desc.split(',')] params = [(p, v) for p, _sep, v in params] return dict(params) def init(self, runner, redis): self.runner = runner self.redis = redis def before(self): """Run once before the task applied to any keys.""" def after(self): """Run once afer the task has been applied to all keys.""" def process_key(self, key): """Run once for each key. May return either the name of the key (if the key should be processed by later tasks), the new name of the key (if the key was renamed and should be processed by later tasks) or ``None`` (if the key has been deleted or should not be processed by further tasks). """ return key class Count(Task): """A task that counts the number of keys.""" name = "count" def __init__(self): self.count = None def before(self): self.count = 0 def after(self): self.runner.emit("Found %d matching keys." % (self.count,)) def process_key(self, key): self.count += 1 return key class Expire(Task): """A task that sets an expiry time on each key.""" name = "expire" def __init__(self, seconds): self.seconds = int(seconds) def process_key(self, key): self.redis.expire(key, self.seconds) return key class Persist(Task): """A task that persists each key.""" name = "persist" def process_key(self, key): self.redis.persist(key) return key class ListKeys(Task): """A task that prints out each key.""" name = "list" def process_key(self, key): self.runner.emit(key) return key class Skip(Task): """A task that skips keys that match a regular expression.""" name = "skip" def __init__(self, pattern): self.regex = re.compile(pattern) def process_key(self, key): if self.regex.match(key): return None return key class Options(usage.Options): synopsis = " [-t ...]" longdesc = "Perform tasks on Redis keys." def __init__(self): usage.Options.__init__(self) self['tasks'] = [] def getUsage(self, width=None): doc = usage.Options.getUsage(self, width=width) header = "Available tasks:" tasks = sorted(Task.task_types(), key=lambda t: t.name) tasks_doc = "".join(usage.docMakeChunks([{ 'long': task.name, 'doc': task.__doc__, } for task in tasks if not task.hidden])) return "\n".join([doc, header, tasks_doc]) def parseArgs(self, config_file, match_pattern): self['config'] = yaml.safe_load(open(config_file)) self['match_pattern'] = match_pattern def opt_task(self, task_desc): """A task to perform on all matching keys.""" task = Task.parse(task_desc) self['tasks'].append(task) opt_t = opt_task def postOptions(self): if not self['tasks']: raise usage.UsageError("Please specify a task.") def scan_keys(redis, match): """Iterate over matching keys.""" prev_cursor = None while True: cursor, keys = redis.scan(prev_cursor, match=match) for key in keys: yield key if cursor is None: break if cursor == prev_cursor: raise TaskError("Redis scan stuck on cursor %r" % (cursor,)) prev_cursor = cursor class TaskRunner(object): stdout = sys.stdout def __init__(self, options): self.options = options self.match_pattern = options['match_pattern'] self.tasks = options['tasks'] self.redis = self.get_redis(options['config']) def emit(self, s): """ Print the given string and then a newline. """ self.stdout.write(s) self.stdout.write("\n") def get_redis(self, config): """ Create and return a redis manager. """ redis_config = config.get('redis_manager', {}) return RedisManager.from_config(redis_config) def run(self): """ Apply all tasks to all keys. """ for task in self.tasks: task.init(self, self.redis) for task in self.tasks: task.before() for key in scan_keys(self.redis, self.match_pattern): for task in self.tasks: key = task.process_key(key) if key is None: break for task in self.tasks: task.after() if __name__ == '__main__': try: options = Options() options.parseOptions() except usage.UsageError, errortext: print '%s: %s' % (sys.argv[0], errortext) print '%s: Try --help for usage details.' % (sys.argv[0]) sys.exit(1) tasks = TaskRunner(options) tasks.run() PK]axG r3,vumi-0.6.9.data/scripts/vumi_count_models.py#!python # -*- test-case-name: vumi.scripts.tests.test_vumi_count_models -*- import re import sys from twisted.internet.defer import inlineCallbacks, succeed from twisted.internet.task import react from twisted.python import usage from vumi.utils import load_class_by_string from vumi.persist.txriak_manager import TxRiakManager class Options(usage.Options): optParameters = [ ["model", "m", None, "Full Python name of the model class to count." " E.g. 'vumi.components.message_store.InboundMessage'."], ["bucket-prefix", "b", None, "The bucket prefix for the Riak manager."], ["index-field", None, None, "Field with index to query. If omitted, all keys in the bucket will" " be counted and no `index-value-*' parameters are allowed."], ["index-value", None, None, "Exact match value or start of range."], ["index-value-end", None, None, "End of range. If ommitted, an exact match query will be used."], ["index-value-regex", None, None, "Regex to filter index values."], ["index-page-size", None, "1000", "The number of keys to fetch in each index query."], ] longdesc = """ Index-based model counter. This makes paginated index queries, optionally filters the results by applying a regex to the index value, and returns a count of all matching models. """ def ensure_dependent_option(self, needed, needs): """ Raise UsageError if `needs` is provided without `needed`. """ if self[needed] is None and self[needs] is not None: raise usage.UsageError("%s requires %s to be specified." % ( needs, needed)) def postOptions(self): if self["model"] is None: raise usage.UsageError("Please specify a model class.") if self["bucket-prefix"] is None: raise usage.UsageError("Please specify a bucket prefix.") self.ensure_dependent_option("index-field", "index-value") self.ensure_dependent_option("index-value", "index-field") self.ensure_dependent_option("index-value", "index-value-end") self.ensure_dependent_option("index-value-end", "index-value-regex") self["index-page-size"] = int(self['index-page-size']) class ProgressEmitter(object): """Report progress as the number of items processed to an emitter.""" def __init__(self, emit, batch_size): self.emit = emit self.batch_size = batch_size self.processed = 0 def update(self, value): if (value / self.batch_size) > (self.processed / self.batch_size): self.emit(value) self.processed = value class ModelCounter(object): def __init__(self, options): self.options = options model_cls = load_class_by_string(options['model']) riak_config = { 'bucket_prefix': options['bucket-prefix'], } self.manager = self.get_riak_manager(riak_config) self.model = self.manager.proxy(model_cls) def cleanup(self): return self.manager.close_manager() def get_riak_manager(self, riak_config): return TxRiakManager.from_config(riak_config) def emit(self, s): print s def count_keys(self, keys, filter_regex): """ Count keys in an index page, filtering by regex if necessary. """ if filter_regex is not None: keys = [(v, k) for v, k in keys if filter_regex.match(v)] return len(keys) @inlineCallbacks def count_pages(self, index_page, filter_regex): emit_progress = lambda t: self.emit( "%s object%s counted." % (t, "" if t == 1 else "s")) progress = ProgressEmitter( emit_progress, self.options["index-page-size"]) counted = 0 while index_page is not None: if index_page.has_next_page(): next_page_d = index_page.next_page() else: next_page_d = succeed(None) counted += self.count_keys(list(index_page), filter_regex) progress.update(counted) index_page = yield next_page_d self.emit("Done, %s object%s found." % ( counted, "" if counted == 1 else "s")) @inlineCallbacks def count_all_keys(self): """ Perform an index query to get all keys and count them. """ self.emit("Counting all keys ...") index_page = yield self.model.all_keys_page( max_results=self.options["index-page-size"]) yield self.count_pages(index_page, filter_regex=None) @inlineCallbacks def count_index_keys(self): """ Perform an index query to get all matching keys and count them. """ filter_regex = self.options["index-value-regex"] if filter_regex is not None: filter_regex = re.compile(filter_regex) self.emit("Counting ...") index_page = yield self.model.index_keys_page( field_name=self.options["index-field"], value=self.options["index-value"], end_value=self.options["index-value-end"], max_results=self.options["index-page-size"], return_terms=True) yield self.count_pages(index_page, filter_regex=filter_regex) def _run(self): if self.options["index-field"] is None: return self.count_all_keys() else: return self.count_index_keys() @inlineCallbacks def run(self): try: yield self._run() finally: yield self.cleanup() def main(_reactor, name, *args): try: options = Options() options.parseOptions(args) except usage.UsageError, errortext: print '%s: %s' % (name, errortext) print '%s: Try --help for usage details.' % (name,) sys.exit(1) model_counter = ModelCounter(options) return model_counter.run() if __name__ == '__main__': react(main, sys.argv) PK]axGHW66(vumi-0.6.9.data/scripts/vumi_tagpools.py#!python # -*- test-case-name: vumi.scripts.tests.test_vumi_tagpools -*- import sys import re import itertools import yaml from twisted.python import usage from vumi.components.tagpool import TagpoolManager from vumi.persist.redis_manager import RedisManager class PoolSubCmd(usage.Options): synopsis = "" def parseArgs(self, pool): self.pool = pool class CreatePoolCmd(PoolSubCmd): def run(self, cfg): local_tags = cfg.tags(self.pool) tags = [(self.pool, local_tag) for local_tag in local_tags] metadata = cfg.metadata(self.pool) cfg.emit("Creating pool %s ..." % self.pool) cfg.emit(" Setting metadata ...") cfg.tagpool.set_metadata(self.pool, metadata) cfg.emit(" Declaring %d tag(s) ..." % len(tags)) cfg.tagpool.declare_tags(tags) cfg.emit(" Done.") class UpdatePoolMetadataCmd(PoolSubCmd): def run(self, cfg): metadata = cfg.metadata(self.pool) cfg.emit("Updating metadata for pool %s ..." % self.pool) cfg.tagpool.set_metadata(self.pool, metadata) cfg.emit(" Done.") class UpdateAllPoolMetadataCmd(usage.Options): def run(self, cfg): pools_in_tagpool = cfg.tagpool.list_pools() pools_in_cfg = set(cfg.pools.keys()) pools_in_both = sorted(pools_in_tagpool.intersection(pools_in_cfg)) cfg.emit("Updating pool metadata.") cfg.emit("Note: Pools not present in both the config and tagpool" " store will not be updated.") if not pools_in_both: cfg.emit("No pools found.") return for pool in pools_in_both: cfg.emit(" Updating metadata for pool %s ..." % pool) metadata = cfg.metadata(pool) cfg.tagpool.set_metadata(pool, metadata) cfg.emit("Done.") class PurgePoolCmd(PoolSubCmd): def run(self, cfg): cfg.emit("Purging pool %s ..." % self.pool) cfg.tagpool.purge_pool(self.pool) cfg.emit(" Done.") def key_ranges(keys): """Take a list of keys and convert them to a compact output string. E.g. foo100, foo101, ..., foo200, foo300 becomes foo[100..200], foo300 """ keys.sort() last_digits_re = re.compile("^(?P
()|(.*[^\d]))(?P\d+)"
                                "(?P.*)$")

    def group(x):
        i, key = x
        match = last_digits_re.match(key)
        if not match:
            return None
        pre, post = match.group('pre'), match.group('post')
        digits = match.group('digits')
        dlen, value = len(digits), int(digits)
        return pre, post, dlen, value - i

    key_ranges = []
    for grp_key, grp_list in itertools.groupby(enumerate(keys), group):
        grp_list = list(grp_list)
        if len(grp_list) == 1 or grp_key is None:
            key_ranges.extend(g[1] for g in grp_list)
        else:
            pre, post, dlen, _cnt = grp_key
            start = last_digits_re.match(grp_list[0][1]).group('digits')
            end = last_digits_re.match(grp_list[-1][1]).group('digits')
            key_range = "%s[%s-%s]%s" % (pre, start, end, post)
            key_ranges.append(key_range)

    return ", ".join(key_ranges)


class ListKeysCmd(PoolSubCmd):
    def run(self, cfg):
        free_tags = cfg.tagpool.free_tags(self.pool)
        inuse_tags = cfg.tagpool.inuse_tags(self.pool)
        cfg.emit("Listing tags for pool %s ..." % self.pool)
        cfg.emit("Free tags:")
        cfg.emit("   " + (key_ranges([tag[1] for tag in free_tags])
                          or "-- None --"))
        cfg.emit("Tags in use:")
        cfg.emit("   " + (key_ranges([tag[1] for tag in inuse_tags])
                          or "-- None --"))


class ListPoolsCmd(usage.Options):
    def run(self, cfg):
        pools_in_tagpool = cfg.tagpool.list_pools()
        pools_in_cfg = set(cfg.pools.keys())
        cfg.emit("Pools defined in cfg and tagpool:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_tagpool.intersection(pools_in_cfg))
                           or ['-- None --']))
        cfg.emit("Pools only in cfg:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_cfg.difference(pools_in_tagpool))
                           or ['-- None --']))
        cfg.emit("Pools only in tagpool:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_tagpool.difference(pools_in_cfg))
                           or ['-- None --']))


class ReleaseTagCmd(usage.Options):

    synopsis = " "

    def parseArgs(self, pool, tag):
        self.pool = pool
        self.tag = tag

    def run(self, cfg):
        free_tags = cfg.tagpool.free_tags(self.pool)
        inuse_tags = cfg.tagpool.inuse_tags(self.pool)
        tag_tuple = (self.pool, self.tag)
        if tag_tuple not in inuse_tags:
            if tag_tuple not in free_tags:
                cfg.emit('Unknown tag %s.' % (tag_tuple,))
            else:
                cfg.emit('Tag %s not in use.' % (tag_tuple,))
        else:
            cfg.tagpool.release_tag(tag_tuple)
            cfg.emit('Released %s.' % (tag_tuple,))


class Options(usage.Options):
    subCommands = [
        ["create-pool", None, CreatePoolCmd,
         "Declare tags for a tag pool."],
        ["update-pool-metadata", None, UpdatePoolMetadataCmd,
         "Update a pool's metadata from config."],
        ["update-all-metadata", None, UpdateAllPoolMetadataCmd,
         "Update all pool meta data from config."],
        ["purge-pool", None, PurgePoolCmd,
         "Purge all tags from a tag pool."],
        ["list-keys", None, ListKeysCmd,
         "List the free and inuse keys associated with a tag pool."],
        ["list-pools", None, ListPoolsCmd,
         "List all pools defined in config and in the tag store."],
        ["release-tag", None, ReleaseTagCmd,
         "Release a single tag, moves it from the in-use to the free set. "
         "Use only if you know what you are doing."]
    ]

    optParameters = [
        ["config", "c", "tagpools.yaml",
         "A config file describing the available pools."],
    ]

    longdesc = """Utilities for working with
                  vumi.application.TagPoolManager."""

    def postOptions(self):
        if self.subCommand is None:
            raise usage.UsageError("Please specify a sub-command.")


class ConfigHolder(object):
    def __init__(self, options):
        self.options = options
        self.config = yaml.safe_load(open(options['config'], "rb"))
        self.pools = self.config.get('pools', {})
        redis = RedisManager.from_config(self.config.get('redis_manager', {}))
        self.tagpool = TagpoolManager(redis.sub_manager(
                self.config.get('tagpool_prefix', 'vumi')))

    def emit(self, s):
        print s

    def tags(self, pool):
        tags = self.pools[pool]['tags']
        if isinstance(tags, basestring):
            tags = eval(tags, {}, {})
        return tags

    def metadata(self, pool):
        return self.pools[pool].get('metadata', {})

    def run(self):
        self.options.subOptions.run(self)


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    cfg = ConfigHolder(options)
    cfg.run()
PK]axG299-vumi-0.6.9.data/scripts/vumi_list_messages.py#!python
# -*- test-case-name: vumi.scripts.tests.test_vumi_list_messages -*-

import sys

from twisted.internet.defer import inlineCallbacks
from twisted.internet.task import react
from twisted.python import usage

from vumi.components.message_store import MessageStore
from vumi.persist.txriak_manager import TxRiakManager


class Options(usage.Options):
    optParameters = [
        ["batch", None, None,
         "Batch identifier to list messages for."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["direction", None, None,
         "Message direction. Valid values are `inbound' and `outbound'."],
        ["index-page-size", None, "1000",
         "The number of keys to fetch in each index query."],
    ]

    longdesc = """
    Index-based message store lister. For each message, the timestamp, remote
    address, and message_id are returned in a comma-separated format.
    """

    def postOptions(self):
        if self["batch"] is None:
            raise usage.UsageError("Please specify a batch.")
        if self["direction"] not in ["inbound", "outbound"]:
            raise usage.UsageError("Please specify a valid direction.")
        if self["bucket-prefix"] is None:
            raise usage.UsageError("Please specify a bucket prefix.")
        self["index-page-size"] = int(self['index-page-size'])


class MessageLister(object):
    def __init__(self, options):
        self.options = options
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.mdb = MessageStore(self.manager, None)

    def cleanup(self):
        return self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return TxRiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    @inlineCallbacks
    def list_pages(self, index_page):
        while index_page is not None:
            next_page_d = index_page.next_page()
            for message_id, timestamp, addr in index_page:
                self.emit(",".join([timestamp, addr, message_id]))
            index_page = yield next_page_d

    @inlineCallbacks
    def _run(self):
        index_func = {
            "inbound": self.mdb.batch_inbound_keys_with_addresses,
            "outbound": self.mdb.batch_outbound_keys_with_addresses,
        }[self.options["direction"]]
        index_page = yield index_func(
            self.options["batch"], max_results=self.options["index-page-size"])
        yield self.list_pages(index_page)

    @inlineCallbacks
    def run(self):
        try:
            yield self._run()
        finally:
            yield self.cleanup()


def main(_reactor, name, *args):
    try:
        options = Options()
        options.parseOptions(args)
    except usage.UsageError, errortext:
        print '%s: %s' % (name, errortext)
        print '%s: Try --help for usage details.' % (name,)
        sys.exit(1)

    model_counter = MessageLister(options)
    return model_counter.run()


if __name__ == '__main__':
    react(main, sys.argv)
PK]axG:.vumi-0.6.9.data/scripts/vumi_model_migrator.py#!python
# -*- test-case-name: vumi.scripts.tests.test_vumi_model_migrator -*-
import sys

from twisted.internet.defer import inlineCallbacks, gatherResults, succeed
from twisted.internet.task import react
from twisted.python import usage

from vumi.utils import load_class_by_string
from vumi.persist.txriak_manager import TxRiakManager


class Options(usage.Options):
    optParameters = [
        ["model", "m", None,
         "Full Python name of the model class to migrate."
         " E.g. 'vumi.components.message_store.InboundMessage'."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["keys", None, None,
         "Migrate these specific keys rather than the whole bucket."
         " E.g. --keys 'foo,bar,baz'"],
        ["concurrent-migrations", None, "20",
         "The number of concurrent migrations to perform."],
        ["index-page-size", None, "1000",
         "The number of keys to fetch in each index query."],
        ["continuation-token", None, None,
         "A continuation token for resuming an interrupted migration."],
        ["post-migrate-function", None, None,
         "Full Python name of a callable to post-process each migrated object."
         " Should update the model object and return a (possibly deferred)"
         " boolean to indicate whether the object has been modified."],
    ]

    optFlags = [
        ["dry-run", None, "Don't save anything back to Riak."],
    ]

    longdesc = """Offline model migrator. Necessary for updating
                  models when index names change so that old model
                  instances remain findable by index searches.
                  """

    def postOptions(self):
        if self['model'] is None:
            raise usage.UsageError("Please specify a model class.")
        if self['bucket-prefix'] is None:
            raise usage.UsageError("Please specify a bucket prefix.")
        self['concurrent-migrations'] = int(self['concurrent-migrations'])
        self['index-page-size'] = int(self['index-page-size'])


class ProgressEmitter(object):
    """Report progress as the number of items processed to an emitter."""

    def __init__(self, emit, batch_size):
        self.emit = emit
        self.batch_size = batch_size
        self.processed = 0

    def update(self, value):
        if (value / self.batch_size) > (self.processed / self.batch_size):
            self.emit(value)
        self.processed = value


class FakeIndexPage(object):
    def __init__(self, keys, page_size):
        self._keys = keys
        self._page_size = page_size

    def __iter__(self):
        return iter(self._keys[:self._page_size])

    def has_next_page(self):
        return len(self._keys) > self._page_size

    def next_page(self):
        return succeed(
            type(self)(self._keys[self._page_size:], self._page_size))


class ModelMigrator(object):
    def __init__(self, options):
        self.options = options
        model_cls = load_class_by_string(options['model'])
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.model = self.manager.proxy(model_cls)

        # The default post-migrate-function does nothing and returns True if
        # and only if the object was migrated.
        self.post_migrate_function = lambda obj: obj.was_migrated
        if options['post-migrate-function'] is not None:
            self.post_migrate_function = load_class_by_string(
                options['post-migrate-function'])

    def cleanup(self):
        return self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return TxRiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    @inlineCallbacks
    def migrate_key(self, key, dry_run):
        try:
            obj = yield self.model.load(key)
            if obj is not None:
                should_save = yield self.post_migrate_function(obj)
                if should_save and not dry_run:
                    yield obj.save()
            else:
                self.emit("Skipping tombstone key %r." % (key,))
        except Exception, e:
            self.emit("Failed to migrate key %r:" % (key,))
            self.emit("  %s: %s" % (type(e).__name__, e))

    @inlineCallbacks
    def migrate_keys(self, _result, keys_list, dry_run):
        """
        Migrate keys from `keys_list` until there are none left.

        This method is expected to be called multiple times concurrently with
        all instances sharing the same `keys_list`.
        """
        # keys_list is a shared mutable list, so we can't just iterate over it.
        while keys_list:
            key = keys_list.pop(0)
            yield self.migrate_key(key, dry_run)

    def migrate_page(self, keys, dry_run):
        # Depending on our Riak client, Python version, and JSON library we may
        # get bytes or unicode here.
        keys = [k.decode('utf-8') if isinstance(k, str) else k for k in keys]
        return gatherResults([
            self.migrate_keys(None, keys, dry_run)
            for _ in xrange(self.options["concurrent-migrations"])])

    @inlineCallbacks
    def migrate_pages(self, index_page, emit_progress):
        dry_run = self.options["dry-run"]
        progress = ProgressEmitter(
            emit_progress, self.options["index-page-size"])
        processed = 0
        while index_page is not None:
            if index_page.has_next_page():
                next_page_d = index_page.next_page()
            else:
                next_page_d = succeed(None)
            keys = list(index_page)
            yield self.migrate_page(keys, dry_run)
            processed += len(keys)
            progress.update(processed)
            continuation = getattr(index_page, 'continuation', None)
            if continuation is not None:
                self.emit("Continuation token: '%s'" % (continuation,))
            index_page = yield next_page_d
        self.emit("Done, %s object%s migrated." % (
            processed, "" if processed == 1 else "s"))

    def migrate_specified_keys(self, keys):
        """
        Migrate specified keys.
        """
        self.emit("Migrating %d specified keys ..." % len(keys))
        emit_progress = lambda t: self.emit(
            "%s of %s objects migrated." % (t, len(keys)))
        index_page = FakeIndexPage(keys, self.options["index-page-size"])
        return self.migrate_pages(index_page, emit_progress)

    @inlineCallbacks
    def migrate_all_keys(self, continuation=None):
        """
        Perform an index query to get all keys and migrate them.

        If `continuation` is provided, it will be used as the starting point
        for the query.
        """
        self.emit("Migrating ...")
        emit_progress = lambda t: self.emit(
            "%s object%s migrated." % (t, "" if t == 1 else "s"))
        index_page = yield self.model.all_keys_page(
            max_results=self.options["index-page-size"],
            continuation=continuation)
        yield self.migrate_pages(index_page, emit_progress)

    def _run(self):
        if self.options["keys"] is not None:
            return self.migrate_specified_keys(self.options["keys"].split(","))
        else:
            return self.migrate_all_keys(self.options["continuation-token"])

    def run(self):
        return self._run().addBoth(lambda _: self.cleanup())


def main(_reactor, name, *args):
    try:
        options = Options()
        options.parseOptions(args)
    except usage.UsageError, errortext:
        print '%s: %s' % (name, errortext)
        print '%s: Try --help for usage details.' % (name,)
        sys.exit(1)

    model_migrator = ModelMigrator(options)
    return model_migrator.run()


if __name__ == '__main__':
    react(main, sys.argv)
PKrgTGEՌvumi/connectors.pyfrom twisted.internet.defer import gatherResults, inlineCallbacks, returnValue

from vumi import log
from vumi.middleware import MiddlewareStack
from vumi.message import (
    TransportMessage, TransportEvent, TransportUserMessage, TransportStatus)


class IgnoreMessage(Exception):
    pass


class BaseConnector(object):
    """Base class for 'connector' objects.

    A connector encapsulates the 'inbound', 'outbound' and 'event' publishers
    and consumers required by vumi workers and avoids having to operate on them
    individually all over the place.
    """
    def __init__(self, worker, connector_name, prefetch_count=None,
                 middlewares=None):
        self.name = connector_name
        self.worker = worker
        self._consumers = {}
        self._publishers = {}
        self._endpoint_handlers = {}
        self._default_handlers = {}
        self._prefetch_count = prefetch_count
        self._middlewares = MiddlewareStack(middlewares
                                            if middlewares is not None else [])

    def _rkey(self, mtype):
        return '%s.%s' % (self.name, mtype)

    def setup(self):
        raise NotImplementedError()

    def teardown(self):
        d = gatherResults([c.stop() for c in self._consumers.values()])
        d.addCallback(lambda r: self._middlewares.teardown())
        return d

    @property
    def paused(self):
        return all(consumer.paused
                   for consumer in self._consumers.itervalues())

    def pause(self):
        return gatherResults([
            consumer.pause() for consumer in self._consumers.itervalues()])

    def unpause(self):
        # This doesn't return a deferred.
        for consumer in self._consumers.values():
            consumer.unpause()

    @inlineCallbacks
    def _setup_publisher(self, mtype):
        publisher = yield self.worker.publish_to(self._rkey(mtype))
        self._publishers[mtype] = publisher
        returnValue(publisher)

    @inlineCallbacks
    def _setup_consumer(self, mtype, msg_class, default_handler):
        def handler(msg):
            return self._consume_message(mtype, msg)

        consumer = yield self.worker.consume(
            self._rkey(mtype), handler, message_class=msg_class, paused=True,
            prefetch_count=self._prefetch_count)
        self._consumers[mtype] = consumer
        self._set_default_endpoint_handler(mtype, default_handler)
        returnValue(consumer)

    def _set_endpoint_handler(self, mtype, handler, endpoint_name):
        if endpoint_name is None:
            endpoint_name = TransportMessage.DEFAULT_ENDPOINT_NAME
        handlers = self._endpoint_handlers.setdefault(mtype, {})
        handlers[endpoint_name] = handler

    def _set_default_endpoint_handler(self, mtype, handler):
        self._endpoint_handlers.setdefault(mtype, {})
        self._default_handlers[mtype] = handler

    def _consume_message(self, mtype, msg):
        endpoint_name = msg.get_routing_endpoint()
        handler = self._endpoint_handlers[mtype].get(endpoint_name)
        if handler is None:
            handler = self._default_handlers.get(mtype)
        d = self._middlewares.apply_consume(mtype, msg, self.name)
        d.addCallback(handler)
        return d.addErrback(self._ignore_message, msg)

    def _publish_message(self, mtype, msg, endpoint_name):
        if endpoint_name is not None:
            msg.set_routing_endpoint(endpoint_name)
        d = self._middlewares.apply_publish(mtype, msg, self.name)
        return d.addCallback(self._publishers[mtype].publish_message)

    def _ignore_message(self, failure, msg):
        failure.trap(IgnoreMessage)
        log.debug("Ignoring msg due to %r: %r" % (failure.value, msg))


class ReceiveInboundConnector(BaseConnector):
    def setup(self):
        outbound_d = self._setup_publisher('outbound')
        inbound_d = self._setup_consumer('inbound', TransportUserMessage,
                                         self.default_inbound_handler)
        event_d = self._setup_consumer('event', TransportEvent,
                                       self.default_event_handler)
        return gatherResults([outbound_d, inbound_d, event_d])

    def default_inbound_handler(self, msg):
        log.warning("No inbound handler for %r: %r" % (self.name, msg))

    def default_event_handler(self, msg):
        log.warning("No event handler for %r: %r" % (self.name, msg))

    def set_inbound_handler(self, handler, endpoint_name=None):
        self._set_endpoint_handler('inbound', handler, endpoint_name)

    def set_default_inbound_handler(self, handler):
        self._set_default_endpoint_handler('inbound', handler)

    def set_event_handler(self, handler, endpoint_name=None):
        self._set_endpoint_handler('event', handler, endpoint_name)

    def set_default_event_handler(self, handler):
        self._set_default_endpoint_handler('event', handler)

    def publish_outbound(self, msg, endpoint_name=None):
        return self._publish_message('outbound', msg, endpoint_name)


class ReceiveOutboundConnector(BaseConnector):
    def setup(self):
        inbound_d = self._setup_publisher('inbound')
        event_d = self._setup_publisher('event')
        outbound_d = self._setup_consumer('outbound', TransportUserMessage,
                                          self.default_outbound_handler)
        return gatherResults([outbound_d, inbound_d, event_d])

    def default_outbound_handler(self, msg):
        log.warning("No outbound handler for %r: %r" % (self.name, msg))

    def set_outbound_handler(self, handler, endpoint_name=None):
        self._set_endpoint_handler('outbound', handler, endpoint_name)

    def set_default_outbound_handler(self, handler):
        self._set_default_endpoint_handler('outbound', handler)

    def publish_inbound(self, msg, endpoint_name=None):
        return self._publish_message('inbound', msg, endpoint_name)

    def publish_event(self, msg, endpoint_name=None):
        return self._publish_message('event', msg, endpoint_name)

    def _ignore_message(self, failure, msg):
        failure.trap(IgnoreMessage)
        log.debug("Ignoring msg (with NACK) due to %r: %r" % (
            failure.value, msg))
        return self.publish_event(TransportEvent(
            user_message_id=msg['message_id'], nack_reason=str(failure.value),
            event_type='nack'))


class PublishStatusConnector(BaseConnector):
    @inlineCallbacks
    def setup(self):
        yield self._setup_publisher('status')

    def publish_status(self, msg):
        return self._publish_message('status', msg, endpoint_name=None)


class ReceiveStatusConnector(BaseConnector):
    @inlineCallbacks
    def setup(self):
        yield self._setup_consumer(
            'status', TransportStatus, self.default_status_handler)

    def default_status_handler(self, msg):
        log.warning("No status handler for %r: %r" % (self.name, msg))

    def set_status_handler(self, handler):
        self._set_default_endpoint_handler('status', handler)
PK=JGSp""vumi/rpc.py# -*- coding: utf-8 -*-

"""Utilties for marking up RPC methods."""

import inspect
import textwrap
import functools
import itertools

from twisted.internet.defer import Deferred


class RpcCheckError(Exception):
    """Raised when a value fails a type check."""


class Signature(object):

    NO_DEFAULT = object()
    NO_ARG = object()

    def __init__(self, f, returns=None, requires_self=True, **kw):
        self.returns = returns if returns is not None else Null()
        self.requires_self = requires_self
        self.params = kw
        self.argspec = inspect.getargspec(f)
        self.defaults = [self.NO_DEFAULT] * (
            len(self.argspec.args) - len(self.argspec.defaults or ()))
        self.defaults += list(self.argspec.defaults or ())

    def check_params(self, args, kw):
        if kw:
            raise RpcCheckError("Keyword parameters not yet supported.")
        if len(args) > len(self.argspec.args):
            raise RpcCheckError("Too many positional arguments.")

        missing_arg_count = len(self.argspec.args) - len(args)
        args = list(args) + [self.NO_ARG] * missing_arg_count
        arg_tuples = itertools.izip(self.argspec.args, self.defaults, args)
        if self.requires_self:
            next(arg_tuples)

        for arg_name, default, arg_value in arg_tuples:
            if arg_value is self.NO_ARG:
                arg_value = default
            if arg_value is self.NO_DEFAULT:
                raise RpcCheckError("Positional argument %r missing"
                                    " but no default is available." % arg_name)
            arg_type = self.params[arg_name]
            arg_type.check(arg_name, arg_value)

    def check_result(self, result):
        self.returns.check('return value', result)
        return result

    def _wrap_help(self, help_text):
        indent = '    '
        return textwrap.wrap(help_text, initial_indent=indent,
                             subsequent_indent=indent)

    def _format_param(self, param_name, param_type, default):
        lines = [":param %s %s:" % (param_type.name, param_name)]
        help_text = param_type.help()
        if param_type.nullable():
            help_text += " May be null."
        if default is not self.NO_DEFAULT:
            help_text += " Default: %r." % (default,)
        lines.extend(self._wrap_help(help_text))
        return lines

    def _format_return(self, param_type):
        lines = [":rtype %s:" % (param_type.name,)]
        lines.extend(self._wrap_help(param_type.help()))
        return lines

    def _args_with_defaults(self):
        args_defaults = itertools.izip(self.argspec.args, self.defaults)
        if self.requires_self:
            next(args_defaults)

        for arg, default in args_defaults:
            yield arg, self.params[arg], default

    def param_doc(self):
        lines = []
        for arg, arg_type, default in self._args_with_defaults():
            lines.extend(self._format_param(arg, self.params[arg], default))
        lines.extend(self._format_return(self.returns))
        return lines

    def jsonrpc_signature(self):
        sig = [self.returns.jsonrpc_type]
        sig.extend(arg_type.jsonrpc_type for _, arg_type, _
                   in self._args_with_defaults())
        return [sig]


def signature(**kw):
    def decorator(f):
        sig = Signature(f, **kw)

        def wrapper(*args, **kw):
            sig.check_params(args, kw)
            result = f(*args, **kw)
            if isinstance(result, Deferred):
                result.addCallback(sig.check_result)
            else:
                sig.check_result(result)
            return result

        functools.update_wrapper(wrapper, f)
        doc = textwrap.wrap(wrapper.__doc__ or '')
        doc.append("")
        doc.extend(sig.param_doc())
        wrapper.__doc__ = "\n".join(doc)
        wrapper.signature = sig.jsonrpc_signature()
        wrapper.signature_object = sig
        return wrapper

    return decorator


class RpcType(object):

    # See: http://xmlrpc.scripting.com/spec.html
    # valid simple types are:
    #    int, boolean, string, double, base64 and dateTime.iso8601
    # valid compound types are:
    #    array, struct
    jsonrpc_type = None

    def __init__(self, help=None, null=False):
        self._help = help
        self._null = null

    @property
    def name(self):
        return self.__class__.__name__

    def help(self):
        return self._help or ''

    def nullable(self):
        return self._null

    def check(self, name, value):
        if value is None:
            if not self._null:
                raise RpcCheckError("%s may not be None (got None)" % (name,))
            return
        self.nonnull_check(name, value)

    def nonnull_check(self, name, value):
        raise RpcCheckError("The base class RpcType accepts no values.")


class Null(RpcType):
    jsonrpc_type = 'null'

    def __init__(self, *args, **kw):
        kw.setdefault('null', True)
        super(Null, self).__init__(*args, **kw)

    def nonnull_check(self, name, value):
        if value is not None:
            raise RpcCheckError("Null value expected for %s (got %r)"
                                % (name, value))


class Unicode(RpcType):
    jsonrpc_type = 'string'

    def nonnull_check(self, name, value):
        if not isinstance(value, unicode):
            raise RpcCheckError("Unicode value expected for %s (got %r)"
                                % (name, value))


class Int(RpcType):
    jsonrpc_type = 'int'

    def nonnull_check(self, name, value):
        if not isinstance(value, (int, long)):
            raise RpcCheckError("Int value expected for %s (got %r)"
                                % (name, value))


class List(RpcType):
    jsonrpc_type = 'array'

    def __init__(self, *args, **kw):
        self._item_type = kw.pop('item_type', None)
        self._length = kw.pop('length', None)
        super(List, self).__init__(*args, **kw)

    def nonnull_check(self, name, value):
        if not isinstance(value, list):
            raise RpcCheckError("List value expected for %s (got %r)"
                                % (name, value))
        if self._length is not None and len(value) != self._length:
            raise RpcCheckError("List value for %s expected to have"
                                " length %d (got %r)"
                                % (name, self._length, value))
        if self._item_type is not None:
            item_name = 'items of %s' % (name,)
            for item in value:
                self._item_type.check(item_name, item)


class Dict(RpcType):
    jsonrpc_type = 'struct'

    def __init__(self, *args, **kw):
        self._item_type = kw.pop('item_type', None)
        self._required_fields = kw.pop('required_fields', {})
        self._optional_fields = kw.pop('optional_fields', {})
        self._closed = kw.pop('closed', False)
        self._no_checks = all(not x for x in (
            self._item_type, self._required_fields, self._optional_fields,
            self._closed))
        super(Dict, self).__init__(*args, **kw)

    def nonnull_check(self, name, value):
        if not isinstance(value, dict):
            raise RpcCheckError("Dict value expected for %s (got %r)"
                                % (name, value))
        if self._no_checks:
            return
        for key in value:
            field_type = self._required_fields.get(key)
            field_type = (self._optional_fields.get(key)
                          if field_type is None else field_type)
            if field_type is None:
                if self._closed:
                    raise RpcCheckError("Dict received unexpected key %s"
                                        " (got %r)" % (key, value))
                field_type = self._item_type
            if field_type is not None:
                field_type.check('item %s of %s' % (key, name), value[key])
        for key in self._required_fields:
            if key not in value:
                raise RpcCheckError("Dict requires key %s (got %r)"
                                    % (key, value))


class Tag(RpcType):
    jsonrpc_type = 'array'

    def nonnull_check(self, name, value):
        if not isinstance(value, (list, tuple)):
            raise RpcCheckError("Tag %s must be a list or tuple (got %r)"
                                % (name, value))
        if len(value) != 2:
            raise RpcCheckError("Tag %s must contain two elements, a pool name"
                                " and a tag name (got %r)"
                                % (name, value))
        for item in value:
            if not isinstance(item, unicode):
                raise RpcCheckError("Tag %s must have unicode pool and tag"
                                    " name (got %r)" % (name, value))
PKqGv;v;
vumi/utils.py# -*- test-case-name: vumi.tests.test_utils -*-

import os.path
import re
import sys
import base64
import pkg_resources
import warnings
from functools import wraps

from zope.interface import implements
from twisted.internet import defer
from twisted.internet import protocol
from twisted.internet.defer import succeed
from twisted.python.failure import Failure
from twisted.web.client import Agent, ResponseDone, WebClientContextFactory
from twisted.web.server import Site
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer
from twisted.web.http import PotentialDataLoss
from twisted.web.resource import Resource

from vumi.errors import VumiError


# Stop Agent from logging two useless lines for every request.
# This is hacky, but there's no better way to do it for now.
from twisted.web import client
client._HTTP11ClientFactory.noisy = False


def import_module(name):
    """
    This is a simpler version of `importlib.import_module` and does
    not support relative imports.

    It's here so that we can avoid using importlib and not have to
    juggle different deps between Python versions.
    """
    __import__(name)
    return sys.modules[name]


def to_kwargs(kwargs):
    """
    Convert top-level keys from unicode to string for older Python versions.

    See http://bugs.python.org/issue2646 for details.
    """
    return dict((k.encode('utf8'), v) for k, v in kwargs.iteritems())


class HttpError(VumiError):
    """Base class for errors raised by http_request_full."""


class HttpDataLimitError(VumiError):
    """Returned by http_request_full if too much data is returned."""


class HttpTimeoutError(VumiError):
    """Returned by http_request_full if the request times out."""


class SimplishReceiver(protocol.Protocol):
    def __init__(self, response, data_limit=None):
        self.deferred = defer.Deferred(canceller=self.cancel_on_timeout)
        self.response = response
        self.data_limit = data_limit
        self.data_recvd_len = 0
        self.response.delivered_body = ''
        if response.code == 204:
            self.deferred.callback(self.response)
        else:
            response.deliverBody(self)

    def cancel_on_timeout(self, d):
        self.cancel_receiving(HttpTimeoutError("Timeout while receiving data"))

    def cancel_on_data_limit(self):
        self.cancel_receiving(HttpDataLimitError(
            "More than %d bytes received" % (self.data_limit,)))

    def cancel_receiving(self, err):
        self.transport.stopProducing()
        self.deferred.errback(err)

    def data_limit_exceeded(self):
        return (self.data_limit is not None and
                self.data_recvd_len > self.data_limit)

    def dataReceived(self, data):
        self.data_recvd_len += len(data)
        if self.data_limit_exceeded():
            self.cancel_on_data_limit()
        self.response.delivered_body += data

    def connectionLost(self, reason):
        if self.deferred.called:
            # this happens when the deferred is cancelled and this
            # triggers connection closing
            return
        if reason.check(ResponseDone):
            self.deferred.callback(self.response)
        elif reason.check(PotentialDataLoss):
            # This is only (and always!) raised if we get a response with no
            # Content-Length header and no other way of determining when the
            # response body is finished (such as chunked transfer encoding).
            # See http://twistedmatrix.com/trac/ticket/4840 for sadness.
            #
            # We ignore this and treat the call as success. If we care about
            # checking for potential data loss, we should do that in all cases
            # rather than trying to figure out if we might need to.
            self.deferred.callback(self.response)
        else:
            self.deferred.errback(reason)


def http_request_full(url, data=None, headers={}, method='POST',
                      timeout=None, data_limit=None, context_factory=None,
                      agent_class=None, reactor=None):
    if reactor is None:
        # The import replaces the local variable.
        from twisted.internet import reactor
    if agent_class is None:
        agent_class = Agent
    context_factory = context_factory or WebClientContextFactory()
    agent = agent_class(reactor, contextFactory=context_factory)
    d = agent.request(method,
                      url,
                      mkheaders(headers),
                      StringProducer(data) if data else None)

    def handle_response(response):
        return SimplishReceiver(response, data_limit).deferred

    d.addCallback(handle_response)

    if timeout is not None:
        cancelling_on_timeout = [False]

        def raise_timeout(reason):
            if not cancelling_on_timeout[0] or reason.check(HttpTimeoutError):
                return reason
            return Failure(HttpTimeoutError("Timeout while connecting"))

        def cancel_on_timeout():
            cancelling_on_timeout[0] = True
            d.cancel()

        def cancel_timeout(r, delayed_call):
            if delayed_call.active():
                delayed_call.cancel()
            return r

        d.addErrback(raise_timeout)
        delayed_call = reactor.callLater(timeout, cancel_on_timeout)
        d.addCallback(cancel_timeout, delayed_call)

    return d


def mkheaders(headers):
    """
    Turn a dict of HTTP headers into an instance of Headers.

    Twisted expects a list of values, not a single value. We should
    support both.
    """
    raw_headers = {}
    for k, v in headers.iteritems():
        if isinstance(v, basestring):
            v = [v]
        raw_headers[k] = v
    return Headers(raw_headers)


def http_request(url, data, headers={}, method='POST', agent_class=None):
    d = http_request_full(
        url, data, headers=headers, method=method, agent_class=agent_class)
    return d.addCallback(lambda r: r.delivered_body)


def basic_auth_string(username, password):
    """
    Encode a username and password for use in an HTTP Basic Authentication
    header
    """
    b64 = base64.encodestring('%s:%s' % (username, password)).strip()
    return 'Basic %s' % b64


def normalize_msisdn(raw, country_code=''):
    # don't touch shortcodes
    if len(raw) <= 5:
        return raw

    raw = ''.join([c for c in raw if c.isdigit() or c == '+'])
    if raw.startswith('00'):
        return '+' + raw[2:]
    if raw.startswith('0'):
        return '+' + country_code + raw[1:]
    if raw.startswith('+'):
        return raw
    if raw.startswith(country_code):
        return '+' + raw
    return raw


class StringProducer(object):
    """
    For various twisted.web mechanics we need a producer to produce
    content for HTTP requests, this is a helper class to quickly
    create a producer for a bit of content
    """
    implements(IBodyProducer)

    def __init__(self, body):
        self.body = body
        self.length = len(body)

    def startProducing(self, consumer):
        consumer.write(self.body)
        return succeed(None)

    def pauseProducing(self):
        pass

    def stopProducing(self):
        pass


def build_web_site(resources, site_class=None):
    """Build a Twisted web Site instance for a specified dictionary of
    resources.

    :param dict resources:
        Dictionary of path -> resource class mappings to create the site from.
    :type site_class: Sub-class of Twisted's Site
    :param site_class:
        Site class to create. Defaults to :class:`LogFilterSite`.
    """
    if site_class is None:
        site_class = LogFilterSite

    root = Resource()
    # sort by ascending path length to make sure we create
    # resources lower down in the path earlier
    resources = resources.items()
    resources = sorted(resources, key=lambda r: len(r[0]))

    def create_node(node, path):
        if path in node.children:
            return node.children.get(path)
        else:
            new_node = Resource()
            node.putChild(path, new_node)
            return new_node

    for path, resource in resources:
        request_path = filter(None, path.split('/'))
        nodes, leaf = request_path[0:-1], request_path[-1]
        parent = reduce(create_node, nodes, root)
        parent.putChild(leaf, resource)

    site_factory = site_class(root)
    return site_factory


class LogFilterSite(Site):
    def log(self, request):
        if getattr(request, 'do_not_log', None):
            return
        return Site.log(self, request)


class PkgResources(object):
    """
    A helper for accessing a packages data files.

    :param str modname:
        The full dotted name of the module. E.g.
        ``vumi.resources``.
    """
    def __init__(self, modname):
        self.modname = modname

    def path(self, path):
        """
        Return the absolute path to a package resource.

        If path is already absolute, it is returned unmodified.

        :param str path:
            The relative or absolute path to the resource.
        """
        if os.path.isabs(path):
            return path
        return pkg_resources.resource_filename(self.modname, path)


vumi_resource_path = PkgResources("vumi.resources").path


def load_class(module_name, class_name):
    """
    Load a class when given its module and its class name

    >>> load_class('vumi.workers.example','ExampleWorker') # doctest: +ELLIPSIS
    
    >>>

    """
    mod = import_module(module_name)
    return getattr(mod, class_name)


def load_class_by_string(class_path):
    """
    Load a class when given its full name, including modules in python
    dot notation

    >>> cls = 'vumi.workers.example.ExampleWorker'
    >>> load_class_by_string(cls) # doctest: +ELLIPSIS
    
    >>>

    """
    parts = class_path.split('.')
    module_name = '.'.join(parts[:-1])
    class_name = parts[-1]
    return load_class(module_name, class_name)


def redis_from_config(redis_config):
    """
    Return a redis client instance from a config.

    If redis_config:

    * equals 'FAKE_REDIS', a new instance of :class:`FakeRedis` is returned.
    * is an instance of :class:`FakeRedis` that instance is returned

    Otherwise a new real redis client is returned.
    """
    warnings.warn("Use of redis directly is deprecated. Use vumi.persist "
                  "instead.", category=DeprecationWarning)

    import redis
    from vumi.persist import fake_redis
    if redis_config == "FAKE_REDIS":
        return fake_redis.FakeRedis()
    if isinstance(redis_config, fake_redis.FakeRedis):
        return redis_config
    return redis.Redis(**redis_config)


def flatten_generator(generator_func):
    """
    This is a synchronous version of @inlineCallbacks.

    NOTE: It doesn't correctly handle returnValue() being called in a
    non-decorated function called from the function we're decorating. We could
    copy the Twisted code to do that, but it's messy.
    """
    @wraps(generator_func)
    def wrapped(*args, **kw):
        gen = generator_func(*args, **kw)
        result = None
        while True:
            try:
                result = gen.send(result)
            except StopIteration:
                # Fell off the end, or "return" statement.
                return None
            except defer._DefGen_Return, e:
                # returnValue() called.
                return e.value

    return wrapped


def filter_options_on_prefix(options, prefix, delimiter='-'):
    """
    splits an options dict based on key prefixes

    >>> filter_options_on_prefix({'foo-bar-1': 'ok'}, 'foo')
    {'bar-1': 'ok'}
    >>>

    """
    return dict((key.split(delimiter, 1)[1], value)
                for key, value in options.items()
                if key.startswith(prefix))


def get_first_word(content, delimiter=' '):
    """
    Returns the first word from a string.

    Example::

      >>> get_first_word('KEYWORD rest of message')
      'KEYWORD'

    :type content: str or None
    :param content:
        Content from which the first word will be retrieved. If the
        content is None it is treated as an empty string (this is a
        convenience for dealing with content-less messages).
    :param str delimiter:
        Delimiter to split the string on. Default is ' '.
        Passed to :func:`string.partition`.
    :returns:
        A string containing the first word.
    """
    return (content or '').partition(delimiter)[0]


def cleanup_msisdn(number, country_code):
    number = re.sub('\+', '', number)
    number = re.sub('^0', country_code, number)
    return number


def get_operator_name(msisdn, mapping):
    for key, value in mapping.items():
        if msisdn.startswith(str(key)):
            if isinstance(value, dict):
                return get_operator_name(msisdn, value)
            return value
    return 'UNKNOWN'


def get_operator_number(msisdn, country_code, mapping, numbers):
    msisdn = cleanup_msisdn(msisdn, country_code)
    operator = get_operator_name(msisdn, mapping)
    number = numbers.get(operator)
    return number


def safe_routing_key(routing_key):
    """
    >>> safe_routing_key(u'*32323#')
    u's32323h'
    >>>

    """
    return reduce(lambda r_key, kv: r_key.replace(*kv),
                  [('*', 's'), ('#', 'h')], routing_key)


def generate_worker_id(system_id, worker_id):
    return "%s:%s" % (system_id, worker_id,)


class StatusEdgeDetector(object):
    '''Assists with finding if a TransportStatus is a change in the status,
    compared to previous statuses, or just a repeat. Will be useful to only
    publish statuses on state change.'''

    def __init__(self):
        self.state = dict()
        self.types = dict()

    def check_status(self, **status):
        '''
        Checks to see if the current status is a repeat. If it is, None is
        returned. If it isn't, the status is returned.

        :param status: The status to check.
        :type status: :class:`TransportStatus`
        '''
        self._check_state(status['status'], status['component'])
        if self._check_type(status['type'], status['component']):
            return status

    def _get_state(self, component):
        return self.state.get(component, None)

    def _set_state(self, component, state):
        self.state[component] = state

    def _get_types(self, component):
        return self.types.get(component, set())

    def _add_type(self, component, type_):
        if component not in self.types:
            self.types[component] = set()
        self.types[component].add(type_)

    def _remove_types(self, component):
        self.types.pop(component, None)

    def _check_state(self, status, component):
        state = self._get_state(component)
        if state != status:
            self._remove_types(component)
            self._set_state(component, status)

    def _check_type(self, type_, component):
        types = self._get_types(component)
        if type_ not in types:
            self._add_type(component, type_)
            return True
        return False
PKh^xG>>vumi/message.py# -*- test-case-name: vumi.tests.test_message -*-

import json
from uuid import uuid4
from datetime import datetime

from errors import MissingMessageField, InvalidMessageField

from vumi.utils import to_kwargs


# This is the date format we work with internally
VUMI_DATE_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
# Same as above, but without microseconds (for more permissive parsing).
_VUMI_DATE_FORMAT_NO_MICROSECONDS = "%Y-%m-%d %H:%M:%S"


def format_vumi_date(timestamp):
    """Format a datetime object using the Vumi date format.

    :param datetime timestamp:
        The datetime object to format.
    :return str:
        The timestamp formatted as a string.
    """
    return timestamp.strftime(VUMI_DATE_FORMAT)


def parse_vumi_date(value):
    """Parse a timestamp string using the Vumi date format.

    Timestamps without microseconds are also parsed correctly.

    :param str value:
        The string to parse.
    :return datetime:
        A datetime object representing the timestamp.
    """
    date_format = VUMI_DATE_FORMAT
    # We only look at the last ten characters, because that's where the "."
    # will be in a valid serialised timestamp with microseconds.
    if "." not in value[-10:]:
        date_format = _VUMI_DATE_FORMAT_NO_MICROSECONDS
    return datetime.strptime(value, date_format)


def date_time_decoder(json_object):
    for key, value in json_object.items():
        try:
            json_object[key] = parse_vumi_date(value)
        except ValueError:
            continue
        except TypeError:
            continue
    return json_object


class JSONMessageEncoder(json.JSONEncoder):
    """A JSON encoder that is able to serialize datetime"""
    def default(self, obj):
        if isinstance(obj, datetime):
            return format_vumi_date(obj)
        return super(JSONMessageEncoder, self).default(obj)


def from_json(json_string):
    return json.loads(json_string, object_hook=date_time_decoder)


def to_json(obj):
    return json.dumps(obj, cls=JSONMessageEncoder)


class Message(object):
    """
    A unified message object used by Vumi when transmitting messages over AMQP
    and occassionally as a standardised JSON format for use in external APIs.

    The special ``.cache`` property stores a dictionary of data that is not
    stored by the :class:`vumi.fields.VumiMessage` field and hence not stored
    by Vumi's message store.
    """

    # name of the special attribute that isn't stored by the message store
    _CACHE_ATTRIBUTE = "__cache__"

    def __init__(self, _process_fields=True, **kwargs):
        if _process_fields:
            kwargs = self.process_fields(kwargs)
        self.payload = kwargs
        self.validate_fields()

    def process_fields(self, fields):
        return fields

    def validate_fields(self):
        pass

    def assert_field_present(self, *fields):
        for field in fields:
            if field not in self.payload:
                raise MissingMessageField(field)

    def assert_field_value(self, field, *values):
        self.assert_field_present(field)
        if self.payload[field] not in values:
            raise InvalidMessageField(field)

    def to_json(self):
        return to_json(self.payload)

    @classmethod
    def from_json(cls, json_string):
        return cls(_process_fields=False, **to_kwargs(from_json(json_string)))

    def __str__(self):
        return u"" % repr(self.payload)

    def __repr__(self):
        return str(self)

    def __eq__(self, other):
        if isinstance(other, Message):
            return self.payload == other.payload
        return False

    def __contains__(self, key):
        return key in self.payload

    def __getitem__(self, key):
        return self.payload[key]

    def __setitem__(self, key, value):
        self.payload[key] = value

    def get(self, key, default=None):
        return self.payload.get(key, default)

    def items(self):
        return self.payload.items()

    def copy(self):
        return self.from_json(self.to_json())

    @property
    def cache(self):
        """
        A special payload attribute that isn't stored by the message store.
        """
        return self.payload.setdefault(self._CACHE_ATTRIBUTE, {})


class TransportMessage(Message):
    """Common base class for messages sent to or from a transport."""

    # sub-classes should set the message type
    MESSAGE_TYPE = None
    MESSAGE_VERSION = '20110921'
    DEFAULT_ENDPOINT_NAME = 'default'

    @staticmethod
    def generate_id():
        """
        Generate a unique message id.

        There are places where we want a message id before we can
        build a complete message. This lets us do that in a consistent
        manner.
        """
        return uuid4().get_hex()

    def process_fields(self, fields):
        fields.setdefault('message_version', self.MESSAGE_VERSION)
        fields.setdefault('message_type', self.MESSAGE_TYPE)
        fields.setdefault('timestamp', datetime.utcnow())
        fields.setdefault('routing_metadata', {})
        fields.setdefault('helper_metadata', {})
        return fields

    def validate_fields(self):
        self.assert_field_value('message_version', self.MESSAGE_VERSION)
        # We might get older event messages without the `helper_metadata`
        # field.
        self.payload.setdefault('helper_metadata', {})
        self.assert_field_present(
            'message_type',
            'timestamp',
            'helper_metadata',
            )
        if self['message_type'] is None:
            raise InvalidMessageField('message_type')

    @property
    def routing_metadata(self):
        return self.payload.setdefault('routing_metadata', {})

    @classmethod
    def check_routing_endpoint(cls, endpoint_name):
        if endpoint_name is None:
            return cls.DEFAULT_ENDPOINT_NAME
        return endpoint_name

    def set_routing_endpoint(self, endpoint_name=None):
        endpoint_name = self.check_routing_endpoint(endpoint_name)
        self.routing_metadata['endpoint_name'] = endpoint_name

    def get_routing_endpoint(self):
        endpoint_name = self.routing_metadata.get('endpoint_name')
        return self.check_routing_endpoint(endpoint_name)


class TransportUserMessage(TransportMessage):
    """Message to or from a user.

    transport_type = sms, ussd, etc
    helper_metadata = for use by dispathers and off-to-the-side
                      components like failure workers (not for use
                      by transports or message workers).
    """

    MESSAGE_TYPE = 'user_message'

    # session event constants
    #
    # SESSION_NONE, SESSION_NEW, SESSION_RESUME, and SESSION_CLOSE
    # may be sent from the transport to a worker. SESSION_NONE indicates
    # there is no relevant session for this message.
    #
    # SESSION_NONE and SESSION_CLOSE may be sent from the worker to
    # the transport. SESSION_NONE indicates any existing session
    # should be continued. SESSION_CLOSE indicates that any existing
    # session should be terminated after sending the message.
    SESSION_NONE, SESSION_NEW, SESSION_RESUME, SESSION_CLOSE = (
        None, 'new', 'resume', 'close')

    # list of valid session events
    SESSION_EVENTS = frozenset([SESSION_NONE, SESSION_NEW, SESSION_RESUME,
                                SESSION_CLOSE])

    # canonical transport types
    TT_HTTP_API = 'http_api'
    TT_IRC = 'irc'
    TT_TELNET = 'telnet'
    TT_TWITTER = 'twitter'
    TT_SMS = 'sms'
    TT_USSD = 'ussd'
    TT_XMPP = 'xmpp'
    TT_MXIT = 'mxit'
    TT_WECHAT = 'wechat'
    TRANSPORT_TYPES = set([TT_HTTP_API, TT_IRC, TT_TELNET, TT_TWITTER, TT_SMS,
                           TT_USSD, TT_XMPP, TT_MXIT, TT_WECHAT])

    AT_IRC_NICKNAME = 'irc_nickname'
    AT_TWITTER_HANDLE = 'twitter_handle'
    AT_MSISDN = 'msisdn'
    AT_GTALK_ID = 'gtalk_id'
    AT_JABBER_ID = 'jabber_id'
    AT_MXIT_ID = 'mxit_id'
    AT_WECHAT_ID = 'wechat_id'
    ADDRESS_TYPES = set([
        AT_IRC_NICKNAME, AT_TWITTER_HANDLE, AT_MSISDN, AT_GTALK_ID,
        AT_JABBER_ID, AT_MXIT_ID, AT_WECHAT_ID])

    def process_fields(self, fields):
        fields = super(TransportUserMessage, self).process_fields(fields)
        fields.setdefault('message_id', self.generate_id())
        fields.setdefault('in_reply_to', None)
        fields.setdefault('provider', None)
        fields.setdefault('session_event', None)
        fields.setdefault('content', None)
        fields.setdefault('transport_metadata', {})
        fields.setdefault('group', None)
        fields.setdefault('to_addr_type', None)
        fields.setdefault('from_addr_type', None)
        return fields

    def validate_fields(self):
        super(TransportUserMessage, self).validate_fields()
        # We might get older message versions without the `group` or `provider`
        # fields.
        self.payload.setdefault('group', None)
        self.payload.setdefault('provider', None)
        self.assert_field_present(
            'message_id',
            'to_addr',
            'from_addr',
            'in_reply_to',
            'session_event',
            'content',
            'transport_name',
            'transport_type',
            'transport_metadata',
            'group',
            'provider',
            )
        if self['session_event'] not in self.SESSION_EVENTS:
            raise InvalidMessageField("Invalid session_event %r"
                                      % (self['session_event'],))

    def user(self):
        return self['from_addr']

    def reply(self, content, continue_session=True, **kw):
        """Construct a reply message.

        The reply message will have its `to_addr` field set to the original
        message's `from_addr`. This means that even if the original message is
        directed to the group only (i.e. it has `to_addr` set to `None`), the
        reply will be directed to the sender of the original message.

        :meth:`reply` suitable for constructing both one-to-one messages (such
        as SMS) and directed messages within a group chat (such as
        name-prefixed content in an IRC channel message).

        If `session_event` is provided in the the keyword args,
        `continue_session` will be ignored.

        NOTE: Certain fields are required to come from the message being
              replied to and may not be overridden by this method:

              # If we're not using this addressing, we shouldn't be replying.
              'to_addr', 'from_addr', 'group', 'in_reply_to', 'provider'
              # These three belong together and are supposed to be opaque.
              'transport_name', 'transport_type', 'transport_metadata'

        FIXME: `helper_metadata` should *not* be copied to the reply message.
               We only do it here because a bunch of legacy code relies on it.
        """
        session_event = None if continue_session else self.SESSION_CLOSE

        for field in [
                # If we're not using this addressing, we shouldn't be replying.
                'to_addr', 'from_addr', 'group', 'in_reply_to', 'provider'
                # These three belong together and are supposed to be opaque.
                'transport_name', 'transport_type', 'transport_metadata']:
            if field in kw:
                # Other "bad keyword argument" conditions cause TypeErrors.
                raise TypeError("'%s' may not be overridden." % (field,))

        fields = {
            'helper_metadata': self['helper_metadata'],  # XXX: See above.
            'session_event': session_event,
            'to_addr': self['from_addr'],
            'from_addr': self['to_addr'],
            'group': self['group'],
            'in_reply_to': self['message_id'],
            'provider': self['provider'],
            'transport_name': self['transport_name'],
            'transport_type': self['transport_type'],
            'transport_metadata': self['transport_metadata'],
        }
        fields.update(kw)

        out_msg = TransportUserMessage(content=content, **fields)
        # The reply should go out the same endpoint it came in.
        out_msg.set_routing_endpoint(self.get_routing_endpoint())
        return out_msg

    def reply_group(self, *args, **kw):
        """Construct a group reply message.

        If the `group` field is set to `None`, :meth:`reply_group` is identical
        to :meth:`reply`.

        If the `group` field is not set to `None`, the reply message will have
        its `to_addr` field set to `None`. This means that even if the original
        message is directed to an individual within the group (i.e. its
        `to_addr` is not set to `None`), the reply will be directed to the
        group as a whole.

        :meth:`reply_group` suitable for both one-to-one messages (such as SMS)
        and undirected messages within a group chat (such as IRC channel
        messages).
        """
        out_msg = self.reply(*args, **kw)
        if self['group'] is not None:
            out_msg['to_addr'] = None
        return out_msg

    @classmethod
    def send(cls, to_addr, content, **kw):
        kw.setdefault('from_addr', None)
        kw.setdefault('transport_name', None)
        kw.setdefault('transport_type', None)
        kw.setdefault('session_event', cls.SESSION_NONE)
        out_msg = cls(
            to_addr=to_addr,
            in_reply_to=None,
            content=content,
            **kw)
        return out_msg


class TransportEvent(TransportMessage):
    """Message about a TransportUserMessage.
    """
    MESSAGE_TYPE = 'event'

    # list of valid delivery statuses
    DELIVERY_STATUSES = frozenset(('pending', 'failed', 'delivered'))

    # map of event_types -> extra fields
    EVENT_TYPES = {
        'ack': {'sent_message_id': lambda v: v is not None},
        'nack': {
            'nack_reason': lambda v: v is not None,
        },
        'delivery_report': {
            'delivery_status': lambda v: v in TransportEvent.DELIVERY_STATUSES,
            },
        }

    def process_fields(self, fields):
        fields = super(TransportEvent, self).process_fields(fields)
        fields.setdefault('event_id', self.generate_id())
        return fields

    def validate_fields(self):
        super(TransportEvent, self).validate_fields()
        self.assert_field_present(
            'user_message_id',
            'event_id',
            'event_type',
            )
        event_type = self.payload['event_type']
        if event_type not in self.EVENT_TYPES:
            raise InvalidMessageField("Unknown event_type %r" % (event_type,))
        for extra_field, check in self.EVENT_TYPES[event_type].items():
            self.assert_field_present(extra_field)
            if not check(self[extra_field]):
                raise InvalidMessageField(extra_field)

    def status(self):
        status = self['event_type']
        if status == "delivery_report":
            status = "%s.%s" % (status, self['delivery_status'])
        return status


class TransportStatus(TransportMessage):
    """Message about a status event emitted by a transport.
    """
    MESSAGE_TYPE = 'status_event'
    STATUSES = frozenset(('ok', 'degraded', 'down'))

    def process_fields(self, fields):
        super(TransportStatus, self).process_fields(fields)
        fields.setdefault('reasons', [])
        fields.setdefault('details', {})
        return fields

    def validate_fields(self):
        super(TransportStatus, self).validate_fields()
        self.assert_field_present('component')
        self.assert_field_present('status')
        self.assert_field_present('type')
        self.assert_field_present('message')

        if self.payload['status'] not in self.STATUSES:
            raise InvalidMessageField(
                "Unknown status %r" % (self.payload['status'],))
PK=H{~6DDvumi/__init__.py"""
Vumi scalable text messaging engine.
"""

__version__ = "0.6.9"
PK=JGV\>ZZvumi/errors.pyclass VumiError(Exception):
    pass


class InvalidMessage(VumiError):
    pass


class InvalidMessageType(VumiError):
    pass


class MissingMessageField(InvalidMessage):
    pass


class InvalidMessageField(InvalidMessage):
    pass


class DuplicateConnectorError(VumiError):
    pass


class InvalidEndpoint(VumiError):
    """Raised when attempting to send a message to an invalid endpoint."""


class DispatcherError(VumiError):
    """Raised when an error is encounter while dispatching a message."""


# Re-export this for compatibility.
from confmodel.errors import ConfigError

ConfigError
PKqGTECCvumi/service.py# -*- test-case-name: vumi.tests.test_service -*-

import json
import warnings
from copy import deepcopy

from twisted.python import log
from twisted.application.service import MultiService
from twisted.application.internet import TCPClient
from twisted.internet.defer import (
    inlineCallbacks, returnValue, Deferred, succeed)
from twisted.internet import protocol, reactor
import txamqp
from txamqp.client import TwistedDelegate
from txamqp.content import Content
from txamqp.protocol import AMQClient

from vumi.errors import VumiError
from vumi.message import Message
from vumi.utils import load_class_by_string, vumi_resource_path, build_web_site


SPECS = {}


def get_spec(specfile):
    """
    Cache the generated part of txamqp, because generating it is expensive.

    This is important for tests, which create lots of txamqp clients,
    and therefore generate lots of specs. Just doing this results in a
    decidedly happy test run time reduction.
    """
    if specfile not in SPECS:
        SPECS[specfile] = txamqp.spec.load(specfile)
    return SPECS[specfile]


class AmqpFactory(protocol.ReconnectingClientFactory):

    def __init__(self, worker):
        self.options = worker.options
        self.config = worker.config
        self.spec = get_spec(vumi_resource_path(worker.options['specfile']))
        self.delegate = TwistedDelegate()
        self.worker = worker
        self.amqp_client = None

    def buildProtocol(self, addr):
        self.amqp_client = WorkerAMQClient(
            self.delegate, self.options['vhost'],
            self.spec, self.options.get('heartbeat', 0))
        self.amqp_client.factory = self
        self.amqp_client.vumi_options = self.options
        self.amqp_client.connected_callback = self.worker._amqp_connected
        self.resetDelay()
        return self.amqp_client

    def clientConnectionFailed(self, connector, reason):
        log.err("AmqpFactory connection failed (%s)" % (
            reason.getErrorMessage(),))
        self.worker._amqp_connection_failed()
        self.amqp_client = None
        protocol.ReconnectingClientFactory.clientConnectionFailed(
            self, connector, reason)

    def clientConnectionLost(self, connector, reason):
        if not self.worker.running:
            # We've specifically asked for this disconnect.
            return
        log.err("AmqpFactory client connection lost (%s)" % (
            reason.getErrorMessage(),))
        self.worker._amqp_connection_failed()
        self.amqp_client = None
        protocol.ReconnectingClientFactory.clientConnectionLost(
            self, connector, reason)


class WorkerAMQClient(AMQClient):
    @inlineCallbacks
    def connectionMade(self):
        AMQClient.connectionMade(self)
        yield self.authenticate(self.vumi_options['username'],
                                self.vumi_options['password'])
        # authentication was successful
        log.msg("Got an authenticated connection")
        yield self.connected_callback(self)

    @inlineCallbacks
    def get_channel(self, channel_id=None):
        """If channel_id is None a new channel is created"""
        if channel_id:
            channel = self.channels[channel_id]
        else:
            channel_id = self.get_new_channel_id()
            channel = yield self.channel(channel_id)
            yield channel.channel_open()
            self.channels[channel_id] = channel
        returnValue(channel)

    def get_new_channel_id(self):
        """
        AMQClient keeps track of channels in a dictionary. The
        channel ids are the keys, get the highest number and up it
        or just return zero for the first channel
        """
        return (max(self.channels) + 1) if self.channels else 0

    def _declare_exchange(self, source, channel):
        # get the details for AMQP
        exchange_name = source.exchange_name
        exchange_type = source.exchange_type
        durable = source.durable
        return channel.exchange_declare(exchange=exchange_name,
                                        type=exchange_type, durable=durable)

    @inlineCallbacks
    def start_consumer(self, consumer_class, *args, **kwargs):
        channel = yield self.get_channel()

        consumer = consumer_class(channel, *args, **kwargs)
        consumer.vumi_options = self.vumi_options

        # get the details for AMQP
        exchange_name = consumer.exchange_name
        durable = consumer.durable
        queue_name = consumer.queue_name
        routing_key = consumer.routing_key

        # declare the exchange, doesn't matter if it already exists
        yield self._declare_exchange(consumer, channel)

        # declare the queue
        yield channel.queue_declare(queue=queue_name, durable=durable)
        # bind it to the exchange with the routing key
        yield channel.queue_bind(queue=queue_name, exchange=exchange_name,
                                 routing_key=routing_key)
        yield consumer.start()
        # return the newly created & consuming consumer
        returnValue(consumer)

    @inlineCallbacks
    def start_publisher(self, publisher_class, *args, **kwargs):
        # much more braindead than start_consumer
        # get a channel
        channel = yield self.get_channel()
        # start the publisher
        publisher = publisher_class(*args, **kwargs)
        publisher.vumi_options = self.vumi_options
        # declare the exchange, doesn't matter if it already exists
        yield self._declare_exchange(publisher, channel)
        # start!
        yield publisher.start(channel)
        # return the publisher
        returnValue(publisher)


class Worker(MultiService, object):
    """
    The Worker is responsible for starting consumers & publishers
    as needed.
    """

    def __init__(self, options, config=None):
        super(Worker, self).__init__()
        self.options = options
        if config is None:
            config = {}
        self.config = config
        self._amqp_client = None

    def _amqp_connected(self, amqp_client):
        self._amqp_client = amqp_client
        return self.startWorker()

    def _amqp_connection_failed(self):
        pass

    def _amqp_connection_lost(self):
        self._amqp_client = None

    def startWorker(self):
        # I hate camelCasing method but since Twisted has it as a
        # standard I voting to stick with it
        raise VumiError("You need to subclass Worker and its "
                        "startWorker method")

    def stopWorker(self):
        pass

    @inlineCallbacks
    def stopService(self):
        if self.running:
            yield self.stopWorker()
        yield super(Worker, self).stopService()

    def routing_key_to_class_name(self, routing_key):
        return ''.join(map(lambda s: s.capitalize(), routing_key.split('.')))

    def consume(self, routing_key, callback, queue_name=None,
                exchange_name='vumi', exchange_type='direct', durable=True,
                message_class=None, paused=False, prefetch_count=None):

        # use the routing key to generate the name for the class
        # amq.routing.key -> AmqRoutingKey
        dynamic_name = self.routing_key_to_class_name(routing_key)
        class_name = "%sDynamicConsumer" % str(dynamic_name)
        kwargs = {
            'routing_key': routing_key,
            'queue_name': queue_name or routing_key,
            'exchange_name': exchange_name,
            'exchange_type': exchange_type,
            'durable': durable,
            'start_paused': paused,
            'prefetch_count': prefetch_count,
        }
        log.msg('Starting %s with %s' % (class_name, kwargs))
        klass = type(class_name, (DynamicConsumer,), kwargs)
        if message_class is not None:
            klass.message_class = message_class
        return self.start_consumer(klass, callback)

    def start_consumer(self, consumer_class, *args, **kw):
        return self._amqp_client.start_consumer(consumer_class, *args, **kw)

    @inlineCallbacks
    def publish_to(self, routing_key):
        channel = yield self._amqp_client.get_channel()
        publisher = DynamicPublisher(channel, routing_key)
        yield self._amqp_client._declare_exchange(publisher, channel)
        # return the publisher
        returnValue(publisher)

    def start_publisher(self, publisher_class, *args, **kw):
        return self._amqp_client.start_publisher(publisher_class, *args, **kw)

    def start_web_resources(self, resources, port, site_class=None):
        resources = dict((path, resource) for resource, path in resources)
        site_factory = build_web_site(resources, site_class=site_class)
        return reactor.listenTCP(port, site_factory)


class QueueCloseMarker(object):
    "This is a marker for closing consumer queues."


class Consumer(object):

    exchange_name = "vumi"
    exchange_type = "direct"
    durable = False

    queue_name = "queue"
    routing_key = "routing_key"

    message_class = Message
    start_paused = False
    prefetch_count = None

    def __init__(self, channel):
        self.channel = channel
        self._fake_channel = getattr(self.channel, '_fake_channel', None)
        self._notify_paused_and_quiet = []
        self.keep_consuming = False
        self.queue = None
        self._consumer_tag = None

    @inlineCallbacks
    def start(self):
        self._in_progress = 0
        self.keep_consuming = True
        self.paused = self.start_paused
        self._unpause_d = None
        if self.prefetch_count is not None:
            yield self.channel.basic_qos(0, self.prefetch_count, False)
        if not self.paused:
            yield self.unpause()
        returnValue(self)

    @inlineCallbacks
    def _read_messages(self):
        try:
            while self.keep_consuming:
                message = yield self.queue.get()
                if isinstance(message, QueueCloseMarker):
                    break
                if self.paused:
                    yield self._unpause_d
                yield self.consume(message)
        except txamqp.queue.Closed as e:
            log.err("Queue has closed", e)
        except Exception:
            # Log this explicitly instead of waiting for the deferred to be
            # garbage-collected, because that might only happen later on pypy.
            log.err()

    @inlineCallbacks
    def _channel_consume(self):
        if self._consumer_tag is not None:
            raise RuntimeError("Consumer already registered.")
        reply = yield self.channel.basic_consume(queue=self.queue_name)
        self._consumer_tag = reply.consumer_tag
        self.queue = yield self.channel.client.queue(self._consumer_tag)
        self.keep_consuming = True
        self._read_messages()

    @inlineCallbacks
    def pause(self):
        self.paused = True
        if self._unpause_d is None:
            self._unpause_d = Deferred()
        yield self.notify_paused_and_quiet()

    def unpause(self):
        self.paused = False
        d, self._unpause_d = self._unpause_d, None
        if d is not None:
            d.callback(None)
        if self._consumer_tag is None:
            return self._channel_consume()

    def notify_paused_and_quiet(self):
        d = Deferred()
        self._notify_paused_and_quiet.append(d)
        self._check_notify()
        return d

    def _check_notify(self):
        if self.paused and not self._in_progress:
            while self._notify_paused_and_quiet:
                self._notify_paused_and_quiet.pop(0).callback(None)

    @inlineCallbacks
    def consume(self, message):
        self._in_progress += 1
        try:
            result = yield self.consume_message(
                self.message_class.from_json(message.content.body))
        finally:
            # If we get an exception here the consumer's already pretty much
            # broken, but we still decrement the _in_progress counter so we
            # don't wait forever for it during shutdown.
            self._in_progress -= 1
            if self._fake_channel is not None:
                self._fake_channel.message_processed()
        if result is not False:
            yield self.channel.basic_ack(message.delivery_tag, False)
        else:
            log.msg('Received %s as a return value consume_message. '
                    'Not acknowledging AMQ message' % result)
        self._check_notify()

    def consume_message(self, message):
        """helper method, override in implementation"""
        log.msg("Received message: %s" % message)

    @inlineCallbacks
    def stop(self):
        log.msg("Consumer stopping...")
        self.keep_consuming = False
        yield self.pause()
        # This actually closes the channel on the server
        yield self.channel.channel_close()
        # This just marks the channel as closed on the client
        self.channel.close(None)
        returnValue(self.keep_consuming)


class DynamicConsumer(Consumer):
    def __init__(self, channel, callback):
        super(DynamicConsumer, self).__init__(channel)
        self.callback = callback

    def consume_message(self, message):
        return self.callback(message)


class RoutingKeyError(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


class _Publisher(object):
    exchange_name = "vumi"
    exchange_type = "direct"
    durable = False
    auto_delete = False
    delivery_mode = 2  # save to disk

    def check_routing_key(self, routing_key):
        if routing_key != routing_key.lower():
            raise RoutingKeyError(
                "The routing_key: %s is not all lower case!" % (routing_key,))


class Publisher(_Publisher):
    """
    An old-style publisher to subclass for special-purpose publishers.
    This is deprecated in favour of using :meth:`Worker.publish_to`, although
    it will stay around for a while.
    """

    routing_key = "routing_key"

    def start(self, channel):
        warnings.warn(
            "Subclassing the Publisher class is deprecated. Please use"
            " Worker.publish_to() instead.", category=DeprecationWarning)
        log.msg("Started the publisher")
        self.channel = channel
        self.bound_routing_keys = {}

    @inlineCallbacks
    def _publish(self, message, routing_key=None):
        if routing_key is None:
            routing_key = self.routing_key
            self.check_routing_key(routing_key)
        yield self.channel.basic_publish(
            exchange=self.exchange_name, content=message,
            routing_key=routing_key)

    def publish_message(self, message, routing_key=None):
        d = self.publish_raw(message.to_json(), routing_key=routing_key)
        d.addCallback(lambda r: message)
        return d

    def publish_json(self, data, routing_key=None):
        """helper method"""
        return self.publish_raw(json.dumps(data, cls=json.JSONEncoder),
                                routing_key=routing_key)

    def publish_raw(self, data, routing_key=None):
        amq_message = Content(data)
        amq_message['delivery mode'] = self.delivery_mode
        return self._publish(amq_message, routing_key=routing_key)


class DynamicPublisher(_Publisher):
    """
    A single-routing-key publisher.
    """

    durable = True

    def __init__(self, channel, routing_key):
        self.channel = channel
        self.check_routing_key(routing_key)
        self.routing_key = routing_key

    def publish_message(self, message):
        self.publish_raw(message.to_json())
        return succeed(message)

    def publish_json(self, data):
        self.publish_raw(json.dumps(data, cls=json.JSONEncoder))

    def publish_raw(self, data):
        amq_message = Content(data)
        amq_message['delivery mode'] = self.delivery_mode
        self._publish(amq_message)

    def _publish(self, message):
        return self.channel.basic_publish(
            exchange=self.exchange_name, content=message,
            routing_key=self.routing_key)


class WorkerCreator(object):
    """
    Creates workers
    """

    def __init__(self, vumi_options):
        self.options = vumi_options

    def create_worker(self, worker_class, config, timeout=30,
                      bindAddress=None):
        """
        Create a worker factory, connect to AMQP and return the factory.

        Return value is the AmqpFactory instance containing the worker.
        """
        return self.create_worker_by_class(
            load_class_by_string(worker_class), config, timeout=timeout,
            bindAddress=bindAddress)

    def create_worker_by_class(self, worker_class, config, timeout=30,
                               bindAddress=None):
        worker = worker_class(deepcopy(self.options), config)
        self._connect(worker, timeout=timeout, bindAddress=bindAddress)
        return worker

    def _connect(self, worker, timeout, bindAddress):
        service = TCPClient(self.options['hostname'], self.options['port'],
                            AmqpFactory(worker), timeout, bindAddress)
        service.setServiceParent(worker)
PK=JGvumi/sentry.py# -*- test-case-name: vumi.tests.test_sentry -*-

import logging

from twisted.python import log
from twisted.web.client import HTTPClientFactory, _makeGetterFactory
from twisted.internet.defer import DeferredList
from twisted.application.service import Service


DEFAULT_LOG_CONTEXT_SENTINEL = "_SENTRY_CONTEXT_"


class QuietHTTPClientFactory(HTTPClientFactory):
    """HTTP client factory that doesn't log starting and stopping."""
    noisy = False


def quiet_get_page(url, contextFactory=None, *args, **kwargs):
    """A version of getPage that uses QuietHTTPClientFactory."""
    return _makeGetterFactory(
        url,
        QuietHTTPClientFactory,
        contextFactory=contextFactory,
        *args, **kwargs).deferred


def vumi_raven_client(dsn, log_context_sentinel=None):
    """Construct a custom raven client and transport-set pair.

    The raven client assumes that sends via transports return success or
    failure immediate in a blocking fashion and doesn't provide transports
    access to the client.

    We circumvent this by constructing a once-off transport class and
    raven client pair that work together. Instances of the transport feed
    information back success and failure back to the client instance once
    deferreds complete.

    Pull-requests with better solutions welcomed.
    """

    import raven
    from raven.transport.base import TwistedHTTPTransport
    from raven.transport.registry import TransportRegistry

    remaining_deferreds = set()
    if log_context_sentinel is None:
        log_context_sentinel = DEFAULT_LOG_CONTEXT_SENTINEL
    log_context = {log_context_sentinel: True}

    class VumiRavenHTTPTransport(TwistedHTTPTransport):

        scheme = ['http', 'https']

        def _get_page(self, data, headers):
            d = quiet_get_page(self._url, method='POST', postdata=data,
                               headers=headers)
            self._track_deferred(d)
            self._track_client_state(d)
            return d

        def _track_deferred(self, d):
            remaining_deferreds.add(d)
            d.addBoth(self._untrack_deferred, d)

        def _untrack_deferred(self, result, d):
            remaining_deferreds.discard(d)
            return result

        def _track_client_state(self, d):
            d.addCallbacks(self._set_client_success, self._set_client_fail)

        def _set_client_success(self, result):
            client.state.set_success()
            return result

        def _set_client_fail(self, result):
            client.state.set_fail()
            return result

        def send(self, data, headers):
            d = self._get_page(data, headers)
            d.addErrback(lambda f: log.err(f, **log_context))

    class VumiRavenClient(raven.Client):

        _registry = TransportRegistry(transports=[
            VumiRavenHTTPTransport
        ])

        def teardown(self):
            return DeferredList(remaining_deferreds)

    client = VumiRavenClient(dsn)
    return client


class SentryLogObserver(object):
    """Twisted log observer that logs to a Raven Sentry client."""

    DEFAULT_ERROR_LEVEL = logging.ERROR
    DEFAULT_LOG_LEVEL = logging.INFO
    LOG_LEVEL_THRESHOLD = logging.WARN

    def __init__(self, client, logger_name, worker_id,
                 log_context_sentinel=None):
        if log_context_sentinel is None:
            log_context_sentinel = DEFAULT_LOG_CONTEXT_SENTINEL
        self.client = client
        self.logger_name = logger_name
        self.worker_id = worker_id
        self.log_context_sentinel = log_context_sentinel
        self.log_context = {self.log_context_sentinel: True}

    def level_for_event(self, event):
        level = event.get('logLevel')
        if level is not None:
            return level
        if event.get('isError'):
            return self.DEFAULT_ERROR_LEVEL
        return self.DEFAULT_LOG_LEVEL

    def logger_for_event(self, event):
        system = event.get('system', '-')
        parts = [self.logger_name]
        if system != '-':
            parts.extend(system.split(','))
        logger = ".".join(parts)
        return logger.lower()

    def _log_to_sentry(self, event):
        level = self.level_for_event(event)
        if level < self.LOG_LEVEL_THRESHOLD:
            return
        data = {
            "logger": self.logger_for_event(event),
            "level": level,
        }
        tags = {
            "worker-id": self.worker_id,
        }
        failure = event.get('failure')
        if failure:
            exc_info = (failure.type, failure.value, failure.tb)
            self.client.captureException(exc_info, data=data, tags=tags)
        else:
            msg = log.textFromEventDict(event)
            self.client.captureMessage(msg, data=data, tags=tags)

    def __call__(self, event):
        if self.log_context_sentinel in event:
            return
        log.callWithContext(self.log_context, self._log_to_sentry, event)


class SentryLoggerService(Service):

    def __init__(self, dsn, logger_name, worker_id, logger=None):
        self.setName('Sentry Logger')
        self.dsn = dsn
        self.client = vumi_raven_client(dsn=dsn)
        self.sentry_log_observer = SentryLogObserver(self.client,
                                                     logger_name,
                                                     worker_id)
        self.logger = logger if logger is not None else log.theLogPublisher

    def startService(self):
        self.logger.addObserver(self.sentry_log_observer)
        return Service.startService(self)

    def stopService(self):
        if self.running:
            self.logger.removeObserver(self.sentry_log_observer)
            return self.client.teardown()
        return Service.stopService(self)

    def registered(self):
        return self.sentry_log_observer in self.logger.observers
PK=JG0vumi/multiworker.py# -*- test-case-name: vumi.tests.test_multiworker -*-

from copy import deepcopy

from vumi.service import Worker, WorkerCreator


class MultiWorker(Worker):
    """A worker whose job it is to start other workers.

    Config options:

    :type workers: dict
    :param workers:
        Dict of worker_name -> fully-qualified class name.
    :type defaults: dict
    :param defaults:
        Default configuration for child workers.

    Each entry in the ``workers`` config dict defines a child worker to start.
    A child worker's configuration should be provided in a config dict keyed by
    its name. Common configuration across child workers should go in the
    ``defaults`` config dict.
    """

    WORKER_CREATOR = WorkerCreator

    def construct_worker_config(self, worker_name):
        """
        Construct an appropriate configuration for the child worker.
        """
        config = deepcopy(self.config.get('defaults', {}))
        config.update(self.config.get(worker_name, {}))
        return config

    def create_worker(self, worker_name, worker_class):
        """
        Create a child worker.
        """
        config = self.construct_worker_config(worker_name)
        worker = self.worker_creator.create_worker(worker_class, config)
        worker.setName(worker_name)
        worker.setServiceParent(self)
        return worker

    def startService(self):
        super(MultiWorker, self).startService()
        self.workers = []
        self.worker_creator = self.WORKER_CREATOR(self.options)
        for wname, wclass in self.config.get('workers', {}).items():
            worker = self.create_worker(wname, wclass)
            self.workers.append(worker)

    def startWorker(self):
        pass
PK=JGMW-vumi/servicemaker.py# -*- test-case-name: vumi.tests.test_servicemaker -*-
import os
import sys
import warnings

import yaml
from zope.interface import implements
from twisted.python import usage
from twisted.application.service import IServiceMaker
from twisted.plugin import IPlugin

from vumi.service import WorkerCreator
from vumi.utils import (load_class_by_string,
                        generate_worker_id)
from vumi.errors import VumiError
from vumi.sentry import SentryLoggerService


class SafeLoaderWithInclude(yaml.SafeLoader):
    def __init__(self, *args, **kwargs):
        super(SafeLoaderWithInclude, self).__init__(*args, **kwargs)
        self.add_constructor('!include', self._include)
        if isinstance(self.stream, file):
            self._root = os.path.dirname(self.stream.name)
        else:
            self._root = os.path.curdir

    def _include(self, loader, node):
        filename = os.path.join(self._root, self.construct_scalar(node))
        filename = os.path.normpath(filename)
        with open(filename) as f:
            return yaml.load(f, Loader=SafeLoaderWithInclude)


def overlay_configs(*configs):
    """Non-recursively overlay a set of configuration dictionaries"""
    config = {}

    for overlay in configs:
        config.update(overlay)

    return config


def filter_null_values(config):
    """Remove keys with None values from a dictionary."""
    return dict(item for item in config.iteritems() if item[1] is not None)


def read_yaml_config(config_file, optional=True):
    """Parse an (usually) optional YAML config file."""
    if optional and config_file is None:
        return {}
    with file(config_file, 'r') as stream:
        # Assume we get a dict out of this.
        return yaml.load(stream, Loader=SafeLoaderWithInclude)


class VumiOptions(usage.Options):
    """
    Options global to everything vumi.
    """
    optParameters = [
        ["hostname", None, None, "AMQP broker (*)"],
        ["port", None, None, "AMQP port (*)", int],
        ["username", None, None, "AMQP username (*)"],
        ["password", None, None, "AMQP password (*)"],
        ["vhost", None, None, "AMQP virtual host (*)"],
        ["specfile", None, None, "AMQP spec file (*)"],
        ["sentry", None, None, "Sentry DSN (*)"],
        ["vumi-config", None, None,
         "YAML config file for setting core vumi options (any command-line"
         " parameter marked with an asterisk)"],
        ["system-id", None, None,
         "An identifier for a collection of Vumi workers"],
    ]

    default_vumi_options = {
        "hostname": "127.0.0.1",
        "port": 5672,
        "username": "vumi",
        "password": "vumi",
        "vhost": "/develop",
        "specfile": "amqp-spec-0-8.xml",
        "sentry": None,
        }

    def get_vumi_options(self):
        # We don't want these to get in the way later.
        vumi_option_params = {}
        for opt in (i[0] for i in VumiOptions.optParameters):
            vumi_option_params[opt] = self.pop(opt)

        config_file = vumi_option_params.pop('vumi-config')

        # non-recursive overlay is safe because vumi options are
        # all simple key-value pairs
        return overlay_configs(
            self.default_vumi_options,
            read_yaml_config(config_file),
            filter_null_values(vumi_option_params))

    def postOptions(self):
        self.vumi_options = self.get_vumi_options()


class StartWorkerOptions(VumiOptions):
    """
    Options to the vumi_worker twistd plugin.
    """

    optFlags = [
        ["worker-help", None,
         "Print out a usage message for the worker-class and exit"],
        ]

    optParameters = [
        ["worker-class", None, None, "Class of a worker to start"],
        ["worker_class", None, None, "Deprecated. See --worker-class instead"],
        ["config", None, None, "YAML config file for worker configuration"
         " options"],
        ["maxthreads", None, None, "Maximum size of reactor thread pool", int],
    ]

    longdesc = """Launch an instance of a vumi worker process."""

    def __init__(self):
        VumiOptions.__init__(self)
        self.set_options = {}

    def opt_set_option(self, keyvalue):
        """Set a worker configuration option (overrides values
        specified in the file passed to --config)."""
        key, _sep, value = keyvalue.partition(':')
        self.set_options[key] = value

    def exit(self):
        # So we can stub it out in tests.
        sys.exit(0)

    def emit(self, text):
        # So we can stub it out in tests.
        print text

    def do_worker_help(self):
        """Print out a usage message for the worker-class and exit"""
        worker_class = load_class_by_string(self.worker_class)
        self.emit(worker_class.__doc__)
        config_class = getattr(worker_class, 'CONFIG_CLASS', None)
        if config_class is not None:
            self.emit(config_class.__doc__)
        self.emit("")
        self.exit()

    def get_worker_class(self):
        worker_class = self.opts.pop('worker-class')
        depr_worker_class = self.opts.pop('worker_class')

        if depr_worker_class is not None:
            warnings.warn("The --worker_class option is deprecated since"
                          " Vumi 0.3. Please use --worker-class instead.",
                          category=DeprecationWarning)
            if worker_class is None:
                worker_class = depr_worker_class

        if worker_class is None:
            raise VumiError("please specify --worker-class")

        return worker_class

    def get_worker_config(self):
        config_file = self.opts.pop('config')

        # non-recursive overlay is safe because set_options
        # can only contain simple key-value pairs.
        return overlay_configs(
            read_yaml_config(config_file),
            self.set_options)

    def get_maxthreads(self):
        return self.opts.pop("maxthreads")

    def postOptions(self):
        VumiOptions.postOptions(self)

        self.worker_class = self.get_worker_class()

        if self.opts.pop('worker-help'):
            self.do_worker_help()

        self.worker_config = self.get_worker_config()

        self.maxthreads = self.get_maxthreads()


class VumiWorkerServiceMaker(object):
    implements(IServiceMaker, IPlugin)
    # the name of our plugin, this will be the subcommand for twistd
    # e.g. $ twistd -n vumi_worker --option1= ...
    tapname = "vumi_worker"
    # description, also for twistd
    description = "Start a Vumi worker"
    # what command line options does this service expose
    options = StartWorkerOptions

    def set_maxthreads(self, maxthreads):
        from twisted.internet import reactor

        if maxthreads is not None:
            reactor.suggestThreadPoolSize(maxthreads)

    def makeService(self, options):
        sentry_dsn = options.vumi_options.pop('sentry', None)
        class_name = options.worker_class.rpartition('.')[2].lower()
        logger_name = options.worker_config.get('worker_name', class_name)
        system_id = options.vumi_options.get('system-id', 'global')
        worker_id = generate_worker_id(system_id, logger_name)

        self.set_maxthreads(options.maxthreads)

        worker_creator = WorkerCreator(options.vumi_options)
        worker = worker_creator.create_worker(options.worker_class,
                                              options.worker_config)

        if sentry_dsn is not None:
            sentry_service = SentryLoggerService(sentry_dsn,
                                                 logger_name,
                                                 worker_id)
            worker.addService(sentry_service)

        return worker


class DeprecatedStartWorkerServiceMaker(VumiWorkerServiceMaker):
    tapname = "start_worker"
    description = "Deprecated copy of vumi_worker. Use vumi_worker instead."
PKqG[vumi/log.py# -*- test-case-name: vumi.tests.test_log -*-
import logging
from functools import partial

from twisted.python import log


debug = partial(log.msg, logLevel=logging.DEBUG)
info = partial(log.msg, logLevel=logging.INFO)
warning = partial(log.msg, logLevel=logging.WARNING)
error = partial(log.err, logLevel=logging.ERROR)
critical = partial(log.err, logLevel=logging.CRITICAL)

# make transition from twisted.python.log easier
msg = info
err = error


class WrappingLogger(object):
    '''A logger that will add the additional arguments that it is initialized
    with to every logging call.'''
    def __init__(self, **kwargs):
        self.debug = partial(debug, **kwargs)
        self.info = partial(info, **kwargs)
        self.warning = partial(warning, **kwargs)
        self.error = partial(error, **kwargs)
        self.critical = partial(critical, **kwargs)
        self.msg = partial(msg, **kwargs)
        self.err = partial(err, **kwargs)
PKqG< < vumi/worker.py# -*- test-case-name: vumi.tests.test_worker -*-

"""Basic tools for workers that handle TransportMessages."""

import time
import os
import socket

from twisted.internet.defer import (
    inlineCallbacks, succeed, maybeDeferred, gatherResults)

from vumi.log import WrappingLogger
from vumi.service import Worker
from vumi.middleware import setup_middlewares_from_config
from vumi.connectors import (
    ReceiveInboundConnector, ReceiveOutboundConnector,
    PublishStatusConnector, ReceiveStatusConnector)
from vumi.config import Config, ConfigInt
from vumi.errors import DuplicateConnectorError
from vumi.utils import generate_worker_id
from vumi.blinkenlights.heartbeat import (HeartBeatPublisher,
                                          HeartBeatMessage)


def then_call(d, func, *args, **kw):
    return d.addCallback(lambda r: func(*args, **kw))


class BaseConfig(Config):
    """Base config definition for workers.

    You should subclass this and add worker-specific fields.
    """

    amqp_prefetch_count = ConfigInt(
        "The number of messages fetched concurrently from each AMQP queue"
        " by each worker instance.",
        default=20, static=True)


class BaseWorker(Worker):
    """Base class for a message processing worker.

    This contains common functionality used by application, transport and
    dispatcher workers. It should be subclassed by workers that need to
    manage their own connectors.
    """

    CONFIG_CLASS = BaseConfig

    def __init__(self, options, config=None):
        super(BaseWorker, self).__init__(options, config=config)
        self.connectors = {}
        self.middlewares = []
        self._static_config = self.CONFIG_CLASS(self.config, static=True)
        self._hb_pub = None
        self._worker_id = None
        self.log = WrappingLogger(system=self.config.get('worker_name'))

    def startWorker(self):
        self.log.msg(
            'Starting a %s worker with config: %s'
            % (self.__class__.__name__, self.config))
        d = maybeDeferred(self._validate_config)
        then_call(d, self.setup_heartbeat)
        then_call(d, self.setup_middleware)
        then_call(d, self.setup_connectors)
        then_call(d, self.setup_worker)
        return d

    def stopWorker(self):
        self.log.msg('Stopping a %s worker.' % (self.__class__.__name__,))
        d = succeed(None)
        then_call(d, self.teardown_worker)
        then_call(d, self.teardown_connectors)
        then_call(d, self.teardown_middleware)
        then_call(d, self.teardown_heartbeat)
        return d

    def setup_connectors(self):
        raise NotImplementedError()

    @inlineCallbacks
    def setup_heartbeat(self):
        # Disable heartbeats if worker_name is not set. We're
        # currently using it as the primary identifier for a worker
        if 'worker_name' in self.config:
            self._worker_name = self.config.get("worker_name")
            self._system_id = self.options.get("system-id", "global")
            self._worker_id = generate_worker_id(self._system_id,
                                                 self._worker_name)
            self.log.msg(
                "Starting HeartBeat publisher with worker_name=%s"
                % self._worker_name)
            self._hb_pub = yield self.start_publisher(
                HeartBeatPublisher, self._gen_heartbeat_attrs)
        else:
            self.log.msg(
                "HeartBeat publisher disabled. No worker_id field found in "
                "config.")

    def teardown_heartbeat(self):
        if self._hb_pub is not None:
            self._hb_pub.stop()
            self._hb_pub = None

    def _gen_heartbeat_attrs(self):
        # worker_name is guaranteed to be set here, otherwise this func would
        # not have been called
        attrs = {
            'version': HeartBeatMessage.VERSION_20130319,
            'worker_id': self._worker_id,
            'system_id': self._system_id,
            'worker_name': self._worker_name,
            'hostname': socket.gethostname(),
            'timestamp': time.time(),
            'pid': os.getpid(),
        }
        attrs.update(self.custom_heartbeat_attrs())
        return attrs

    def custom_heartbeat_attrs(self):
        """Worker subclasses can override this to add custom attributes"""
        return {}

    def teardown_connectors(self):
        d = succeed(None)
        for connector_name in self.connectors.keys():
            then_call(d, self.teardown_connector, connector_name)
        return d

    def setup_worker(self):
        raise NotImplementedError()

    def teardown_worker(self):
        raise NotImplementedError()

    def setup_middleware(self):
        """Create middlewares from config."""
        d = setup_middlewares_from_config(self, self.config)
        d.addCallback(self.middlewares.extend)
        return d

    def teardown_middleware(self):
        """Teardown middlewares."""
        d = succeed(None)
        for mw in reversed(self.middlewares):
            then_call(d, mw.teardown_middleware)
        return d

    def get_static_config(self):
        """Return static (message independent) configuration."""
        return self._static_config

    def get_config(self, msg, ctxt=None):
        """This should return a message and context specific config object.

        It deliberately returns a deferred even when this isn't strictly
        necessary to ensure that workers will continue to work when per-message
        configuration needs to be fetched from elsewhere.
        """
        return succeed(self.CONFIG_CLASS(self.config))

    def _validate_config(self):
        """Once subclasses call `super().validate_config` properly,
           this method can be removed.
           """
        # TODO: remove this once all uses of validate_config have been fixed.
        self.validate_config()

    def validate_config(self):
        """
        Application-specific config validation happens in here.

        Subclasses may override this method to perform extra config
        validation.
        """
        # TODO: deprecate this in favour of a similar method on
        #       config classes.
        pass

    def setup_connector(self, connector_cls, connector_name, middleware=False):
        if connector_name in self.connectors:
            raise DuplicateConnectorError("Attempt to add duplicate connector"
                                          " with name %r" % (connector_name,))
        prefetch_count = self.get_static_config().amqp_prefetch_count
        middlewares = self.middlewares if middleware else None

        connector = connector_cls(self, connector_name,
                                  prefetch_count=prefetch_count,
                                  middlewares=middlewares)
        self.connectors[connector_name] = connector

        d = connector.setup()
        d.addCallback(lambda r: connector)
        return d

    def teardown_connector(self, connector_name):
        connector = self.connectors.pop(connector_name)
        d = connector.teardown()
        d.addCallback(lambda r: connector)
        return d

    def setup_ri_connector(self, connector_name, middleware=True):
        return self.setup_connector(ReceiveInboundConnector, connector_name,
                                    middleware=middleware)

    def setup_ro_connector(self, connector_name, middleware=True):
        return self.setup_connector(ReceiveOutboundConnector, connector_name,
                                    middleware=middleware)

    def setup_publish_status_connector(self, connector_name, middleware=True):
        return self.setup_connector(PublishStatusConnector, connector_name,
                                    middleware=middleware)

    def setup_receive_status_connector(self, connector_name, middleware=True):
        return self.setup_connector(ReceiveStatusConnector, connector_name,
                                    middleware=middleware)

    def pause_connectors(self):
        return gatherResults([
            connector.pause() for connector in self.connectors.itervalues()])

    def unpause_connectors(self):
        for connector in self.connectors.itervalues():
            connector.unpause()
PK=JGxʐvumi/config.py# -*- test-case-name: vumi.tests.test_config -*-

from confmodel.fields import ConfigField
from confmodel.fallbacks import FieldFallback

from vumi.utils import load_class_by_string

from confmodel import Config
from confmodel.errors import ConfigError
from confmodel.fields import (
    ConfigInt, ConfigFloat, ConfigBool, ConfigList, ConfigDict, ConfigText,
    ConfigUrl, ConfigRegex)
from confmodel.interfaces import IConfigData


class ConfigClassName(ConfigField):
    field_type = 'Class'

    def __init__(self, doc, required=False, default=None, static=False,
                 implements=None):
        super(ConfigClassName, self).__init__(doc, required, default, static)
        self.interface = implements

    def clean(self, value):
        try:
            cls = load_class_by_string(value)
        except (ValueError, ImportError), e:
            # ValueError for empty module name
            self.raise_config_error(str(e))

        if self.interface and not self.interface.implementedBy(cls):
            self.raise_config_error('does not implement %r.' % (
                self.interface,))
        return cls


class ConfigServerEndpoint(ConfigField):
    field_type = 'twisted_endpoint'

    def clean(self, value):
        from twisted.internet.endpoints import serverFromString
        from twisted.internet import reactor
        try:
            return serverFromString(reactor, value)
        except ValueError:
            self.raise_config_error('is not a valid server endpoint')


class ServerEndpointFallback(FieldFallback):
    def __init__(self, host_field="host", port_field="port"):
        self.host_field = host_field
        self.port_field = port_field
        self.required_fields = [port_field]

    def build_value(self, config):
        fields = {
            "host": getattr(config, self.host_field),
            "port": getattr(config, self.port_field),
        }

        formatstr = "tcp:port={port}"
        if fields["host"] is not None:
            formatstr += ":interface={host}"
        return formatstr.format(**fields)


class ConfigClientEndpoint(ConfigField):
    field_type = 'twisted_endpoint'

    def clean(self, value):
        from twisted.internet.endpoints import clientFromString
        from twisted.internet.interfaces import IStreamClientEndpoint
        from twisted.internet import reactor
        if IStreamClientEndpoint.providedBy(value):
            # We got an actual endpoint object, useful for testing.
            return value
        try:
            return clientFromString(reactor, value)
        except ValueError:
            self.raise_config_error('is not a valid client endpoint')


class ClientEndpointFallback(FieldFallback):
    def __init__(self, host_field="host", port_field="port"):
        self.host_field = host_field
        self.port_field = port_field
        self.required_fields = [self.host_field, self.port_field]

    def build_value(self, config):
        fields = {
            "host": getattr(config, self.host_field),
            "port": getattr(config, self.port_field),
        }
        return "tcp:host={host}:port={port}".format(**fields)


class ConfigContext(object):
    """Context within which a configuration object can be retrieved.

    For example, configuration may depend on the message being processed
    or on the HTTP URL being accessed.
    """
    def __init__(self, **kw):
        for k, v in kw.items():
            setattr(self, k, v)


class ConfigRiak(ConfigDict):
    field_type = 'riak'
    """Riak configuration.

    Ensures that there is at least a ``bucket_prefix`` key.
    """
    def clean(self, value):
        if "bucket_prefix" not in value:
            self.raise_config_error(
                "does not contain the `bucket_prefix` key.")
        return super(self.__class__, self).clean(value)

# Re-export these for compatibility.
Config
ConfigError
ConfigInt
ConfigFloat
ConfigBool
ConfigList
ConfigDict
ConfigText
ConfigUrl
ConfigRegex
IConfigData
PK=JG.b		vumi/reconnecting_client.py# -*- coding: utf-8 -*-
# -*- test-case-name: vumi.tests.test_reconnecting_client -*-

"""A service to provide the functionality of ReconnectingClientFactory
   when using Twisted's endpoints.

   Melded together from code and ideas from:

   * Twisted's existing ReconnectingClientFactory code.
   * https://github.com/keturn/twisted/blob/persistent-client-service-4735/
     twisted/application/internet.py
   """

import random

from twisted.application.service import Service
from twisted.internet.defer import gatherResults, Deferred
from twisted.python import log


class _RestartableProtocolProxy(object):
    """A proxy for a Protocol to provide connectionLost notification."""

    def __init__(self, protocol, clientService):
        self.__protocol = protocol
        self.__clientService = clientService

    def connectionLost(self, reason):
        result = self.__protocol.connectionLost(reason)
        self.__clientService.clientConnectionLost(reason)
        return result

    def __getattr__(self, item):
        return getattr(self.__protocol, item)

    def __repr__(self):
        return '<%s.%s wraps %r>' % (__name__, self.__class__.__name__,
            self.__protocol)



class _RestartableProtocolFactoryProxy(object):
    """A wrapper for a ProtocolFactory to facilitate restarting Protocols."""

    _protocolProxyFactory = _RestartableProtocolProxy

    def __init__(self, protocolFactory, clientService):
        self.protocolFactory = protocolFactory
        self.clientService = clientService


    def buildProtocol(self, addr):
        protocol = self.protocolFactory.buildProtocol(addr)
        wrappedProtocol = self._protocolProxyFactory(
            protocol, self.clientService)
        return wrappedProtocol


    def __getattr__(self, item):
        # maybe components.proxyForInterface is the thing to do here, but that
        # gave me a metaclass conflict.
        return getattr(self.protocolFactory, item)


    def __repr__(self):
        return '<%s.%s wraps %r>' % ( __name__, self.__class__.__name__,
            self.protocolFactory)



class ReconnectingClientService(Service):
    """
    Service which auto-reconnects clients with an exponential back-off.

    Note that clients should call my resetDelay method after they have
    connected successfully.

    @ivar factory: A L{protocol.Factory} which will be used to create clients
        for the endpoint.
    @ivar endpoint: An L{IStreamClientEndpoint
        } provider
        which will be used to connect when the service starts.

    @ivar maxDelay: Maximum number of seconds between connection attempts.
    @ivar initialDelay: Delay for the first reconnection attempt.
    @ivar factor: A multiplicitive factor by which the delay grows
    @ivar jitter: Percentage of randomness to introduce into the delay length
        to prevent stampeding.
    @ivar clock: The clock used to schedule reconnection. It's mainly useful to
        be parametrized in tests. If the factory is serialized, this attribute
        will not be serialized, and the default value (the reactor) will be
        restored when deserialized.
    @type clock: L{IReactorTime}
    @ivar maxRetries: Maximum number of consecutive unsuccessful connection
        attempts, after which no further connection attempts will be made. If
        this is not explicitly set, no maximum is applied.
    """
    maxDelay = 3600
    initialDelay = 1.0
    # Note: These highly sensitive factors have been precisely measured by
    # the National Institute of Science and Technology.  Take extreme care
    # in altering them, or you may damage your Internet!
    # (Seriously: )
    factor = 2.7182818284590451 # (math.e)
    # Phi = 1.6180339887498948 # (Phi is acceptable for use as a
    # factor if e is too large for your application.)
    jitter = 0.11962656472 # molar Planck constant times c, joule meter/mole

    delay = initialDelay
    retries = 0
    maxRetries = None
    clock = None
    noisy = False

    continueTrying = False

    _delayedRetry = None
    _connectingDeferred = None
    _protocol = None
    _protocolStoppingDeferred = None


    def __init__(self, endpoint, factory):
        self.endpoint = endpoint
        self.factory = factory

        if self.clock is None:
            from twisted.internet import reactor
            self.clock = reactor


    def startService(self):
        Service.startService(self)
        self.continueTrying = True
        self.retry(delay=0.0)


    def stopService(self):
        """
        Stop attempting to reconnect and close any existing connections.
        """
        self.continueTrying = False

        waitFor = []

        if self._delayedRetry is not None and self._delayedRetry.active():
            self._delayedRetry.cancel()
            self._delayedRetry = None

        if self._connectingDeferred is not None:
            waitFor.append(self._connectingDeferred)
            self._connectingDeferred.cancel()
            self._connectingDeferred = None

        if self._protocol is not None:
            self._protocolStoppingDeferred = Deferred()
            waitFor.append(self._protocolStoppingDeferred)
            self._protocol.transport.loseConnection()

        d = gatherResults(waitFor)
        return d.addCallback(lambda _: Service.stopService(self))


    def clientConnected(self, protocol):
        self._protocol = protocol
        # TODO: do we want to provide a hook for the protocol
        #       to call resetDelay itself?
        self.resetDelay()


    def clientConnectionFailed(self, unused_reason):
        # TODO: log the reason?
        self.retry()


    def clientConnectionLost(self, unused_reason):
        # TODO: log the reason?
        self._protocol = None
        if self._protocolStoppingDeferred is not None:
            d = self._protocolStoppingDeferred
            self._protocolStoppingDeferred = None
            d.callback(None)
        self.retry()


    def retry(self, delay=None):
        """
        Have this connector connect again, after a suitable delay.
        """
        if not self.continueTrying:
            if self.noisy:
                log.msg("Abandoning %s on explicit request" % (self.endpoint,))
            return

        if self.maxRetries is not None and (self.retries >= self.maxRetries):
            if self.noisy:
                log.msg("Abandoning %s after %d retries." %
                        (self.endpoint, self.retries))
            return

        self.retries += 1

        if delay is None:
            self.delay = min(self.delay * self.factor, self.maxDelay)
            if self.jitter:
                self.delay = random.normalvariate(self.delay,
                                                  self.delay * self.jitter)
            delay = self.delay

        if self.noisy:
            log.msg("Will retry %s in %g seconds"
                    % (self.endpoint, delay))

        def reconnector():
            proxied_factory = _RestartableProtocolFactoryProxy(
                self.factory, self)
            self._connectingDeferred = self.endpoint.connect(proxied_factory)
            self._connectingDeferred.addCallback(self.clientConnected)
            self._connectingDeferred.addErrback(self.clientConnectionFailed)

        self._delayedRetry = self.clock.callLater(delay, reconnector)


    def resetDelay(self):
        """
        Call this method after a successful connection: it resets the delay and
        the retry counter.
        """
        self.delay = self.initialDelay
        self.retries = 0
PK=JGMO vumi/scripts/vumi_redis_tools.py#!/usr/bin/env python
# -*- test-case-name: vumi.scripts.tests.test_vumi_redis_tools -*-
import re
import sys

import yaml
from twisted.python import usage

from vumi.persist.redis_manager import RedisManager


class TaskError(Exception):
    """Raised when an error is encoutered while using tasks."""


class Task(object):
    """
    A task to perform on a redis key.
    """

    name = None
    hidden = False  # set to True to hide from docs
    runner = None
    redis = None

    @classmethod
    def parse(cls, task_desc):
        """
        Parse a string description into a task.

        Task description format::

          [:[=[,...]]]
        """
        task_type, _, param_desc = task_desc.partition(':')
        task_cls = cls._parse_task_type(task_type)
        params = {}
        if param_desc:
            params = cls._parse_param_desc(param_desc)
        return task_cls(**params)

    @classmethod
    def task_types(cls):
        return cls.__subclasses__()

    @classmethod
    def _parse_task_type(cls, task_type):
        names = dict((t.name, t) for t in cls.task_types())
        if task_type not in names:
            raise TaskError("Unknown task type %r" % (task_type,))
        return names[task_type]

    @classmethod
    def _parse_param_desc(cls, param_desc):
        params = [x.partition('=') for x in param_desc.split(',')]
        params = [(p, v) for p, _sep, v in params]
        return dict(params)

    def init(self, runner, redis):
        self.runner = runner
        self.redis = redis

    def before(self):
        """Run once before the task applied to any keys."""

    def after(self):
        """Run once afer the task has been applied to all keys."""

    def process_key(self, key):
        """Run once for each key.

        May return either the name of the key (if the key should
        be processed by later tasks), the new name of the key (if
        the key was renamed and should be processed by later tasks)
        or ``None`` (if the key has been deleted or should not be
        processed by further tasks).
        """
        return key


class Count(Task):
    """A task that counts the number of keys."""

    name = "count"

    def __init__(self):
        self.count = None

    def before(self):
        self.count = 0

    def after(self):
        self.runner.emit("Found %d matching keys." % (self.count,))

    def process_key(self, key):
        self.count += 1
        return key


class Expire(Task):
    """A task that sets an expiry time on each key."""

    name = "expire"

    def __init__(self, seconds):
        self.seconds = int(seconds)

    def process_key(self, key):
        self.redis.expire(key, self.seconds)
        return key


class Persist(Task):
    """A task that persists each key."""

    name = "persist"

    def process_key(self, key):
        self.redis.persist(key)
        return key


class ListKeys(Task):
    """A task that prints out each key."""

    name = "list"

    def process_key(self, key):
        self.runner.emit(key)
        return key


class Skip(Task):
    """A task that skips keys that match a regular expression."""

    name = "skip"

    def __init__(self, pattern):
        self.regex = re.compile(pattern)

    def process_key(self, key):
        if self.regex.match(key):
            return None
        return key


class Options(usage.Options):

    synopsis = "  [-t  ...]"

    longdesc = "Perform tasks on Redis keys."

    def __init__(self):
        usage.Options.__init__(self)
        self['tasks'] = []

    def getUsage(self, width=None):
        doc = usage.Options.getUsage(self, width=width)
        header = "Available tasks:"
        tasks = sorted(Task.task_types(), key=lambda t: t.name)
        tasks_doc = "".join(usage.docMakeChunks([{
            'long': task.name,
            'doc': task.__doc__,
        } for task in tasks if not task.hidden]))
        return "\n".join([doc, header, tasks_doc])

    def parseArgs(self, config_file, match_pattern):
        self['config'] = yaml.safe_load(open(config_file))
        self['match_pattern'] = match_pattern

    def opt_task(self, task_desc):
        """A task to perform on all matching keys."""
        task = Task.parse(task_desc)
        self['tasks'].append(task)

    opt_t = opt_task

    def postOptions(self):
        if not self['tasks']:
            raise usage.UsageError("Please specify a task.")


def scan_keys(redis, match):
    """Iterate over matching keys."""
    prev_cursor = None
    while True:
        cursor, keys = redis.scan(prev_cursor, match=match)
        for key in keys:
            yield key
        if cursor is None:
            break
        if cursor == prev_cursor:
            raise TaskError("Redis scan stuck on cursor %r" % (cursor,))
        prev_cursor = cursor


class TaskRunner(object):

    stdout = sys.stdout

    def __init__(self, options):
        self.options = options
        self.match_pattern = options['match_pattern']
        self.tasks = options['tasks']
        self.redis = self.get_redis(options['config'])

    def emit(self, s):
        """
        Print the given string and then a newline.
        """
        self.stdout.write(s)
        self.stdout.write("\n")

    def get_redis(self, config):
        """
        Create and return a redis manager.
        """
        redis_config = config.get('redis_manager', {})
        return RedisManager.from_config(redis_config)

    def run(self):
        """
        Apply all tasks to all keys.
        """
        for task in self.tasks:
            task.init(self, self.redis)

        for task in self.tasks:
            task.before()

        for key in scan_keys(self.redis, self.match_pattern):
            for task in self.tasks:
                key = task.process_key(key)
                if key is None:
                    break

        for task in self.tasks:
            task.after()


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    tasks = TaskRunner(options)
    tasks.run()
PK=JG%i44vumi/scripts/db_backup.py# -*- test-case-name: vumi.scripts.tests.test_db_backup -*-
import sys
import json
import pkg_resources
import traceback
import re
import time
import calendar
import copy
from datetime import datetime

import yaml
from twisted.python import usage

from vumi.persist.redis_manager import RedisManager
from vumi.errors import ConfigError


def vumi_version():
    vumi = pkg_resources.get_distribution("vumi")
    return str(vumi)


class KeyHandler(object):

    REDIS_TYPES = ('string', 'list', 'set', 'zset', 'hash')

    def __init__(self):
        self._get_handlers = dict((ktype, getattr(self, '%s_get' % ktype))
                                  for ktype in self.REDIS_TYPES)
        self._set_handlers = dict((ktype, getattr(self, '%s_set' % ktype))
                                  for ktype in self.REDIS_TYPES)

    def dump_key(self, redis, key):
        key_type = redis.type(key)
        record = {
            'type': key_type,
            'key': key,
            'value': self._get_handlers[key_type](redis, key),
            'ttl': redis.ttl(key),
        }
        return record

    def restore_key(self, redis, record, ttl_offset=0):
        key, key_type, ttl = record['key'], record['type'], record['ttl']
        if ttl is not None:
            ttl -= ttl_offset
            if ttl <= 0:
                return
        self._set_handlers[key_type](redis, key, record['value'])
        if ttl is not None:
            redis.expire(key, int(round(ttl)))

    def record_okay(self, record):
        if not isinstance(record, dict):
            return False
        for key in ('type', 'key', 'value', 'ttl'):
            if key not in record:
                return False
        return True

    def string_get(self, redis, key):
        return redis.get(key)

    def string_set(self, redis, key, value):
        redis.set(key, value)

    def list_get(self, redis, key):
        return redis.lrange(key, 0, -1)

    def list_set(self, redis, key, value):
        for item in value:
            redis.rpush(key, item)

    def set_get(self, redis, key):
        return sorted(redis.smembers(key))

    def set_set(self, redis, key, value):
        for item in value:
            redis.sadd(key, item)

    def zset_get(self, redis, key):
        return redis.zrange(key, 0, -1, withscores=True)

    def zset_set(self, redis, key, value):
        for item, score in value:
            redis.zadd(key, **{item.encode('utf8'): score})

    def hash_get(self, redis, key):
        return redis.hgetall(key)

    def hash_set(self, redis, key, value):
        redis.hmset(key, value)


class BackupDbsCmd(usage.Options):

    synopsis = " "

    optFlags = [
        ["not-sorted", None, "Don't sort keys when doing backup."],
    ]

    def parseArgs(self, db_config, db_backup):
        self.db_config = yaml.safe_load(open(db_config))
        self.db_backup = open(db_backup, "wb")
        self.redis_config = self.db_config.get('redis_manager', {})

    def header(self, cfg):
        return {
            'vumi_version': vumi_version(),
            'format': 'LF separated JSON',
            'backup_type': 'redis',
            'timestamp': cfg.get_utcnow().isoformat(),
            'sorted': not bool(self['not-sorted']),
            'redis_config': self.redis_config,
        }

    def write_line(self, data):
        self.db_backup.write(json.dumps(data))
        self.db_backup.write("\n")

    def run(self, cfg):
        cfg.emit("Backing up dbs ...")
        redis = cfg.get_redis(self.redis_config)
        key_handler = KeyHandler()
        keys = redis.keys()
        if not self.opts['not-sorted']:
            keys = sorted(keys)
        self.write_line(self.header(cfg))
        for key in keys:
            record = key_handler.dump_key(redis, key)
            self.write_line(record)
        self.db_backup.close()
        cfg.emit("Backed up %d keys." % (len(keys),))


class RestoreDbsCmd(usage.Options):

    synopsis = " "

    optFlags = [
        ["purge", None, "Purge all keys from the redis manager before "
                        "restoring."],
        ["frozen-ttls", None, "Restore TTLs of keys to the same value they "
                              "had when the backup was created, disregarding "
                              "how much time has passed since the backup was "
                              "created. The default is adjust TTLs by the "
                              "amount of time that has passed and to expire "
                              "keys whose TTLs are then zero or negative."],
    ]

    def parseArgs(self, db_config, db_backup):
        self.db_config = yaml.safe_load(open(db_config))
        self.db_backup = open(db_backup, "rb")
        self.redis_config = self.db_config.get('redis_manager', {})

    def check_header(self, header):
        if header is None:
            return None, "Header not found."
        try:
            header = json.loads(header)
        except Exception:
            return None, "Header not JSON."
        if not isinstance(header, dict):
            return None, "Header not JSON dict."
        if 'backup_type' not in header:
            return None, "Header missing backup_type."
        if header['backup_type'] != 'redis':
            return None, "Only redis backup type currently supported."
        return header, None

    def seconds_from_now(self, iso_timestamp):
        seconds_timestamp, _dot, _milliseconds = iso_timestamp.partition('.')
        time_of_backup = time.strptime(seconds_timestamp, "%Y-%m-%dT%H:%M:%S")
        return time.time() - calendar.timegm(time_of_backup)

    def run(self, cfg):
        line_iter = iter(self.db_backup)
        try:
            header = line_iter.next()
        except StopIteration:
            header = None

        header, error = self.check_header(header)
        if error is not None:
            cfg.emit(error)
            cfg.emit("Aborting restore.")
            return

        if self.opts['frozen-ttls']:
            ttl_offset = 0
        else:
            ttl_offset = self.seconds_from_now(header['timestamp'])

        cfg.emit("Restoring dbs ...")
        redis = cfg.get_redis(self.redis_config)
        if self.opts['purge']:
            redis._purge_all()
        key_handler = KeyHandler()
        keys, skipped = 0, 0
        for i, line in enumerate(line_iter):
            try:
                record = json.loads(line)
            except Exception:
                excinfo = sys.exc_info()
                for s in traceback.format_exception(*excinfo):
                    cfg.emit(s)
                skipped += 1
                continue
            if not key_handler.record_okay(record):
                cfg.emit("Skipping bad backup record on line %d." % (i + 1,))
                skipped += 1
                continue
            key_handler.restore_key(redis, record, ttl_offset)
            keys += 1

        cfg.emit("%d keys successfully restored." % keys)
        if skipped != 0:
            cfg.emit("WARNING: %d bad backup lines skipped." % skipped)


class MigrateDbsCmd(usage.Options):

    synopsis = (" "
                " ")

    def parseArgs(self, migration_config, db_backup, migrated_backup):
        self.migration_config = yaml.safe_load(open(migration_config))
        self.db_backup = open(db_backup, "rb")
        self.migrated_backup = open(migrated_backup, "wb")

    def postOptions(self):
        self.rules = self.create_rules(self.migration_config)

    def make_rule_drop(self, kw):
        key_regex = re.compile(kw['key'])

        def rule(record):
            if key_regex.match(record['key']):
                return True, None
            return False, record

        return rule

    def make_rule_rename(self, kw):
        from_regex = re.compile(kw['from'])
        to_template = kw['to']

        def rule(record):
            key = record['key']
            record['key'] = from_regex.sub(to_template, key)
            if record['key'] == key:
                return False, record
            return True, record

        return rule

    def create_rules(self, migration_config):
        rules = []
        for rule in migration_config['rules']:
            kw = rule.copy()
            rule_name = kw.pop('type')
            rule_maker = getattr(self, "make_rule_%s" % rule_name, None)
            if rule_maker is None:
                raise ConfigError("Unknown rule type %r" % rule_name)
            rules.append(rule_maker(kw))
        return rules

    def apply_rules(self, record):
        for rule in self.rules:
            done, record = rule(record)
            if done:
                break
        return record

    def write_line(self, record):
        self.migrated_backup.write(json.dumps(record))
        self.migrated_backup.write("\n")

    def run(self, cfg):
        line_iter = iter(self.db_backup)
        try:
            header = line_iter.next()
        except StopIteration:
            cfg.emit("No header in backup.")
            return
        self.write_line(json.loads(header))

        cfg.emit("Migrating backup ...")
        changed, processed = 0, 0
        for data in line_iter:
            record = json.loads(data)
            new_record = self.apply_rules(copy.deepcopy(record))
            if record != new_record:
                changed += 1
            processed += 1
            if new_record is not None:
                self.write_line(new_record)
        self.migrated_backup.close()

        cfg.emit("Summary of changes:")
        cfg.emit("  %d records processed." % processed)
        cfg.emit("  %d records altered." % changed)


class PrefixTree(object):
    def __init__(self):
        self._root = {}

    def add_key(self, key):
        node = self._root
        for c in key:
            if c not in node:
                node[c] = {}
            node = node[c]

    def compress_edges(self, edge_pattern):
        edge_regex = re.compile(edge_pattern)
        new_root = {}

        stack = [(self._root, new_root, '')]
        while stack:
            current_old, current_new, prefix = stack.pop()
            for c, next_old in current_old.iteritems():
                edge = prefix + c
                if edge_regex.match(edge):
                    current_new[edge] = next_new = {}
                    stack.append((next_old, next_new, ''))
                elif next_old:
                    stack.append((next_old, current_new, edge))
                else:
                    current_new[edge] = {}

        self._root = new_root

    def _print_tree(self, emit, indent, node, level):
        full_indent = indent * level
        for edge, node_edge in sorted(node.items()):
            sub_trees = dict((k, v) for k, v in node_edge.items() if v)
            leaves = len(node_edge) - len(sub_trees)
            emit("%s%s%s" % (full_indent, edge,
                             " (%d leaves)" % leaves if leaves else ""))
            if sub_trees:
                self._print_tree(emit, indent, sub_trees, level + 1)

    def print_tree(self, emit, indent="  "):
        return self._print_tree(emit, indent, self._root, 0)


class AnalyzeCmd(usage.Options):

    synopsis = ""

    optParameters = [
        ["separators", "s", "[:#]",
         "Regular expression for allowed key part separators."],
    ]

    def parseArgs(self, db_backup):
        self.db_backup = open(db_backup, "rb")

    def run(self, cfg):
        backup_lines = iter(self.db_backup)
        try:
            backup_lines.next()  # skip header
        except StopIteration:
            cfg.emit("No header found. Aborting.")
            return

        tree = PrefixTree()
        for i, line in enumerate(backup_lines):
            try:
                key = json.loads(line)['key']
            except:
                cfg.emit("Bad record %d: %r" % (i, line))
                continue
            tree.add_key(key)
        edge_pattern = r".*%s" % self.opts['separators']
        tree.compress_edges(edge_pattern)

        cfg.emit("Keys:")
        cfg.emit("-----")
        tree.print_tree(cfg.emit)


class Options(usage.Options):
    subCommands = [
        ["backup", None, BackupDbsCmd,
         "Backup databases."],
        ["restore", None, RestoreDbsCmd,
         "Restore databases."],
        ["migrate", None, MigrateDbsCmd,
         "Rename keys in a database backup."],
        ["analyze", None, AnalyzeCmd,
         "Analyze a database backup."],
    ]

    longdesc = """Back-up and restore utility for Vumi
                  Redis (and maybe later Riak) data stores."""

    def postOptions(self):
        if self.subCommand is None:
            raise usage.UsageError("Please specify a sub-command.")


class ConfigHolder(object):
    def __init__(self, options):
        self.options = options

    def emit(self, s):
        print s

    def get_utcnow(self):
        return datetime.utcnow()

    def get_redis(self, config):
        return RedisManager.from_config(config)

    def run(self):
        self.options.subOptions.run(self)


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    cfg = ConfigHolder(options)
    cfg.run()
PK=JG܎!vumi/scripts/vumi_count_models.py#!/usr/bin/env python
# -*- test-case-name: vumi.scripts.tests.test_vumi_count_models -*-

import re
import sys

from twisted.internet.defer import inlineCallbacks, succeed
from twisted.internet.task import react
from twisted.python import usage

from vumi.utils import load_class_by_string
from vumi.persist.txriak_manager import TxRiakManager


class Options(usage.Options):
    optParameters = [
        ["model", "m", None,
         "Full Python name of the model class to count."
         " E.g. 'vumi.components.message_store.InboundMessage'."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["index-field", None, None,
         "Field with index to query. If omitted, all keys in the bucket will"
         " be counted and no `index-value-*' parameters are allowed."],
        ["index-value", None, None, "Exact match value or start of range."],
        ["index-value-end", None, None,
         "End of range. If ommitted, an exact match query will be used."],
        ["index-value-regex", None, None, "Regex to filter index values."],
        ["index-page-size", None, "1000",
         "The number of keys to fetch in each index query."],
    ]

    longdesc = """
    Index-based model counter. This makes paginated index queries, optionally
    filters the results by applying a regex to the index value, and returns a
    count of all matching models.
    """

    def ensure_dependent_option(self, needed, needs):
        """
        Raise UsageError if `needs` is provided without `needed`.
        """
        if self[needed] is None and self[needs] is not None:
            raise usage.UsageError("%s requires %s to be specified." % (
                needs, needed))

    def postOptions(self):
        if self["model"] is None:
            raise usage.UsageError("Please specify a model class.")
        if self["bucket-prefix"] is None:
            raise usage.UsageError("Please specify a bucket prefix.")
        self.ensure_dependent_option("index-field", "index-value")
        self.ensure_dependent_option("index-value", "index-field")
        self.ensure_dependent_option("index-value", "index-value-end")
        self.ensure_dependent_option("index-value-end", "index-value-regex")
        self["index-page-size"] = int(self['index-page-size'])


class ProgressEmitter(object):
    """Report progress as the number of items processed to an emitter."""

    def __init__(self, emit, batch_size):
        self.emit = emit
        self.batch_size = batch_size
        self.processed = 0

    def update(self, value):
        if (value / self.batch_size) > (self.processed / self.batch_size):
            self.emit(value)
        self.processed = value


class ModelCounter(object):
    def __init__(self, options):
        self.options = options
        model_cls = load_class_by_string(options['model'])
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.model = self.manager.proxy(model_cls)

    def cleanup(self):
        return self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return TxRiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    def count_keys(self, keys, filter_regex):
        """
        Count keys in an index page, filtering by regex if necessary.
        """
        if filter_regex is not None:
            keys = [(v, k) for v, k in keys if filter_regex.match(v)]
        return len(keys)

    @inlineCallbacks
    def count_pages(self, index_page, filter_regex):
        emit_progress = lambda t: self.emit(
            "%s object%s counted." % (t, "" if t == 1 else "s"))
        progress = ProgressEmitter(
            emit_progress, self.options["index-page-size"])
        counted = 0
        while index_page is not None:
            if index_page.has_next_page():
                next_page_d = index_page.next_page()
            else:
                next_page_d = succeed(None)
            counted += self.count_keys(list(index_page), filter_regex)
            progress.update(counted)
            index_page = yield next_page_d
        self.emit("Done, %s object%s found." % (
            counted, "" if counted == 1 else "s"))

    @inlineCallbacks
    def count_all_keys(self):
        """
        Perform an index query to get all keys and count them.
        """
        self.emit("Counting all keys ...")
        index_page = yield self.model.all_keys_page(
            max_results=self.options["index-page-size"])
        yield self.count_pages(index_page, filter_regex=None)

    @inlineCallbacks
    def count_index_keys(self):
        """
        Perform an index query to get all matching keys and count them.
        """
        filter_regex = self.options["index-value-regex"]
        if filter_regex is not None:
            filter_regex = re.compile(filter_regex)
        self.emit("Counting ...")
        index_page = yield self.model.index_keys_page(
            field_name=self.options["index-field"],
            value=self.options["index-value"],
            end_value=self.options["index-value-end"],
            max_results=self.options["index-page-size"],
            return_terms=True)
        yield self.count_pages(index_page, filter_regex=filter_regex)

    def _run(self):
        if self.options["index-field"] is None:
            return self.count_all_keys()
        else:
            return self.count_index_keys()

    @inlineCallbacks
    def run(self):
        try:
            yield self._run()
        finally:
            yield self.cleanup()


def main(_reactor, name, *args):
    try:
        options = Options()
        options.parseOptions(args)
    except usage.UsageError, errortext:
        print '%s: %s' % (name, errortext)
        print '%s: Try --help for usage details.' % (name,)
        sys.exit(1)

    model_counter = ModelCounter(options)
    return model_counter.run()


if __name__ == '__main__':
    react(main, sys.argv)
PK=JG_CCvumi/scripts/vumi_tagpools.py#!/usr/bin/env python
# -*- test-case-name: vumi.scripts.tests.test_vumi_tagpools -*-
import sys
import re
import itertools

import yaml
from twisted.python import usage

from vumi.components.tagpool import TagpoolManager
from vumi.persist.redis_manager import RedisManager


class PoolSubCmd(usage.Options):

    synopsis = ""

    def parseArgs(self, pool):
        self.pool = pool


class CreatePoolCmd(PoolSubCmd):
    def run(self, cfg):
        local_tags = cfg.tags(self.pool)
        tags = [(self.pool, local_tag) for local_tag in local_tags]
        metadata = cfg.metadata(self.pool)

        cfg.emit("Creating pool %s ..." % self.pool)
        cfg.emit("  Setting metadata ...")
        cfg.tagpool.set_metadata(self.pool, metadata)
        cfg.emit("  Declaring %d tag(s) ..." % len(tags))
        cfg.tagpool.declare_tags(tags)
        cfg.emit("  Done.")


class UpdatePoolMetadataCmd(PoolSubCmd):
    def run(self, cfg):
        metadata = cfg.metadata(self.pool)

        cfg.emit("Updating metadata for pool %s ..." % self.pool)
        cfg.tagpool.set_metadata(self.pool, metadata)
        cfg.emit("  Done.")


class UpdateAllPoolMetadataCmd(usage.Options):
    def run(self, cfg):
        pools_in_tagpool = cfg.tagpool.list_pools()
        pools_in_cfg = set(cfg.pools.keys())
        pools_in_both = sorted(pools_in_tagpool.intersection(pools_in_cfg))

        cfg.emit("Updating pool metadata.")
        cfg.emit("Note: Pools not present in both the config and tagpool"
                 " store will not be updated.")

        if not pools_in_both:
            cfg.emit("No pools found.")
            return

        for pool in pools_in_both:
            cfg.emit("  Updating metadata for pool %s ..." % pool)
            metadata = cfg.metadata(pool)
            cfg.tagpool.set_metadata(pool, metadata)

        cfg.emit("Done.")


class PurgePoolCmd(PoolSubCmd):
    def run(self, cfg):
        cfg.emit("Purging pool %s ..." % self.pool)
        cfg.tagpool.purge_pool(self.pool)
        cfg.emit("  Done.")


def key_ranges(keys):
    """Take a list of keys and convert them to a compact
    output string.

    E.g. foo100, foo101, ..., foo200, foo300
         becomes
         foo[100..200], foo300
    """
    keys.sort()
    last_digits_re = re.compile("^(?P
()|(.*[^\d]))(?P\d+)"
                                "(?P.*)$")

    def group(x):
        i, key = x
        match = last_digits_re.match(key)
        if not match:
            return None
        pre, post = match.group('pre'), match.group('post')
        digits = match.group('digits')
        dlen, value = len(digits), int(digits)
        return pre, post, dlen, value - i

    key_ranges = []
    for grp_key, grp_list in itertools.groupby(enumerate(keys), group):
        grp_list = list(grp_list)
        if len(grp_list) == 1 or grp_key is None:
            key_ranges.extend(g[1] for g in grp_list)
        else:
            pre, post, dlen, _cnt = grp_key
            start = last_digits_re.match(grp_list[0][1]).group('digits')
            end = last_digits_re.match(grp_list[-1][1]).group('digits')
            key_range = "%s[%s-%s]%s" % (pre, start, end, post)
            key_ranges.append(key_range)

    return ", ".join(key_ranges)


class ListKeysCmd(PoolSubCmd):
    def run(self, cfg):
        free_tags = cfg.tagpool.free_tags(self.pool)
        inuse_tags = cfg.tagpool.inuse_tags(self.pool)
        cfg.emit("Listing tags for pool %s ..." % self.pool)
        cfg.emit("Free tags:")
        cfg.emit("   " + (key_ranges([tag[1] for tag in free_tags])
                          or "-- None --"))
        cfg.emit("Tags in use:")
        cfg.emit("   " + (key_ranges([tag[1] for tag in inuse_tags])
                          or "-- None --"))


class ListPoolsCmd(usage.Options):
    def run(self, cfg):
        pools_in_tagpool = cfg.tagpool.list_pools()
        pools_in_cfg = set(cfg.pools.keys())
        cfg.emit("Pools defined in cfg and tagpool:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_tagpool.intersection(pools_in_cfg))
                           or ['-- None --']))
        cfg.emit("Pools only in cfg:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_cfg.difference(pools_in_tagpool))
                           or ['-- None --']))
        cfg.emit("Pools only in tagpool:")
        cfg.emit("   " +
                 ', '.join(sorted(pools_in_tagpool.difference(pools_in_cfg))
                           or ['-- None --']))


class ReleaseTagCmd(usage.Options):

    synopsis = " "

    def parseArgs(self, pool, tag):
        self.pool = pool
        self.tag = tag

    def run(self, cfg):
        free_tags = cfg.tagpool.free_tags(self.pool)
        inuse_tags = cfg.tagpool.inuse_tags(self.pool)
        tag_tuple = (self.pool, self.tag)
        if tag_tuple not in inuse_tags:
            if tag_tuple not in free_tags:
                cfg.emit('Unknown tag %s.' % (tag_tuple,))
            else:
                cfg.emit('Tag %s not in use.' % (tag_tuple,))
        else:
            cfg.tagpool.release_tag(tag_tuple)
            cfg.emit('Released %s.' % (tag_tuple,))


class Options(usage.Options):
    subCommands = [
        ["create-pool", None, CreatePoolCmd,
         "Declare tags for a tag pool."],
        ["update-pool-metadata", None, UpdatePoolMetadataCmd,
         "Update a pool's metadata from config."],
        ["update-all-metadata", None, UpdateAllPoolMetadataCmd,
         "Update all pool meta data from config."],
        ["purge-pool", None, PurgePoolCmd,
         "Purge all tags from a tag pool."],
        ["list-keys", None, ListKeysCmd,
         "List the free and inuse keys associated with a tag pool."],
        ["list-pools", None, ListPoolsCmd,
         "List all pools defined in config and in the tag store."],
        ["release-tag", None, ReleaseTagCmd,
         "Release a single tag, moves it from the in-use to the free set. "
         "Use only if you know what you are doing."]
    ]

    optParameters = [
        ["config", "c", "tagpools.yaml",
         "A config file describing the available pools."],
    ]

    longdesc = """Utilities for working with
                  vumi.application.TagPoolManager."""

    def postOptions(self):
        if self.subCommand is None:
            raise usage.UsageError("Please specify a sub-command.")


class ConfigHolder(object):
    def __init__(self, options):
        self.options = options
        self.config = yaml.safe_load(open(options['config'], "rb"))
        self.pools = self.config.get('pools', {})
        redis = RedisManager.from_config(self.config.get('redis_manager', {}))
        self.tagpool = TagpoolManager(redis.sub_manager(
                self.config.get('tagpool_prefix', 'vumi')))

    def emit(self, s):
        print s

    def tags(self, pool):
        tags = self.pools[pool]['tags']
        if isinstance(tags, basestring):
            tags = eval(tags, {}, {})
        return tags

    def metadata(self, pool):
        return self.pools[pool].get('metadata', {})

    def run(self):
        self.options.subOptions.run(self)


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    cfg = ConfigHolder(options)
    cfg.run()
PK=JG۟!vumi/scripts/benchmark_persist.py# -*- test-case-name: vumi.scripts.tests.test_benchmark_persist -*-
import sys
import time
from twisted.python import usage
from twisted.internet import reactor
from twisted.internet.defer import maybeDeferred, inlineCallbacks, DeferredList

from vumi.message import TransportUserMessage
from vumi.persist.model import Model
from vumi.persist.txriak_manager import TxRiakManager
from vumi.persist.fields import VumiMessage


class Options(usage.Options):
    optParameters = [
        ["messages", "m", "1000",
         "Total number of messages to write and read back."],
        ["concurrent-messages", "c", "100",
         "Number of messages to read and write concurrently"],
    ]

    longdesc = """Benchmarks vumi.persist.model.Model"""


class MessageModel(Model):
    msg = VumiMessage(TransportUserMessage)


class WriteReadBenchmark(object):
    """
    Writes messages to Riak and then reads them back.
    """

    def __init__(self, options):
        self.messages = int(options['messages'])
        self.concurrent = int(options['concurrent-messages'])

    def make_batches(self):
        num_batches, rem = divmod(self.messages, self.concurrent)
        batches = [self.make_batch(i, self.concurrent)
                   for i in range(num_batches)]
        if rem:
            batches.append(self.make_batch(num_batches, rem))
        return batches

    def make_batch(self, batch_no, num_msgs):
        return [TransportUserMessage(to_addr="1234", from_addr="5678",
                    transport_name="bench", transport_type="sms",
                    content="Batch: %d. Msg: %d" % (batch_no, i))
                for i in range(num_msgs)]

    def write_batch(self, model, msgs):
        print "  Writing %d messages." % len(msgs)
        deferreds = []
        for msg in msgs:
            msg_obj = model(key=msg['message_id'], msg=msg)
            deferreds.append(msg_obj.save())
        return DeferredList(deferreds)

    def read_batch(self, model, msgs):
        print "  Reading %d messages." % len(msgs)
        deferreds = []
        for msg in msgs:
            deferreds.append(model.load(msg['message_id']))
        return DeferredList(deferreds)

    @inlineCallbacks
    def run(self):
        manager = TxRiakManager.from_config({'bucket_prefix': 'test.bench.'})
        model = manager.proxy(MessageModel)
        yield manager.purge_all()

        msg_batches = self.make_batches()

        start = time.time()

        for batch in msg_batches:
            yield self.write_batch(model, batch)

        write_done = time.time()
        write_time = write_done - start
        print "Write took %.2f seconds (%.2f msgs/s)" % (
                write_time, self.messages / write_time)

        result_batches = []
        for batch in msg_batches:
            r = yield self.read_batch(model, batch)
            result_batches.append(r)

        read_done = time.time()
        read_time = read_done - write_done
        print "Read took %.2f seconds (%.2f msgs/s)" % (
                read_time, self.messages / read_time)

        for batch, result_batch in zip(msg_batches, result_batches):
            for msg, (good, stored_msg) in zip(batch, result_batch):
                if not good or stored_msg is None:
                    raise RuntimeError("Failed to retrieve message (%s)"
                                       % msg['content'])
                if not(msg == stored_msg.msg):  # TODO: fix message !=
                    raise RuntimeError("Message %r does not equal stored"
                                       " message %r" % (msg, stored_msg.msg))

        print "Messages retrieved successfully."

        yield manager.purge_all()
        print "Messages purged."

if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    bench = WriteReadBenchmark(options)

    def _eb(f):
        f.printTraceback()

    def _main():
        d = maybeDeferred(bench.run)
        d.addErrback(_eb)
        d.addBoth(lambda _: reactor.stop())

    reactor.callLater(0, _main)
    reactor.run()
PK=JG 9IIvumi/scripts/model_migrator.py# -*- test-case-name: vumi.scripts.tests.test_model_migrator -*-
import sys

from twisted.python import usage

from vumi.utils import load_class_by_string
from vumi.persist.riak_manager import RiakManager


class Options(usage.Options):
    optParameters = [
        ["model", "m", None,
         "Full Python name of the model class to migrate."
         " E.g. 'vumi.components.message_store.InboundMessage'."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["keys", None, None,
         "Migrate these specific keys rather than the whole bucket."
         " E.g. --keys 'foo,bar,baz'"],
    ]

    optFlags = [
        ["dry-run", None, "Don't save anything back to Riak."],
    ]

    longdesc = """Offline model migrator. Necessary for updating
                  models when index names change so that old model
                  instances remain findable by index searches.
                  """

    def postOptions(self):
        if self['model'] is None:
            raise usage.UsageError("Please specify a model class.")
        if self['bucket-prefix'] is None:
            raise usage.UsageError("Please specify a bucket prefix.")


class ProgressEmitter(object):
    """Report progress as a percentage to an emitter."""

    def __init__(self, total, emit):
        self.emit = emit
        self.total = total
        self.percentage = 0

    def _calculate_percentage(self, value):
        if value == 0:
            return 0
        return int(value * 100.0 / self.total)

    def update(self, value):
        old_percentage = self.percentage
        self.percentage = self._calculate_percentage(value)
        if self.percentage != old_percentage:
            self.emit(self.percentage)


class ModelMigrator(object):
    def __init__(self, options):
        self.options = options
        model_cls = load_class_by_string(options['model'])
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.model = self.manager.proxy(model_cls)

    def cleanup(self):
        self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return RiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    def _run(self):
        dry_run = self.options["dry-run"]
        if self.options["keys"] is not None:
            keys = self.options["keys"].split(",")
            self.emit("Migrating %d specified keys ..." % len(keys))
        else:
            keys = self.model.all_keys()
            self.emit("%d keys found. Migrating ..." % len(keys))
        # Depending on our Riak client, Python version, and JSON library we may
        # get bytes or unicode here.
        keys = [k.decode('utf-8') if isinstance(k, str) else k for k in keys]
        progress = ProgressEmitter(
            len(keys),
            lambda p: self.emit("%s%% complete." % (p,))
        )
        for i, key in enumerate(keys):
            try:
                obj = self.model.load(key)
                if obj is not None:
                    if not dry_run:
                        obj.save()
                else:
                    self.emit("Skipping tombstone key %r." % (key,))
            except Exception, e:
                self.emit("Failed to migrate key %r:" % (key,))
                self.emit("  %s: %s" % (type(e).__name__, e))
            progress.update(i)
        self.emit("Done.")

    def run(self):
        try:
            self._run()
        finally:
            self.cleanup()


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    cfg = ModelMigrator(options)
    cfg.run()
PK=JGvumi/scripts/__init__.pyPK=JGNzAvumi/scripts/inject_messages.py# -*- test-case-name: vumi.scripts.tests.test_inject_messages -*-
import sys
import json
from twisted.python import usage
from twisted.internet import reactor, threads
from twisted.internet.defer import (maybeDeferred, DeferredQueue,
                                    inlineCallbacks)
from vumi.message import TransportUserMessage
from vumi.service import Worker, WorkerCreator
from vumi.servicemaker import VumiOptions
from vumi.utils import to_kwargs


class InjectorOptions(VumiOptions):
    optParameters = [
        ["transport-name", None, None,
            "Name of the transport to inject messages from"],
        ["direction", None, "inbound",
            "Direction messages are to be sent to."],
        ["verbose", "v", False, "Output the JSON being injected"],
    ]

    def postOptions(self):
        VumiOptions.postOptions(self)
        if not self['transport-name']:
            raise usage.UsageError("Please provide the "
                                    "transport-name parameter.")


class MessageInjector(Worker):

    WORKER_QUEUE = DeferredQueue()

    @inlineCallbacks
    def startWorker(self):
        self.transport_name = self.config['transport-name']
        self.direction = self.config['direction']
        self.publisher = yield self.publish_to(
            '%s.%s' % (self.transport_name, self.direction))
        self.WORKER_QUEUE.put(self)

    def process_file(self, in_file, out_file=None):
        return threads.deferToThread(self._process_file_in_thread,
                                     in_file, out_file)

    def _process_file_in_thread(self, in_file, out_file):
        for line in in_file:
            line = line.strip()
            self.emit(out_file, line)
            threads.blockingCallFromThread(reactor, self.process_line, line)

    def emit(self, out_file, obj):
        if out_file is not None:
            out_file.write('%s\n' % (obj,))

    def process_line(self, line):
        data = {
            'transport_name': self.transport_name,
            'transport_metadata': {},
        }
        data.update(json.loads(line))
        self.publisher.publish_message(
            TransportUserMessage(**to_kwargs(data)))


@inlineCallbacks
def main(options):
    verbose = options['verbose']

    worker_creator = WorkerCreator(options.vumi_options)
    service = worker_creator.create_worker_by_class(
        MessageInjector, options)
    yield service.startService()

    in_file = sys.stdin
    out_file = sys.stdout if verbose else None

    worker = yield MessageInjector.WORKER_QUEUE.get()
    yield worker.process_file(in_file, out_file)
    reactor.stop()


if __name__ == '__main__':
    try:
        options = InjectorOptions()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    def _eb(f):
        f.printTraceback()

    def _main():
        maybeDeferred(main, options).addErrback(_eb)

    reactor.callLater(0, _main)
    reactor.run()
PK=JG1FF"vumi/scripts/vumi_list_messages.py#!/usr/bin/env python
# -*- test-case-name: vumi.scripts.tests.test_vumi_list_messages -*-

import sys

from twisted.internet.defer import inlineCallbacks
from twisted.internet.task import react
from twisted.python import usage

from vumi.components.message_store import MessageStore
from vumi.persist.txriak_manager import TxRiakManager


class Options(usage.Options):
    optParameters = [
        ["batch", None, None,
         "Batch identifier to list messages for."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["direction", None, None,
         "Message direction. Valid values are `inbound' and `outbound'."],
        ["index-page-size", None, "1000",
         "The number of keys to fetch in each index query."],
    ]

    longdesc = """
    Index-based message store lister. For each message, the timestamp, remote
    address, and message_id are returned in a comma-separated format.
    """

    def postOptions(self):
        if self["batch"] is None:
            raise usage.UsageError("Please specify a batch.")
        if self["direction"] not in ["inbound", "outbound"]:
            raise usage.UsageError("Please specify a valid direction.")
        if self["bucket-prefix"] is None:
            raise usage.UsageError("Please specify a bucket prefix.")
        self["index-page-size"] = int(self['index-page-size'])


class MessageLister(object):
    def __init__(self, options):
        self.options = options
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.mdb = MessageStore(self.manager, None)

    def cleanup(self):
        return self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return TxRiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    @inlineCallbacks
    def list_pages(self, index_page):
        while index_page is not None:
            next_page_d = index_page.next_page()
            for message_id, timestamp, addr in index_page:
                self.emit(",".join([timestamp, addr, message_id]))
            index_page = yield next_page_d

    @inlineCallbacks
    def _run(self):
        index_func = {
            "inbound": self.mdb.batch_inbound_keys_with_addresses,
            "outbound": self.mdb.batch_outbound_keys_with_addresses,
        }[self.options["direction"]]
        index_page = yield index_func(
            self.options["batch"], max_results=self.options["index-page-size"])
        yield self.list_pages(index_page)

    @inlineCallbacks
    def run(self):
        try:
            yield self._run()
        finally:
            yield self.cleanup()


def main(_reactor, name, *args):
    try:
        options = Options()
        options.parseOptions(args)
    except usage.UsageError, errortext:
        print '%s: %s' % (name, errortext)
        print '%s: Try --help for usage details.' % (name,)
        sys.exit(1)

    model_counter = MessageLister(options)
    return model_counter.run()


if __name__ == '__main__':
    react(main, sys.argv)
PK=JG熩#vumi/scripts/vumi_model_migrator.py#!/usr/bin/env python
# -*- test-case-name: vumi.scripts.tests.test_vumi_model_migrator -*-
import sys

from twisted.internet.defer import inlineCallbacks, gatherResults, succeed
from twisted.internet.task import react
from twisted.python import usage

from vumi.utils import load_class_by_string
from vumi.persist.txriak_manager import TxRiakManager


class Options(usage.Options):
    optParameters = [
        ["model", "m", None,
         "Full Python name of the model class to migrate."
         " E.g. 'vumi.components.message_store.InboundMessage'."],
        ["bucket-prefix", "b", None,
         "The bucket prefix for the Riak manager."],
        ["keys", None, None,
         "Migrate these specific keys rather than the whole bucket."
         " E.g. --keys 'foo,bar,baz'"],
        ["concurrent-migrations", None, "20",
         "The number of concurrent migrations to perform."],
        ["index-page-size", None, "1000",
         "The number of keys to fetch in each index query."],
        ["continuation-token", None, None,
         "A continuation token for resuming an interrupted migration."],
        ["post-migrate-function", None, None,
         "Full Python name of a callable to post-process each migrated object."
         " Should update the model object and return a (possibly deferred)"
         " boolean to indicate whether the object has been modified."],
    ]

    optFlags = [
        ["dry-run", None, "Don't save anything back to Riak."],
    ]

    longdesc = """Offline model migrator. Necessary for updating
                  models when index names change so that old model
                  instances remain findable by index searches.
                  """

    def postOptions(self):
        if self['model'] is None:
            raise usage.UsageError("Please specify a model class.")
        if self['bucket-prefix'] is None:
            raise usage.UsageError("Please specify a bucket prefix.")
        self['concurrent-migrations'] = int(self['concurrent-migrations'])
        self['index-page-size'] = int(self['index-page-size'])


class ProgressEmitter(object):
    """Report progress as the number of items processed to an emitter."""

    def __init__(self, emit, batch_size):
        self.emit = emit
        self.batch_size = batch_size
        self.processed = 0

    def update(self, value):
        if (value / self.batch_size) > (self.processed / self.batch_size):
            self.emit(value)
        self.processed = value


class FakeIndexPage(object):
    def __init__(self, keys, page_size):
        self._keys = keys
        self._page_size = page_size

    def __iter__(self):
        return iter(self._keys[:self._page_size])

    def has_next_page(self):
        return len(self._keys) > self._page_size

    def next_page(self):
        return succeed(
            type(self)(self._keys[self._page_size:], self._page_size))


class ModelMigrator(object):
    def __init__(self, options):
        self.options = options
        model_cls = load_class_by_string(options['model'])
        riak_config = {
            'bucket_prefix': options['bucket-prefix'],
        }
        self.manager = self.get_riak_manager(riak_config)
        self.model = self.manager.proxy(model_cls)

        # The default post-migrate-function does nothing and returns True if
        # and only if the object was migrated.
        self.post_migrate_function = lambda obj: obj.was_migrated
        if options['post-migrate-function'] is not None:
            self.post_migrate_function = load_class_by_string(
                options['post-migrate-function'])

    def cleanup(self):
        return self.manager.close_manager()

    def get_riak_manager(self, riak_config):
        return TxRiakManager.from_config(riak_config)

    def emit(self, s):
        print s

    @inlineCallbacks
    def migrate_key(self, key, dry_run):
        try:
            obj = yield self.model.load(key)
            if obj is not None:
                should_save = yield self.post_migrate_function(obj)
                if should_save and not dry_run:
                    yield obj.save()
            else:
                self.emit("Skipping tombstone key %r." % (key,))
        except Exception, e:
            self.emit("Failed to migrate key %r:" % (key,))
            self.emit("  %s: %s" % (type(e).__name__, e))

    @inlineCallbacks
    def migrate_keys(self, _result, keys_list, dry_run):
        """
        Migrate keys from `keys_list` until there are none left.

        This method is expected to be called multiple times concurrently with
        all instances sharing the same `keys_list`.
        """
        # keys_list is a shared mutable list, so we can't just iterate over it.
        while keys_list:
            key = keys_list.pop(0)
            yield self.migrate_key(key, dry_run)

    def migrate_page(self, keys, dry_run):
        # Depending on our Riak client, Python version, and JSON library we may
        # get bytes or unicode here.
        keys = [k.decode('utf-8') if isinstance(k, str) else k for k in keys]
        return gatherResults([
            self.migrate_keys(None, keys, dry_run)
            for _ in xrange(self.options["concurrent-migrations"])])

    @inlineCallbacks
    def migrate_pages(self, index_page, emit_progress):
        dry_run = self.options["dry-run"]
        progress = ProgressEmitter(
            emit_progress, self.options["index-page-size"])
        processed = 0
        while index_page is not None:
            if index_page.has_next_page():
                next_page_d = index_page.next_page()
            else:
                next_page_d = succeed(None)
            keys = list(index_page)
            yield self.migrate_page(keys, dry_run)
            processed += len(keys)
            progress.update(processed)
            continuation = getattr(index_page, 'continuation', None)
            if continuation is not None:
                self.emit("Continuation token: '%s'" % (continuation,))
            index_page = yield next_page_d
        self.emit("Done, %s object%s migrated." % (
            processed, "" if processed == 1 else "s"))

    def migrate_specified_keys(self, keys):
        """
        Migrate specified keys.
        """
        self.emit("Migrating %d specified keys ..." % len(keys))
        emit_progress = lambda t: self.emit(
            "%s of %s objects migrated." % (t, len(keys)))
        index_page = FakeIndexPage(keys, self.options["index-page-size"])
        return self.migrate_pages(index_page, emit_progress)

    @inlineCallbacks
    def migrate_all_keys(self, continuation=None):
        """
        Perform an index query to get all keys and migrate them.

        If `continuation` is provided, it will be used as the starting point
        for the query.
        """
        self.emit("Migrating ...")
        emit_progress = lambda t: self.emit(
            "%s object%s migrated." % (t, "" if t == 1 else "s"))
        index_page = yield self.model.all_keys_page(
            max_results=self.options["index-page-size"],
            continuation=continuation)
        yield self.migrate_pages(index_page, emit_progress)

    def _run(self):
        if self.options["keys"] is not None:
            return self.migrate_specified_keys(self.options["keys"].split(","))
        else:
            return self.migrate_all_keys(self.options["continuation-token"])

    def run(self):
        return self._run().addBoth(lambda _: self.cleanup())


def main(_reactor, name, *args):
    try:
        options = Options()
        options.parseOptions(args)
    except usage.UsageError, errortext:
        print '%s: %s' % (name, errortext)
        print '%s: Try --help for usage details.' % (name,)
        sys.exit(1)

    model_migrator = ModelMigrator(options)
    return model_migrator.run()


if __name__ == '__main__':
    react(main, sys.argv)
PK=JGG5Ή"vumi/scripts/parse_log_messages.py# -*- test-case-name: vumi.scripts.tests.test_parse_log_messages -*-
import sys
import re
import json
import warnings
from twisted.python import usage
from twisted.internet import reactor
from twisted.internet.defer import maybeDeferred
import datetime
from vumi.message import to_json


DATE_PATTERN = re.compile(
    r'(?P\d{4})-(?P\d{2})-(?P\d{2}) '
    r'(?P\d{2}):(?P\d{2}):(?P\d{2})')
LOG_PATTERN = {
    'vumi': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 .* '
        r'Inbound: '),
    'smpp_inbound': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 .* '
        r'PUBLISHING INBOUND: (?P.*)'),
    'smpp_outbound': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 .* '
        r'Consumed outgoing message '),
    'dispatcher_inbound_message': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 Processed inbound message for [a-zA-Z0-9_]+: (?P.*)'),
    'dispatcher_outbound_message': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 Processed outbound message for [a-zA-Z0-9_]+: (?P.*)'),
    'dispatcher_event': re.compile(
        r'(?P[\d\-\:\s]+)\+0000 Processed event message for [a-zA-Z0-9_]+: (?P.*)'),
}


class Options(usage.Options):
    optParameters = [
        ["from", "f", None,
         "Ignore any log lines prior to timestamp [YYYY-MM-DD HH:MM:SS]"],
        ["until", "u", None,
         "Ignore any log lines after timestamp [YYYY-MM-DD HH:MM:SS]"],
        ["format", None, "vumi",
         "Message format, one of: [vumi, smpp] (default vumi)"],
    ]

    longdesc = """Parses inbound messages logged by a Vumi worker from stdin
    and outputs them as JSON encoded Vumi messages to stdout. Useful
    along with the `inject_messages.py` script to replay failed inbound
    messages. The two formats supported currently are 'vumi' (which is a
    simple custom format used by some third-party workers) and 'smpp' (which is
    used for logging inbound messages by the SMPP transport).
    """


def parse_date(string, pattern):
    match = pattern.match(string)
    if match:
        return dict((k, int(v)) for k, v in match.groupdict().items())
    return {}


class LogParser(object):
    """
    Parses Vumi TransportUserMessages from a log file and writes
    simple JSON serialized Vumi TransportUserMessages to stdout.

    Regular expression may be passed in to specify log and date
    format.

    Two common output formats are the one used by SMPP for logging:

    `YYYY-MM-DD HH:MM:SS+0000  PUBLISHING INBOUND: `

    and the one used by some Vumi campaign workers:

    `YYYY-MM-DD HH:MM:SS+0000  Inbound: `
    """

    def __init__(self, options, date_pattern=None, log_pattern=None):
        self.date_pattern = date_pattern or DATE_PATTERN
        if options['format'] == 'smpp':
            warnings.warn(
                'smpp format is deprecated, use smpp_inbound instead',
                category=DeprecationWarning)
            options['format'] = 'smpp_inbound'

        self.log_pattern = log_pattern or LOG_PATTERN.get(options['format'])
        self.start = options['from']
        if self.start:
            self.start = datetime.datetime(**parse_date(self.start,
                                            self.date_pattern))
        self.stop = options['until']
        if self.stop:
            self.stop = datetime.datetime(**parse_date(self.stop,
                                            self.date_pattern))
        self.parse()

    def parse(self):
        while True:
            line = sys.stdin.readline()
            if not line:
                break
            self.readline(line)

    def emit(self, obj):
        sys.stdout.write('%s\n' % (obj,))

    def readline(self, line):
        match = self.log_pattern.match(line)
        if match:
            data = match.groupdict()
            date = datetime.datetime(**parse_date(data['date'],
                                        self.date_pattern))
            if self.start and self.start > date:
                return
            if self.stop and date > self.stop:
                return
            try:
                # JSON
                self.emit(to_json(json.loads(data['message'])))
            except:
                # Raw dict being printed
                self.emit(to_json(eval(data['message'])))


if __name__ == '__main__':
    try:
        options = Options()
        options.parseOptions()
    except usage.UsageError, errortext:
        print '%s: %s' % (sys.argv[0], errortext)
        print '%s: Try --help for usage details.' % (sys.argv[0])
        sys.exit(1)

    def _eb(f):
        f.printTraceback()

    def _main():
        d = maybeDeferred(LogParser, options)
        d.addErrback(_eb)
        d.addCallback(lambda _: reactor.stop())

    reactor.callLater(0, _main)
    reactor.run()
PK=JGO_  -vumi/scripts/tests/test_parse_log_messages.pyimport json
from pkg_resources import resource_string

from vumi.scripts.parse_log_messages import LogParser
from vumi.tests.helpers import VumiTestCase


class DummyLogParser(LogParser):
    def __init__(self, *args, **kwargs):
        super(DummyLogParser, self).__init__(*args, **kwargs)
        self.emit_log = []

    def parse(self):
        pass

    def emit(self, obj):
        self.emit_log.append(obj)


SAMPLE_INBOUND_LINE = (
    "2012-04-12 10:52:23+0000 [WorkerAMQClient,client] "
    "Inbound: ")

SAMPLE_SMPP_OUTBOUND_LINE = (
    "2013-09-02 07:07:36+0000 [VumiRedis,client] Consumed outgoing message "
    ""
)

LOGGING_MW_INBOUND_LINE = (
    "2014-08-11 06:55:12+0000 Processed inbound message for jsbox_transport: "
    "{\"transport_name\": \"aat_ussd_transport\", \"in_reply_to\": null, "
    "\"group\": null, \"from_addr\": \"2783XXXXXXX\", \"timestamp\": "
    "\"2014-08-11 06:25:08.616561\", \"to_addr\": \"*134*550#\", \"content\":"
    " null, \"routing_metadata\": {}, \"message_version\": \"20110921\","
    " \"transport_type\": \"ussd\", \"helper_metadata\": {}, "
    "\"transport_metadata\": {\"aat_ussd\": "
    "{\"provider\": \"MTN\", \"ussd_session_id\": \"XXXX\"}}, "
    "\"session_event\": \"new\", \"message_id\": \"XXX\", "
    "\"message_type\": \"user_message\"}"
)

LOGGING_MW_OUTBOUND_LINE = (
    "2014-08-11 06:55:12+0000 Processed outbound message for jsbox_transport: "
    "{\"transport_name\": \"aat_ussd_transport\", \"in_reply_to\": null, "
    "\"group\": null, \"from_addr\": \"*134*550#\", \"timestamp\": "
    "\"2014-08-11 06:25:08.616561\", \"to_addr\": \"2783XXXXXXX\", \"content\":"
    " \"hello\", \"routing_metadata\": {}, \"message_version\": \"20110921\","
    " \"transport_type\": \"ussd\", \"helper_metadata\": {}, "
    "\"transport_metadata\": {\"aat_ussd\": "
    "{\"provider\": \"MTN\", \"ussd_session_id\": \"XXXX\"}}, "
    "\"session_event\": \"new\", \"message_id\": \"XXX\", "
    "\"message_type\": \"user_message\"}"
)

LOGGING_MW_EVENT_LINE = (
    "2014-08-11 06:55:12+0000 Processed event message for "
    "billing_dispatcher_ro: {\"transport_name\": \"mtech_ng_smpp_transport\","
    " \"event_type\": \"delivery_report\", \"event_id\":"
    " \"XXX\", \"timestamp\":"
    " \"2014-08-11 06:51:22.927352\", \"routing_metadata\": {},"
    " \"message_version\": \"20110921\", \"helper_metadata\": {},"
    " \"delivery_status\": \"delivered\", \"transport_metadata\": {},"
    " \"user_message_id\": \"XXX\","
    " \"message_type\": \"event\"}"
)


class TestParseSMPPLogMessages(VumiTestCase):

    def test_parsing_of_line(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'vumi',
        })
        parser.readline(SAMPLE_INBOUND_LINE)

        parsed = json.loads(parser.emit_log[0])
        expected = {
            "content": "hello world",
            "transport_type": "ussd",
            "to_addr": "*120*12345*489665#",
            "message_id": "b1893fa98ff4485299e3781f73ebfbb6",
            "from_addr": "+27123456780"
        }
        for key in expected.keys():
            self.assertEqual(parsed.get(key), expected.get(key))

    def test_parsing_of_smpp_inbound_line(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'smpp_inbound',
        })
        parser.readline(
            "2011-11-15 02:04:48+0000 [EsmeTransceiver,client] "
            "PUBLISHING INBOUND: {'content': u'AFN9WH79', 'transport_type': "
            "'sms', 'to_addr': '1458', 'message_id': 'ec443820-62a8-4051-92e7"
            "-66adaa487d20', 'from_addr': '23xxxxxxxx'}")

        self.assertEqual(json.loads(parser.emit_log[0]), {
            "content": "AFN9WH79",
            "transport_type": "sms",
            "to_addr": "1458",
            "message_id": "ec443820-62a8-4051-92e7-66adaa487d20",
            "from_addr": "23xxxxxxxx"
        })

    def test_parsing_of_smpp_outbound_line(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'smpp_outbound'
        })
        parser.readline(SAMPLE_SMPP_OUTBOUND_LINE)
        parsed = json.loads(parser.emit_log[0])
        expected = {
            "content": "hello world",
            "transport_type": "sms",
            "to_addr": "+27123456780",
            "message_id": "baz",
            "from_addr": "default10141"
        }
        for key in expected.keys():
            self.assertEqual(parsed.get(key), expected.get(key))

    def test_parse_of_smpp_lines_with_limits(self):
        sample = resource_string(__name__, 'sample-smpp-output.log')
        parser = DummyLogParser({
            'from': '2011-11-15 00:23:59',
            'until': '2011-11-15 00:24:26',
            'format': 'smpp',
            })
        for line in sample.split('\n'):
            parser.readline(line)

        self.assertEqual(len(parser.emit_log), 2)
        self.assertEqual(json.loads(parser.emit_log[0].strip())['content'],
                         "CODE2")
        self.assertEqual(json.loads(parser.emit_log[1].strip())['content'],
                         "CODE3")

    def test_parse_of_logging_mw_inbound(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'dispatcher_inbound_message'
        })
        parser.readline(LOGGING_MW_INBOUND_LINE)
        parsed = json.loads(parser.emit_log[0])
        expected = {
            "content": None,
            "transport_type": "ussd",
            "to_addr": "*134*550#",
            "message_id": "XXX",
            "from_addr": "2783XXXXXXX"
        }
        for key in expected.keys():
            self.assertEqual(parsed.get(key), expected.get(key))

    def test_parse_of_logging_mw_outbound(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'dispatcher_outbound_message'
        })
        parser.readline(LOGGING_MW_OUTBOUND_LINE)
        parsed = json.loads(parser.emit_log[0])
        expected = {
            "content": "hello",
            "transport_type": "ussd",
            "to_addr": "2783XXXXXXX",
            "message_id": "XXX",
            "from_addr": "*134*550#"
        }
        for key in expected.keys():
            self.assertEqual(parsed.get(key), expected.get(key))

    def test_parse_of_logging_mw_event(self):
        parser = DummyLogParser({
            'from': None,
            'until': None,
            'format': 'dispatcher_event'
        })
        parser.readline(LOGGING_MW_EVENT_LINE)
        parsed = json.loads(parser.emit_log[0])
        expected = {
            "event_type": "delivery_report",
            "event_id": "XXX",
        }
        for key in expected.keys():
            self.assertEqual(parsed.get(key), expected.get(key))
PK=JG岴SS,vumi/scripts/tests/test_vumi_count_models.py"""Tests for vumi.scripts.vumi_model_migrator."""

import sys
from uuid import uuid4
from StringIO import StringIO

from twisted.internet.defer import inlineCallbacks
from twisted.python import usage

from vumi.persist.model import Model
from vumi.persist.fields import Unicode
from vumi.scripts.vumi_count_models import ModelCounter, Options, main
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class SimpleModel(Model):
    VERSION = 1
    a = Unicode(index=True)
    even_odd = Unicode(index=True)


class StubbedModelCounter(ModelCounter):
    def __init__(self, testcase, *args, **kwargs):
        self.testcase = testcase
        self.output = []
        super(StubbedModelCounter, self).__init__(*args, **kwargs)

    def emit(self, s):
        self.output.append(s)

    def get_riak_manager(self, riak_config):
        return self.testcase.get_riak_manager(riak_config)


class TestModelCounter(VumiTestCase):

    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True, is_sync=False))
        # Since we're never loading the actual objects, we can't detect
        # tombstones. Therefore, each test needs its own bucket prefix.
        self.expected_bucket_prefix = "bucket-%s" % (uuid4().hex,)
        self.riak_manager = self.persistence_helper.get_riak_manager({
            "bucket_prefix": self.expected_bucket_prefix,
        })
        self.add_cleanup(self.riak_manager.close_manager)
        self.model = self.riak_manager.proxy(SimpleModel)
        self.model_cls_path = ".".join([
            SimpleModel.__module__, SimpleModel.__name__])
        self.default_args = [
            "-m", self.model_cls_path,
            "-b", self.expected_bucket_prefix,
        ]

    def make_counter(self, args=None, index_page_size=None, index_field=None,
                     index_value=None, index_value_end=None,
                     index_value_regex=None):
        if args is None:
            args = self.default_args
        if index_field is not None:
            args.extend(["--index-field", index_field])
        if index_value is not None:
            args.extend(["--index-value", index_value])
        if index_value_end is not None:
            args.extend(["--index-value-end", index_value_end])
        if index_value_regex is not None:
            args.extend(["--index-value-regex", index_value_regex])
        if index_page_size is not None:
            args.extend(["--index-page-size", str(index_page_size)])
        options = Options()
        options.parseOptions(args)
        return StubbedModelCounter(self, options)

    def get_riak_manager(self, config):
        self.assertEqual(config["bucket_prefix"], self.expected_bucket_prefix)
        return self.persistence_helper.get_riak_manager(config)

    @inlineCallbacks
    def mk_simple_models(self, n, start=0):
        for i in range(start, start + n):
            even_odd = {0: u"even", 1: u"odd"}[i % 2]
            obj = self.model(
                u"key-%d" % i, a=u"value-%d" % i, even_odd=even_odd)
            yield obj.save()

    def test_model_class_required(self):
        self.assertRaises(usage.UsageError, self.make_counter, [
            "-b", self.expected_bucket_prefix,
        ])

    def test_bucket_required(self):
        self.assertRaises(usage.UsageError, self.make_counter, [
            "-m", self.model_cls_path,
        ])

    def test_index_value_requires_index(self):
        """
        index-value without index-field is invalid.
        """
        self.assertRaises(
            usage.UsageError, self.make_counter, index_value="foo")

    def test_index_requires_index_value(self):
        """
        index-field without index-value is invalid.
        """
        self.assertRaises(
            usage.UsageError, self.make_counter, index_field="foo")

    def test_index_value_end_requires_index_value(self):
        """
        index-value-end without index-value is invalid.
        """
        self.assertRaises(
            usage.UsageError, self.make_counter, index_value_end="foo")

    def test_index_value_regex_requires_index_value_end(self):
        """
        index-value-regex without a range query is pointless.
        """
        self.assertRaises(
            usage.UsageError, self.make_counter, index_field="foo",
            index_value="foo", index_value_regex="foo")

    @inlineCallbacks
    def test_main(self):
        """
        The counter runs via `main()`.
        """
        yield self.mk_simple_models(3)
        self.patch(sys, "stdout", StringIO())
        yield main(
            None, "name",
            "-m", self.model_cls_path,
            "-b", self.riak_manager.bucket_prefix)
        self.assertEqual(
            sys.stdout.getvalue(),
            "Counting all keys ...\nDone, 3 objects found.\n")

    @inlineCallbacks
    def test_count_all_keys(self):
        """
        All keys are counted
        """
        yield self.mk_simple_models(5)
        counter = self.make_counter()
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting all keys ...",
            "Done, 5 objects found.",
        ])

    @inlineCallbacks
    def test_count_by_index_value(self):
        yield self.mk_simple_models(5)
        counter = self.make_counter(index_field="even_odd", index_value="odd")
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting ...",
            "Done, 2 objects found.",
        ])

    @inlineCallbacks
    def test_count_by_index_range(self):
        yield self.mk_simple_models(5)
        counter = self.make_counter(
            index_field="a", index_value="value-1", index_value_end="value-3")
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting ...",
            "Done, 3 objects found.",
        ])

    @inlineCallbacks
    def test_count_with_filter_range(self):
        yield self.mk_simple_models(5)
        counter = self.make_counter(
            index_field="a", index_value="value-1", index_value_end="value-3",
            index_value_regex=r"value-[0134]")
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting ...",
            "Done, 2 objects found.",
        ])

    @inlineCallbacks
    def test_count_all_small_pages(self):
        yield self.mk_simple_models(3)
        counter = self.make_counter(index_page_size=2)
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting all keys ...",
            "2 objects counted.",
            "Done, 3 objects found.",
        ])

    @inlineCallbacks
    def test_count_range_small_pages(self):
        yield self.mk_simple_models(5)
        counter = self.make_counter(
            index_field="a", index_value="value-1", index_value_end="value-3",
            index_page_size=2)
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting ...",
            "2 objects counted.",
            "Done, 3 objects found.",
        ])

    @inlineCallbacks
    def test_count_all_tiny_pages(self):
        yield self.mk_simple_models(3)
        counter = self.make_counter(index_page_size=1)
        yield counter.run()
        self.assertEqual(counter.output, [
            "Counting all keys ...",
            "1 object counted.",
            "2 objects counted.",
            "3 objects counted.",
            "Done, 3 objects found.",
        ])
PK=JG~U@@-vumi/scripts/tests/test_vumi_list_messages.py"""Tests for vumi.scripts.vumi_list_messages."""

import sys
from datetime import datetime, timedelta
from uuid import uuid4
from StringIO import StringIO

from twisted.internet.defer import inlineCallbacks
from twisted.python import usage

from vumi.components.message_store import MessageStore
from vumi.scripts.vumi_list_messages import MessageLister, Options, main
from vumi.tests.helpers import VumiTestCase, PersistenceHelper, MessageHelper


class StubbedMessageLister(MessageLister):
    def __init__(self, testcase, *args, **kwargs):
        self.testcase = testcase
        self.output = []
        super(StubbedMessageLister, self).__init__(*args, **kwargs)

    def emit(self, s):
        self.output.append(s)

    def get_riak_manager(self, riak_config):
        return self.testcase.get_riak_manager(riak_config)


class TestMessageLister(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True, is_sync=False))
        self.msg_helper = self.add_helper(MessageHelper())
        # Since we're never loading the actual objects, we can't detect
        # tombstones. Therefore, each test needs its own bucket prefix.
        self.expected_bucket_prefix = "bucket-%s" % (uuid4().hex,)
        self.riak_manager = self.persistence_helper.get_riak_manager({
            "bucket_prefix": self.expected_bucket_prefix,
        })
        self.add_cleanup(self.riak_manager.close_manager)
        self.redis_manager = yield self.persistence_helper.get_redis_manager()
        self.mdb = MessageStore(self.riak_manager, self.redis_manager)
        self.default_args = [
            "-b", self.expected_bucket_prefix,
        ]

    def make_lister(self, args=None, batch=None, direction=None,
                    index_page_size=None):
        if args is None:
            args = self.default_args
        if batch is not None:
            args.extend(["--batch", batch])
        if direction is not None:
            args.extend(["--direction", direction])
        if index_page_size is not None:
            args.extend(["--index-page-size", str(index_page_size)])
        options = Options()
        options.parseOptions(args)
        return StubbedMessageLister(self, options)

    def get_riak_manager(self, config):
        self.assertEqual(config["bucket_prefix"], self.expected_bucket_prefix)
        return self.persistence_helper.get_riak_manager(config)

    def make_inbound(self, batch_id, from_addr, timestamp=None):
        if timestamp is None:
            timestamp = datetime.utcnow()
        msg = self.msg_helper.make_inbound(
            None, from_addr=from_addr, timestamp=timestamp)
        d = self.mdb.add_inbound_message(msg, batch_id=batch_id)
        timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")
        d.addCallback(
            lambda _: (timestamp_str, from_addr, msg["message_id"]))
        return d

    def make_outbound(self, batch_id, to_addr, timestamp=None):
        if timestamp is None:
            timestamp = datetime.utcnow()
        msg = self.msg_helper.make_outbound(
            None, to_addr=to_addr, timestamp=timestamp)
        d = self.mdb.add_outbound_message(msg, batch_id=batch_id)
        timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")
        d.addCallback(
            lambda _: (timestamp_str, to_addr, msg["message_id"]))
        return d

    def test_batch_required(self):
        self.assertRaises(usage.UsageError, self.make_lister, [
            "--direction", "inbound",
            "-b", self.expected_bucket_prefix,
        ])

    def test_valid_direction_required(self):
        self.assertRaises(usage.UsageError, self.make_lister, [
            "--batch", "gingercoookies",
            "-b", self.expected_bucket_prefix,
        ])
        self.assertRaises(usage.UsageError, self.make_lister, [
            "--batch", "gingercoookies",
            "--direction", "widdershins",
            "-b", self.expected_bucket_prefix,
        ])

    def test_bucket_required(self):
        self.assertRaises(usage.UsageError, self.make_lister, [
            "--batch", "gingercoookies",
            "--direction", "inbound",
        ])

    @inlineCallbacks
    def test_main(self):
        """
        The lister runs via `main()`.
        """
        msg_data = yield self.make_inbound("gingercookies", "12345")
        self.patch(sys, "stdout", StringIO())
        yield main(
            None, "name",
            "--batch", "gingercookies",
            "--direction", "inbound",
            "-b", self.riak_manager.bucket_prefix)
        self.assertEqual(
            sys.stdout.getvalue(),
            "%s\n" % (",".join(msg_data),))

    @inlineCallbacks
    def test_list_inbound(self):
        """
        Inbound messages can be listed.
        """
        start = datetime.utcnow() - timedelta(seconds=10)
        msg_datas = [
            (yield self.make_inbound(
                "gingercookies", "1234%d" % i, start + timedelta(seconds=i)))
            for i in range(5)
        ]
        lister = self.make_lister(batch="gingercookies", direction="inbound")
        yield lister.run()
        self.assertEqual(
            lister.output, [",".join(msg_data) for msg_data in msg_datas])

    @inlineCallbacks
    def test_list_inbound_small_pages(self):
        """
        Inbound messages can be listed.
        """
        start = datetime.utcnow() - timedelta(seconds=10)
        msg_datas = [
            (yield self.make_inbound(
                "gingercookies", "1234%d" % i, start + timedelta(seconds=i)))
            for i in range(5)
        ]
        lister = self.make_lister(
            batch="gingercookies", direction="inbound", index_page_size=2)
        yield lister.run()
        self.assertEqual(
            lister.output, [",".join(msg_data) for msg_data in msg_datas])

    @inlineCallbacks
    def test_list_outbound(self):
        """
        Outbound messages can be listed.
        """
        start = datetime.utcnow() - timedelta(seconds=10)
        msg_datas = [
            (yield self.make_outbound(
                "gingercookies", "1234%d" % i, start + timedelta(seconds=i)))
            for i in range(5)
        ]
        lister = self.make_lister(batch="gingercookies", direction="outbound")
        yield lister.run()
        self.assertEqual(
            lister.output, [",".join(msg_data) for msg_data in msg_datas])
PK=JG(vumi/scripts/tests/test_vumi_tagpools.py"""Tests for vumi.scripts.vumi_tagpools."""

from pkg_resources import resource_filename

from vumi.tests.helpers import VumiTestCase, PersistenceHelper


def make_cfg(args):
    from vumi.scripts.vumi_tagpools import ConfigHolder, Options

    class TestConfigHolder(ConfigHolder):
        def __init__(self, *args, **kwargs):
            self.output = []
            super(TestConfigHolder, self).__init__(*args, **kwargs)

        def emit(self, s):
            self.output.append(s)

    args = ["--config",
            resource_filename(__name__, "sample-tagpool-cfg.yaml")] + args
    options = Options()
    options.parseOptions(args)
    return TestConfigHolder(options)


class TagPoolBaseTestCase(VumiTestCase):
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(is_sync=True))
        # Make sure we start fresh.
        self.persistence_helper.get_redis_manager()._purge_all()


class TestCreatePoolCmd(TagPoolBaseTestCase):
    def test_create_pool_range_tags(self):
        cfg = make_cfg(["create-pool", "shortcode"])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Creating pool shortcode ...',
            '  Setting metadata ...',
            '  Declaring 1000 tag(s) ...',
            '  Done.',
            ])
        self.assertEqual(cfg.tagpool.get_metadata("shortcode"),
                         {'transport_type': 'sms'})
        self.assertEqual(sorted(cfg.tagpool.free_tags("shortcode")),
                         [("shortcode", str(d)) for d in range(10001, 11001)])
        self.assertEqual(cfg.tagpool.inuse_tags("shortcode"), [])

    def test_create_pool_explicit_tags(self):
        cfg = make_cfg(["create-pool", "xmpp"])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Creating pool xmpp ...',
            '  Setting metadata ...',
            '  Declaring 1 tag(s) ...',
            '  Done.',
            ])
        self.assertEqual(cfg.tagpool.get_metadata("xmpp"),
                         {'transport_type': 'xmpp'})
        self.assertEqual(sorted(cfg.tagpool.free_tags("xmpp")),
                         [("xmpp", "me@example.com")])
        self.assertEqual(cfg.tagpool.inuse_tags("xmpp"), [])


class TestUpdatePoolMetadataCmd(TagPoolBaseTestCase):
    def test_update_tagpool_metadata(self):
        cfg = make_cfg(["update-pool-metadata", "shortcode"])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Updating metadata for pool shortcode ...',
            '  Done.',
            ])
        self.assertEqual(cfg.tagpool.get_metadata("shortcode"),
                         {'transport_type': 'sms'})


class TestUpdateAllPoolMetadataCmd(TagPoolBaseTestCase):
    def test_update_all_metadata(self):
        cfg = make_cfg(["update-all-metadata"])
        cfg.tagpool.declare_tags([("xmpp", "tag"), ("longcode", "tag")])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Updating pool metadata.',
            'Note: Pools not present in both the config and tagpool'
            ' store will not be updated.',
            '  Updating metadata for pool longcode ...',
            '  Updating metadata for pool xmpp ...',
            'Done.'
            ])
        self.assertEqual(cfg.tagpool.get_metadata("longcode"),
                         {u'transport_type': u'sms'})
        self.assertEqual(cfg.tagpool.get_metadata("xmpp"),
                         {u'transport_type': u'xmpp'})
        self.assertEqual(cfg.tagpool.get_metadata("shortcode"), {})

    def test_no_pools(self):
        cfg = make_cfg(["update-all-metadata"])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Updating pool metadata.',
            'Note: Pools not present in both the config and tagpool'
            ' store will not be updated.',
            'No pools found.',
            ])


class TestPurgePoolCmd(TagPoolBaseTestCase):
    def test_purge_pool(self):
        cfg = make_cfg(["purge-pool", "foo"])
        cfg.tagpool.declare_tags([("foo", "tag1"), ("foo", "tag2")])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Purging pool foo ...',
            '  Done.',
            ])
        self.assertEqual(cfg.tagpool.free_tags("foo"), [])
        self.assertEqual(cfg.tagpool.inuse_tags("foo"), [])
        self.assertEqual(cfg.tagpool.get_metadata("foo"), {})


class TestListKeysCmd(TagPoolBaseTestCase):
    def setUp(self):
        super(TestListKeysCmd, self).setUp()
        self.test_tags = [("foo", "tag%d" % i) for
                          i in [1, 2, 3, 5, 6, 7, 9]]

    def test_list_keys_all_free(self):
        cfg = make_cfg(["list-keys", "foo"])
        cfg.tagpool.declare_tags(self.test_tags)
        cfg.run()
        self.assertEqual(cfg.output, [
            'Listing tags for pool foo ...',
            'Free tags:',
            '   tag[1-3], tag[5-7], tag9',
            'Tags in use:',
            '   -- None --',
            ])

    def test_list_keys_all_in_use(self):
        cfg = make_cfg(["list-keys", "foo"])
        cfg.tagpool.declare_tags(self.test_tags)
        for tag in self.test_tags:
            cfg.tagpool.acquire_tag("foo")
        cfg.run()
        self.assertEqual(cfg.output, [
            'Listing tags for pool foo ...',
            'Free tags:',
            '   -- None --',
            'Tags in use:',
            '   tag[1-3], tag[5-7], tag9',
            ])


class TestListPoolsCmd(TagPoolBaseTestCase):
    def test_list_pools_with_only_pools_in_config(self):
        cfg = make_cfg(["list-pools"])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Pools defined in cfg and tagpool:',
            '   -- None --',
            'Pools only in cfg:',
            '   longcode, shortcode, xmpp',
            'Pools only in tagpool:',
            '   -- None --',
            ])

    def test_list_pools_with_all_pools_in_tagpool(self):
        cfg = make_cfg(["list-pools"])
        cfg.tagpool.declare_tags([("xmpp", "tag"), ("longcode", "tag"),
                                  ("shortcode", "tag")])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Pools defined in cfg and tagpool:',
            '   longcode, shortcode, xmpp',
            'Pools only in cfg:',
            '   -- None --',
            'Pools only in tagpool:',
            '   -- None --',
            ])

    def test_list_pools_with_all_sorts_of_pools(self):
        cfg = make_cfg(["list-pools"])
        cfg.tagpool.declare_tags([("xmpp", "tag"), ("other", "tag")])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Pools defined in cfg and tagpool:',
            '   xmpp',
            'Pools only in cfg:',
            '   longcode, shortcode',
            'Pools only in tagpool:',
            '   other',
            ])


class TestReleaseTagCmd(TagPoolBaseTestCase):

    def setUp(self):
        super(TestReleaseTagCmd, self).setUp()
        self.test_tags = [("foo", "tag%d" % i) for
                          i in [1, 2, 3, 5, 6, 7, 9]]

    def test_release_tag_not_in_use(self):
        cfg = make_cfg(["release-tag", "foo", "tag1"])
        cfg.tagpool.declare_tags(self.test_tags)
        cfg.run()
        self.assertEqual(cfg.output,
                         ["Tag ('foo', 'tag1') not in use."])

    def test_release_unknown_tag(self):
        cfg = make_cfg(["release-tag", "foo", "tag1"])
        cfg.run()
        self.assertEqual(cfg.output,
                         ["Unknown tag ('foo', 'tag1')."])

    def test_release_tag(self):
        cfg = make_cfg(["release-tag", "foo", "tag1"])
        cfg.tagpool.declare_tags(self.test_tags)
        cfg.tagpool.acquire_specific_tag(('foo', 'tag1'))
        self.assertEqual(cfg.tagpool.inuse_tags('foo'), [('foo', 'tag1')])
        cfg.run()
        self.assertEqual(cfg.tagpool.inuse_tags('foo'), [])
        self.assertEqual(cfg.output, ["Released ('foo', 'tag1')."])
PK=JGvumi/scripts/tests/__init__.pyPK=JG@zdd)vumi/scripts/tests/sample-smpp-output.log2011-11-15 00:23:52+0000 [EsmeTransceiver,client] PUBLISHING INBOUND: {'content': u'CODE1', 'transport_type': 'sms', 'to_addr': '1458', 'message_id': '8b400347-408a-47b5-9663-ff719881d73d', 'from_addr': '23xxxxxxxx'}
2011-11-15 00:23:59+0000 [EsmeTransceiver,client] PUBLISHING INBOUND: {'content': u'CODE2', 'transport_type': 'sms', 'to_addr': '1458', 'message_id': 'c7590d37-6156-4e1e-8ad8-f11f97cc1a5b', 'from_addr': '24xxxxxxxx'}
2011-11-15 00:24:26+0000 [EsmeTransceiver,client] PUBLISHING INBOUND: {'content': u'CODE3', 'transport_type': 'sms', 'to_addr': '1458', 'message_id': 'b70b6698-ee6c-463e-be24-a86246e88a7a', 'from_addr': '25xxxxxxxx'}
2011-11-15 00:25:26+0000 [EsmeTransceiver,client] PUBLISHING INBOUND: {'content': u'CODE4', 'transport_type': 'sms', 'to_addr': '1458', 'message_id': 'f499a00d-83c2-4baf-8af3-cab27fb04394', 'from_addr': '26xxxxxxxx'}
PK=JGx)vumi/scripts/tests/test_model_migrator.py"""Tests for vumi.scripts.model_migrator."""

from twisted.python import usage

from vumi.persist.model import Model
from vumi.persist.fields import Unicode
from vumi.scripts.model_migrator import ModelMigrator, Options
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class SimpleModel(Model):
    a = Unicode()


class StubbedModelMigrator(ModelMigrator):
    def __init__(self, testcase, *args, **kwargs):
        # So we can patch the manager's load function to simulate failures.
        self._manager_load_func = kwargs.pop("manager_load_func", None)
        self.testcase = testcase
        self.output = []
        self.recorded_loads = []
        self.recorded_stores = []
        super(StubbedModelMigrator, self).__init__(*args, **kwargs)

    def emit(self, s):
        self.output.append(s)

    def get_riak_manager(self, riak_config):
        manager = self.testcase.get_riak_manager(riak_config)
        if self._manager_load_func is not None:
            self.testcase.patch(manager, "load", self._manager_load_func)
        self.testcase.persistence_helper.record_load_and_store(
            manager, self.recorded_loads, self.recorded_stores)
        return manager


class TestModelMigrator(VumiTestCase):

    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True, is_sync=True))
        self.expected_bucket_prefix = "bucket"
        self.riak_manager = self.persistence_helper.get_riak_manager({
            "bucket_prefix": self.expected_bucket_prefix,
        })
        self.add_cleanup(self.riak_manager.close_manager)
        self.model = self.riak_manager.proxy(SimpleModel)
        self.model_cls_path = ".".join([
            SimpleModel.__module__, SimpleModel.__name__])
        self.default_args = [
            "-m", self.model_cls_path,
            "-b", self.expected_bucket_prefix,
        ]

    def make_migrator(self, args=None, manager_load_func=None):
        if args is None:
            args = self.default_args
        options = Options()
        options.parseOptions(args)
        return StubbedModelMigrator(
            self, options, manager_load_func=manager_load_func)

    def get_riak_manager(self, config):
        self.assertEqual(config["bucket_prefix"], self.expected_bucket_prefix)
        return self.persistence_helper.get_riak_manager(config)

    def recorded_loads_and_stores(self, model_migrator):
        return model_migrator.recorded_loads, model_migrator.recorded_stores

    def mk_simple_models(self, n):
        for i in range(n):
            obj = self.model(u"key-%d" % i, a=u"value-%d" % i)
            obj.save()

    def test_model_class_required(self):
        self.assertRaises(usage.UsageError, self.make_migrator, [
            "-b", self.expected_bucket_prefix,
        ])

    def test_bucket_required(self):
        self.assertRaises(usage.UsageError, self.make_migrator, [
            "-m", self.model_cls_path,
        ])

    def test_successful_migration(self):
        self.mk_simple_models(3)
        cfg = self.make_migrator()
        loads, stores = self.recorded_loads_and_stores(cfg)
        cfg.run()
        self.assertEqual(cfg.output, [
            "3 keys found. Migrating ...",
            "33% complete.", "66% complete.",
            "Done.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])

    def test_migration_with_tombstones(self):
        self.mk_simple_models(3)

        def tombstone_load(modelcls, key, result=None):
            return None

        cfg = self.make_migrator(manager_load_func=tombstone_load)
        cfg.run()
        for i in range(3):
            self.assertTrue(("Skipping tombstone key u'key-%d'." % i)
                            in cfg.output)
        self.assertEqual(cfg.output[:1], [
            "3 keys found. Migrating ...",
        ])
        self.assertEqual(cfg.output[-2:], [
            "66% complete.",
            "Done.",
        ])

    def test_migration_with_failures(self):
        self.mk_simple_models(3)

        def error_load(modelcls, key, result=None):
            raise ValueError("Failed to load.")

        cfg = self.make_migrator(manager_load_func=error_load)
        cfg.run()
        line_pairs = zip(cfg.output, cfg.output[1:])
        for i in range(3):
            self.assertTrue((
                "Failed to migrate key u'key-0':",
                "  ValueError: Failed to load.",
            ) in line_pairs)
        self.assertEqual(cfg.output[:1], [
            "3 keys found. Migrating ...",
        ])
        self.assertEqual(cfg.output[-2:], [
            "66% complete.",
            "Done.",
        ])

    def test_migrating_specific_keys(self):
        self.mk_simple_models(3)
        cfg = self.make_migrator(self.default_args + ["--keys", "key-1,key-2"])
        loads, stores = self.recorded_loads_and_stores(cfg)
        cfg.run()
        self.assertEqual(cfg.output, [
            "Migrating 2 specified keys ...",
            "50% complete.",
            "Done.",
        ])
        self.assertEqual(sorted(loads), [u"key-1", u"key-2"])
        self.assertEqual(sorted(stores), [u"key-1", u"key-2"])

    def test_dry_run(self):
        self.mk_simple_models(3)
        cfg = self.make_migrator(self.default_args + ["--dry-run"])
        loads, stores = self.recorded_loads_and_stores(cfg)
        cfg.run()
        self.assertEqual(cfg.output, [
            "3 keys found. Migrating ...",
            "33% complete.", "66% complete.",
            "Done.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [])
PK=JG@@.vumi/scripts/tests/test_vumi_model_migrator.py"""Tests for vumi.scripts.vumi_model_migrator."""

import sys
from StringIO import StringIO

from twisted.internet.defer import inlineCallbacks, succeed
from twisted.internet.task import deferLater
from twisted.python import usage

from vumi.persist import model
from vumi.persist.fields import Unicode
from vumi.scripts.vumi_model_migrator import ModelMigrator, Options, main
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


def post_migrate_function(obj):
    """
    Post-migrate-function for use in tests.
    """
    obj.a = obj.a + u"-modified"
    return True


def post_migrate_function_deferred(obj):
    """
    Post-migrate-function for use in tests.
    """
    from twisted.internet import reactor
    return deferLater(reactor, 0.1, post_migrate_function, obj)


def post_migrate_function_new_only(obj):
    """
    Post-migrate-function for use in tests.
    """
    if obj.was_migrated:
        return post_migrate_function(obj)
    return False


def fqpn(thing):
    """
    Get the fully-qualified name of a thing.
    """
    return ".".join([thing.__module__, thing.__name__])


class SimpleModelMigrator(model.ModelMigrator):
    def migrate_from_1(self, migration_data):
        migration_data.set_value('$VERSION', 2)
        migration_data.copy_values("a")
        return migration_data


class SimpleModelOld(model.Model):
    VERSION = 1
    bucket = 'simplemodel'
    a = Unicode()


class SimpleModel(model.Model):
    VERSION = 2
    MIGRATOR = SimpleModelMigrator
    a = Unicode()


class StubbedModelMigrator(ModelMigrator):
    def __init__(self, testcase, *args, **kwargs):
        # So we can patch the manager's load function to simulate failures.
        self._manager_load_func = kwargs.pop("manager_load_func", None)
        self.testcase = testcase
        self.output = []
        self.recorded_loads = []
        self.recorded_stores = []
        super(StubbedModelMigrator, self).__init__(*args, **kwargs)

    def emit(self, s):
        self.output.append(s)

    def get_riak_manager(self, riak_config):
        manager = self.testcase.get_riak_manager(riak_config)
        if self._manager_load_func is not None:
            self.testcase.patch(manager, "load", self._manager_load_func)
        self.testcase.persistence_helper.record_load_and_store(
            manager, self.recorded_loads, self.recorded_stores)
        return manager


class TestVumiModelMigrator(VumiTestCase):

    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True, is_sync=False))
        self.expected_bucket_prefix = "bucket"
        self.riak_manager = self.persistence_helper.get_riak_manager({
            "bucket_prefix": self.expected_bucket_prefix,
        })
        self.add_cleanup(self.riak_manager.close_manager)
        self.old_model = self.riak_manager.proxy(SimpleModelOld)
        self.model = self.riak_manager.proxy(SimpleModel)
        self.model_cls_path = fqpn(SimpleModel)
        self.default_args = [
            "-m", self.model_cls_path,
            "-b", self.expected_bucket_prefix,
        ]

    def make_migrator(self, args=None, index_page_size=None,
                      concurrent_migrations=None, continuation_token=None,
                      post_migrate_function=None, manager_load_func=None):
        if args is None:
            args = self.default_args
        if index_page_size is not None:
            args.extend(
                ["--index-page-size", str(index_page_size)])
        if concurrent_migrations is not None:
            args.extend(
                ["--concurrent-migrations", str(concurrent_migrations)])
        if continuation_token is not None:
            args.extend(
                ["--continuation-token", continuation_token])
        if post_migrate_function is not None:
            args.extend(
                ["--post-migrate-function", post_migrate_function])
        options = Options()
        options.parseOptions(args)
        return StubbedModelMigrator(
            self, options, manager_load_func=manager_load_func)

    def get_riak_manager(self, config):
        self.assertEqual(config["bucket_prefix"], self.expected_bucket_prefix)
        return self.persistence_helper.get_riak_manager(config)

    def recorded_loads_and_stores(self, model_migrator):
        return model_migrator.recorded_loads, model_migrator.recorded_stores

    @inlineCallbacks
    def mk_simple_models_old(self, n, start=0):
        for i in range(start, start + n):
            obj = self.old_model(u"key-%d" % i, a=u"value-%d" % i)
            yield obj.save()

    @inlineCallbacks
    def mk_simple_models_new(self, n, start=0):
        for i in range(start, start + n):
            obj = self.model(u"key-%d" % i, a=u"value-%d" % i)
            yield obj.save()

    def test_model_class_required(self):
        self.assertRaises(usage.UsageError, self.make_migrator, [
            "-b", self.expected_bucket_prefix,
        ])

    def test_bucket_required(self):
        self.assertRaises(usage.UsageError, self.make_migrator, [
            "-m", self.model_cls_path,
        ])

    @inlineCallbacks
    def test_main(self):
        yield self.mk_simple_models_old(3)
        self.patch(sys, "stdout", StringIO())
        yield main(
            None, "name",
            "-m", self.model_cls_path,
            "-b", self.riak_manager.bucket_prefix)
        self.assertEqual(
            sys.stdout.getvalue(),
            "Migrating ...\nDone, 3 objects migrated.\n")

    @inlineCallbacks
    def test_successful_migration(self):
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator()
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])

    @inlineCallbacks
    def test_successful_migration_small_pages(self):
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(index_page_size=2)
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        [continuation] = [line for line in model_migrator.output
                          if line.startswith("Continuation token:")]
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "2 objects migrated.",
            continuation,
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])

    @inlineCallbacks
    def test_successful_migration_tiny_pages(self):
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(index_page_size=1)
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        [ct1, ct2, ct3] = [line for line in model_migrator.output
                           if line.startswith("Continuation token:")]
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "1 object migrated.",
            ct1,
            "2 objects migrated.",
            ct2,
            "3 objects migrated.",
            ct3,
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])

    @inlineCallbacks
    def test_successful_migration_with_continuation(self):
        yield self.mk_simple_models_old(3)

        # Run a migration all the way through to get a continuation token
        model_migrator = self.make_migrator(index_page_size=2)
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        [continuation] = [line for line in model_migrator.output
                          if line.startswith("Continuation token:")]
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "2 objects migrated.",
            continuation,
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])

        # Recreate key-2 because it was already migrated and would otherwise be
        # skipped.
        yield self.mk_simple_models_old(1, start=2)
        # Run a migration starting from the continuation point.
        loads[:] = []
        stores[:] = []
        continuation_token = continuation.split()[-1][1:-1]
        cont_model_migrator = self.make_migrator(
            index_page_size=2, continuation_token=continuation_token)
        cloads, cstores = self.recorded_loads_and_stores(cont_model_migrator)
        yield cont_model_migrator.run()
        self.assertEqual(cont_model_migrator.output, [
            "Migrating ...",
            "Done, 1 object migrated.",
        ])
        self.assertEqual(cloads, [u"key-2"])
        self.assertEqual(cstores, [u"key-2"])

    @inlineCallbacks
    def test_migration_with_tombstones(self):
        yield self.mk_simple_models_old(3)

        def tombstone_load(modelcls, key, result=None):
            return succeed(None)

        model_migrator = self.make_migrator(manager_load_func=tombstone_load)
        yield model_migrator.run()
        for i in range(3):
            self.assertTrue(("Skipping tombstone key u'key-%d'." % i)
                            in model_migrator.output)
        self.assertEqual(model_migrator.output[:1], [
            "Migrating ...",
        ])
        self.assertEqual(model_migrator.output[-1:], [
            "Done, 3 objects migrated.",
        ])

    @inlineCallbacks
    def test_migration_with_failures(self):
        yield self.mk_simple_models_old(3)

        def error_load(modelcls, key, result=None):
            raise ValueError("Failed to load.")

        model_migrator = self.make_migrator(manager_load_func=error_load)
        yield model_migrator.run()
        line_pairs = zip(model_migrator.output, model_migrator.output[1:])
        for i in range(3):
            self.assertTrue((
                "Failed to migrate key u'key-0':",
                "  ValueError: Failed to load.",
            ) in line_pairs)
        self.assertEqual(model_migrator.output[:1], [
            "Migrating ...",
        ])
        self.assertEqual(model_migrator.output[-1:], [
            "Done, 3 objects migrated.",
        ])

    @inlineCallbacks
    def test_migrating_specific_keys(self):
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(
            self.default_args + ["--keys", "key-1,key-2"])
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating 2 specified keys ...",
            "Done, 2 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-1", u"key-2"])
        self.assertEqual(sorted(stores), [u"key-1", u"key-2"])

    @inlineCallbacks
    def test_dry_run(self):
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(self.default_args + ["--dry-run"])
        loads, stores = self.recorded_loads_and_stores(model_migrator)
        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [])

    @inlineCallbacks
    def test_migrating_old_and_new_keys(self):
        """
        Models that haven't been migrated don't need to be stored.
        """
        yield self.mk_simple_models_old(1)
        yield self.mk_simple_models_new(1, start=1)
        yield self.mk_simple_models_old(1, start=2)
        model_migrator = self.make_migrator(self.default_args)
        loads, stores = self.recorded_loads_and_stores(model_migrator)

        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-0", u"key-1", u"key-2"])
        self.assertEqual(sorted(stores), [u"key-0", u"key-2"])

    @inlineCallbacks
    def test_migrating_with_post_migrate_function(self):
        """
        If post-migrate-function is provided, it should be called for every
        object.
        """
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(
            post_migrate_function=fqpn(post_migrate_function))
        loads, stores = self.recorded_loads_and_stores(model_migrator)

        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])
        for i in range(3):
            obj = yield self.model.load(u"key-%d" % i)
            self.assertEqual(obj.a, u"value-%d-modified" % i)

    @inlineCallbacks
    def test_migrating_with_deferred_post_migrate_function(self):
        """
        A post-migrate-function may return a Deferred.
        """
        yield self.mk_simple_models_old(3)
        model_migrator = self.make_migrator(
            post_migrate_function=fqpn(post_migrate_function_deferred))
        loads, stores = self.recorded_loads_and_stores(model_migrator)

        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])
        for i in range(3):
            obj = yield self.model.load(u"key-%d" % i)
            self.assertEqual(obj.a, u"value-%d-modified" % i)

    @inlineCallbacks
    def test_migrating_old_and_new_with_post_migrate_function(self):
        """
        A post-migrate-function may choose to modify objects that were not
        migrated.
        """
        yield self.mk_simple_models_old(1)
        yield self.mk_simple_models_new(1, start=1)
        yield self.mk_simple_models_old(1, start=2)
        model_migrator = self.make_migrator(
            post_migrate_function=fqpn(post_migrate_function))
        loads, stores = self.recorded_loads_and_stores(model_migrator)

        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-%d" % i for i in range(3)])
        self.assertEqual(sorted(stores), [u"key-%d" % i for i in range(3)])
        for i in range(3):
            obj = yield self.model.load(u"key-%d" % i)
            self.assertEqual(obj.a, u"value-%d-modified" % i)

    @inlineCallbacks
    def test_migrating_old_and_new_with_new_only_post_migrate_function(self):
        """
        A post-migrate-function may choose to leave objects that were not
        migrated unmodified.
        """
        yield self.mk_simple_models_old(1)
        yield self.mk_simple_models_new(1, start=1)
        yield self.mk_simple_models_old(1, start=2)
        model_migrator = self.make_migrator(
            post_migrate_function=fqpn(post_migrate_function_new_only))
        loads, stores = self.recorded_loads_and_stores(model_migrator)

        yield model_migrator.run()
        self.assertEqual(model_migrator.output, [
            "Migrating ...",
            "Done, 3 objects migrated.",
        ])
        self.assertEqual(sorted(loads), [u"key-0", u"key-1", u"key-2"])
        self.assertEqual(sorted(stores), [u"key-0", u"key-2"])

        obj_0 = yield self.model.load(u"key-0")
        self.assertEqual(obj_0.a, u"value-0-modified")
        obj_1 = yield self.model.load(u"key-1")
        self.assertEqual(obj_1.a, u"value-1")
        obj_2 = yield self.model.load(u"key-2")
        self.assertEqual(obj_2.a, u"value-2-modified")
PK=JGLP77$vumi/scripts/tests/test_db_backup.py"""Tests for vumi.scripts.db_backup."""

import json
import datetime

import yaml

from vumi.scripts.db_backup import ConfigHolder, Options, vumi_version
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class ConfigHolderWrapper(ConfigHolder):
    def __init__(self, testcase, *args, **kwargs):
        self.testcase = testcase
        self.output = []
        self.utcnow = None
        super(ConfigHolderWrapper, self).__init__(*args, **kwargs)

    def emit(self, s):
        self.output.append(s)

    def get_utcnow(self):
        self.utcnow = super(ConfigHolderWrapper, self).get_utcnow()
        return self.utcnow

    def get_redis(self, config):
        return self.testcase.get_sub_redis(config)


class DbBackupBaseTestCase(VumiTestCase):
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(is_sync=True))
        self.redis = self.persistence_helper.get_redis_manager()
        # Make sure we start fresh.
        self.redis._purge_all()

    def make_cfg(self, args):
        options = Options()
        options.parseOptions(args)
        return ConfigHolderWrapper(self, options)

    def mkfile(self, data):
        name = self.mktemp()
        with open(name, "wb") as data_file:
            data_file.write(data)
        return name

    def mkdbconfig(self, key_prefix):
        config = {
            'redis_manager': {
                'key_prefix': key_prefix,
            },
        }
        return self.mkfile(yaml.safe_dump(config))

    def get_sub_redis(self, config):
        config = config.copy()
        config['FAKE_REDIS'] = self.redis._client
        config['key_prefix'] = self.redis._key(config['key_prefix'])
        return self.persistence_helper.get_redis_manager(config)

    def mkdbbackup(self, data=None, raw=False):
        if data is None:
            data = self.DB_BACKUP
        if raw:
            dumps = lambda x: x
        else:
            dumps = json.dumps
        return self.mkfile("\n".join([dumps(x) for x in data]))


class TestBackupDbCmd(DbBackupBaseTestCase):
    def test_backup_db(self):
        self.redis.set("foo", 1)
        self.redis.set("bar:bar", 2)
        self.redis.set("bar:baz", "bar")
        db_backup = self.mktemp()
        cfg = self.make_cfg(["backup", self.mkdbconfig("bar"), db_backup])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Backing up dbs ...',
            'Backed up 2 keys.',
        ])
        with open(db_backup) as backup:
            self.assertEqual([json.loads(x) for x in backup], [
                {"vumi_version": vumi_version(),
                 "format": "LF separated JSON",
                 "backup_type": "redis",
                 "timestamp": cfg.utcnow.isoformat(),
                 "sorted": True,
                 "redis_config": {"key_prefix": "bar"},
                 },
                {'key': 'bar', 'type': 'string', 'value': '2', 'ttl': None},
                {'key': 'baz', 'type': 'string', 'value': 'bar', 'ttl': None},
            ])

    def check_backup(self, key_prefix, expected):
        db_backup = self.mktemp()
        cfg = self.make_cfg(["backup", self.mkdbconfig(key_prefix), db_backup])
        cfg.run()
        with open(db_backup) as backup:
            self.assertEqual([json.loads(x) for x in backup][1:], expected)

    def test_backup_string(self):
        self.redis.set("bar:s", "foo")
        self.check_backup("bar", [{'key': 's', 'type': 'string',
                                   'value': "foo", 'ttl': None}])

    def test_backup_list(self):
        lvalue = ["a", "c", "b"]
        for item in lvalue:
            self.redis.rpush("bar:l", item)
        self.check_backup("bar", [{'key': 'l', 'type': 'list',
                                   'value': lvalue, 'ttl': None}])

    def test_backup_set(self):
        svalue = set(["a", "c", "b"])
        for item in svalue:
            self.redis.sadd("bar:s", item)
        self.check_backup("bar", [{'key': 's', 'type': 'set',
                                   'value': sorted(svalue), 'ttl': None}])

    def test_backup_zset(self):
        zvalue = [['z', 1], ['a', 2], ['c', 3]]
        for item, score in zvalue:
            self.redis.zadd("bar:z", **{item: score})
        self.check_backup("bar", [{'key': 'z', 'type': 'zset',
                                   'value': zvalue, 'ttl': None}])

    def test_hash_backup(self):
        self.redis.hmset("bar:set", {"foo": "1", "baz": "2"})
        self.check_backup("bar", [{'key': 'set', 'type': 'hash',
                                   'value': {"foo": "1", "baz": "2"},
                                   'ttl': None}])

    def test_ttl_backup(self):
        self.redis.set("bar:s", "foo")
        self.redis.expire("bar:s", 30)
        db_backup = self.mktemp()
        cfg = self.make_cfg(["backup", self.mkdbconfig("bar"), db_backup])
        cfg.run()
        with open(db_backup) as backup:
            [record] = [json.loads(x) for x in backup][1:]
            self.assertTrue(0 < record.pop('ttl') <= 30)
            self.assertEqual(record, {'key': 's', 'type': 'string',
                                      'value': "foo"})


class TestRestoreDbCmd(DbBackupBaseTestCase):

    DB_BACKUP = [
        {'backup_type': 'redis',
         'timestamp': '2012-08-21T23:18:52.413504'},
        {'key': 'bar', 'type': 'string', 'value': "2", 'ttl': None},
        {'key': 'baz', 'type': 'string', 'value': "bar", 'ttl': None},
    ]

    RESTORED_DATA = [
        {'bar': '2'},
        {'baz': "bar"},
    ]

    def _bad_header_test(self, data, expected_response, raw=False):
        cfg = self.make_cfg(["restore", self.mkdbconfig("bar"),
                             self.mkdbbackup(data, raw=raw)])
        cfg.run()
        self.assertEqual(cfg.output, expected_response)

    def test_empty_backup(self):
        self._bad_header_test([], [
            'Header not found.',
            'Aborting restore.',
        ])

    def test_header_not_json(self):
        self._bad_header_test(["."], [
            'Header not JSON.',
            'Aborting restore.',
        ], raw=True)

    def test_non_json_dict(self):
        self._bad_header_test(["."], [
            'Header not JSON dict.',
            'Aborting restore.',
        ])

    def test_header_missing_backup_type(self):
        self._bad_header_test([{}], [
            'Header missing backup_type.',
            'Aborting restore.',
        ])

    def test_unsupported_backup_type(self):
        self._bad_header_test([{'backup_type': 'notredis'}], [
            'Only redis backup type currently supported.',
            'Aborting restore.',
        ])

    def test_restore_backup(self):
        cfg = self.make_cfg(["restore", self.mkdbconfig("bar"),
                             self.mkdbbackup()])
        cfg.run()
        self.assertEqual(cfg.output, [
            'Restoring dbs ...',
            '2 keys successfully restored.',
        ])
        redis_data = sorted(
            (key, self.redis.get(key)) for key in self.redis.keys())
        expected_data = [tuple(x.items()[0]) for x in self.RESTORED_DATA]
        expected_data = [("bar:%s" % (k,), v) for k, v in expected_data]
        self.assertEqual(redis_data, expected_data)

    def test_restore_with_purge(self):
        redis = self.redis.sub_manager("bar")
        redis.set("foo", 1)
        cfg = self.make_cfg(["restore", "--purge", self.mkdbconfig("bar"),
                             self.mkdbbackup()])
        cfg.run()
        self.assertEqual(redis.get("foo"), None)

    def check_restore(self, backup_data, restored_data, redis_get,
                      timestamp=None, args=(), key_prefix="bar"):
        if timestamp is None:
            timestamp = datetime.datetime.utcnow()
        backup_data = [{'backup_type': 'redis',
                        'timestamp': timestamp.isoformat(),
                        }] + backup_data
        cfg = self.make_cfg(["restore"] + list(args) +
                            [self.mkdbconfig(key_prefix),
                             self.mkdbbackup(backup_data)])
        cfg.run()
        redis_data = sorted((key, redis_get(key)) for key in self.redis.keys())
        restored_data = sorted([("%s:%s" % (key_prefix, k), v)
                                for k, v in restored_data.items()])
        self.assertEqual(redis_data, restored_data)

    def test_restore_string(self):
        self.check_restore([{'key': 's', 'type': 'string', 'value': 'ping',
                             'ttl': None}],
                           {'s': 'ping'}, self.redis.get)

    def test_restore_list(self):
        lvalue = ['z', 'a', 'c']
        self.check_restore([{'key': 'l', 'type': 'list', 'value': lvalue,
                             'ttl': None}],
                           {'l': lvalue},
                           lambda k: self.redis.lrange(k, 0, -1))

    def test_restore_set(self):
        svalue = set(['z', 'a', 'c'])
        self.check_restore([{'key': 's', 'type': 'set',
                             'value': list(svalue), 'ttl': None}],
                           {'s': svalue}, self.redis.smembers)

    def test_restore_zset(self):
        def get_zset(k):
            return self.redis.zrange(k, 0, -1, withscores=True)
        zvalue = [('z', 1), ('a', 2), ('c', 3)]
        self.check_restore([{'key': 'z', 'type': 'zset', 'value': zvalue,
                             'ttl': None}],
                           {'z': zvalue}, get_zset)

    def test_restore_hash(self):
        hvalue = {'a': 'foo', 'b': 'bing'}
        self.check_restore([{'key': 'h', 'type': 'hash', 'value': hvalue,
                             'ttl': None}],
                           {'h': hvalue}, self.redis.hgetall)

    def test_restore_ttl(self):
        self.check_restore([{'key': 's', 'type': 'string', 'value': 'ping',
                             'ttl': 30}],
                           {'s': 'ping'}, self.redis.get, key_prefix="bar")
        self.assertTrue(0 < self.redis.ttl("bar:s") <= 30)

    def test_restore_ttl_frozen(self):
        yesterday = datetime.datetime.utcnow() - datetime.timedelta(days=1)
        self.check_restore([{'key': 's', 'type': 'string', 'value': 'ping',
                             'ttl': 30}],
                           {'s': 'ping'}, self.redis.get,
                           timestamp=yesterday,
                           args=["--frozen-ttls"], key_prefix="bar")
        self.assertTrue(0 < self.redis.ttl("bar:s") <= 30)


class TestMigrateDbCmd(DbBackupBaseTestCase):

    def mkrules(self, rules):
        config = {
            "rules": rules,
        }
        return self.mkfile(yaml.safe_dump(config))

    def check_rules(self, rules, data, output, expected):
        header = [{"backup_type": "redis"}]
        result_file = self.mkfile("")
        cfg = self.make_cfg(["migrate",
                             self.mkrules(rules),
                             self.mkdbbackup(header + data),
                             result_file])
        cfg.run()
        self.assertEqual(cfg.output, ["Migrating backup ...",
                                      "Summary of changes:"
                                      ] + output)

        result = [json.loads(x) for x in open(result_file)]
        self.assertEqual(result, header + expected)

    def test_single_regex_rename(self):
        self.check_rules([{"type": "rename", "from": r"foo:", "to": r"baz:"}],
                         [{"key": "foo:bar", "value": "foobar"},
                          {"key": "bar:foo", "value": "barfoo"}],
                         ["  2 records processed.",
                          "  1 records altered."],
                         [{"key": "baz:bar", "value": "foobar"},
                          {"key": "bar:foo", "value": "barfoo"}])

    def test_multiple_renames(self):
        self.check_rules([{"type": "rename", "from": r"foo:", "to": r"baz:"},
                          {"type": "rename", "from": r"bar:", "to": r"rab:"}],
                         [{"key": "foo:bar", "value": "foobar"},
                          {"key": "bar:foo", "value": "barfoo"}],
                         ["  2 records processed.",
                          "  2 records altered."],
                         [{"key": "baz:bar", "value": "foobar"},
                          {"key": "rab:foo", "value": "barfoo"}])

    def test_single_drop(self):
        self.check_rules([{"type": "drop", "key": r"foo:"}],
                         [{"key": "foo:bar", "value": "foobar"},
                          {"key": "bar:foo", "value": "barfoo"}],
                         ["  2 records processed.",
                          "  1 records altered."],
                         [{"key": "bar:foo", "value": "barfoo"}])

    def test_multiple_drops(self):
        self.check_rules([{"type": "drop", "key": r"foo:"},
                          {"type": "drop", "key": r"bar:"}],
                         [{"key": "foo:bar", "value": "foobar"},
                          {"key": "bar:foo", "value": "barfoo"}],
                         ["  2 records processed.",
                          "  2 records altered."],
                         [])


class TestAnalyzeCmd(DbBackupBaseTestCase):
    def mkkeysbackup(self, keys):
        records = [{'backup_type': 'redis'}]
        records.extend({'key': k} for k in keys)
        return self.mkdbbackup(records)

    def check_tree(self, keys, output):
        db_backup = self.mkkeysbackup(keys)
        cfg = self.make_cfg(["analyze", db_backup])
        cfg.run()
        self.assertEqual(cfg.output, ["Keys:", "-----"] + output)

    def test_no_keys(self):
        self.check_tree([], [])

    def test_one_key(self):
        self.check_tree(["foo"], ["foo"])

    def test_two_distinct_keys(self):
        self.check_tree(["foo", "bar"], ["bar", "foo"])

    def test_two_keys_that_share_prefix(self):
        self.check_tree(["foo:bar", "foo:baz"], [
            "foo: (2 leaves)",
        ])

    def test_full_tree(self):
        keys = (["foo:%d" % i for i in range(10)] +
                ["foo:bar:%d" % i for i in range(3)] +
                ["bar:%d" % i for i in range(4)])
        self.check_tree(keys, [
            "bar: (4 leaves)",
            "foo: (10 leaves)",
            "  bar: (3 leaves)",
        ])
PK=JGr--+vumi/scripts/tests/test_vumi_redis_tools.py"""Tests for vumi.scripts.vumi_redis_tools."""

import StringIO

import yaml

from twisted.python.usage import UsageError

from vumi.scripts.vumi_redis_tools import (
    scan_keys, TaskRunner, Options, Task, TaskError,
    Count, Expire, Persist, ListKeys, Skip)
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class DummyTaskRunner(object):
    def __init__(self):
        self.output = []

    def emit(self, s):
        self.output.append(s)


class DummyTask(Task):
    """Dummy task for testing."""

    name = "dummy"
    hidden = True

    def __init__(self, a=None, b=None):
        self.a = a
        self.b = b


class TestTask(VumiTestCase):
    def test_name(self):
        t = Task()
        self.assertEqual(t.name, None)

    def test_parse_with_args(self):
        t = Task.parse("dummy:a=foo,b=bar")
        self.assertEqual(t.name, "dummy")
        self.assertEqual(t.a, "foo")
        self.assertEqual(t.b, "bar")
        self.assertEqual(type(t), DummyTask)

    def test_parse_without_args(self):
        t = Task.parse("dummy")
        self.assertEqual(t.name, "dummy")
        self.assertEqual(t.a, None)
        self.assertEqual(t.b, None)
        self.assertEqual(type(t), DummyTask)

    def test_parse_no_task(self):
        self.assertRaises(TaskError, Task.parse, "unknown")

    def test_init(self):
        t = Task()
        runner = object()
        redis = object()
        t.init(runner, redis)
        self.assertEqual(t.runner, runner)
        self.assertEqual(t.redis, redis)


class TestCount(VumiTestCase):

    def setUp(self):
        self.runner = DummyTaskRunner()

    def mk_count(self):
        t = Count()
        t.init(self.runner, None)
        t.before()
        return t

    def test_name(self):
        t = Count()
        self.assertEqual(t.name, "count")

    def test_create(self):
        t = Task.parse("count")
        self.assertEqual(t.name, "count")
        self.assertEqual(type(t), Count)

    def test_before(self):
        t = self.mk_count()
        self.assertEqual(t.count, 0)

    def test_process_key(self):
        t = self.mk_count()
        key = t.process_key("foo")
        self.assertEqual(t.count, 1)
        self.assertEqual(key, "foo")

    def test_after(self):
        t = self.mk_count()
        for i in range(5):
            t.process_key(str(i))
        t.after()
        self.assertEqual(self.runner.output, [
            "Found 5 matching keys.",
        ])


class TestExpire(VumiTestCase):

    def setUp(self):
        self.runner = DummyTaskRunner()
        self.persistence_helper = self.add_helper(
            PersistenceHelper(is_sync=True))
        self.redis = self.persistence_helper.get_redis_manager()
        self.redis._purge_all()  # Make sure we start fresh.

    def mk_expire(self, seconds=10):
        t = Expire(seconds)
        t.init(self.runner, self.redis)
        t.before()
        return t

    def test_name(self):
        t = Expire(seconds=20)
        self.assertEqual(t.name, "expire")

    def test_create(self):
        t = Task.parse("expire:seconds=20")
        self.assertEqual(t.name, "expire")
        self.assertEqual(t.seconds, 20)
        self.assertEqual(type(t), Expire)

    def test_process_key(self):
        t = self.mk_expire(seconds=10)
        self.redis.set("key1", "bar")
        self.redis.set("key2", "baz")
        key = t.process_key("key1")
        self.assertEqual(key, "key1")
        self.assertTrue(
            0 < self.redis.ttl("key1") <= 10)
        self.assertEqual(
            self.redis.ttl("key2"), None)


class TestPersist(VumiTestCase):

    def setUp(self):
        self.runner = DummyTaskRunner()
        self.persistence_helper = self.add_helper(
            PersistenceHelper(is_sync=True))
        self.redis = self.persistence_helper.get_redis_manager()
        self.redis._purge_all()  # Make sure we start fresh.

    def mk_persist(self):
        t = Persist()
        t.init(self.runner, self.redis)
        t.before()
        return t

    def test_name(self):
        t = Persist()
        self.assertEqual(t.name, "persist")

    def test_create(self):
        t = Task.parse("persist")
        self.assertEqual(t.name, "persist")
        self.assertEqual(type(t), Persist)

    def test_process_key(self):
        t = self.mk_persist()
        self.redis.setex("key1", 10, "bar")
        self.redis.setex("key2", 20, "baz")
        key = t.process_key("key1")
        self.assertEqual(key, "key1")
        self.assertEqual(
            self.redis.ttl("key1"), None)
        self.assertTrue(
            0 < self.redis.ttl("key2") <= 20)


class TestListKeys(VumiTestCase):

    def setUp(self):
        self.runner = DummyTaskRunner()

    def mk_list(self):
        t = ListKeys()
        t.init(self.runner, None)
        t.before()
        return t

    def test_name(self):
        t = ListKeys()
        self.assertEqual(t.name, "list")

    def test_create(self):
        t = Task.parse("list")
        self.assertEqual(t.name, "list")
        self.assertEqual(type(t), ListKeys)

    def test_process_key(self):
        t = self.mk_list()
        key = t.process_key("key1")
        self.assertEqual(key, "key1")
        self.assertEqual(self.runner.output, [
            "key1",
        ])


class TestSkip(VumiTestCase):

    def setUp(self):
        self.runner = DummyTaskRunner()

    def mk_skip(self, pattern=".*"):
        t = Skip(pattern)
        t.init(self.runner, None)
        t.before()
        return t

    def test_name(self):
        t = Skip(".*")
        self.assertEqual(t.name, "skip")

    def test_create(self):
        t = Task.parse("skip:pattern=.*")
        self.assertEqual(t.name, "skip")
        self.assertEqual(t.regex.pattern, ".*")
        self.assertEqual(type(t), Skip)

    def test_process_key(self):
        t = self.mk_skip("skip_.*")
        self.assertEqual(t.process_key("skip_this"), None)
        self.assertEqual(t.process_key("dont_skip"), "dont_skip")


class TestOptions(VumiTestCase):
    def mk_file(self, data):
        name = self.mktemp()
        with open(name, "wb") as data_file:
            data_file.write(data)
        return name

    def mk_redis_config(self, key_prefix):
        config = {
            'redis_manager': {
                'key_prefix': key_prefix,
            },
        }
        return self.mk_file(yaml.safe_dump(config))

    def mk_opts_raw(self, args):
        opts = Options()
        opts.parseOptions(args)
        return opts

    def mk_opts(self, args):
        config = self.mk_redis_config("foo")
        return self.mk_opts_raw(args + [config, "*"])

    def test_no_config(self):
        exc = self.assertRaises(
            UsageError,
            self.mk_opts_raw, [])
        self.assertEqual(str(exc), "Wrong number of arguments.")

    def test_no_pattern(self):
        exc = self.assertRaises(
            UsageError,
            self.mk_opts_raw, ["config.yaml"])
        self.assertEqual(str(exc), "Wrong number of arguments.")

    def test_no_tasks(self):
        exc = self.assertRaises(
            UsageError,
            self.mk_opts, [])
        self.assertEqual(str(exc), "Please specify a task.")

    def test_one_task(self):
        opts = self.mk_opts(["-t", "count"])
        self.assertEqual(
            [t.name for t in opts["tasks"]],
            ["count"]
        )

    def test_multiple_tasks(self):
        opts = self.mk_opts(["-t", "list", "-t", "count"])
        self.assertEqual(
            [t.name for t in opts["tasks"]],
            ["list", "count"]
        )

    def test_help(self):
        opts = Options()
        lines = opts.getUsage().splitlines()
        self.assertEqual(lines[-6:], [
            "Available tasks:",
            "      --count    A task that counts the number of keys.",
            "      --expire   A task that sets an expiry time on each key.",
            "      --list     A task that prints out each key.",
            "      --persist  A task that persists each key.",
            "      --skip     A task that skips keys that match a regular"
            " expression.",
        ])


class TestTaskRunner(VumiTestCase):
    def make_runner(self, tasks, redis=None, pattern="*"):
        if redis is None:
            redis = self.mk_redis_config()
        args = tasks + [redis, pattern]
        options = Options()
        options.parseOptions(args)
        runner = TaskRunner(options)
        runner.redis._purge_all()   # Make sure we start fresh.
        runner.stdout = StringIO.StringIO()
        return runner

    def output(self, runner):
        return runner.stdout.getvalue().splitlines()

    def mk_file(self, data):
        name = self.mktemp()
        with open(name, "wb") as data_file:
            data_file.write(data)
        return name

    def mk_redis_config(self):
        config = {
            'redis_manager': {
                'FAKE_REDIS': True,
            },
        }
        return self.mk_file(yaml.safe_dump(config))

    def test_single_task(self):
        runner = self.make_runner([
            "-t", "count",
        ])
        runner.run()
        self.assertEqual(self.output(runner), [
            'Found 0 matching keys.',
        ])

    def test_multiple_task(self):
        runner = self.make_runner([
            "-t", "expire:seconds=10",
            "-t", "count",
        ])
        runner.redis.set("key1", "k1")
        runner.redis.set("key2", "k2")
        runner.run()
        self.assertEqual(self.output(runner), [
            'Found 2 matching keys.',
        ])
        self.assertTrue(0 < runner.redis.ttl("key1") <= 10)
        self.assertTrue(0 < runner.redis.ttl("key2") <= 10)

    def test_match(self):
        runner = self.make_runner([
            "-t", "expire:seconds=10",
            "-t", "count",
        ], pattern="coffee:*")
        runner.redis.set("coffee:key1", "k1")
        runner.redis.set("tea:key2", "k2")
        runner.run()
        self.assertEqual(self.output(runner), [
            'Found 1 matching keys.',
        ])
        self.assertTrue(0 < runner.redis.ttl("coffee:key1") <= 10)
        self.assertEqual(runner.redis.ttl("tea:key2"), None)

    def test_key_skipping(self):
        runner = self.make_runner([
            "-t", "skip:pattern=key1",
            "-t", "list",
        ])
        runner.redis.set("key1", "k1")
        runner.redis.set("key2", "k2")
        runner.run()
        self.assertEqual(self.output(runner), [
            'key2',
        ])


class TestScanKeys(VumiTestCase):
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(is_sync=True))
        self.redis = self.persistence_helper.get_redis_manager()
        self.redis._purge_all()  # Make sure we start fresh.

    def test_no_keys(self):
        keys = list(scan_keys(self.redis, "*"))
        self.assertEqual(keys, [])

    def test_single_scan_loop(self):
        expected_keys = ["key%d" % i for i in range(5)]
        for key in expected_keys:
            self.redis.set(key, "foo")
        keys = sorted(scan_keys(self.redis, "*"))
        self.assertEqual(keys, expected_keys)

    def test_multiple_scan_loops(self):
        expected_keys = ["key%02d" % i for i in range(100)]
        for key in expected_keys:
            self.redis.set(key, "foo")
        keys = sorted(scan_keys(self.redis, "*"))
        self.assertEqual(keys, expected_keys)

    def test_match(self):
        self.redis.set("coffee:latte", "yes")
        self.redis.set("tea:rooibos", "yes")
        keys = list(scan_keys(self.redis, "coffee:*"))
        self.assertEqual(keys, ["coffee:latte"])
PKqGCC

*vumi/scripts/tests/test_inject_messages.pyimport StringIO
import json

from twisted.internet.defer import inlineCallbacks

from vumi.scripts.inject_messages import MessageInjector
from vumi.tests.helpers import VumiTestCase, WorkerHelper


class TestMessageInjector(VumiTestCase):

    DEFAULT_DATA = {
        'content': 'CODE2',
        'transport_type': 'sms',
        'to_addr': '1458',
        'message_id': '1',
        'from_addr': '1234',
    }

    def setUp(self):
        self.worker_helper = self.add_helper(WorkerHelper('sphex'))

    def get_worker(self, direction):
        return self.worker_helper.get_worker(MessageInjector, {
            'transport-name': 'sphex',
            'direction': direction,
        })

    def make_data(self, **kw):
        kw.update(self.DEFAULT_DATA)
        return kw

    def check_msg(self, msg, data):
        for key in data:
            self.assertEqual(msg[key], data[key])

    @inlineCallbacks
    def test_process_line_inbound(self):
        worker = yield self.get_worker('inbound')
        data = self.make_data()
        worker.process_line(json.dumps(data))
        [msg] = yield self.worker_helper.wait_for_dispatched_inbound()
        self.check_msg(msg, data)

    @inlineCallbacks
    def test_process_line_outbound(self):
        worker = yield self.get_worker('outbound')
        data = self.make_data()
        worker.process_line(json.dumps(data))
        [msg] = yield self.worker_helper.wait_for_dispatched_outbound()
        self.check_msg(msg, data)

    @inlineCallbacks
    def test_process_file_inbound(self):
        worker = yield self.get_worker('inbound')
        data = [self.make_data(message_id=i) for i in range(10)]
        data_string = "\n".join(json.dumps(datum) for datum in data)
        in_file = StringIO.StringIO(data_string)
        out_file = StringIO.StringIO()
        yield worker.process_file(in_file, out_file)
        msgs = yield self.worker_helper.wait_for_dispatched_inbound()
        for msg, datum in zip(msgs, data):
            self.check_msg(msg, datum)
        self.assertEqual(out_file.getvalue(), data_string + "\n")

    @inlineCallbacks
    def test_process_file_outbound(self):
        worker = yield self.get_worker('outbound')
        data = [self.make_data(message_id=i) for i in range(10)]
        data_string = "\n".join(json.dumps(datum) for datum in data)
        in_file = StringIO.StringIO(data_string)
        out_file = StringIO.StringIO()
        yield worker.process_file(in_file, out_file)
        msgs = yield self.worker_helper.wait_for_dispatched_outbound()
        for msg, datum in zip(msgs, data):
            self.check_msg(msg, datum)
        self.assertEqual(out_file.getvalue(), data_string + "\n")
PK=H1·**)vumi/components/message_store_resource.py# -*- test-case-name: vumi.components.tests.test_message_store_resource -*-

import iso8601

from twisted.application.internet import StreamServerEndpointService
from twisted.internet.defer import DeferredList, inlineCallbacks
from twisted.web.resource import NoResource, Resource
from twisted.web.server import NOT_DONE_YET

from vumi.components.message_store import MessageStore
from vumi.components.message_formatters import (
    JsonFormatter, CsvFormatter, CsvEventFormatter)
from vumi.config import (
    ConfigDict, ConfigText, ConfigServerEndpoint, ConfigInt,
    ServerEndpointFallback)
from vumi.message import format_vumi_date
from vumi.persist.txriak_manager import TxRiakManager
from vumi.persist.txredis_manager import TxRedisManager
from vumi.transports.httprpc import httprpc
from vumi.utils import build_web_site
from vumi.worker import BaseWorker


# NOTE: Thanks Ned http://stackoverflow.com/a/312464!
def chunks(l, n):
    """ Yield successive n-sized chunks from l.
    """
    for i in xrange(0, len(l), n):
        yield l[i:i + n]


class ParameterError(Exception):
    """
    Exception raised while trying to parse a parameter.
    """
    pass


class MessageStoreProxyResource(Resource):

    isLeaf = True
    default_concurrency = 1

    def __init__(self, message_store, batch_id, formatter):
        Resource.__init__(self)
        self.message_store = message_store
        self.batch_id = batch_id
        self.formatter = formatter

    def _extract_date_arg(self, request, argname):
        if argname not in request.args:
            return None
        if len(request.args[argname]) > 1:
            raise ParameterError(
                "Invalid '%s' parameter: Too many values" % (argname,))
        [value] = request.args[argname]
        try:
            timestamp = iso8601.parse_date(value)
            return format_vumi_date(timestamp)
        except iso8601.ParseError as e:
            raise ParameterError(
                "Invalid '%s' parameter: %s" % (argname, str(e)))

    def render_GET(self, request):
        if 'concurrency' in request.args:
            concurrency = int(request.args['concurrency'][0])
        else:
            concurrency = self.default_concurrency

        try:
            start = self._extract_date_arg(request, 'start')
            end = self._extract_date_arg(request, 'end')
        except ParameterError as e:
            request.setResponseCode(400)
            return str(e)

        self.formatter.add_http_headers(request)
        self.formatter.write_row_header(request)

        if not (start or end):
            d = self.get_keys_page(self.message_store, self.batch_id)
        else:
            d = self.get_keys_page_for_time(
                self.message_store, self.batch_id, start, end)
        request.connection_has_been_closed = False
        request.notifyFinish().addBoth(
            lambda _: setattr(request, 'connection_has_been_closed', True))
        d.addCallback(self.fetch_pages, concurrency, request)
        return NOT_DONE_YET

    def get_keys_page(self, message_store, batch_id):
        raise NotImplementedError('To be implemented by sub-class.')

    def get_keys_page_for_time(self, message_store, batch_id, start, end):
        raise NotImplementedError('To be implemented by sub-class.')

    def get_message(self, message_store, message_id):
        raise NotImplementedError('To be implemented by sub-class.')

    def fetch_pages(self, keys_page, concurrency, request):
        """
        Process a page of keys and each subsequent page.

        The keys for the current page are handed off to :meth:`fetch_page` for
        processing. If there is another page, we fetch that while the current
        page is being handled and add a callback to process it when the
        current page is finished.

        When there are no more pages, we add a callback to close the request.
        """
        if request.connection_has_been_closed:
            # We're no longer connected, so stop doing work.
            return
        d = self.fetch_page(keys_page, concurrency, request)
        if keys_page.has_next_page():
            # We fetch the next page before waiting for the current page to be
            # processed.
            next_page_d = keys_page.next_page()
            d.addCallback(lambda _: next_page_d)
            # Add this method as a callback to operate on the next page. It's
            # like recursion, but without worrying about stack size.
            d.addCallback(self.fetch_pages, concurrency, request)
        else:
            # No more pages, so close the request.
            d.addCallback(self.finish_request_cb, request)
        return d

    def finish_request_cb(self, _result, request):
        if not request.connection_has_been_closed:
            # We need to check for this here in case we lose the connection
            # while delivering the last page.
            return request.finish()

    @inlineCallbacks
    def fetch_page(self, keys_page, concurrency, request):
        """
        Process a page of keys in chunks of concurrently-fetched messages.
        """
        for keys in chunks(list(keys_page), concurrency):
            if request.connection_has_been_closed:
                # We're no longer connected, so stop doing work.
                return
            yield self.handle_chunk(keys, request)

    def handle_chunk(self, message_keys, request):
        """
        Concurrently fetch a chunk of messages and write each to the response.
        """
        return DeferredList([
            self.handle_message(key, request) for key in message_keys])

    def handle_message(self, message_key, request):
        d = self.get_message(self.message_store, message_key)
        d.addCallback(self.write_message, request)
        return d

    def write_message(self, message, request):
        if not request.content.closed:
            self.formatter.write_row(request, message)


class InboundResource(MessageStoreProxyResource):

    def get_keys_page(self, message_store, batch_id):
        return message_store.batch_inbound_keys_page(batch_id)

    def get_keys_page_for_time(self, message_store, batch_id, start, end):
        return message_store.batch_inbound_keys_with_timestamps(
            batch_id, max_results=message_store.DEFAULT_MAX_RESULTS,
            start=start, end=end, with_timestamps=False)

    def get_message(self, message_store, message_id):
        return message_store.get_inbound_message(message_id)


class OutboundResource(MessageStoreProxyResource):

    def get_keys_page(self, message_store, batch_id):
        return message_store.batch_outbound_keys_page(batch_id)

    def get_keys_page_for_time(self, message_store, batch_id, start, end):
        return message_store.batch_outbound_keys_with_timestamps(
            batch_id, max_results=message_store.DEFAULT_MAX_RESULTS,
            start=start, end=end, with_timestamps=False)

    def get_message(self, message_store, message_id):
        return message_store.get_outbound_message(message_id)


class EventResource(MessageStoreProxyResource):

    def get_keys_page(self, message_store, batch_id):
        return message_store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=message_store.DEFAULT_MAX_RESULTS,
            start=None, end=None)

    def get_keys_page_for_time(self, message_store, batch_id, start, end):
        return message_store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=message_store.DEFAULT_MAX_RESULTS,
            start=start, end=end)

    def get_message(self, message_store, event_index):
        event_id, _, _ = event_index
        return message_store.get_event(event_id)


class BatchResource(Resource):

    RESOURCES = {
        'inbound.json': (InboundResource, JsonFormatter),
        'outbound.json': (OutboundResource, JsonFormatter),
        'events.json': (EventResource, JsonFormatter),
        'inbound.csv': (InboundResource, CsvFormatter),
        'outbound.csv': (OutboundResource, CsvFormatter),
        'events.csv': (EventResource, CsvEventFormatter),
    }

    def __init__(self, message_store, batch_id):
        Resource.__init__(self)
        self.message_store = message_store
        self.batch_id = batch_id

    def getChild(self, path, request):
        if path not in self.RESOURCES:
            return NoResource()
        resource_class, message_formatter = self.RESOURCES.get(path)
        return resource_class(
            self.message_store, self.batch_id, message_formatter())


class MessageStoreResource(Resource):

    def __init__(self, message_store):
        Resource.__init__(self)
        self.message_store = message_store

    def getChild(self, path, request):
        return BatchResource(self.message_store, path)


class MessageStoreResourceWorker(BaseWorker):

    class CONFIG_CLASS(BaseWorker.CONFIG_CLASS):
        worker_name = ConfigText(
            'Name of the this message store resource worker',
            required=True, static=True)
        twisted_endpoint = ConfigServerEndpoint(
            'Twisted endpoint to listen on.', required=True, static=True,
            fallbacks=[ServerEndpointFallback()])
        web_path = ConfigText(
            'The path to serve this resource on.', required=True, static=True)
        health_path = ConfigText(
            'The path to serve the health resource on.', default='/health/',
            static=True)
        riak_manager = ConfigDict(
            'Riak client configuration.', default={}, static=True)
        redis_manager = ConfigDict(
            'Redis client configuration.', default={}, static=True)

        # TODO: Deprecate these fields when confmodel#5 is done.
        host = ConfigText(
            "*DEPRECATED* 'host' and 'port' fields may be used in place of"
            " the 'twisted_endpoint' field.", static=True)
        port = ConfigInt(
            "*DEPRECATED* 'host' and 'port' fields may be used in place of"
            " the 'twisted_endpoint' field.", static=True)

    def get_health_response(self):
        return 'OK'

    @inlineCallbacks
    def setup_worker(self):
        config = self.get_static_config()
        self._riak = yield TxRiakManager.from_config(config.riak_manager)
        redis = yield TxRedisManager.from_config(config.redis_manager)
        self.store = MessageStore(self._riak, redis)

        site = build_web_site({
            config.web_path: MessageStoreResource(self.store),
            config.health_path: httprpc.HttpRpcHealthResource(self),
        })
        self.addService(
            StreamServerEndpointService(config.twisted_endpoint, site))

    @inlineCallbacks
    def teardown_worker(self):
        yield self._riak.close_manager()

    def setup_connectors(self):
        # NOTE: not doing anything AMQP
        pass
PK=JGe]]vumi/components/tagpool_api.py# -*- coding: utf-8 -*-

"""JSON RPC API for vumi.components.tagpool."""

from txjsonrpc.web.jsonrpc import JSONRPC
from txjsonrpc.jsonrpc import addIntrospection

from twisted.application.internet import StreamServerEndpointService
from twisted.internet.defer import inlineCallbacks

from vumi.worker import BaseWorker
from vumi.config import (
    ConfigDict, ConfigText, ConfigServerEndpoint, ConfigInt,
    ServerEndpointFallback)
from vumi.persist.txredis_manager import TxRedisManager
from vumi.components.tagpool import TagpoolManager
from vumi.rpc import signature, Unicode, Tag, List, Dict
from vumi.transports.httprpc import httprpc
from vumi.utils import build_web_site


class TagpoolApiServer(JSONRPC):
    def __init__(self, tagpool):
        JSONRPC.__init__(self)
        self.tagpool = tagpool

    @signature(pool=Unicode("Name of pool to acquire tag from."),
               owner=Unicode("Owner acquiring tag (or None).", null=True),
               reason=Dict("Metadata on why tag is being acquired (or None).",
                           null=True),
               returns=Tag("Tag acquired (or None).", null=True))
    def jsonrpc_acquire_tag(self, pool, owner=None, reason=None):
        """Acquire a tag from the pool (returns None if no tags are avaliable).
           """
        d = self.tagpool.acquire_tag(pool, owner, reason)
        return d

    @signature(tag=Tag("Tag to acquire as [pool, tagname] pair."),
               owner=Unicode("Owner acquiring tag (or None).", null=True),
               reason=Dict("Metadata on why tag is being acquired (or None).",
                           null=True),
               returns=Tag("Tag acquired (or None).", null=True))
    def jsonrpc_acquire_specific_tag(self, tag, owner=None, reason=None):
        """Acquire the specific tag (returns None if the tag is unavailable).
           """
        d = self.tagpool.acquire_specific_tag(tag, owner, reason)
        return d

    @signature(tag=Tag("Tag to release."))
    def jsonrpc_release_tag(self, tag):
        """Release the specified tag if it exists and is inuse."""
        return self.tagpool.release_tag(tag)

    @signature(tags=List("List of tags to declare.", item_type=Tag()))
    def jsonrpc_declare_tags(self, tags):
        """Declare all of the listed tags."""
        return self.tagpool.declare_tags(tags)

    @signature(pool=Unicode("Name of pool to retreive metadata for."),
               returns=Dict("Retrieved metadata."))
    def jsonrpc_get_metadata(self, pool):
        """Retrieve the metadata for the given pool."""
        return self.tagpool.get_metadata(pool)

    @signature(pool=Unicode("Name of pool to update metadata for."),
               metadata=Dict("New value of metadata."))
    def jsonrpc_set_metadata(self, pool, metadata):
        """Set the metadata for the given pool."""
        return self.tagpool.set_metadata(pool, metadata)

    @signature(pool=Unicode("Name of the pool to purge."))
    def jsonrpc_purge_pool(self, pool):
        """Delete the given pool and all associated metadata and tags.

           No tags from the pool may be inuse.
           """
        return self.tagpool.purge_pool(pool)

    @signature(returns=List("List of pool names.", item_type=Unicode()))
    def jsonrpc_list_pools(self):
        """Return a list of all available pools."""
        d = self.tagpool.list_pools()
        d.addCallback(list)
        return d

    @signature(pool=Unicode("Name of pool."),
               returns=List("List of free tags.", item_type=Tag()))
    def jsonrpc_free_tags(self, pool):
        """Return a list of free tags in the given pool."""
        d = self.tagpool.free_tags(pool)
        return d

    @signature(pool=Unicode("Name of pool."),
               returns=List("List of tags inuse.", item_type=Tag()))
    def jsonrpc_inuse_tags(self, pool):
        """Return a list of tags currently in use within the given pool."""
        d = self.tagpool.inuse_tags(pool)
        return d

    @signature(tag=Tag("Tag to return ownership information on."),
               returns=List("List of owner and reason.", length=2, null=True))
    def jsonrpc_acquired_by(self, tag):
        """Returns the owner of an acquired tag and why is was acquired."""
        d = self.tagpool.acquired_by(tag)
        d.addCallback(list)
        return d

    @signature(owner=Unicode("Owner of tags (or None for unowned tags).",
                             null=True),
               returns=List("List of tags owned.", item_type=Tag()))
    def jsonrpc_owned_tags(self, owner):
        """Return a list of tags currently owned by an owner."""
        return self.tagpool.owned_tags(owner)


class TagpoolApiWorker(BaseWorker):

    class CONFIG_CLASS(BaseWorker.CONFIG_CLASS):
        worker_name = ConfigText(
            "Name of this tagpool API worker.", required=True, static=True)
        twisted_endpoint = ConfigServerEndpoint(
            "Twisted endpoint to listen on.", required=True, static=True,
            fallbacks=[ServerEndpointFallback()])
        web_path = ConfigText(
            "The path to serve this resource on.", required=True, static=True)
        health_path = ConfigText(
            "The path to server the health resource on.", default='/health/',
            static=True)
        redis_manager = ConfigDict(
            "Redis client configuration.", default={}, static=True)

        # TODO: Deprecate these fields when confmodel#5 is done.
        host = ConfigText(
            "*DEPRECATED* 'host' and 'port' fields may be used in place of"
            " the 'twisted_endpoint' field.", static=True)
        port = ConfigInt(
            "*DEPRECATED* 'host' and 'port' fields may be used in place of"
            " the 'twisted_endpoint' field.", static=True)

    def get_health_response(self):
        return "OK"

    @inlineCallbacks
    def setup_worker(self):
        config = self.get_static_config()
        self.redis_manager = yield TxRedisManager.from_config(
            config.redis_manager)
        tagpool = TagpoolManager(self.redis_manager)
        rpc = TagpoolApiServer(tagpool)
        addIntrospection(rpc)
        site = build_web_site({
            config.web_path: rpc,
            config.health_path: httprpc.HttpRpcHealthResource(self),
        })
        self.addService(
            StreamServerEndpointService(config.twisted_endpoint, site))

    def teardown_worker(self):
        pass

    def setup_connectors(self):
        pass
PK=JG``&vumi/components/message_store_cache.py# -*- test-case-name: vumi.components.tests.test_message_store_cache -*-
# -*- coding: utf-8 -*-

from datetime import datetime
import hashlib
import json
import time

from twisted.internet.defer import returnValue

from vumi.persist.redis_base import Manager
from vumi.message import TransportEvent, parse_vumi_date
from vumi.errors import VumiError


class MessageStoreCacheException(VumiError):
    pass


class MessageStoreCache(object):
    """
    A helper class to provide a view on information in the message store
    that is difficult to query straight from riak.
    """
    BATCH_KEY = 'batches'
    OUTBOUND_KEY = 'outbound'
    OUTBOUND_COUNT_KEY = 'outbound_count'
    INBOUND_KEY = 'inbound'
    INBOUND_COUNT_KEY = 'inbound_count'
    TO_ADDR_KEY = 'to_addr_hll'
    FROM_ADDR_KEY = 'from_addr_hll'
    EVENT_KEY = 'event'
    EVENT_COUNT_KEY = 'event_count'
    STATUS_KEY = 'status'
    SEARCH_TOKEN_KEY = 'search_token'
    SEARCH_RESULT_KEY = 'search_result'
    TRUNCATE_MESSAGE_KEY_COUNT_AT = 2000

    # Cache search results for 24 hrs
    DEFAULT_SEARCH_RESULT_TTL = 60 * 60 * 24

    def __init__(self, redis):
        # Store redis as `manager` as well since @Manager.calls_manager
        # requires it to be named as such.
        self.redis = self.manager = redis

    def key(self, *args):
        return ':'.join([unicode(a) for a in args])

    def batch_key(self, *args):
        return self.key(self.BATCH_KEY, *args)

    def outbound_key(self, batch_id):
        return self.batch_key(self.OUTBOUND_KEY, batch_id)

    def outbound_count_key(self, batch_id):
        return self.batch_key(self.OUTBOUND_COUNT_KEY, batch_id)

    def inbound_key(self, batch_id):
        return self.batch_key(self.INBOUND_KEY, batch_id)

    def inbound_count_key(self, batch_id):
        return self.batch_key(self.INBOUND_COUNT_KEY, batch_id)

    def to_addr_key(self, batch_id):
        return self.batch_key(self.TO_ADDR_KEY, batch_id)

    def from_addr_key(self, batch_id):
        return self.batch_key(self.FROM_ADDR_KEY, batch_id)

    def status_key(self, batch_id):
        return self.batch_key(self.STATUS_KEY, batch_id)

    def event_key(self, batch_id):
        return self.batch_key(self.EVENT_KEY, batch_id)

    def event_count_key(self, batch_id):
        return self.batch_key(self.EVENT_COUNT_KEY, batch_id)

    def search_token_key(self, batch_id):
        return self.batch_key(self.SEARCH_TOKEN_KEY, batch_id)

    def search_result_key(self, batch_id, token):
        return self.batch_key(self.SEARCH_RESULT_KEY, batch_id, token)

    def uses_counters(self, batch_id):
        """
        Returns ``True`` if ``batch_id`` has moved to the new system
        of using counters instead of assuming all keys are in Redis
        and doing a `zcard` on that.

        The test for this is to see if `inbound_count_key(batch_id)`
        exists. If it is then we've moved to the new system and are
        using counters.
        """
        return self.redis.exists(self.inbound_count_key(batch_id))

    def uses_event_counters(self, batch_id):
        """
        Returns ``True`` if ``batch_id`` has moved to the new system of using
        counters for events instead of assuming all keys are in Redis and doing
        a `zcard` on that.

        The test for this is to see if `inbound_count_key(batch_id)` exists. If
        it is then we've moved to the new system and are using counters.
        """
        return self.redis.exists(self.event_count_key(batch_id))

    @Manager.calls_manager
    def switch_to_counters(self, batch_id):
        """
        Actively switch a batch from the old ``zcard()`` based approach
        to the new ``redis.incr()`` counter based approach.
        """
        uses_counters = yield self.uses_counters(batch_id)
        if uses_counters:
            return

        # NOTE:     Under high load this may result in the counter being off
        #           by a few. Considering this is a cache that is to be
        #           reconciled we're happy for that to be the case.
        inbound_count = yield self.count_inbound_message_keys(batch_id)
        outbound_count = yield self.count_outbound_message_keys(batch_id)

        # We do `*_count or None` because there's a chance of getting back
        # a None if this is a new batch that's not received any traffic yet.
        yield self.redis.set(self.inbound_count_key(batch_id),
                             inbound_count or 0)
        yield self.redis.set(self.outbound_count_key(batch_id),
                             outbound_count or 0)

        yield self.truncate_inbound_message_keys(batch_id)
        yield self.truncate_outbound_message_keys(batch_id)

    @Manager.calls_manager
    def _truncate_keys(self, redis_key, truncate_at):
        # Indexes are zero based
        truncate_at = (truncate_at or self.TRUNCATE_MESSAGE_KEY_COUNT_AT) + 1
        # NOTE: Doing this because ZCARD is O(1) where ZREMRANGEBYRANK is
        #       O(log(N)+M)
        current_size = yield self.redis.zcard(redis_key)
        if current_size <= truncate_at:
            returnValue(0)

        keys_removed = yield self.redis.zremrangebyrank(
            redis_key, 0, truncate_at * -1)
        returnValue(keys_removed)

    def truncate_inbound_message_keys(self, batch_id, truncate_at=None):
        return self._truncate_keys(self.inbound_key(batch_id), truncate_at)

    def truncate_outbound_message_keys(self, batch_id, truncate_at=None):
        return self._truncate_keys(self.outbound_key(batch_id), truncate_at)

    def truncate_event_keys(self, batch_id, truncate_at=None):
        return self._truncate_keys(self.event_key(batch_id), truncate_at)

    @Manager.calls_manager
    def batch_start(self, batch_id, use_counters=True):
        """
        Does various setup work in order to be able to accurately
        store cached data for a batch_id.

        A call to this isn't necessary but good for general house keeping.

        :param bool use_counters:
            If ``True`` this batch is started and will use counters
            rather than Redis zsets() to keep track of message counts.

            Defaults to ``True``.


        This operation idempotent.
        """
        yield self.redis.sadd(self.batch_key(), batch_id)
        yield self.init_status(batch_id)
        if use_counters:
            yield self.redis.set(self.inbound_count_key(batch_id), 0)
            yield self.redis.set(self.outbound_count_key(batch_id), 0)
            yield self.redis.set(self.event_count_key(batch_id), 0)

    @Manager.calls_manager
    def init_status(self, batch_id):
        """
        Setup the hash for event tracking on this batch, it primes the
        hash to have the bare minimum of expected keys and their values
        all set to 0. If there's already an existing value then it is
        left untouched.
        """
        events = (TransportEvent.EVENT_TYPES.keys() +
                  ['delivery_report.%s' % status
                   for status in TransportEvent.DELIVERY_STATUSES] +
                  ['sent'])
        for event in events:
            yield self.redis.hsetnx(self.status_key(batch_id), event, 0)

    def get_batch_ids(self):
        """
        Return a list of known batch_ids
        """
        return self.redis.smembers(self.batch_key())

    def batch_exists(self, batch_id):
        return self.redis.sismember(self.batch_key(), batch_id)

    @Manager.calls_manager
    def clear_batch(self, batch_id):
        """
        Removes all cached values for the given batch_id, useful before
        a reconciliation happens to ensure that we start from scratch.

        NOTE:   This will reset all counters back to zero and will increment
                them as messages are received. If your UI depends on your
                cached values your UI values might be off while the
                reconciliation is taking place.
        """
        yield self.redis.delete(self.inbound_key(batch_id))
        yield self.redis.delete(self.inbound_count_key(batch_id))
        yield self.redis.delete(self.outbound_key(batch_id))
        yield self.redis.delete(self.outbound_count_key(batch_id))
        yield self.redis.delete(self.event_key(batch_id))
        yield self.redis.delete(self.event_count_key(batch_id))
        yield self.redis.delete(self.status_key(batch_id))
        yield self.redis.delete(self.to_addr_key(batch_id))
        yield self.redis.delete(self.from_addr_key(batch_id))
        yield self.redis.srem(self.batch_key(), batch_id)

    def get_timestamp(self, timestamp):
        """
        Return a timestamp value for a datetime value.
        """
        if isinstance(timestamp, basestring):
            timestamp = parse_vumi_date(timestamp)
        return time.mktime(timestamp.timetuple())

    @Manager.calls_manager
    def add_outbound_message(self, batch_id, msg):
        """
        Add an outbound message to the cache for the given batch_id
        """
        timestamp = self.get_timestamp(msg['timestamp'])
        yield self.add_outbound_message_key(
            batch_id, msg['message_id'], timestamp)
        yield self.add_to_addr(batch_id, msg['to_addr'])

    @Manager.calls_manager
    def add_outbound_message_key(self, batch_id, message_key, timestamp):
        """
        Add a message key, weighted with the timestamp to the batch_id.
        """
        new_entry = yield self.redis.zadd(self.outbound_key(batch_id), **{
            message_key.encode('utf-8'): timestamp,
        })
        if new_entry:
            yield self.increment_event_status(batch_id, 'sent')

            uses_counters = yield self.uses_counters(batch_id)
            if uses_counters:
                yield self.redis.incr(self.outbound_count_key(batch_id))
                yield self.truncate_outbound_message_keys(batch_id)

    @Manager.calls_manager
    def add_outbound_message_count(self, batch_id, count):
        """
        Add a count to all outbound message counters. (Used for recon.)
        """
        yield self.increment_event_status(batch_id, 'sent', count)
        yield self.redis.incr(self.outbound_count_key(batch_id), count)

    @Manager.calls_manager
    def add_event_count(self, batch_id, status, count):
        """
        Add a count to all relevant event counters. (Used for recon.)
        """
        yield self.increment_event_status(batch_id, status, count)
        yield self.redis.incr(self.event_count_key(batch_id), count)

    @Manager.calls_manager
    def add_event(self, batch_id, event):
        """
        Add an event to the cache for the given batch_id
        """
        event_id = event['event_id']
        timestamp = self.get_timestamp(event['timestamp'])
        new_entry = yield self.add_event_key(batch_id, event_id, timestamp)
        if new_entry:
            event_type = event['event_type']
            yield self.increment_event_status(batch_id, event_type)
            if event_type == 'delivery_report':
                yield self.increment_event_status(
                    batch_id, '%s.%s' % (event_type, event['delivery_status']))

    @Manager.calls_manager
    def add_event_key(self, batch_id, event_key, timestamp):
        """
        Add the event key to the set of known event keys.
        Returns 0 if the key already exists in the set, 1 if it doesn't.
        """
        uses_event_counters = yield self.uses_event_counters(batch_id)
        if uses_event_counters:
            new_entry = yield self.redis.zadd(self.event_key(batch_id), **{
                event_key.encode('utf-8'): timestamp,
            })
            if new_entry:
                yield self.redis.incr(self.event_count_key(batch_id))
                yield self.truncate_event_keys(batch_id)
            returnValue(new_entry)
        else:
            # HACK: Disabling this because of unbounded growth.
            #       Please perform reconciliation on all batches that still use
            #       SET-based event tracking.
            # NOTE: Cheaper recon is coming Real Soon Now.
            returnValue(False)
            # This uses a set, not a sorted set.
            new_entry = yield self.redis.sadd(
                self.event_key(batch_id), event_key)
            returnValue(new_entry)

    def increment_event_status(self, batch_id, event_type, count=1):
        """
        Increment the status for the given event_type for the given batch_id.
        """
        return self.redis.hincrby(self.status_key(batch_id), event_type, count)

    @Manager.calls_manager
    def get_event_status(self, batch_id):
        """
        Return a dictionary containing the latest event stats for the given
        batch_id.
        """
        stats = yield self.redis.hgetall(self.status_key(batch_id))
        returnValue(dict([(k, int(v)) for k, v in stats.iteritems()]))

    @Manager.calls_manager
    def add_inbound_message(self, batch_id, msg):
        """
        Add an inbound message to the cache for the given batch_id
        """
        timestamp = self.get_timestamp(msg['timestamp'])
        yield self.add_inbound_message_key(
            batch_id, msg['message_id'], timestamp)
        yield self.add_from_addr(batch_id, msg['from_addr'])

    @Manager.calls_manager
    def add_inbound_message_key(self, batch_id, message_key, timestamp):
        """
        Add a message key, weighted with the timestamp to the batch_id
        """
        new_entry = yield self.redis.zadd(self.inbound_key(batch_id), **{
            message_key.encode('utf-8'): timestamp,
        })

        if new_entry:
            uses_counters = yield self.uses_counters(batch_id)
            if uses_counters:
                yield self.redis.incr(self.inbound_count_key(batch_id))
                yield self.truncate_inbound_message_keys(batch_id)

    @Manager.calls_manager
    def add_inbound_message_count(self, batch_id, count):
        """
        Add a count to all inbound message counters. (Used for recon.)
        """
        yield self.redis.incr(self.inbound_count_key(batch_id), count)

    def add_from_addr(self, batch_id, from_addr):
        """
        Add a from_addr to this batch_id using Redis's HyperLogLog
        functionality. Generally this information is set when
        `add_inbound_message()` is called.
        """
        return self.redis.pfadd(
            self.from_addr_key(batch_id), from_addr.encode('utf-8'))

    def get_from_addrs(self, batch_id, asc=False):
        """
        Return a set of all known from_addrs sorted by timestamp.
        """
        # NOTE: Disabled because this doesn't scale to large batches.
        #       See https://github.com/praekelt/vumi/issues/877

        # return self.redis.zrange(self.from_addr_key(batch_id), 0, -1,
        #                          desc=not asc)
        return []

    def count_from_addrs(self, batch_id):
        """
        Return count of the unique from_addrs in this batch. Note that the
        returned count is not exact. We use Redis's HyperLogLog functionality,
        so the count is subject to a standard error of 0.81%:
        http://redis.io/commands/pfcount
        """
        return self.redis.pfcount(self.from_addr_key(batch_id))

    def add_to_addr(self, batch_id, to_addr):
        """
        Add a to_addr to this batch_id using Redis's HyperLogLog
        functionality. Generally this information is set when
        `add_outbound_message()` is called.
        """
        return self.redis.pfadd(
            self.to_addr_key(batch_id), to_addr.encode('utf-8'))

    def get_to_addrs(self, batch_id, asc=False):
        """
        Return a set of unique to_addrs addressed in this batch ordered
        by the most recent timestamp.
        """
        # NOTE: Disabled because this doesn't scale to large batches.
        #       See https://github.com/praekelt/vumi/issues/877

        # return self.redis.zrange(self.to_addr_key(batch_id), 0, -1,
        #                          desc=not asc)
        return []

    def count_to_addrs(self, batch_id):
        """
        Return count of the unique to_addrs in this batch. Note that the
        returned count is not exact. We use Redis's HyperLogLog functionality,
        so the count is subject to a standard error of 0.81%:
        http://redis.io/commands/pfcount
        """
        return self.redis.pfcount(self.to_addr_key(batch_id))

    def get_inbound_message_keys(self, batch_id, start=0, stop=-1, asc=False,
                                 with_timestamp=False):
        """
        Return a list of keys ordered according to their timestamps
        """
        return self.redis.zrange(self.inbound_key(batch_id),
                                 start, stop, desc=not asc,
                                 withscores=with_timestamp)

    @Manager.calls_manager
    def inbound_message_count(self, batch_id):
        count = yield self.redis.get(self.inbound_count_key(batch_id))
        returnValue(0 if count is None else int(count))

    def inbound_message_keys_size(self, batch_id):
        return self.redis.zcard(self.inbound_key(batch_id))

    @Manager.calls_manager
    def count_inbound_message_keys(self, batch_id):
        """
        Return the count of the unique inbound message keys for this batch_id
        """
        if not (yield self.uses_counters(batch_id)):
            returnValue((yield self.inbound_message_keys_size(batch_id)))

        count = yield self.inbound_message_count(batch_id)
        returnValue(count)

    def get_outbound_message_keys(self, batch_id, start=0, stop=-1, asc=False,
                                  with_timestamp=False):
        """
        Return a list of keys ordered according to their timestamps.
        """
        return self.redis.zrange(self.outbound_key(batch_id),
                                 start, stop, desc=not asc,
                                 withscores=with_timestamp)

    @Manager.calls_manager
    def outbound_message_count(self, batch_id):
        count = yield self.redis.get(self.outbound_count_key(batch_id))
        returnValue(0 if count is None else int(count))

    def outbound_message_keys_size(self, batch_id):
        return self.redis.zcard(self.outbound_key(batch_id))

    @Manager.calls_manager
    def count_outbound_message_keys(self, batch_id):
        """
        Return the count of the unique outbound message keys for this batch_id
        """
        if not (yield self.uses_counters(batch_id)):
            returnValue((yield self.outbound_message_keys_size(batch_id)))

        count = yield self.outbound_message_count(batch_id)
        returnValue(count)

    @Manager.calls_manager
    def event_count(self, batch_id):
        count = yield self.redis.get(self.event_count_key(batch_id))
        returnValue(0 if count is None else int(count))

    @Manager.calls_manager
    def count_event_keys(self, batch_id):
        """
        Return the count of the unique event keys for this batch_id
        """
        uses_event_counters = yield self.uses_event_counters(batch_id)
        if uses_event_counters:
            count = yield self.event_count(batch_id)
            returnValue(count)
        else:
            count = yield self.redis.scard(self.event_key(batch_id))
            returnValue(count)

    @Manager.calls_manager
    def count_inbound_throughput(self, batch_id, sample_time=300):
        """
        Calculate the number of messages seen in the last `sample_time` amount
        of seconds.

        :param int sample_time:
            How far to look back to calculate the throughput.
            Defaults to 300 seconds (5 minutes)
        """
        last_seen = yield self.redis.zrange(
            self.inbound_key(batch_id), 0, 0, desc=True,
            withscores=True)
        if not last_seen:
            returnValue(0)

        [(latest, timestamp)] = last_seen
        count = yield self.redis.zcount(
            self.inbound_key(batch_id), timestamp - sample_time, timestamp)
        returnValue(int(count))

    @Manager.calls_manager
    def count_outbound_throughput(self, batch_id, sample_time=300):
        """
        Calculate the number of messages seen in the last `sample_time` amount
        of seconds.

        :param int sample_time:
            How far to look back to calculate the throughput.
            Defaults to 300 seconds (5 minutes)
        """
        last_seen = yield self.redis.zrange(
            self.outbound_key(batch_id), 0, 0, desc=True, withscores=True)
        if not last_seen:
            returnValue(0)

        [(latest, timestamp)] = last_seen
        count = yield self.redis.zcount(
            self.outbound_key(batch_id), timestamp - sample_time, timestamp)
        returnValue(int(count))

    def get_query_token(self, direction, query):
        """
        Return a token for the query.

        The query is a list of dictionaries, to ensure consistent keys
        we want to make sure the input is always ordered the same before
        creating a cache key.

        :param str direction:
            Namespace to store this query under.
            Generally 'inbound' or 'outbound'.
        :param list query:
            A list of dictionaries with query information.

        """
        ordered_query = sorted([sorted(part.items()) for part in query])
        # TODO: figure out if JSON is necessary here or if something like str()
        #       will work just as well.
        return '%s-%s' % (
            direction, hashlib.md5(json.dumps(ordered_query)).hexdigest())

    @Manager.calls_manager
    def start_query(self, batch_id, direction, query):
        """
        Start a query operation on the inbound messages for the given batch_id.
        Returns a token with which the results of the query can be fetched
        as soon as they arrive.
        """
        token = self.get_query_token(direction, query)
        yield self.redis.sadd(self.search_token_key(batch_id), token)
        returnValue(token)

    @Manager.calls_manager
    def store_query_results(self, batch_id, token, keys, direction,
                            ttl=None):
        """
        Store the inbound query results for a query that was started with
        `start_inbound_query`. Internally this grabs the timestamps from
        the cache (there is an assumption that it has already been reconciled)
        and orders the results accordingly.

        :param str token:
            The token to store the results under.
        :param list keys:
            The list of keys to store.
        :param str direction:
            Which messages to search, either inbound or outbound.
        :param int ttl:
            How long to store the results for.
            Defaults to DEFAULT_SEARCH_RESULT_TTL.
        """
        ttl = ttl or self.DEFAULT_SEARCH_RESULT_TTL
        result_key = self.search_result_key(batch_id, token)
        if direction == 'inbound':
            score_set_key = self.inbound_key(batch_id)
        elif direction == 'outbound':
            score_set_key = self.outbound_key(batch_id)
        else:
            raise MessageStoreCacheException('Invalid direction')

        # populate the results set weighted according to the timestamps
        # that are already known in the cache.
        for key in keys:
            timestamp = yield self.redis.zscore(score_set_key, key)
            yield self.redis.zadd(result_key, **{
                key.encode('utf-8'): timestamp,
            })

        # Auto expire after TTL
        yield self.redis.expire(result_key, ttl)
        # Remove from the list of in progress search operations.
        yield self.redis.srem(self.search_token_key(batch_id), token)

    def is_query_in_progress(self, batch_id, token):
        """
        Check whether a search is still in progress for the given token.
        """
        return self.redis.sismember(self.search_token_key(batch_id), token)

    def get_query_results(self, batch_id, token, start=0, stop=-1,
                          asc=False):
        """
        Return the results for the query token. Will return an empty list
        of no results are available.
        """
        result_key = self.search_result_key(batch_id, token)
        return self.redis.zrange(result_key, start, stop, desc=not asc)

    def count_query_results(self, batch_id, token):
        """
        Return the number of results for the query token.
        """
        result_key = self.search_result_key(batch_id, token)
        return self.redis.zcard(result_key)
PK=JGBSvv!vumi/components/window_manager.py# -*- test-case-name: vumi.components.tests.test_window_manager -*-
import json
import uuid

from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.task import LoopingCall

from vumi import log


class WindowException(Exception):
    pass


class WindowManager(object):

    WINDOW_KEY = 'windows'
    FLIGHT_KEY = 'inflight'
    FLIGHT_STATS_KEY = 'flightstats'
    MAP_KEY = 'keymap'

    def __init__(self, redis, window_size=100, flight_lifetime=None,
                gc_interval=10):
        self.window_size = window_size
        self.flight_lifetime = flight_lifetime or (gc_interval * window_size)
        self.redis = redis
        self.clock = self.get_clock()
        self.gc = LoopingCall(self.clear_expired_flight_keys)
        self.gc.clock = self.clock
        self.gc.start(gc_interval)
        self._monitor = None

    def noop(self, *args, **kwargs):
        pass

    def stop(self):
        if self._monitor and self._monitor.running:
            self._monitor.stop()

        if self.gc.running:
            self.gc.stop()

    def get_windows(self):
        return self.redis.zrange(self.window_key(), 0, -1)

    @inlineCallbacks
    def window_exists(self, window_id):
        score = yield self.redis.zscore(self.window_key(), window_id)
        if score is not None:
            returnValue(True)
        returnValue(False)

    def window_key(self, *keys):
        return ':'.join([self.WINDOW_KEY] + map(unicode, keys))

    def flight_key(self, *keys):
        return self.window_key(self.FLIGHT_KEY, *keys)

    def stats_key(self, *keys):
        return self.window_key(self.FLIGHT_STATS_KEY, *keys)

    def map_key(self, *keys):
        return self.window_key(self.MAP_KEY, *keys)

    def get_clock(self):
        return reactor

    def get_clocktime(self):
        return self.clock.seconds()

    @inlineCallbacks
    def create_window(self, window_id, strict=False):
        if strict and (yield self.window_exists(window_id)):
            raise WindowException('Window already exists: %s' % (window_id,))
        clock_time = self.get_clocktime()
        yield self.redis.zadd(self.WINDOW_KEY, **{
            window_id: clock_time,
            })
        returnValue(clock_time)

    @inlineCallbacks
    def remove_window(self, window_id):
        waiting_list = yield self.count_waiting(window_id)
        if waiting_list:
            raise WindowException('Window not empty')
        yield self.redis.zrem(self.WINDOW_KEY, window_id)

    @inlineCallbacks
    def add(self, window_id, data, key=None):
        key = key or uuid.uuid4().get_hex()
        # The redis.set() has to complete before redis.lpush(),
        # otherwise the key can be popped from the window before the
        # data is available.
        yield self.redis.set(self.window_key(window_id, key),
                             json.dumps(data))
        yield self.redis.lpush(self.window_key(window_id), key)
        returnValue(key)

    @inlineCallbacks
    def get_next_key(self, window_id):

        window_key = self.window_key(window_id)
        inflight_key = self.flight_key(window_id)

        waiting_list = yield self.count_waiting(window_id)
        if waiting_list == 0:
            return

        flight_size = yield self.count_in_flight(window_id)
        room_available = self.window_size - flight_size

        if room_available > 0:
            log.debug('Window %s has space for %s' % (window_key,
                                                        room_available))
            next_key = yield self.redis.rpoplpush(window_key, inflight_key)
            if next_key:
                yield self._set_timestamp(window_id, next_key)
                returnValue(next_key)

    def _set_timestamp(self, window_id, flight_key):
        return self.redis.zadd(self.stats_key(window_id), **{
                flight_key: self.get_clocktime(),
        })

    def _clear_timestamp(self, window_id, flight_key):
        return self.redis.zrem(self.stats_key(window_id), flight_key)

    def count_waiting(self, window_id):
        window_key = self.window_key(window_id)
        return self.redis.llen(window_key)

    def count_in_flight(self, window_id):
        flight_key = self.flight_key(window_id)
        return self.redis.llen(flight_key)

    def get_expired_flight_keys(self, window_id):
        return self.redis.zrangebyscore(self.stats_key(window_id),
            '-inf', self.get_clocktime() - self.flight_lifetime)

    @inlineCallbacks
    def clear_expired_flight_keys(self):
        windows = yield self.get_windows()
        for window_id in windows:
            expired_keys = yield self.get_expired_flight_keys(window_id)
            for key in expired_keys:
                yield self.redis.lrem(self.flight_key(window_id), key, 1)

    @inlineCallbacks
    def get_data(self, window_id, key):
        json_data = yield self.redis.get(self.window_key(window_id, key))
        returnValue(json.loads(json_data))

    @inlineCallbacks
    def remove_key(self, window_id, key):
        yield self.redis.lrem(self.flight_key(window_id), key, 1)
        yield self.redis.delete(self.window_key(window_id, key))
        yield self.redis.delete(self.stats_key(window_id, key))
        yield self.clear_external_id(window_id, key)
        yield self._clear_timestamp(window_id, key)

    @inlineCallbacks
    def set_external_id(self, window_id, flight_key, external_id):
        yield self.redis.set(self.map_key(window_id, 'internal', external_id),
            flight_key)
        yield self.redis.set(self.map_key(window_id, 'external', flight_key),
            external_id)

    def get_internal_id(self, window_id, external_id):
        return self.redis.get(self.map_key(window_id, 'internal', external_id))

    def get_external_id(self, window_id, flight_key):
        return self.redis.get(self.map_key(window_id, 'external', flight_key))

    @inlineCallbacks
    def clear_external_id(self, window_id, flight_key):
        external_id = yield self.get_external_id(window_id, flight_key)
        if external_id:
            yield self.redis.delete(self.map_key(window_id, 'external',
                                                 flight_key))
            yield self.redis.delete(self.map_key(window_id, 'internal',
                                                 external_id))

    def monitor(self, key_callback, interval=10, cleanup=True,
                cleanup_callback=None):

        if self._monitor is not None:
            raise WindowException('Monitor already started')

        self._monitor = LoopingCall(lambda: self._monitor_windows(
            key_callback, cleanup, cleanup_callback))
        self._monitor.clock = self.get_clock()
        self._monitor.start(interval)

    @inlineCallbacks
    def _monitor_windows(self, key_callback, cleanup=True,
                         cleanup_callback=None):
        windows = yield self.get_windows()
        for window_id in windows:
            key = (yield self.get_next_key(window_id))
            while key:
                yield key_callback(window_id, key)
                key = (yield self.get_next_key(window_id))

            # Remove empty windows if required
            if cleanup and not ((yield self.count_waiting(window_id)) or
                                (yield self.count_in_flight(window_id))):
                if cleanup_callback:
                    cleanup_callback(window_id)
                yield self.remove_window(window_id)
PKh^xGd vumi/components/message_store.py# -*- test-case-name: vumi.components.tests.test_message_store -*-
# -*- coding: utf-8 -*-

"""Message store."""

from calendar import timegm
from collections import defaultdict
from datetime import datetime
from uuid import uuid4
import itertools
import warnings

from twisted.internet.defer import inlineCallbacks, returnValue

from vumi.message import (
    TransportEvent, TransportUserMessage, parse_vumi_date, format_vumi_date)
from vumi.persist.model import Model, Manager
from vumi.persist.fields import (
    VumiMessage, ForeignKey, ManyToMany, ListOf, Tag, Dynamic, Unicode)
from vumi.persist.txriak_manager import TxRiakManager
from vumi import log
from vumi.components.message_store_cache import MessageStoreCache
from vumi.components.message_store_migrators import (
    EventMigrator, InboundMessageMigrator, OutboundMessageMigrator)


def to_reverse_timestamp(vumi_timestamp):
    """
    Turn a vumi_date-formatted string into a string that sorts in reverse order
    and can be turned back into a timestamp later.

    This is done by converting to a unix timestamp and subtracting it from
    0xffffffffff (2**40 - 1) to get a number well outside the range
    representable by the datetime module. The result is returned as a
    hexadecimal string.
    """
    timestamp = timegm(parse_vumi_date(vumi_timestamp).timetuple())
    return "%X" % (0xffffffffff - timestamp)


def from_reverse_timestamp(reverse_timestamp):
    """
    Turn a reverse timestamp string (from `to_reverse_timestamp()`) into a
    vumi_date-formatted string.
    """
    timestamp = 0xffffffffff - int(reverse_timestamp, 16)
    return format_vumi_date(datetime.utcfromtimestamp(timestamp))


class Batch(Model):
    # key is batch_id
    tags = ListOf(Tag())
    metadata = Dynamic(Unicode())


class CurrentTag(Model):
    # key is flattened tag
    current_batch = ForeignKey(Batch, null=True)
    tag = Tag()
    metadata = Dynamic(Unicode())

    @staticmethod
    def _flatten_tag(tag):
        return "%s:%s" % tag

    @staticmethod
    def _split_key(key):
        return tuple(key.split(':', 1))

    @classmethod
    def _tag_and_key(cls, tag_or_key):
        if isinstance(tag_or_key, tuple):
            # key looks like a tag
            tag, key = tag_or_key, cls._flatten_tag(tag_or_key)
        else:
            tag, key = cls._split_key(tag_or_key), tag_or_key
        return tag, key

    def __init__(self, manager, key, _riak_object=None, **kw):
        tag, key = self._tag_and_key(key)
        if _riak_object is None:
            kw['tag'] = tag
        super(CurrentTag, self).__init__(manager, key,
                                         _riak_object=_riak_object, **kw)

    @classmethod
    def load(cls, manager, key, result=None):
        _tag, key = cls._tag_and_key(key)
        return super(CurrentTag, cls).load(manager, key, result)


class OutboundMessage(Model):
    VERSION = 5
    MIGRATOR = OutboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(Batch)

    # Extra fields for compound indexes
    batches_with_addresses = ListOf(Unicode(), index=True)
    batches_with_addresses_reverse = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        self.batches_with_addresses = []
        self.batches_with_addresses_reverse = []
        timestamp = self.msg['timestamp']
        if not isinstance(timestamp, basestring):
            timestamp = format_vumi_date(timestamp)
        reverse_ts = to_reverse_timestamp(timestamp)
        for batch_id in self.batches.keys():
            self.batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['to_addr']))
            self.batches_with_addresses_reverse.append(
                u"%s$%s$%s" % (batch_id, reverse_ts, self.msg['to_addr']))
        return super(OutboundMessage, self).save()


class Event(Model):
    VERSION = 2
    MIGRATOR = EventMigrator

    # key is event_id
    event = VumiMessage(TransportEvent)
    message = ForeignKey(OutboundMessage)
    batches = ManyToMany(Batch)

    # Extra fields for compound indexes
    message_with_status = Unicode(index=True, null=True)
    batches_with_statuses_reverse = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        timestamp = self.event['timestamp']
        if not isinstance(timestamp, basestring):
            timestamp = format_vumi_date(timestamp)
        status = self.event.status()
        self.message_with_status = u"%s$%s$%s" % (
            self.message.key, timestamp, status)
        self.batches_with_statuses_reverse = []
        reverse_ts = to_reverse_timestamp(timestamp)
        for batch_id in self.batches.keys():
            self.batches_with_statuses_reverse.append(
                u"%s$%s$%s" % (batch_id, reverse_ts, status))
        return super(Event, self).save()


class InboundMessage(Model):
    VERSION = 5
    MIGRATOR = InboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(Batch)

    # Extra fields for compound indexes
    batches_with_addresses = ListOf(Unicode(), index=True)
    batches_with_addresses_reverse = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        self.batches_with_addresses = []
        self.batches_with_addresses_reverse = []
        timestamp = self.msg['timestamp']
        if not isinstance(timestamp, basestring):
            timestamp = format_vumi_date(timestamp)
        reverse_ts = to_reverse_timestamp(timestamp)
        for batch_id in self.batches.keys():
            self.batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['from_addr']))
            self.batches_with_addresses_reverse.append(
                u"%s$%s$%s" % (batch_id, reverse_ts, self.msg['from_addr']))
        return super(InboundMessage, self).save()


class ReconKeyManager(object):
    """
    A helper for tracking keys during cache recon.

    Keys are added one at a time from oldest to newest, and a buffer of recent
    keys is kept to allow old and new keys to be handled differently.
    """

    def __init__(self, start_timestamp, key_count):
        self.start_timestamp = start_timestamp
        self.key_count = key_count
        self.cache_keys = []
        self.new_keys = []

    def add_key(self, key, timestamp):
        """
        Add a key and timestamp to the manager.

        If ``timestamp`` is newer than :attr:`start_timestamp`, the pair is
        added to :attr:`new_keys`, otherwise it is added to :attr:`cache_keys`.
        If this causes :attr:`cache_keys` to grow larger than
        :attr:`key_count`, the earliest entry is removed and returned. If not,
        ``None`` is returned.

        It is assumed that keys will be added from oldest to newest.
        """
        if timestamp > self.start_timestamp:
            self.new_keys.append((key, timestamp))
            return None
        self.cache_keys.append((key, timestamp))
        if len(self.cache_keys) > self.key_count:
            return self.cache_keys.pop(0)
        return None

    def __iter__(self):
        return itertools.chain(self.cache_keys, self.new_keys)


class MessageStore(object):
    """Vumi message store.

    Message batches, inbound messages, outbound messages, events and
    information about which batch a tag is currently associated with is
    stored in Riak.

    A small amount of information about the state of a batch (i.e. number
    of messages in the batch, messages sent, acknowledgements and delivery
    reports received) is stored in Redis.
    """

    # The Python Riak client defaults to max_results=1000 in places.
    DEFAULT_MAX_RESULTS = 1000

    def __init__(self, manager, redis):
        self.manager = manager
        self.batches = manager.proxy(Batch)
        self.outbound_messages = manager.proxy(OutboundMessage)
        self.events = manager.proxy(Event)
        self.inbound_messages = manager.proxy(InboundMessage)
        self.current_tags = manager.proxy(CurrentTag)
        self.cache = MessageStoreCache(redis)

    @Manager.calls_manager
    def needs_reconciliation(self, batch_id, delta=0.01):
        """
        Check if a batch_id's cache values need to be reconciled with
        what's stored in the MessageStore.

        :param float delta:
            What an acceptable delta is for the cached values. Defaults to 0.01
            If the cached values are off by the delta then this returns True.
        """
        inbound = float((yield self.batch_inbound_count(batch_id)))
        cached_inbound = yield self.cache.count_inbound_message_keys(
            batch_id)

        if inbound and (abs(cached_inbound - inbound) / inbound) > delta:
            returnValue(True)

        outbound = float((yield self.batch_outbound_count(batch_id)))
        cached_outbound = yield self.cache.count_outbound_message_keys(
            batch_id)

        if outbound and (abs(cached_outbound - outbound) / outbound) > delta:
            returnValue(True)

        returnValue(False)

    @Manager.calls_manager
    def reconcile_cache(self, batch_id, start_timestamp=None):
        """
        Rebuild the cache for the given batch.

        The ``start_timestamp`` parameter is used for testing only.
        """
        if start_timestamp is None:
            start_timestamp = format_vumi_date(datetime.utcnow())
        yield self.cache.clear_batch(batch_id)
        yield self.cache.batch_start(batch_id)
        yield self.reconcile_outbound_cache(batch_id, start_timestamp)
        yield self.reconcile_inbound_cache(batch_id, start_timestamp)

    @Manager.calls_manager
    def reconcile_inbound_cache(self, batch_id, start_timestamp):
        """
        Rebuild the inbound message cache.
        """
        key_manager = ReconKeyManager(
            start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT)
        key_count = 0

        index_page = yield self.batch_inbound_keys_with_addresses(batch_id)
        while index_page is not None:
            for key, timestamp, addr in index_page:
                yield self.cache.add_from_addr(batch_id, addr)
                old_key = key_manager.add_key(key, timestamp)
                if old_key is not None:
                    key_count += 1
            index_page = yield index_page.next_page()

        yield self.cache.add_inbound_message_count(batch_id, key_count)
        for key, timestamp in key_manager:
            try:
                yield self.cache.add_inbound_message_key(
                    batch_id, key, self.cache.get_timestamp(timestamp))
            except:
                log.err()

    @Manager.calls_manager
    def reconcile_outbound_cache(self, batch_id, start_timestamp):
        """
        Rebuild the outbound message cache.
        """
        key_manager = ReconKeyManager(
            start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT)
        key_count = 0
        status_counts = defaultdict(int)

        index_page = yield self.batch_outbound_keys_with_addresses(batch_id)
        while index_page is not None:
            for key, timestamp, addr in index_page:
                yield self.cache.add_to_addr(batch_id, addr)
                old_key = key_manager.add_key(key, timestamp)
                if old_key is not None:
                    key_count += 1
                    sc = yield self.get_event_counts(old_key[0])
                    for status, count in sc.iteritems():
                        status_counts[status] += count
            index_page = yield index_page.next_page()

        yield self.cache.add_outbound_message_count(batch_id, key_count)
        for status, count in status_counts.iteritems():
            yield self.cache.add_event_count(batch_id, status, count)
        for key, timestamp in key_manager:
            try:
                yield self.cache.add_outbound_message_key(
                    batch_id, key, self.cache.get_timestamp(timestamp))
                yield self.reconcile_event_cache(batch_id, key)
            except:
                log.err()

    @Manager.calls_manager
    def get_event_counts(self, message_id):
        """
        Get event counts for a particular message.

        This is used for old messages that we want to bulk-update.
        """
        status_counts = defaultdict(int)

        index_page = yield self.message_event_keys_with_statuses(message_id)
        while index_page is not None:
            for key, _timestamp, status in index_page:
                status_counts[status] += 1
                if status.startswith("delivery_report."):
                    status_counts["delivery_report"] += 1
            index_page = yield index_page.next_page()

        returnValue(status_counts)

    @Manager.calls_manager
    def reconcile_event_cache(self, batch_id, message_id):
        """
        Update the event cache for a particular message.
        """
        event_keys = yield self.message_event_keys(message_id)
        for event_key in event_keys:
            event = yield self.get_event(event_key)
            yield self.cache.add_event(batch_id, event)

    @Manager.calls_manager
    def batch_start(self, tags=(), **metadata):
        batch_id = uuid4().get_hex()
        batch = self.batches(batch_id)
        batch.tags.extend(tags)
        for key, value in metadata.iteritems():
            batch.metadata[key] = value
        yield batch.save()

        for tag in tags:
            tag_record = yield self.current_tags.load(tag)
            if tag_record is None:
                tag_record = self.current_tags(tag)
            tag_record.current_batch.set(batch)
            yield tag_record.save()

        yield self.cache.batch_start(batch_id)
        returnValue(batch_id)

    @Manager.calls_manager
    def batch_done(self, batch_id):
        batch = yield self.batches.load(batch_id)
        tag_keys = yield batch.backlinks.currenttags()
        for tags_bunch in self.manager.load_all_bunches(CurrentTag, tag_keys):
            for tag in (yield tags_bunch):
                tag.current_batch.set(None)
                yield tag.save()

    @Manager.calls_manager
    def add_outbound_message(self, msg, tag=None, batch_id=None, batch_ids=()):
        msg_id = msg['message_id']
        msg_record = yield self.outbound_messages.load(msg_id)
        if msg_record is None:
            msg_record = self.outbound_messages(msg_id, msg=msg)
        else:
            msg_record.msg = msg

        if batch_id is None and tag is not None:
            tag_record = yield self.current_tags.load(tag)
            if tag_record is not None:
                batch_id = tag_record.current_batch.key

        batch_ids = list(batch_ids)
        if batch_id is not None:
            batch_ids.append(batch_id)

        for batch_id in batch_ids:
            msg_record.batches.add_key(batch_id)
            yield self.cache.add_outbound_message(batch_id, msg)

        yield msg_record.save()

    @Manager.calls_manager
    def get_outbound_message(self, msg_id):
        msg = yield self.outbound_messages.load(msg_id)
        returnValue(msg.msg if msg is not None else None)

    @Manager.calls_manager
    def _get_batches_from_outbound(self, msg_id):
        msg_record = yield self.outbound_messages.load(msg_id)
        if msg_record is not None:
            batch_ids = msg_record.batches.keys()
        else:
            batch_ids = []
        returnValue(batch_ids)

    @Manager.calls_manager
    def add_event(self, event, batch_ids=None):
        event_id = event['event_id']
        msg_id = event['user_message_id']
        event_record = yield self.events.load(event_id)
        if event_record is None:
            event_record = self.events(event_id, event=event, message=msg_id)
            if batch_ids is None:
                # If we aren't given batch_ids, get them from the outbound
                # message.
                batch_ids = yield self._get_batches_from_outbound(msg_id)
        else:
            event_record.event = event

        if batch_ids is not None:
            for batch_id in batch_ids:
                event_record.batches.add_key(batch_id)
                yield self.cache.add_event(batch_id, event)

        yield event_record.save()

    @Manager.calls_manager
    def get_event(self, event_id):
        event = yield self.events.load(event_id)
        returnValue(event.event if event is not None else None)

    @Manager.calls_manager
    def get_events_for_message(self, message_id):
        events = []
        event_keys = yield self.message_event_keys(message_id)
        for event_id in event_keys:
            event = yield self.get_event(event_id)
            events.append(event)
        returnValue(events)

    @Manager.calls_manager
    def add_inbound_message(self, msg, tag=None, batch_id=None, batch_ids=()):
        msg_id = msg['message_id']
        msg_record = yield self.inbound_messages.load(msg_id)
        if msg_record is None:
            msg_record = self.inbound_messages(msg_id, msg=msg)
        else:
            msg_record.msg = msg

        if batch_id is None and tag is not None:
            tag_record = yield self.current_tags.load(tag)
            if tag_record is not None:
                batch_id = tag_record.current_batch.key

        batch_ids = list(batch_ids)
        if batch_id is not None:
            batch_ids.append(batch_id)

        for batch_id in batch_ids:
            msg_record.batches.add_key(batch_id)
            yield self.cache.add_inbound_message(batch_id, msg)

        yield msg_record.save()

    @Manager.calls_manager
    def get_inbound_message(self, msg_id):
        msg = yield self.inbound_messages.load(msg_id)
        returnValue(msg.msg if msg is not None else None)

    def get_batch(self, batch_id):
        return self.batches.load(batch_id)

    @Manager.calls_manager
    def get_tag_info(self, tag):
        tagmdl = yield self.current_tags.load(tag)
        if tagmdl is None:
            tagmdl = yield self.current_tags(tag)
        returnValue(tagmdl)

    def batch_status(self, batch_id):
        return self.cache.get_event_status(batch_id)

    def batch_outbound_keys(self, batch_id):
        return self.outbound_messages.index_keys('batches', batch_id)

    def batch_outbound_keys_page(self, batch_id, max_results=None,
                                 continuation=None):
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        return self.outbound_messages.index_keys_page(
            'batches', batch_id, max_results=max_results,
            continuation=continuation)

    def batch_outbound_keys_matching(self, batch_id, query):
        mr = self.outbound_messages.index_match(query, 'batches', batch_id)
        return mr.get_keys()

    def batch_inbound_keys(self, batch_id):
        return self.inbound_messages.index_keys('batches', batch_id)

    def batch_inbound_keys_page(self, batch_id, max_results=None,
                                continuation=None):
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        return self.inbound_messages.index_keys_page(
            'batches', batch_id, max_results=max_results,
            continuation=continuation)

    def batch_inbound_keys_matching(self, batch_id, query):
        mr = self.inbound_messages.index_match(query, 'batches', batch_id)
        return mr.get_keys()

    def batch_event_keys_page(self, batch_id, max_results=None,
                              continuation=None):
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        return self.events.index_keys_page(
            'batches', batch_id, max_results=max_results,
            continuation=continuation)

    def message_event_keys(self, msg_id):
        return self.events.index_keys('message', msg_id)

    @Manager.calls_manager
    def batch_inbound_count(self, batch_id):
        keys = yield self.batch_inbound_keys(batch_id)
        returnValue(len(keys))

    @Manager.calls_manager
    def batch_outbound_count(self, batch_id):
        keys = yield self.batch_outbound_keys(batch_id)
        returnValue(len(keys))

    @Manager.calls_manager
    def find_inbound_keys_matching(self, batch_id, query, ttl=None,
                                   wait=False):
        """
        Has the message search issue a `batch_inbound_keys_matching()`
        query and stores the resulting keys in the cache ordered by
        descending timestamp.

        :param str batch_id:
            The batch to search across
        :param list query:
            The list of dictionaries with query information.
        :param int ttl:
            How long to store the results for.
        :param bool wait:
            Only return the token after the matching, storing & ordering
            of keys has completed. Useful for testing.

        Returns a token with which the results can be fetched.

        NOTE:   This function can only be called from inside Twisted as
                it assumes that the result of `batch_inbound_keys_matching`
                is a Deferred.
        """
        assert isinstance(self.manager, TxRiakManager), (
            "manager is not an instance of TxRiakManager")
        token = yield self.cache.start_query(batch_id, 'inbound', query)
        deferred = self.batch_inbound_keys_matching(batch_id, query)
        deferred.addCallback(
            lambda keys: self.cache.store_query_results(batch_id, token, keys,
                                                        'inbound', ttl))
        if wait:
            yield deferred
        returnValue(token)

    @Manager.calls_manager
    def find_outbound_keys_matching(self, batch_id, query, ttl=None,
                                    wait=False):
        """
        Has the message search issue a `batch_outbound_keys_matching()`
        query and stores the resulting keys in the cache ordered by
        descending timestamp.

        :param str batch_id:
            The batch to search across
        :param list query:
            The list of dictionaries with query information.
        :param int ttl:
            How long to store the results for.
        :param bool wait:
            Only return the token after the matching, storing & ordering
            of keys has completed. Useful for testing.

        Returns a token with which the results can be fetched.

        NOTE:   This function can only be called from inside Twisted as
                it depends on Deferreds being fired that aren't returned
                by the function itself.
        """
        token = yield self.cache.start_query(batch_id, 'outbound', query)
        deferred = self.batch_outbound_keys_matching(batch_id, query)
        deferred.addCallback(
            lambda keys: self.cache.store_query_results(batch_id, token, keys,
                                                        'outbound', ttl))
        if wait:
            yield deferred
        returnValue(token)

    def get_keys_for_token(self, batch_id, token, start=0, stop=-1, asc=False):
        """
        Returns the resulting keys of a search.

        :param str token:
            The token returned by `find_inbound_keys_matching()`
        """
        return self.cache.get_query_results(batch_id, token, start, stop, asc)

    def count_keys_for_token(self, batch_id, token):
        """
        Count the number of keys in the token's result set.
        """
        return self.cache.count_query_results(batch_id, token)

    def is_query_in_progress(self, batch_id, token):
        """
        Return True or False depending on whether or not the query is
        still running
        """
        return self.cache.is_query_in_progress(batch_id, token)

    def get_inbound_message_keys(self, batch_id, start=0, stop=-1,
                                 with_timestamp=False):
        warnings.warn("get_inbound_message_keys() is deprecated. Use "
                      "get_cached_inbound_message_keys().",
                      category=DeprecationWarning)
        return self.get_cached_inbound_message_keys(batch_id, start, stop,
                                                    with_timestamp)

    def get_cached_inbound_message_keys(self, batch_id, start=0, stop=-1,
                                        with_timestamp=False):
        """
        Return the keys ordered by descending timestamp.

        :param str batch_id:
            The batch_id to fetch keys for
        :param int start:
            Where to start from, defaults to 0 which is the first key.
        :param int stop:
            How many to fetch, defaults to -1 which is the last key.
        :param bool with_timestamp:
            Whether or not to return a list of (key, timestamp) tuples
            instead of only the list of keys.
        """
        return self.cache.get_inbound_message_keys(
            batch_id, start, stop, with_timestamp=with_timestamp)

    def get_outbound_message_keys(self, batch_id, start=0, stop=-1,
                                  with_timestamp=False):
        warnings.warn("get_outbound_message_keys() is deprecated. Use "
                      "get_cached_outbound_message_keys().",
                      category=DeprecationWarning)
        return self.get_cached_outbound_message_keys(batch_id, start, stop,
                                                     with_timestamp)

    def get_cached_outbound_message_keys(self, batch_id, start=0, stop=-1,
                                         with_timestamp=False):
        """
        Return the keys ordered by descending timestamp.

        :param str batch_id:
            The batch_id to fetch keys for
        :param int start:
            Where to start from, defaults to 0 which is the first key.
        :param int stop:
            How many to fetch, defaults to -1 which is the last key.
        :param bool with_timestamp:
            Whether or not to return a list of (key, timestamp) tuples
            instead of only the list of keys.
        """
        return self.cache.get_outbound_message_keys(
            batch_id, start, stop, with_timestamp=with_timestamp)

    def _start_end_values(self, batch_id, start, end):
        if start is not None:
            start_value = "%s$%s" % (batch_id, start)
        else:
            start_value = "%s%s" % (batch_id, "#")  # chr(ord('$') - 1)
        if end is not None:
            # We append the "%" to this because we may have another field after
            # the timestamp and we want to include that in range.
            end_value = "%s$%s%s" % (batch_id, end, "%")  # chr(ord('$') + 1)
        else:
            end_value = "%s%s" % (batch_id, "%")  # chr(ord('$') + 1)
        return start_value, end_value

    @Manager.calls_manager
    def _query_batch_index(self, model_proxy, batch_id, index, max_results,
                           start, end, formatter):
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        start_value, end_value = self._start_end_values(batch_id, start, end)
        results = yield model_proxy.index_keys_page(
            index, start_value, end_value, max_results=max_results,
            return_terms=(formatter is not None))
        if formatter is not None:
            results = IndexPageWrapper(formatter, self, batch_id, results)
        returnValue(results)

    def batch_inbound_keys_with_timestamps(self, batch_id, max_results=None,
                                           start=None, end=None,
                                           with_timestamps=True):
        """
        Return all inbound message keys with (and ordered by) timestamps.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        :param bool with_timestamps:
            If set to ``False``, only the keys will be returned. The results
            will still be ordered by timestamp, however.

        This method performs a Riak index query.
        """
        formatter = key_with_ts_only_formatter if with_timestamps else None
        return self._query_batch_index(
            self.inbound_messages, batch_id, 'batches_with_addresses',
            max_results, start, end, formatter)

    def batch_outbound_keys_with_timestamps(self, batch_id, max_results=None,
                                            start=None, end=None,
                                            with_timestamps=True):
        """
        Return all outbound message keys with (and ordered by) timestamps.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        :param bool with_timestamps:
            If set to ``False``, only the keys will be returned. The results
            will still be ordered by timestamp, however.

        This method performs a Riak index query.
        """
        formatter = key_with_ts_only_formatter if with_timestamps else None
        return self._query_batch_index(
            self.outbound_messages, batch_id, 'batches_with_addresses',
            max_results, start, end, formatter)

    def batch_inbound_keys_with_addresses(self, batch_id, max_results=None,
                                          start=None, end=None):
        """
        Return all inbound message keys with (and ordered by) timestamps and
        addresses.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        This method performs a Riak index query.
        """
        return self._query_batch_index(
            self.inbound_messages, batch_id, 'batches_with_addresses',
            max_results, start, end, key_with_ts_and_value_formatter)

    def batch_outbound_keys_with_addresses(self, batch_id, max_results=None,
                                           start=None, end=None):
        """
        Return all outbound message keys with (and ordered by) timestamps and
        addresses.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        This method performs a Riak index query.
        """
        return self._query_batch_index(
            self.outbound_messages, batch_id, 'batches_with_addresses',
            max_results, start, end, key_with_ts_and_value_formatter)

    def batch_inbound_keys_with_addresses_reverse(self, batch_id,
                                                  max_results=None,
                                                  start=None, end=None):
        """
        Return all inbound message keys with timestamps and addresses.
        Results are ordered from newest to oldest.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        This method performs a Riak index query.
        """
        # We're using reverse timestamps, so swap start and end and convert to
        # reverse timestamps.
        if start is not None:
            start = to_reverse_timestamp(start)
        if end is not None:
            end = to_reverse_timestamp(end)
        start, end = end, start
        return self._query_batch_index(
            self.inbound_messages, batch_id, 'batches_with_addresses_reverse',
            max_results, start, end, key_with_rts_and_value_formatter)

    def batch_outbound_keys_with_addresses_reverse(self, batch_id,
                                                   max_results=None,
                                                   start=None, end=None):
        """
        Return all outbound message keys with timestamps and addresses.
        Results are ordered from newest to oldest.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        This method performs a Riak index query.
        """
        # We're using reverse timestamps, so swap start and end and convert to
        # reverse timestamps.
        if start is not None:
            start = to_reverse_timestamp(start)
        if end is not None:
            end = to_reverse_timestamp(end)
        start, end = end, start
        return self._query_batch_index(
            self.outbound_messages, batch_id, 'batches_with_addresses_reverse',
            max_results, start, end, key_with_rts_and_value_formatter)

    def batch_event_keys_with_statuses_reverse(self, batch_id,
                                               max_results=None,
                                               start=None, end=None):
        """
        Return all event keys with timestamps and statuses.
        Results are ordered from newest to oldest.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        This method performs a Riak index query.
        """
        # We're using reverse timestamps, so swap start and end and convert to
        # reverse timestamps.
        if start is not None:
            start = to_reverse_timestamp(start)
        if end is not None:
            end = to_reverse_timestamp(end)
        start, end = end, start
        return self._query_batch_index(
            self.events, batch_id, 'batches_with_statuses_reverse',
            max_results, start, end, key_with_rts_and_value_formatter)

    @Manager.calls_manager
    def message_event_keys_with_statuses(self, msg_id, max_results=None):
        """
        Return all event keys with (and ordered by) timestamps and statuses.

        :param str msg_id:
            The message_id to fetch event keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS

        This method performs a Riak index query. Unlike similar message key
        methods, start and end values are not supported as the number of events
        per message is expected to be small.
        """
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        start_value, end_value = self._start_end_values(msg_id, None, None)
        results = yield self.events.index_keys_page(
            'message_with_status', start_value, end_value,
            return_terms=True, max_results=max_results)
        returnValue(IndexPageWrapper(
            key_with_ts_and_value_formatter, self, msg_id, results))

    @Manager.calls_manager
    def batch_inbound_stats(self, batch_id, max_results=None,
                            start=None, end=None):
        """
        Return inbound message stats for the specified time range.

        Currently, message stats include total message count and unique address
        count.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS.

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        :returns:
            ``dict`` containing 'total' and 'unique_addresses' entries.

        This method performs multiple Riak index queries.
        """
        total = 0
        unique_addresses = set()

        start_value, end_value = self._start_end_values(batch_id, start, end)
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        raw_page = yield self.inbound_messages.index_keys_page(
            'batches_with_addresses', start_value, end_value,
            return_terms=True, max_results=max_results)
        page = IndexPageWrapper(
            key_with_ts_and_value_formatter, self, batch_id, raw_page)

        while page is not None:
            results = list(page)
            total += len(results)
            unique_addresses.update(addr for key, timestamp, addr in results)
            page = yield page.next_page()

        returnValue({
            "total": total,
            "unique_addresses": len(unique_addresses),
        })

    @Manager.calls_manager
    def batch_outbound_stats(self, batch_id, max_results=None,
                             start=None, end=None):
        """
        Return outbound message stats for the specified time range.

        Currently, message stats include total message count and unique address
        count.

        :param str batch_id:
            The batch_id to fetch keys for.

        :param int max_results:
            Number of results per page. Defaults to DEFAULT_MAX_RESULTS.

        :param str start:
            Optional start timestamp string matching VUMI_DATE_FORMAT.

        :param str end:
            Optional end timestamp string matching VUMI_DATE_FORMAT.

        :returns:
            ``dict`` containing 'total' and 'unique_addresses' entries.

        This method performs multiple Riak index queries.
        """
        total = 0
        unique_addresses = set()

        start_value, end_value = self._start_end_values(batch_id, start, end)
        if max_results is None:
            max_results = self.DEFAULT_MAX_RESULTS
        raw_page = yield self.outbound_messages.index_keys_page(
            'batches_with_addresses', start_value, end_value,
            return_terms=True, max_results=max_results)
        page = IndexPageWrapper(
            key_with_ts_and_value_formatter, self, batch_id, raw_page)

        while page is not None:
            results = list(page)
            total += len(results)
            unique_addresses.update(addr for key, timestamp, addr in results)
            page = yield page.next_page()

        returnValue({
            "total": total,
            "unique_addresses": len(unique_addresses),
        })


class IndexPageWrapper(object):
    """
    Index page wrapper that reformats index values into something easier to
    work with.

    This is a wrapper around the lower-level index page object from Riak and
    proxies a subset of its functionality.
    """
    def __init__(self, formatter, message_store, batch_id, index_page):
        self._formatter = formatter
        self._message_store = message_store
        self.manager = message_store.manager
        self._batch_id = batch_id
        self._index_page = index_page

    def _wrap_index_page(self, index_page):
        """
        Wrap a raw index page object if it is not None.
        """
        if index_page is not None:
            index_page = type(self)(
                self._formatter, self._message_store, self._batch_id,
                index_page)
        return index_page

    @Manager.calls_manager
    def next_page(self):
        """
        Fetch the next page of results.

        :returns:
            A new :class:`IndexPageWrapper` object containing the next page
            of results.
        """
        next_page = yield self._index_page.next_page()
        returnValue(self._wrap_index_page(next_page))

    def has_next_page(self):
        """
        Indicate whether there are more results to follow.

        :returns:
            ``True`` if there are more results, ``False`` if this is the last
            page.
        """
        return self._index_page.has_next_page()

    def __iter__(self):
        return (self._formatter(self._batch_id, r) for r in self._index_page)


def key_with_ts_and_value_formatter(batch_id, result):
    value, key = result
    prefix = batch_id + "$"
    if not value.startswith(prefix):
        raise ValueError(
            "Index value %r does not begin with expected prefix %r." % (
                value, prefix))
    suffix = value[len(prefix):]
    timestamp, delimiter, address = suffix.partition("$")
    if delimiter != "$":
        raise ValueError(
            "Index value %r does not match expected format." % (value,))
    return (key, timestamp, address)


def key_with_rts_and_value_formatter(batch_id, result):
    key, reverse_ts, value = key_with_ts_and_value_formatter(batch_id, result)
    return (key, from_reverse_timestamp(reverse_ts), value)


def key_with_ts_only_formatter(batch_id, result):
    key, timestamp, value = key_with_ts_and_value_formatter(batch_id, result)
    return (key, timestamp)


@inlineCallbacks
def add_batches_to_event(stored_event):
    """
    Post-migrate function to be used with `vumi_model_migrator` to add batches
    to stored events that don't have any.
    """
    if stored_event.batches.keys():
        # We already have batches, so there's no need to look them up.
        returnValue(False)

    outbound_messages = stored_event.manager.proxy(OutboundMessage)
    msg_record = yield outbound_messages.load(stored_event.message.key)
    if msg_record is not None:
        for batch_id in msg_record.batches.keys():
            stored_event.batches.add_key(batch_id)

    returnValue(True)
PK=JGhU#vumi/components/schedule_manager.py# -*- test-case-name: vumi.components.tests.test_schedule_manager -*-

from datetime import datetime, timedelta

from vumi import log


class ScheduleManager(object):
    """Utility for determining whether a scheduled event is due.

    :class:`ScheduleManager` basically answers the question "are we there yet?"
    given a schedule definition, the last time the question was asked and the
    current time. It is designed to be used as part of a larger system that
    periodically checks for scheduled events.

    The schedule definition is a `dict` containing a mandatory `recurring`
    field which specifies the type of recurring schedule and other fields
    depending on the value of the `recurring` field.

    Currently, the following are supported:

     * `daily`
       The `time` field is required and specifies the (approximate) time of day
       the event is scheduled for in "HH:MM:SS" format.

     * `day_of_month`
       The `time` field is required and specifies the (approximate) time of day
       the event is scheduled for in "HH:MM:SS" format.
       The `days` field is required and specifies the days of the month the
       event is scheduled for as a list of comma/whitespace-separated integers.

     * `day_of_week`
       The `time` field is required and specifies the (approximate) time of day
       the event is scheduled for in "HH:MM:SS" format.
       The `days` field is required and specifies the days of the week the
       event is scheduled for as a list of comma/whitespace-separated integers,
       1 for Monday through 7 for Sunday.

     * `never`
       No extra fields are required and the event is never scheduled.
    """

    def __init__(self, schedule_definition):
        self.schedule_definition = schedule_definition

    def is_scheduled(self, then, now):
        now_dt = datetime.utcfromtimestamp(now)
        then_dt = datetime.utcfromtimestamp(then)

        next_dt = self.get_next(then_dt)

        if next_dt is None:
            # We have an invalid schedule definition or nothing scheduled.
            return False

        return (next_dt <= now_dt)

    def get_next(self, since_dt):
        try:
            recurring_type = self.schedule_definition['recurring']
            if recurring_type == 'daily':
                return self.get_next_daily(since_dt)
            elif recurring_type == 'day_of_month':
                return self.get_next_day_of_month(since_dt)
            elif recurring_type == 'day_of_week':
                return self.get_next_day_of_week(since_dt)
            elif recurring_type == 'never':
                return None
            else:
                raise ValueError(
                    "Invalid value for 'recurring': %r" % (recurring_type,))
        except Exception:
            log.error(None, "Error processing schedule.")

    def get_next_daily(self, since_dt):
        timeofday = datetime.strptime(
            self.schedule_definition['time'], '%H:%M:%S').time()

        next_dt = datetime.combine(since_dt.date(), timeofday)
        while next_dt <= since_dt:
            next_dt += timedelta(days=1)

        return next_dt

    def _parse_days(self, minval, maxval):
        dstr = self.schedule_definition.get('days')
        try:
            days = set([int(day) for day in dstr.replace(',', ' ').split()])
            for day in days:
                assert minval <= day <= maxval
            return days
        except:
            raise ValueError("Invalid value for 'days': %r" % (dstr,))

    def get_next_day_of_month(self, since_dt):
        timeofday = datetime.strptime(
            self.schedule_definition['time'], '%H:%M:%S').time()
        days_of_month = self._parse_days(1, 31)

        next_dt = datetime.combine(since_dt.date(), timeofday)
        while (next_dt.day not in days_of_month) or (next_dt <= since_dt):
            next_dt += timedelta(days=1)

        return next_dt

    def get_next_day_of_week(self, since_dt):
        timeofday = datetime.strptime(
            self.schedule_definition['time'], '%H:%M:%S').time()
        days_of_week = self._parse_days(1, 7)

        next_dt = datetime.combine(since_dt.date(), timeofday)
        while ((next_dt.isoweekday() not in days_of_week)
               or (next_dt <= since_dt)):
            next_dt += timedelta(days=1)

        return next_dt
PK=JGO!!vumi/components/__init__.py"""Various useful components."""
PK=JG%K""vumi/components/tagpool.py# -*- test-case-name: vumi.components.tests.test_tagpool -*-
# -*- coding: utf-8 -*-

"""Tag pool manager."""

import json
import time

from twisted.internet.defer import returnValue

from vumi.errors import VumiError
from vumi.persist.redis_base import Manager


class TagpoolError(VumiError):
    """An error occurred during an operation on a tag pool."""


class TagpoolManager(object):
    """Manage a set of tag pools.

    :param redis:
        An instance of :class:`vumi.persist.redis_base.Manager`.
    """

    encoding = "UTF-8"

    def __init__(self, redis):
        self.redis = redis
        self.manager = redis  # TODO: This is a bit of a hack to make the
                              #       the calls_manager decorator work

    def _encode(self, unicode_text):
        return unicode_text.encode(self.encoding)

    def _decode(self, binary_data):
        return binary_data.decode(self.encoding)

    @Manager.calls_manager
    def acquire_tag(self, pool, owner=None, reason=None):
        local_tag = yield self._acquire_tag(pool, owner, reason)
        returnValue((pool, local_tag) if local_tag is not None else None)

    @Manager.calls_manager
    def acquire_specific_tag(self, tag, owner=None, reason=None):
        pool, local_tag = tag
        acquired = yield self._acquire_specific_tag(pool, local_tag,
                                                    owner, reason)
        if acquired:
            returnValue(tag)
        returnValue(None)

    @Manager.calls_manager
    def release_tag(self, tag):
        pool, local_tag = tag
        yield self._release_tag(pool, local_tag)

    @Manager.calls_manager
    def declare_tags(self, tags):
        pools = {}
        for pool, local_tag in tags:
            pools.setdefault(pool, []).append(local_tag)
        for pool, local_tags in pools.items():
            yield self._register_pool(pool)
            yield self._declare_tags(pool, local_tags)

    @Manager.calls_manager
    def get_metadata(self, pool):
        metadata_key = self._tag_pool_metadata_key(pool)
        metadata = yield self.redis.hgetall(metadata_key)
        metadata = dict((self._decode(k), json.loads(v))
                        for k, v in metadata.iteritems())
        returnValue(metadata)

    @Manager.calls_manager
    def set_metadata(self, pool, metadata):
        metadata_key = self._tag_pool_metadata_key(pool)
        metadata = dict((self._encode(k), json.dumps(v))
                        for k, v in metadata.iteritems())
        yield self._register_pool(pool)
        yield self.redis.delete(metadata_key)
        yield self.redis.hmset(metadata_key, metadata)

    @Manager.calls_manager
    def purge_pool(self, pool):
        free_list_key, free_set_key, inuse_set_key = self._tag_pool_keys(pool)
        metadata_key = self._tag_pool_metadata_key(pool)
        in_use_count = yield self.redis.scard(inuse_set_key)
        if in_use_count:
            raise TagpoolError('%s tags of pool %s still in use.' % (
                               in_use_count, pool))
        else:
            yield self.redis.delete(free_set_key)
            yield self.redis.delete(free_list_key)
            yield self.redis.delete(inuse_set_key)
            yield self.redis.delete(metadata_key)
            yield self._unregister_pool(pool)

    @Manager.calls_manager
    def list_pools(self):
        pool_list_key = self._pool_list_key()
        pools = yield self.redis.smembers(pool_list_key)
        returnValue(set(self._decode(pool) for pool in pools))

    @Manager.calls_manager
    def free_tags(self, pool):
        _free_list, free_set_key, _inuse_set = self._tag_pool_keys(pool)
        free_tags = yield self.redis.smembers(free_set_key)
        returnValue([(pool, self._decode(local_tag))
                     for local_tag in free_tags])

    @Manager.calls_manager
    def inuse_tags(self, pool):
        _free_list, _free_set, inuse_set_key = self._tag_pool_keys(pool)
        inuse_tags = yield self.redis.smembers(inuse_set_key)
        returnValue([(pool, self._decode(local_tag))
                     for local_tag in inuse_tags])

    @Manager.calls_manager
    def acquired_by(self, tag):
        pool, local_tag = tag
        local_tag = self._encode(local_tag)
        reason_hash_key = self._tag_pool_reason_key(pool)
        raw_reason = yield self.redis.hget(reason_hash_key, local_tag)
        if raw_reason is not None:
            reason = json.loads(raw_reason)
            owner = reason.get('owner')
        else:
            reason, owner = None, None
        returnValue((owner, reason))

    @Manager.calls_manager
    def owned_tags(self, owner):
        owner_tag_list_key = self._owner_tag_list_key(owner)
        owned_tags = yield self.redis.smembers(owner_tag_list_key)
        returnValue([json.loads(raw_tag) for raw_tag in owned_tags])

    def _pool_list_key(self):
        return ":".join(["tagpools", "list"])

    @Manager.calls_manager
    def _register_pool(self, pool):
        """Add a pool to list of pools."""
        pool = self._encode(pool)
        pool_list_key = self._pool_list_key()
        yield self.redis.sadd(pool_list_key, pool)

    @Manager.calls_manager
    def _unregister_pool(self, pool):
        """Remove a pool to list of pools."""
        pool = self._encode(pool)
        pool_list_key = self._pool_list_key()
        yield self.redis.srem(pool_list_key, pool)

    def _tag_pool_keys(self, pool):
        pool = self._encode(pool)
        return tuple(":".join(["tagpools", pool, state])
                     for state in ("free:list", "free:set", "inuse:set"))

    def _tag_pool_metadata_key(self, pool):
        pool = self._encode(pool)
        return ":".join(["tagpools", pool, "metadata"])

    @Manager.calls_manager
    def _acquire_tag(self, pool, owner, reason):
        free_list_key, free_set_key, inuse_set_key = self._tag_pool_keys(pool)
        tag = yield self.redis.lpop(free_list_key)
        if tag is not None:
            yield self.redis.smove(free_set_key, inuse_set_key, tag)
            yield self._store_reason(pool, tag, owner, reason)
        returnValue(self._decode(tag) if tag is not None else None)

    @Manager.calls_manager
    def _acquire_specific_tag(self, pool, local_tag, owner, reason):
        local_tag = self._encode(local_tag)
        free_list_key, free_set_key, inuse_set_key = self._tag_pool_keys(pool)
        moved = yield self.redis.lrem(free_list_key, local_tag, num=1)
        if moved:
            yield self.redis.smove(free_set_key, inuse_set_key, local_tag)
            yield self._store_reason(pool, local_tag, owner, reason)
        returnValue(moved)

    @Manager.calls_manager
    def _release_tag(self, pool, local_tag):
        local_tag = self._encode(local_tag)
        free_list_key, free_set_key, inuse_set_key = self._tag_pool_keys(pool)
        count = yield self.redis.smove(inuse_set_key, free_set_key, local_tag)
        if count == 1:
            yield self.redis.rpush(free_list_key, local_tag)
            yield self._remove_reason(pool, local_tag)

    @Manager.calls_manager
    def _declare_tags(self, pool, local_tags):
        free_list_key, free_set_key, inuse_set_key = self._tag_pool_keys(pool)
        new_tags = set(self._encode(tag) for tag in local_tags)
        old_tags = yield self.redis.sunion(free_set_key, inuse_set_key)
        old_tags = set(old_tags)
        for tag in sorted(new_tags - old_tags):
            yield self.redis.sadd(free_set_key, tag)
            yield self.redis.rpush(free_list_key, tag)

    def _tag_pool_reason_key(self, pool):
        pool = self._encode(pool)
        return ":".join(["tagpools", pool, "reason:hash"])

    def _owner_tag_list_key(self, owner):
        if owner is None:
            return ":".join(["tagpools", "unowned", "tags"])
        owner = self._encode(owner)
        return ":".join(["tagpools", "owners", owner, "tags"])

    @Manager.calls_manager
    def _store_reason(self, pool, local_tag, owner, reason):
        if reason is None:
            reason = {}
        reason['timestamp'] = time.time()
        reason['owner'] = owner
        reason_hash_key = self._tag_pool_reason_key(pool)
        yield self.redis.hset(reason_hash_key, local_tag, json.dumps(reason))
        owner_tag_list_key = self._owner_tag_list_key(owner)
        yield self.redis.sadd(owner_tag_list_key,
                              json.dumps([pool, self._decode(local_tag)]))

    @Manager.calls_manager
    def _remove_reason(self, pool, local_tag):
        reason_hash_key = self._tag_pool_reason_key(pool)
        reason = yield self.redis.hget(reason_hash_key, local_tag)
        if reason is not None:
            reason = json.loads(reason)
            owner = reason.get('owner')
            owner_tag_list_key = self._owner_tag_list_key(owner)
            self.redis.srem(owner_tag_list_key,
                            json.dumps([pool, self._decode(local_tag)]))
PK=JGclvumi/components/session.py# -*- test-case-name: vumi.components.tests.test_session -*-

"""Session management utilities."""

import time

from twisted.internet.defer import inlineCallbacks, returnValue

from vumi import log


class SessionManager(object):
    """A manager for sessions.

    :param TxRedisManager redis:
        Redis manager object.
    :param int max_session_length:
        Time before a session expires. Default is None (never expire).
    :param float gc_period:
        Deprecated and ignored.
    """

    def __init__(self, redis, max_session_length=None, gc_period=None):
        self.max_session_length = max_session_length
        self.redis = redis
        if gc_period is not None:
            log.warning("SessionManager 'gc_period' parameter is deprecated.")

    @inlineCallbacks
    def stop(self, stop_redis=True):
        if stop_redis:
            yield self.redis._close()

    @classmethod
    def from_redis_config(cls, config, key_prefix=None,
                          max_session_length=None, gc_period=None):
        """Create a `SessionManager` instance using `TxRedisManager`.
        """
        from vumi.persist.txredis_manager import TxRedisManager
        d = TxRedisManager.from_config(config)
        if key_prefix is not None:
            d.addCallback(lambda m: m.sub_manager(key_prefix))
        return d.addCallback(lambda m: cls(m, max_session_length, gc_period))

    @inlineCallbacks
    def active_sessions(self):
        """Return a list of active user_ids and associated sessions.

        Queries redis for keys starting with the session key prefix. This is
        O(n) over the total number of keys in redis, but this is still pretty
        quick even for millions of keys. Try not to hit this too often, though.
        """
        keys = yield self.redis.keys('session:*')
        sessions = []
        for user_id in [key.split(':', 1)[1] for key in keys]:
            sessions.append((user_id, (yield self.load_session(user_id))))

        returnValue(sessions)

    def load_session(self, user_id):
        """
        Load session data from Redis
        """
        ukey = "%s:%s" % ('session', user_id)
        return self.redis.hgetall(ukey)

    def schedule_session_expiry(self, user_id, timeout):
        """
        Schedule a session to timeout

        Parameters
        ----------
        user_id : str
            The user's id.
        timeout : int
            The number of seconds after which this session should expire
        """
        ukey = "%s:%s" % ('session', user_id)
        return self.redis.expire(ukey, timeout)

    @inlineCallbacks
    def create_session(self, user_id, **kwargs):
        """
        Create a new session using the given user_id
        """
        yield self.clear_session(user_id)
        defaults = {
            'created_at': time.time()
        }
        defaults.update(kwargs)
        yield self.save_session(user_id, defaults)
        if self.max_session_length:
            yield self.schedule_session_expiry(user_id,
                                               int(self.max_session_length))
        returnValue((yield self.load_session(user_id)))

    def clear_session(self, user_id):
        ukey = "%s:%s" % ('session', user_id)
        return self.redis.delete(ukey)

    @inlineCallbacks
    def save_session(self, user_id, session):
        """
        Save a session

        Parameters
        ----------
        user_id : str
            The user's id.
        session : dict
            The session info, nested dictionaries are not supported. Any
            values that are dictionaries are converted to strings by Redis.

        """
        ukey = "%s:%s" % ('session', user_id)
        for s_key, s_value in session.items():
            yield self.redis.hset(ukey, s_key, s_value)
        returnValue(session)
PK=JGf  $vumi/components/message_store_api.py# -*- test-case-name: vumi.components.tests.test_message_store_api -*-
import json
import functools

from twisted.web import resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet.defer import inlineCallbacks

from vumi.service import Worker
from vumi.message import JSONMessageEncoder
from vumi.transports.httprpc import httprpc
from vumi.components.message_store import MessageStore
from vumi.persist.txriak_manager import TxRiakManager
from vumi.persist.txredis_manager import TxRedisManager


class MatchResource(resource.Resource):
    """
    A Resource that accepts a query as JSON via HTTP POST and issues a match
    operation on the MessageStore.
    """

    DEFAULT_RESULT_SIZE = 20

    REQ_TTL_HEADER = 'X-VMS-Match-TTL'
    REQ_WAIT_HEADER = 'X-VMS-Match-Wait'

    RESP_COUNT_HEADER = 'X-VMS-Result-Count'
    RESP_TOKEN_HEADER = 'X-VMS-Result-Token'
    RESP_IN_PROGRESS_HEADER = 'X-VMS-Match-In-Progress'

    def __init__(self, direction, message_store, batch_id):
        """
        :param str direction:
            Either 'inbound' or 'oubound', this is used to figure out which
            function needs to be called on the MessageStore.
        :param MessageStore message_store:
            Instance of the MessageStore.
        :param str batch_id:
            The batch_id to use to query on.
        """
        resource.Resource.__init__(self)

        self._match_cb = functools.partial({
            'inbound': message_store.find_inbound_keys_matching,
            'outbound': message_store.find_outbound_keys_matching,
        }.get(direction), batch_id)
        self._results_cb = functools.partial(
            message_store.get_keys_for_token, batch_id)
        self._count_cb = functools.partial(
            message_store.count_keys_for_token, batch_id)
        self._in_progress_cb = functools.partial(
            message_store.is_query_in_progress, batch_id)
        self._load_bunches_cb = {
            'inbound': message_store.inbound_messages.load_all_bunches,
            'outbound': message_store.outbound_messages.load_all_bunches,
        }.get(direction)

    def _add_resp_header(self, request, key, value):
        if isinstance(value, unicode):
            value = value.encode('utf-8')
        if not isinstance(value, str):
            raise TypeError("HTTP header values must be bytes.")
        request.responseHeaders.addRawHeader(key, value)

    def _render_token(self, token, request):
        self._add_resp_header(request, self.RESP_TOKEN_HEADER, token)
        request.finish()

    def render_POST(self, request):
        """
        Start a match operation. Expects the query to be POSTed
        as the raw HTTP POST data.

        The query is a list of dictionaries. A dictionary should have the
        structure as defined in `vumi.persist.model.Model.index_match`

        The results of the query are stored fo limited time. It defaults
        to `MessageStoreCache.DEFAULT_SEARCH_RESULT_TTL` but can be overriden
        by specifying the TTL in seconds using the header key as specified
        in `REQ_TTL_HEADER`.

        If the request has the `REQ_WAIT_HEADER` value equals `1` (int)
        then it will only return with a response when the keys are actually
        available for collecting.
        """
        query = json.loads(request.content.read())
        headers = request.requestHeaders
        ttl = int(headers.getRawHeaders(self.REQ_TTL_HEADER, [0])[0])
        if headers.hasHeader(self.REQ_WAIT_HEADER):
            wait = bool(int(headers.getRawHeaders(self.REQ_WAIT_HEADER)[0]))
        else:
            wait = False
        deferred = self._match_cb(query, ttl=(ttl or None), wait=wait)
        deferred.addCallback(self._render_token, request)
        return NOT_DONE_YET

    @inlineCallbacks
    def _render_results(self, request, token, start, stop, keys_only, asc):
        in_progress = yield self._in_progress_cb(token)
        count = yield self._count_cb(token)
        keys = yield self._results_cb(token, start, stop, asc)
        self._add_resp_header(request, self.RESP_IN_PROGRESS_HEADER,
            str(int(in_progress)))
        self._add_resp_header(request, self.RESP_COUNT_HEADER, str(count))
        if keys_only:
            request.write(json.dumps(keys))
        else:
            messages = []
            for bunch in self._load_bunches_cb(keys):
                # inbound & outbound messages have a `.msg` attribute which
                # is the actual message stored, they share the same message_id
                # as the key.
                messages.extend([msg.msg.payload for msg in (yield bunch)
                                    if msg.msg])

            # sort the results in the order that the keys specified
            messages.sort(key=lambda msg: keys.index(msg['message_id']))
            request.write(json.dumps(messages, cls=JSONMessageEncoder))
        request.finish()

    def render_GET(self, request):
        token = request.args['token'][0]
        start = int(request.args['start'][0] if 'start' in request.args else 0)
        stop = int(request.args['stop'][0] if 'stop' in request.args
                    else (start + self.DEFAULT_RESULT_SIZE - 1))
        asc = bool(int(request.args['asc'][0]) if 'asc' in request.args
                    else False)
        keys_only = bool(int(request.args['keys'][0]) if 'keys' in request.args
                            else False)
        self._render_results(request, token, start, stop, keys_only, asc)
        return NOT_DONE_YET

    def getChild(self, name, request):
        return self


class BatchResource(resource.Resource):

    def __init__(self, message_store, batch_id):
        resource.Resource.__init__(self)
        self.message_store = message_store
        self.batch_id = batch_id

        inbound = resource.Resource()
        inbound.putChild('match',
            MatchResource('inbound', message_store, batch_id))
        self.putChild('inbound', inbound)

        outbound = resource.Resource()
        outbound.putChild('match',
            MatchResource('outbound', message_store, batch_id))
        self.putChild('outbound', outbound)

    def render_GET(self, request):
        return self.batch_id

    def getChild(self, name, request):
        if not name:
            return self


class BatchIndexResource(resource.Resource):

    def __init__(self, message_store):
        resource.Resource.__init__(self)
        self.message_store = message_store

    def render_GET(self, request):
        return ''

    def getChild(self, batch_id, request):
        if batch_id:
            return BatchResource(self.message_store, batch_id)
        return self


class MessageStoreAPI(resource.Resource):

    def __init__(self, message_store):
        resource.Resource.__init__(self)
        self.putChild('batch', BatchIndexResource(message_store))


class MessageStoreAPIWorker(Worker):
    """
    Worker that starts the MessageStoreAPI. It has some ability to connect to
    AMQP but to doesn't do anything with it yet.

    :param str web_path:
        What is the base path this API should listen on?
    :param int web_port:
        On what port should it be listening?
    :param str health_path:
        Which path should respond to HAProxy health checks?
    :param dict riak_manager:
        The configuration parameters for TxRiakManager
    :param dict redis_manager:
        The configuration parameters for TxRedisManager
    """
    @inlineCallbacks
    def startWorker(self):
        web_path = self.config['web_path']
        web_port = int(self.config['web_port'])
        health_path = self.config['health_path']

        self._riak = yield TxRiakManager.from_config(
            self.config['riak_manager'])
        redis = yield TxRedisManager.from_config(self.config['redis_manager'])
        self.store = MessageStore(self._riak, redis)

        self.webserver = self.start_web_resources([
            (MessageStoreAPI(self.store), web_path),
            (httprpc.HttpRpcHealthResource(self), health_path),
            ], web_port)

    @inlineCallbacks
    def stopWorker(self):
        yield self.webserver.loseConnection()
        yield self._riak.close_manager()

    def get_health_response(self):
        """Called by the HttpRpcHealthResource"""
        return 'ok'
PK@H(MM%vumi/components/message_formatters.py# -*- test-case-name: vumi.components.tests.test_message_formatters -*-

from csv import writer

from zope.interface import Interface, implements


class IMessageFormatter(Interface):
    """ Interface for writing messages to an HTTP request. """

    def add_http_headers(request):
        """
        Add any needed HTTP headers to the request.

        Often used to set the Content-Type header.
        """

    def write_row_header(request):
        """
        Write any header bytes that need to be written to the request before
        messages.
        """

    def write_row(request, message):
        """
        Write a :class:`TransportUserMessage` to the request.
        """


class JsonFormatter(object):
    """ Formatter for writing messages to requests as JSON. """

    implements(IMessageFormatter)

    def add_http_headers(self, request):
        resp_headers = request.responseHeaders
        resp_headers.addRawHeader(
            'Content-Type', 'application/json; charset=utf-8')

    def write_row_header(self, request):
        pass

    def write_row(self, request, message):
        request.write(message.to_json())
        request.write('\n')


class CsvFormatter(object):
    """ Formatter for writing messages to requests as CSV. """

    implements(IMessageFormatter)

    FIELDS = (
        'timestamp',
        'message_id',
        'to_addr',
        'from_addr',
        'in_reply_to',
        'session_event',
        'content',
        'group',
    )

    def add_http_headers(self, request):
        resp_headers = request.responseHeaders
        resp_headers.addRawHeader(
            'Content-Type', 'text/csv; charset=utf-8')

    def write_row_header(self, request):
        writer(request).writerow(self.FIELDS)

    def write_row(self, request, message):
        writer(request).writerow([
            self._format_field(field, message) for field in self.FIELDS])

    def _format_field(self, field, message):
        field_formatter = getattr(self, '_format_field_%s' % (field,), None)
        if field_formatter is not None:
            field_value = field_formatter(message)
        else:
            field_value = self._format_field_default(field, message)
        return field_value.encode('utf-8')

    def _format_field_default(self, field, message):
        return message[field] or u''

    def _format_field_timestamp(self, message):
        return message['timestamp'].isoformat()


class CsvEventFormatter(CsvFormatter):
    """ Formatter for writing messages to requests as CSV. """

    implements(IMessageFormatter)

    FIELDS = (
        'timestamp',
        'event_id',
        'status',
        'user_message_id',
        'nack_reason',
    )

    def _format_field_status(self, message):
        return message.status()

    def _format_field_nack_reason(self, message):
        return message.get('nack_reason', u'') or u''
PK=JGSo&o&*vumi/components/message_store_migrators.py# -*- test-case-name: vumi.components.tests.test_message_store_migrators -*-
# -*- coding: utf-8 -*-

from vumi.persist.model import ModelMigrator


class MessageMigratorBase(ModelMigrator):
    def _copy_msg_field(self, msg_field, mdata):
        key_prefix = "%s." % (msg_field,)
        msg_fields = [k for k in mdata.old_data if k.startswith(key_prefix)]
        mdata.copy_values(*msg_fields)

    def _foreign_key_to_many_to_many(self, foreign_key, many_to_many, mdata):
        old_keys = mdata.old_index.get('%s_bin' % (foreign_key,), [])
        mdata.set_value(many_to_many, old_keys)
        many_to_many_index = '%s_bin' % (many_to_many,)
        for old_key in old_keys:
            mdata.add_index(many_to_many_index, old_key)


class EventMigrator(MessageMigratorBase):
    def migrate_from_unversioned(self, mdata):
        mdata.set_value('$VERSION', 1)

        if 'message' not in mdata.old_data:
            # We have an old-style index-only field here, so add the data.
            [message_id] = mdata.old_index['message_bin']
            mdata.old_data['message'] = message_id

        self._copy_msg_field('event', mdata)
        mdata.copy_values('message')
        mdata.copy_indexes('message_bin')

        return mdata

    def reverse_from_1(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        # We don't set the version because we're writing unversioned models.
        self._copy_msg_field('event', mdata)
        mdata.copy_values('message')
        mdata.copy_indexes('message_bin')

        return mdata

    def migrate_from_1(self, mdata):
        # If the old data contains a value for the `batches` field, it must be
        # back-migrated from a newer version. If not, we have no way to know
        # what batches the event belongs to, so we leave the field empty. Some
        # external data migration tool will have to populate it.
        mdata.set_value('$VERSION', 2)
        self._copy_msg_field('event', mdata)
        mdata.set_value('batches', mdata.old_data.get('batches', []))
        mdata.copy_values('message')
        mdata.copy_indexes('message_bin')
        mdata.copy_indexes('message_with_status_bin')

        return mdata

    def reverse_from_2(self, mdata):
        # We copy the `batches` field and related indexes even though the older
        # model version doesn't know about them. This lets us migrate
        # v2 -> v1 -> v2 without losing data.
        mdata.set_value('$VERSION', 1)
        self._copy_msg_field('event', mdata)
        mdata.copy_values('message', 'batches')
        mdata.copy_indexes('message_bin')
        mdata.copy_indexes('message_with_status_bin')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_statuses_reverse_bin')

        return mdata


class OutboundMessageMigrator(MessageMigratorBase):
    def migrate_from_unversioned(self, mdata):
        mdata.set_value('$VERSION', 1)

        self._copy_msg_field('msg', mdata)
        self._foreign_key_to_many_to_many('batch', 'batches', mdata)

        return mdata

    def migrate_from_1(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 2)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')

        return mdata

    def migrate_from_2(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 3)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')

        return mdata

    def reverse_from_3(self, mdata):
        # The only difference between v2 and v3 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 2)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')

        return mdata

    def migrate_from_3(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 4)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')
        mdata.copy_indexes('batches_with_addresses_bin')

        return mdata

    def reverse_from_4(self, mdata):
        # The only difference between v3 and v4 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 3)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')
        mdata.copy_indexes('batches_with_addresses_bin')

        return mdata

    def migrate_from_4(self, mdata):
        # We copy existing fields and indexes over except for the indexes we're
        # removing.
        mdata.set_value('$VERSION', 5)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_addresses_bin')
        mdata.copy_indexes('batches_with_addresses_reverse_bin')

        return mdata

    def reverse_from_5(self, mdata):
        # The only difference between v4 and v5 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 4)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_addresses_bin')
        mdata.copy_indexes('batches_with_addresses_reverse_bin')

        return mdata


class InboundMessageMigrator(MessageMigratorBase):
    def migrate_from_unversioned(self, mdata):
        mdata.set_value('$VERSION', 1)

        self._copy_msg_field('msg', mdata)
        self._foreign_key_to_many_to_many('batch', 'batches', mdata)

        return mdata

    def migrate_from_1(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 2)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')

        return mdata

    def migrate_from_2(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 3)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')

        return mdata

    def reverse_from_3(self, mdata):
        # The only difference between v2 and v3 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 2)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')

        return mdata

    def migrate_from_3(self, mdata):
        # We only copy existing fields and indexes over. The new fields and
        # indexes are computed at save time.
        mdata.set_value('$VERSION', 4)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')
        mdata.copy_indexes('batches_with_addresses_bin')

        return mdata

    def reverse_from_4(self, mdata):
        # The only difference between v3 and v4 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 3)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_timestamps_bin')
        mdata.copy_indexes('batches_with_addresses_bin')

        return mdata

    def migrate_from_4(self, mdata):
        # We copy existing fields and indexes over except for the indexes we're
        # removing.
        mdata.set_value('$VERSION', 5)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_addresses_bin')
        mdata.copy_indexes('batches_with_addresses_reverse_bin')

        return mdata

    def reverse_from_5(self, mdata):
        # The only difference between v4 and v5 is an index that's computed at
        # save time, so the reverse migration is identical to the forward
        # migration except for the version we set.
        mdata.set_value('$VERSION', 4)
        self._copy_msg_field('msg', mdata)
        mdata.copy_values('batches')
        mdata.copy_indexes('batches_bin')
        mdata.copy_indexes('batches_with_addresses_bin')
        mdata.copy_indexes('batches_with_addresses_reverse_bin')

        return mdata
PK=JG+

%vumi/components/tests/test_session.py"""Tests for vumi.persist.session."""

import time

from twisted.internet.defer import inlineCallbacks

from vumi.components.session import SessionManager
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class TestSessionManager(VumiTestCase):
    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        self.manager = yield self.persistence_helper.get_redis_manager()
        yield self.manager._purge_all()  # Just in case
        self.sm = SessionManager(self.manager)
        self.add_cleanup(self.sm.stop)

    @inlineCallbacks
    def test_active_sessions(self):
        def get_sessions():
            return self.sm.active_sessions().addCallback(lambda s: sorted(s))

        def ids():
            return get_sessions().addCallback(lambda s: [x[0] for x in s])

        self.assertEqual((yield ids()), [])
        yield self.sm.create_session("u1")
        self.assertEqual((yield ids()), ["u1"])
         # 10 seconds later
        yield self.sm.create_session("u2", created_at=time.time() + 10)
        self.assertEqual((yield ids()), ["u1", "u2"])

        s1, s2 = yield get_sessions()
        self.assertTrue(s1[1]['created_at'] < s2[1]['created_at'])

    @inlineCallbacks
    def test_schedule_session_expiry(self):
        self.sm.max_session_length = 60.0
        yield self.sm.create_session("u1")

    @inlineCallbacks
    def test_create_and_retrieve_session(self):
        session = yield self.sm.create_session("u1")
        self.assertEqual(sorted(session.keys()), ['created_at'])
        self.assertTrue(time.time() - float(session['created_at']) < 10.0)
        loaded = yield self.sm.load_session("u1")
        self.assertEqual(loaded, session)

    @inlineCallbacks
    def test_create_clears_existing_session(self):
        session = yield self.sm.create_session("u1", foo="bar")
        self.assertEqual(sorted(session.keys()), ['created_at', 'foo'])
        loaded = yield self.sm.load_session("u1")
        self.assertEqual(loaded, session)

        session = yield self.sm.create_session("u1", bar="baz")
        self.assertEqual(sorted(session.keys()), ['bar', 'created_at'])
        loaded = yield self.sm.load_session("u1")
        self.assertEqual(loaded, session)

    @inlineCallbacks
    def test_save_session(self):
        test_session = {"foo": 5, "bar": "baz"}
        yield self.sm.create_session("u1")
        yield self.sm.save_session("u1", test_session)
        session = yield self.sm.load_session("u1")
        self.assertTrue(session.pop('created_at') is not None)
        # Redis saves & returns all session values as strings
        self.assertEqual(session, dict([map(str, kvs) for kvs
                                        in test_session.items()]))
PKh^xG**0vumi/components/tests/test_message_formatters.py# -*- coding: utf-8 -*-

from twisted.web.test.test_web import DummyRequest

from vumi.components.message_formatters import (
    IMessageFormatter, JsonFormatter, CsvFormatter, CsvEventFormatter)

from vumi.tests.helpers import VumiTestCase, MessageHelper


class TestJsonFormatter(VumiTestCase):
    def setUp(self):
        self.msg_helper = self.add_helper(MessageHelper())
        self.request = DummyRequest([''])
        self.formatter = JsonFormatter()

    def test_implements_IMessageFormatter(self):
        self.assertTrue(IMessageFormatter.providedBy(self.formatter))

    def test_add_http_headers(self):
        self.formatter.add_http_headers(self.request)
        self.assertEqual(
            self.request.responseHeaders.getRawHeaders('Content-Type'),
            ['application/json; charset=utf-8'])

    def test_write_row_header(self):
        self.formatter.write_row_header(self.request)
        self.assertEqual(self.request.written, [])

    def test_write_row(self):
        msg = self.msg_helper.make_inbound("foo")
        self.formatter.write_row(self.request, msg)
        self.assertEqual(self.request.written, [
            msg.to_json(), "\n",
        ])


class TestCsvFormatter(VumiTestCase):
    def setUp(self):
        self.msg_helper = self.add_helper(MessageHelper())
        self.request = DummyRequest([''])
        self.formatter = CsvFormatter()

    def test_implements_IMessageFormatter(self):
        self.assertTrue(IMessageFormatter.providedBy(self.formatter))

    def test_add_http_headers(self):
        self.formatter.add_http_headers(self.request)
        self.assertEqual(
            self.request.responseHeaders.getRawHeaders('Content-Type'),
            ['text/csv; charset=utf-8'])

    def test_write_row_header(self):
        self.formatter.write_row_header(self.request)
        self.assertEqual(self.request.written, [
            "timestamp,message_id,to_addr,from_addr,in_reply_to,session_event,"
            "content,group\r\n"
        ])

    def assert_row_written(self, row, row_template, msg):
        self.assertEqual(row, [row_template % {
            'ts': msg['timestamp'].isoformat(),
            'id': msg['message_id'],
        }])

    def test_write_row(self):
        msg = self.msg_helper.make_inbound("foo")
        self.formatter.write_row(self.request, msg)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,9292,+41791234567,,,foo,\r\n", msg)

    def test_write_row_with_in_reply_to(self):
        msg = self.msg_helper.make_inbound("foo", in_reply_to="msg-2")
        self.formatter.write_row(self.request, msg)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,9292,+41791234567,msg-2,,foo,\r\n", msg)

    def test_write_row_with_session_event(self):
        msg = self.msg_helper.make_inbound("foo", session_event="new")
        self.formatter.write_row(self.request, msg)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,9292,+41791234567,,new,foo,\r\n", msg)

    def test_write_row_with_group(self):
        msg = self.msg_helper.make_inbound("foo", group="#channel")
        self.formatter.write_row(self.request, msg)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,9292,+41791234567,,,foo,#channel\r\n", msg)

    def test_write_row_with_unicode_content(self):
        msg = self.msg_helper.make_inbound(u"føø", group="#channel")
        self.formatter.write_row(self.request, msg)
        self.assert_row_written(
            self.request.written,
            u"%(ts)s,%(id)s,9292,+41791234567,,,føø,#channel\r\n".encode(
                "utf-8"),
            msg)


class TestCsvEventFormatter(VumiTestCase):
    def setUp(self):
        self.msg_helper = self.add_helper(MessageHelper())
        self.request = DummyRequest([''])
        self.formatter = CsvEventFormatter()

    def test_implements_IMessageFormatter(self):
        self.assertTrue(IMessageFormatter.providedBy(self.formatter))

    def test_add_http_headers(self):
        self.formatter.add_http_headers(self.request)
        self.assertEqual(
            self.request.responseHeaders.getRawHeaders('Content-Type'),
            ['text/csv; charset=utf-8'])

    def test_write_row_header(self):
        self.formatter.write_row_header(self.request)
        self.assertEqual(self.request.written, [
            "timestamp,event_id,status,user_message_id,nack_reason\r\n"
        ])

    def assert_row_written(self, row, row_template, event):
        self.assertEqual(row, [row_template % {
            'ts': event['timestamp'].isoformat(),
            'id': event['event_id'],
            'msg_id': event['user_message_id'],
        }])

    def test_write_row_ack(self):
        event = self.msg_helper.make_ack()
        self.formatter.write_row(self.request, event)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,ack,%(msg_id)s,\r\n", event)

    def test_write_row_nack(self):
        event = self.msg_helper.make_nack(nack_reason="raisins")
        self.formatter.write_row(self.request, event)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,nack,%(msg_id)s,raisins\r\n", event)

    def test_write_row_delivery_report(self):
        event = self.msg_helper.make_delivery_report()
        self.formatter.write_row(self.request, event)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,delivery_report.delivered,%(msg_id)s,\r\n", event)

    def test_write_row_with_unicode_content(self):
        event = self.msg_helper.make_nack(nack_reason=u"føø")
        self.formatter.write_row(self.request, event)
        self.assert_row_written(
            self.request.written,
            "%(ts)s,%(id)s,nack,%(msg_id)s,føø\r\n", event)
PKh^xGTT4vumi/components/tests/test_message_store_resource.py# -*- coding: utf-8 -*-

import json
from datetime import datetime
from urllib import urlencode

from twisted.internet import reactor
from twisted.internet.defer import (
    inlineCallbacks, Deferred, succeed, gatherResults)
from twisted.web.server import Site

from vumi.components.message_formatters import JsonFormatter

from vumi.utils import http_request_full

from vumi.tests.helpers import (
    VumiTestCase, MessageHelper, PersistenceHelper, import_skip,
    WorkerHelper)


class TestMessageStoreResource(VumiTestCase):

    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True))
        self.worker_helper = self.add_helper(WorkerHelper())
        self.msg_helper = self.add_helper(MessageHelper())

    @inlineCallbacks
    def start_server(self):
        try:
            from vumi.components.message_store_resource import (
                MessageStoreResourceWorker)
        except ImportError, e:
            import_skip(e, 'riak')

        config = self.persistence_helper.mk_config({
            'twisted_endpoint': 'tcp:0',
            'web_path': '/resource_path/',
        })

        worker = yield self.worker_helper.get_worker(
            MessageStoreResourceWorker, config)
        yield worker.startService()
        port = yield worker.services[0]._waitingForPort
        addr = port.getHost()

        self.url = 'http://%s:%s' % (addr.host, addr.port)
        self.store = worker.store
        self.addCleanup(self.stop_server, port)

    def stop_server(self, port):
        d = port.stopListening()
        d.addCallback(lambda _: port.loseConnection())
        return d

    def make_batch(self, tag):
        return self.store.batch_start([tag])

    def make_outbound(self, batch_id, content, timestamp=None):
        if timestamp is None:
            timestamp = datetime.utcnow()
        msg = self.msg_helper.make_outbound(content, timestamp=timestamp)
        d = self.store.add_outbound_message(msg, batch_id=batch_id)
        d.addCallback(lambda _: msg)
        return d

    def make_inbound(self, batch_id, content, timestamp=None):
        if timestamp is None:
            timestamp = datetime.utcnow()
        msg = self.msg_helper.make_inbound(content, timestamp=timestamp)
        d = self.store.add_inbound_message(msg, batch_id=batch_id)
        d.addCallback(lambda _: msg)
        return d

    def make_ack(self, batch_id, timestamp=None):
        if timestamp is None:
            timestamp = datetime.utcnow()
        ack = self.msg_helper.make_ack(timestamp=timestamp)
        d = self.store.add_event(ack, batch_ids=[batch_id])
        d.addCallback(lambda _: ack)
        return d

    def make_request(self, method, batch_id, leaf, **params):
        url = '%s/%s/%s/%s' % (self.url, 'resource_path', batch_id, leaf)
        if params:
            url = '%s?%s' % (url, urlencode(params))
        return http_request_full(method=method, url=url)

    def get_batch_resource(self, batch_id):
        return self.store_resource.getChild(batch_id, None)

    def assert_csv_rows(self, rows, expected):
        self.assertEqual(sorted(rows), sorted([
            row_template % {
                'id': msg['message_id'],
                'ts': msg['timestamp'].isoformat(),
            } for row_template, msg in expected
        ]))

    def assert_csv_event_rows(self, rows, expected):
        self.assertEqual(sorted(rows), sorted([
            row_template % {
                'id': ev['event_id'],
                'ts': ev['timestamp'].isoformat(),
                'msg_id': ev['user_message_id'],
            } for row_template, ev in expected
        ]))

    @inlineCallbacks
    def test_get_inbound(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        msg1 = yield self.make_inbound(batch_id, 'føø')
        msg2 = yield self.make_inbound(batch_id, 'føø')
        resp = yield self.make_request('GET', batch_id, 'inbound.json')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg1['message_id'], msg2['message_id']]))

    @inlineCallbacks
    def test_get_inbound_csv(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        msg1 = yield self.make_inbound(batch_id, 'føø')
        msg2 = yield self.make_inbound(batch_id, 'føø')
        resp = yield self.make_request('GET', batch_id, 'inbound.csv')
        rows = resp.delivered_body.split('\r\n')
        header, rows = rows[0], rows[1:-1]
        self.assertEqual(header, (
            "timestamp,message_id,to_addr,from_addr,in_reply_to,session_event,"
            "content,group"))
        self.assert_csv_rows(rows, [
            ("%(ts)s,%(id)s,9292,+41791234567,,,føø,", msg1),
            ("%(ts)s,%(id)s,9292,+41791234567,,,føø,", msg2),
        ])

    @inlineCallbacks
    def test_get_outbound(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        msg1 = yield self.make_outbound(batch_id, 'føø')
        msg2 = yield self.make_outbound(batch_id, 'føø')
        resp = yield self.make_request('GET', batch_id, 'outbound.json')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg1['message_id'], msg2['message_id']]))

    @inlineCallbacks
    def test_get_outbound_csv(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        msg1 = yield self.make_outbound(batch_id, 'føø')
        msg2 = yield self.make_outbound(batch_id, 'føø')
        resp = yield self.make_request('GET', batch_id, 'outbound.csv')
        rows = resp.delivered_body.split('\r\n')
        header, rows = rows[0], rows[1:-1]
        self.assertEqual(header, (
            "timestamp,message_id,to_addr,from_addr,in_reply_to,session_event,"
            "content,group"))
        self.assert_csv_rows(rows, [
            ("%(ts)s,%(id)s,+41791234567,9292,,,føø,", msg1),
            ("%(ts)s,%(id)s,+41791234567,9292,,,føø,", msg2),
        ])

    @inlineCallbacks
    def test_get_events(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        ack1 = yield self.make_ack(batch_id)
        ack2 = yield self.make_ack(batch_id)
        resp = yield self.make_request('GET', batch_id, 'events.json')
        events = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([ev['event_id'] for ev in events]),
            set([ack1['event_id'], ack2['event_id']]))

    @inlineCallbacks
    def test_get_events_csv(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        ack1 = yield self.make_ack(batch_id)
        ack2 = yield self.make_ack(batch_id)
        resp = yield self.make_request('GET', batch_id, 'events.csv')
        rows = resp.delivered_body.split('\r\n')
        header, rows = rows[0], rows[1:-1]
        self.assertEqual(header, (
            "timestamp,event_id,status,user_message_id,"
            "nack_reason"))
        self.assert_csv_event_rows(rows, [
            ("%(ts)s,%(id)s,ack,%(msg_id)s,", ack1),
            ("%(ts)s,%(id)s,ack,%(msg_id)s,", ack2),
        ])

    @inlineCallbacks
    def test_get_inbound_multiple_pages(self):
        yield self.start_server()
        self.store.DEFAULT_MAX_RESULTS = 1
        batch_id = yield self.make_batch(('foo', 'bar'))
        msg1 = yield self.make_inbound(batch_id, 'føø')
        msg2 = yield self.make_inbound(batch_id, 'føø')
        resp = yield self.make_request('GET', batch_id, 'inbound.json')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg1['message_id'], msg2['message_id']]))

    @inlineCallbacks
    def test_disconnect_kills_server(self):
        """
        If the client connection is lost, we stop processing the request.

        This test is a bit hacky, because it has to muck about inside the
        resource in order to pause and resume at appropriate places.
        """
        yield self.start_server()

        from vumi.components.message_store_resource import InboundResource

        batch_id = yield self.make_batch(('foo', 'bar'))
        msgs = [(yield self.make_inbound(batch_id, 'føø'))
                for _ in range(6)]

        class PausingInboundResource(InboundResource):
            def __init__(self, *args, **kw):
                InboundResource.__init__(self, *args, **kw)
                self.pause_after = 3
                self.pause_d = Deferred()
                self.resume_d = Deferred()
                self.fetch = {}

            def _finish_fetching(self, msg):
                self.fetch[msg['message_id']].callback(msg['message_id'])
                return msg

            def get_message(self, message_store, message_id):
                self.fetch[message_id] = Deferred()
                d = succeed(None)
                if self.pause_after > 0:
                    self.pause_after -= 1
                else:
                    if not self.pause_d.called:
                        self.pause_d.callback(None)
                    d.addCallback(lambda _: self.resume_d)
                d.addCallback(lambda _: InboundResource.get_message(
                    self, message_store, message_id))
                d.addCallback(self._finish_fetching)
                return d

        res = PausingInboundResource(self.store, batch_id, JsonFormatter())
        site = Site(res)
        server = yield reactor.listenTCP(0, site, interface='127.0.0.1')
        self.add_cleanup(server.loseConnection)
        addr = server.getHost()
        url = 'http://%s:%s?concurrency=2' % (addr.host, addr.port)

        resp_d = http_request_full(method='GET', url=url)
        # Wait until we've processed some messages.
        yield res.pause_d
        # Kill the client connection.
        yield resp_d.cancel()
        # Continue processing messages.
        res.resume_d.callback(None)

        # This will fail because we've cancelled the request. We don't care
        # about the exception, so we swallow it and move on.
        yield resp_d.addErrback(lambda _: None)

        # Wait for all the in-progress loads to finish.
        fetched_msg_ids = yield gatherResults(res.fetch.values())

        sorted_message_ids = sorted(msg['message_id'] for msg in msgs)
        self.assertEqual(set(fetched_msg_ids), set(sorted_message_ids[:4]))

    @inlineCallbacks
    def test_get_inbound_for_time_range(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'inbound.json', start='2014-11-02 00:00:00',
            end='2014-11-04 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg2['message_id'], msg3['message_id']]))

    @inlineCallbacks
    def test_get_inbound_for_time_range_bad_args(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))

        resp = yield self.make_request(
            'GET', batch_id, 'inbound.json', start='foo')
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'start' parameter: Unable to parse date string 'foo'")

        resp = yield self.make_request(
            'GET', batch_id, 'inbound.json', end='bar')
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'end' parameter: Unable to parse date string 'bar'")

        url = '%s/%s/%s/%s?start=foo&start=bar' % (
            self.url, 'resource_path', batch_id, 'inbound.json')
        resp = yield http_request_full(method='GET', url=url)
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'start' parameter: Too many values")

    @inlineCallbacks
    def test_get_inbound_for_time_range_no_start(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        msg1 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'inbound.json', end='2014-11-04 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg1['message_id'], msg2['message_id'], msg3['message_id']]))

    @inlineCallbacks
    def test_get_inbound_for_time_range_no_end(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(3))
        msg4 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'inbound.json', start='2014-11-02 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg2['message_id'], msg3['message_id'], msg4['message_id']]))

    @inlineCallbacks
    def test_get_inbound_csv_for_time_range(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_inbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_inbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'inbound.csv', start='2014-11-02 00:00:00',
            end='2014-11-04 00:00:00')
        rows = resp.delivered_body.split('\r\n')
        header, rows = rows[0], rows[1:-1]
        self.assertEqual(header, (
            "timestamp,message_id,to_addr,from_addr,in_reply_to,session_event,"
            "content,group"))
        self.assert_csv_rows(rows, [
            ("%(ts)s,%(id)s,9292,+41791234567,,,føø,", msg2),
            ("%(ts)s,%(id)s,9292,+41791234567,,,føø,", msg3),
        ])

    @inlineCallbacks
    def test_get_outbound_for_time_range(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'outbound.json', start='2014-11-02 00:00:00',
            end='2014-11-04 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg2['message_id'], msg3['message_id']]))

    @inlineCallbacks
    def test_get_outbound_for_time_range_bad_args(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))

        resp = yield self.make_request(
            'GET', batch_id, 'outbound.json', start='foo')
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'start' parameter: Unable to parse date string 'foo'")

        resp = yield self.make_request(
            'GET', batch_id, 'outbound.json', end='bar')
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'end' parameter: Unable to parse date string 'bar'")

        url = '%s/%s/%s/%s?start=foo&start=bar' % (
            self.url, 'resource_path', batch_id, 'outbound.json')
        resp = yield http_request_full(method='GET', url=url)
        self.assertEqual(resp.code, 400)
        self.assertEqual(
            resp.delivered_body,
            "Invalid 'start' parameter: Too many values")

    @inlineCallbacks
    def test_get_outbound_for_time_range_no_start(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        msg1 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'outbound.json', end='2014-11-04 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg1['message_id'], msg2['message_id'], msg3['message_id']]))

    @inlineCallbacks
    def test_get_outbound_for_time_range_no_end(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(3))
        msg4 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'outbound.json', start='2014-11-02 00:00:00')
        messages = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([msg['message_id'] for msg in messages]),
            set([msg2['message_id'], msg3['message_id'], msg4['message_id']]))

    @inlineCallbacks
    def test_get_outbound_csv_for_time_range(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(1))
        msg2 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(2))
        msg3 = yield self.make_outbound(batch_id, 'føø', timestamp=mktime(3))
        yield self.make_outbound(batch_id, 'føø', timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'outbound.csv', start='2014-11-02 00:00:00',
            end='2014-11-04 00:00:00')
        rows = resp.delivered_body.split('\r\n')
        header, rows = rows[0], rows[1:-1]
        self.assertEqual(header, (
            "timestamp,message_id,to_addr,from_addr,in_reply_to,session_event,"
            "content,group"))
        self.assert_csv_rows(rows, [
            ("%(ts)s,%(id)s,+41791234567,9292,,,føø,", msg2),
            ("%(ts)s,%(id)s,+41791234567,9292,,,føø,", msg3),
        ])

    @inlineCallbacks
    def test_get_events_for_time_range(self):
        yield self.start_server()
        batch_id = yield self.make_batch(('foo', 'bar'))
        mktime = lambda day: datetime(2014, 11, day, 12, 0, 0)
        yield self.make_ack(batch_id, timestamp=mktime(1))
        ack2 = yield self.make_ack(batch_id, timestamp=mktime(2))
        ack3 = yield self.make_ack(batch_id, timestamp=mktime(3))
        yield self.make_ack(batch_id, timestamp=mktime(4))
        resp = yield self.make_request(
            'GET', batch_id, 'events.json', start='2014-11-02 00:00:00',
            end='2014-11-04 00:00:00')
        events = map(
            json.loads, filter(None, resp.delivered_body.split('\n')))
        self.assertEqual(
            set([ev['event_id'] for ev in events]),
            set([ack2['event_id'], ack3['event_id']]))
PK=JGA*A*%vumi/components/tests/test_tagpool.py# -*- coding: utf-8 -*-

"""Tests for vumi.components.tagpool."""

import json

from twisted.internet.defer import inlineCallbacks

from vumi.components.tagpool import TagpoolManager, TagpoolError
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class TestTxTagpoolManager(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        self.redis = yield self.persistence_helper.get_redis_manager()
        yield self.redis._purge_all()  # Just in case
        self.tpm = TagpoolManager(self.redis)

    def pool_key_generator(self, pool):
        def tkey(x):
            return "tagpools:%s:%s" % (pool, x)
        return tkey

    @inlineCallbacks
    def test_declare_tags(self):
        tag1, tag2 = ("poolA", "tag1"), ("poolA", "tag2")
        yield self.tpm.declare_tags([tag1, tag2])
        self.assertEqual((yield self.tpm.acquire_tag("poolA")), tag1)
        self.assertEqual((yield self.tpm.acquire_tag("poolA")), tag2)
        self.assertEqual((yield self.tpm.acquire_tag("poolA")), None)
        tag3 = ("poolA", "tag3")
        yield self.tpm.declare_tags([tag2, tag3])
        self.assertEqual((yield self.tpm.acquire_tag("poolA")), tag3)

    @inlineCallbacks
    def test_declare_unicode_tag(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        self.assertEqual((yield self.tpm.acquire_tag(tag[0])), tag)

    @inlineCallbacks
    def test_purge_pool(self):
        tag1, tag2 = ("poolA", "tag1"), ("poolA", "tag2")
        yield self.tpm.declare_tags([tag1, tag2])
        yield self.tpm.purge_pool('poolA')
        self.assertEqual((yield self.tpm.acquire_tag('poolA')), None)

    @inlineCallbacks
    def test_purge_unicode_pool(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        yield self.tpm.purge_pool(tag[0])
        self.assertEqual((yield self.tpm.acquire_tag(tag[0])), None)

    @inlineCallbacks
    def test_purge_inuse_pool(self):
        tag1, tag2 = ("poolA", "tag1"), ("poolA", "tag2")
        yield self.tpm.declare_tags([tag1, tag2])
        self.assertEqual((yield self.tpm.acquire_tag('poolA')), tag1)
        try:
            yield self.tpm.purge_pool('poolA')
        except TagpoolError:
            pass
        else:
            self.fail("Expected TagpoolError to be raised.")

    @inlineCallbacks
    def test_list_pools(self):
        tag1, tag2 = ("poolA", "tag1"), ("poolB", "tag2")
        yield self.tpm.declare_tags([tag1, tag2])
        self.assertEqual((yield self.tpm.list_pools()),
                         set(['poolA', 'poolB']))

    @inlineCallbacks
    def test_list_unicode_pool(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        self.assertEqual((yield self.tpm.list_pools()),
                         set([tag[0]]))

    @inlineCallbacks
    def test_acquire_tag(self):
        tkey = self.pool_key_generator("poolA")
        tag1, tag2 = ("poolA", "tag1"), ("poolA", "tag2")
        yield self.tpm.declare_tags([tag1, tag2])
        self.assertEqual((yield self.tpm.acquire_tag("poolA")), tag1)
        self.assertEqual((yield self.tpm.acquire_tag("poolB")), None)
        redis = self.redis
        self.assertEqual((yield redis.lrange(tkey("free:list"), 0, -1)),
                         ["tag2"])
        self.assertEqual((yield redis.smembers(tkey("free:set"))),
                         set(["tag2"]))
        self.assertEqual((yield redis.smembers(tkey("inuse:set"))),
                         set(["tag1"]))

    @inlineCallbacks
    def test_acquire_unicode_tag(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        self.assertEqual((yield self.tpm.acquire_tag(tag[0])), tag)
        self.assertEqual((yield self.tpm.acquire_tag(tag[0])), None)

    @inlineCallbacks
    def test_acquire_specific_tag(self):
        tkey = self.pool_key_generator("poolA")
        tags = [("poolA", "tag%d" % i) for i in range(10)]
        tag5 = tags[5]
        yield self.tpm.declare_tags(tags)
        self.assertEqual((yield self.tpm.acquire_specific_tag(tag5)), tag5)
        self.assertEqual((yield self.tpm.acquire_specific_tag(tag5)), None)
        free_local_tags = [t[1] for t in tags]
        free_local_tags.remove("tag5")
        redis = self.redis
        self.assertEqual((yield redis.lrange(tkey("free:list"), 0, -1)),
                         free_local_tags)
        self.assertEqual((yield redis.smembers(tkey("free:set"))),
                         set(free_local_tags))
        self.assertEqual((yield redis.smembers(tkey("inuse:set"))),
                         set(["tag5"]))

    @inlineCallbacks
    def test_acquire_specific_unicode_tag(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        self.assertEqual((yield self.tpm.acquire_specific_tag(tag)), tag)
        self.assertEqual((yield self.tpm.acquire_specific_tag(tag)), None)

    @inlineCallbacks
    def test_release_tag(self):
        tkey = self.pool_key_generator("poolA")
        tag1, tag2, tag3 = [("poolA", "tag%d" % i) for i in (1, 2, 3)]
        yield self.tpm.declare_tags([tag1, tag2, tag3])
        yield self.tpm.acquire_tag("poolA")
        yield self.tpm.acquire_tag("poolA")
        yield self.tpm.release_tag(tag1)
        redis = self.redis
        self.assertEqual((yield redis.lrange(tkey("free:list"), 0, -1)),
                         ["tag3", "tag1"])
        self.assertEqual((yield redis.smembers(tkey("free:set"))),
                         set(["tag1", "tag3"]))
        self.assertEqual((yield redis.smembers(tkey("inuse:set"))),
                         set(["tag2"]))

    @inlineCallbacks
    def test_release_unicode_tag(self):
        tag = (u"poöl", u"tág")
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_tag(tag[0])
        yield self.tpm.release_tag(tag)
        self.assertEqual((yield self.tpm.acquire_tag(tag[0])), tag)

    @inlineCallbacks
    def test_metadata(self):
        mkey = self.pool_key_generator("poolA")("metadata")
        metadata = {
            "transport_type": "sms",
            "default_msg_fields": {
                "transport_name": "sphex",
                "helper_metadata": {
                    "even_more_nested": "foo",
                },
            },
        }
        yield self.tpm.set_metadata("poolA", metadata)
        self.assertEqual((yield self.tpm.get_metadata("poolA")), metadata)
        tt_json = yield self.redis.hget(mkey, "transport_type")
        transport_type = json.loads(tt_json)
        self.assertEqual(transport_type, "sms")

        short_md = {"foo": "bar"}
        yield self.tpm.set_metadata("poolA", short_md)
        self.assertEqual((yield self.tpm.get_metadata("poolA")), short_md)

    @inlineCallbacks
    def test_metadata_for_unicode_pool_name(self):
        pool = u"poöl"
        metadata = {"foo": "bar"}
        yield self.tpm.set_metadata(pool, metadata)
        self.assertEqual((yield self.tpm.get_metadata(pool)),
                         metadata)

    @inlineCallbacks
    def test_unicode_metadata(self):
        metadata = {u"föo": u"báz"}
        yield self.tpm.set_metadata("pool", metadata)
        self.assertEqual((yield self.tpm.get_metadata("pool")),
                         metadata)

    def _check_reason(self, expected_owner, owner, reason, expected_data):
        self.assertEqual(expected_owner, owner)
        self.assertEqual(expected_owner, reason.pop('owner'))
        timestamp = reason.pop('timestamp')
        self.assertTrue(isinstance(timestamp, float))
        self.assertEqual(reason, expected_data)

    @inlineCallbacks
    def test_acquired_by(self):
        tag = ["pool", "tag"]
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_tag(tag[0], "me", {"foo": "bar"})
        owner, reason = yield self.tpm.acquired_by(tag)
        self._check_reason("me", owner, reason, {"foo": "bar"})

    @inlineCallbacks
    def test_acquired_by_undeclared_tags(self):
        tag = ["pool", "tag"]
        owner, reason = yield self.tpm.acquired_by(tag)
        self.assertEqual(owner, None)
        self.assertEqual(reason, None)

    @inlineCallbacks
    def test_acquired_by_no_owner(self):
        tag = ["pool", "tag"]
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_tag(tag[0])
        owner, reason = yield self.tpm.acquired_by(tag)
        self._check_reason(None, owner, reason, {})

    @inlineCallbacks
    def test_acquired_by_unicode_owner(self):
        tag = ["pool", "tag"]
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_tag(tag[0], u"mé")
        owner, reason = yield self.tpm.acquired_by(tag)
        self._check_reason(u"mé", owner, reason, {})

    @inlineCallbacks
    def test_acquired_by_from_unicode_tag(self):
        tag = [u"poöl", u"tág"]
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_tag(tag[0], "me")
        owner, reason = yield self.tpm.acquired_by(tag)
        self._check_reason(u"me", owner, reason, {})

    @inlineCallbacks
    def test_acquired_by_after_using_specific_tag(self):
        tag = ["pool", "tag"]
        yield self.tpm.declare_tags([tag])
        yield self.tpm.acquire_specific_tag(tag, "me", {"foo": "bar"})
        owner, reason = yield self.tpm.acquired_by(tag)
        self._check_reason("me", owner, reason, {"foo": "bar"})

    @inlineCallbacks
    def test_owned_tags(self):
        tags = [["pool1", "tag1"], ["pool2", "tag2"]]
        yield self.tpm.declare_tags(tags)
        yield self.tpm.acquire_tag(tags[0][0], owner="me")
        my_tags = yield self.tpm.owned_tags("me")
        self.assertEqual(my_tags, [tags[0]])

    @inlineCallbacks
    def test_owned_tags_no_owner(self):
        tags = [["pool1", "tag1"], ["pool2", "tag2"]]
        yield self.tpm.declare_tags(tags)
        yield self.tpm.acquire_tag(tags[0][0])
        my_tags = yield self.tpm.owned_tags(None)
        self.assertEqual(my_tags, [tags[0]])

    @inlineCallbacks
    def test_owned_tags_unicode_owner(self):
        tags = [["pool1", "tag1"], ["pool2", "tag2"]]
        yield self.tpm.declare_tags(tags)
        yield self.tpm.acquire_tag(tags[0][0], owner=u"mé")
        my_tags = yield self.tpm.owned_tags(u"mé")
        self.assertEqual(my_tags, [tags[0]])

    @inlineCallbacks
    def test_owned_tags_unicode_tags(self):
        tags = [[u"poöl1", u"tág1"], [u"poöl2", u"tág2"]]
        yield self.tpm.declare_tags(tags)
        yield self.tpm.acquire_tag(tags[0][0], owner="me")
        my_tags = yield self.tpm.owned_tags(u"me")
        self.assertEqual(my_tags, [tags[0]])


class TestTagpoolManager(TestTxTagpoolManager):
    sync_persistence = True
PK=JGRJTT5vumi/components/tests/test_message_store_migrators.py"""Tests for go.components.message_store_migrators."""

from twisted.internet.defer import inlineCallbacks

from vumi.message import format_vumi_date
from vumi.tests.helpers import (
    VumiTestCase, MessageHelper, PersistenceHelper, import_skip)

try:
    from vumi.components.tests.message_store_old_models import (
        OutboundMessageVNone, InboundMessageVNone, EventVNone, BatchVNone,
        OutboundMessageV1, InboundMessageV1, OutboundMessageV2,
        InboundMessageV2, OutboundMessageV3, InboundMessageV3, EventV1,
        OutboundMessageV4, InboundMessageV4)
    from vumi.components.message_store import (
        to_reverse_timestamp,
        OutboundMessage as OutboundMessageV5,
        InboundMessage as InboundMessageV5,
        Event as EventV2)
    riak_import_error = None
except ImportError, e:
    riak_import_error = e


def mws_value(msg_id, event, status):
    return "%s$%s$%s" % (msg_id, format_vumi_date(event['timestamp']), status)


def bwsr_value(batch_id, event, status):
    reverse_ts = to_reverse_timestamp(format_vumi_date(event['timestamp']))
    return "%s$%s$%s" % (batch_id, reverse_ts, status)


def bwt_value(batch_id, msg):
    return "%s$%s" % (batch_id, format_vumi_date(msg['timestamp']))


def bwa_in_value(batch_id, msg):
    return "%s$%s$%s" % (
        batch_id, format_vumi_date(msg['timestamp']), msg['from_addr'])


def bwa_out_value(batch_id, msg):
    return "%s$%s$%s" % (
        batch_id, format_vumi_date(msg['timestamp']), msg['to_addr'])


def bwar_in_value(batch_id, msg):
    reverse_ts = to_reverse_timestamp(format_vumi_date(msg['timestamp']))
    return "%s$%s$%s" % (batch_id, reverse_ts, msg['from_addr'])


def bwar_out_value(batch_id, msg):
    reverse_ts = to_reverse_timestamp(format_vumi_date(msg['timestamp']))
    return "%s$%s$%s" % (batch_id, reverse_ts, msg['to_addr'])


def batch_index(value):
    return ("batches_bin", value)


def bwt_index(value):
    return ("batches_with_timestamps_bin", value)


def bwa_index(value):
    return ("batches_with_addresses_bin", value)


def bwar_index(value):
    return ("batches_with_addresses_reverse_bin", value)


class TestMigratorBase(VumiTestCase):
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True))
        if riak_import_error is not None:
            import_skip(riak_import_error, 'riak')
        self.manager = self.persistence_helper.get_riak_manager()
        self.add_cleanup(self.manager.close_manager)
        self.msg_helper = self.add_helper(MessageHelper())


class TestEventMigrator(TestMigratorBase):
    @inlineCallbacks
    def setUp(self):
        yield super(TestEventMigrator, self).setUp()
        self.event_vnone = self.manager.proxy(EventVNone)
        self.event_v1 = self.manager.proxy(EventV1)
        self.event_v2 = self.manager.proxy(EventV2)

    @inlineCallbacks
    def test_migrate_vnone_to_v1(self):
        """
        A vNone model can be migrated to v1.
        """
        msg = self.msg_helper.make_outbound("outbound")
        msg_id = msg["message_id"]
        event = self.msg_helper.make_ack(msg)
        old_record = self.event_vnone(
            event["event_id"], event=event, message=msg_id)
        yield old_record.save()

        new_record = yield self.event_v1.load(old_record.key)
        self.assertEqual(new_record.event, event)
        self.assertEqual(new_record.message.key, msg_id)

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(new_record.message_with_status, None)
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
        ]))

        yield new_record.save()
        self.assertEqual(
            new_record.message_with_status, mws_value(msg_id, event, "ack"))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v1_vnone(self):
        """
        A v1 model can be stored in a vNone-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.event_v1._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = None

        msg = self.msg_helper.make_outbound("outbound")
        msg_id = msg["message_id"]
        event = self.msg_helper.make_ack(msg)
        new_record = self.event_v1(
            event["event_id"], event=event, message=msg_id)
        yield new_record.save()

        old_record = yield self.event_vnone.load(new_record.key)
        self.assertEqual(old_record.event, event)
        self.assertEqual(old_record.message.key, msg_id)

    @inlineCallbacks
    def test_migrate_vnone_to_v1_index_only_foreign_key(self):
        """
        A vNone model can be migrated to v1 even if it's old enough to still
        have index-only foreign keys.
        """
        msg = self.msg_helper.make_outbound("outbound")
        msg_id = msg["message_id"]
        event = self.msg_helper.make_ack(msg)
        old_record = self.event_vnone(
            event["event_id"], event=event, message=msg_id)

        # Remove the foreign key field from the data before saving it.
        old_record._riak_object.delete_data_field("message")
        yield old_record.save()

        new_record = yield self.event_v1.load(old_record.key)
        self.assertEqual(new_record.event, event)
        self.assertEqual(new_record.message.key, msg_id)

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(new_record.message_with_status, None)
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
        ]))

        yield new_record.save()
        self.assertEqual(
            new_record.message_with_status, mws_value(msg_id, event, "ack"))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
        ]))

    @inlineCallbacks
    def test_migrate_v1_to_v2(self):
        """
        A v1 model can be migrated to v2, but the batches field will be empty.
        """
        msg = self.msg_helper.make_outbound("outbound")
        msg_id = msg["message_id"]
        event = self.msg_helper.make_ack(msg)
        old_record = self.event_v1(
            event["event_id"], event=event, message=msg_id)
        yield old_record.save()

        new_record = yield self.event_v2.load(old_record.key)
        self.assertEqual(new_record.event, event)
        self.assertEqual(new_record.message.key, msg_id)
        self.assertEqual(new_record.batches.keys(), [])
        self.assertEqual(new_record.message_with_status, None)
        self.assertEqual(set(new_record.batches_with_statuses_reverse), set())
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
        ]))

        # Some indexes are only added at save time.
        yield new_record.save()
        self.assertEqual(
            new_record.message_with_status, mws_value(msg_id, event, "ack"))
        self.assertEqual(set(new_record.batches_with_statuses_reverse), set())
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v2_v1(self):
        """
        A v2 model can be stored in a v1-compatible way, but batch information
        is preserved.
        """
        # Configure the manager to save the older message version.
        modelcls = self.event_v2._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 1

        msg = self.msg_helper.make_outbound("outbound")
        msg_id = msg["message_id"]
        event = self.msg_helper.make_ack(msg)
        new_record = self.event_v2(
            event["event_id"], event=event, message=msg_id)
        new_record.batches.add_key(u"batch-1")
        yield new_record.save()

        old_record = yield self.event_v1.load(new_record.key)
        self.assertEqual(old_record.event, event)
        self.assertEqual(old_record.message.key, msg_id)
        self.assertEqual(new_record.message_with_status, None)
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("batches_bin", "batch-1"),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
            ("batches_with_statuses_reverse_bin",
             bwsr_value("batch-1", event, "ack")),
        ]))

        # Some indexes are only added at save time.
        yield old_record.save()
        self.assertEqual(
            old_record.message_with_status, mws_value(msg_id, event, "ack"))
        self.assertEqual(old_record._riak_object.get_indexes(), set([
            ("message_bin", msg_id),
            ("batches_bin", "batch-1"),
            ("message_with_status_bin", mws_value(msg_id, event, "ack")),
            ("batches_with_statuses_reverse_bin",
             bwsr_value("batch-1", event, "ack")),
        ]))

        new2_record = yield self.event_v2.load(old_record.key)
        self.assertEqual(new2_record.event, event)
        self.assertEqual(new2_record.message.key, msg_id)
        self.assertEqual(new2_record.batches.keys(), [u"batch-1"])
        self.assertEqual(new2_record.message_with_status, None)
        self.assertEqual(set(new2_record.batches_with_statuses_reverse), set())


class TestOutboundMessageMigrator(TestMigratorBase):
    @inlineCallbacks
    def setUp(self):
        yield super(TestOutboundMessageMigrator, self).setUp()
        self.outbound_vnone = self.manager.proxy(OutboundMessageVNone)
        self.outbound_v1 = self.manager.proxy(OutboundMessageV1)
        self.outbound_v2 = self.manager.proxy(OutboundMessageV2)
        self.outbound_v3 = self.manager.proxy(OutboundMessageV3)
        self.outbound_v4 = self.manager.proxy(OutboundMessageV4)
        self.outbound_v5 = self.manager.proxy(OutboundMessageV5)
        self.batch_vnone = self.manager.proxy(BatchVNone)

    @inlineCallbacks
    def test_migrate_vnone_to_v1(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.outbound_vnone(
            msg["message_id"], msg=msg, batch=old_batch)
        yield old_record.save()
        new_record = yield self.outbound_v1.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

    @inlineCallbacks
    def test_migrate_vnone_to_v1_without_batch(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_record = self.outbound_vnone(
            msg["message_id"], msg=msg, batch=None)
        yield old_record.save()
        new_record = yield self.outbound_v1.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

    @inlineCallbacks
    def test_migrate_v1_to_v2_no_batches(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_record = self.outbound_v1(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.outbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v1_to_v2_one_batch(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.outbound_v1(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.outbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(
            new_record._riak_object.get_indexes(),
            set([batch_index("batch-1")]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set(
            [batch_index("batch-1"), bwt_index(bwt_value("batch-1", msg))]))

    @inlineCallbacks
    def test_migrate_v1_to_v2_two_batches(self):
        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.outbound_v1(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.outbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"), batch_index("batch-2")]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_no_batches(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_record = self.outbound_v2(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.outbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_one_batch(self):
        msg = self.msg_helper.make_outbound("outbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.outbound_v2(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.outbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_out_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_two_batches(self):
        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.outbound_v2(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.outbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v3_to_v2(self):
        """
        A v3 model can be stored in a v2-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.outbound_v3._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 2

        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.outbound_v3(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.outbound_v2.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])

    @inlineCallbacks
    def test_migrate_v3_to_v4_no_batches(self):
        """
        A v3 model with no batches has no extra indexes when migrated to v4.
        """
        msg = self.msg_helper.make_outbound("outbound")
        old_record = self.outbound_v3(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.outbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v3_to_v4_one_batch(self):
        """
        A v3 model with one batch gets one extra index when migrated to v4.
        """
        msg = self.msg_helper.make_outbound("outbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.outbound_v3(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.outbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_out_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse),
            set([bwar_out_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v3_to_v4_two_batches(self):
        """
        A v3 model with two batches gets two extra indexes when migrated to v4.
        """
        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.outbound_v3(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.outbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-2", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v4_to_v3(self):
        """
        A v4 model can be stored in a v3-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.outbound_v4._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 3

        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.outbound_v4(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.outbound_v3.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])

    @inlineCallbacks
    def test_migrate_v4_to_v5_no_batches(self):
        """
        A v4 model with no batches has no fewer indexes when migrated to v5.
        """
        msg = self.msg_helper.make_outbound("outbound")
        old_record = self.outbound_v4(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.outbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v4_to_v5_one_batch(self):
        """
        A v4 model with one batch gets one fewer index when migrated to v5.
        """
        msg = self.msg_helper.make_outbound("outbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.outbound_v4(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.outbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_out_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse),
            set([bwar_out_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v4_to_v5_two_batches(self):
        """
        A v4 model with two batches gets two fewer indexes when migrated to v5.
        """
        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.outbound_v4(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.outbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-2", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwa_index(bwa_out_value("batch-1", msg)),
            bwa_index(bwa_out_value("batch-2", msg)),
            bwar_index(bwar_out_value("batch-1", msg)),
            bwar_index(bwar_out_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v5_to_v4(self):
        """
        A v5 model can be stored in a v4-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.outbound_v5._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 4

        msg = self.msg_helper.make_outbound("outbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.outbound_v5(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.outbound_v4.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])


class TestInboundMessageMigrator(TestMigratorBase):

    @inlineCallbacks
    def setUp(self):
        yield super(TestInboundMessageMigrator, self).setUp()
        self.inbound_vnone = self.manager.proxy(InboundMessageVNone)
        self.inbound_v1 = self.manager.proxy(InboundMessageV1)
        self.inbound_v2 = self.manager.proxy(InboundMessageV2)
        self.inbound_v3 = self.manager.proxy(InboundMessageV3)
        self.inbound_v4 = self.manager.proxy(InboundMessageV4)
        self.inbound_v5 = self.manager.proxy(InboundMessageV5)
        self.batch_vnone = self.manager.proxy(BatchVNone)

    @inlineCallbacks
    def test_migrate_vnone_to_v1(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.inbound_vnone(
            msg["message_id"], msg=msg, batch=old_batch)
        yield old_record.save()
        new_record = yield self.inbound_v1.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

    @inlineCallbacks
    def test_migrate_vnone_to_v1_without_batch(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_record = self.inbound_vnone(
            msg["message_id"], msg=msg, batch=None)
        yield old_record.save()
        new_record = yield self.inbound_v1.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

    @inlineCallbacks
    def test_migrate_v1_to_v2_no_batches(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_record = self.inbound_v1(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.inbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v1_to_v2_one_batch(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.inbound_v1(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.inbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(
            new_record._riak_object.get_indexes(),
            set([batch_index("batch-1")]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set(
            [batch_index("batch-1"), bwt_index(bwt_value("batch-1", msg))]))

    @inlineCallbacks
    def test_migrate_v1_to_v2_two_batches(self):
        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.inbound_v1(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.inbound_v2.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"), batch_index("batch-2")]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_no_batches(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_record = self.inbound_v2(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.inbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_one_batch(self):
        msg = self.msg_helper.make_inbound("inbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.inbound_v2(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.inbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_in_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v2_to_v3_two_batches(self):
        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.inbound_v2(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.inbound_v3.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v3_to_v2(self):
        """
        A v3 model can be stored in a v2-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.inbound_v3._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 2

        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.inbound_v3(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.inbound_v2.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])

    @inlineCallbacks
    def test_migrate_v3_to_v4_no_batches(self):
        """
        A v3 model with no batches gets no extra indexes when migrated to v4.
        """
        msg = self.msg_helper.make_inbound("inbound")
        old_record = self.inbound_v3(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.inbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_timestamps), set([]))
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v3_to_v4_one_batch(self):
        """
        A v3 model with one batche gets one extra index when migrated to v4.
        """
        msg = self.msg_helper.make_inbound("inbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.inbound_v3(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.inbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_timestamps),
            set([bwt_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_in_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse),
            set([bwar_in_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwt_index(bwt_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v3_to_v4_two_batches(self):
        """
        A v3 model with two batches gets two extra indexes when migrated to v4.
        """
        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.inbound_v3(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.inbound_v4.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwt_index(bwt_value("batch-1", msg)),
            bwt_index(bwt_value("batch-2", msg)),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-2", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v4_to_v3(self):
        """
        A v4 model can be stored in a v3-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.inbound_v4._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 3

        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.inbound_v4(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.inbound_v3.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])

    @inlineCallbacks
    def test_migrate_v4_to_v5_no_batches(self):
        """
        A v4 model with no batches gets no fewer indexes when migrated to v5.
        """
        msg = self.msg_helper.make_inbound("inbound")
        old_record = self.inbound_v4(msg["message_id"], msg=msg)
        yield old_record.save()
        new_record = yield self.inbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

        yield new_record.save()
        self.assertEqual(set(new_record.batches_with_addresses), set([]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([]))

    @inlineCallbacks
    def test_migrate_v4_to_v5_one_batch(self):
        """
        A v4 model with one batch gets one fewer index when migrated to v5.
        """
        msg = self.msg_helper.make_inbound("inbound")
        old_batch = self.batch_vnone(key=u"batch-1")
        old_record = self.inbound_v4(msg["message_id"], msg=msg)
        old_record.batches.add_key(old_batch.key)
        yield old_record.save()
        new_record = yield self.inbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [old_batch.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.

        self.assertEqual(
            set(new_record.batches_with_addresses_reverse), set([]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(
            set(new_record.batches_with_addresses),
            set([bwa_in_value("batch-1", msg)]))
        self.assertEqual(
            set(new_record.batches_with_addresses_reverse),
            set([bwar_in_value("batch-1", msg)]))
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
        ]))

    @inlineCallbacks
    def test_migrate_v4_to_v5_two_batches(self):
        """
        A v4 model with two batches gets two fewer indexes when migrated to v5.
        """
        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        old_record = self.inbound_v4(msg["message_id"], msg=msg)
        old_record.batches.add_key(batch_1.key)
        old_record.batches.add_key(batch_2.key)
        yield old_record.save()
        new_record = yield self.inbound_v5.load(old_record.key)
        self.assertEqual(new_record.msg, msg)
        self.assertEqual(new_record.batches.keys(), [batch_1.key, batch_2.key])

        # The migration doesn't set the new fields and indexes, that only
        # happens at save time.
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-2", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-2", msg)),
        ]))

        yield new_record.save()
        self.assertEqual(new_record._riak_object.get_indexes(), set([
            batch_index("batch-1"),
            batch_index("batch-2"),
            bwa_index(bwa_in_value("batch-1", msg)),
            bwa_index(bwa_in_value("batch-2", msg)),
            bwar_index(bwar_in_value("batch-1", msg)),
            bwar_index(bwar_in_value("batch-2", msg)),
        ]))

    @inlineCallbacks
    def test_reverse_migrate_v5_to_v4(self):
        """
        A v5 model can be stored in a v4-compatible way.
        """
        # Configure the manager to save the older message version.
        modelcls = self.inbound_v5._modelcls
        model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
        self.manager.store_versions[model_name] = 4

        msg = self.msg_helper.make_inbound("inbound")
        batch_1 = self.batch_vnone(key=u"batch-1")
        batch_2 = self.batch_vnone(key=u"batch-2")
        new_record = self.inbound_v5(msg["message_id"], msg=msg)
        new_record.batches.add_key(batch_1.key)
        new_record.batches.add_key(batch_2.key)
        yield new_record.save()

        old_record = yield self.inbound_v4.load(new_record.key)
        self.assertEqual(old_record.msg, msg)
        self.assertEqual(old_record.batches.keys(), [batch_1.key, batch_2.key])
PK=JG!vumi/components/tests/__init__.pyPK=JGl%%,vumi/components/tests/test_window_manager.pyfrom twisted.internet.defer import inlineCallbacks
from twisted.internet.task import Clock

from vumi.components.window_manager import WindowManager, WindowException
from vumi.tests.helpers import VumiTestCase, PersistenceHelper


class TestWindowManager(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        redis = yield self.persistence_helper.get_redis_manager()
        self.window_id = 'window_id'

        # Patch the clock so we can control time
        self.clock = Clock()
        self.patch(WindowManager, 'get_clock', lambda _: self.clock)

        self.wm = WindowManager(redis, window_size=10, flight_lifetime=10)
        self.add_cleanup(self.wm.stop)
        yield self.wm.create_window(self.window_id)
        self.redis = self.wm.redis

    @inlineCallbacks
    def test_windows(self):
        windows = yield self.wm.get_windows()
        self.assertTrue(self.window_id in windows)

    def test_strict_window_recreation(self):
        return self.assertFailure(
            self.wm.create_window(self.window_id, strict=True),
                                    WindowException)

    @inlineCallbacks
    def test_window_recreation(self):
        orig_clock_time = self.clock.seconds()
        clock_time = yield self.wm.create_window(self.window_id)
        self.assertEqual(clock_time, orig_clock_time)

    @inlineCallbacks
    def test_window_removal(self):
        yield self.wm.add(self.window_id, 1)
        yield self.assertFailure(self.wm.remove_window(self.window_id),
            WindowException)
        key = yield self.wm.get_next_key(self.window_id)
        item = yield self.wm.get_data(self.window_id, key)
        self.assertEqual(item, 1)
        self.assertEqual((yield self.wm.remove_window(self.window_id)), None)

    @inlineCallbacks
    def test_adding_to_window(self):
        for i in range(10):
            yield self.wm.add(self.window_id, i)
        window_key = self.wm.window_key(self.window_id)
        window_members = yield self.redis.llen(window_key)
        self.assertEqual(window_members, 10)

    @inlineCallbacks
    def test_fetching_from_window(self):
        for i in range(12):
            yield self.wm.add(self.window_id, i)

        flight_keys = []
        for i in range(10):
            flight_key = yield self.wm.get_next_key(self.window_id)
            self.assertTrue(flight_key)
            flight_keys.append(flight_key)

        out_of_window_flight = yield self.wm.get_next_key(self.window_id)
        self.assertEqual(out_of_window_flight, None)

        # We should get data out in the order we put it in
        for i, flight_key in enumerate(flight_keys):
            data = yield self.wm.get_data(self.window_id, flight_key)
            self.assertEqual(data, i)

        # Removing one should allow for space for the next to fill up
        yield self.wm.remove_key(self.window_id, flight_keys[0])
        next_flight_key = yield self.wm.get_next_key(self.window_id)
        self.assertTrue(next_flight_key)

    @inlineCallbacks
    def test_set_and_external_id(self):
        yield self.wm.set_external_id(self.window_id, "flight_key",
                                      "external_id")
        self.assertEqual(
            (yield self.wm.get_external_id(self.window_id, "flight_key")),
            "external_id")
        self.assertEqual(
            (yield self.wm.get_internal_id(self.window_id, "external_id")),
            "flight_key")

    @inlineCallbacks
    def test_remove_key_removes_external_and_internal_id(self):
        yield self.wm.set_external_id(self.window_id, "flight_key",
                                      "external_id")
        yield self.wm.remove_key(self.window_id, "flight_key")
        self.assertEqual(
            (yield self.wm.get_external_id(self.window_id, "flight_key")),
            None)
        self.assertEqual(
            (yield self.wm.get_internal_id(self.window_id, "external_id")),
            None)

    @inlineCallbacks
    def assert_count_waiting(self, window_id, amount):
        self.assertEqual((yield self.wm.count_waiting(window_id)), amount)

    @inlineCallbacks
    def assert_expired_keys(self, window_id, amount):
        # Stuff has taken too long and so we should get 10 expired keys
        expired_keys = yield self.wm.get_expired_flight_keys(window_id)
        self.assertEqual(len(expired_keys), amount)

    @inlineCallbacks
    def assert_in_flight(self, window_id, amount):
        self.assertEqual((yield self.wm.count_in_flight(window_id)),
            amount)

    @inlineCallbacks
    def slide_window(self, limit=10):
        for i in range(limit):
            yield self.wm.get_next_key(self.window_id)

    @inlineCallbacks
    def test_expiry_of_acks(self):

        def mock_clock_time(self):
            return self._clocktime

        self.patch(WindowManager, 'get_clocktime', mock_clock_time)
        self.wm._clocktime = 0

        for i in range(30):
            yield self.wm.add(self.window_id, i)

        # We're manually setting the clock instead of using clock.advance()
        # so we can wait for the deferreds to finish before continuing to the
        # next clear_expired_flight_keys run since LoopingCall() will only fire
        # again if the previous run has completed.
        yield self.slide_window()
        self.wm._clocktime = 10
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 10)

        yield self.slide_window()
        self.wm._clocktime = 20
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 20)

        yield self.slide_window()
        self.wm._clocktime = 30
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 30)

        self.assert_in_flight(self.window_id, 0)
        self.assert_count_waiting(self.window_id, 0)

    @inlineCallbacks
    def test_monitor_windows(self):
        yield self.wm.remove_window(self.window_id)

        window_ids = ['window_id_1', 'window_id_2']
        for window_id in window_ids:
            yield self.wm.create_window(window_id)
            for i in range(20):
                yield self.wm.add(window_id, i)

        key_callbacks = {}

        def callback(window_id, key):
            key_callbacks.setdefault(window_id, []).append(key)

        cleanup_callbacks = []

        def cleanup_callback(window_id):
            cleanup_callbacks.append(window_id)

        yield self.wm._monitor_windows(callback, False)

        self.assertEqual(set(key_callbacks.keys()), set(window_ids))
        self.assertEqual(len(key_callbacks.values()[0]), 10)
        self.assertEqual(len(key_callbacks.values()[1]), 10)

        yield self.wm._monitor_windows(callback, False)

        # Nothing should've changed since we haven't removed anything.
        self.assertEqual(len(key_callbacks.values()[0]), 10)
        self.assertEqual(len(key_callbacks.values()[1]), 10)

        for window_id, keys in key_callbacks.items():
            for key in keys:
                yield self.wm.remove_key(window_id, key)

        yield self.wm._monitor_windows(callback, False)
        # Everything should've been processed now
        self.assertEqual(len(key_callbacks.values()[0]), 20)
        self.assertEqual(len(key_callbacks.values()[1]), 20)

        # Now run again but cleanup the empty windows
        self.assertEqual(set((yield self.wm.get_windows())), set(window_ids))
        for window_id, keys in key_callbacks.items():
            for key in keys:
                yield self.wm.remove_key(window_id, key)

        yield self.wm._monitor_windows(callback, True, cleanup_callback)
        self.assertEqual(len(key_callbacks.values()[0]), 20)
        self.assertEqual(len(key_callbacks.values()[1]), 20)
        self.assertEqual((yield self.wm.get_windows()), [])
        self.assertEqual(set(cleanup_callbacks), set(window_ids))


class TestConcurrentWindowManager(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        redis = yield self.persistence_helper.get_redis_manager()
        self.window_id = 'window_id'

        # Patch the count_waiting so we can fake the race condition
        self.clock = Clock()
        self.patch(WindowManager, 'count_waiting', lambda _, window_id: 100)

        self.wm = WindowManager(redis, window_size=10, flight_lifetime=10)
        self.add_cleanup(self.wm.stop)
        yield self.wm.create_window(self.window_id)
        self.redis = self.wm.redis

    @inlineCallbacks
    def test_race_condition(self):
        """
        A race condition can occur when multiple window managers try and
        access the same window at the same time.

        A LoopingCall loops over the available windows, for those windows
        it tries to get a next key. It does that by checking how many are
        waiting to be sent out and adding however many it can still carry
        to its own flight.

        Since there are concurrent workers, between the time of checking how
        many are available and how much room it has available, a different
        window manager may have already beaten it to it.

        If this happens Redis' `rpoplpush` method will return None since
        there are no more available keys for the given window.
        """
        yield self.wm.add(self.window_id, 1)
        yield self.wm.add(self.window_id, 2)
        yield self.wm._monitor_windows(lambda *a: True, True)
        self.assertEqual((yield self.wm.get_next_key(self.window_id)), None)
PK=JGtk=,,/vumi/components/tests/test_message_store_api.pyimport json
from datetime import datetime, timedelta

from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue, Deferred

from vumi.utils import http_request_full
from vumi.message import TransportUserMessage

from vumi.tests.helpers import (
    VumiTestCase, MessageHelper, WorkerHelper, PersistenceHelper, import_skip,
)


class TestMessageStoreAPI(VumiTestCase):
    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True))
        try:
            from vumi.components.message_store_api import (
                MatchResource, MessageStoreAPIWorker)
        except ImportError, e:
            import_skip(e, 'riak')

        self.msg_helper = self.add_helper(MessageHelper())
        self.worker_helper = self.add_helper(WorkerHelper())

        self.match_resource = MatchResource
        self.base_path = '/api/v1/'
        self.worker = yield self.worker_helper.get_worker(
            MessageStoreAPIWorker, self.persistence_helper.mk_config({
                'web_path': self.base_path,
                'web_port': 0,
                'health_path': '/health/',
            }))
        self.store = self.worker.store
        self.addr = self.worker.webserver.getHost()
        self.url = 'http://%s:%s%s' % (self.addr.host, self.addr.port,
                                        self.base_path)

        self.tag = ("pool", "tag")
        self.batch_id = yield self.store.batch_start([self.tag])

    @inlineCallbacks
    def create_inbound(self, batch_id, count, content_template):
        messages = []
        now = datetime.now()
        for i in range(count):
            msg = self.msg_helper.make_inbound(
                content_template.format(i),
                timestamp=(now - timedelta(i * 10)))
            yield self.store.add_inbound_message(msg, batch_id=batch_id)
            messages.append(msg)
        returnValue(messages)

    @inlineCallbacks
    def create_outbound(self, batch_id, count, content_template):
        messages = []
        now = datetime.now()
        for i in range(count):
            msg = self.msg_helper.make_outbound(
                content_template.format(i),
                timestamp=(now - timedelta(i * 10)))
            yield self.store.add_outbound_message(msg, batch_id=batch_id)
            messages.append(msg)
        returnValue(messages)

    def do_get(self, path, headers={}):
        url = '%s%s' % (self.url, path)
        return http_request_full(url, headers=headers, method='GET')

    def do_post(self, path, data, headers={}):
        url = '%s%s' % (self.url, path)
        default_headers = {
                'Content-Type': 'application/json; charset=utf-8',
            }
        default_headers.update(headers)
        return http_request_full(url, data=json.dumps(data),
            headers=default_headers, method='POST')

    def wait_for_results(self, direction, batch_id, token):
        url = '%sbatch/%s/%s/match/?token=%s' % (self.url, batch_id,
                                                        direction, token)

        @inlineCallbacks
        def check(d):
            response = yield http_request_full(url, method='GET')
            [progress_status] = response.headers.getRawHeaders(
                self.match_resource.RESP_IN_PROGRESS_HEADER)
            if progress_status == '0':
                d.callback(response)
            else:
                reactor.callLater(0, check, d)

        done = Deferred()
        reactor.callLater(0, check, done)
        return done

    @inlineCallbacks
    def do_query(self, direction, batch_id, pattern, key='msg.content',
                    flags='i', wait=False):
        query = [{
            'key': key,
            'pattern': pattern,
            'flags': flags,
        }]
        if wait:
            headers = {self.match_resource.REQ_WAIT_HEADER: '1'}
        else:
            headers = {}

        expected_token = self.store.cache.get_query_token(direction, query)
        response = yield self.do_post('batch/%s/%s/match/' % (
            self.batch_id, direction), query, headers=headers)
        [token] = response.headers.getRawHeaders(
            self.match_resource.RESP_TOKEN_HEADER)
        self.assertEqual(token, expected_token)
        self.assertEqual(response.code, 200)
        returnValue(token)

    def assertResultCount(self, response, count):
        in_progress = response.headers.getRawHeaders(
            self.match_resource.RESP_IN_PROGRESS_HEADER)[0]
        assert in_progress == "0", "Query still in progress."
        self.assertEqual(
            response.headers.getRawHeaders(
                self.match_resource.RESP_COUNT_HEADER),
            [str(count)])

    def assertJSONResultEqual(self, json_blob, messages):
        """
        Asserts that the JSON response we're getting back is the same as
        the list of messages provided.

        There are easier ways to do this by comparing bigger JSON blogs
        but then debugging the huge strings would be a pain.
        """
        dictionaries = json.loads(json_blob)
        self.assertEqual(len(dictionaries), len(messages),
            'Unequal amount of dictionaries and messages')
        for dictionary, message in zip(dictionaries, messages):
            # The json dumping & reloading happening here is required to have
            # the timestamp fields be parsed properly. This is an unfortunate
            # side effect of how how timestamps are currently stored as
            # datetime() instances in the payload instead of plain strings.
            self.assertEqual(
                TransportUserMessage(_process_fields=False, **message.payload),
                TransportUserMessage.from_json(json.dumps(dictionary)))

    @inlineCallbacks
    def test_batch_index_resource(self):
        response = yield self.do_get('batch/')
        self.assertEqual(response.delivered_body, '')
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_batch_resource(self):
        response = yield self.do_get('batch/%s/' % (self.batch_id))
        self.assertEqual(response.delivered_body, self.batch_id)
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_waiting_inbound_match_resource(self):
        messages = yield self.create_inbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('inbound', self.batch_id, '.*',
                                                wait=True)
        response = yield self.do_get('batch/%s/inbound/match/?token=%s' % (
            self.batch_id, token))
        self.assertResultCount(response, 22)
        current_page = messages[:self.match_resource.DEFAULT_RESULT_SIZE]
        self.assertJSONResultEqual(response.delivered_body, current_page)
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_keys_inbound_match_resource(self):
        messages = yield self.create_inbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('inbound', self.batch_id, '.*',
                                                wait=True)
        response = yield self.do_get(
            'batch/%s/inbound/match/?token=%s&keys=1' % (
                self.batch_id, token))
        self.assertResultCount(response, 22)
        current_page = messages[:self.match_resource.DEFAULT_RESULT_SIZE]
        self.assertEqual(json.loads(response.delivered_body),
            [msg['message_id'] for msg in current_page])
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_polling_inbound_match_resource(self):
        messages = yield self.create_inbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('inbound', self.batch_id, '.*',
                                                wait=False)
        response = yield self.wait_for_results('inbound', self.batch_id, token)
        self.assertResultCount(response, 22)
        page = messages[:20]
        self.assertJSONResultEqual(response.delivered_body, page)
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_empty_inbound_match_resource(self):
        expected_token = yield self.do_query(
            'inbound', self.batch_id, '.*', wait=True)
        response = yield self.do_get('batch/%s/inbound/match/?token=%s' % (
            self.batch_id, expected_token))
        self.assertResultCount(response, 0)
        self.assertEqual(json.loads(response.delivered_body), [])
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_waiting_outbound_match_resource(self):
        messages = yield self.create_outbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('outbound', self.batch_id, '.*',
                                                wait=True)
        response = yield self.do_get('batch/%s/outbound/match/?token=%s' % (
            self.batch_id, token))
        self.assertResultCount(response, 22)
        current_page = messages[:self.match_resource.DEFAULT_RESULT_SIZE]
        self.assertJSONResultEqual(response.delivered_body, current_page)
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_keys_outbound_match_resource(self):
        messages = yield self.create_outbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('outbound', self.batch_id, '.*',
                                                wait=True)
        response = yield self.do_get(
            'batch/%s/outbound/match/?token=%s&keys=1' % (
                self.batch_id, token))
        self.assertResultCount(response, 22)
        current_page = messages[:self.match_resource.DEFAULT_RESULT_SIZE]
        self.assertEqual(json.loads(response.delivered_body),
            [msg['message_id'] for msg in current_page])
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_polling_outbound_match_resource(self):
        messages = yield self.create_outbound(self.batch_id, 22,
                                                'hello world {0}')
        token = yield self.do_query('outbound', self.batch_id, '.*',
                                                wait=False)
        response = yield self.wait_for_results('outbound', self.batch_id,
                                                token)
        self.assertResultCount(response, 22)
        page = messages[:20]
        self.assertJSONResultEqual(response.delivered_body, page)
        self.assertEqual(response.code, 200)

    @inlineCallbacks
    def test_empty_outbound_match_resource(self):
        expected_token = yield self.do_query(
            'outbound', self.batch_id, '.*', wait=True)
        response = yield self.do_get('batch/%s/outbound/match/?token=%s' % (
            self.batch_id, expected_token))
        self.assertResultCount(response, 0)
        self.assertEqual(json.loads(response.delivered_body), [])
        self.assertEqual(response.code, 200)
PK=JGxRRR1vumi/components/tests/test_message_store_cache.py# -*- coding: utf-8 -*-

"""Tests for vumi.components.message_store_cache."""

from datetime import datetime, timedelta

from twisted.internet.defer import inlineCallbacks, returnValue

from vumi.tests.helpers import (
    VumiTestCase, MessageHelper, PersistenceHelper, import_skip,
)


class MessageStoreCacheTestCase(VumiTestCase):

    start_batch = True

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True))
        try:
            from vumi.components.message_store import MessageStore
        except ImportError, e:
            import_skip(e, 'riak')
        self.redis = yield self.persistence_helper.get_redis_manager()
        self.manager = yield self.persistence_helper.get_riak_manager()
        self.add_cleanup(self.manager.close_manager)
        self.store = yield MessageStore(self.manager, self.redis)
        self.cache = self.store.cache
        self.batch_id = 'a-batch-id'
        if self.start_batch:
            yield self.cache.batch_start(self.batch_id)
        self.msg_helper = self.add_helper(MessageHelper())

    @inlineCallbacks
    def add_messages(self, batch_id, callback, now=None, count=10):
        messages = []
        now = (datetime.now() if now is None else now)
        for i in range(count):
            msg = self.msg_helper.make_inbound(
                "inbound",
                from_addr='from-%s' % (i,),
                to_addr='to-%s' % (i,))
            msg['timestamp'] = now - timedelta(seconds=i)
            yield callback(batch_id, msg)
            messages.append(msg)
        returnValue(messages)

    @inlineCallbacks
    def add_event_pairs(self, batch_id, now=None, count=10):
        messages = []
        now = (datetime.now() if now is None else now)
        for i in range(count):
            msg = self.msg_helper.make_inbound(
                "inbound",
                from_addr='from-%s' % (i,),
                to_addr='to-%s' % (i,))
            msg['timestamp'] = now - timedelta(seconds=i)
            yield self.cache.add_outbound_message(batch_id, msg)
            ack = self.msg_helper.make_ack(msg)
            delivery = self.msg_helper.make_delivery_report(msg)
            yield self.cache.add_event(self.batch_id, ack)
            yield self.cache.add_event(self.batch_id, delivery)
            messages.extend((ack, delivery))
        returnValue(messages)


class TestMessageStoreCache(MessageStoreCacheTestCase):

    @inlineCallbacks
    def test_add_outbound_message(self):
        msg = self.msg_helper.make_outbound("outbound")
        yield self.cache.add_outbound_message(self.batch_id, msg)
        [msg_key] = yield self.cache.get_outbound_message_keys(self.batch_id)
        self.assertEqual(msg_key, msg['message_id'])

    @inlineCallbacks
    def test_get_outbound_message_keys(self):
        messages = yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message)
        # make sure we get keys back ordered according to timestamp, which
        # means the reverse of how we put them in.
        keys = yield self.cache.get_outbound_message_keys(self.batch_id)
        self.assertEqual(len(keys), 10)
        self.assertEqual(keys, list([m['message_id'] for m in messages]))

    @inlineCallbacks
    def test_count_outbound_message_keys(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message)
        count = yield self.cache.count_outbound_message_keys(self.batch_id)
        self.assertEqual(count, 10)

    @inlineCallbacks
    def test_paged_get_outbound_message_keys(self):
        messages = yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message)
        # make sure we get keys back ordered according to timestamp, which
        # means the reverse of how we put them in.
        keys = yield self.cache.get_outbound_message_keys(self.batch_id, 0, 4)
        self.assertEqual(len(keys), 5)
        self.assertEqual(keys, list([m['message_id'] for m in messages])[:5])

    @inlineCallbacks
    def test_get_batch_ids(self):
        yield self.cache.batch_start('batch-1')
        yield self.cache.add_outbound_message(
            'batch-1', self.msg_helper.make_outbound("outbound"))
        yield self.cache.batch_start('batch-2')
        yield self.cache.add_outbound_message(
            'batch-2', self.msg_helper.make_outbound("outbound"))
        self.assertEqual((yield self.cache.get_batch_ids()), set([
            self.batch_id, 'batch-1', 'batch-2']))

    @inlineCallbacks
    def test_add_inbound_message(self):
        msg = self.msg_helper.make_inbound("inbound")
        yield self.cache.add_inbound_message(self.batch_id, msg)
        [msg_key] = yield self.cache.get_inbound_message_keys(self.batch_id)
        self.assertEqual(msg_key, msg['message_id'])

    @inlineCallbacks
    def test_get_inbound_message_keys(self):
        messages = yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message)
        keys = yield self.cache.get_inbound_message_keys(self.batch_id)
        self.assertEqual(len(keys), 10)
        self.assertEqual(keys, list([m['message_id'] for m in messages]))

    @inlineCallbacks
    def test_count_inbound_message_keys(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message)
        count = yield self.cache.count_inbound_message_keys(self.batch_id)
        self.assertEqual(count, 10)

    @inlineCallbacks
    def test_paged_get_inbound_message_keys(self):
        messages = yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message)
        keys = yield self.cache.get_inbound_message_keys(self.batch_id, 0, 4)
        self.assertEqual(len(keys), 5)
        self.assertEqual(keys, list([m['message_id'] for m in messages])[:5])

    @inlineCallbacks
    def test_get_from_addrs(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message)
        from_addrs = yield self.cache.get_from_addrs(self.batch_id)
        # NOTE: This functionality is disabled for now.
        # self.assertEqual(from_addrs, ['from-%s' % i for i in range(10)])
        self.assertEqual(from_addrs, [])

    @inlineCallbacks
    def test_count_from_addrs(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message)
        count = yield self.cache.count_from_addrs(self.batch_id)
        self.assertEqual(count, 10)

    @inlineCallbacks
    def test_get_to_addrs(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message)
        to_addrs = yield self.cache.get_to_addrs(self.batch_id)
        # NOTE: This functionality is disabled for now.
        # self.assertEqual(to_addrs, ['to-%s' % i for i in range(10)])
        self.assertEqual(to_addrs, [])

    @inlineCallbacks
    def test_count_to_addrs(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message)
        count = yield self.cache.count_to_addrs(self.batch_id)
        self.assertEqual(count, 10)

    @inlineCallbacks
    def test_add_event(self):
        msg = self.msg_helper.make_outbound("outbound")
        self.cache.add_outbound_message(self.batch_id, msg)
        ack = self.msg_helper.make_ack(msg)
        delivery = self.msg_helper.make_delivery_report(msg)
        yield self.cache.add_event(self.batch_id, ack)
        yield self.cache.add_event(self.batch_id, delivery)
        event_count = yield self.cache.count_event_keys(self.batch_id)
        self.assertEqual(event_count, 2)
        status = yield self.cache.get_event_status(self.batch_id)
        self.assertEqual(status, {
            'delivery_report': 1,
            'delivery_report.delivered': 1,
            'delivery_report.failed': 0,
            'delivery_report.pending': 0,
            'ack': 1,
            'nack': 0,
            'sent': 1,
        })

    @inlineCallbacks
    def test_add_event_idempotence(self):
        msg = self.msg_helper.make_outbound("outbound")
        self.cache.add_outbound_message(self.batch_id, msg)
        acks = [self.msg_helper.make_ack(msg) for i in range(10)]
        for ack in acks:
            # send exact same event multiple times
            ack['event_id'] = 'identical'
            yield self.cache.add_event(self.batch_id, ack)
        event_count = yield self.cache.count_event_keys(self.batch_id)
        self.assertEqual(event_count, 1)
        status = yield self.cache.get_event_status(self.batch_id)
        self.assertEqual(status, {
            'delivery_report': 0,
            'delivery_report.delivered': 0,
            'delivery_report.failed': 0,
            'delivery_report.pending': 0,
            'ack': 1,
            'nack': 0,
            'sent': 1,
        })

    @inlineCallbacks
    def test_add_outbound_message_idempotence(self):
        for i in range(10):
            msg = self.msg_helper.make_outbound("outbound")
            msg['message_id'] = 'the-same-thing'
            yield self.cache.add_outbound_message(self.batch_id, msg)
        outbound_count = yield self.cache.count_outbound_message_keys(
            self.batch_id)
        self.assertEqual(outbound_count, 1)
        status = yield self.cache.get_event_status(self.batch_id)
        self.assertEqual(status['sent'], 1)
        self.assertEqual(
            (yield self.cache.get_outbound_message_keys(self.batch_id)),
            ['the-same-thing'])

    @inlineCallbacks
    def test_add_inbound_message_idempotence(self):
        for i in range(10):
            msg = self.msg_helper.make_inbound("inbound")
            msg['message_id'] = 'the-same-thing'
            yield self.cache.add_inbound_message(self.batch_id, msg)
        inbound_count = yield self.cache.count_inbound_message_keys(
            self.batch_id)
        self.assertEqual(inbound_count, 1)
        self.assertEqual(
            (yield self.cache.get_inbound_message_keys(self.batch_id)),
            ['the-same-thing'])

    @inlineCallbacks
    def test_clear_batch(self):
        msg_in = self.msg_helper.make_inbound("inbound")
        msg_out = self.msg_helper.make_outbound("outbound")
        ack = self.msg_helper.make_ack(msg_out)
        dr = self.msg_helper.make_delivery_report(
            msg_out, delivery_status='delivered')
        yield self.cache.add_inbound_message(self.batch_id, msg_in)
        yield self.cache.add_outbound_message(self.batch_id, msg_out)
        yield self.cache.add_event(self.batch_id, ack)
        yield self.cache.add_event(self.batch_id, dr)

        self.assertEqual(
            (yield self.cache.get_event_status(self.batch_id)),
            {
                'ack': 1,
                'delivery_report': 1,
                'delivery_report.delivered': 1,
                'delivery_report.failed': 0,
                'delivery_report.pending': 0,
                'nack': 0,
                'sent': 1,
            })
        yield self.cache.clear_batch(self.batch_id)
        yield self.cache.batch_start(self.batch_id)
        self.assertEqual(
            (yield self.cache.get_event_status(self.batch_id)),
            {
                'ack': 0,
                'delivery_report': 0,
                'delivery_report.delivered': 0,
                'delivery_report.failed': 0,
                'delivery_report.pending': 0,
                'nack': 0,
                'sent': 0,
            })
        self.assertEqual(
            (yield self.cache.count_from_addrs(self.batch_id)), 0)
        self.assertEqual(
            (yield self.cache.count_to_addrs(self.batch_id)), 0)
        self.assertEqual(
            (yield self.cache.count_inbound_message_keys(self.batch_id)), 0)
        self.assertEqual(
            (yield self.cache.count_outbound_message_keys(self.batch_id)), 0)

    @inlineCallbacks
    def test_count_inbound_throughput(self):
        # test for empty batches.
        self.assertEqual(
            (yield self.cache.count_inbound_throughput(self.batch_id)), 0)

        now = datetime.now()
        for i in range(10):
            msg_in = self.msg_helper.make_inbound("inbound")
            msg_in['timestamp'] = now - timedelta(seconds=i * 10)
            yield self.cache.add_inbound_message(self.batch_id, msg_in)

        self.assertEqual(
            (yield self.cache.count_inbound_throughput(self.batch_id)), 10)
        self.assertEqual(
            (yield self.cache.count_inbound_throughput(
                self.batch_id, sample_time=1)), 1)
        self.assertEqual(
            (yield self.cache.count_inbound_throughput(
                self.batch_id, sample_time=10)), 2)

    @inlineCallbacks
    def test_count_outbound_throughput(self):
        # test for empty batches.
        self.assertEqual(
            (yield self.cache.count_outbound_throughput(self.batch_id)), 0)

        now = datetime.now()
        for i in range(10):
            msg_out = self.msg_helper.make_outbound("outbound")
            msg_out['timestamp'] = now - timedelta(seconds=i * 10)
            yield self.cache.add_outbound_message(self.batch_id, msg_out)

        self.assertEqual(
            (yield self.cache.count_outbound_throughput(self.batch_id)), 10)
        self.assertEqual(
            (yield self.cache.count_outbound_throughput(
                self.batch_id, sample_time=1)), 1)
        self.assertEqual(
            (yield self.cache.count_outbound_throughput(
                self.batch_id, sample_time=10)), 2)

    def test_get_query_token(self):
        cache = self.store.cache
        # different ordering in the dict should result in the same token.
        token1 = cache.get_query_token('inbound', [{'a': 'b', 'c': 'd'}])
        token2 = cache.get_query_token('inbound', [{'c': 'd', 'a': 'b'}])
        self.assertEqual(token1, token2)

    @inlineCallbacks
    def test_start_query(self):
        token = yield self.cache.start_query(self.batch_id, 'inbound', [
            {'key': 'key', 'pattern': 'pattern', 'flags': 'flags'}])
        self.assertTrue(
            (yield self.cache.is_query_in_progress(self.batch_id, token)))

    @inlineCallbacks
    def test_store_query_results(self):
        now = datetime.now()
        message_ids = []
        # NOTE: we're writing these messages oldest first so we can check
        #       that the cache is returning them in the correct order
        #       when we pull out the search results.
        for i in range(10):
            msg_in = self.msg_helper.make_inbound('hello-%s' % (i,))
            msg_in['timestamp'] = now + timedelta(seconds=i * 10)
            yield self.cache.add_inbound_message(self.batch_id, msg_in)
            message_ids.append(msg_in['message_id'])

        token = yield self.cache.start_query(self.batch_id, 'inbound', [
            {'key': 'msg.content', 'pattern': 'hello', 'flags': ''}])
        yield self.cache.store_query_results(
            self.batch_id, token, message_ids, 'inbound', 120)
        self.assertFalse(
            (yield self.cache.is_query_in_progress(self.batch_id, token)))
        self.assertEqual(
            (yield self.cache.get_query_results(self.batch_id, token)),
            list(reversed(message_ids)))
        self.assertEqual(
            (yield self.cache.count_query_results(self.batch_id, token)),
            10)


class TestMessageStoreCacheWithCounters(MessageStoreCacheTestCase):

    start_batch = False

    @inlineCallbacks
    def test_switching_to_counters_outbound(self):
        self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT = 7

        for i in range(10):
            msg = self.msg_helper.make_outbound("outbound")
            yield self.cache.add_outbound_message(self.batch_id, msg)

        self.assertFalse((yield self.cache.uses_counters(self.batch_id)))
        self.assertEqual(
            (yield self.cache.count_outbound_message_keys(self.batch_id)),
            10)
        yield self.cache.switch_to_counters(self.batch_id)
        self.assertTrue((yield self.cache.uses_counters(self.batch_id)))
        self.assertEqual(
            (yield self.cache.count_outbound_message_keys(self.batch_id)),
            10)

        outbound = yield self.cache.get_outbound_message_keys(self.batch_id)
        self.assertEqual(len(outbound), 7)

    @inlineCallbacks
    def test_switching_to_counters_inbound(self):
        self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT = 7

        for i in range(10):
            msg = self.msg_helper.make_inbound("inbound")
            yield self.cache.add_inbound_message(self.batch_id, msg)

        self.assertFalse((yield self.cache.uses_counters(self.batch_id)))
        self.assertEqual(
            (yield self.cache.count_inbound_message_keys(self.batch_id)),
            10)
        yield self.cache.switch_to_counters(self.batch_id)
        self.assertTrue((yield self.cache.uses_counters(self.batch_id)))
        self.assertEqual(
            (yield self.cache.count_inbound_message_keys(self.batch_id)),
            10)

        inbound = yield self.cache.get_inbound_message_keys(self.batch_id)
        self.assertEqual(len(inbound), 7)

    @inlineCallbacks
    def test_inbound_truncate_at_within_limits(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message, count=10)
        count = yield self.cache.truncate_inbound_message_keys(
            self.batch_id, truncate_at=11)
        self.assertEqual(count, 0)

    @inlineCallbacks
    def test_inbound_truncate_at_over_limits(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_inbound_message, count=10)
        count = yield self.cache.truncate_inbound_message_keys(
            self.batch_id, truncate_at=7)
        self.assertEqual(count, 3)
        self.assertEqual((yield self.cache.count_inbound_message_keys(
            self.batch_id)), 7)

    @inlineCallbacks
    def test_outbound_truncate_at_within_limits(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message, count=10)
        count = yield self.cache.truncate_outbound_message_keys(
            self.batch_id, truncate_at=11)
        self.assertEqual(count, 0)

    @inlineCallbacks
    def test_outbound_truncate_at_over_limits(self):
        yield self.add_messages(
            self.batch_id, self.cache.add_outbound_message, count=10)
        count = yield self.cache.truncate_outbound_message_keys(
            self.batch_id, truncate_at=7)
        self.assertEqual(count, 3)
        self.assertEqual((yield self.cache.count_outbound_message_keys(
            self.batch_id)), 7)

    @inlineCallbacks
    def test_event_truncate_at_within_limits(self):
        # We need to use counters here, but have a default truncation limit
        # higher than what we're testing.
        self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT = 100
        yield self.cache.batch_start(self.batch_id, use_counters=True)

        yield self.add_event_pairs(self.batch_id, count=5)
        removed_count = yield self.cache.truncate_event_keys(
            self.batch_id, truncate_at=11)
        self.assertEqual(removed_count, 0)
        key_count = yield self.cache.redis.zcard(
            self.cache.event_key(self.batch_id))
        self.assertEqual(key_count, 10)

    @inlineCallbacks
    def test_event_truncate_at_over_limits(self):
        # We need to use counters here, but have a default truncation limit
        # higher than what we're testing.
        self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT = 100
        yield self.cache.batch_start(self.batch_id, use_counters=True)

        yield self.add_event_pairs(self.batch_id, count=5)
        removed_count = yield self.cache.truncate_event_keys(
            self.batch_id, truncate_at=7)
        self.assertEqual(removed_count, 3)
        key_count = yield self.cache.redis.zcard(
            self.cache.event_key(self.batch_id))
        self.assertEqual(key_count, 7)

    @inlineCallbacks
    def test_truncation_after_hitting_limit(self):
        truncate_at = 10
        # Check we're actually truncating in the messages
        self.assertTrue(self.cache.uses_counters(self.batch_id))

        start = datetime.now()
        received_messages = []
        for i in range(20):
            now = start + timedelta(seconds=i)
            # populate in ascending timestamp
            [msg] = yield self.add_messages(
                self.batch_id, self.cache.add_inbound_message,
                now=now, count=1)
            received_messages.append(msg)
            # Manually truncate
            yield self.cache.truncate_inbound_message_keys(
                self.batch_id, truncate_at=truncate_at)

        # Get latest 20 messages from cache (there should be 10)
        cached_message_keys = yield self.cache.get_inbound_message_keys(
            self.batch_id, 0, 19)
        # Make sure we're not storing more than we expect to be
        self.assertEqual(len(cached_message_keys), truncate_at)
        # Make sure we're storing the most recent ones
        self.assertEqual(
            set(cached_message_keys),
            set([m['message_id'] for m in received_messages[-truncate_at:]]))
PKh^xGX>=>=+vumi/components/tests/test_message_store.py# -*- coding: utf-8 -*-

"""Tests for vumi.components.message_store."""
import time
from datetime import datetime, timedelta

from twisted.internet.defer import inlineCallbacks, returnValue

from vumi.message import TransportEvent, format_vumi_date
from vumi.tests.helpers import (
    VumiTestCase, MessageHelper, PersistenceHelper, import_skip)

try:
    from vumi.components.message_store import (
        MessageStore, to_reverse_timestamp, from_reverse_timestamp,
        add_batches_to_event)
except ImportError, e:
    import_skip(e, 'riak')


def zero_ms(timestamp):
    dt, dot, ms = format_vumi_date(timestamp).partition(".")
    return dot.join([dt, "0" * len(ms)])


class TestReverseTimestampUtils(VumiTestCase):

    def test_to_reverse_timestamp(self):
        """
        to_reverse_timestamp() turns a vumi_date-formatted string into a
        reverse timestamp.
        """
        self.assertEqual(
            "FFAAE41F25", to_reverse_timestamp("2015-04-01 12:13:14"))
        self.assertEqual(
            "FFAAE41F25", to_reverse_timestamp("2015-04-01 12:13:14.000000"))
        self.assertEqual(
            "FFAAE41F25", to_reverse_timestamp("2015-04-01 12:13:14.999999"))
        self.assertEqual(
            "FFAAE41F24", to_reverse_timestamp("2015-04-01 12:13:15"))
        self.assertEqual(
            "F0F9025FA5", to_reverse_timestamp("4015-04-01 12:13:14"))

    def test_from_reverse_timestamp(self):
        """
        from_reverse_timestamp() is the inverse of to_reverse_timestamp().
        """
        self.assertEqual(
            "2015-04-01 12:13:14.000000", from_reverse_timestamp("FFAAE41F25"))
        self.assertEqual(
            "2015-04-01 12:13:13.000000", from_reverse_timestamp("FFAAE41F26"))
        self.assertEqual(
            "4015-04-01 12:13:14.000000", from_reverse_timestamp("F0F9025FA5"))


class TestMessageStoreBase(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(
            PersistenceHelper(use_riak=True))
        self.redis = yield self.persistence_helper.get_redis_manager()
        self.manager = self.persistence_helper.get_riak_manager()
        self.add_cleanup(self.manager.close_manager)
        self.store = MessageStore(self.manager, self.redis)
        self.msg_helper = self.add_helper(MessageHelper())

    @inlineCallbacks
    def _maybe_batch(self, tag, by_batch):
        add_kw, batch_id = {}, None
        if tag is not None:
            batch_id = yield self.store.batch_start([tag])
            if by_batch:
                add_kw['batch_id'] = batch_id
            else:
                add_kw['tag'] = tag
        returnValue((add_kw, batch_id))

    @inlineCallbacks
    def _create_outbound(self, tag=("pool", "tag"), by_batch=False,
                         content='outbound foo'):
        """Create and store an outbound message."""
        add_kw, batch_id = yield self._maybe_batch(tag, by_batch)
        msg = self.msg_helper.make_outbound(content)
        msg_id = msg['message_id']
        yield self.store.add_outbound_message(msg, **add_kw)
        returnValue((msg_id, msg, batch_id))

    @inlineCallbacks
    def _create_inbound(self, tag=("pool", "tag"), by_batch=False,
                        content='inbound foo'):
        """Create and store an inbound message."""
        add_kw, batch_id = yield self._maybe_batch(tag, by_batch)
        msg = self.msg_helper.make_inbound(
            content, to_addr="+1234567810001", transport_type="sms")
        msg_id = msg['message_id']
        yield self.store.add_inbound_message(msg, **add_kw)
        returnValue((msg_id, msg, batch_id))

    @inlineCallbacks
    def create_outbound_messages(self, batch_id, count, start_timestamp=None,
                                 time_multiplier=10, to_addr=None):
        # Store via message_store
        now = start_timestamp or datetime.now()
        messages = []
        for i in range(count):
            msg = self.msg_helper.make_outbound(
                "foo", timestamp=(now - timedelta(i * time_multiplier)))
            if to_addr is not None:
                msg['to_addr'] = to_addr
            yield self.store.add_outbound_message(msg, batch_id=batch_id)
            messages.append(msg)
        returnValue(messages)

    def _create_event(self, event_type, timestamp):
        maker = {
            'ack': self.msg_helper.make_ack,
            'nack': self.msg_helper.make_nack,
            'delivery_report': self.msg_helper.make_delivery_report,
        }[event_type]
        return maker(timestamp=timestamp)

    @inlineCallbacks
    def create_events(self, batch_id, count, start_timestamp=None,
                      time_multiplier=10, event_mix=None):
        # Store via message_store
        now = start_timestamp or datetime.now()
        events = []
        if event_mix is None:
            event_mix = ['ack', 'nack', 'delivery_report']
        event_types = (event_mix * count)[:count]
        for i, event_type in enumerate(event_types):
            ev = self._create_event(
                event_type, timestamp=(now - timedelta(i * time_multiplier)))
            yield self.store.add_event(ev, batch_ids=[batch_id])
            events.append(ev)
        returnValue(events)

    @inlineCallbacks
    def create_inbound_messages(self, batch_id, count, start_timestamp=None,
                                time_multiplier=10, from_addr=None):
        # Store via message_store
        now = start_timestamp or datetime.now()
        messages = []
        for i in range(count):
            msg = self.msg_helper.make_inbound(
                "foo", timestamp=(now - timedelta(i * time_multiplier)))
            if from_addr is not None:
                msg['from_addr'] = from_addr
            yield self.store.add_inbound_message(msg, batch_id=batch_id)
            messages.append(msg)
        returnValue(messages)

    def _batch_status(self, ack=0, nack=0, delivered=0, failed=0, pending=0,
                      sent=0):
        return {
            'ack': ack, 'nack': nack, 'sent': sent,
            'delivery_report': sum([delivered, failed, pending]),
            'delivery_report.delivered': delivered,
            'delivery_report.failed': failed,
            'delivery_report.pending': pending,
            }


class TestMessageStore(TestMessageStoreBase):

    @inlineCallbacks
    def test_batch_start(self):
        tag1 = ("poolA", "tag1")
        batch_id = yield self.store.batch_start([tag1])
        batch = yield self.store.get_batch(batch_id)
        tag_info = yield self.store.get_tag_info(tag1)
        outbound_keys = yield self.store.batch_outbound_keys(batch_id)
        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(outbound_keys, [])
        self.assertEqual(list(batch.tags), [tag1])
        self.assertEqual(tag_info.current_batch.key, batch_id)
        self.assertEqual(batch_status, self._batch_status())

    @inlineCallbacks
    def test_batch_start_with_metadata(self):
        batch_id = yield self.store.batch_start([], key1=u"foo", key2=u"bar")
        batch = yield self.store.get_batch(batch_id)
        self.assertEqual(batch.metadata['key1'], "foo")
        self.assertEqual(batch.metadata['key2'], "bar")

    @inlineCallbacks
    def test_batch_done(self):
        tag1 = ("poolA", "tag1")
        batch_id = yield self.store.batch_start([tag1])
        yield self.store.batch_done(batch_id)
        batch = yield self.store.get_batch(batch_id)
        tag_info = yield self.store.get_tag_info(tag1)
        self.assertEqual(list(batch.tags), [tag1])
        self.assertEqual(tag_info.current_batch.key, None)

    @inlineCallbacks
    def test_add_outbound_message(self):
        msg_id, msg, _batch_id = yield self._create_outbound(tag=None)

        stored_msg = yield self.store.get_outbound_message(msg_id)
        self.assertEqual(stored_msg, msg)
        event_keys = yield self.store.message_event_keys(msg_id)
        self.assertEqual(event_keys, [])

    @inlineCallbacks
    def test_add_outbound_message_again(self):
        msg_id, msg, _batch_id = yield self._create_outbound(tag=None)

        old_stored_msg = yield self.store.get_outbound_message(msg_id)
        self.assertEqual(old_stored_msg, msg)

        msg['helper_metadata']['foo'] = {'bar': 'baz'}
        yield self.store.add_outbound_message(msg)
        new_stored_msg = yield self.store.get_outbound_message(msg_id)
        self.assertEqual(new_stored_msg, msg)
        self.assertNotEqual(old_stored_msg, new_stored_msg)

    @inlineCallbacks
    def test_add_outbound_message_with_batch_id(self):
        msg_id, msg, batch_id = yield self._create_outbound(by_batch=True)

        stored_msg = yield self.store.get_outbound_message(msg_id)
        outbound_keys = yield self.store.batch_outbound_keys(batch_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(outbound_keys, [msg_id])
        self.assertEqual(event_keys, [])
        self.assertEqual(batch_status, self._batch_status(sent=1))

    @inlineCallbacks
    def test_add_outbound_message_with_tag(self):
        msg_id, msg, batch_id = yield self._create_outbound()

        stored_msg = yield self.store.get_outbound_message(msg_id)
        outbound_keys = yield self.store.batch_outbound_keys(batch_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(outbound_keys, [msg_id])
        self.assertEqual(event_keys, [])
        self.assertEqual(batch_status, self._batch_status(sent=1))

    @inlineCallbacks
    def test_add_outbound_message_to_multiple_batches(self):
        msg_id, msg, batch_id_1 = yield self._create_outbound()
        batch_id_2 = yield self.store.batch_start()
        yield self.store.add_outbound_message(msg, batch_id=batch_id_2)

        self.assertEqual(
            (yield self.store.batch_outbound_keys(batch_id_1)), [msg_id])
        self.assertEqual(
            (yield self.store.batch_outbound_keys(batch_id_2)), [msg_id])
        # Make sure we're writing the right indexes.
        stored_msg = yield self.store.outbound_messages.load(msg_id)
        timestamp = format_vumi_date(msg['timestamp'])
        reverse_ts = to_reverse_timestamp(timestamp)
        self.assertEqual(stored_msg._riak_object.get_indexes(), set([
            ('batches_bin', batch_id_1),
            ('batches_bin', batch_id_2),
            ('batches_with_addresses_bin',
             "%s$%s$%s" % (batch_id_1, timestamp, msg['to_addr'])),
            ('batches_with_addresses_bin',
             "%s$%s$%s" % (batch_id_2, timestamp, msg['to_addr'])),
            ('batches_with_addresses_reverse_bin',
             "%s$%s$%s" % (batch_id_1, reverse_ts, msg['to_addr'])),
            ('batches_with_addresses_reverse_bin',
             "%s$%s$%s" % (batch_id_2, reverse_ts, msg['to_addr'])),
        ]))

    @inlineCallbacks
    def test_get_events_for_message(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack)

        dr = self.msg_helper.make_delivery_report(msg)
        dr_id = ack['event_id']
        yield self.store.add_event(dr)

        stored_ack = yield self.store.get_event(ack_id)
        stored_dr = yield self.store.get_event(dr_id)

        events = yield self.store.get_events_for_message(msg_id)

        self.assertTrue(len(events), 2)
        self.assertTrue(
            all(isinstance(event, TransportEvent) for event in events))
        self.assertTrue(stored_ack in events)
        self.assertTrue(stored_dr in events)

    @inlineCallbacks
    def test_add_ack_event_batch_ids_from_outbound(self):
        """
        If the `batch_ids` param is not given, and the event doesn't exist,
        batch ids are looked up on the outbound message.
        """
        msg_id, msg, batch_id = yield self._create_outbound()
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack)

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1, ack=1))

        event = yield self.store.events.load(ack_id)
        self.assertEqual(event.batches.keys(), [batch_id])
        timestamp = format_vumi_date(ack["timestamp"])
        self.assertEqual(event.message_with_status, "%s$%s$ack" % (
            msg_id, timestamp))
        self.assertEqual(set(event.batches_with_statuses_reverse), set([
            "%s$%s$ack" % (batch_id, to_reverse_timestamp(timestamp)),
        ]))

    @inlineCallbacks
    def test_add_ack_event_uses_existing_batches(self):
        """
        If the `batch_ids` param is not given, and the event already
        exists, batch ids should not be looked up on the outbound message.
        """
        # create a message but don't store it
        msg = self.msg_helper.make_outbound('outbound text')
        msg_id = msg['message_id']
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        # create an event and store it
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack, [batch_id])

        # now store the event again without specifying batches
        yield self.store.add_event(ack)

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=0, ack=1))

        event = yield self.store.events.load(ack_id)
        self.assertEqual(event.batches.keys(), [batch_id])
        timestamp = format_vumi_date(ack["timestamp"])
        self.assertEqual(event.message_with_status, "%s$%s$ack" % (
            msg_id, timestamp))
        self.assertEqual(set(event.batches_with_statuses_reverse), set([
            "%s$%s$ack" % (batch_id, to_reverse_timestamp(timestamp)),
        ]))

    @inlineCallbacks
    def test_add_ack_event_with_batch_ids(self):
        """
        If an event is added with batch_ids provided, those batch_ids are used.
        """
        msg_id, msg, batch_id = yield self._create_outbound()
        batch_1 = yield self.store.batch_start([])
        batch_2 = yield self.store.batch_start([])

        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack, batch_ids=[batch_1, batch_2])

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)
        batch_1_status = yield self.store.batch_status(batch_1)
        batch_2_status = yield self.store.batch_status(batch_2)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1))
        self.assertEqual(batch_1_status, self._batch_status(ack=1))
        self.assertEqual(batch_2_status, self._batch_status(ack=1))

        event = yield self.store.events.load(ack_id)
        timestamp = format_vumi_date(ack["timestamp"])
        self.assertEqual(event.message_with_status, "%s$%s$ack" % (
            msg_id, timestamp))
        self.assertEqual(set(event.batches_with_statuses_reverse), set([
            "%s$%s$ack" % (batch_1, to_reverse_timestamp(timestamp)),
            "%s$%s$ack" % (batch_2, to_reverse_timestamp(timestamp)),
        ]))

    @inlineCallbacks
    def test_add_ack_event_without_batch_ids_no_outbound(self):
        """
        If an event is added without batch_ids and no outbound message is
        found, no batch_ids will be used.
        """
        msg_id, msg, batch_id = yield self._create_outbound()

        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        ack['user_message_id'] = "no-message"
        yield self.store.add_event(ack)

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys("no-message")
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1))

        event = yield self.store.events.load(ack_id)
        timestamp = format_vumi_date(ack["timestamp"])
        self.assertEqual(event.message_with_status, "%s$%s$ack" % (
            "no-message", timestamp))
        self.assertEqual(set(event.batches_with_statuses_reverse), set())

    @inlineCallbacks
    def test_add_ack_event_with_empty_batch_ids(self):
        """
        If an event is added with an empty list of batch_ids, no batch_ids will
        be used.
        """
        msg_id, msg, batch_id = yield self._create_outbound()
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack, batch_ids=[])

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1))

        event = yield self.store.events.load(ack_id)
        timestamp = format_vumi_date(ack["timestamp"])
        self.assertEqual(event.message_with_status, "%s$%s$ack" % (
            msg_id, timestamp))
        self.assertEqual(set(event.batches_with_statuses_reverse), set())

    @inlineCallbacks
    def test_add_ack_event_again(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack)
        old_stored_ack = yield self.store.get_event(ack_id)
        self.assertEqual(old_stored_ack, ack)

        ack['helper_metadata']['foo'] = {'bar': 'baz'}
        yield self.store.add_event(ack)
        new_stored_ack = yield self.store.get_event(ack_id)
        self.assertEqual(new_stored_ack, ack)
        self.assertNotEqual(old_stored_ack, new_stored_ack)

        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(event_keys, [ack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1, ack=1))

    @inlineCallbacks
    def test_add_nack_event(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        nack = self.msg_helper.make_nack(msg)
        nack_id = nack['event_id']
        yield self.store.add_event(nack)

        stored_nack = yield self.store.get_event(nack_id)
        event_keys = yield self.store.message_event_keys(msg_id)
        batch_status = yield self.store.batch_status(batch_id)

        self.assertEqual(stored_nack, nack)
        self.assertEqual(event_keys, [nack_id])
        self.assertEqual(batch_status, self._batch_status(sent=1, nack=1))

        event = yield self.store.events.load(nack_id)
        self.assertEqual(event.message_with_status, "%s$%s$nack" % (
            msg_id, nack["timestamp"]))

    @inlineCallbacks
    def test_add_ack_event_without_batch(self):
        msg_id, msg, _batch_id = yield self._create_outbound(tag=None)
        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack)

        stored_ack = yield self.store.get_event(ack_id)
        event_keys = yield self.store.message_event_keys(msg_id)

        self.assertEqual(stored_ack, ack)
        self.assertEqual(event_keys, [ack_id])

    @inlineCallbacks
    def test_add_nack_event_without_batch(self):
        msg_id, msg, _batch_id = yield self._create_outbound(tag=None)
        nack = self.msg_helper.make_nack(msg)
        nack_id = nack['event_id']
        yield self.store.add_event(nack)

        stored_nack = yield self.store.get_event(nack_id)
        event_keys = yield self.store.message_event_keys(msg_id)

        self.assertEqual(stored_nack, nack)
        self.assertEqual(event_keys, [nack_id])

    @inlineCallbacks
    def test_add_delivery_report_events(self):
        msg_id, msg, batch_id = yield self._create_outbound()

        dr_ids = []
        for status in TransportEvent.DELIVERY_STATUSES:
            dr = self.msg_helper.make_delivery_report(
                msg, delivery_status=status)
            dr_id = dr['event_id']
            dr_ids.append(dr_id)
            yield self.store.add_event(dr)
            stored_dr = yield self.store.get_event(dr_id)
            self.assertEqual(stored_dr, dr)

            event = yield self.store.events.load(dr_id)
            self.assertEqual(event.message_with_status, "%s$%s$%s" % (
                msg_id, dr["timestamp"], "delivery_report.%s" % (status,)))

        event_keys = yield self.store.message_event_keys(msg_id)
        self.assertEqual(sorted(event_keys), sorted(dr_ids))
        dr_counts = dict((status, 1)
                         for status in TransportEvent.DELIVERY_STATUSES)
        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status, self._batch_status(sent=1, **dr_counts))

    @inlineCallbacks
    def test_add_inbound_message(self):
        msg_id, msg, _batch_id = yield self._create_inbound(tag=None)
        stored_msg = yield self.store.get_inbound_message(msg_id)
        self.assertEqual(stored_msg, msg)

    @inlineCallbacks
    def test_add_inbound_message_again(self):
        msg_id, msg, _batch_id = yield self._create_inbound(tag=None)

        old_stored_msg = yield self.store.get_inbound_message(msg_id)
        self.assertEqual(old_stored_msg, msg)

        msg['helper_metadata']['foo'] = {'bar': 'baz'}
        yield self.store.add_inbound_message(msg)
        new_stored_msg = yield self.store.get_inbound_message(msg_id)
        self.assertEqual(new_stored_msg, msg)
        self.assertNotEqual(old_stored_msg, new_stored_msg)

    @inlineCallbacks
    def test_add_inbound_message_with_batch_id(self):
        msg_id, msg, batch_id = yield self._create_inbound(by_batch=True)

        stored_msg = yield self.store.get_inbound_message(msg_id)
        inbound_keys = yield self.store.batch_inbound_keys(batch_id)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(inbound_keys, [msg_id])

    @inlineCallbacks
    def test_add_inbound_message_with_tag(self):
        msg_id, msg, batch_id = yield self._create_inbound()

        stored_msg = yield self.store.get_inbound_message(msg_id)
        inbound_keys = yield self.store.batch_inbound_keys(batch_id)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(inbound_keys, [msg_id])

    @inlineCallbacks
    def test_add_inbound_message_to_multiple_batches(self):
        msg_id, msg, batch_id_1 = yield self._create_inbound()
        batch_id_2 = yield self.store.batch_start()
        yield self.store.add_inbound_message(msg, batch_id=batch_id_2)

        self.assertEqual((yield self.store.batch_inbound_keys(batch_id_1)),
                         [msg_id])
        self.assertEqual((yield self.store.batch_inbound_keys(batch_id_2)),
                         [msg_id])
        # Make sure we're writing the right indexes.
        stored_msg = yield self.store.inbound_messages.load(msg_id)
        timestamp = format_vumi_date(msg['timestamp'])
        reverse_ts = to_reverse_timestamp(timestamp)
        self.assertEqual(stored_msg._riak_object.get_indexes(), set([
            ('batches_bin', batch_id_1),
            ('batches_bin', batch_id_2),
            ('batches_with_addresses_bin',
             "%s$%s$%s" % (batch_id_1, timestamp, msg['from_addr'])),
            ('batches_with_addresses_bin',
             "%s$%s$%s" % (batch_id_2, timestamp, msg['from_addr'])),
            ('batches_with_addresses_reverse_bin',
             "%s$%s$%s" % (batch_id_1, reverse_ts, msg['from_addr'])),
            ('batches_with_addresses_reverse_bin',
             "%s$%s$%s" % (batch_id_2, reverse_ts, msg['from_addr'])),
        ]))

    @inlineCallbacks
    def test_inbound_counts(self):
        _msg_id, _msg, batch_id = yield self._create_inbound(by_batch=True)
        self.assertEqual(1, (yield self.store.batch_inbound_count(batch_id)))
        yield self.store.add_inbound_message(
            self.msg_helper.make_inbound("foo"), batch_id=batch_id)
        self.assertEqual(2, (yield self.store.batch_inbound_count(batch_id)))

    @inlineCallbacks
    def test_outbound_counts(self):
        _msg_id, _msg, batch_id = yield self._create_outbound(by_batch=True)
        self.assertEqual(1, (yield self.store.batch_outbound_count(batch_id)))
        yield self.store.add_outbound_message(
            self.msg_helper.make_outbound("foo"), batch_id=batch_id)
        self.assertEqual(2, (yield self.store.batch_outbound_count(batch_id)))

    @inlineCallbacks
    def test_inbound_keys_matching(self):
        msg_id, msg, batch_id = yield self._create_inbound(content='hello')
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_inbound_keys_matching(batch_id, query=[{
                'key': 'msg.content',
                'pattern': 'hell.+',
                'flags': 'i',
            }])))
        # test case sensitivity
        self.assertEqual(
            [],
            (yield self.store.batch_inbound_keys_matching(batch_id, query=[{
                'key': 'msg.content',
                'pattern': 'HELLO',
                'flags': '',
            }])))
        # the inbound from_addr has a leading +, it needs to be escaped
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_inbound_keys_matching(batch_id, query=[{
                'key': 'msg.from_addr',
                'pattern': "\%s" % (msg.payload['from_addr'],),
                'flags': 'i',
            }])))
        # the outbound to_addr has a leading +, it needs to be escaped
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_inbound_keys_matching(batch_id, query=[{
                'key': 'msg.to_addr',
                'pattern': "\%s" % (msg.payload['to_addr'],),
                'flags': 'i',
            }])))

    @inlineCallbacks
    def test_outbound_keys_matching(self):
        msg_id, msg, batch_id = yield self._create_outbound(content='hello')
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_outbound_keys_matching(batch_id, query=[{
                'key': 'msg.content',
                'pattern': 'hell.+',
                'flags': 'i',
            }])))
        # test case sensitivity
        self.assertEqual(
            [],
            (yield self.store.batch_outbound_keys_matching(batch_id, query=[{
                'key': 'msg.content',
                'pattern': 'HELLO',
                'flags': '',
            }])))
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_outbound_keys_matching(batch_id, query=[{
                'key': 'msg.from_addr',
                'pattern': msg.payload['from_addr'],
                'flags': 'i',
            }])))
        # the outbound to_addr has a leading +, it needs to be escaped
        self.assertEqual(
            [msg_id],
            (yield self.store.batch_outbound_keys_matching(batch_id, query=[{
                'key': 'msg.to_addr',
                'pattern': "\%s" % (msg.payload['to_addr'],),
                'flags': 'i',
            }])))

    @inlineCallbacks
    def test_add_inbound_message_with_batch_ids(self):
        batch_id1 = yield self.store.batch_start([])
        batch_id2 = yield self.store.batch_start([])
        msg = self.msg_helper.make_inbound("hi")

        yield self.store.add_inbound_message(
            msg, batch_ids=[batch_id1, batch_id2])

        stored_msg = yield self.store.get_inbound_message(msg['message_id'])
        inbound_keys1 = yield self.store.batch_inbound_keys(batch_id1)
        inbound_keys2 = yield self.store.batch_inbound_keys(batch_id2)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(inbound_keys1, [msg['message_id']])
        self.assertEqual(inbound_keys2, [msg['message_id']])

    @inlineCallbacks
    def test_add_inbound_message_with_batch_id_and_batch_ids(self):
        batch_id1 = yield self.store.batch_start([])
        batch_id2 = yield self.store.batch_start([])
        msg = self.msg_helper.make_inbound("hi")

        yield self.store.add_inbound_message(
            msg, batch_id=batch_id1, batch_ids=[batch_id2])

        stored_msg = yield self.store.get_inbound_message(msg['message_id'])
        inbound_keys1 = yield self.store.batch_inbound_keys(batch_id1)
        inbound_keys2 = yield self.store.batch_inbound_keys(batch_id2)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(inbound_keys1, [msg['message_id']])
        self.assertEqual(inbound_keys2, [msg['message_id']])

    @inlineCallbacks
    def test_add_outbound_message_with_batch_ids(self):
        batch_id1 = yield self.store.batch_start([])
        batch_id2 = yield self.store.batch_start([])
        msg = self.msg_helper.make_outbound("hi")

        yield self.store.add_outbound_message(
            msg, batch_ids=[batch_id1, batch_id2])

        stored_msg = yield self.store.get_outbound_message(msg['message_id'])
        outbound_keys1 = yield self.store.batch_outbound_keys(batch_id1)
        outbound_keys2 = yield self.store.batch_outbound_keys(batch_id2)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(outbound_keys1, [msg['message_id']])
        self.assertEqual(outbound_keys2, [msg['message_id']])

    @inlineCallbacks
    def test_add_outbound_message_with_batch_id_and_batch_ids(self):
        batch_id1 = yield self.store.batch_start([])
        batch_id2 = yield self.store.batch_start([])
        msg = self.msg_helper.make_outbound("hi")

        yield self.store.add_outbound_message(
            msg, batch_id=batch_id1, batch_ids=[batch_id2])

        stored_msg = yield self.store.get_outbound_message(msg['message_id'])
        outbound_keys1 = yield self.store.batch_outbound_keys(batch_id1)
        outbound_keys2 = yield self.store.batch_outbound_keys(batch_id2)

        self.assertEqual(stored_msg, msg)
        self.assertEqual(outbound_keys1, [msg['message_id']])
        self.assertEqual(outbound_keys2, [msg['message_id']])

    @inlineCallbacks
    def test_batch_inbound_keys_page(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)
        all_keys = sorted(msg['message_id'] for msg in messages)

        keys_p1 = yield self.store.batch_inbound_keys_page(batch_id, 6)
        # Paginated results are sorted by key.
        self.assertEqual(sorted(keys_p1), all_keys[:6])

        keys_p2 = yield keys_p1.next_page()
        self.assertEqual(sorted(keys_p2), all_keys[6:])

    @inlineCallbacks
    def test_batch_outbound_keys_page(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)
        all_keys = sorted(msg['message_id'] for msg in messages)

        keys_p1 = yield self.store.batch_outbound_keys_page(batch_id, 6)
        # Paginated results are sorted by key.
        self.assertEqual(sorted(keys_p1), all_keys[:6])

        keys_p2 = yield keys_p1.next_page()
        self.assertEqual(sorted(keys_p2), all_keys[6:])

    @inlineCallbacks
    def test_batch_inbound_keys_with_timestamp(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        first_page = yield self.store.batch_inbound_keys_with_timestamps(
            batch_id, max_results=6)

        results = list(first_page)
        self.assertEqual(len(results), 6)
        self.assertEqual(first_page.has_next_page(), True)

        next_page = yield first_page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_inbound_keys_with_timestamp_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_timestamps(
            batch_id, max_results=6, start=all_keys[1][1])
        self.assertEqual(list(index_page), all_keys[1:])

    @inlineCallbacks
    def test_batch_inbound_keys_with_timestamp_without_timestamps(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_timestamps(
            batch_id, with_timestamps=False)
        self.assertEqual(list(index_page), [k for k, _ in all_keys])

    @inlineCallbacks
    def test_batch_inbound_keys_with_timestamp_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_timestamps(
            batch_id, max_results=6, end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_inbound_keys_with_timestamp_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_timestamps(
            batch_id, max_results=6, start=all_keys[1][1], end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_timestamp(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        first_page = yield self.store.batch_outbound_keys_with_timestamps(
            batch_id, max_results=6)

        results = list(first_page)
        self.assertEqual(len(results), 6)
        self.assertEqual(first_page.has_next_page(), True)

        next_page = yield first_page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_outbound_keys_with_timestamp_without_timestamps(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_timestamps(
            batch_id, with_timestamps=False)
        self.assertEqual(list(index_page), [k for k, _ in all_keys])

    @inlineCallbacks
    def test_batch_outbound_keys_with_timestamp_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_timestamps(
            batch_id, max_results=6, start=all_keys[1][1])
        self.assertEqual(list(index_page), all_keys[1:])

    @inlineCallbacks
    def test_batch_outbound_keys_with_timestamp_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_timestamps(
            batch_id, max_results=6, end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_timestamp_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted((msg['timestamp'], msg['message_id'])
                             for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp))
                    for (timestamp, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_timestamps(
            batch_id, max_results=6, start=all_keys[1][1], end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)
        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        first_page = yield self.store.batch_inbound_keys_with_addresses(
            batch_id, max_results=6)

        results = list(first_page)
        self.assertEqual(len(results), 6)
        self.assertEqual(first_page.has_next_page(), True)

        next_page = yield first_page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_addresses(
            batch_id, max_results=6, start=all_keys[1][1])
        self.assertEqual(list(index_page), all_keys[1:])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_addresses(
            batch_id, max_results=6, end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_inbound_keys_with_addresses(
            batch_id, max_results=6, start=all_keys[1][1], end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)
        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        first_page = yield self.store.batch_outbound_keys_with_addresses(
            batch_id, max_results=6)

        results = list(first_page)
        self.assertEqual(len(results), 6)
        self.assertEqual(first_page.has_next_page(), True)

        next_page = yield first_page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_addresses(
            batch_id, max_results=6, start=all_keys[1][1])
        self.assertEqual(list(index_page), all_keys[1:])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_addresses(
            batch_id, max_results=6, end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        index_page = yield self.store.batch_outbound_keys_with_addresses(
            batch_id, max_results=6, start=all_keys[1][1], end=all_keys[-2][1])
        self.assertEqual(list(index_page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_reverse(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['from_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_inbound_keys_with_addresses_reverse(
            batch_id, max_results=6)

        results = list(page)
        self.assertEqual(len(results), 6)
        self.assertEqual(page.has_next_page(), True)

        next_page = yield page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_reverse_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['from_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_inbound_keys_with_addresses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1])
        self.assertEqual(list(page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_reverse_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['from_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_inbound_keys_with_addresses_reverse(
            batch_id, max_results=6, end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:])

    @inlineCallbacks
    def test_batch_inbound_keys_with_addresses_reverse_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['from_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_inbound_keys_with_addresses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1], end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_reverse(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['to_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_outbound_keys_with_addresses_reverse(
            batch_id, max_results=6)

        results = list(page)
        self.assertEqual(len(results), 6)
        self.assertEqual(page.has_next_page(), True)

        next_page = yield page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_reverse_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['to_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_outbound_keys_with_addresses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1])
        self.assertEqual(list(page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_reverse_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['to_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_outbound_keys_with_addresses_reverse(
            batch_id, max_results=6, end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:])

    @inlineCallbacks
    def test_batch_outbound_keys_with_addresses_reverse_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(msg['timestamp']), msg['to_addr'], msg['message_id'])
             for msg in messages], reverse=True)
        all_keys = [(key, timestamp, addr)
                    for (timestamp, addr, key) in sorted_keys]

        page = yield self.store.batch_outbound_keys_with_addresses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1], end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:-1])

    @inlineCallbacks
    def test_batch_event_keys_with_statuses_reverse(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        events = yield self.create_events(batch_id, 10)
        sorted_keys = sorted(
            [(zero_ms(ev['timestamp']), ev.status(), ev['event_id'])
             for ev in events], reverse=True)
        all_keys = [(key, timestamp, status)
                    for (timestamp, status, key) in sorted_keys]

        page = yield self.store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=6)

        results = list(page)
        self.assertEqual(len(results), 6)
        self.assertEqual(page.has_next_page(), True)

        next_page = yield page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 10)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_event_keys_with_statuses_reverse_start(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        events = yield self.create_events(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(ev['timestamp']), ev.status(), ev['event_id'])
             for ev in events], reverse=True)
        all_keys = [(key, timestamp, status)
                    for (timestamp, status, key) in sorted_keys]

        page = yield self.store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1])
        self.assertEqual(list(page), all_keys[:-1])

    @inlineCallbacks
    def test_batch_event_keys_with_statuses_reverse_end(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        events = yield self.create_events(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(ev['timestamp']), ev.status(), ev['event_id'])
             for ev in events], reverse=True)
        all_keys = [(key, timestamp, status)
                    for (timestamp, status, key) in sorted_keys]

        page = yield self.store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=6, end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:])

    @inlineCallbacks
    def test_batch_event_keys_with_statuses_reverse_range(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        events = yield self.create_events(batch_id, 5)
        sorted_keys = sorted(
            [(zero_ms(ev['timestamp']), ev.status(), ev['event_id'])
             for ev in events], reverse=True)
        all_keys = [(key, timestamp, status)
                    for (timestamp, status, key) in sorted_keys]

        page = yield self.store.batch_event_keys_with_statuses_reverse(
            batch_id, max_results=6, start=all_keys[-2][1], end=all_keys[1][1])
        self.assertEqual(list(page), all_keys[1:-1])

    @inlineCallbacks
    def test_message_event_keys_with_statuses(self):
        """
        Event keys and statuses for a message can be retrieved by index.
        """
        msg_id, msg, batch_id = yield self._create_outbound()

        ack = self.msg_helper.make_ack(msg)
        yield self.store.add_event(ack)
        drs = []
        for status in TransportEvent.DELIVERY_STATUSES:
            dr = self.msg_helper.make_delivery_report(
                msg, delivery_status=status)
            drs.append(dr)
            yield self.store.add_event(dr)

        def mk_tuple(e, status):
            return e["event_id"], format_vumi_date(e["timestamp"]), status

        all_keys = [mk_tuple(ack, "ack")] + [
            mk_tuple(e, "delivery_report.%s" % (e["delivery_status"],))
            for e in drs]

        first_page = yield self.store.message_event_keys_with_statuses(
            msg_id, max_results=3)

        results = list(first_page)
        self.assertEqual(len(results), 3)
        self.assertEqual(first_page.has_next_page(), True)

        next_page = yield first_page.next_page()
        results.extend(next_page)
        self.assertEqual(len(results), 4)
        self.assertEqual(next_page.has_next_page(), False)

        self.assertEqual(results, all_keys)

    @inlineCallbacks
    def test_batch_inbound_stats(self):
        """
        batch_inbound_stats returns total and unique address counts for the
        whole batch if no time range is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        yield self.create_inbound_messages(
            batch_id, 5, start_timestamp=now, from_addr=u'00005')
        yield self.create_inbound_messages(
            batch_id, 3, start_timestamp=start_3, from_addr=u'00003')
        yield self.create_inbound_messages(
            batch_id, 2, start_timestamp=start_2, from_addr=u'00002')

        inbound_stats = yield self.store.batch_inbound_stats(
            batch_id, max_results=6)
        self.assertEqual(inbound_stats, {"total": 10, "unique_addresses": 3})

    @inlineCallbacks
    def test_batch_inbound_stats_start(self):
        """
        batch_inbound_stats returns total and unique address counts for all
        messages newer than the start date if only the start date is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_inbound_messages(
            batch_id, 5, start_timestamp=now, from_addr=u'00005')
        messages_3 = yield self.create_inbound_messages(
            batch_id, 3, start_timestamp=start_3, from_addr=u'00003')
        messages_2 = yield self.create_inbound_messages(
            batch_id, 2, start_timestamp=start_2, from_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        inbound_stats_1 = yield self.store.batch_inbound_stats(
            batch_id, start=all_keys[2][1])

        self.assertEqual(inbound_stats_1, {"total": 8, "unique_addresses": 3})

        inbound_stats_2 = yield self.store.batch_inbound_stats(
            batch_id, start=all_keys[6][1])

        self.assertEqual(inbound_stats_2, {"total": 4, "unique_addresses": 2})

    @inlineCallbacks
    def test_batch_inbound_stats_end(self):
        """
        batch_inbound_stats returns total and unique address counts for all
        messages older than the end date if only the end date is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_inbound_messages(
            batch_id, 5, start_timestamp=now, from_addr=u'00005')
        messages_3 = yield self.create_inbound_messages(
            batch_id, 3, start_timestamp=start_3, from_addr=u'00003')
        messages_2 = yield self.create_inbound_messages(
            batch_id, 2, start_timestamp=start_2, from_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        inbound_stats_1 = yield self.store.batch_inbound_stats(
            batch_id, end=all_keys[-3][1])

        self.assertEqual(inbound_stats_1, {"total": 8, "unique_addresses": 3})

        inbound_stats_2 = yield self.store.batch_inbound_stats(
            batch_id, end=all_keys[-7][1])

        self.assertEqual(inbound_stats_2, {"total": 4, "unique_addresses": 2})

    @inlineCallbacks
    def test_batch_inbound_stats_range(self):
        """
        batch_inbound_stats returns total and unique address counts for all
        messages newer than the start date and older than the end date if both
        are specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_inbound_messages(
            batch_id, 5, start_timestamp=now, from_addr=u'00005')
        messages_3 = yield self.create_inbound_messages(
            batch_id, 3, start_timestamp=start_3, from_addr=u'00003')
        messages_2 = yield self.create_inbound_messages(
            batch_id, 2, start_timestamp=start_2, from_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['from_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        inbound_stats_1 = yield self.store.batch_inbound_stats(
            batch_id, start=all_keys[2][1], end=all_keys[-3][1])

        self.assertEqual(inbound_stats_1, {"total": 6, "unique_addresses": 3})

        inbound_stats_2 = yield self.store.batch_inbound_stats(
            batch_id, start=all_keys[2][1], end=all_keys[-7][1])

        self.assertEqual(inbound_stats_2, {"total": 2, "unique_addresses": 2})

    @inlineCallbacks
    def test_batch_outbound_stats(self):
        """
        batch_outbound_stats returns total and unique address counts for the
        whole batch if no time range is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        yield self.create_outbound_messages(
            batch_id, 5, start_timestamp=now, to_addr=u'00005')
        yield self.create_outbound_messages(
            batch_id, 3, start_timestamp=start_3, to_addr=u'00003')
        yield self.create_outbound_messages(
            batch_id, 2, start_timestamp=start_2, to_addr=u'00002')

        outbound_stats = yield self.store.batch_outbound_stats(
            batch_id, max_results=6)
        self.assertEqual(outbound_stats, {"total": 10, "unique_addresses": 3})

    @inlineCallbacks
    def test_batch_outbound_stats_start(self):
        """
        batch_outbound_stats returns total and unique address counts for all
        messages newer than the start date if only the start date is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_outbound_messages(
            batch_id, 5, start_timestamp=now, to_addr=u'00005')
        messages_3 = yield self.create_outbound_messages(
            batch_id, 3, start_timestamp=start_3, to_addr=u'00003')
        messages_2 = yield self.create_outbound_messages(
            batch_id, 2, start_timestamp=start_2, to_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        outbound_stats_1 = yield self.store.batch_outbound_stats(
            batch_id, start=all_keys[2][1])

        self.assertEqual(outbound_stats_1, {"total": 8, "unique_addresses": 3})

        outbound_stats_2 = yield self.store.batch_outbound_stats(
            batch_id, start=all_keys[6][1])

        self.assertEqual(outbound_stats_2, {"total": 4, "unique_addresses": 2})

    @inlineCallbacks
    def test_batch_outbound_stats_end(self):
        """
        batch_outbound_stats returns total and unique address counts for all
        messages older than the end date if only the end date is specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_outbound_messages(
            batch_id, 5, start_timestamp=now, to_addr=u'00005')
        messages_3 = yield self.create_outbound_messages(
            batch_id, 3, start_timestamp=start_3, to_addr=u'00003')
        messages_2 = yield self.create_outbound_messages(
            batch_id, 2, start_timestamp=start_2, to_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        outbound_stats_1 = yield self.store.batch_outbound_stats(
            batch_id, end=all_keys[-3][1])

        self.assertEqual(outbound_stats_1, {"total": 8, "unique_addresses": 3})

        outbound_stats_2 = yield self.store.batch_outbound_stats(
            batch_id, end=all_keys[-7][1])

        self.assertEqual(outbound_stats_2, {"total": 4, "unique_addresses": 2})

    @inlineCallbacks
    def test_batch_outbound_stats_range(self):
        """
        batch_outbound_stats returns total and unique address counts for all
        messages newer than the start date and older than the end date if both
        are specified.
        """
        batch_id = yield self.store.batch_start([('pool', 'tag')])

        now = datetime.now()
        start_3 = now - timedelta(5)
        start_2 = now - timedelta(35)
        messages_5 = yield self.create_outbound_messages(
            batch_id, 5, start_timestamp=now, to_addr=u'00005')
        messages_3 = yield self.create_outbound_messages(
            batch_id, 3, start_timestamp=start_3, to_addr=u'00003')
        messages_2 = yield self.create_outbound_messages(
            batch_id, 2, start_timestamp=start_2, to_addr=u'00002')
        messages = messages_5 + messages_3 + messages_2

        sorted_keys = sorted(
            (msg['timestamp'], msg['to_addr'], msg['message_id'])
            for msg in messages)
        all_keys = [(key, format_vumi_date(timestamp), addr)
                    for (timestamp, addr, key) in sorted_keys]

        outbound_stats_1 = yield self.store.batch_outbound_stats(
            batch_id, start=all_keys[2][1], end=all_keys[-3][1])

        self.assertEqual(outbound_stats_1, {"total": 6, "unique_addresses": 3})

        outbound_stats_2 = yield self.store.batch_outbound_stats(
            batch_id, start=all_keys[2][1], end=all_keys[-7][1])

        self.assertEqual(outbound_stats_2, {"total": 2, "unique_addresses": 2})


class TestMessageStoreCache(TestMessageStoreBase):

    def clear_cache(self, message_store):
        # FakeRedis provides a flushdb() function but TxRedisManager doesn't
        # and I'm not sure what the intended behaviour of flushdb on a
        # submanager is
        return message_store.cache.redis._purge_all()

    @inlineCallbacks
    def test_cache_batch_start(self):
        batch_id = yield self.store.batch_start([("poolA", "tag1")])
        self.assertTrue((yield self.store.cache.batch_exists(batch_id)))
        self.assertTrue(batch_id in (yield self.store.cache.get_batch_ids()))

    @inlineCallbacks
    def test_cache_add_outbound_message(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        [cached_msg_id] = (
            yield self.store.cache.get_outbound_message_keys(batch_id))
        cached_to_addrs = yield self.store.cache.get_to_addrs(batch_id)
        self.assertEqual(msg_id, cached_msg_id)
        # NOTE: This functionality is disabled for now.
        # self.assertEqual([msg['to_addr']], cached_to_addrs)
        self.assertEqual([], cached_to_addrs)

    @inlineCallbacks
    def test_cache_add_inbound_message(self):
        msg_id, msg, batch_id = yield self._create_inbound()
        [cached_msg_id] = (
            yield self.store.cache.get_inbound_message_keys(batch_id))
        cached_from_addrs = yield self.store.cache.get_from_addrs(batch_id)
        self.assertEqual(msg_id, cached_msg_id)
        # NOTE: This functionality is disabled for now.
        # self.assertEqual([msg['from_addr']], cached_from_addrs)
        self.assertEqual([], cached_from_addrs)

    @inlineCallbacks
    def test_cache_add_event(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        ack = TransportEvent(user_message_id=msg_id, event_type='ack',
                             sent_message_id='xyz')
        yield self.store.add_event(ack)
        self.assertEqual((yield self.store.cache.get_event_status(batch_id)), {
            'delivery_report': 0,
            'delivery_report.delivered': 0,
            'delivery_report.failed': 0,
            'delivery_report.pending': 0,
            'ack': 1,
            'nack': 0,
            'sent': 1,
        })

    @inlineCallbacks
    def test_needs_reconciliation(self):
        msg_id, msg, batch_id = yield self._create_outbound()
        self.assertFalse((yield self.store.needs_reconciliation(batch_id)))

        msg_id, msg, batch_id = yield self._create_outbound()

        # Store via message_store
        yield self.create_outbound_messages(batch_id, 10)

        # Store one extra in the cache to throw off the allow threshold delta
        recon_msg = self.msg_helper.make_outbound("foo")
        yield self.store.cache.add_outbound_message(batch_id, recon_msg)

        # Default reconciliation delta should return True
        self.assertTrue((yield self.store.needs_reconciliation(batch_id)))
        # More liberal reconciliation delta should return False
        self.assertFalse((
            yield self.store.needs_reconciliation(batch_id, delta=0.1)))

    @inlineCallbacks
    def test_reconcile_cache(self):
        cache = self.store.cache
        batch_id = yield self.store.batch_start([("pool", "tag")])

        # Store via message_store
        yield self.create_inbound_messages(batch_id, 1, from_addr='from1')
        yield self.create_inbound_messages(batch_id, 2, from_addr='from2')
        yield self.create_inbound_messages(batch_id, 3, from_addr='from3')

        outbound_messages = []
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 4, to_addr='to1')))
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 6, to_addr='to2')))

        for msg in outbound_messages:
            ack = self.msg_helper.make_ack(msg)
            yield self.store.add_event(ack)

        yield self.clear_cache(self.store)
        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status, {})
        # Default reconciliation delta should return True
        self.assertTrue((yield self.store.needs_reconciliation(batch_id)))
        yield self.store.reconcile_cache(batch_id)
        # Reconciliation check should return False after recon.
        self.assertFalse((yield self.store.needs_reconciliation(batch_id)))
        self.assertFalse(
            (yield self.store.needs_reconciliation(batch_id, delta=0)))

        inbound_count = yield cache.count_inbound_message_keys(batch_id)
        self.assertEqual(inbound_count, 6)
        outbound_count = yield cache.count_outbound_message_keys(batch_id)
        self.assertEqual(outbound_count, 10)

        inbound_uniques = yield cache.count_from_addrs(batch_id)
        self.assertEqual(inbound_uniques, 3)
        outbound_uniques = yield cache.count_to_addrs(batch_id)
        self.assertEqual(outbound_uniques, 2)

        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status['ack'], 10)
        self.assertEqual(batch_status['sent'], 10)

    @inlineCallbacks
    def test_reconcile_cache_with_old_and_new_messages(self):
        """
        If we're reconciling a batch that contains messages older than the
        truncation threshold and newer than the start of the recon, we still
        end up with the correct numbers.
        """
        cache = self.store.cache
        cache.TRUNCATE_MESSAGE_KEY_COUNT_AT = 5
        batch_id = yield self.store.batch_start([("pool", "tag")])

        # Store via message_store
        inbound_messages = []
        inbound_messages.extend((yield self.create_inbound_messages(
            batch_id, 1, from_addr='from1')))
        inbound_messages.extend((yield self.create_inbound_messages(
            batch_id, 2, from_addr='from2')))
        inbound_messages.extend((yield self.create_inbound_messages(
            batch_id, 3, from_addr='from3')))

        outbound_messages = []
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 4, to_addr='to1')))
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 6, to_addr='to2')))

        for msg in outbound_messages:
            ack = self.msg_helper.make_ack(msg)
            yield self.store.add_event(ack)
            dr = self.msg_helper.make_delivery_report(
                msg, delivery_status="delivered")
            yield self.store.add_event(dr)

        # We want one message newer than the start of the recon, and they're
        # ordered from newest to oldest.
        start_timestamp = format_vumi_date(inbound_messages[1]["timestamp"])

        yield self.store.reconcile_cache(batch_id, start_timestamp)

        inbound_count = yield cache.count_inbound_message_keys(batch_id)
        self.assertEqual(inbound_count, 6)
        outbound_count = yield cache.count_outbound_message_keys(batch_id)
        self.assertEqual(outbound_count, 10)

        inbound_uniques = yield self.store.cache.count_from_addrs(batch_id)
        self.assertEqual(inbound_uniques, 3)
        outbound_uniques = yield self.store.cache.count_to_addrs(batch_id)
        self.assertEqual(outbound_uniques, 2)

        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status["sent"], 10)
        self.assertEqual(batch_status["ack"], 10)
        self.assertEqual(batch_status["delivery_report"], 10)
        self.assertEqual(batch_status["delivery_report.delivered"], 10)

    @inlineCallbacks
    def test_reconcile_cache_and_switch_to_counters(self):
        batch_id = yield self.store.batch_start([("pool", "tag")])
        cache = self.store.cache

        # Clear the cache and restart the batch without counters.
        yield cache.clear_batch(batch_id)
        yield cache.batch_start(batch_id, use_counters=False)

        # Store via message_store
        yield self.create_inbound_messages(batch_id, 1, from_addr='from1')
        yield self.create_inbound_messages(batch_id, 2, from_addr='from2')
        yield self.create_inbound_messages(batch_id, 3, from_addr='from3')

        outbound_messages = []
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 4, to_addr='to1')))
        outbound_messages.extend((yield self.create_outbound_messages(
            batch_id, 6, to_addr='to2')))

        for msg in outbound_messages:
            ack = self.msg_helper.make_ack(msg)
            yield self.store.add_event(ack)

        # This will fail if we're using counter-based events with a ZSET.
        events_scard = yield cache.redis.scard(cache.event_key(batch_id))
        # HACK: We're not tracking these in the SET anymore.
        #       See HACK comment in message_store_cache.py.
        # self.assertEqual(events_scard, 10)
        self.assertEqual(events_scard, 0)

        yield self.clear_cache(self.store)
        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status, {})
        # Default reconciliation delta should return True
        self.assertTrue((yield self.store.needs_reconciliation(batch_id)))
        yield self.store.reconcile_cache(batch_id)
        # Reconciliation check should return False after recon.
        self.assertFalse((yield self.store.needs_reconciliation(batch_id)))
        self.assertFalse(
            (yield self.store.needs_reconciliation(batch_id, delta=0)))

        inbound_count = yield cache.count_inbound_message_keys(batch_id)
        self.assertEqual(inbound_count, 6)
        outbound_count = yield cache.count_outbound_message_keys(batch_id)
        self.assertEqual(outbound_count, 10)

        inbound_uniques = yield self.store.cache.count_from_addrs(batch_id)
        self.assertEqual(inbound_uniques, 3)
        outbound_uniques = yield self.store.cache.count_to_addrs(batch_id)
        self.assertEqual(outbound_uniques, 2)

        batch_status = yield self.store.batch_status(batch_id)
        self.assertEqual(batch_status['ack'], 10)
        self.assertEqual(batch_status['sent'], 10)

        # This will fail if we're using old-style events with a SET.
        events_zcard = yield cache.redis.zcard(cache.event_key(batch_id))
        self.assertEqual(events_zcard, 10)

    @inlineCallbacks
    def test_find_inbound_keys_matching(self):
        batch_id = yield self.store.batch_start([("pool", "tag")])

        # Store via message_store
        messages = yield self.create_inbound_messages(batch_id, 10)

        token = yield self.store.find_inbound_keys_matching(batch_id, [{
            'key': 'msg.content',
            'pattern': '.*',
            'flags': 'i',
        }], wait=True)

        keys = yield self.store.get_keys_for_token(batch_id, token)
        in_progress = yield self.store.cache.is_query_in_progress(
            batch_id, token)
        self.assertEqual(len(keys), 10)
        self.assertEqual(
            10, (yield self.store.count_keys_for_token(batch_id, token)))
        self.assertEqual(keys, [msg['message_id'] for msg in messages])
        self.assertFalse(in_progress)

    @inlineCallbacks
    def test_find_outbound_keys_matching(self):
        batch_id = yield self.store.batch_start([("pool", "tag")])

        # Store via message_store
        messages = yield self.create_outbound_messages(batch_id, 10)

        token = yield self.store.find_outbound_keys_matching(batch_id, [{
            'key': 'msg.content',
            'pattern': '.*',
            'flags': 'i',
        }], wait=True)

        keys = yield self.store.get_keys_for_token(batch_id, token)
        in_progress = yield self.store.cache.is_query_in_progress(
            batch_id, token)
        self.assertEqual(len(keys), 10)
        self.assertEqual(
            10, (yield self.store.count_keys_for_token(batch_id, token)))
        self.assertEqual(keys, [msg['message_id'] for msg in messages])
        self.assertFalse(in_progress)

    @inlineCallbacks
    def test_get_inbound_message_keys(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)

        keys = yield self.store.get_inbound_message_keys(batch_id)
        self.assertEqual(keys, [msg['message_id'] for msg in messages])

    @inlineCallbacks
    def test_get_inbound_message_keys_with_timestamp(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_inbound_messages(batch_id, 10)

        results = dict((yield self.store.get_inbound_message_keys(
            batch_id, with_timestamp=True)))
        for msg in messages:
            found = results[msg['message_id']]
            expected = time.mktime(msg['timestamp'].timetuple())
            self.assertAlmostEqual(found, expected)

    @inlineCallbacks
    def test_get_outbound_message_keys(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)

        keys = yield self.store.get_outbound_message_keys(batch_id)
        self.assertEqual(keys, [msg['message_id'] for msg in messages])

    @inlineCallbacks
    def test_get_outbound_message_keys_with_timestamp(self):
        batch_id = yield self.store.batch_start([('pool', 'tag')])
        messages = yield self.create_outbound_messages(batch_id, 10)

        results = dict((yield self.store.get_outbound_message_keys(
            batch_id, with_timestamp=True)))
        for msg in messages:
            found = results[msg['message_id']]
            expected = time.mktime(msg['timestamp'].timetuple())
            self.assertAlmostEqual(found, expected)


class TestMigrationFunctions(TestMessageStoreBase):

    @inlineCallbacks
    def test_add_batches_to_event_no_batches(self):
        """
        If the stored event has no batches, they're looked up from the outbound
        message and added to the event.
        """
        msg_id, msg, batch_id = yield self._create_outbound()

        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack, batch_ids=[])

        event = yield self.store.events.load(ack_id)
        self.assertEqual(event.batches.keys(), [])

        updated = yield add_batches_to_event(event)
        self.assertEqual(updated, True)
        self.assertEqual(event.batches.keys(), [batch_id])

    @inlineCallbacks
    def test_add_batches_to_event_with_batches(self):
        """
        If the stored event already has batches, we do nothing.
        """
        msg_id, msg, batch_id = yield self._create_outbound()

        ack = self.msg_helper.make_ack(msg)
        ack_id = ack['event_id']
        yield self.store.add_event(ack, batch_ids=[batch_id])

        event = yield self.store.events.load(ack_id)
        self.assertEqual(event.batches.keys(), [batch_id])

        updated = yield add_batches_to_event(event)
        self.assertEqual(updated, False)
        self.assertEqual(event.batches.keys(), [batch_id])
PK=JG[q.vumi/components/tests/test_schedule_manager.py"""Tests for go.apps.sequential_send.vumi_app"""

from datetime import datetime

from vumi.components.schedule_manager import ScheduleManager
from vumi.tests.utils import LogCatcher
from vumi.tests.helpers import VumiTestCase


class TestScheduleManager(VumiTestCase):
    def assert_schedule_next(self, config, since_dt, expected_next_dt):
        sm = ScheduleManager(config)
        self.assertEqual(sm.get_next(since_dt), expected_next_dt)

    def assert_config_error(self, config, errmsg):
        sm = ScheduleManager(config)
        with LogCatcher() as logger:
            self.assertEqual(None, sm.get_next(None))
            [err] = logger.errors
            self.assertEqual(err['why'], 'Error processing schedule.')
            self.assertEqual(err['failure'].value.args[0], errmsg)
        [f] = self.flushLoggedErrors(ValueError)
        self.assertEqual(f, err['failure'])

    def test_invalid_recurring(self):
        self.assert_config_error(
            {'recurring': 'No, iterate.'},
            "Invalid value for 'recurring': 'No, iterate.'")

    def test_daily_schedule_same_day(self):
        self.assert_schedule_next(
            {'recurring': 'daily', 'time': '12:00:00'},
            datetime(2012, 11, 20, 11, 0, 0),
            datetime(2012, 11, 20, 12, 0, 0))

    def test_daily_schedule_next_day(self):
        self.assert_schedule_next(
            {'recurring': 'daily', 'time': '12:00:00'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 21, 12, 0, 0))

    def test_daily_invalid_time(self):
        self.assert_config_error(
            {'recurring': 'daily', 'time': 'lunch time'},
            "time data 'lunch time' does not match format '%H:%M:%S'")

    def test_day_of_month_same_day(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '20 25'},
            datetime(2012, 11, 20, 11, 0, 0),
            datetime(2012, 11, 20, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '15 20'},
            datetime(2012, 11, 20, 11, 0, 0),
            datetime(2012, 11, 20, 12, 0, 0))

    def test_day_of_month_same_month(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '20 25'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 25, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '15 25'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 25, 12, 0, 0))

    def test_day_of_month_next_month(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '15 20'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 12, 15, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '1 15'},
            datetime(2012, 12, 20, 13, 0, 0),
            datetime(2013, 1, 1, 12, 0, 0))

    def test_day_of_month_invalid_time(self):
        self.assert_config_error(
            {'recurring': 'day_of_month', 'time': 'lunch time'},
            "time data 'lunch time' does not match format '%H:%M:%S'")

    def test_day_of_month_invalid_days(self):
        self.assert_config_error(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': 'x'},
            "Invalid value for 'days': 'x'")
        self.assert_config_error(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '0 1'},
            "Invalid value for 'days': '0 1'")
        self.assert_config_error(
            {'recurring': 'day_of_month', 'time': '12:00:00', 'days': '32'},
            "Invalid value for 'days': '32'")

    def test_day_of_week_same_day(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '2 4'},
            datetime(2012, 11, 20, 11, 0, 0),
            datetime(2012, 11, 20, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '2 4'},
            datetime(2012, 11, 20, 11, 0, 0),
            datetime(2012, 11, 20, 12, 0, 0))

    def test_day_of_week_same_week(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '2 4'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 22, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '6 7'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 24, 12, 0, 0))

    def test_day_of_week_next_week(self):
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '1'},
            datetime(2012, 11, 20, 13, 0, 0),
            datetime(2012, 11, 26, 12, 0, 0))
        self.assert_schedule_next(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '1'},
            datetime(2012, 12, 20, 13, 0, 0),
            datetime(2012, 12, 24, 12, 0, 0))

    def test_day_of_week_invalid_time(self):
        self.assert_config_error(
            {'recurring': 'day_of_week', 'time': 'lunch time'},
            "time data 'lunch time' does not match format '%H:%M:%S'")

    def test_day_of_week_invalid_days(self):
        self.assert_config_error(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': 'x'},
            "Invalid value for 'days': 'x'")
        self.assert_config_error(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '0 1'},
            "Invalid value for 'days': '0 1'")
        self.assert_config_error(
            {'recurring': 'day_of_week', 'time': '12:00:00', 'days': '8'},
            "Invalid value for 'days': '8'")

    def test_never(self):
        self.assert_schedule_next(
            {'recurring': 'never'},
            datetime(2012, 11, 20, 13, 0, 0),
            None)
PK=JGb De$e$1vumi/components/tests/message_store_old_models.py"""Previous versions of message store models."""

from calendar import timegm
from datetime import datetime

from vumi.message import (
    TransportUserMessage, TransportEvent, format_vumi_date, parse_vumi_date)
from vumi.persist.model import Model
from vumi.persist.fields import (
    VumiMessage, ForeignKey, ListOf, Dynamic, Tag, Unicode, ManyToMany)
from vumi.components.message_store_migrators import (
    InboundMessageMigrator, OutboundMessageMigrator, EventMigrator)


class BatchVNone(Model):
    bucket = 'batch'

    # key is batch_id
    tags = ListOf(Tag())
    metadata = Dynamic(Unicode())


class OutboundMessageVNone(Model):
    bucket = 'outboundmessage'

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batch = ForeignKey(BatchVNone, null=True)


class EventVNone(Model):
    bucket = 'event'

    # key is event_id
    event = VumiMessage(TransportEvent)
    message = ForeignKey(OutboundMessageVNone)


class InboundMessageVNone(Model):
    bucket = 'inboundmessage'

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batch = ForeignKey(BatchVNone, null=True)


class OutboundMessageV1(Model):
    bucket = 'outboundmessage'

    VERSION = 1
    MIGRATOR = OutboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)


class InboundMessageV1(Model):
    bucket = 'inboundmessage'

    VERSION = 1
    MIGRATOR = InboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)


class OutboundMessageV2(Model):
    bucket = 'outboundmessage'

    VERSION = 2
    MIGRATOR = OutboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        batches_with_timestamps = []
        timestamp = self.msg['timestamp']
        for batch_id in self.batches.keys():
            batches_with_timestamps.append(u"%s$%s" % (batch_id, timestamp))
        self.batches_with_timestamps = batches_with_timestamps
        return super(OutboundMessageV2, self).save()


class InboundMessageV2(Model):
    bucket = 'inboundmessage'

    VERSION = 2
    MIGRATOR = InboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        batches_with_timestamps = []
        timestamp = self.msg['timestamp']
        for batch_id in self.batches.keys():
            batches_with_timestamps.append(u"%s$%s" % (batch_id, timestamp))
        self.batches_with_timestamps = batches_with_timestamps
        return super(InboundMessageV2, self).save()


class OutboundMessageV3(Model):
    bucket = 'outboundmessage'

    VERSION = 3
    MIGRATOR = OutboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)
    batches_with_addresses = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        batches_with_timestamps = []
        batches_with_addresses = []
        timestamp = format_vumi_date(self.msg['timestamp'])
        for batch_id in self.batches.keys():
            batches_with_timestamps.append(u"%s$%s" % (batch_id, timestamp))
            batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['to_addr']))
        self.batches_with_timestamps = batches_with_timestamps
        self.batches_with_addresses = batches_with_addresses
        return super(OutboundMessageV3, self).save()


class InboundMessageV3(Model):
    bucket = 'inboundmessage'

    VERSION = 3
    MIGRATOR = InboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)
    batches_with_addresses = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        batches_with_timestamps = []
        batches_with_addresses = []
        timestamp = self.msg['timestamp']
        for batch_id in self.batches.keys():
            batches_with_timestamps.append(u"%s$%s" % (batch_id, timestamp))
            batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['from_addr']))
        self.batches_with_timestamps = batches_with_timestamps
        self.batches_with_addresses = batches_with_addresses
        return super(InboundMessageV3, self).save()


class EventV1(Model):
    bucket = 'event'

    VERSION = 1
    MIGRATOR = EventMigrator

    # key is event_id
    event = VumiMessage(TransportEvent)
    message = ForeignKey(OutboundMessageV3)

    # Extra fields for compound indexes
    message_with_status = Unicode(index=True, null=True)

    def save(self):
        # We override this method to set our index fields before saving.
        timestamp = self.event['timestamp']
        status = self.event['event_type']
        if status == "delivery_report":
            status = "%s.%s" % (status, self.event['delivery_status'])
        self.message_with_status = u"%s$%s$%s" % (
            self.message.key, timestamp, status)
        return super(EventV1, self).save()


def to_reverse_timestamp(vumi_timestamp):
    """
    Turn a vumi_date-formatted string into a string that sorts in reverse order
    and can be turned back into a timestamp later.

    This is done by converting to a unix timestamp and subtracting it from
    0xffffffffff (2**40 - 1) to get a number well outside the range
    representable by the datetime module. The result is returned as a
    hexadecimal string.
    """
    timestamp = timegm(parse_vumi_date(vumi_timestamp).timetuple())
    return "%X" % (0xffffffffff - timestamp)


def from_reverse_timestamp(reverse_timestamp):
    """
    Turn a reverse timestamp string (from `to_reverse_timestamp()`) into a
    vumi_date-formatted string.
    """
    timestamp = 0xffffffffff - int(reverse_timestamp, 16)
    return format_vumi_date(datetime.utcfromtimestamp(timestamp))


class OutboundMessageV4(Model):
    bucket = 'outboundmessage'

    VERSION = 4
    MIGRATOR = OutboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)
    batches_with_addresses = ListOf(Unicode(), index=True)
    batches_with_addresses_reverse = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        self.batches_with_timestamps = []
        self.batches_with_addresses = []
        self.batches_with_addresses_reverse = []
        timestamp = self.msg['timestamp']
        if not isinstance(timestamp, basestring):
            timestamp = format_vumi_date(timestamp)
        reverse_ts = to_reverse_timestamp(timestamp)
        for batch_id in self.batches.keys():
            self.batches_with_timestamps.append(
                u"%s$%s" % (batch_id, timestamp))
            self.batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['to_addr']))
            self.batches_with_addresses_reverse.append(
                u"%s$%s$%s" % (batch_id, reverse_ts, self.msg['to_addr']))
        return super(OutboundMessageV4, self).save()


class InboundMessageV4(Model):
    bucket = 'inboundmessage'

    VERSION = 4
    MIGRATOR = InboundMessageMigrator

    # key is message_id
    msg = VumiMessage(TransportUserMessage)
    batches = ManyToMany(BatchVNone)

    # Extra fields for compound indexes
    batches_with_timestamps = ListOf(Unicode(), index=True)
    batches_with_addresses = ListOf(Unicode(), index=True)
    batches_with_addresses_reverse = ListOf(Unicode(), index=True)

    def save(self):
        # We override this method to set our index fields before saving.
        self.batches_with_timestamps = []
        self.batches_with_addresses = []
        self.batches_with_addresses_reverse = []
        timestamp = self.msg['timestamp']
        if not isinstance(timestamp, basestring):
            timestamp = format_vumi_date(timestamp)
        reverse_ts = to_reverse_timestamp(timestamp)
        for batch_id in self.batches.keys():
            self.batches_with_timestamps.append(
                u"%s$%s" % (batch_id, timestamp))
            self.batches_with_addresses.append(
                u"%s$%s$%s" % (batch_id, timestamp, self.msg['from_addr']))
            self.batches_with_addresses_reverse.append(
                u"%s$%s$%s" % (batch_id, reverse_ts, self.msg['from_addr']))
        return super(InboundMessageV4, self).save()
PK=JG)))vumi/components/tests/test_tagpool_api.py# -*- coding: utf-8 -*-

"""Tests for vumi.components.tagpool_api."""

from txjsonrpc.web.jsonrpc import Proxy
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet import reactor
from twisted.web.server import Site
from twisted.python import log

from vumi.components.tagpool_api import TagpoolApiServer, TagpoolApiWorker
from vumi.components.tagpool import TagpoolManager
from vumi.utils import http_request
from vumi.tests.helpers import VumiTestCase, WorkerHelper, PersistenceHelper


class TestTagpoolApiServer(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        self.redis = yield self.persistence_helper.get_redis_manager()
        self.tagpool = TagpoolManager(self.redis)
        site = Site(TagpoolApiServer(self.tagpool))
        self.server = yield reactor.listenTCP(0, site, interface='127.0.0.1')
        self.add_cleanup(self.server.loseConnection)
        addr = self.server.getHost()
        self.proxy = Proxy("http://%s:%d/" % (addr.host, addr.port))
        yield self.setup_tags()

    @inlineCallbacks
    def setup_tags(self):
        # pool1 has two tags which are free
        yield self.tagpool.declare_tags([
            ("pool1", "tag1"), ("pool1", "tag2")])
        # pool2 has two tags which are used
        yield self.tagpool.declare_tags([
            ("pool2", "tag1"), ("pool2", "tag2")])
        yield self.tagpool.acquire_specific_tag(["pool2", "tag1"])
        yield self.tagpool.acquire_specific_tag(["pool2", "tag2"])
        # pool3 is empty but has metadata
        yield self.tagpool.set_metadata("pool3", {"meta": "data"})

    def _check_reason(self, result, expected_owner, expected_reason):
        owner, reason = result
        self.assertEqual(owner, expected_owner)
        self.assertEqual(reason.pop('owner'), expected_owner)
        self.assertTrue(isinstance(reason.pop('timestamp'), float))
        self.assertEqual(reason, expected_reason)

    @inlineCallbacks
    def test_acquire_tag(self):
        result = yield self.proxy.callRemote("acquire_tag", "pool1")
        self.assertEqual(result, ["pool1", "tag1"])
        self.assertEqual((yield self.tagpool.inuse_tags("pool1")),
                         [("pool1", "tag1")])
        result = yield self.proxy.callRemote("acquire_tag", "pool2")
        self.assertEqual(result, None)

    @inlineCallbacks
    def test_acquire_tag_with_owner_and_reason(self):
        result = yield self.proxy.callRemote(
            "acquire_tag", "pool1", "me", {"foo": "bar"})
        self.assertEqual(result, ["pool1", "tag1"])
        result = yield self.tagpool.acquired_by(["pool1", "tag1"])
        self._check_reason(result, "me", {"foo": "bar"})

    @inlineCallbacks
    def test_acquire_specific_tag(self):
        result = yield self.proxy.callRemote("acquire_specific_tag",
                                             ["pool1", "tag1"])
        self.assertEqual(result, ["pool1", "tag1"])
        self.assertEqual((yield self.tagpool.inuse_tags("pool1")),
                         [("pool1", "tag1")])
        result = yield self.proxy.callRemote("acquire_specific_tag",
                                             ["pool2", "tag1"])
        self.assertEqual(result, None)

    @inlineCallbacks
    def test_acquire_specific_tag_with_owner_and_reason(self):
        result = yield self.proxy.callRemote(
            "acquire_specific_tag", ["pool1", "tag1"], "me", {"foo": "bar"})
        self.assertEqual(result, ["pool1", "tag1"])
        result = yield self.tagpool.acquired_by(["pool1", "tag1"])
        self._check_reason(result, "me", {"foo": "bar"})

    @inlineCallbacks
    def test_release_tag(self):
        result = yield self.proxy.callRemote("release_tag",
                                             ["pool1", "tag1"])
        self.assertEqual(result, None)
        result = yield self.proxy.callRemote("release_tag",
                                             ["pool2", "tag1"])
        self.assertEqual(result, None)
        self.assertEqual((yield self.tagpool.inuse_tags("pool2")),
                         [("pool2", "tag2")])

    @inlineCallbacks
    def test_declare_tags(self):
        tags = [("newpool", "tag1"), ("newpool", "tag2")]
        result = yield self.proxy.callRemote("declare_tags", tags)
        self.assertEqual(result, None)
        free_tags = yield self.tagpool.free_tags("newpool")
        self.assertEqual(sorted(free_tags), sorted(tags))

    @inlineCallbacks
    def test_get_metadata(self):
        result = yield self.proxy.callRemote("get_metadata", "pool3")
        self.assertEqual(result, {"meta": "data"})
        result = yield self.proxy.callRemote("get_metadata", "pool1")
        self.assertEqual(result, {})

    @inlineCallbacks
    def test_set_metadata(self):
        result = yield self.proxy.callRemote("set_metadata", "newpool",
                                             {"my": "data"})
        self.assertEqual(result, None)
        self.assertEqual((yield self.tagpool.get_metadata("newpool")),
                         {"my": "data"})

    @inlineCallbacks
    def test_purge_pool(self):
        result = yield self.proxy.callRemote("purge_pool", "pool1")
        self.assertEqual(result, None)
        self.assertEqual((yield self.tagpool.free_tags("pool1")), [])

    @inlineCallbacks
    def test_purge_pool_with_keys_in_use(self):
        d = self.proxy.callRemote("purge_pool", "pool2")
        yield d.addErrback(lambda f: log.err(f))
        errors = self.flushLoggedErrors('txjsonrpc.jsonrpclib.Fault')
        self.assertEqual(len(errors), 1)
        server_errors = self.flushLoggedErrors(
            'vumi.components.tagpool.TagpoolError')
        self.assertEqual(len(server_errors), 1)

    @inlineCallbacks
    def test_list_pools(self):
        result = yield self.proxy.callRemote("list_pools")
        self.assertEqual(sorted(result), ["pool1", "pool2", "pool3"])

    @inlineCallbacks
    def test_free_tags(self):
        result = yield self.proxy.callRemote("free_tags", "pool1")
        self.assertEqual(
            sorted(result), [["pool1", "tag1"], ["pool1", "tag2"]])
        result = yield self.proxy.callRemote("free_tags", "pool2")
        self.assertEqual(result, [])
        result = yield self.proxy.callRemote("free_tags", "pool3")
        self.assertEqual(result, [])

    @inlineCallbacks
    def test_inuse_tags(self):
        result = yield self.proxy.callRemote("inuse_tags", "pool1")
        self.assertEqual(result, [])
        result = yield self.proxy.callRemote("inuse_tags", "pool2")
        self.assertEqual(
            sorted(result), [["pool2", "tag1"], ["pool2", "tag2"]])
        result = yield self.proxy.callRemote("inuse_tags", "pool3")
        self.assertEqual(result, [])

    @inlineCallbacks
    def test_acquired_by(self):
        result = yield self.proxy.callRemote("acquired_by", ["pool1", "tag1"])
        self.assertEqual(result, [None, None])
        result = yield self.proxy.callRemote("acquired_by", ["pool2", "tag1"])
        self._check_reason(result, None, {})
        yield self.tagpool.acquire_tag("pool1", owner="me",
                                       reason={"foo": "bar"})
        result = yield self.proxy.callRemote("acquired_by", ["pool1", "tag1"])
        self._check_reason(result, "me", {"foo": "bar"})

    @inlineCallbacks
    def test_owned_tags(self):
        result = yield self.proxy.callRemote("owned_tags", None)
        self.assertEqual(sorted(result),
                         [[u'pool2', u'tag1'], [u'pool2', u'tag2']])
        yield self.tagpool.acquire_tag("pool1", owner="me",
                                       reason={"foo": "bar"})
        result = yield self.proxy.callRemote("owned_tags", "me")
        self.assertEqual(result, [["pool1", "tag1"]])


class TestTagpoolApiWorker(VumiTestCase):

    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        self.worker_helper = self.add_helper(WorkerHelper())

    @inlineCallbacks
    def cleanup_worker(self, worker):
        if worker.running:
            yield worker.redis_manager._purge_all()
            yield worker.redis_manager.close_manager()
            yield worker.stopService()

    @inlineCallbacks
    def get_api_worker(self, config=None, start=True):
        config = {} if config is None else config
        config.setdefault('worker_name', 'test_api_worker')
        config.setdefault('twisted_endpoint', 'tcp:0')
        config.setdefault('web_path', 'api')
        config.setdefault('health_path', 'health')
        config = self.persistence_helper.mk_config(config)
        worker = yield self.worker_helper.get_worker(
            TagpoolApiWorker, config, start)
        self.add_cleanup(self.cleanup_worker, worker)
        if not start:
            returnValue(worker)
        yield worker.startService()
        port = worker.services[0]._waitingForPort.result
        addr = port.getHost()
        proxy = Proxy("http://%s:%d/api/" % (addr.host, addr.port))
        returnValue((worker, proxy))

    @inlineCallbacks
    def test_list_methods(self):
        worker, proxy = yield self.get_api_worker()
        result = yield proxy.callRemote('system.listMethods')
        self.assertTrue(u'acquire_tag' in result)

    @inlineCallbacks
    def test_method_help(self):
        worker, proxy = yield self.get_api_worker()
        result = yield proxy.callRemote('system.methodHelp', 'acquire_tag')
        self.assertEqual(result, u"\n".join([
            "Acquire a tag from the pool (returns None if"
            " no tags are avaliable).",
            "",
            ":param Unicode pool:",
            "    Name of pool to acquire tag from.",
            ":param Unicode owner:",
            "    Owner acquiring tag (or None). May be null. Default: None.",
            ":param Dict reason:",
            "    Metadata on why tag is being acquired (or None)."
            " May be null.",
            "    Default: None.",
            ":rtype Tag:",
            "    Tag acquired (or None).",
        ]))

    @inlineCallbacks
    def test_method_signature(self):
        worker, proxy = yield self.get_api_worker()
        result = yield proxy.callRemote('system.methodSignature',
                                        'acquire_tag')
        self.assertEqual(result, [[u'array', u'string', u'string', u'struct']])

    @inlineCallbacks
    def test_health_resource(self):
        worker, proxy = yield self.get_api_worker()
        result = yield http_request(
            "http://%s:%s/health" % (proxy.host, proxy.port),
            data=None, method='GET')
        self.assertEqual(result, "OK")
PK=JGƋvumi/codecs/ivumi_codecs.pyfrom zope.interface import Interface


class IVumiCodec(Interface):

    def encode(unicode_string, encoding, errors):
        """
        Encode a unicode_string in a specific encoding and return
        the byte string.
        """

    def decode(byte_string, encoding, errors):
        """
        Decode a bytestring in a specific encoding and return the
        unicode string
        """
PK=JGvumi/codecs/vumi_codecs.py# -*- test-case-name: vumi.codecs.tests.test_vumi_codecs -*-
# -*- coding: utf-8 -*-
import codecs
import sys

from vumi.codecs.ivumi_codecs import IVumiCodec

from zope.interface import implements


class VumiCodecException(Exception):
    pass


class GSM7BitCodec(codecs.Codec):
    """
    This has largely been copied from:
    http://stackoverflow.com/questions/13130935/decode-7-bit-gsm
    """

    gsm_basic_charset = (
        u"@£$¥èéùìòÇ\nØø\rÅåΔ_ΦΓΛΩΠΨΣΘΞ\x1bÆæßÉ !\"#¤%&'()*+,-./0123456789:;"
        u"<=>?¡ABCDEFGHIJKLMNOPQRSTUVWXYZÄÖÑÜ`¿abcdefghijklmnopqrstuvwxyzäö"
        u"ñüà")

    gsm_basic_charset_map = dict(
        (l, i) for i, l in enumerate(gsm_basic_charset))

    gsm_extension = (
        u"````````````````````^```````````````````{}`````\\````````````[~]`"
        u"|````````````````````````````````````€``````````````````````````")

    gsm_extension_map = dict((l, i) for i, l in enumerate(gsm_extension))

    def encode(self, unicode_string, errors='strict'):
        result = []
        for position, c in enumerate(unicode_string):
            idx = self.gsm_basic_charset_map.get(c)
            if idx is not None:
                result.append(chr(idx))
                continue
            idx = self.gsm_extension_map.get(c)
            if idx is not None:
                result.append(chr(27) + chr(idx))
            else:
                result.append(
                    self.handle_encode_error(
                        c, errors, position, unicode_string))

        obj = ''.join(result)
        return (obj, len(obj))

    def handle_encode_error(self, char, handler_type, position, obj):
        handler = getattr(
            self, 'handle_encode_%s_error' % (handler_type,), None)
        if handler is None:
            raise VumiCodecException(
                'Invalid errors type %s for GSM7BitCodec', handler_type)
        return handler(char, position, obj)

    def handle_encode_strict_error(self, char, position, obj):
        raise UnicodeEncodeError(
            'gsm0338', char, position, position + 1, repr(obj))

    def handle_encode_ignore_error(self, char, position, obj):
        return ''

    def handle_encode_replace_error(self, char, position, obj):
        return chr(self.gsm_basic_charset_map.get('?'))

    def decode(self, byte_string, errors='strict'):
        res = iter(byte_string)
        result = []
        for position, c in enumerate(res):
            try:
                if c == chr(27):
                    c = next(res)
                    result.append(self.gsm_extension[ord(c)])
                else:
                    result.append(self.gsm_basic_charset[ord(c)])
            except IndexError:
                result.append(
                    self.handle_decode_error(c, errors, position, byte_string))

        obj = u''.join(result)
        return (obj, len(obj))

    def handle_decode_error(self, char, handler_type, position, obj):
        handler = getattr(
            self, 'handle_decode_%s_error' % (handler_type,), None)
        if handler is None:
            raise VumiCodecException(
                'Invalid errors type %s for GSM7BitCodec', handler_type)
        return handler(char, position, obj)

    def handle_decode_strict_error(self, char, position, obj):
        raise UnicodeDecodeError(
            'gsm0338', char, position, position + 1, obj)

    def handle_decode_ignore_error(self, char, position, obj):
        return u''

    def handle_decode_replace_error(self, char, position, obj):
        return u'?'


class UCS2Codec(codecs.Codec):
    """
    UCS2 is for all intents & purposes assumed to be the same as
    big endian UTF16.
    """
    def encode(self, input, errors='strict'):
        return codecs.utf_16_be_encode(input, errors)

    def decode(self, input, errors='strict'):
        return codecs.utf_16_be_decode(input, errors)


class VumiCodec(object):
    implements(IVumiCodec)

    custom_codecs = {
        'gsm0338': GSM7BitCodec(),
        'ucs2': UCS2Codec()
    }

    def encode(self, unicode_string, encoding=None, errors='strict'):
        if not isinstance(unicode_string, unicode):
            raise VumiCodecException(
                'Only Unicode strings accepted for encoding.')
        encoding = encoding or sys.getdefaultencoding()
        if encoding in self.custom_codecs:
            encoder = self.custom_codecs[encoding].encode
        else:
            encoder = codecs.getencoder(encoding)
        obj, length = encoder(unicode_string, errors)
        return obj

    def decode(self, byte_string, encoding=None, errors='strict'):
        if not isinstance(byte_string, str):
            raise VumiCodecException(
                'Only bytestrings accepted for decoding.')
        encoding = encoding or sys.getdefaultencoding()
        if encoding in self.custom_codecs:
            decoder = self.custom_codecs[encoding].decode
        else:
            decoder = codecs.getdecoder(encoding)
        obj, length = decoder(byte_string, errors)
        return obj
PK=JG?7GGvumi/codecs/__init__.pyfrom vumi.codecs.vumi_codecs import VumiCodec

__all__ = ['VumiCodec']
PK=JG

%vumi/codecs/tests/test_vumi_codecs.py# -*- coding: utf-8 -*-
from vumi.codecs.ivumi_codecs import IVumiCodec
from vumi.codecs.vumi_codecs import VumiCodec, VumiCodecException

from twisted.trial.unittest import TestCase


class TestVumiCodec(TestCase):

    def setUp(self):
        self.codec = VumiCodec()

    def test_implements(self):
        self.assertTrue(IVumiCodec.implementedBy(VumiCodec))
        self.assertTrue(IVumiCodec.providedBy(self.codec))

    def test_unicode_encode_guard(self):
        self.assertRaises(
            VumiCodecException, self.codec.encode, "byte string")

    def test_bytestring_decode_guard(self):
        self.assertRaises(
            VumiCodecException, self.codec.decode, u"unicode")

    def test_default_encoding(self):
        self.assertEqual(self.codec.encode(u"a"), "a")
        self.assertRaises(
            UnicodeEncodeError, self.codec.encode, u"ë")

    def test_default_decoding(self):
        self.assertEqual(self.codec.decode("a"), u"a")
        self.assertRaises(
            UnicodeDecodeError, self.codec.decode, '\xc3\xab')  # e-umlaut

    def test_encode_utf8(self):
        self.assertEqual(self.codec.encode(u"Zoë", "utf-8"), 'Zo\xc3\xab')

    def test_decode_utf8(self):
        self.assertEqual(self.codec.decode('Zo\xc3\xab', "utf-8"), u"Zoë")

    def test_encode_utf16be(self):
        self.assertEqual(
            self.codec.encode(u"Zoë", "utf-16be"), '\x00Z\x00o\x00\xeb')

    def test_decode_utf16be(self):
        self.assertEqual(
            self.codec.decode('\x00Z\x00o\x00\xeb', "utf-16be"), u"Zoë")

    def test_encode_ucs2(self):
        self.assertEqual(
            self.codec.encode(u"Zoë", "ucs2"), '\x00Z\x00o\x00\xeb')

    def test_decode_ucs2(self):
        self.assertEqual(
            self.codec.decode('\x00Z\x00o\x00\xeb', "ucs2"), u"Zoë")

    def test_encode_gsm0338(self):
        self.assertEqual(
            self.codec.encode(u"HÜLK", "gsm0338"),
            ''.join([chr(code) for code in [72, 94, 76, 75]]))

    def test_encode_gsm0338_extended(self):
        self.assertEqual(
            self.codec.encode(u"foo €", "gsm0338"),
            ''.join([chr(code) for code in [102, 111, 111, 32, 27, 101]]))

    def test_decode_gsm0338_extended(self):
        self.assertEqual(
            self.codec.decode(
                ''.join([chr(code) for code in [102, 111, 111, 32, 27, 101]]),
                'gsm0338'),
            u"foo €")

    def test_encode_gsm0338_strict(self):
        self.assertRaises(
            UnicodeEncodeError, self.codec.encode, u'Zoë', 'gsm0338')

    def test_encode_gsm0338_ignore(self):
        self.assertEqual(
            self.codec.encode(u"Zoë", "gsm0338", 'ignore'), 'Zo')

    def test_encode_gsm0338_replace(self):
        self.assertEqual(
            self.codec.encode(u"Zoë", "gsm0338", 'replace'), 'Zo?')

    def test_decode_gsm0338_strict(self):
        self.assertRaises(
            UnicodeDecodeError, self.codec.decode,
            u'Zoë'.encode('utf-8'), 'gsm0338')

    def test_decode_gsm0338_ignore(self):
        self.assertEqual(
            self.codec.decode(
                u'Zoë'.encode('utf-8'), "gsm0338", 'ignore'), u'Zo')

    def test_decode_gsm0338_replace(self):
        self.assertEqual(
            self.codec.decode(
                u'Zoë'.encode('utf-8'), "gsm0338", 'replace'), u'Zo??')
PK=JGvumi/codecs/tests/__init__.pyPK=JG"dE`E`vumi/dispatchers/base.py# -*- test-case-name: vumi.dispatchers.tests.test_base -*-

"""Basic tools for building dispatchers."""

import re
import functools

from twisted.internet.defer import inlineCallbacks, returnValue, maybeDeferred

from vumi.service import Worker
from vumi.errors import ConfigError, DispatcherError
from vumi.message import TransportUserMessage, TransportEvent
from vumi.utils import load_class_by_string, get_first_word
from vumi.middleware import MiddlewareStack, setup_middlewares_from_config
from vumi import log
from vumi.components.session import SessionManager
from vumi.persist.txredis_manager import TxRedisManager


class BaseDispatchWorker(Worker):
    """Base class for a dispatch worker.

    """

    @inlineCallbacks
    def startWorker(self):
        log.msg('Starting a %s dispatcher with config: %s'
                % (self.__class__.__name__, self.config))

        self.amqp_prefetch_count = self.config.get('amqp_prefetch_count', 20)
        yield self.setup_endpoints()
        yield self.setup_middleware()
        yield self.setup_router()
        yield self.setup_transport_publishers()
        yield self.setup_exposed_publishers()
        yield self.setup_transport_consumers()
        yield self.setup_exposed_consumers()

        consumers = (self.exposed_consumer.values() +
                        self.transport_consumer.values() +
                        self.transport_event_consumer.values())
        for consumer in consumers:
            consumer.unpause()

    @inlineCallbacks
    def stopWorker(self):
        yield self.teardown_router()
        yield self.teardown_middleware()

    def setup_endpoints(self):
        self.transport_names = self.config.get('transport_names', [])
        self.exposed_names = self.config.get('exposed_names', [])

    @inlineCallbacks
    def setup_middleware(self):
        middlewares = yield setup_middlewares_from_config(self, self.config)
        self._middlewares = MiddlewareStack(middlewares)

    def teardown_middleware(self):
        return self._middlewares.teardown()

    def setup_router(self):
        router_cls = load_class_by_string(self.config['router_class'])
        self._router = router_cls(self, self.config)
        return maybeDeferred(self._router.setup_routing)

    def teardown_router(self):
        return maybeDeferred(self._router.teardown_routing)

    @inlineCallbacks
    def setup_transport_publishers(self):
        self.transport_publisher = {}
        for transport_name in self.transport_names:
            self.transport_publisher[transport_name] = yield self.publish_to(
                '%s.outbound' % (transport_name,))

    @inlineCallbacks
    def setup_transport_consumers(self):
        self.transport_consumer = {}
        self.transport_event_consumer = {}
        for transport_name in self.transport_names:
            self.transport_consumer[transport_name] = yield self.consume(
                '%s.inbound' % (transport_name,),
                functools.partial(self.dispatch_inbound_message,
                                  transport_name),
                message_class=TransportUserMessage, paused=True,
                prefetch_count=self.amqp_prefetch_count)
        for transport_name in self.transport_names:
            self.transport_event_consumer[transport_name] = yield self.consume(
                '%s.event' % (transport_name,),
                functools.partial(self.dispatch_inbound_event, transport_name),
                message_class=TransportEvent, paused=True,
                prefetch_count=self.amqp_prefetch_count)

    @inlineCallbacks
    def setup_exposed_publishers(self):
        self.exposed_publisher = {}
        self.exposed_event_publisher = {}
        for exposed_name in self.exposed_names:
            self.exposed_publisher[exposed_name] = yield self.publish_to(
                '%s.inbound' % (exposed_name,))
        for exposed_name in self.exposed_names:
            self.exposed_event_publisher[exposed_name] = yield self.publish_to(
                '%s.event' % (exposed_name,))

    @inlineCallbacks
    def setup_exposed_consumers(self):
        self.exposed_consumer = {}
        for exposed_name in self.exposed_names:
            self.exposed_consumer[exposed_name] = yield self.consume(
                '%s.outbound' % (exposed_name,),
                functools.partial(self.dispatch_outbound_message,
                                  exposed_name),
                message_class=TransportUserMessage, paused=True,
                prefetch_count=self.amqp_prefetch_count)

    def dispatch_inbound_message(self, endpoint, msg):
        d = self._middlewares.apply_consume("inbound", msg, endpoint)
        d.addCallback(self._router.dispatch_inbound_message)
        return d

    def dispatch_inbound_event(self, endpoint, msg):
        d = self._middlewares.apply_consume("event", msg, endpoint)
        d.addCallback(self._router.dispatch_inbound_event)
        return d

    def dispatch_outbound_message(self, endpoint, msg):
        d = self._middlewares.apply_consume("outbound", msg, endpoint)
        d.addCallback(self._router.dispatch_outbound_message)
        return d

    def publish_inbound_message(self, endpoint, msg):
        d = self._middlewares.apply_publish("inbound", msg, endpoint)
        d.addCallback(self.exposed_publisher[endpoint].publish_message)
        return d

    def publish_inbound_event(self, endpoint, msg):
        d = self._middlewares.apply_publish("event", msg, endpoint)
        d.addCallback(self.exposed_event_publisher[endpoint].publish_message)
        return d

    def publish_outbound_message(self, endpoint, msg):
        d = self._middlewares.apply_publish("outbound", msg, endpoint)
        d.addCallback(self.transport_publisher[endpoint].publish_message)
        return d


class BaseDispatchRouter(object):
    """Base class for dispatch routing logic.

    This is a convenient definition of and set of common functionality
    for router classes. You need not subclass this and should not
    instantiate this directly.

    The :meth:`__init__` method should take exactly the following
    options so that your class can be instantiated from configuration
    in a standard way:

    :param vumi.dispatchers.BaseDispatchWorker dispatcher:
        The dispatcher this routing class is part of.
    :param dict config:
        The configuration options passed to the dispatcher.

    If you are subclassing this class, you should not override
    :meth:`__init__`. Custom setup should be done in
    :meth:`setup_routing` instead.
    """

    def __init__(self, dispatcher, config):
        self.dispatcher = dispatcher
        self.config = config

    def setup_routing(self):
        """Perform setup required for router.

        :rtype: Deferred or None
        :returns: May return a Deferred that is called when setup is
                    complete
        """
        pass

    def teardown_routing(self):
        """Perform teardown required for router.

        :rtype: Deferred or None
        :returns: May return a Deferred that is called when teardown is
                    complete
        """
        pass

    def dispatch_inbound_message(self, msg):
        """Dispatch an inbound user message to a publisher.

        :param vumi.message.TransportUserMessage msg:
            Message to dispatch.
        """
        raise NotImplementedError()

    def dispatch_inbound_event(self, msg):
        """Dispatch an event to a publisher.

        :param vumi.message.TransportEvent msg:
            Message to dispatch.
        """
        raise NotImplementedError()

    def dispatch_outbound_message(self, msg):
        """Dispatch an outbound user message to a publisher.

        :param vumi.message.TransportUserMessage msg:
            Message to dispatch.
        """
        raise NotImplementedError()


class SimpleDispatchRouter(BaseDispatchRouter):
    """Simple dispatch router that maps transports to apps.

    Configuration options:

    :param dict route_mappings:
        A map of *transport_names* to *exposed_names*. Inbound
        messages and events received from a given transport are
        dispatched to the application attached to the corresponding
        exposed name.

    :param dict transport_mappings: An optional re-mapping of
        *transport_names* to *transport_names*.  By default, outbound
        messages are dispatched to the transport attached to the
        *endpoint* with the same name as the transport name given in
        the message. If a transport name is present in this
        dictionary, the message is instead dispatched to the new
        transport name given by the re-mapping.
    """

    def dispatch_inbound_message(self, msg):
        names = self.config['route_mappings'][msg['transport_name']]
        for name in names:
            # copy message so that the middleware doesn't see a particular
            # message instance multiple times
            self.dispatcher.publish_inbound_message(name, msg.copy())

    def dispatch_inbound_event(self, msg):
        names = self.config['route_mappings'][msg['transport_name']]
        for name in names:
            # copy message so that the middleware doesn't see a particular
            # message instance multiple times
            self.dispatcher.publish_inbound_event(name, msg.copy())

    def dispatch_outbound_message(self, msg):
        name = msg['transport_name']
        name = self.config.get('transport_mappings', {}).get(name, name)
        if name in self.dispatcher.transport_publisher:
            self.dispatcher.publish_outbound_message(name, msg)
        else:
            log.error(DispatcherError(
                'Unknown transport_name: %s, discarding %r' % (
                    name, msg.payload)))


class TransportToTransportRouter(BaseDispatchRouter):
    """Simple dispatch router that connects transports to other
    transports.

    .. note::

       Connecting transports to one results in event messages being
       discarded since transports cannot receive events. Outbound
       messages never need to be dispatched because transports only
       send inbound messages.

    Configuration options:

    :param dict route_mappings:
        A map of *transport_names* to *transport_names*. Inbound
        messages received from a transport are sent as outbound
        messages to the associated transport.
    """

    def dispatch_inbound_message(self, msg):
        names = self.config['route_mappings'][msg['transport_name']]
        for name in names:
            self.dispatcher.publish_outbound_message(name, msg.copy())

    def dispatch_inbound_event(self, msg):
        """
        Explicitly throw away events, because transports can't receive them.
        """
        pass

    def dispatch_outbound_message(self, msg):
        """
        If we're only hooking transports up to each other, there are no
        outbound messages.
        """
        pass


class ToAddrRouter(SimpleDispatchRouter):
    """Router that dispatches based on msg to_addr.

    :type toaddr_mappings: dict
    :param toaddr_mappings:
        Mapping from application transport names to regular
        expressions. If a message's to_addr matches the given
        regular expression the message is sent to the applications
        listening on the given transport name.
    """

    def setup_routing(self):
        self.mappings = []
        for name, toaddr_pattern in self.config['toaddr_mappings'].items():
            self.mappings.append((name, re.compile(toaddr_pattern)))
            # TODO: assert that name is in list of publishers.

    def dispatch_inbound_message(self, msg):
        toaddr = msg['to_addr']
        for name, regex in self.mappings:
            if regex.match(toaddr):
                # copy message so that the middleware doesn't see a particular
                # message instance multiple times
                self.dispatcher.publish_inbound_message(name, msg.copy())

    def dispatch_inbound_event(self, msg):
        pass
        # TODO:
        #   Use msg['user_message_id'] to look up where original message
        #   was dispatched to and dispatch this message there
        #   Perhaps there should be a message on the base class to support
        #   this.


class FromAddrMultiplexRouter(BaseDispatchRouter):
    """Router that multiplexes multiple transports based on msg from_addr.

    This router is intended to be used to multiplex a pool of transports that
    each only supports a single external address, and present them to
    applications (or downstream dispatchers) as a single transport that
    supports multiple external addresses. This is useful for multiplexing
    :class:`vumi.transports.xmpp.XMPPTransport` instances, for example.

    .. note::

       This router rewrites `transport_name` in both directions. Also, only
       one exposed name is supported.

    Configuration options:

    :param dict fromaddr_mappings:
        Mapping from message `from_addr` to `transport_name`.
    """

    def setup_routing(self):
        if len(self.dispatcher.exposed_names) != 1:
            raise ConfigError("Only one exposed name allowed for %s." % (
                    type(self).__name__,))
        [self.exposed_name] = self.dispatcher.exposed_names

    def dispatch_inbound_message(self, msg):
        msg['transport_name'] = self.exposed_name
        self.dispatcher.publish_inbound_message(self.exposed_name, msg)

    def dispatch_inbound_event(self, msg):
        msg['transport_name'] = self.exposed_name
        self.dispatcher.publish_inbound_event(self.exposed_name, msg)

    def dispatch_outbound_message(self, msg):
        name = self.config['fromaddr_mappings'][msg['from_addr']]
        msg['transport_name'] = name
        self.dispatcher.publish_outbound_message(name, msg)


class UserGroupingRouter(SimpleDispatchRouter):
    """
    Router that dispatches based on msg `from_addr`. Each unique
    `from_addr` is round-robin assigned to one of the defined
    groups in `group_mappings`. All messages from that
    `from_addr` are then routed to the `app` assigned to that group.

    Useful for A/B testing.

    Configuration options:

    :param dict group_mappings:
        Mapping of group names to transport_names.
        If a user is assigned to a given group the
        message is sent to the application listening
        on the given transport_name.

    :param str dispatcher_name:
        The name of the dispatcher, used internally as
        the prefix for Redis keys.
    """

    def setup_routing(self):
        r_config = self.config.get('redis_manager', {})
        r_prefix = self.config['dispatcher_name']
        # FIXME: The following is a hack to deal with sync-only setup.
        self._redis_d = TxRedisManager.from_config(r_config)
        self._redis_d.addCallback(lambda m: m.sub_manager(r_prefix))
        self._redis_d.addCallback(self._setup_redis)

        self.groups = self.config['group_mappings']
        self.nr_of_groups = len(self.groups)

    def _setup_redis(self, redis):
        self.redis = redis

    @inlineCallbacks
    def get_next_group(self):
        counter = (yield self.redis.incr('round-robin')) - 1
        current_group_id = counter % self.nr_of_groups
        sorted_groups = sorted(self.groups.items())
        group = sorted_groups[current_group_id]
        returnValue(group)

    @inlineCallbacks
    def get_group_for_user(self, user_id):
        user_key = "user:%s" % (user_id,)
        group = yield self.redis.get(user_key)
        if not group:
            group, transport_name = yield self.get_next_group()
            yield self.redis.set(user_key, group)
        returnValue(group)

    @inlineCallbacks
    def dispatch_inbound_message(self, msg):
        yield self._redis_d  # Horrible hack to ensure we have it setup.
        group = yield self.get_group_for_user(msg.user().encode('utf8'))
        app = self.groups[group]
        self.dispatcher.publish_inbound_message(app, msg)


class ContentKeywordRouter(SimpleDispatchRouter):
    """Router that dispatches based on the first word of the message
    content. In the context of SMSes the first word is sometimes called
    the 'keyword'.

    :param dict keyword_mappings:
        Mapping from application transport names to simple keywords.
        This is purely a convenience for constructing simple routing
        rules. The rules generated from this option are appened to
        the of rules supplied via the *rules* option.

    :param list rules:
        A list of routing rules. A routing rule is a dictionary. It
        must have `app` and `keyword` keys and may contain `to_addr`
        and `prefix` keys. If a message's first word matches a given
        keyword, the message is sent to the application listening on
        the transport name given by the value of `app`. If a 'to_addr'
        key is supplied, the message `to_addr` must also match the
        value of the 'to_addr' key. If a 'prefix' is supplied, the
        message `from_addr` must *start with* the value of the
        'prefix' key.

    :param str fallback_application:
        Optional application transport name to forward inbound messages
        that match no rule to. If omitted, unrouted inbound messages
        are just logged.

    :param dict transport_mappings:
        Mapping from message `from_addr` values to transports names.
        If a message's from_addr matches a given from_addr, the
        message is sent to the associated transport.

    :param int expire_routing_memory:
        Time in seconds before outbound message's ids are expired from
        the redis routing store. Outbound message ids are stored along
        with the transport_name the message came in on and are used to
        route events such as acknowledgements and delivery reports
        back to the application that sent the outgoing
        message. Default is seven days.
    """

    DEFAULT_ROUTING_TIMEOUT = 60 * 60 * 24 * 7  # 7 days

    def setup_routing(self):
        self.r_config = self.config.get('redis_manager', {})
        self.r_prefix = self.config['dispatcher_name']

        self.rules = []
        for rule in self.config.get('rules', []):
            if 'keyword' not in rule or 'app' not in rule:
                raise ConfigError("Rule definition %r must contain values for"
                                  " both 'app' and 'keyword'" % rule)
            rule = rule.copy()
            rule['keyword'] = rule['keyword'].lower()
            self.rules.append(rule)
        keyword_mappings = self.config.get('keyword_mappings', {})
        for transport_name, keyword in keyword_mappings.items():
            self.rules.append({'app': transport_name,
                               'keyword': keyword.lower()})
        self.fallback_application = self.config.get('fallback_application')
        self.transport_mappings = self.config['transport_mappings']
        self.expire_routing_timeout = int(self.config.get(
            'expire_routing_memory', self.DEFAULT_ROUTING_TIMEOUT))

        # FIXME: The following is a hack to deal with sync-only setup.
        self._redis_d = TxRedisManager.from_config(self.r_config)
        self._redis_d.addCallback(lambda m: m.sub_manager(self.r_prefix))
        self._redis_d.addCallback(self._setup_redis)

    def _setup_redis(self, redis):
        self.redis = redis
        self.session_manager = SessionManager(
            self.redis, self.expire_routing_timeout)

    def get_message_key(self, message):
        return 'message:%s' % (message,)

    def publish_transport(self, name, msg):
        self.dispatcher.publish_outbound_message(name, msg)

    def publish_exposed_inbound(self, name, msg):
        self.dispatcher.publish_inbound_message(name, msg)

    def publish_exposed_event(self, name, msg):
        self.dispatcher.publish_inbound_event(name, msg)

    def is_msg_matching_routing_rules(self, keyword, msg, rule):
        return all([keyword == rule['keyword'],
                    (not 'to_addr' in rule) or
                    (msg['to_addr'] == rule['to_addr']),
                    (not 'prefix' in rule) or
                    (msg['from_addr'].startswith(rule['prefix']))])

    def dispatch_inbound_message(self, msg):
        keyword = get_first_word(msg['content']).lower()
        matched = False
        for rule in self.rules:
            if self.is_msg_matching_routing_rules(keyword, msg, rule):
                matched = True
                # copy message so that the middleware doesn't see a particular
                # message instance multiple times
                self.publish_exposed_inbound(rule['app'], msg.copy())
        if not matched:
            if self.fallback_application is not None:
                self.publish_exposed_inbound(self.fallback_application, msg)
            else:
                log.error(DispatcherError(
                    'Message could not be routed: %r' % (msg,)))

    @inlineCallbacks
    def dispatch_inbound_event(self, msg):
        yield self._redis_d  # Horrible hack to ensure we have it setup.
        message_key = self.get_message_key(msg['user_message_id'])
        session = yield self.session_manager.load_session(message_key)
        name = session.get('name')
        if not name:
            log.error(DispatcherError(
                "No transport_name for return route found in Redis"
                " while dispatching transport event for message %s"
                % (msg['user_message_id'],)))
        try:
            self.publish_exposed_event(name, msg)
        except:
            log.error(DispatcherError("No publishing route for %s" % (name,)))

    @inlineCallbacks
    def dispatch_outbound_message(self, msg):
        yield self._redis_d  # Horrible hack to ensure we have it setup.
        transport_name = self.transport_mappings.get(msg['from_addr'])
        if transport_name is not None:
            self.publish_transport(transport_name, msg)
            message_key = self.get_message_key(msg['message_id'])
            yield self.session_manager.create_session(
                message_key, name=msg['transport_name'])
        else:
            log.error(DispatcherError(
                "No transport for %s" % (msg['from_addr'],)))


class RedirectRouter(BaseDispatchRouter):
    """Router that dispatches outbound messages to a different transport.

    :param dict redirect_outbound:
        A dictionary where the key is the name of an exposed_name and
        the value is the name of a transport_name.
    :param dict redirect_inbound:
        A dictionary where the key is the value of a transport_name and
        the value is the value of an exposed_name.
    """

    def setup_routing(self):
        self.outbound_mappings = self.config.get('redirect_outbound', {})
        self.inbound_mappings = self.config.get('redirect_inbound', {})

    def _dispatch_inbound(self, publish_function, vumi_message):
        transport_name = vumi_message['transport_name']
        redirect_to = self.inbound_mappings[transport_name]
        if not redirect_to:
            raise ConfigError(
                "No exposed name available for %s's inbound message: %s" % (
                transport_name, vumi_message))

        msg_copy = vumi_message.copy()
        msg_copy['transport_name'] = redirect_to
        publish_function(redirect_to, msg_copy)

    def dispatch_inbound_event(self, event):
        self._dispatch_inbound(self.dispatcher.publish_inbound_event, event)

    def dispatch_inbound_message(self, msg):
        self._dispatch_inbound(self.dispatcher.publish_inbound_message, msg)

    def dispatch_outbound_message(self, msg):
        transport_name = msg['transport_name']
        redirect_to = self.outbound_mappings.get(transport_name)
        if redirect_to:
            self.dispatcher.publish_outbound_message(redirect_to, msg)
        else:
            log.error(DispatcherError(
                'No redirect_outbound specified for %s' % (
                    transport_name,)))


class RedirectOutboundRouter(RedirectRouter):
    """
    Deprecated in favour of `RedirectRouter`.

    RedirectRouter provides the same features while also allowing
    inbound redirection to take place, which `RedirectOutboundRouter`
    conveniently ignores.
    """
    def setup_routing(self, *args, **kwargs):
        log.warning('RedirectOutboundRouter is deprecated, please use '
            '`RedirectRouter` instead.')
        return super(RedirectOutboundRouter, self).setup_routing(
            *args, **kwargs)
PK=JG==(vumi/dispatchers/endpoint_dispatchers.py# -*- test-case-name: vumi.dispatchers.tests.test_endpoint_dispatchers -*-

"""Basic tools for building dispatchers."""

from twisted.internet.defer import gatherResults, maybeDeferred

from vumi.worker import BaseWorker
from vumi.config import ConfigDict, ConfigList
from vumi import log


class DispatcherConfig(BaseWorker.CONFIG_CLASS):
    receive_inbound_connectors = ConfigList(
        "List of connectors that will receive inbound messages and events.",
        required=True, static=True)
    receive_outbound_connectors = ConfigList(
        "List of connectors that will receive outbound messages.",
        required=True, static=True)


class Dispatcher(BaseWorker):
    """Base class for a dispatcher."""

    CONFIG_CLASS = DispatcherConfig

    def setup_worker(self):
        d = maybeDeferred(self.setup_dispatcher)
        d.addCallback(lambda r: self.unpause_connectors())
        return d

    def teardown_worker(self):
        d = self.pause_connectors()
        d.addCallback(lambda r: self.teardown_dispatcher())
        return d

    def setup_dispatcher(self):
        """
        All dispatcher specific setup should happen in here.

        Subclasses should override this method to perform extra setup.
        """
        pass

    def teardown_dispatcher(self):
        """
        Clean-up of setup done in setup_dispatcher should happen here.
        """
        pass

    def get_configured_ri_connectors(self):
        return self.get_static_config().receive_inbound_connectors

    def get_configured_ro_connectors(self):
        return self.get_static_config().receive_outbound_connectors

    def default_errback(self, f, msg, connector_name):
        log.error(f, "Error routing message for %s" % (connector_name,))

    def process_inbound(self, config, msg, connector_name):
        raise NotImplementedError()

    def errback_inbound(self, f, msg, connector_name):
        return f

    def process_outbound(self, config, msg, connector_name):
        raise NotImplementedError()

    def errback_outbound(self, f, msg, connector_name):
        return f

    def process_event(self, config, event, connector_name):
        raise NotImplementedError()

    def errback_event(self, f, event, connector_name):
        return f

    def _mkhandler(self, handler_func, errback_func, connector_name):
        def handler(msg):
            d = maybeDeferred(self.get_config, msg)
            d.addCallback(handler_func, msg, connector_name)
            d.addErrback(errback_func, msg, connector_name)
            d.addErrback(self.default_errback, msg, connector_name)
            return d
        return handler

    def setup_connectors(self):
        def add_ri_handlers(connector, connector_name):
            connector.set_default_inbound_handler(
                self._mkhandler(
                    self.process_inbound, self.errback_inbound,
                    connector_name))
            connector.set_default_event_handler(
                self._mkhandler(
                    self.process_event, self.errback_event, connector_name))
            return connector

        def add_ro_handlers(connector, connector_name):
            connector.set_default_outbound_handler(
                self._mkhandler(
                    self.process_outbound, self.errback_outbound,
                    connector_name))
            return connector

        deferreds = []
        for connector_name in self.get_configured_ri_connectors():
            d = self.setup_ri_connector(connector_name)
            d.addCallback(add_ri_handlers, connector_name)
            deferreds.append(d)

        for connector_name in self.get_configured_ro_connectors():
            d = self.setup_ro_connector(connector_name)
            d.addCallback(add_ro_handlers, connector_name)
            deferreds.append(d)

        return gatherResults(deferreds)

    def publish_inbound(self, msg, connector_name, endpoint):
        return self.connectors[connector_name].publish_inbound(msg, endpoint)

    def publish_outbound(self, msg, connector_name, endpoint):
        return self.connectors[connector_name].publish_outbound(msg, endpoint)

    def publish_event(self, event, connector_name, endpoint):
        return self.connectors[connector_name].publish_event(event, endpoint)


class RoutingTableDispatcherConfig(Dispatcher.CONFIG_CLASS):
    routing_table = ConfigDict(
        "Routing table. Keys are connector names, values are dicts mapping "
        "endpoint names to [connector, endpoint] pairs.", required=True)


class RoutingTableDispatcher(Dispatcher):
    CONFIG_CLASS = RoutingTableDispatcherConfig

    def find_target(self, config, msg, connector_name):
        endpoint_name = msg.get_routing_endpoint()
        endpoint_routing = config.routing_table.get(connector_name)
        if endpoint_routing is None:
            log.warning("No routing information for connector '%s'" % (
                    connector_name,))
            return None
        target = endpoint_routing.get(endpoint_name)
        if target is None:
            log.warning("No routing information for endpoint '%s' on '%s'" % (
                    endpoint_name, connector_name,))
            return None
        return target

    def process_inbound(self, config, msg, connector_name):
        target = self.find_target(config, msg, connector_name)
        if target is None:
            return
        return self.publish_inbound(msg, target[0], target[1])

    def process_outbound(self, config, msg, connector_name):
        target = self.find_target(config, msg, connector_name)
        if target is None:
            return
        return self.publish_outbound(msg, target[0], target[1])

    def process_event(self, config, event, connector_name):
        target = self.find_target(config, event, connector_name)
        if target is None:
            return
        return self.publish_event(event, target[0], target[1])
PK=JG܎\ccvumi/dispatchers/__init__.py"""The vumi.dispatchers API."""

__all__ = ["BaseDispatchWorker", "BaseDispatchRouter", "SimpleDispatchRouter",
           "TransportToTransportRouter", "ToAddrRouter",
           "FromAddrMultiplexRouter", "UserGroupingRouter",
           "ContentKeywordRouter"]

from vumi.dispatchers.base import (BaseDispatchWorker, BaseDispatchRouter,
                                   SimpleDispatchRouter,
                                   TransportToTransportRouter, ToAddrRouter,
                                   FromAddrMultiplexRouter,
                                   UserGroupingRouter, ContentKeywordRouter)
PK=JG**!vumi/dispatchers/load_balancer.py# -*- test-case-name: vumi.dispatchers.tests.test_load_balancer -*-

"""Router for round-robin load balancing between two transports."""

import itertools

from vumi import log
from vumi.errors import ConfigError
from vumi.dispatchers.base import BaseDispatchRouter


class LoadBalancingRouter(BaseDispatchRouter):
    """Router that does round-robin dispatching to transports.

    Supports only one exposed name and requires at least one transport
    name.

    Configuration options:

    :param bool reply_affinity:
        If set to true, replies are sent back to the same transport
        they were sent from. If false, replies are round-robinned in
        the same way other outbound messages are. Default: true.
    :param bool rewrite_transport_name:
        If set to true, rewrites message `transport_names` in both
        directions. Default: true.
    """

    def setup_routing(self):
        self.reply_affinity = self.config.get('reply_affinity', True)
        self.rewrite_transport_names = self.config.get(
            'rewrite_transport_names', True)
        if len(self.dispatcher.exposed_names) != 1:
            raise ConfigError("Only one exposed name allowed for %s." %
                              (type(self).__name__,))
        [self.exposed_name] = self.dispatcher.exposed_names
        if not self.dispatcher.transport_names:
            raise ConfigError("At least one transport name is needed for %s" %
                              (type(self).__name__,))
        self.transport_name_cycle = itertools.cycle(
            self.dispatcher.transport_names)
        self.transport_name_set = set(self.dispatcher.transport_names)

    def push_transport_name(self, msg, transport_name):
        hm = msg['helper_metadata']
        lm = hm.setdefault('load_balancer', {})
        transport_names = lm.setdefault('transport_names', [])
        transport_names.append(transport_name)

    def pop_transport_name(self, msg):
        hm = msg['helper_metadata']
        lm = hm.get('load_balancer', {})
        transport_names = lm.get('transport_names', [])
        if not transport_names:
            return None
        return transport_names.pop()

    def dispatch_inbound_message(self, msg):
        if self.reply_affinity:
            # TODO: we should really be pushing the endpoint name
            #       but it isn't available here
            self.push_transport_name(msg, msg['transport_name'])
        if self.rewrite_transport_names:
            msg['transport_name'] = self.exposed_name
        self.dispatcher.publish_inbound_message(self.exposed_name, msg)

    def dispatch_inbound_event(self, msg):
        if self.rewrite_transport_names:
            msg['transport_name'] = self.exposed_name
        self.dispatcher.publish_inbound_event(self.exposed_name, msg)

    def dispatch_outbound_message(self, msg):
        if self.reply_affinity and msg['in_reply_to']:
            transport_name = self.pop_transport_name(msg)
            if transport_name not in self.transport_name_set:
                log.warning("LoadBalancer is configured for reply affinity but"
                            " reply for unknown load balancer endpoint %r was"
                            " was received. Using round-robin routing instead."
                            % (transport_name,))
                transport_name = self.transport_name_cycle.next()
        else:
            transport_name = self.transport_name_cycle.next()
        if self.rewrite_transport_names:
            msg['transport_name'] = transport_name
        self.dispatcher.publish_outbound_message(transport_name, msg)
PK=JGxpZ}}vumi/dispatchers/tests/utils.pyfrom twisted.internet.defer import inlineCallbacks

from vumi.tests.utils import VumiWorkerTestCase, PersistenceMixin

# For backcompat
from .helpers import DummyDispatcher
DummyDispatcher  # To keep pyflakes happy.


class DispatcherTestCase(VumiWorkerTestCase, PersistenceMixin):

    """
    This is a base class for testing dispatcher workers.

    """

    transport_name = None
    dispatcher_name = "sphex_dispatcher"
    dispatcher_class = None

    def setUp(self):
        self._persist_setUp()
        super(DispatcherTestCase, self).setUp()

    @inlineCallbacks
    def tearDown(self):
        yield super(DispatcherTestCase, self).tearDown()
        yield self._persist_tearDown()

    def get_dispatcher(self, config, cls=None, start=True):
        """
        Get an instance of a dispatcher class.

        :param config: Config dict.
        :param cls: The Dispatcher class to instantiate.
                    Defaults to :attr:`dispatcher_class`
        :param start: True to start the displatcher (default), False otherwise.

        Some default config values are helpfully provided in the
        interests of reducing boilerplate:

        * ``dispatcher_name`` defaults to :attr:`self.dispatcher_name`
        """

        if cls is None:
            cls = self.dispatcher_class
        config = self.mk_config(config)
        config.setdefault('dispatcher_name', self.dispatcher_name)
        return self.get_worker(config, cls, start)

    def rkey(self, name):
        # We don't want the default behaviour for dispatchers.
        return name

    def get_dispatched_messages(self, transport_name, direction='outbound'):
        return self._get_dispatched(
            '%s.%s' % (transport_name, direction))

    def wait_for_dispatched_messages(self, transport_name, amount,
                                     direction='outbound'):
        return self._wait_for_dispatched(
            '%s.%s' % (transport_name, direction), amount)

    def dispatch(self, message, transport_name, direction='inbound',
                 exchange='vumi'):
        return self._dispatch(
            message, '%s.%s' % (transport_name, direction), exchange)
PKqG
)r)r#vumi/dispatchers/tests/test_base.pyfrom twisted.internet.defer import inlineCallbacks, returnValue

from vumi.dispatchers.base import (
    BaseDispatchWorker, ToAddrRouter, FromAddrMultiplexRouter)
from vumi.dispatchers.tests.helpers import DispatcherHelper, DummyDispatcher
from vumi.errors import DispatcherError
from vumi.tests.utils import LogCatcher
from vumi.tests.helpers import VumiTestCase, MessageHelper


class TestBaseDispatchWorker(VumiTestCase):
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))

    def get_dispatcher(self, **config_extras):
        config = {
            "transport_names": [
                "transport1",
                "transport2",
                "transport3",
                ],
            "exposed_names": [
                "app1",
                "app2",
                "app3",
                ],
            "router_class": "vumi.dispatchers.base.SimpleDispatchRouter",
            "route_mappings": {
                "transport1": ["app1"],
                "transport2": ["app2"],
                "transport3": ["app1", "app3"]
                },
            "middleware": [
                {"mw1": "vumi.middleware.tests.utils.RecordingMiddleware"},
                {"mw2": "vumi.middleware.tests.utils.RecordingMiddleware"},
                ],
            }
        config.update(config_extras)
        return self.disp_helper.get_dispatcher(config)

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    def mk_middleware_records(self, rkey_in, rkey_out):
        records = []
        for rkey, direction in [(rkey_in, False), (rkey_out, True)]:
            endpoint, method = rkey.split('.', 1)
            mw = [[name, method, endpoint] for name in ("mw1", "mw2")]
            if direction:
                mw.reverse()
            records.extend(mw)
        return records

    def assert_inbound(self, dst_conn, src_conn, msg):
        [dst_msg] = self.disp_helper.get_dispatched_inbound(dst_conn)
        middleware_records = self.mk_middleware_records(
            src_conn + '.inbound', dst_conn + '.inbound')
        self.assertEqual(dst_msg.payload.pop('record'), middleware_records)
        self.assertEqual(msg, dst_msg)

    def assert_event(self, dst_conn, src_conn, msg):
        [dst_msg] = self.disp_helper.get_dispatched_events(dst_conn)
        middleware_records = self.mk_middleware_records(
            src_conn + '.event', dst_conn + '.event')
        self.assertEqual(dst_msg.payload.pop('record'), middleware_records)
        self.assertEqual(msg, dst_msg)

    def assert_outbound(self, dst_conn, src_conn_msg_pairs):
        dst_msgs = self.disp_helper.get_dispatched_outbound(dst_conn)
        for src_conn, msg in src_conn_msg_pairs:
            dst_msg = dst_msgs.pop(0)
            middleware_records = self.mk_middleware_records(
                src_conn + '.outbound', dst_conn + '.outbound')
            self.assertEqual(dst_msg.payload.pop('record'), middleware_records)
            self.assertEqual(msg, dst_msg)
        self.assertEqual([], dst_msgs)

    def assert_no_inbound(self, *conns):
        for conn in conns:
            self.assertEqual([], self.disp_helper.get_dispatched_inbound(conn))

    def assert_no_outbound(self, *conns):
        for conn in conns:
            self.assertEqual(
                [], self.disp_helper.get_dispatched_outbound(conn))

    def assert_no_events(self, *conns):
        for conn in conns:
            self.assertEqual([], self.disp_helper.get_dispatched_events(conn))

    @inlineCallbacks
    def test_inbound_message_routing(self):
        yield self.get_dispatcher()
        msg = yield self.ch('transport1').make_dispatch_inbound(
            "foo", transport_name='transport1')
        self.assert_inbound('app1', 'transport1', msg)
        self.assert_no_inbound('app2', 'app3')
        self.assert_no_events('app1', 'app2', 'app3')

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport2').make_dispatch_inbound(
            "foo", transport_name='transport2')
        self.assert_inbound('app2', 'transport2', msg)
        self.assert_no_inbound('app1', 'app3')
        self.assert_no_events('app1', 'app2', 'app3')

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport3').make_dispatch_inbound(
            "foo", transport_name='transport3')
        self.assert_inbound('app1', 'transport3', msg)
        self.assert_inbound('app3', 'transport3', msg)
        self.assert_no_inbound('app2')
        self.assert_no_events('app1', 'app2', 'app3')

    @inlineCallbacks
    def test_inbound_ack_routing(self):
        yield self.get_dispatcher()
        msg = yield self.ch('transport1').make_dispatch_ack(
            transport_name='transport1')
        self.assert_event('app1', 'transport1', msg)
        self.assert_no_inbound('app1', 'app2', 'app3')
        self.assert_no_events('app2', 'app3')

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport2').make_dispatch_ack(
            transport_name='transport2')
        self.assert_event('app2', 'transport2', msg)
        self.assert_no_inbound('app1', 'app2', 'app3')
        self.assert_no_events('app1', 'app3')

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport3').make_dispatch_ack(
            transport_name='transport3')
        self.assert_event('app1', 'transport3', msg)
        self.assert_event('app3', 'transport3', msg)
        self.assert_no_inbound('app1', 'app2', 'app3')
        self.assert_no_events('app2')

    @inlineCallbacks
    def test_outbound_message_routing(self):
        yield self.get_dispatcher()

        @inlineCallbacks
        def dispatch_for_transport(transport):
            app_msg_pairs = []
            for app in ['app1', 'app2', 'app3']:
                msg = yield self.ch(app).make_dispatch_outbound(
                    app, transport_name=transport)
                app_msg_pairs.append((app, msg))
            returnValue(app_msg_pairs)

        app_msg_pairs = yield dispatch_for_transport('transport1')
        self.assert_outbound('transport1', app_msg_pairs)
        self.assert_no_outbound('transport2', 'transport3')

        self.disp_helper.clear_all_dispatched()
        app_msg_pairs = yield dispatch_for_transport('transport2')
        self.assert_outbound('transport2', app_msg_pairs)
        self.assert_no_outbound('transport1', 'transport3')

        self.disp_helper.clear_all_dispatched()
        app_msg_pairs = yield dispatch_for_transport('transport3')
        self.assert_outbound('transport3', app_msg_pairs)
        self.assert_no_outbound('transport1', 'transport2')

    @inlineCallbacks
    def test_unroutable_outbound_error(self):
        dispatcher = yield self.get_dispatcher()
        router = dispatcher._router
        msg = self.disp_helper.make_outbound("out", transport_name='foo')
        with LogCatcher() as log:
            yield router.dispatch_outbound_message(msg)
            [error] = log.errors
            self.assertTrue(('Unknown transport_name: foo' in
                                str(error['failure'].value)))
        [f] = self.flushLoggedErrors(DispatcherError)
        self.assertEqual(f, error['failure'])

    @inlineCallbacks
    def test_outbound_message_routing_transport_mapping(self):
        """
        Test that transport mappings are applied for outbound messages.
        """
        yield self.get_dispatcher(
            transport_mappings={'upstream1': 'transport1'},
            transport_names=[
                'transport1',
                'transport2',
                'transport3',
                'upstream1',
            ])

        @inlineCallbacks
        def dispatch_for_transport(transport):
            app_msg_pairs = []
            for app in ['app1', 'app2', 'app3']:
                msg = yield self.ch(app).make_dispatch_outbound(
                    app, transport_name=transport)
                app_msg_pairs.append((app, msg))
            returnValue(app_msg_pairs)

        app_msg_pairs = yield dispatch_for_transport('upstream1')
        self.assert_outbound('transport1', app_msg_pairs)
        self.assert_no_outbound('transport2', 'transport3', 'upstream1')

        self.disp_helper.clear_all_dispatched()
        app_msg_pairs = yield dispatch_for_transport('transport2')
        self.assert_outbound('transport2', app_msg_pairs)
        self.assert_no_outbound('transport1', 'transport3', 'upstream1')

    def get_dispatcher_consumers(self, dispatcher):
        return (dispatcher.transport_consumer.values() +
                dispatcher.transport_event_consumer.values() +
                dispatcher.exposed_consumer.values())

    @inlineCallbacks
    def test_consumer_prefetch_count_default(self):
        dp = yield self.get_dispatcher()
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 20)

    @inlineCallbacks
    def test_consumer_prefetch_count_custom(self):
        dp = yield self.get_dispatcher(amqp_prefetch_count=10)
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 10)

    @inlineCallbacks
    def test_consumer_prefetch_count_none(self):
        dp = yield self.get_dispatcher(amqp_prefetch_count=None)
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 0)


class TestToAddrRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.config = {
            'transport_names': ['transport1'],
            'exposed_names': ['app1', 'app2'],
            'toaddr_mappings': {
                'app1': 'to:.*:1',
                'app2': 'to:app2',
                },
            }
        self.dispatcher = DummyDispatcher(self.config)
        self.router = ToAddrRouter(self.dispatcher, self.config)
        yield self.router.setup_routing()
        self.msg_helper = self.add_helper(MessageHelper())

    def test_dispatch_inbound_message(self):
        msg = self.msg_helper.make_inbound(
            "1", to_addr='to:foo:1', transport_name='transport1')
        self.router.dispatch_inbound_message(msg)
        publishers = self.dispatcher.exposed_publisher
        self.assertEqual(publishers['app1'].msgs, [msg])
        self.assertEqual(publishers['app2'].msgs, [])

    def test_dispatch_outbound_message(self):
        msg = self.msg_helper.make_outbound("out", transport_name='transport1')
        self.router.dispatch_outbound_message(msg)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport1'].msgs, [msg])

        self.dispatcher.transport_publisher['transport1'].clear()
        self.config['transport_mappings'] = {
            'upstream1': 'transport1',
            }

        msg = self.msg_helper.make_outbound("out", transport_name='upstream1')
        self.router.dispatch_outbound_message(msg)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport1'].msgs, [msg])


class TestTransportToTransportRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))
        self.worker = yield self.disp_helper.get_worker(BaseDispatchWorker, {
            "transport_names": [
                "transport1",
                "transport2",
            ],
            "exposed_names": [],
            "router_class": "vumi.dispatchers.base.TransportToTransportRouter",
            "route_mappings": {
                "transport1": ["transport2"],
            },
        })

    @inlineCallbacks
    def test_inbound_message_routing(self):
        tx1_helper = self.disp_helper.get_connector_helper('transport1')
        tx2_helper = self.disp_helper.get_connector_helper('transport2')
        msg = yield tx1_helper.make_dispatch_inbound(
            "foo", transport_name='transport1')
        self.assertEqual([msg], tx1_helper.get_dispatched_inbound())
        self.assertEqual([msg], tx2_helper.get_dispatched_outbound())
        self.assertEqual([], tx1_helper.get_dispatched_outbound())
        self.assertEqual([], tx2_helper.get_dispatched_inbound())


class TestFromAddrMultiplexRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        config = {
            "transport_names": [
                "transport_1",
                "transport_2",
                "transport_3",
                ],
            "exposed_names": ["muxed"],
            "router_class": "vumi.dispatchers.base.FromAddrMultiplexRouter",
            "fromaddr_mappings": {
                "thing1@muxme": "transport_1",
                "thing2@muxme": "transport_2",
                "thing3@muxme": "transport_3",
                },
            }
        self.dispatcher = DummyDispatcher(config)
        self.router = FromAddrMultiplexRouter(self.dispatcher, config)
        self.add_cleanup(self.router.teardown_routing)
        yield self.router.setup_routing()
        self.msg_helper = self.add_helper(MessageHelper())

    def make_inbound_mux(self, content, from_addr, transport_name):
        return self.msg_helper.make_inbound(
            content, transport_name=transport_name, from_addr=from_addr)

    def make_ack_mux(self, from_addr, transport_name):
        return self.msg_helper.make_ack(
            transport_name=transport_name, from_addr=from_addr)

    def make_outbound_mux(self, content, from_addr):
        return self.msg_helper.make_outbound(
            content, transport_name='muxed', from_addr=from_addr)

    def test_inbound_message_routing(self):
        msg1 = self.make_inbound_mux('mux 1', 'thing1@muxme', 'transport_1')
        self.router.dispatch_inbound_message(msg1)
        msg2 = self.make_inbound_mux('mux 2', 'thing2@muxme', 'transport_2')
        self.router.dispatch_inbound_message(msg2)
        publishers = self.dispatcher.exposed_publisher
        self.assertEqual(publishers['muxed'].msgs, [msg1, msg2])

    def test_inbound_event_routing(self):
        msg1 = self.make_ack_mux('thing1@muxme', 'transport_1')
        self.router.dispatch_inbound_event(msg1)
        msg2 = self.make_ack_mux('thing2@muxme', 'transport_2')
        self.router.dispatch_inbound_event(msg2)
        publishers = self.dispatcher.exposed_event_publisher
        self.assertEqual(publishers['muxed'].msgs, [msg1, msg2])

    def test_outbound_message_routing(self):
        msg1 = self.make_outbound_mux('mux 1', 'thing1@muxme')
        self.router.dispatch_outbound_message(msg1)
        msg2 = self.make_outbound_mux('mux 2', 'thing2@muxme')
        self.router.dispatch_outbound_message(msg2)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport_1'].msgs, [msg1])
        self.assertEqual(publishers['transport_2'].msgs, [msg2])


class TestUserGroupingRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))
        self.dispatcher = yield self.disp_helper.get_dispatcher({
            'dispatcher_name': 'user_group_dispatcher',
            'router_class': 'vumi.dispatchers.base.UserGroupingRouter',
            'transport_names': [
                'transport1',
            ],
            'exposed_names': [
                'app1',
                'app2',
            ],
            'group_mappings': {
                'group1': 'app1',
                'group2': 'app2',
            },
            'transport_mappings': {
                'upstream1': 'transport1',
            },
        })
        self.router = self.dispatcher._router
        yield self.router._redis_d
        self.redis = self.router.redis
        yield self.redis._purge_all()  # just in case

    @inlineCallbacks
    def test_group_assignment(self):
        msg = self.disp_helper.make_inbound("foo")
        selected_group = yield self.router.get_group_for_user(msg.user())
        self.assertTrue(selected_group)
        for i in range(0, 10):
            group = yield self.router.get_group_for_user(msg.user())
            self.assertEqual(group, selected_group)

    @inlineCallbacks
    def test_round_robin_group_assignment(self):
        messages = [
            self.disp_helper.make_inbound(str(i), from_addr='from_%s' % (i,))
            for i in range(0, 4)]
        groups = [(yield self.router.get_group_for_user(message.user()))
                  for message in messages]
        self.assertEqual(groups, [
            'group1',
            'group2',
            'group1',
            'group2',
        ])

    def make_inbound_from(self, from_addr):
        return self.disp_helper.make_inbound("foo", from_addr=from_addr)

    @inlineCallbacks
    def test_routing_to_application(self):
        # generate 4 messages, 2 from each user
        msg1 = self.make_inbound_from('from_1')
        msg2 = self.make_inbound_from('from_2')
        msg3 = self.make_inbound_from('from_3')
        msg4 = self.make_inbound_from('from_4')
        # send them through to the dispatcher
        messages = [msg1, msg2, msg3, msg4]
        for message in messages:
            yield self.disp_helper.dispatch_inbound(message, 'transport1')

        app1_msgs = self.disp_helper.get_dispatched_inbound('app1')
        app2_msgs = self.disp_helper.get_dispatched_inbound('app2')
        self.assertEqual(app1_msgs, [msg1, msg3])
        self.assertEqual(app2_msgs, [msg2, msg4])

    @inlineCallbacks
    def test_routing_to_transport(self):
        app_msg = self.disp_helper.make_outbound(
            'foo', transport_name='transport1')
        yield self.disp_helper.dispatch_outbound(app_msg, 'app1')
        [tx_msg] = self.disp_helper.get_dispatched_outbound('transport1')
        self.assertEqual(app_msg, tx_msg)

    @inlineCallbacks
    def test_routing_to_transport_mapped(self):
        app_msg = self.disp_helper.make_outbound(
            'foo', transport_name='upstream1')
        yield self.disp_helper.dispatch_outbound(app_msg, 'app1')
        [tx_msg] = self.disp_helper.get_dispatched_outbound('transport1')
        self.assertEqual(app_msg, tx_msg)


class TestContentKeywordRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))
        self.dispatcher = yield self.disp_helper.get_dispatcher({
            'dispatcher_name': 'keyword_dispatcher',
            'router_class': 'vumi.dispatchers.base.ContentKeywordRouter',
            'transport_names': ['transport1', 'transport2'],
            'transport_mappings': {
                'shortcode1': 'transport1',
                'shortcode2': 'transport2',
                },
            'exposed_names': ['app1', 'app2', 'app3', 'fallback_app'],
            'rules': [{'app': 'app1',
                       'keyword': 'KEYWORD1',
                       'to_addr': '8181',
                       'prefix': '+256',
                       },
                      {'app': 'app2',
                       'keyword': 'KEYWORD2',
                       }],
            'keyword_mappings': {
                'app2': 'KEYWORD3',
                'app3': 'KEYWORD1',
                },
            'fallback_application': 'fallback_app',
            'expire_routing_memory': '3',
        })
        self.router = self.dispatcher._router
        yield self.router._redis_d
        self.add_cleanup(self.router.session_manager.stop)
        self.redis = self.router.redis
        yield self.redis._purge_all()  # just in case

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    def send_inbound(self, content, **kw):
        return self.ch('transport1').make_dispatch_inbound(content, **kw)

    def assert_dispatched(self, connector_name, msgs):
        self.assertEqual(
            msgs, self.disp_helper.get_dispatched_inbound(connector_name))

    @inlineCallbacks
    def test_inbound_message_routing(self):
        msg1 = yield self.send_inbound(
            'KEYWORD1 rest of msg', to_addr='8181', from_addr='+256788601462')
        msg2 = yield self.send_inbound(
            'KEYWORD2 rest of msg', to_addr='8181', from_addr='+256788601462')
        msg3 = yield self.send_inbound(
            'KEYWORD3 rest of msg', to_addr='8181', from_addr='+256788601462')

        self.assert_dispatched('app1', [msg1])
        self.assert_dispatched('app2', [msg2, msg3])
        self.assert_dispatched('app3', [msg1])

    @inlineCallbacks
    def test_inbound_message_routing_empty_message_content(self):
        msg = yield self.send_inbound(None)
        self.assert_dispatched('app1', [])
        self.assert_dispatched('app2', [])
        self.assert_dispatched('fallback_app', [msg])

    @inlineCallbacks
    def test_inbound_message_routing_not_casesensitive(self):
        msg = yield self.send_inbound(
            'keyword1 rest of msg', to_addr='8181', from_addr='+256788601462')
        self.assert_dispatched('app1', [msg])

    @inlineCallbacks
    def test_inbound_event_routing_ok(self):
        yield self.router.session_manager.create_session(
            'message:1', name='app2')
        ack = yield self.ch('transport1').make_dispatch_ack(
            self.disp_helper.make_outbound("foo", message_id='1'),
            transport_name='transport1')

        self.assertEqual([], self.disp_helper.get_dispatched_events('app1'))
        self.assertEqual([ack], self.disp_helper.get_dispatched_events('app2'))

    @inlineCallbacks
    def test_inbound_event_routing_failing_no_routing_back_in_redis(self):
        ack = yield self.ch('transport1').make_dispatch_ack(
            transport_name='transport1')

        self.assertEqual([], self.disp_helper.get_dispatched_events('app1'))
        self.assertEqual([], self.disp_helper.get_dispatched_events('app2'))

        [redis_lookup_fail, no_route_fail] = self.flushLoggedErrors(
            DispatcherError)
        self.assertEqual(str(redis_lookup_fail.value), (
            'No transport_name for return route found in Redis while'
            ' dispatching transport event for message %s'
            % ack['user_message_id']))
        self.assertEqual(str(no_route_fail.value),
                         'No publishing route for None')

    @inlineCallbacks
    def test_outbound_message_routing(self):
        msg = yield self.ch('app2').make_dispatch_outbound(
            "KEYWORD1 rest of msg", from_addr='shortcode1',
            transport_name='app2', message_id='1')

        self.assertEqual(
            [msg], self.disp_helper.get_dispatched_outbound('transport1'))
        self.assertEqual(
            [], self.disp_helper.get_dispatched_outbound('transport2'))

        session = yield self.router.session_manager.load_session('message:1')
        self.assertEqual(session['name'], 'app2')


class TestRedirectOutboundRouterForSMPP(VumiTestCase):
    """
    This is a test to cover our use case when using SMPP 3.4 with
    split Tx and Rx binds. The outbound traffic needs to go to the Tx, while
    the Rx just should go through. Upstream everything should be seen
    as arriving from the dispatcher and so the `transport_name` should be
    overwritten.
    """

    @inlineCallbacks
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))
        self.dispatcher = yield self.disp_helper.get_dispatcher({
            'dispatcher_name': 'redirect_outbound_dispatcher',
            'router_class': 'vumi.dispatchers.base.RedirectOutboundRouter',
            'transport_names': ['smpp_rx_transport', 'smpp_tx_transport'],
            'exposed_names': ['upstream'],
            'redirect_outbound': {
                'upstream': 'smpp_tx_transport',
            },
            'redirect_inbound': {
                'smpp_tx_transport': 'upstream',
                'smpp_rx_transport': 'upstream',
            },
        })
        self.router = self.dispatcher._router

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    @inlineCallbacks
    def test_outbound_message_via_tx(self):
        msg = yield self.ch('upstream').make_dispatch_outbound(
            "foo", transport_name='upstream')
        [out] = self.disp_helper.get_dispatched_outbound('smpp_tx_transport')
        self.assertEqual(out['message_id'], msg['message_id'])

    @inlineCallbacks
    def test_inbound_event_tx(self):
        ack = yield self.ch('smpp_tx_transport').make_dispatch_ack(
            transport_name='smpp_tx_transport')
        [event] = self.disp_helper.get_dispatched_events('upstream')
        self.assertEqual(event['transport_name'], 'upstream')
        self.assertEqual(event['event_id'], ack['event_id'])

    @inlineCallbacks
    def test_inbound_event_rx(self):
        ack = yield self.ch('smpp_rx_transport').make_dispatch_ack(
            transport_name='smpp_rx_transport')
        [event] = self.disp_helper.get_dispatched_events('upstream')
        self.assertEqual(event['transport_name'], 'upstream')
        self.assertEqual(event['event_id'], ack['event_id'])

    @inlineCallbacks
    def test_inbound_message_via_rx(self):
        msg = yield self.ch('smpp_rx_transport').make_dispatch_inbound(
            "foo", transport_name='smpp_rx_transport')
        [app_msg] = self.disp_helper.get_dispatched_inbound('upstream')
        self.assertEqual(app_msg['transport_name'], 'upstream')
        self.assertEqual(app_msg['message_id'], msg['message_id'])

    @inlineCallbacks
    def test_error_logging_for_bad_app(self):
        msgt1 = self.disp_helper.make_outbound(
            "foo", transport_name='foo')  # Does not exist
        with LogCatcher() as log:
            yield self.disp_helper.dispatch_outbound(msgt1, 'upstream')
            [err] = log.errors
            self.assertTrue('No redirect_outbound specified for foo' in
                                str(err['failure'].value))
        [f] = self.flushLoggedErrors(DispatcherError)
        self.assertEqual(f, err['failure'])


class TestRedirectOutboundRouter(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(BaseDispatchWorker))
        self.dispatcher = yield self.disp_helper.get_dispatcher({
            'dispatcher_name': 'redirect_outbound_dispatcher',
            'router_class': 'vumi.dispatchers.base.RedirectOutboundRouter',
            'transport_names': ['transport1', 'transport2'],
            'exposed_names': ['app1', 'app2'],
            'redirect_outbound': {
                'app1': 'transport1',
                'app2': 'transport2',
            },
            'redirect_inbound': {
                'transport1': 'app1',
                'transport2': 'app2',
            }
        })
        self.router = self.dispatcher._router

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    @inlineCallbacks
    def test_outbound_redirect(self):
        msgt1 = yield self.ch('app1').make_dispatch_outbound(
            "t1", transport_name='app1')
        msgt2 = yield self.ch('app2').make_dispatch_outbound(
            "t2", transport_name='app2')
        self.assertEqual(
            [msgt1], self.disp_helper.get_dispatched_outbound('transport1'))
        self.assertEqual(
            [msgt2], self.disp_helper.get_dispatched_outbound('transport2'))

    @inlineCallbacks
    def test_inbound_event(self):
        ack = yield self.ch('transport1').make_dispatch_ack(
            transport_name='transport1')
        [event] = self.disp_helper.get_dispatched_events('app1')
        self.assertEqual(event['transport_name'], 'app1')
        self.assertEqual(event['event_id'], ack['event_id'])

    @inlineCallbacks
    def test_inbound_message(self):
        msg = yield self.ch('transport1').make_dispatch_inbound(
            "foo", transport_name='transport1')
        [app_msg] = self.disp_helper.get_dispatched_inbound('app1')
        self.assertEqual(app_msg['transport_name'], 'app1')
        self.assertEqual(app_msg['message_id'], msg['message_id'])

    @inlineCallbacks
    def test_error_logging_for_bad_app(self):
        msgt1 = self.disp_helper.make_outbound(
            "foo", transport_name='app3')  # Does not exist
        with LogCatcher() as log:
            yield self.disp_helper.dispatch_outbound(msgt1, 'app2')
            [err] = log.errors
            self.assertTrue('No redirect_outbound specified for app3' in
                                str(err['failure'].value))
        [f] = self.flushLoggedErrors(DispatcherError)
        self.assertEqual(f, err['failure'])
PK=JGJ#55,vumi/dispatchers/tests/test_load_balancer.py"""Tests for vumi.dispatchers.load_balancer."""

from twisted.internet.defer import inlineCallbacks

from vumi.dispatchers.load_balancer import LoadBalancingRouter
from vumi.dispatchers.tests.helpers import DummyDispatcher
from vumi.tests.helpers import VumiTestCase, MessageHelper
from vumi.tests.utils import LogCatcher


class BaseLoadBalancingTestCase(VumiTestCase):

    reply_affinity = None
    rewrite_transport_names = None

    @inlineCallbacks
    def setUp(self):
        config = {
            "transport_names": [
                "transport_1",
                "transport_2",
            ],
            "exposed_names": ["round_robin"],
            "router_class": ("vumi.dispatchers.load_balancer."
                             "LoadBalancingRouter"),
        }
        if self.reply_affinity is not None:
            config['reply_affinity'] = self.reply_affinity
        if self.rewrite_transport_names is not None:
            config['rewrite_transport_names'] = self.rewrite_transport_names
        self.dispatcher = DummyDispatcher(config)
        self.router = LoadBalancingRouter(self.dispatcher, config)
        self.add_cleanup(self.router.teardown_routing)
        yield self.router.setup_routing()
        self.msg_helper = self.add_helper(MessageHelper())


class TestLoadBalancingWithoutReplyAffinity(BaseLoadBalancingTestCase):

    reply_affinity = False

    def test_inbound_message_routing(self):
        msg1 = self.msg_helper.make_inbound(
            'msg 1', transport_name='transport_1')
        self.router.dispatch_inbound_message(msg1)
        msg2 = self.msg_helper.make_inbound(
            'msg 2', transport_name='transport_2')
        self.router.dispatch_inbound_message(msg2)
        publishers = self.dispatcher.exposed_publisher
        self.assertEqual(publishers['round_robin'].msgs, [msg1, msg2])

    def test_inbound_event_routing(self):
        msg1 = self.msg_helper.make_ack(transport_name='transport_1')
        self.router.dispatch_inbound_event(msg1)
        msg2 = self.msg_helper.make_ack(transport_name='transport_2')
        self.router.dispatch_inbound_event(msg2)
        publishers = self.dispatcher.exposed_event_publisher
        self.assertEqual(publishers['round_robin'].msgs, [msg1, msg2])

    def test_outbound_message_routing(self):
        msg1 = self.msg_helper.make_outbound('msg 1')
        self.router.dispatch_outbound_message(msg1)
        msg2 = self.msg_helper.make_outbound('msg 2')
        self.router.dispatch_outbound_message(msg2)
        msg3 = self.msg_helper.make_outbound('msg 3')
        self.router.dispatch_outbound_message(msg3)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport_1'].msgs, [msg1, msg3])
        self.assertEqual(publishers['transport_2'].msgs, [msg2])


class TestLoadBalancingWithReplyAffinity(BaseLoadBalancingTestCase):

    reply_affinity = True

    def test_inbound_message_routing(self):
        msg1 = self.msg_helper.make_inbound(
            'msg 1', transport_name='transport_1')
        self.router.dispatch_inbound_message(msg1)
        msg2 = self.msg_helper.make_inbound(
            'msg 2', transport_name='transport_2')
        self.router.dispatch_inbound_message(msg2)
        publishers = self.dispatcher.exposed_publisher
        self.assertEqual(publishers['round_robin'].msgs, [msg1, msg2])

    def test_inbound_event_routing(self):
        msg1 = self.msg_helper.make_ack(transport_name='transport_1')
        self.router.dispatch_inbound_event(msg1)
        msg2 = self.msg_helper.make_ack(transport_name='transport_2')
        self.router.dispatch_inbound_event(msg2)
        publishers = self.dispatcher.exposed_event_publisher
        self.assertEqual(publishers['round_robin'].msgs, [msg1, msg2])

    def test_outbound_message_routing(self):
        msg1 = self.msg_helper.make_outbound('msg 1', in_reply_to='msg X')
        self.router.push_transport_name(msg1, 'transport_1')
        self.router.dispatch_outbound_message(msg1)
        msg2 = self.msg_helper.make_outbound('msg 2', in_reply_to='msg X')
        self.router.push_transport_name(msg2, 'transport_1')
        self.router.dispatch_outbound_message(msg2)
        msg3 = self.msg_helper.make_outbound('msg 3', in_reply_to='msg X')
        self.router.push_transport_name(msg3, 'transport_2')
        self.router.dispatch_outbound_message(msg3)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport_1'].msgs, [msg1, msg2])
        self.assertEqual(publishers['transport_2'].msgs, [msg3])

    def test_outbound_message_with_unknown_transport_name(self):
        # we expect unknown outbound transport_names to be
        # round-robinned and logged.
        msg1 = self.msg_helper.make_outbound('msg 1', in_reply_to='msg X')
        self.router.push_transport_name(msg1, 'transport_unknown')
        with LogCatcher() as lc:
            self.router.dispatch_outbound_message(msg1)
            [errmsg] = lc.messages()
            self.assertTrue("unknown load balancer endpoint "
                            "'transport_unknown' was was received" in errmsg)
        publishers = self.dispatcher.transport_publisher
        self.assertEqual(publishers['transport_1'].msgs, [msg1])


class TestLoadBalancingWithRewriteTransportNames(BaseLoadBalancingTestCase):

    rewrite_transport_names = True

    def test_inbound_message_routing(self):
        msg = self.msg_helper.make_inbound(
            'msg 1', transport_name='transport_1')
        self.router.dispatch_inbound_message(msg)
        [new_msg] = self.dispatcher.exposed_publisher['round_robin'].msgs
        self.assertEqual(new_msg['transport_name'], 'round_robin')

    def test_inbound_event_routing(self):
        msg = self.msg_helper.make_ack(transport_name='transport_1')
        self.router.dispatch_inbound_event(msg)
        [new_msg] = self.dispatcher.exposed_event_publisher['round_robin'].msgs
        self.assertEqual(new_msg['transport_name'], 'round_robin')

    def test_outbound_message_routing(self):
        msg1 = self.msg_helper.make_outbound(
            'msg 1', transport_name='round_robin')
        self.router.dispatch_outbound_message(msg1)
        [new_msg] = self.dispatcher.transport_publisher['transport_1'].msgs
        self.assertEqual(new_msg['transport_name'], 'transport_1')


class TestLoadBalancingWithoutRewriteTransportNames(BaseLoadBalancingTestCase):

    rewrite_transport_names = False

    def test_inbound_message_routing(self):
        msg = self.msg_helper.make_inbound(
            'msg 1', transport_name='transport_1')
        self.router.dispatch_inbound_message(msg)
        [new_msg] = self.dispatcher.exposed_publisher['round_robin'].msgs
        self.assertEqual(new_msg['transport_name'], 'transport_1')

    def test_inbound_event_routing(self):
        msg = self.msg_helper.make_ack(transport_name='transport_1')
        self.router.dispatch_inbound_event(msg)
        [new_msg] = self.dispatcher.exposed_event_publisher['round_robin'].msgs
        self.assertEqual(new_msg['transport_name'], 'transport_1')

    def test_outbound_message_routing(self):
        msg1 = self.msg_helper.make_outbound(
            'msg 1', transport_name='round_robin')
        self.router.dispatch_outbound_message(msg1)
        [new_msg] = self.dispatcher.transport_publisher['transport_1'].msgs
        self.assertEqual(new_msg['transport_name'], 'round_robin')
PKqGE%%3vumi/dispatchers/tests/test_endpoint_dispatchers.pyfrom twisted.python.failure import Failure
from twisted.internet.defer import inlineCallbacks

from vumi.dispatchers.endpoint_dispatchers import (
    Dispatcher, RoutingTableDispatcher)
from vumi.dispatchers.tests.helpers import DispatcherHelper
from vumi.tests.utils import LogCatcher
from vumi.tests.helpers import VumiTestCase


class DummyError(Exception):
    """Custom exception to use in test cases."""


class TestDispatcher(VumiTestCase):

    def setUp(self):
        self.disp_helper = self.add_helper(DispatcherHelper(Dispatcher))

    def get_dispatcher(self, **config_extras):
        config = {
            "receive_inbound_connectors": ["transport1", "transport2"],
            "receive_outbound_connectors": ["app1", "app2"],
            }
        config.update(config_extras)
        return self.disp_helper.get_dispatcher(config)

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    @inlineCallbacks
    def test_default_errback(self):
        disp = yield self.get_dispatcher()
        msg = self.disp_helper.make_inbound('bad')
        f = Failure(DummyError("worse"))
        with LogCatcher() as lc:
            yield disp.default_errback(f, msg, "app1")
            [err1] = lc.errors
        self.assertEqual(
            err1['why'], "Error routing message for app1")
        self.assertEqual(
            err1['failure'], f)
        self.flushLoggedErrors(DummyError)

    @inlineCallbacks
    def check_errback_called(self, method_to_raise, errback_method, direction):
        dummy_error = DummyError("eep")

        def raiser(*args):
            raise dummy_error

        errors = []

        def record_error(self, f, msg, connector_name):
            errors.append((f, msg, connector_name))

        raiser_patch = self.patch(Dispatcher, method_to_raise, raiser)
        recorder_patch = self.patch(Dispatcher, errback_method, record_error)
        disp = yield self.get_dispatcher()

        if direction == 'inbound':
            connector_name = 'transport1'
            msg = yield self.ch('transport1').make_dispatch_inbound("inbound")
        elif direction == 'outbound':
            connector_name = 'app1'
            msg = yield self.ch('app1').make_dispatch_outbound("outbound")
        elif direction == 'event':
            connector_name = 'transport1'
            msg = yield self.ch('transport1').make_dispatch_ack()
        else:
            raise ValueError(
                "Unexcepted value %r for direction" % (direction,))

        [(err_f, err_msg, err_connector)] = errors
        self.assertEqual(err_f.value, dummy_error)
        self.assertEqual(err_msg, msg)
        self.assertEqual(err_connector, connector_name)

        yield self.disp_helper.cleanup_worker(disp)
        raiser_patch.restore()
        recorder_patch.restore()

    @inlineCallbacks
    def test_inbound_errbacks(self):
        for err_method in ('process_inbound', 'get_config'):
            yield self.check_errback_called(
                err_method, 'default_errback', 'inbound')
            yield self.check_errback_called(
                err_method, 'errback_inbound', 'inbound')

    @inlineCallbacks
    def test_outbound_errbacks(self):
        for err_method in ('process_outbound', 'get_config'):
            yield self.check_errback_called(
                err_method, 'default_errback', 'outbound')
            yield self.check_errback_called(
                err_method, 'errback_outbound', 'outbound')

    @inlineCallbacks
    def test_event_errbacks(self):
        for err_method in ('process_event', 'get_config'):
            yield self.check_errback_called(
                err_method, 'default_errback', 'event')
            yield self.check_errback_called(
                err_method, 'errback_event', 'event')


class TestRoutingTableDispatcher(VumiTestCase):

    def setUp(self):
        self.disp_helper = self.add_helper(
            DispatcherHelper(RoutingTableDispatcher))

    def get_dispatcher(self, **config_extras):
        config = {
            "receive_inbound_connectors": ["transport1", "transport2"],
            "receive_outbound_connectors": ["app1", "app2"],
            "routing_table": {
                "transport1": {
                    "default": ["app1", "default"],
                    },
                "transport2": {
                    "default": ["app2", "default"],
                    "ep1": ["app1", "ep1"],
                    },
                "app1": {
                    "default": ["transport1", "default"],
                    "ep2": ["transport2", "default"],
                    },
                "app2": {
                    "default": ["transport2", "default"],
                    },
                },
            }
        config.update(config_extras)
        return self.disp_helper.get_dispatcher(config)

    def ch(self, connector_name):
        return self.disp_helper.get_connector_helper(connector_name)

    def assert_rkeys_used(self, *rkeys):
        broker = self.disp_helper.worker_helper.broker
        self.assertEqual(set(rkeys), set(broker.dispatched['vumi'].keys()))

    def assert_dispatched_endpoint(self, msg, endpoint, dispatched_msgs):
        msg.set_routing_endpoint(endpoint)
        self.assertEqual([msg], dispatched_msgs)

    @inlineCallbacks
    def test_inbound_message_routing(self):
        yield self.get_dispatcher()
        msg = yield self.ch("transport1").make_dispatch_inbound("inbound")
        self.assert_rkeys_used('transport1.inbound', 'app1.inbound')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('app1').get_dispatched_inbound())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch("transport2").make_dispatch_inbound("inbound")
        self.assert_rkeys_used('transport2.inbound', 'app2.inbound')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('app2').get_dispatched_inbound())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch("transport2").make_dispatch_inbound(
            "inbound", endpoint='ep1')
        self.assert_rkeys_used('transport2.inbound', 'app1.inbound')
        self.assert_dispatched_endpoint(
            msg, 'ep1', self.ch('app1').get_dispatched_inbound())

    @inlineCallbacks
    def test_outbound_message_routing(self):
        yield self.get_dispatcher()
        msg = yield self.ch('app1').make_dispatch_outbound("outbound")
        self.assert_rkeys_used('app1.outbound', 'transport1.outbound')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('transport1').get_dispatched_outbound())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('app2').make_dispatch_outbound("outbound")
        self.assert_rkeys_used('app2.outbound', 'transport2.outbound')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('transport2').get_dispatched_outbound())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('app1').make_dispatch_outbound(
            "outbound", endpoint='ep2')
        self.assert_rkeys_used('app1.outbound', 'transport2.outbound')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('transport2').get_dispatched_outbound())

    @inlineCallbacks
    def test_inbound_event_routing(self):
        yield self.get_dispatcher()
        msg = yield self.ch('transport1').make_dispatch_ack()
        self.assert_rkeys_used('transport1.event', 'app1.event')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('app1').get_dispatched_events())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport2').make_dispatch_ack()
        self.assert_rkeys_used('transport2.event', 'app2.event')
        self.assert_dispatched_endpoint(
            msg, 'default', self.ch('app2').get_dispatched_events())

        self.disp_helper.clear_all_dispatched()
        msg = yield self.ch('transport2').make_dispatch_ack(endpoint='ep1')
        self.assert_rkeys_used('transport2.event', 'app1.event')
        self.assert_dispatched_endpoint(
            msg, 'ep1', self.ch('app1').get_dispatched_events())

    def get_dispatcher_consumers(self, dispatcher):
        consumers = []
        for conn in dispatcher.connectors.values():
            consumers.extend(conn._consumers.values())
        return consumers

    @inlineCallbacks
    def test_consumer_prefetch_count_default(self):
        dp = yield self.get_dispatcher()
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 20)

    @inlineCallbacks
    def test_consumer_prefetch_count_custom(self):
        dp = yield self.get_dispatcher(amqp_prefetch_count=10)
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 10)

    @inlineCallbacks
    def test_consumer_prefetch_count_none(self):
        dp = yield self.get_dispatcher(amqp_prefetch_count=None)
        consumers = self.get_dispatcher_consumers(dp)
        for consumer in consumers:
            fake_channel = consumer.channel._fake_channel
            self.assertEqual(fake_channel.qos_prefetch_count, 0)
PK=JGk33+vumi/dispatchers/tests/test_test_helpers.pyfrom twisted.internet.defer import inlineCallbacks

from vumi.dispatchers.endpoint_dispatchers import Dispatcher
from vumi.dispatchers.tests.helpers import DummyDispatcher, DispatcherHelper
from vumi.tests.helpers import (
    VumiTestCase, IHelper, PersistenceHelper, MessageHelper, WorkerHelper,
    MessageDispatchHelper, success_result_of)


class TestDummyDispatcher(VumiTestCase):
    def test_publish_inbound(self):
        """
        DummyDispatcher should have a fake inbound publisher that remembers
        messages.
        """
        dispatcher = DummyDispatcher({
            'transport_names': ['ri_conn'],
            'exposed_names': ['ro_conn'],
        })
        self.assertEqual(dispatcher.exposed_publisher['ro_conn'].msgs, [])
        dispatcher.publish_inbound_message('ro_conn', 'fake inbound')
        self.assertEqual(
            dispatcher.exposed_publisher['ro_conn'].msgs, ['fake inbound'])
        dispatcher.exposed_publisher['ro_conn'].clear()
        self.assertEqual(dispatcher.exposed_publisher['ro_conn'].msgs, [])

    def test_publish_outbound(self):
        """
        DummyDispatcher should have a fake outbound publisher that remembers
        messages.
        """
        dispatcher = DummyDispatcher({
            'transport_names': ['ri_conn'],
            'exposed_names': ['ro_conn'],
        })
        self.assertEqual(dispatcher.transport_publisher['ri_conn'].msgs, [])
        dispatcher.publish_outbound_message('ri_conn', 'fake outbound')
        self.assertEqual(
            dispatcher.transport_publisher['ri_conn'].msgs, ['fake outbound'])
        dispatcher.transport_publisher['ri_conn'].clear()
        self.assertEqual(dispatcher.transport_publisher['ri_conn'].msgs, [])

    def test_publish_event(self):
        """
        DummyDispatcher should have a fake event publisher that remembers
        messages.
        """
        dispatcher = DummyDispatcher({
            'transport_names': ['ri_conn'],
            'exposed_names': ['ro_conn'],
        })
        self.assertEqual(
            dispatcher.exposed_event_publisher['ro_conn'].msgs, [])
        dispatcher.publish_inbound_event('ro_conn', 'fake event')
        self.assertEqual(
            dispatcher.exposed_event_publisher['ro_conn'].msgs, ['fake event'])
        dispatcher.exposed_event_publisher['ro_conn'].clear()
        self.assertEqual(
            dispatcher.exposed_event_publisher['ro_conn'].msgs, [])


class RunningCheckDispatcher(Dispatcher):
    disp_worker_running = False

    def setup_dispatcher(self):
        self.disp_worker_running = True

    def teardown_dispatcher(self):
        self.disp_worker_running = False


class FakeCleanupCheckHelper(object):
    cleaned_up = False

    def cleanup(self):
        self.cleaned_up = True


class TestDispatcherHelper(VumiTestCase):
    def test_implements_IHelper(self):
        """
        DispatcherHelper instances should provide the IHelper interface.
        """
        self.assertTrue(IHelper.providedBy(DispatcherHelper(None)))

    def test_defaults(self):
        """
        DispatcherHelper instances should have the expected parameter defaults.
        """
        fake_disp_class = object()
        disp_helper = DispatcherHelper(fake_disp_class)
        self.assertEqual(disp_helper.dispatcher_class, fake_disp_class)
        self.assertIsInstance(
            disp_helper.persistence_helper, PersistenceHelper)
        self.assertIsInstance(disp_helper.msg_helper, MessageHelper)
        self.assertIsInstance(disp_helper.worker_helper, WorkerHelper)
        dispatch_helper = disp_helper.dispatch_helper
        self.assertIsInstance(dispatch_helper, MessageDispatchHelper)
        self.assertEqual(dispatch_helper.msg_helper, disp_helper.msg_helper)
        self.assertEqual(
            dispatch_helper.worker_helper, disp_helper.worker_helper)
        self.assertEqual(disp_helper.persistence_helper.use_riak, False)

    def test_all_params(self):
        """
        DispatcherHelper should pass use_riak to its PersistenceHelper and all
        other params to its MessageHelper.
        """
        fake_disp_class = object()
        disp_helper = DispatcherHelper(
            fake_disp_class, use_riak=True, transport_addr='Obs station')
        self.assertEqual(disp_helper.persistence_helper.use_riak, True)
        self.assertEqual(disp_helper.msg_helper.transport_addr, 'Obs station')

    def test_setup_sync(self):
        """
        DispatcherHelper.setup() should return ``None``, not a Deferred.
        """
        msg_helper = DispatcherHelper(None)
        self.add_cleanup(msg_helper.cleanup)
        self.assertEqual(msg_helper.setup(), None)

    def test_cleanup(self):
        """
        DispatcherHelper.cleanup() should call .cleanup() on its
        PersistenceHelper and WorkerHelper.
        """
        disp_helper = DispatcherHelper(None)
        disp_helper.persistence_helper = FakeCleanupCheckHelper()
        disp_helper.worker_helper = FakeCleanupCheckHelper()
        self.assertEqual(disp_helper.persistence_helper.cleaned_up, False)
        self.assertEqual(disp_helper.worker_helper.cleaned_up, False)
        success_result_of(disp_helper.cleanup())
        self.assertEqual(disp_helper.persistence_helper.cleaned_up, True)
        self.assertEqual(disp_helper.worker_helper.cleaned_up, True)

    @inlineCallbacks
    def test_get_dispatcher_defaults(self):
        """
        .get_dispatcher() should return a started dispatcher.
        """
        disp_helper = self.add_helper(DispatcherHelper(RunningCheckDispatcher))
        app = yield disp_helper.get_dispatcher({
            'receive_inbound_connectors': [],
            'receive_outbound_connectors': [],
        })
        self.assertIsInstance(app, RunningCheckDispatcher)
        self.assertEqual(app.disp_worker_running, True)

    @inlineCallbacks
    def test_get_dispatcher_no_start(self):
        """
        .get_dispatcher() should return an unstarted dispatcher if passed
        ``start=False``.
        """
        disp_helper = self.add_helper(DispatcherHelper(RunningCheckDispatcher))
        app = yield disp_helper.get_dispatcher({
            'receive_inbound_connectors': [],
            'receive_outbound_connectors': [],
        }, start=False)
        self.assertIsInstance(app, RunningCheckDispatcher)
        self.assertEqual(app.disp_worker_running, False)

    @inlineCallbacks
    def test_get_application_different_class(self):
        """
        .get_dispatcher() should return an instance of the specified worker
        class if one is provided.
        """
        disp_helper = self.add_helper(DispatcherHelper(Dispatcher))
        app = yield disp_helper.get_dispatcher({
            'receive_inbound_connectors': [],
            'receive_outbound_connectors': [],
        }, cls=RunningCheckDispatcher)
        self.assertIsInstance(app, RunningCheckDispatcher)

    def test_get_connector_helper(self):
        """
        .get_connector_helper() should return a DispatcherConnectorHelper
        instance for the provided connector name.
        """
        disp_helper = DispatcherHelper(None)
        dc_helper = disp_helper.get_connector_helper('barconn')
        self.assertEqual(dc_helper.msg_helper, disp_helper.msg_helper)
        self.assertEqual(dc_helper.worker_helper._connector_name, 'barconn')
        self.assertEqual(
            dc_helper.worker_helper.broker, disp_helper.worker_helper.broker)
PK=JGw!vumi/dispatchers/tests/helpers.pyfrom twisted.internet.defer import inlineCallbacks

from zope.interface import implements

from vumi.dispatchers.base import BaseDispatchWorker
from vumi.middleware import MiddlewareStack
from vumi.tests.helpers import (
    MessageHelper, PersistenceHelper, WorkerHelper, MessageDispatchHelper,
    generate_proxies, IHelper,
)


class DummyDispatcher(BaseDispatchWorker):

    class DummyPublisher(object):
        def __init__(self):
            self.msgs = []

        def publish_message(self, msg):
            self.msgs.append(msg)

        def clear(self):
            self.msgs[:] = []

    def __init__(self, config):
        self.transport_publisher = {}
        self.transport_names = config.get('transport_names', [])
        for transport in self.transport_names:
            self.transport_publisher[transport] = self.DummyPublisher()
        self.exposed_publisher = {}
        self.exposed_event_publisher = {}
        self.exposed_names = config.get('exposed_names', [])
        for exposed in self.exposed_names:
            self.exposed_publisher[exposed] = self.DummyPublisher()
            self.exposed_event_publisher[exposed] = self.DummyPublisher()
        self._middlewares = MiddlewareStack([])


class DispatcherHelper(object):
    """
    Test helper for dispatcher workers.

    This helper construct and wraps several lower-level helpers and provides
    higher-level functionality for dispatcher tests.

    :param dispatcher_class:
        The worker class for the dispatcher being tested.

    :param bool use_riak:
        Set to ``True`` if the test requires Riak. This is passed to the
        underlying :class:`~vumi.tests.helpers.PersistenceHelper`.

    :param \**msg_helper_args:
        All other keyword params are passed to the underlying
        :class:`~vumi.tests.helpers.MessageHelper`.
    """

    implements(IHelper)

    def __init__(self, dispatcher_class, use_riak=False, **msg_helper_args):
        self.dispatcher_class = dispatcher_class
        self.worker_helper = WorkerHelper()
        self.persistence_helper = PersistenceHelper(use_riak=use_riak)
        self.msg_helper = MessageHelper(**msg_helper_args)
        self.dispatch_helper = MessageDispatchHelper(
            self.msg_helper, self.worker_helper)

        # Proxy methods from our helpers.
        generate_proxies(self, self.msg_helper)
        generate_proxies(self, self.worker_helper)
        generate_proxies(self, self.dispatch_helper)

    def setup(self):
        self.persistence_helper.setup()
        self.worker_helper.setup()

    @inlineCallbacks
    def cleanup(self):
        yield self.worker_helper.cleanup()
        yield self.persistence_helper.cleanup()

    def get_dispatcher(self, config, cls=None, start=True):
        """
        Get an instance of a dispatcher class.

        :param dict config: Config dict.
        :param cls:
            The transport class to instantiate. Defaults to
            :attr:`dispatcher_class`
        :param bool start:
            ``True`` to start the dispatcher (default), ``False`` otherwise.
        """
        if cls is None:
            cls = self.dispatcher_class
        config = self.persistence_helper.mk_config(config)
        return self.get_worker(cls, config, start)

    def get_connector_helper(self, connector_name):
        """
        Construct a :class:`~DispatcherConnectorHelper` for the provided
        ``connector_name``.
        """
        return DispatcherConnectorHelper(self, connector_name)


class DispatcherConnectorHelper(object):
    """
    Subset of :class:`~vumi.tests.helpers.WorkerHelper` and
    :class:`~vumi.tests.helpers.MessageDispatchHelper` functionality for a
    specific connector. This should only be created with
    :meth:`DispatcherHelper.get_connector_helper`.
    """
    def __init__(self, dispatcher_helper, connector_name):
        self.msg_helper = dispatcher_helper.msg_helper
        self.worker_helper = WorkerHelper(
            connector_name, dispatcher_helper.worker_helper.broker)
        self.dispatch_helper = MessageDispatchHelper(
            self.msg_helper, self.worker_helper)

        generate_proxies(self, self.worker_helper)
        generate_proxies(self, self.dispatch_helper)

        # We don't want to be able to make workers with this helper.
        del self.get_worker
        del self.cleanup_worker
PK=JG"vumi/dispatchers/tests/__init__.pyPKqGStG@G@vumi/tests/test_connectors.pyfrom twisted.internet.defer import inlineCallbacks, returnValue

from vumi.connectors import (
    BaseConnector, ReceiveInboundConnector, ReceiveOutboundConnector,
    PublishStatusConnector, ReceiveStatusConnector, IgnoreMessage)
from vumi.tests.utils import LogCatcher
from vumi.worker import BaseWorker
from vumi.message import TransportUserMessage
from vumi.middleware.tests.utils import RecordingMiddleware
from vumi.tests.helpers import VumiTestCase, MessageHelper, WorkerHelper


class DummyWorker(BaseWorker):
    def setup_connectors(self):
        pass

    def setup_worker(self):
        pass

    def teardown_worker(self):
        pass


class BaseConnectorTestCase(VumiTestCase):

    connector_class = None

    def setUp(self):
        self.msg_helper = self.add_helper(MessageHelper())
        self.worker_helper = self.add_helper(WorkerHelper())

    @inlineCallbacks
    def mk_connector(self, worker=None, connector_name=None,
                     prefetch_count=None, middlewares=None, setup=False):
        if worker is None:
            worker = yield self.worker_helper.get_worker(DummyWorker, {})
        if connector_name is None:
            connector_name = "dummy_connector"
        connector = self.connector_class(worker, connector_name,
                                         prefetch_count=prefetch_count,
                                         middlewares=middlewares)
        if setup:
            yield connector.setup()
        returnValue(connector)

    @inlineCallbacks
    def mk_consumer(self, *args, **kwargs):
        conn = yield self.mk_connector(*args, **kwargs)
        consumer = yield conn._setup_consumer('inbound', TransportUserMessage,
                                              lambda msg: None)
        returnValue((conn, consumer))


class TestBaseConnector(BaseConnectorTestCase):

    connector_class = BaseConnector

    @inlineCallbacks
    def test_creation(self):
        conn = yield self.mk_connector(connector_name="foo")
        self.assertEqual(conn.name, "foo")
        self.assertTrue(isinstance(conn.worker, BaseWorker))

    @inlineCallbacks
    def test_middlewares_consume(self):
        worker = yield self.worker_helper.get_worker(DummyWorker, {})
        middlewares = [RecordingMiddleware(
            str(i), {'consume_priority': 0, 'publish_priority': 0}, worker)
            for i in range(3)]
        conn, consumer = yield self.mk_consumer(
            worker=worker, connector_name='foo', middlewares=middlewares)
        consumer.unpause()
        msgs = []
        conn._set_default_endpoint_handler('inbound', msgs.append)
        msg = self.msg_helper.make_inbound("inbound")
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        record = msgs[0].payload.pop('record')
        self.assertEqual(record,
                         [(str(i), 'inbound', 'foo')
                          for i in range(3)])

    @inlineCallbacks
    def test_middlewares_publish(self):
        worker = yield self.worker_helper.get_worker(DummyWorker, {})
        middlewares = [RecordingMiddleware(
            str(i), {'consume_priority': 0, 'publish_priority': 0}, worker)
            for i in range(3)]
        conn = yield self.mk_connector(
            worker=worker, connector_name='foo', middlewares=middlewares)
        yield conn._setup_publisher('outbound')
        msg = self.msg_helper.make_outbound("outbound")
        yield conn._publish_message('outbound', msg, 'dummy_endpoint')
        msgs = self.worker_helper.get_dispatched_outbound('foo')
        record = msgs[0].payload.pop('record')
        self.assertEqual(record,
                         [[str(i), 'outbound', 'foo']
                          for i in range(2, -1, -1)])

    @inlineCallbacks
    def test_prefetch_count(self):
        conn, consumer = yield self.mk_consumer(prefetch_count=10)
        fake_channel = consumer.channel._fake_channel
        self.assertEqual(fake_channel.qos_prefetch_count, 10)

    @inlineCallbacks
    def test_setup_raises(self):
        conn = yield self.mk_connector()
        self.assertRaises(NotImplementedError, conn.setup)

    @inlineCallbacks
    def test_teardown(self):
        conn, consumer = yield self.mk_consumer()
        self.assertTrue(consumer.keep_consuming)
        yield conn.teardown()
        self.assertFalse(consumer.keep_consuming)

    @inlineCallbacks
    def test_paused(self):
        conn, consumer = yield self.mk_consumer()
        consumer.pause()
        self.assertTrue(conn.paused)
        consumer.unpause()
        self.assertFalse(conn.paused)

    @inlineCallbacks
    def test_pause(self):
        conn, consumer = yield self.mk_consumer()
        consumer.unpause()
        self.assertFalse(consumer.paused)
        conn.pause()
        self.assertTrue(consumer.paused)

    @inlineCallbacks
    def test_unpause(self):
        conn, consumer = yield self.mk_consumer()
        consumer.pause()
        self.assertTrue(consumer.paused)
        conn.unpause()
        self.assertFalse(consumer.paused)

    @inlineCallbacks
    def test_setup_publisher(self):
        conn = yield self.mk_connector(connector_name='foo')
        publisher = yield conn._setup_publisher('outbound')
        self.assertEqual(publisher.routing_key, 'foo.outbound')

    @inlineCallbacks
    def test_setup_consumer(self):
        conn, consumer = yield self.mk_consumer(connector_name='foo')
        self.assertTrue(consumer.paused)
        self.assertEqual(consumer.routing_key, 'foo.inbound')
        self.assertEqual(consumer.message_class, TransportUserMessage)

    @inlineCallbacks
    def test_set_endpoint_handler(self):
        conn, consumer = yield self.mk_consumer(connector_name='foo')
        consumer.unpause()
        msgs = []
        conn._set_endpoint_handler('inbound', msgs.append, 'dummy_endpoint')
        msg = self.msg_helper.make_inbound("inbound")
        msg.set_routing_endpoint('dummy_endpoint')
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_set_none_endpoint_handler(self):
        conn, consumer = yield self.mk_consumer(connector_name='foo')
        consumer.unpause()
        msgs = []
        conn._set_endpoint_handler('inbound', msgs.append, None)
        msg = self.msg_helper.make_inbound("inbound")
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_set_default_endpoint_handler(self):
        conn, consumer = yield self.mk_consumer(connector_name='foo')
        consumer.unpause()
        msgs = []
        conn._set_default_endpoint_handler('inbound', msgs.append)
        msg = self.msg_helper.make_inbound("inbound")
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_publish_message_with_endpoint(self):
        conn = yield self.mk_connector(connector_name='foo')
        yield conn._setup_publisher('outbound')
        msg = self.msg_helper.make_outbound("outbound")
        yield conn._publish_message('outbound', msg, 'dummy_endpoint')
        msgs = self.worker_helper.get_dispatched_outbound('foo')
        self.assertEqual(msgs, [msg])


class TestReceiveInboundConnector(BaseConnectorTestCase):

    connector_class = ReceiveInboundConnector

    @inlineCallbacks
    def test_setup(self):
        conn = yield self.mk_connector(connector_name='foo')
        yield conn.setup()
        conn.unpause()

        with LogCatcher() as lc:
            msg = self.msg_helper.make_inbound("inbound")
            yield self.worker_helper.dispatch_inbound(msg, 'foo')
            [msg_log] = lc.messages()
            self.assertTrue(msg_log.startswith("No inbound handler for 'foo'"))

        with LogCatcher() as lc:
            event = self.msg_helper.make_ack()
            yield self.worker_helper.dispatch_event(event, 'foo')
            [event_log] = lc.messages()
            self.assertTrue(event_log.startswith("No event handler for 'foo'"))

        msg = self.msg_helper.make_outbound("outbound")
        yield conn.publish_outbound(msg)
        msgs = self.worker_helper.get_dispatched_outbound('foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_default_inbound_handler(self):
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        with LogCatcher() as lc:
            conn.default_inbound_handler(
                self.msg_helper.make_inbound("inbound"))
            [log] = lc.messages()
            self.assertTrue(log.startswith("No inbound handler for 'foo'"))

    @inlineCallbacks
    def test_default_event_handler(self):
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        with LogCatcher() as lc:
            conn.default_event_handler(self.msg_helper.make_ack())
            [log] = lc.messages()
            self.assertTrue(log.startswith("No event handler for 'foo'"))

    @inlineCallbacks
    def test_set_inbound_handler(self):
        msgs = []
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        conn.unpause()
        conn.set_inbound_handler(msgs.append)
        msg = self.msg_helper.make_inbound("inbound")
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_set_default_inbound_handler(self):
        msgs = []
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        conn.unpause()
        conn.set_default_inbound_handler(msgs.append)
        msg = self.msg_helper.make_inbound("inbound")
        yield self.worker_helper.dispatch_inbound(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_set_event_handler(self):
        msgs = []
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        conn.unpause()
        conn.set_event_handler(msgs.append)
        msg = self.msg_helper.make_ack()
        yield self.worker_helper.dispatch_event(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_set_default_event_handler(self):
        msgs = []
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        conn.unpause()
        conn.set_default_event_handler(msgs.append)
        msg = self.msg_helper.make_ack()
        yield self.worker_helper.dispatch_event(msg, 'foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_publish_outbound(self):
        conn = yield self.mk_connector(connector_name='foo', setup=True)
        msg = self.msg_helper.make_outbound("outbound")
        yield conn.publish_outbound(msg)
        msgs = self.worker_helper.get_dispatched_outbound('foo')
        self.assertEqual(msgs, [msg])

    @inlineCallbacks
    def test_inbound_handler_ignore_message(self):
        def im_handler(msg):
            raise IgnoreMessage()

        conn = yield self.mk_connector(connector_name='foo', setup=True)
        conn.unpause()
        conn.set_default_inbound_handler(im_handler)
        msg = self.msg_helper.make_inbound("inbound")
        with LogCatcher() as lc:
            yield self.worker_helper.dispatch_inbound(msg, 'foo')
            [log] = lc.messages()
            self.assertTrue(log.startswith(
                "Ignoring msg due to IgnoreMessage(): ``.
    """

    def __repr__(self):
        return 'DEFAULT'

DEFAULT = _Default()


class IHelper(Interface):
    """
    Interface for test helpers.

    This specifies a standard setup and cleanup mechanism used by test cases
    that implement the :class:`IHelperEnabledTestCase` interface.

    There are no interface restrictions on the constructor of a helper.
    """

    def setup(*args, **kwargs):
        """
        Perform potentially async helper setup.

        This may return a deferred for async setup or block for sync setup. All
        helpers must implement this even if it does nothing.

        If the setup is optional but commonly used, this method can take flags
        to perform or suppress all or part of it as required.
        """

    def cleanup():
        """
        Clean up any resources created by this helper.

        This may return a deferred for async cleanup or block for sync cleanup.
        All helpers must implement this even if it does nothing.
        """


class IHelperEnabledTestCase(Interface):
    """
    Interface for test cases that use helpers.

    This specifies a standard mechanism for managing setup and cleanup of
    helper classes that implement the :class:`IHelper` interface.
    """

    def add_helper(helper_object, *args, **kwargs):
        """
        Register cleanup and perform setup for a helper object.

        This should call ``helper_object.setup(*args, **kwargs)`` and
        ``self.add_cleanup(helper_object.cleanup)`` or an equivalent.

        Returns the ``helper_object`` passed in or a :class:`Deferred` if
        setup is async.
        """


def proxyable(func):
    """
    Mark a method as being suitable for automatic proxy generation.

    See :func:`generate_proxies` for usage.
    """
    func.proxyable = True
    return func


def generate_proxies(target, source):
    """
    Generate proxies on ``target`` for proxyable methods on ``source``.

    This is useful for wrapping helper objects in higher-level helpers or
    extending a helper to provide extra functionality without having to resort
    to subclassing.

    The "proxying" is actually just copying the proxyable attribute onto the
    target.

    >>> class AddHelper(object):
    ...     def __init__(self, number):
    ...         self._number = number
    ...
    ...     @proxyable
    ...     def add_number(self, number):
    ...         return self._number + number

    >>> class OtherHelper(object):
    ...     def __init__(self, number):
    ...         self._adder = AddHelper(number)
    ...         generate_proxies(self, self._adder)
    ...
    ...     @proxyable
    ...     def say_hello(self):
    ...         return "hello"

    >>> other_helper = OtherHelper(3)
    >>> other_helper.say_hello()
    'hello'
    >>> other_helper.add_number(2)
    5
    """

    for name in dir(source):
        attribute = getattr(source, name)
        if not getattr(attribute, 'proxyable', False):
            continue

        if hasattr(target, name):
            raise Exception(
                'Attribute already exists: %s' % (name,))
        setattr(target, name, attribute)


def success_result_of(d):
    """
    We can't necessarily use TestCase.successResultOf because our Twisted might
    not be new enough. This is a standalone copy with some minor message
    differences.
    """
    results = []
    d.addBoth(results.append)
    if not results:
        raise FailTest("No result available for deferred: %r" % (d,))
    if isinstance(results[0], Failure):
        raise FailTest("Expected success from deferred %r, got failure: %r" % (
            d, results[0]))
    return results[0]


def get_timeout():
    """
    Look up the test timeout in the ``VUMI_TEST_TIMEOUT`` environment variable.

    A default of 5 seconds is used if there isn't one there.
    """
    timeout_str = os.environ.get('VUMI_TEST_TIMEOUT', '5')
    return float(timeout_str)


def get_stack_trace(exclude_last=0):
    """
    Get a stack trace that can be stored and referred to later.

    The inside of this function is excluded from the stack trace, because it's
    not relevant. Additionally, all entries prior to the first occurrence of
    "twisted/trial/_asynctest.py" or "django/test/testcases.py" are removed to
    avoid unnecessary test runner noise.

    :param int exclude_last:
        Number of entries to remove from the end of the stack trace. Use this
        to get rid of wrapper functions or implementation details irrelevant to
        the purpose of the stack trace.

    :return:
        A list of strings, each representing a stack frame, in the same format
        as ``traceback.format_stack()``.
    """
    stack = traceback.format_stack()

    def is_boring(entry):
        return all(pathfrag not in entry for pathfrag in [
            "twisted/trial/_asynctest.py",
            "django/test/testcases.py",
        ])

    filtered_stack = list(dropwhile(is_boring, stack))
    if filtered_stack:
        # We haven't accidentally devoured everything.
        lines_removed = len(stack) - len(filtered_stack)
        stack = filtered_stack
        stack.insert(0, "%s test runner lines removed.\n" % lines_removed)
    # Remove the current stack frame and any extra stack frames we've been
    # asked to remove.
    return stack[:-(exclude_last + 1)]


class VumiTestCase(TestCase):
    """
    Base test case class for all things vumi-related.

    This is a subclass of :class:`twisted.trial.unittest.TestCase` with a small
    number of additional features:

    * It implements :class:`IHelperEnabledTestCase` to make using helpers
      easier. (See :meth:`add_helper`.)

    * :attr:`timeout` is set to a default value of ``5`` and can be overridden
      by setting the ``VUMI_TEST_TIMEOUT`` environment variable. (Longer
      timeouts are more reliable for continuous integration builds, shorter
      ones are less painful for local development.)

    * :meth:`add_cleanup` provides an alternative mechanism for specifying
      cleanup in the same place as the creation of thing that needs to be
      cleaned up.

    .. note::

       While this class does not have a :meth:`setUp` method (thus avoiding the
       need for subclasses to call it), it *does* have a :meth:`tearDown`
       method. :meth:`add_cleanup` should be used in subclasses instead of
       overriding :meth:`tearDown`.
    """

    implements(IHelperEnabledTestCase)

    timeout = get_timeout()
    reactor_check_interval = 0.01  # 10ms, no science behind this number.
    reactor_check_iterations = 100  # No science behind this number either.

    _cleanup_funcs = None

    @inlineCallbacks
    def tearDown(self):
        """
        Run any cleanup functions registered with :meth:`add_cleanup`.
        """
        # Run any cleanup code we've registered with .add_cleanup().
        # We do this ourselves instead of using trial's .addCleanup() because
        # that doesn't have timeouts applied to it.
        if self._cleanup_funcs is not None:
            for cleanup, args, kw in reversed(self._cleanup_funcs):
                yield cleanup(*args, **kw)
        yield self._check_reactor_things()

    @inlineCallbacks
    def _check_reactor_things(self):
        """
        Poll the reactor for unclosed connections and wait for them to close.

        Properly waiting for all connections to finish closing requires hooking
        into :meth:`Protocol.connectionLost` in both client and server. Since
        this isn't practical in all cases, we check the reactor for any open
        connections and wait a bit for them to finish closing if we find any.

        NOTE: This will only wait for connections that close on their own. Any
              connections that have been left open will stay open (unless they
              time out or something) and will leave the reactor dirty after we
              stop waiting.
        """
        from twisted.internet import reactor
        # Give the reactor a chance to get clean.
        yield deferLater(reactor, 0, lambda: None)

        for i in range(self.reactor_check_iterations):
            # There are some internal readers that we want to ignore.
            # Unfortunately they're private.
            internal_readers = getattr(reactor, '_internalReaders', set())
            selectables = set(reactor.getReaders() + reactor.getWriters())
            if not (selectables - internal_readers):
                # The reactor's clean, let's go home.
                return

            # We haven't gone home, so wait a bit for selectables to go away.
            yield deferLater(
                reactor, self.reactor_check_interval, lambda: None)

    def add_cleanup(self, func, *args, **kw):
        """
        Register a cleanup function to be called at teardown time.

        :param callable func:
            The callable object to call at cleanup time. This callable may
            return a :class:`Deferred`, in which case cleanup will continue
            after it fires.
        :param \*args: Passed to ``func`` when it is called.
        :param \**kw: Passed to ``func`` when it is called.

        .. note::
           This method should be use in place of the inherited
           :meth:`addCleanup` method, because the latter doesn't apply timeouts
           to cleanup functions.
        """
        if self._cleanup_funcs is None:
            self._cleanup_funcs = []
        self._cleanup_funcs.append((func, args, kw))

    def add_helper(self, helper_object, *args, **kw):
        """
        Perform setup and register cleanup for the given helper object.

        :param helper_object:
            Helper object to add. ``helper_object`` must provide the
            :class:`IHelper` interface.
        :param \*args: Passed to :meth:`helper_object.setup` when it is called.
        :param \**kw: Passed to :meth:`helper_object.setup` when it is called.

        :returns:
            Either ``helper_object`` or a :class:`Deferred` that fires with it.

        If :meth:`helper_object.setup` returns a :class:`Deferred`, this method
        also returns a :class:`Deferred`.

        Example usage assuming ``@inlineCallbacks``:

        >>> @inlineCallbacks
        ... def test_foo(self):
        ...     msg_helper = yield self.add_helper(MessageHelper())
        ...     msg_helper.make_inbound("foo")

        Example usage assuming non-async setup:

        >>> def test_bar(self):
        ...     msg_helper = self.add_helper(MessageHelper())
        ...     msg_helper.make_inbound("bar")

        """

        if not IHelper.providedBy(helper_object):
            raise ValueError(
                "Helper object does not provide the IHelper interface: %s" % (
                    helper_object,))
        self.add_cleanup(helper_object.cleanup)
        return maybe_async_return(
            helper_object, helper_object.setup(*args, **kw))

    def _runFixturesAndTest(self, result):
        """
        Override trial's ``_runFixturesAndTest()`` method to detect test
        methods that are generator functions, indicating a missing
        ``@inlineCallbacks`` decorator.

        NOTE: This should probably be removed when
              https://twistedmatrix.com/trac/ticket/3917 is merged and the next
              Twisted version (probably 14.0) is released.
        """
        method = getattr(self, self._testMethodName)
        if method.func_code.co_flags & CO_GENERATOR:
            # We have a generator that isn't wrapped in @inlineCallbacks
            e = ValueError(
                "Test method is a generator. Missing @inlineCallbacks?")
            result.addError(self, Failure(e))
            return
        return super(VumiTestCase, self)._runFixturesAndTest(result)


class MessageHelper(object):
    """
    Test helper for constructing various messages.

    This helper does no setup or cleanup. It takes the following parameters,
    which are used as defaults for message fields:

    :param str transport_name:
        Default value for ``transport_name`` on all messages.

    :param str transport_type:
        Default value for ``transport_type`` on all messages.

    :param str mobile_addr:
        Default value for ``from_addr`` on inbound messages and ``to_addr`` on
        outbound messages.

    :param str transport_addr:
        Default value for ``to_addr`` on inbound messages and ``from_addr`` on
        outbound messages.
    """

    implements(IHelper)

    def __init__(self, transport_name='sphex', transport_type='sms',
                 mobile_addr='+41791234567', transport_addr='9292'):
        self.transport_name = transport_name
        self.transport_type = transport_type
        self.mobile_addr = mobile_addr
        self.transport_addr = transport_addr

    def setup(self):
        pass

    def cleanup(self):
        pass

    @proxyable
    def make_inbound(self, content, from_addr=DEFAULT, to_addr=DEFAULT, **kw):
        """
        Construct an inbound :class:`~vumi.message.TransportUserMessage`.

        This is a convenience wrapper around :meth:`make_user_message` and just
        sets ``to_addr`` and ``from_addr`` appropriately for an inbound
        message.
        """
        if from_addr is DEFAULT:
            from_addr = self.mobile_addr
        if to_addr is DEFAULT:
            to_addr = self.transport_addr
        return self.make_user_message(content, from_addr, to_addr, **kw)

    @proxyable
    def make_outbound(self, content, from_addr=DEFAULT, to_addr=DEFAULT, **kw):
        """
        Construct an outbound :class:`~vumi.message.TransportUserMessage`.

        This is a convenience wrapper around :meth:`make_user_message` and just
        sets ``to_addr`` and ``from_addr`` appropriately for an outbound
        message.
        """
        if from_addr is DEFAULT:
            from_addr = self.transport_addr
        if to_addr is DEFAULT:
            to_addr = self.mobile_addr
        return self.make_user_message(content, from_addr, to_addr, **kw)

    @proxyable
    def make_user_message(self, content, from_addr, to_addr, group=None,
                          session_event=None, transport_type=DEFAULT,
                          transport_name=DEFAULT, transport_metadata=DEFAULT,
                          helper_metadata=DEFAULT, endpoint=DEFAULT, **kw):
        """
        Construct a :class:`~vumi.message.TransportUserMessage`.

        This method is the underlying implementation for :meth:`make_inbound`
        and :meth:`make_outbound` and those should be used instead where they
        apply.

        The only real difference between using this method and constructing a
        message object directly is that this method provides sensible defaults
        for most fields and sets the routing endpoint (if provided) in a more
        convenient way.

        The following parameters are mandatory:

        :param str content: Message ``content`` field.
        :param str from_addr: Message ``from_addr`` field.
        :param str to_addr: Message ``to_addr`` field.

        The following parameters override default values for the message fields
        of the same name:

        :param str group: Default ``None``.
        :param str session_event: Default ``None``.
        :param str transport_type: Default :attr:`transport_type`.
        :param str transport_name: Default :attr:`transport_name`.
        :param dict transport_metadata: Default ``{}``.
        :param dict helper_metadata: Default ``{}``.

        The following parameter is special:

        :param str endpoint:
            If specified, the routing endpoint on the message is set by calling
            :meth:`TransportUserMessage.set_routing_endpoint`.

        All other keyword args are passed to the
        :class:`~vumi.message.TransportUserMessage` constructor.
        """
        if transport_type is DEFAULT:
            transport_type = self.transport_type
        if helper_metadata is DEFAULT:
            helper_metadata = {}
        if transport_metadata is DEFAULT:
            transport_metadata = {}
        if transport_name is DEFAULT:
            transport_name = self.transport_name
        msg = TransportUserMessage(
            from_addr=from_addr,
            to_addr=to_addr,
            group=group,
            transport_name=transport_name,
            transport_type=transport_type,
            transport_metadata=transport_metadata,
            helper_metadata=helper_metadata,
            content=content,
            session_event=session_event,
            **kw)
        if endpoint is not DEFAULT:
            msg.set_routing_endpoint(endpoint)
        return msg

    @proxyable
    def make_event(self, event_type, user_message_id, transport_type=DEFAULT,
                   transport_name=DEFAULT, transport_metadata=DEFAULT,
                   endpoint=DEFAULT, **kw):
        """
        Construct a :class:`~vumi.message.TransportEvent`.

        This method is the underlying implementation for :meth:`make_ack`,
        :meth:`make_nack` and :meth:`make_delivery_report`. Those should
        be used instead where they apply.

        The only real difference between using this method and constructing an
        event object directly is that this method provides sensible defaults
        for most fields and sets the routing endpoint (if provided) in a more
        convenient way.

        The following parameters are mandatory:

        :param str event_type: Event ``event_type`` field.
        :param str user_message_id: Event ``user_message_id`` field.

        Any fields required by a particular event type (such as
        ``sent_message_id`` for ``ack`` events) are also mandatory.

        The following parameters override default values for the event fields
        of the same name:

        :param str transport_type: Default :attr:`transport_type`.
        :param str transport_name: Default :attr:`transport_name`.
        :param dict transport_metadata: Default ``{}``.

        The following parameter is special:

        :param str endpoint:
            If specified, the routing endpoint on the event is set by calling
            :meth:`TransportUserMessage.set_routing_endpoint`.

        All other keyword args are passed to the
        :class:`~vumi.message.TransportEvent` constructor.
        """
        if transport_type is DEFAULT:
            transport_type = self.transport_type
        if transport_name is DEFAULT:
            transport_name = self.transport_name
        if transport_metadata is DEFAULT:
            transport_metadata = {}
        msg = TransportEvent(
            event_type=event_type,
            user_message_id=user_message_id,
            transport_name=transport_name,
            transport_type=transport_type,
            transport_metadata=transport_metadata,
            **kw)
        if endpoint is not DEFAULT:
            msg.set_routing_endpoint(endpoint)
        return msg

    @proxyable
    def make_ack(self, msg=None, sent_message_id=DEFAULT, **kw):
        """
        Construct an 'ack' :class:`~vumi.message.TransportEvent`.

        :param msg:
            :class:`~vumi.message.TransportUserMessage` instance the event is
            for. If ``None``, this method will call :meth:`make_outbound` to
            get one.
        :param str sent_message_id:
            If this isn't provided, ``msg['message_id']`` will be used.

        All remaining keyword params are passed to :meth:`make_event`.
        """
        if msg is None:
            msg = self.make_outbound("for ack")
        user_message_id = msg['message_id']
        if sent_message_id is DEFAULT:
            sent_message_id = user_message_id
        return self.make_event(
            'ack', user_message_id, sent_message_id=sent_message_id, **kw)

    @proxyable
    def make_nack(self, msg=None, nack_reason=DEFAULT, **kw):
        """
        Construct a 'nack' :class:`~vumi.message.TransportEvent`.

        :param msg:
            :class:`~vumi.message.TransportUserMessage` instance the event is
            for. If ``None``, this method will call :meth:`make_outbound` to
            get one.
        :param str nack_reason:
            If this isn't provided, a suitable excuse will be used.

        All remaining keyword params are passed to :meth:`make_event`.
        """
        if msg is None:
            msg = self.make_outbound("for nack")
        user_message_id = msg['message_id']
        if nack_reason is DEFAULT:
            nack_reason = "sunspots"
        return self.make_event(
            'nack', user_message_id, nack_reason=nack_reason, **kw)

    @proxyable
    def make_delivery_report(self, msg=None, delivery_status=DEFAULT, **kw):
        """
        Construct a 'delivery_report' :class:`~vumi.message.TransportEvent`.

        :param msg:
            :class:`~vumi.message.TransportUserMessage` instance the event is
            for. If ``None``, this method will call :meth:`make_outbound` to
            get one.
        :param str delivery_status:
            If this isn't provided, ``"delivered"`` will be used.

        All remaining keyword params are passed to :meth:`make_event`.
        """
        if msg is None:
            msg = self.make_outbound("for delivery_report")
        user_message_id = msg['message_id']
        if delivery_status is DEFAULT:
            delivery_status = "delivered"
        return self.make_event(
            'delivery_report', user_message_id,
            delivery_status=delivery_status, **kw)

    @proxyable
    def make_reply(self, msg, content, **kw):
        """
        Construct a reply :class:`~vumi.message.TransportUserMessage`.

        This literally just calls ``msg.reply(content, **kw)``. It is included
        for completeness and symmetry with
        :meth:`MessageDispatchHelper.make_dispatch_reply`.
        """
        return msg.reply(content, **kw)

    @proxyable
    def make_status(self, **kw):
        """
        Construct a :class:`~vumi.message.TransportStatus`.
        """
        return TransportStatus(**kw)


def _start_and_return_worker(worker):
    return worker.startWorker().addCallback(lambda r: worker)


class WorkerHelper(object):
    """
    Test helper for creating workers and dispatching messages.

    This helper does no setup, but it waits for pending message deliveries and
    the stops all workers it knows about during cleanup. It takes the following
    parameters:

    :param str connector_name:
        Default value for ``connector_name`` on all message broker operations.
        If ``None``, the connector name must be provided for each operation.

    :param broker:
        The message broker to use internally. This should be an instance of
        :class:`~vumi.tests.fake_amqp.FakeAMQPBroker` if it is provided, but
        most of the time the default of ``None`` should be used to have the
        helper create its own broker.
    """

    implements(IHelper)

    def __init__(self, connector_name=None, broker=None,
                 status_connector_name=None):
        self._connector_name = connector_name
        self._status_connector_name = status_connector_name
        self.broker = broker if broker is not None else FakeAMQPBroker()
        self._workers = []

    def setup(self):
        pass

    @inlineCallbacks
    def cleanup(self):
        """
        Wait for any pending message deliveries and stop all workers.
        """
        yield self.broker.wait_delivery()
        for worker in self._workers:
            yield worker.stopWorker()

    @proxyable
    def cleanup_worker(self, worker):
        """
        Clean up a particular worker manually and remove it from the helper's
        cleanup list. This should only be called with workers that are already
        in the helper's cleanup list.
        """
        self._workers.remove(worker)
        return worker.stopWorker()

    @classmethod
    def get_fake_amqp_client(cls, broker):
        """
        Wrap a fake broker in an fake client.

        The broker parameter is mandatory because it's important that cleanup
        happen. If ``None`` is passed in explicitly, a new broker object will
        be created.
        """
        spec = get_spec(vumi_resource_path("amqp-spec-0-8.xml"))
        return FakeAMQClient(spec, {}, broker)

    @classmethod
    def get_worker_raw(cls, worker_class, config, broker=None):
        """
        Create and return an instance of a vumi worker.

        This doesn't start the worker and it doesn't add it to any cleanup
        machinery. In most cases, you want :meth:`get_worker` instead.
        """

        # When possible, always try and enable heartbeat setup in tests.
        # so make sure worker_name is set
        if (config is not None) and ('worker_name' not in config):
            config['worker_name'] = "unnamed"

        worker = worker_class({}, config)
        worker._amqp_client = cls.get_fake_amqp_client(broker)
        return worker

    @proxyable
    def get_worker(self, worker_class, config, start=True):
        """
        Create and return an instance of a vumi worker.

        :param worker_class: The worker class to instantiate.
        :param config: Config dict.
        :param start:
            ``True`` to start the worker (default), ``False`` otherwise.
        """
        worker = self.get_worker_raw(worker_class, config, self.broker)

        self._workers.append(worker)
        d = succeed(worker)
        if start:
            d.addCallback(_start_and_return_worker)
        return d

    def _rkey(self, connector_name, name):
        if connector_name is None:
            connector_name = self._connector_name
        return '.'.join((connector_name, name))

    @proxyable
    def get_dispatched(self, connector_name, name, message_class):
        """
        Get messages dispatched to a routing key.

        The more specific :meth:`get_dispatched_events`,
        :meth:`get_dispatched_inbound`, and :meth:`get_dispatched_outbound`
        wrapper methods should be used instead where they apply.

        :param str connector_name:
            The connector name, which is used as the routing key prefix.

        :param str name:
            The routing key suffix, generally ``"event"``, ``"inbound"``, or
            ``"outbound"``.

        :param message_class:
            The message class to wrap the raw message data in. This should
            probably be :class:`~vumi.message.TransportUserMessage` or
            :class:`~vumi.message.TransportEvent`.
        """
        rkey = self._rkey(connector_name, name)
        msgs = self.broker.get_dispatched('vumi', rkey)
        return [message_class.from_json(msg.body) for msg in msgs]

    def _wait_for_dispatched(self, connector_name, name, amount):
        rkey = self._rkey(connector_name, name)
        if amount is not None:
            # The broker knows how to wait for a specific number of messages.
            return self.broker.wait_messages('vumi', rkey, amount)
        # Wait for delivery to finish, then return whatever we have.
        return self.broker.wait_delivery().addCallback(
            lambda _: self.broker.get_messages('vumi', rkey))

    @proxyable
    def clear_all_dispatched(self):
        """
        Clear all dispatched messages from the broker.
        """
        self.broker.clear_messages('vumi')
        self.broker.clear_messages('vumi.metrics')

    def _clear_dispatched(self, connector_name, name):
        rkey = self._rkey(connector_name, name)
        return self.broker.clear_messages('vumi', rkey)

    @proxyable
    def get_dispatched_events(self, connector_name=None):
        """
        Get events dispatched to a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns: A list of :class:`~vumi.message.TransportEvent` instances.
        """
        return self.get_dispatched(connector_name, 'event', TransportEvent)

    @proxyable
    def get_dispatched_inbound(self, connector_name=None):
        """
        Get inbound messages dispatched to a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A list of :class:`~vumi.message.TransportUserMessage` instances.
        """
        return self.get_dispatched(
            connector_name, 'inbound', TransportUserMessage)

    @proxyable
    def get_dispatched_outbound(self, connector_name=None):
        """
        Get outbound messages dispatched to a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A list of :class:`~vumi.message.TransportUserMessage` instances.
        """
        return self.get_dispatched(
            connector_name, 'outbound', TransportUserMessage)

    @proxyable
    def get_dispatched_statuses(self, connector_name=None):
        """
        Get statuses dispatched to a connector.

        :param str connector_name:
            Connector name. If ``None``, the default status connector name for
            the helper instance will be used.

        :returns:
            A list of :class:`~vumi.message.TransportStatus` instances.
        """
        if connector_name is None:
            connector_name = self._status_connector_name

        return self.get_dispatched(
            connector_name, 'status', TransportStatus)

    @proxyable
    def wait_for_dispatched_events(self, amount=None, connector_name=None):
        """
        Wait for events dispatched to a connector.

        :param int amount:
            Number of messages to wait for. If ``None``, this will wait for the
            end of the current delivery run instead of a specific number of
            messages.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires with a list of
            :class:`~vumi.message.TransportEvent` instances.
        """
        d = self._wait_for_dispatched(connector_name, 'event', amount)
        d.addCallback(lambda msgs: [
            TransportEvent(**msg.payload) for msg in msgs])
        return d

    @proxyable
    def wait_for_dispatched_inbound(self, amount=None, connector_name=None):
        """
        Wait for inbound messages dispatched to a connector.

        :param int amount:
            Number of messages to wait for. If ``None``, this will wait for the
            end of the current delivery run instead of a specific number of
            messages.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires with a list of
            :class:`~vumi.message.TransportUserMessage` instances.
        """
        d = self._wait_for_dispatched(connector_name, 'inbound', amount)
        d.addCallback(lambda msgs: [
            TransportUserMessage(**msg.payload) for msg in msgs])
        return d

    @proxyable
    def wait_for_dispatched_outbound(self, amount=None, connector_name=None):
        """
        Wait for outbound messages dispatched to a connector.

        :param int amount:
            Number of messages to wait for. If ``None``, this will wait for the
            end of the current delivery run instead of a specific number of
            messages.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires with a list of
            :class:`~vumi.message.TransportUserMessage` instances.
        """
        d = self._wait_for_dispatched(connector_name, 'outbound', amount)
        d.addCallback(lambda msgs: [
            TransportUserMessage(**msg.payload) for msg in msgs])
        return d

    @proxyable
    def wait_for_dispatched_statuses(self, amount=None, connector_name=None):
        """
        Wait for statuses dispatched to a connector.

        :param int amount:
            Number of messages to wait for. If ``None``, this will wait for the
            end of the current delivery run instead of a specific number of
            messages.

        :param str connector_name:
            Connector name. If ``None``, the default status connector name for
            the helper instance will be used.

        :returns:
            A :class:`Deferred` that fires with a list of
            :class:`~vumi.message.TransportEvent` instances.
        """
        if connector_name is None:
            connector_name = self._status_connector_name

        d = self._wait_for_dispatched(connector_name, 'status', amount)
        d.addCallback(lambda msgs: [
            TransportStatus(**msg.payload) for msg in msgs])
        return d

    @proxyable
    def clear_dispatched_events(self, connector_name=None):
        """
        Clear dispatched events for a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.
        """
        return self._clear_dispatched(connector_name, 'event')

    @proxyable
    def clear_dispatched_inbound(self, connector_name=None):
        """
        Clear dispatched inbound messages for a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.
        """
        return self._clear_dispatched(connector_name, 'inbound')

    @proxyable
    def clear_dispatched_outbound(self, connector_name=None):
        """
        Clear dispatched outbound messages for a connector.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.
        """
        return self._clear_dispatched(connector_name, 'outbound')

    @proxyable
    def clear_dispatched_statuses(self, connector_name=None):
        """
        Clear dispatched statuses for a connector.

        :param str connector_name:
            Connector name. If ``None``, the default status connector name for
            the helper instance will be used.
        """
        if connector_name is None:
            connector_name = self._status_connector_name

        return self._clear_dispatched(connector_name, 'status')

    @proxyable
    def dispatch_raw(self, routing_key, message, exchange='vumi'):
        """
        Dispatch a message to the specified routing key.

        The more specific :meth:`dispatch_inbound`, :meth:`dispatch_outbound`,
        and :meth:`dispatch_event` wrapper methods should be used instead where
        they apply.

        :param str routing_key:
            Routing key to dispatch the message to.

        :param message:
            Message to dispatch.

        :param str exchange:
            AMQP exchange to dispatch the message to. Defaults to ``"vumi"``

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        self.broker.publish_message(exchange, routing_key, message)
        return self.kick_delivery()

    @proxyable
    def dispatch_inbound(self, message, connector_name=None):
        """
        Dispatch an inbound message.

        :param message:
            Message to dispatch. Should be a
            :class:`~vumi.message.TransportUserMessage` instance.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        return self.dispatch_raw(
            self._rkey(connector_name, 'inbound'), message)

    @proxyable
    def dispatch_outbound(self, message, connector_name=None):
        """
        Dispatch an outbound message.

        :param message:
            Message to dispatch. Should be a
            :class:`~vumi.message.TransportUserMessage` instance.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        return self.dispatch_raw(
            self._rkey(connector_name, 'outbound'), message)

    @proxyable
    def dispatch_event(self, message, connector_name=None):
        """
        Dispatch an event.

        :param message:
            Message to dispatch. Should be a
            :class:`~vumi.message.TransportEvent` instance.

        :param str connector_name:
            Connector name. If ``None``, the default connector name for the
            helper instance will be used.

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        return self.dispatch_raw(
            self._rkey(connector_name, 'event'), message)

    @proxyable
    def dispatch_status(self, message, connector_name=None):
        """
        Dispatch a status.

        :param message:
            Message to dispatch. Should be a
            :class:`~vumi.message.TransportStatus` instance.

        :param str connector_name:
            Connector name. If ``None``, the default status connector name for
            the helper instance will be used.

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        if connector_name is None:
            connector_name = self._status_connector_name

        return self.dispatch_raw(
            self._rkey(connector_name, 'status'), message)

    @proxyable
    def kick_delivery(self):
        """
        Trigger delivery of messages by the broker.

        This is generally called internally by anything that sends a message.

        :returns:
            A :class:`Deferred` that fires when all messages have been
            delivered.
        """
        return self.broker.kick_delivery()

    @proxyable
    def get_dispatched_metrics(self):
        """
        Get dispatched metrics.

        The list of datapoints from each dispatched metrics message is
        returned.
        """
        msgs = self.broker.get_dispatched('vumi.metrics', 'vumi.metrics')
        return [json.loads(msg.body)['datapoints'] for msg in msgs]

    @proxyable
    def wait_for_dispatched_metrics(self):
        """
        Get dispatched metrics after waiting for any pending deliveries.

        The list of datapoints from each dispatched metrics message is
        returned.
        """
        return self.broker.wait_delivery().addCallback(
            lambda _: self.get_dispatched_metrics())

    @proxyable
    def clear_dispatched_metrics(self):
        """
        Clear dispatched metrics messages from the broker.
        """
        self.broker.clear_messages('vumi.metrics')


class MessageDispatchHelper(object):
    """
    Helper for creating and immediately dispatching messages.

    This builds on top of :class:`MessageHelper` and :class:`WorkerHelper`.

    It does not allow dispatching to nonstandard connectors. If you need to do
    that, either use :class:`MessageHelper` and :class:`WorkerHelper` directly
    or build a second :class:`MessageDispatchHelper` with a second
    :class:`WorkerHelper`.

    :param msg_helper: A :class:`MessageHelper` instance.
    :param worker_helper: A :class:`WorkerHelper` instance.
    """

    implements(IHelper)

    def __init__(self, msg_helper, worker_helper):
        self.msg_helper = msg_helper
        self.worker_helper = worker_helper

    def setup(self):
        pass

    def cleanup(self):
        pass

    @proxyable
    def make_dispatch_inbound(self, *args, **kw):
        """
        Construct and dispatch an inbound message.

        This is a wrapper around :meth:`MessageHelper.make_inbound` (to which
        all parameters are passed) and :meth:`WorkerHelper.dispatch_inbound`.

        :returns:
            A :class:`Deferred` that fires with the constructed message once it
            has been dispatched.
        """
        msg = self.msg_helper.make_inbound(*args, **kw)
        d = self.worker_helper.dispatch_inbound(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_outbound(self, *args, **kw):
        """
        Construct and dispatch an outbound message.

        This is a wrapper around :meth:`MessageHelper.make_outbound` (to which
        all parameters are passed) and :meth:`WorkerHelper.dispatch_outbound`.

        :returns:
            A :class:`Deferred` that fires with the constructed message once it
            has been dispatched.
        """
        msg = self.msg_helper.make_outbound(*args, **kw)
        d = self.worker_helper.dispatch_outbound(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_ack(self, *args, **kw):
        """
        Construct and dispatch an ack event.

        This is a wrapper around :meth:`MessageHelper.make_ack` (to which all
        parameters are passed) and :meth:`WorkerHelper.dispatch_event`.

        :returns:
            A :class:`Deferred` that fires with the constructed event once it
            has been dispatched.
        """
        msg = self.msg_helper.make_ack(*args, **kw)
        d = self.worker_helper.dispatch_event(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_nack(self, *args, **kw):
        """
        Construct and dispatch a nack event.

        This is a wrapper around :meth:`MessageHelper.make_nack` (to which all
        parameters are passed) and :meth:`WorkerHelper.dispatch_event`.

        :returns:
            A :class:`Deferred` that fires with the constructed event once it
            has been dispatched.
        """
        msg = self.msg_helper.make_nack(*args, **kw)
        d = self.worker_helper.dispatch_event(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_delivery_report(self, *args, **kw):
        """
        Construct and dispatch a delivery report event.

        This is a wrapper around :meth:`MessageHelper.make_delivery_report` (to
        which all parameters are passed) and
        :meth:`WorkerHelper.dispatch_event`.

        :returns:
            A :class:`Deferred` that fires with the constructed event once it
            has been dispatched.
        """
        msg = self.msg_helper.make_delivery_report(*args, **kw)
        d = self.worker_helper.dispatch_event(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_reply(self, *args, **kw):
        """
        Construct and dispatch a reply message.

        This is a wrapper around :meth:`MessageHelper.make_reply` (to which all
        parameters are passed) and :meth:`WorkerHelper.dispatch_outbound`.

        :returns:
            A :class:`Deferred` that fires with the constructed message once it
            has been dispatched.
        """
        msg = self.msg_helper.make_reply(*args, **kw)
        d = self.worker_helper.dispatch_outbound(msg)
        return d.addCallback(lambda r: msg)

    @proxyable
    def make_dispatch_status(self, *args, **kw):
        """
        Construct and dispatch a status.

        This is a wrapper around :meth:`MessageHelper.make_status` (to
        which all parameters are passed) and
        :meth:`WorkerHelper.dispatch_status`.

        :returns:
            A :class:`Deferred` that fires with the constructed message once it
            has been dispatched.
        """
        msg = self.msg_helper.make_status(*args, **kw)
        d = self.worker_helper.dispatch_status(msg)
        return d.addCallback(lambda r: msg)


class RiakDisabledForTest(object):
    """
    Placeholder object for a disabled riak config.

    This class exists to throw a meaningful error when trying to use Riak in
    a test that disallows it. We can't do this from inside the Riak setup
    infrastructure, because that would be very invasive for something that
    only really matters for tests.
    """
    def __getattr__(self, name):
        raise RuntimeError(
            "Use of Riak has been disabled for this test. Please set "
            "'use_riak = True' on the test class to enable it.")

    def __deepcopy__(self, memo):
        """
        We have no state, but ``deepcopy()`` triggers our :meth:`__getattr__`.
        We return ``self`` so the copy compares equal.
        """
        return self


def import_filter(exc, *expected):
    msg = exc.args[0]
    module = msg.split()[-1]
    if expected and (module not in expected):
        raise
    return module


def import_skip(exc, *expected):
    """
    Raise :class:`SkipTest` if the provided :class:`ImportError` matches a
    module name in ``expected``, otherwise reraise the :class:`ImportError`.

    This is useful for skipping tests that require optional dependencies which
    might not be present.
    """
    module = import_filter(exc, *expected)
    raise SkipTest("Failed to import '%s'." % (module,))


def skiptest(reason):
    """
    Decorate a test that should be skipped with a reason.

    NOTE: Don't import this as `skip`, because that will cause trial to skip
          the entire module that imports it.
    """
    def skipdeco(func):
        func.skip = reason
        return func
    return skipdeco


def maybe_async(sync_attr):
    """
    Decorate a method that may be sync or async.

    This redecorates with the either ``@inlineCallbacks`` or
    ``@flatten_generator``, depending on the value of ``sync_attr``.
    """
    if callable(sync_attr):
        # If we don't get a sync attribute name, default to 'is_sync'.
        return maybe_async('is_sync')(sync_attr)

    def redecorate(func):
        @wraps(func)
        def wrapper(self, *args, **kw):
            if getattr(self, sync_attr):
                return flatten_generator(func)(self, *args, **kw)
            return inlineCallbacks(func)(self, *args, **kw)
        return wrapper

    return redecorate


def maybe_async_return(value, maybe_deferred):
    """
    Return ``value`` or a deferred that fires with it.

    This is useful in cases where we're performing a potentially async
    operation but don't necessarily have enough information to use
    `maybe_async`.
    """
    if isinstance(maybe_deferred, Deferred):
        return maybe_deferred.addCallback(lambda r: value)
    return value


class PersistenceHelperError(Exception):
    """
    Exception thrown by a PersistenceHelper when it sees something wrong.
    """


class PersistenceHelper(object):
    """
    Test helper for managing persistent storage.

    This helper manages Riak and Redis clients and configs and cleans up after
    them. It does no setup, but its cleanup may take a while if there's a lot
    in Riak.

    All configs for objects that build Riak or Redis clients must be passed
    through :meth:`mk_config`.

    :param bool use_riak:
        Pass ``True`` if Riak is desired, otherwise it will be disabled in the
        generated config parameters.

    :param bool is_sync:
        Pass ``True`` if synchronous Riak and Redis clients are desired,
        otherwise asynchronous ones will be built. This only applies to clients
        built by this helper, not those built by other objects using configs
        from this helper.
    """

    implements(IHelper)

    _patches_applied = False

    def __init__(self, use_riak=False, is_sync=False, assert_closed=False):
        self._assert_closed = assert_closed
        if os.environ.get('VUMI_TEST_ASSERT_CLOSED', ''):
            # Override from environment
            self._assert_closed = True
        self.use_riak = use_riak
        self.is_sync = is_sync
        self._patches = []
        self._riak_managers = []
        self._redis_managers = []
        self._test_prefix = 'vumitest'
        self._config_overrides = {
            'redis_manager': {
                'FAKE_REDIS': 'yes',
                'key_prefix': self._test_prefix,
            },
            'riak_manager': {
                'bucket_prefix': self._test_prefix,
            },
        }
        if not self.use_riak:
            self._config_overrides['riak_manager'] = RiakDisabledForTest()

        self._riak_stacks = {}

    def setup(self):
        self._patch_riak()
        self._patch_txriak()
        self._patch_redis()
        self._patch_txredis()
        self._patches_applied = True

    @maybe_async
    def cleanup(self):
        unclosed_managers = []

        for purge, manager in self._get_riak_managers_for_cleanup():
            if manager._is_unclosed():
                unclosed_managers.append(manager)
            if purge:
                try:
                    yield self._purge_riak(manager)
                except ConnectionRefusedError:
                    pass
            yield manager.close_manager()

        for purge, manager in self._get_redis_managers_for_cleanup():
            if purge:
                yield self._purge_redis(manager)
            yield manager.close_manager()

        self._unpatch()

        if unclosed_managers and self._assert_closed:
            # We have unclosed managers and we've been asked to assert that we
            # don't.
            for manager in unclosed_managers:
                stack = self._riak_stacks.get(
                    manager, ["No stack trace found.\n"])
                print "========= %r =========" % manager
                print "".join(stack)
            print "Unclosed Riak managers:", len(unclosed_managers)
            raise PersistenceHelperError(
                "Unclosed Riak managers found during cleanup: %s %s" % (
                    len(unclosed_managers), unclosed_managers))

    def _get_riak_managers_for_cleanup(self):
        """
        Get a list of Riak managers and whether they should be purged.

        The return value is a list of (`bool`, `Manager`) tuples. If the first
        item is `True`, the manager should be purged. It's safe to purge
        managers even if the first item is `False`, but it adds extra cleanup
        time.
        """
        # NOTE: Assumes we're only ever connecting to one Riak cluster.
        seen_bucket_prefixes = set()
        managers = []
        for manager in self._riak_managers:
            if manager.bucket_prefix in seen_bucket_prefixes:
                managers.append((False, manager))
            else:
                seen_bucket_prefixes.add(manager.bucket_prefix)
                managers.append((True, manager))
        # Return in reverse order in case something overrides cleanup and
        # cares about ordering.
        return reversed(managers)

    def _get_redis_managers_for_cleanup(self):
        """
        Get a list of Redis managers and whether they should be purged.

        The return value is a list of (`bool`, `Manager`) tuples. If the first
        item is `True`, the manager should be purged. It's safe to purge
        managers even if the first item is `False`, but it adds extra cleanup
        time.
        """
        # NOTE: Assumes we're only ever connecting to one Redis db.
        seen_key_prefixes = set()
        managers = []
        for manager in self._redis_managers:
            if manager._key_prefix in seen_key_prefixes:
                managers.append((False, manager))
            else:
                seen_key_prefixes.add(manager._key_prefix)
                managers.append((True, manager))
        # Return in reverse order in case something overrides teardown and
        # cares about ordering.
        return reversed(managers)

    def _patch(self, obj, attribute, value):
        monkey_patch = MonkeyPatcher((obj, attribute, value))
        self._patches.append(monkey_patch)
        monkey_patch.patch()
        return monkey_patch

    def _unpatch(self):
        for patch in reversed(self._patches):
            patch.restore()
        self._patches_applied = False

    def _patch_riak(self):
        try:
            from vumi.persist.riak_manager import RiakManager
        except ImportError, e:
            import_filter(e, 'riak')
            return

        orig_init = RiakManager.__init__

        def wrapper(obj, *args, **kw):
            orig_init(obj, *args, **kw)
            self._collect_riak_manager(obj)

        self._patch(RiakManager, '__init__', wrapper)

    def _patch_txriak(self):
        try:
            from vumi.persist.txriak_manager import TxRiakManager
        except ImportError, e:
            import_filter(e, 'riak')
            return

        orig_init = TxRiakManager.__init__

        def wrapper(obj, *args, **kw):
            orig_init(obj, *args, **kw)
            self._collect_riak_manager(obj)

        self._patch(TxRiakManager, '__init__', wrapper)

    def _collect_riak_manager(self, manager):
        self._riak_managers.append(manager)
        if self._assert_closed:
            self._riak_stacks[manager] = get_stack_trace(2)

    def _patch_redis(self):
        try:
            from vumi.persist.redis_manager import RedisManager
        except ImportError, e:
            import_filter(e, 'redis')
            return

        orig_init = RedisManager.__init__

        def wrapper(obj, *args, **kw):
            orig_init(obj, *args, **kw)
            self._redis_managers.append(obj)

        self._patch(RedisManager, '__init__', wrapper)

    def _patch_txredis(self):
        from vumi.persist.txredis_manager import TxRedisManager

        orig_init = TxRedisManager.__init__

        def wrapper(obj, *args, **kw):
            orig_init(obj, *args, **kw)
            self._redis_managers.append(obj)

        self._patch(TxRedisManager, '__init__', wrapper)

    def _purge_riak(self, manager):
        "This is a separate method to allow easy overriding."
        return manager.purge_all()

    @maybe_async
    def _purge_redis(self, manager):
        "This is a separate method to allow easy overriding."
        try:
            yield manager._purge_all()
        except RuntimeError, e:
            # Ignore managers that are already closed.
            if e.args[0] != 'Not connected':
                raise

    def _check_patches_applied(self):
        if not self._patches_applied:
            raise PersistenceHelperError(
                "setup() must be called before performing this operation.")

    @proxyable
    def get_riak_manager(self, config=None):
        """
        Build and return a Riak manager.

        :param dict config:
            Riak manager config. (Not a complete worker config.) If ``None``,
            the one used by :meth:`mk_config` will be used.

        :returns:
            A :class:`~vumi.persist.riak_manager.RiakManager` or
            :class:`~vumi.persist.riak_manager.TxRiakManager`, depending on the
            value of :attr:`is_sync`.
        """
        self._check_patches_applied()
        if config is None:
            config = self._config_overrides["riak_manager"].copy()
        else:
            config = config.copy()
            config["bucket_prefix"] = "%s%s" % (
                self._test_prefix, config["bucket_prefix"])

        if self.is_sync:
            return self._get_sync_riak_manager(config)
        return self._get_async_riak_manager(config)

    def _get_async_riak_manager(self, config):
        try:
            from vumi.persist.txriak_manager import TxRiakManager
        except ImportError, e:
            import_skip(e, 'riak')

        return TxRiakManager.from_config(config)

    def _get_sync_riak_manager(self, config):
        try:
            from vumi.persist.riak_manager import RiakManager
        except ImportError, e:
            import_skip(e, 'riak')

        return RiakManager.from_config(config)

    def record_load_and_store(self, riak_manager, loads, stores):
        """
        Patch a Riak manager to capture load and store operations.

        :param riak_manager: The manager object to patch.
        :param list loads: A list to append the keys of loaded objects to.
        :param list stores: A list to append the keys of stored objects to.
        """
        orig_load = riak_manager.load
        orig_store = riak_manager.store

        def record_load(modelcls, key, result=None):
            loads.append(key)
            return orig_load(modelcls, key, result=result)

        def record_store(obj):
            stores.append(obj.key)
            return orig_store(obj)

        self._patch(riak_manager, "load", record_load)
        self._patch(riak_manager, "store", record_store)

    @proxyable
    def get_redis_manager(self, config=None):
        """
        Build and return a Redis manager.

        This will be backed by an in-memory fake unless the
        ``VUMITEST_REDIS_DB`` environment variable is set.

        :param dict config:
            Redis manager config. (Not a complete worker config.) If ``None``,
            the one used by :meth:`mk_config` will be used.

        :returns:
            A :class:`~vumi.persist.redis_manager.RedisManager` or
            :class:`~vumi.persist.redis_manager.TxRedisManager`, depending on
            the value of :attr:`is_sync`.
        """
        self._check_patches_applied()
        if config is None:
            config = self._config_overrides['redis_manager'].copy()

        if self.is_sync:
            return self._get_sync_redis_manager(config)
        return self._get_async_redis_manager(config)

    def _get_async_redis_manager(self, config):
        from vumi.persist.txredis_manager import TxRedisManager

        return TxRedisManager.from_config(config)

    def _get_sync_redis_manager(self, config):
        try:
            from vumi.persist.redis_manager import RedisManager
        except ImportError, e:
            import_skip(e, 'redis')

        return RedisManager.from_config(config)

    @proxyable
    def mk_config(self, config):
        """
        Return a copy of ``config`` with the ``riak_manager`` and
        ``redis_manager`` fields overridden.

        All configs for things that create Riak or Redis clients should be
        passed through this method.
        """
        self._check_patches_applied()
        config = config.copy()
        config.update(self._config_overrides)
        return config
PK=JGvumi/tests/__init__.pyPKqGVy::vumi/tests/test_fake_amqp.pyfrom twisted.internet.defer import inlineCallbacks, returnValue, DeferredQueue

from vumi.service import get_spec, Worker
from vumi.utils import vumi_resource_path
from vumi.tests import fake_amqp
from vumi.tests.helpers import VumiTestCase


def mkmsg(body):
    return fake_amqp.Thing("Message", body=body)


class ToyWorker(Worker):
    @inlineCallbacks
    def startWorker(self):
        paused = self.config.get('paused', False)
        self.msgs = []
        self.pub = yield self.publish_to('test.pub')
        self.conpub = yield self.publish_to('test.con')
        self.con = yield self.consume(
            'test.con', self.consume_msg, paused=paused)

    def consume_msg(self, msg):
        self.msgs.append(msg)


class ToyAMQClient(object):
    """
    A fake fake client object for building fake channel objects.
    """
    def __init__(self, broker, delegate):
        self.broker = broker
        self.delegate = delegate


class TestFakeAMQP(VumiTestCase):
    def setUp(self):
        self.broker = fake_amqp.FakeAMQPBroker()
        self.add_cleanup(self.broker.wait_delivery)

    def make_exchange(self, exchange, exchange_type):
        self.broker.exchange_declare(exchange, exchange_type, durable=True)
        return self.broker.exchanges[exchange]

    def make_queue(self, queue):
        self.broker.queue_declare(queue)
        return self.broker.queues[queue]

    def make_channel(self, channel_id, delegate=None):
        channel = fake_amqp.FakeAMQPChannel(
            channel_id, ToyAMQClient(self.broker, delegate))
        channel.channel_open()
        return channel

    def set_up_broker(self):
        self.chan1 = self.make_channel(1)
        self.chan2 = self.make_channel(2)
        self.ex_direct = self.make_exchange('direct', 'direct')
        self.ex_topic = self.make_exchange('topic', 'topic')
        self.q1 = self.make_queue('q1')
        self.q2 = self.make_queue('q2')
        self.q3 = self.make_queue('q3')

    @inlineCallbacks
    def get_worker(self, **config):
        spec = get_spec(vumi_resource_path("amqp-spec-0-8.xml"))
        amq_client = fake_amqp.FakeAMQClient(spec, {}, self.broker)

        worker = ToyWorker({}, config)
        worker._amqp_client = amq_client
        yield worker.startWorker()
        returnValue(worker)

    def test_misc(self):
        str(fake_amqp.Thing('kind', foo='bar'))
        msg = fake_amqp.Message(None, [('foo', 'bar')])
        self.assertEqual('bar', msg.foo)
        self.assertRaises(AttributeError, lambda: msg.bar)

    def test_channel_open(self):
        channel = fake_amqp.FakeAMQPChannel(0, ToyAMQClient(self.broker, None))
        self.assertEqual([], self.broker.channels)
        channel.channel_open()
        self.assertEqual([channel], self.broker.channels)

    def test_exchange_declare(self):
        channel = self.make_channel(0)
        self.assertEqual({}, self.broker.exchanges)
        channel.exchange_declare('foo', 'direct', durable=True)
        self.assertEqual(['foo'], self.broker.exchanges.keys())
        self.assertEqual('direct', self.broker.exchanges['foo'].exchange_type)
        channel.exchange_declare('bar', 'topic', durable=True)
        self.assertEqual(['bar', 'foo'], sorted(self.broker.exchanges.keys()))
        self.assertEqual('topic', self.broker.exchanges['bar'].exchange_type)

    def test_declare_and_queue_bind(self):
        channel = self.make_channel(0)
        self.assertEqual({}, self.broker.queues)
        channel.queue_declare('foo')
        channel.queue_declare('foo')
        self.assertEqual(['foo'], self.broker.queues.keys())
        exch = self.make_exchange('exch', 'direct')
        self.assertEqual({}, exch.binds)
        channel.queue_bind('foo', 'exch', 'routing.key')
        self.assertEqual(['routing.key'], exch.binds.keys())

        n = len(self.broker.queues)
        channel.queue_declare('')
        self.assertEqual(n + 1, len(self.broker.queues))

    def test_publish_direct(self):
        self.set_up_broker()
        self.chan1.queue_bind('q1', 'direct', 'routing.key.one')
        self.chan1.queue_bind('q1', 'direct', 'routing.key.two')
        self.chan1.queue_bind('q2', 'direct', 'routing.key.two')
        delivered = []

        def fake_put(*args):
            delivered.append(args)
        self.q1.put = fake_put
        self.q2.put = fake_put
        self.q3.put = fake_put

        self.chan1.basic_publish('direct', 'routing.key.none', 'blah')
        self.assertEqual([], delivered)

        self.chan1.basic_publish('direct', 'routing.key.*', 'blah')
        self.assertEqual([], delivered)

        self.chan1.basic_publish('direct', 'routing.key.#', 'blah')
        self.assertEqual([], delivered)

        self.chan1.basic_publish('direct', 'routing.key.one', 'blah')
        self.assertEqual([('direct', 'routing.key.one', 'blah')], delivered)

        delivered[:] = []  # Clear without reassigning
        self.chan1.basic_publish('direct', 'routing.key.two', 'blah')
        self.assertEqual([('direct', 'routing.key.two', 'blah')] * 2,
                         delivered)

    def test_publish_topic(self):
        self.set_up_broker()
        self.chan1.queue_bind('q1', 'topic', 'routing.key.*.foo.#')
        self.chan1.queue_bind('q2', 'topic', 'routing.key.#.foo')
        self.chan1.queue_bind('q3', 'topic', 'routing.key.*.foo.*')
        delivered = []

        def mfp(q):
            def fake_put(*args):
                delivered.append((q,) + args)
            return fake_put
        self.q1.put = mfp('q1')
        self.q2.put = mfp('q2')
        self.q3.put = mfp('q3')

        self.chan1.basic_publish('topic', 'routing.key.none', 'blah')
        self.assertEqual([], delivered)

        self.chan1.basic_publish('topic', 'routing.key.foo.one', 'blah')
        self.assertEqual([], delivered)

        self.chan1.basic_publish('topic', 'routing.key.foo', 'blah')
        self.assertEqual([('q2', 'topic', 'routing.key.foo', 'blah')],
                         delivered)

        delivered[:] = []  # Clear without reassigning
        self.chan1.basic_publish('topic', 'routing.key.one.two.foo', 'blah')
        self.assertEqual([('q2', 'topic', 'routing.key.one.two.foo', 'blah')],
                         delivered)

        delivered[:] = []  # Clear without reassigning
        self.chan1.basic_publish('topic', 'routing.key.one.foo', 'blah')
        self.assertEqual([('q1', 'topic', 'routing.key.one.foo', 'blah'),
                          ('q2', 'topic', 'routing.key.one.foo', 'blah'),
                          ], sorted(delivered))

        delivered[:] = []  # Clear without reassigning
        self.chan1.basic_publish('topic', 'routing.key.one.foo.two', 'blah')
        self.assertEqual([('q1', 'topic', 'routing.key.one.foo.two', 'blah'),
                          ('q3', 'topic', 'routing.key.one.foo.two', 'blah'),
                          ], sorted(delivered))

    def test_basic_get(self):
        self.set_up_broker()
        self.assertEqual('get-empty', self.chan1.basic_get('q1').method.name)
        self.q1.put('foo', 'rkey.foo', mkmsg('blah'))
        self.assertEqual('blah', self.chan1.basic_get('q1').content.body)
        self.assertEqual('get-empty', self.chan1.basic_get('q1').method.name)

    def test_consumer_wrangling(self):
        self.set_up_broker()
        self.chan1.queue_bind('q1', 'direct', 'foo')
        self.assertEqual(set(), self.q1.consumers)
        self.chan1.basic_consume('q1', 'tag1')
        self.assertEqual(set(['tag1']), self.q1.consumers)
        self.chan1.basic_consume('q1', 'tag2')
        self.assertEqual(set(['tag1', 'tag2']), self.q1.consumers)
        self.chan1.basic_cancel('tag2')
        self.assertEqual(set(['tag1']), self.q1.consumers)
        self.chan1.basic_cancel('tag2')
        self.assertEqual(set(['tag1']), self.q1.consumers)

    def test_basic_qos_global_unsupported(self):
        """
        basic_qos() is unsupported with global=True.
        """
        channel = self.make_channel(0)
        self.assertRaises(NotImplementedError, channel.basic_qos, 0, 1, True)

    def test_basic_qos_per_consumer(self):
        """
        basic_qos() only applies to consumers started after the call.
        """
        channel = self.make_channel(0)
        channel.queue_declare('q1')
        channel.queue_declare('q2')
        self.assertEqual(channel.qos_prefetch_count, 0)

        channel.basic_consume('q1', 'tag1')
        self.assertEqual(channel._get_consumer_prefetch('tag1'), 0)

        channel.basic_qos(0, 1, False)
        channel.basic_consume('q2', 'tag2')
        self.assertEqual(channel._get_consumer_prefetch('tag1'), 0)
        self.assertEqual(channel._get_consumer_prefetch('tag2'), 1)

    @inlineCallbacks
    def test_basic_ack(self):
        """
        basic_ack() should acknowledge a message.
        """
        class ToyDelegate(object):
            def __init__(self):
                self.queue = DeferredQueue()

            def basic_deliver(self, channel, msg):
                self.queue.put(msg)

        delegate = ToyDelegate()
        channel = self.make_channel(0, delegate)
        channel.exchange_declare('e1', 'direct', durable=True)
        channel.queue_declare('q1')
        channel.queue_bind('q1', 'e1', 'rkey')
        channel.basic_consume('q1', 'tag1')

        self.assertEqual(len(channel.unacked), 0)
        channel.basic_publish('e1', 'rkey', fake_amqp.mkContent('foo'))
        msg = yield delegate.queue.get()
        dtag = msg.delivery_tag
        self.assertEqual(len(channel.unacked), 1)
        channel.basic_ack(dtag, False)
        self.assertEqual(len(channel.unacked), 0)

        # Clean up.
        channel.message_processed()
        yield channel.broker.wait_delivery()

    @inlineCallbacks
    def test_basic_ack_consumer_canceled(self):
        """
        basic_ack() should fail if the consumer has been canceled.
        """
        class ToyDelegate(object):
            def __init__(self):
                self.queue = DeferredQueue()

            def basic_deliver(self, channel, msg):
                self.queue.put(msg)

        delegate = ToyDelegate()
        channel = self.make_channel(0, delegate)
        channel.exchange_declare('e1', 'direct', durable=True)
        channel.queue_declare('q1')
        channel.queue_bind('q1', 'e1', 'rkey')
        channel.basic_consume('q1', 'tag1')

        self.assertEqual(len(channel.unacked), 0)
        channel.basic_publish('e1', 'rkey', fake_amqp.mkContent('foo'))
        msg = yield delegate.queue.get()
        dtag = msg.delivery_tag
        self.assertEqual(len(channel.unacked), 1)

        channel.basic_cancel('tag1')
        self.assertRaises(Exception, channel.basic_ack, dtag, False)
        self.assertEqual(len(channel.unacked), 0)

        # Clean up.
        channel.message_processed()
        yield channel.broker.wait_delivery()

    @inlineCallbacks
    def test_fake_amqclient(self):
        worker = yield self.get_worker()
        yield worker.pub.publish_json({'message': 'foo'})
        yield worker.conpub.publish_json({'message': 'bar'})
        yield self.broker.wait_delivery()
        self.assertEqual({'message': 'bar'}, worker.msgs[0].payload)

    @inlineCallbacks
    def test_fake_amqclient_qos(self):
        """
        Even if we set QOS, all messages should get delivered.
        """
        worker = yield self.get_worker()

        yield worker.con.channel.basic_qos(0, 1, False)
        yield worker.conpub.publish_json({'message': 'foo'})
        yield worker.conpub.publish_json({'message': 'bar'})
        yield self.broker.wait_delivery()
        self.assertEqual(2, len(worker.msgs))

    @inlineCallbacks
    def test_fake_amqclient_pause(self):
        """
        Pausing and unpausing channels should work as expected.
        """
        worker = yield self.get_worker(paused=True)

        yield worker.conpub.publish_json({'message': 'foo'})
        yield self.broker.wait_delivery()
        self.assertEqual([], worker.msgs)

        yield worker.con.unpause()
        yield self.broker.wait_delivery()
        self.assertEqual(1, len(worker.msgs))
        self.assertEqual({'message': 'foo'}, worker.msgs[0].payload)
        worker.msgs = []

        yield self.broker.wait_delivery()
        yield worker.con.pause()
        yield worker.con.pause()
        yield self.broker.wait_delivery()
        yield worker.conpub.publish_json({'message': 'bar'})
        self.assertEqual([], worker.msgs)

        yield worker.con.unpause()
        yield worker.conpub.publish_json({'message': 'baz'})
        yield self.broker.wait_delivery()
        self.assertEqual(2, len(worker.msgs))
        yield worker.con.unpause()

    # This is a test which actually connects to the AMQP broker.
    #
    # It originally existed purely as a mechanism for discovering what
    # the real client/broker's behaviour is in order to duplicate it
    # in the fake one. I've left it in here for now in case we need to
    # do further investigation later, but we *really* don't want to
    # run it as part of the test suite.

    # @inlineCallbacks
    # def test_zzz_real_amqclient(self):
    #     print ""
    #     from vumi.service import WorkerCreator
    #     options = {
    #         "hostname": "127.0.0.1",
    #         "port": 5672,
    #         "username": "vumi",
    #         "password": "vumi",
    #         "vhost": "/develop",
    #         "specfile": "amqp-spec-0-8.xml",
    #         }
    #     wc = WorkerCreator(options)
    #     d = Deferred()

    #     class ToyWorker(Worker):
    #         @inlineCallbacks
    #         def startWorker(self):
    #             self.pub = yield self.publish_to('test.pub')
    #             self.pub.routing_key_is_bound = lambda _: True
    #             self.conpub = yield self.publish_to('test.con')
    #             self.con = yield self.consume('test.con', self.consume_msg,
    #                                           paused=True)
    #             d.callback(None)

    #         def consume_msg(self, msg):
    #             print "CONSUMED!", msg
    #             return True

    #     worker = wc.create_worker_by_class(ToyWorker, {})
    #     worker.startService()
    #     yield d
    #     print "foo"
    #     yield worker.pub.publish_json({"foo": "bar"})
    #     yield worker.conpub.publish_json({"bar": "baz"})
    #     yield worker.con.unpause()
    #     yield worker.con.pause()
    #     yield worker.con.pause()
    #     print "bar"
    #     yield worker.pub.channel.queue_declare(queue='test.foo')
    #     yield worker.pub.channel.queue_bind(queue='test.foo',
    #                                         exchange='vumi',
    #                                         routing_key='test.pub')
    #     yield worker.pub.publish_json({"foo": "bar"})
    #     print "getting..."
    #     foo = yield worker.pub.channel.basic_get(queue='test.foo')
    #     print "got:", foo
    #     yield worker.stopWorker()
    #     yield worker.stopService()
PKqG8,{H{Hvumi/tests/fake_amqp.py# -*- test-case-name: vumi.tests.test_fake_amqp -*-

from uuid import uuid4
import re

from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue, Deferred
from txamqp.client import TwistedDelegate
from txamqp.content import Content

from vumi.service import WorkerAMQClient
from vumi.message import Message as VumiMessage


def gen_id(prefix=''):
    return ''.join([prefix, uuid4().get_hex()])


def gen_longlong():
    return uuid4().int & 0xffffffffffffffff


class Thing(object):
    """
    A generic thing to reply with.
    """
    def __init__(self, kind, **kw):
        self._kind = kind
        self._kwfields = kw.keys()
        for k, v in kw.items():
            setattr(self, k, v)

    def __str__(self):
        return "" % (self._kind,
                                    ['[%s: %s]' % (f, getattr(self, f))
                                     for f in self._kwfields])


class Message(object):
    """
    A message is more complicated than a Thing.
    """
    def __init__(self, method, fields=(), content=None):
        self.method = method
        self._fields = fields
        self.content = content

    def __getattr__(self, key):
        for k, v in self._fields:
            if k == key:
                return v
        raise AttributeError(key)


def mkMethod(name, index=-1):
    """
    Create a "Method" object, suitable for a ``txamqp`` message.

    :param name: The name of the AMQP method, per the XML spec.
    :param index: The index of the AMQP method, per the XML spec.
    """
    return Thing("Method", name=name, id=index)


def mkContent(body, children=None, properties=None):
    return Thing("Content", body=body, children=children,
                 properties=properties)


def mk_deliver(body, exchange, routing_key, ctag, dtag):
    return Message(mkMethod('deliver', 60), [
            ('consumer_tag', ctag),
            ('delivery_tag', dtag),
            ('redelivered', False),
            ('exchange', exchange),
            ('routing_key', routing_key),
            ], mkContent(body))


def mk_get_ok(body, exchange, routing_key, dtag):
    return Message(mkMethod('get-ok', 71), [
            ('delivery_tag', dtag),
            ('redelivered', False),
            ('exchange', exchange),
            ('routing_key', routing_key),
            ], mkContent(body))


class FakeAMQPBroker(object):
    def __init__(self):
        self.queues = {}
        self.exchanges = {}
        self.channels = []
        self.dispatched = {}
        self._delivering = None

    def _get_queue(self, queue):
        assert queue in self.queues
        return self.queues[queue]

    def _get_exchange(self, exchange):
        assert exchange in self.exchanges
        return self.exchanges[exchange]

    def channel_open(self, channel):
        assert channel not in self.channels
        self.channels.append(channel)
        return Message(mkMethod("open-ok", 11))

    def channel_close(self, channel):
        if channel in self.channels:
            self.channels.remove(channel)
        return Message(mkMethod("close-ok", 41))

    def exchange_declare(self, exchange, exchange_type, durable):
        exchange_class = None
        if exchange_type == 'direct':
            exchange_class = FakeAMQPExchangeDirect
        elif exchange_type == 'topic':
            exchange_class = FakeAMQPExchangeTopic
        assert exchange_class is not None
        self.exchanges.setdefault(exchange, exchange_class(exchange, durable))
        assert exchange_type == self.exchanges[exchange].exchange_type
        assert durable == self.exchanges[exchange].durable
        return Message(mkMethod("declare-ok", 11))

    def queue_declare(self, queue):
        if not queue:
            queue = gen_id('queue.')
        self.queues.setdefault(queue, FakeAMQPQueue(queue))
        queue_obj = self._get_queue(queue)
        return Message(mkMethod("declare-ok", 11), [
                ('queue', queue),
                ('message_count', queue_obj.message_count()),
                ('consumer_count', queue_obj.consumer_count()),
                ])

    def queue_bind(self, queue, exchange, routing_key):
        self._get_exchange(exchange).queue_bind(routing_key,
                                            self._get_queue(queue))
        return Message(mkMethod("bind-ok", 21))

    def basic_consume(self, queue, tag):
        self._get_queue(queue).add_consumer(tag)
        self.kick_delivery()
        return Message(mkMethod("consume-ok", 21), [("consumer_tag", tag)])

    def basic_cancel(self, tag, queue):
        if queue in self.queues:
            self.queues[queue].remove_consumer(tag)
        return Message(mkMethod("cancel-ok", 31), [("consumer_tag", tag)])

    def basic_publish(self, exchange, routing_key, content):
        exc = self.dispatched.setdefault(exchange, {})
        exc.setdefault(routing_key, []).append(content)
        if exchange not in self.exchanges:
            # This is to test, so we don't care about missing queues
            return None
        self._get_exchange(exchange).basic_publish(routing_key, content)
        self.kick_delivery()
        return None

    def basic_get(self, queue):
        return self._get_queue(queue).get_message()

    def basic_ack(self, queue, delivery_tag):
        self._get_queue(queue).ack(delivery_tag)
        return None

    def deliver_to_channels(self):
        # Since all delivery goes through kick_delivery(), this can
        # only happen if message_processed() is called too many times.
        assert self._delivering is not None

        for channel in self.channels:
            self.try_deliver_to_channel(channel)

        # Process the sentinel "message" we added in kick_delivery().
        self.message_processed()

    def try_deliver_to_channel(self, channel):
        delivered = False
        for ctag, queue in channel.consumers.items():
            while channel.deliverable(ctag):
                dtag, msg = self._get_queue(queue).get_message()
                if dtag is None:
                    break
                dmsg = mk_deliver(msg['content'], msg['exchange'],
                                  msg['routing_key'], ctag, dtag)
                self._delivering['count'] += 1
                channel.deliver_message(dmsg, ctag)
                delivered = True
        return delivered

    def kick_delivery(self):
        """
        Schedule a message delivery run.

        Returns a deferred that will fire when all deliverable
        messages have been delivered and processed by their consumers.
        This is useful for manually triggering a delivery run from
        inside a test.
        """
        if self._delivering is None:
            self._delivering = {
                'deferred': Deferred(),
                'count': 0,
                }
        # Add a sentinel "message" that gets processed after this
        # delivery run, making the delivery process re-entrant. This
        # is important, because delivered messages can trigger more
        # messages to be published, which kicks delivery again.
        self._delivering['count'] += 1
        # Schedule this for later, so that we don't block whatever it
        # is we're currently doing.
        reactor.callLater(0, self.deliver_to_channels)
        return self.wait_delivery()

    def wait_delivery(self):
        """
        Wait for the current message delivery run (if any) to finish.

        Returns a deferred that will fire when the broker is finished
        delivering any messages from the current run. This should not
        leave any messages undelivered, because basic_publish() kicks
        off a delivery run.

        Each call returns a new deferred to avoid callback chain ordering
        issues when several things want to wait for delivery.

        NOTE: This method should be called during test teardown to make
        sure there are no pending delivery cleanups that will cause a
        dirty reactor race.
        """
        d = Deferred()
        if self._delivering is None:
            d.callback(None)
        else:
            self._delivering['deferred'].chainDeferred(d)
        return d

    def wait_messages(self, exchange, rkey, n):
        def check(d):
            msgs = self.get_messages(exchange, rkey)
            if len(msgs) >= n:
                d.callback(msgs)
            else:
                reactor.callLater(0, check, d)

        done = Deferred()
        reactor.callLater(0, check, done)
        return done

    def clear_messages(self, exchange, rkey=None):
        if exchange not in self.dispatched:
            return
        if rkey:
            del self.dispatched[exchange][rkey][:]
        else:
            self.dispatched[exchange].clear()

    def get_dispatched(self, exchange, rkey):
        return self.dispatched.get(exchange, {}).get(rkey, [])

    def get_messages(self, exchange, rkey):
        contents = self.get_dispatched(exchange, rkey)
        messages = [VumiMessage.from_json(content.body)
                    for content in contents]
        return messages

    def publish_message(self, exchange, routing_key, message):
        return self.publish_raw(exchange, routing_key, message.to_json())

    def publish_raw(self, exchange, routing_key, data):
        assert exchange in self.exchanges
        amq_message = Content(data)
        return self.basic_publish(exchange, routing_key, amq_message)

    def message_processed(self):
        assert self._delivering is not None
        self._delivering['count'] -= 1
        if self._delivering['count'] <= 0:
            d = self._delivering['deferred']
            self._delivering = None
            d.callback(None)


class FakeAMQPChannel(object):
    def __init__(self, channel_id, client):
        self.channel_id = channel_id
        self.client = client
        self.broker = client.broker
        self.qos_prefetch_count = 0
        self.consumers = {}
        self.delegate = client.delegate
        self.unacked = []
        self._consumer_prefetch = {}

    def __repr__(self):
        return '' % (self.channel_id,)

    def channel_open(self):
        return self.broker.channel_open(self)

    def channel_close(self):
        return self.broker.channel_close(self)

    def channel_flow(self, active):
        raise NotImplementedError(
            "channel.flow() is no longer supported in RabbitMQ 3.3.0.")

    def close(self, _reason):
        pass

    def basic_qos(self, _prefetch_size, prefetch_count, is_global):
        if is_global:
            raise NotImplementedError("global prefetch limits not supported.")
        self.qos_prefetch_count = prefetch_count

    def exchange_declare(self, exchange, type, durable=None):
        return self.broker.exchange_declare(exchange, type, durable)

    def queue_declare(self, queue, durable=None):
        return self.broker.queue_declare(queue)

    def queue_bind(self, queue, exchange, routing_key):
        return self.broker.queue_bind(queue, exchange, routing_key)

    def basic_consume(self, queue, tag=None):
        if not tag:
            tag = gen_id('consumer.')
        assert tag not in self.consumers
        self._consumer_prefetch[tag] = self.qos_prefetch_count
        self.consumers[tag] = queue
        return self.broker.basic_consume(queue, tag)

    def basic_cancel(self, tag):
        queue = self.consumers.pop(tag, None)
        if queue:
            self.broker.basic_cancel(tag, queue)
        self._consumer_prefetch.pop(tag, None)
        return Message(mkMethod("cancel-ok", 31))

    def basic_publish(self, exchange, routing_key, content):
        return self.broker.basic_publish(exchange, routing_key, content)

    def basic_ack(self, delivery_tag, multiple):
        assert delivery_tag in [dtag for dtag, _ctag, _queue in self.unacked]
        for dtag, ctag, queue in self.unacked[:]:
            if multiple or (dtag == delivery_tag):
                self.unacked.remove((dtag, ctag, queue))
                if ctag is not None and ctag not in self.consumers:
                    raise Exception("Invalid consumer tag: %s" % (ctag,))
                resp = self.broker.basic_ack(queue, dtag)
                if (dtag == delivery_tag):
                    return resp

    def _get_consumer_prefetch(self, consumer_tag):
        return self._consumer_prefetch[consumer_tag]

    def deliverable(self, consumer_tag):
        if consumer_tag not in self.consumers:
            return False
        prefetch = self._get_consumer_prefetch(consumer_tag)
        if prefetch < 1:
            return True
        return len(self.unacked) < prefetch

    def deliver_message(self, msg, consumer_tag):
        self.unacked.append(
            (msg.delivery_tag, consumer_tag, self.consumers[consumer_tag]))
        self.delegate.basic_deliver(self, msg)

    def basic_get(self, queue):
        dtag, msg = self.broker.basic_get(queue)
        if msg:
            self.unacked.append((dtag, None, queue))
            return mk_get_ok(msg['content'], msg['exchange'],
                             msg['routing_key'], dtag)
        return Message(mkMethod("get-empty", 72))

    def message_processed(self):
        """
        Notify the broker that a message has been processed, in order
        to make delivery sane.
        """
        self.broker.message_processed()


class FakeAMQPExchange(object):
    def __init__(self, name, durable):
        self.name = name
        self.binds = {}
        self.durable = durable

    def queue_bind(self, routing_key, queue):
        binds = self.binds.setdefault(routing_key, set())
        binds.add(queue)

    def basic_publish(self, routing_key, content):
        raise NotImplementedError()


class FakeAMQPExchangeDirect(FakeAMQPExchange):
    exchange_type = 'direct'

    def basic_publish(self, routing_key, content):
        for queue in self.binds.get(routing_key, set()):
            queue.put(self.name, routing_key, content)


class FakeAMQPExchangeTopic(FakeAMQPExchange):
    exchange_type = 'topic'

    def _bind_regex(self, bind):
        for k, v in [('.', r'\.'),
                     ('*', r'[^.]+'),
                     ('\.#\.', r'\.([^.]+\.)*'),
                     ('#\.', r'([^.]+\.)*'),
                     ('\.#', r'(\.[^.]+)*')]:
            bind = '^%s$' % bind.replace(k, v)
        return re.compile(bind)

    def match_rkey(self, bind, rkey):
        return (self._bind_regex(bind).match(rkey) is not None)

    def basic_publish(self, routing_key, content):
        for bind, queues in self.binds.items():
            if self.match_rkey(bind, routing_key):
                for queue in queues:
                    queue.put(self.name, routing_key, content)


class FakeAMQPQueue(object):
    def __init__(self, name):
        self.name = name
        self.messages = []
        self.consumers = set()
        self.unacked_messages = {}

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self):
        return hash(self.name)

    def add_consumer(self, consumer_tag):
        if consumer_tag not in self.consumers:
            self.consumers.add(consumer_tag)

    def remove_consumer(self, consumer_tag):
        if consumer_tag in self.consumers:
            self.consumers.remove(consumer_tag)

    def message_count(self):
        return len(self.messages)

    def consumer_count(self):
        return len(self.consumers)

    def put(self, exchange, routing_key, content):
        self.messages.append({
                'exchange': exchange,
                'routing_key': routing_key,
                'content': content.body,
                })

    def ack(self, delivery_tag):
        self.unacked_messages.pop(delivery_tag)

    def get_message(self):
        try:
            msg = self.messages.pop(0)
        except IndexError:
            return (None, None)
        dtag = gen_longlong()
        self.unacked_messages[dtag] = msg
        return (dtag, msg)


class FakeAMQClient(WorkerAMQClient):
    def __init__(self, spec, vumi_options=None, broker=None):
        WorkerAMQClient.__init__(self, TwistedDelegate(), '', spec)
        if vumi_options is not None:
            self.vumi_options = vumi_options
        if broker is None:
            broker = FakeAMQPBroker()
        self.broker = broker

    @inlineCallbacks
    def channel(self, id):
        yield self.channelLock.acquire()
        try:
            try:
                ch = self.channels[id]
            except KeyError:
                ch = FakeAMQPChannelWrapper(id, self)
                self.channels[id] = ch
        finally:
            self.channelLock.release()
        returnValue(ch)


class FakeAMQPChannelWrapper(object):
    """
    Wrapper around a FakeAMQPChannel to make it look more like a real channel
    object.
    """

    def __init__(self, id, client):
        self._fake_channel = FakeAMQPChannel(id, client)
        self.client = client

    def __repr__(self):
        return '' % (
            self._fake_channel.channel_id,)

    def channel_open(self):
        return self._fake_channel.channel_open()

    def channel_close(self):
        return self._fake_channel.channel_close()

    def channel_flow(self, active):
        return self._fake_channel.channel_flow(active)

    def close(self, _reason):
        pass

    def basic_qos(self, prefetch_size, prefetch_count, is_global):
        return self._fake_channel.basic_qos(
            prefetch_size, prefetch_count, is_global)

    def exchange_declare(self, exchange, type, durable=None):
        return self._fake_channel.exchange_declare(exchange, type, durable)

    def queue_declare(self, queue, durable=None):
        return self._fake_channel.queue_declare(queue, durable)

    def queue_bind(self, queue, exchange, routing_key):
        return self._fake_channel.queue_bind(queue, exchange, routing_key)

    def basic_consume(self, queue, tag=None):
        return self._fake_channel.basic_consume(queue, tag)

    def basic_cancel(self, tag):
        return self._fake_channel.basic_cancel(tag)

    def basic_publish(self, exchange, routing_key, content):
        return self._fake_channel.basic_publish(exchange, routing_key, content)

    def basic_ack(self, delivery_tag, multiple):
        return self._fake_channel.basic_ack(delivery_tag, multiple)

    def basic_get(self, queue):
        return self._fake_channel.basic_get(queue)
PK=JGgVvumi/tests/test_multiworker.pyfrom twisted.internet.defer import (Deferred, DeferredList, inlineCallbacks,
                                    returnValue)

from vumi.tests.utils import StubbedWorkerCreator
from vumi.service import Worker
from vumi.message import TransportUserMessage
from vumi.multiworker import MultiWorker
from vumi.tests.helpers import VumiTestCase, MessageHelper, WorkerHelper


class ToyWorker(Worker):
    events = []

    def startService(self):
        self._d = Deferred()
        return super(ToyWorker, self).startService()

    @inlineCallbacks
    def startWorker(self):
        self.events.append("START: %s" % self.name)
        self.pub = yield self.publish_to("%s.outbound" % self.name)
        yield self.consume("%s.inbound" % self.name, self.process_message,
                           message_class=TransportUserMessage)
        self._d.callback(None)

    def stopWorker(self):
        self.events.append("STOP: %s" % self.name)

    def process_message(self, message):
        return self.pub.publish_message(
            message.reply(''.join(reversed(message['content']))))


class StubbedMultiWorker(MultiWorker):
    def WORKER_CREATOR(self, options):
        worker_creator = StubbedWorkerCreator(options)
        worker_creator.broker = self._amqp_client.broker
        return worker_creator

    def wait_for_workers(self):
        return DeferredList([w._d for w in self.workers])


class TestMultiWorker(VumiTestCase):

    base_config = {
        'workers': {
            'worker1': "%s.ToyWorker" % (__name__,),
            'worker2': "%s.ToyWorker" % (__name__,),
            'worker3': "%s.ToyWorker" % (__name__,),
            },
        'worker1': {
            'foo': 'bar',
            },
        }

    def setUp(self):
        self.msg_helper = self.add_helper(MessageHelper())
        self.worker_helper = self.add_helper(WorkerHelper())
        self.clear_events()
        self.add_cleanup(self.clear_events)

    def clear_events(self):
        ToyWorker.events[:] = []

    def dispatch(self, msg, connector_name):
        return self.worker_helper.dispatch_inbound(msg, connector_name)

    def get_replies(self, connector_name):
        msgs = self.worker_helper.get_dispatched_outbound(connector_name)
        return [msg['content'] for msg in msgs]

    @inlineCallbacks
    def get_multiworker(self, config):
        self.worker = yield self.worker_helper.get_worker(
            StubbedMultiWorker, config, start=False)
        yield self.worker.startService()
        yield self.worker.wait_for_workers()
        returnValue(self.worker)

    @inlineCallbacks
    def test_start_stop_workers(self):
        self.assertEqual([], ToyWorker.events)
        worker = yield self.get_multiworker(self.base_config)
        self.assertEqual(['START: worker%s' % (i + 1) for i in range(3)],
                         sorted(ToyWorker.events))
        ToyWorker.events[:] = []
        yield worker.stopService()
        self.assertEqual(['STOP: worker%s' % (i + 1) for i in range(3)],
                         sorted(ToyWorker.events))

    @inlineCallbacks
    def test_message_flow(self):
        yield self.get_multiworker(self.base_config)
        yield self.dispatch(self.msg_helper.make_inbound("foo"), "worker1")
        self.assertEqual(['oof'], self.get_replies("worker1"))
        yield self.dispatch(self.msg_helper.make_inbound("bar"), "worker2")
        yield self.dispatch(self.msg_helper.make_inbound("baz"), "worker3")
        self.assertEqual(['rab'], self.get_replies("worker2"))
        self.assertEqual(['zab'], self.get_replies("worker3"))

    @inlineCallbacks
    def test_config(self):
        worker = yield self.get_multiworker(self.base_config)
        worker1 = worker.getServiceNamed("worker1")
        worker2 = worker.getServiceNamed("worker2")
        self.assertEqual({'foo': 'bar'}, worker1.config)
        self.assertEqual({}, worker2.config)

    @inlineCallbacks
    def test_default_config(self):
        cfg = {'defaults': {'foo': 'baz'}}
        cfg.update(self.base_config)
        worker = yield self.get_multiworker(cfg)
        worker1 = worker.getServiceNamed("worker1")
        worker2 = worker.getServiceNamed("worker2")
        self.assertEqual({'foo': 'bar'}, worker1.config)
        self.assertEqual({'foo': 'baz'}, worker2.config)
PKqGR(BjJjJvumi/tests/test_utils.pyimport os.path

from twisted.internet import reactor
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import Protocol, Factory
from twisted.internet.task import Clock
from twisted.web.server import Site, NOT_DONE_YET
from twisted.web.resource import Resource
from twisted.web import http
from twisted.web.client import WebClientContextFactory, ResponseFailed

from vumi.utils import (
    normalize_msisdn, vumi_resource_path, cleanup_msisdn, get_operator_name,
    http_request, http_request_full, get_first_word, redis_from_config,
    build_web_site, LogFilterSite, PkgResources, HttpTimeoutError,
    StatusEdgeDetector)
from vumi.message import TransportStatus
from vumi.persist.fake_redis import FakeRedis
from vumi.tests.fake_connection import (
    FakeServer, FakeHttpServer, ProxyAgentWithContext, wait0)
from vumi.tests.helpers import VumiTestCase, import_skip


class DummyRequest(object):
    def __init__(self, postpath, prepath):
        self.postpath = postpath
        self.prepath = prepath


class TestNormalizeMsisdn(VumiTestCase):
    def test_leading_zero(self):
        self.assertEqual(normalize_msisdn('0761234567', '27'),
                         '+27761234567')

    def test_double_leading_zero(self):
        self.assertEqual(normalize_msisdn('0027761234567', '27'),
                         '+27761234567')

    def test_leading_plus(self):
        self.assertEqual(normalize_msisdn('+27761234567', '27'),
                         '+27761234567')

    def test_no_leading_plus_or_zero(self):
        self.assertEqual(normalize_msisdn('27761234567', '27'),
                         '+27761234567')

    def test_short_address(self):
        self.assertEqual(normalize_msisdn('1234'), '1234')
        self.assertEqual(normalize_msisdn('12345'), '12345')

    def test_short_address_with_leading_plus(self):
        self.assertEqual(normalize_msisdn('+12345'), '+12345')

    def test_unicode_addr_remains_unicode(self):
        addr = normalize_msisdn(u'0761234567', '27')
        self.assertEqual(addr, u'+27761234567')
        self.assertTrue(isinstance(addr, unicode))

    def test_str_addr_remains_str(self):
        addr = normalize_msisdn('0761234567', '27')
        self.assertEqual(addr, '+27761234567')
        self.assertTrue(isinstance(addr, str))


class TestUtils(VumiTestCase):
    def test_make_campaign_path_abs(self):
        vumi_tests_path = os.path.dirname(__file__)
        vumi_path = os.path.dirname(os.path.dirname(vumi_tests_path))
        self.assertEqual('/foo/bar', vumi_resource_path('/foo/bar'))
        self.assertEqual(os.path.join(vumi_path, 'vumi/resources/foo/bar'),
                         vumi_resource_path('foo/bar'))

    def test_cleanup_msisdn(self):
        self.assertEqual('27761234567', cleanup_msisdn('27761234567', '27'))
        self.assertEqual('27761234567', cleanup_msisdn('+27761234567', '27'))
        self.assertEqual('27761234567', cleanup_msisdn('0761234567', '27'))

    def test_get_operator_name(self):
        mapping = {'27': {'2782': 'VODACOM', '2783': 'MTN'}}
        self.assertEqual('MTN', get_operator_name('27831234567', mapping))
        self.assertEqual('VODACOM', get_operator_name('27821234567', mapping))
        self.assertEqual('UNKNOWN', get_operator_name('27801234567', mapping))

    def test_get_first_word(self):
        self.assertEqual('KEYWORD',
                         get_first_word('KEYWORD rest of the message'))
        self.assertEqual('', get_first_word(''))
        self.assertEqual('', get_first_word(None))

    def test_redis_from_config_str(self):
        try:
            fake_redis = redis_from_config("FAKE_REDIS")
        except ImportError, e:
            import_skip(e, 'redis')
        self.assertTrue(isinstance(fake_redis, FakeRedis))

    def test_redis_from_config_fake_redis(self):
        fake_redis = FakeRedis()
        try:
            self.assertEqual(redis_from_config(fake_redis), fake_redis)
        except ImportError, e:
            import_skip(e, 'redis')

    def get_resource(self, path, site):
        request = DummyRequest(postpath=path.split('/'), prepath=[])
        return site.getResourceFor(request)

    def test_build_web_site(self):
        resource_a = Resource()
        resource_b = Resource()
        site = build_web_site({
            'foo/a': resource_a,
            'bar/b': resource_b,
        })
        self.assertEqual(self.get_resource('foo/a', site), resource_a)
        self.assertEqual(self.get_resource('bar/b', site), resource_b)
        self.assertTrue(isinstance(site, LogFilterSite))

    def test_build_web_site_with_overlapping_paths(self):
        resource_a = Resource()
        resource_b = Resource()
        site = build_web_site({
            'foo/a': resource_a,
            'foo/b': resource_b,
        })
        self.assertEqual(self.get_resource('foo/a', site), resource_a)
        self.assertEqual(self.get_resource('foo/b', site), resource_b)
        self.assertTrue(isinstance(site, LogFilterSite))

    def test_build_web_site_with_custom_site_class(self):
        site = build_web_site({}, site_class=Site)
        self.assertTrue(isinstance(site, Site))
        self.assertFalse(isinstance(site, LogFilterSite))


class FakeHTTP10(Protocol):
    def dataReceived(self, data):
        self.transport.write(self.factory.response_body)
        self.transport.loseConnection()


class TestHttpUtils(VumiTestCase):
    def setUp(self):
        self.fake_http = FakeHttpServer(lambda r: self._render_request(r))
        self.url = "http://example.com:9980/"

    def set_render(self, f):
        def render(request):
            request.setHeader('Content-Type', 'text/plain')
            try:
                data = f(request)
                request.setResponseCode(http.OK)
            except Exception, err:
                data = str(err)
                request.setResponseCode(http.INTERNAL_SERVER_ERROR)
            return data
        self._render_request = render

    def set_async_render(self):
        def render_interrupt(request):
            reactor.callLater(0, d.callback, request)
            return NOT_DONE_YET
        d = Deferred()
        self.set_render(render_interrupt)
        return d

    @inlineCallbacks
    def make_real_webserver(self):
        """
        Construct a real webserver to test actual connectivity.
        """
        root = Resource()
        root.isLeaf = True
        root.render = lambda r: self._render_request(r)
        site_factory = Site(root)
        webserver = yield reactor.listenTCP(
            0, site_factory, interface='127.0.0.1')
        self.add_cleanup(webserver.loseConnection)
        addr = webserver.getHost()
        url = "http://%s:%s/" % (addr.host, addr.port)
        returnValue(url)

    def with_agent(self, f, *args, **kw):
        """
        Wrapper around http_request_full and friends that injects our fake
        connection's agent.
        """
        kw.setdefault('agent_class', self.fake_http.get_agent)
        return f(*args, **kw)

    @inlineCallbacks
    def test_http_request_to_localhost(self):
        """
        Make a request over the network (localhost) to check that we're getting
        a real agent by default.
        """
        url = yield self.make_real_webserver()
        self.set_render(lambda r: "Yay")
        data = yield http_request(url, '')
        self.assertEqual(data, "Yay")

    @inlineCallbacks
    def test_http_request_ok(self):
        self.set_render(lambda r: "Yay")
        data = yield self.with_agent(http_request, self.url, '')
        self.assertEqual(data, "Yay")

    @inlineCallbacks
    def test_http_request_err(self):
        def err(r):
            raise ValueError("Bad")
        self.set_render(err)
        data = yield self.with_agent(http_request, self.url, '')
        self.assertEqual(data, "Bad")

    @inlineCallbacks
    def test_http_request_full_to_localhost(self):
        """
        Make a request over the network (localhost) to check that we're getting
        a real agent by default.
        """
        url = yield self.make_real_webserver()
        self.set_render(lambda r: "Yay")
        request = yield http_request_full(url, '')
        self.assertEqual(request.delivered_body, "Yay")
        self.assertEqual(request.code, http.OK)
        self.set_render(lambda r: "Yay")

    @inlineCallbacks
    def test_http_request_with_custom_context_factory(self):
        self.set_render(lambda r: "Yay")
        agents = []

        ctxt = WebClientContextFactory()

        def stashing_factory(reactor, contextFactory=None):
            agent = self.fake_http.get_agent(
                reactor, contextFactory=contextFactory)
            agents.append(agent)
            return agent

        request = yield http_request_full(
            self.url, '', context_factory=ctxt, agent_class=stashing_factory)
        self.assertEqual(request.delivered_body, "Yay")
        self.assertEqual(request.code, http.OK)
        [agent] = agents
        self.assertEqual(agent.contextFactory, ctxt)

    @inlineCallbacks
    def test_http_request_full_drop(self):
        """
        If a connection drops, we get an appropriate exception.
        """
        got_request = self.set_async_render()
        got_data = self.with_agent(http_request_full, self.url, '')
        request = yield got_request
        request.setResponseCode(http.OK)
        request.write("Foo!")
        request.transport.loseConnection()

        yield self.assertFailure(got_data, ResponseFailed)

    @inlineCallbacks
    def test_http_request_full_ok(self):
        self.set_render(lambda r: "Yay")
        request = yield self.with_agent(http_request_full, self.url, '')
        self.assertEqual(request.delivered_body, "Yay")
        self.assertEqual(request.code, http.OK)

    @inlineCallbacks
    def test_http_request_full_headers(self):
        def check_ua(request):
            self.assertEqual('blah', request.getHeader('user-agent'))
            return "Yay"
        self.set_render(check_ua)

        request = yield self.with_agent(
            http_request_full, self.url, '', {'User-Agent': ['blah']})
        self.assertEqual(request.delivered_body, "Yay")
        self.assertEqual(request.code, http.OK)

        request = yield self.with_agent(
            http_request_full, self.url, '', {'User-Agent': 'blah'})
        self.assertEqual(request.delivered_body, "Yay")
        self.assertEqual(request.code, http.OK)

    @inlineCallbacks
    def test_http_request_full_err(self):
        def err(r):
            raise ValueError("Bad")
        self.set_render(err)
        request = yield self.with_agent(http_request_full, self.url, '')
        self.assertEqual(request.delivered_body, "Bad")
        self.assertEqual(request.code, http.INTERNAL_SERVER_ERROR)

    @inlineCallbacks
    def test_http_request_potential_data_loss(self):
        """
        In the absence of a Content-Length header or chunked transfer encoding,
        we need to swallow a PotentialDataLoss exception.
        """
        # We can't use Twisted's HTTP server, because that always does the
        # sensible thing. We also pretend to be HTTP 1.0 for simplicity.
        factory = Factory()
        factory.protocol = FakeHTTP10
        factory.response_body = (
            "HTTP/1.0 201 CREATED\r\n"
            "Date: Mon, 23 Jan 2012 15:08:47 GMT\r\n"
            "Server: Fake HTTP 1.0\r\n"
            "Content-Type: text/html; charset=utf-8\r\n"
            "\r\n"
            "Yay")
        fake_server = FakeServer(factory)
        agent_factory = lambda *a, **kw: ProxyAgentWithContext(
            fake_server.endpoint, *a, **kw)

        data = yield http_request(self.url, '', agent_class=agent_factory)
        self.assertEqual(data, "Yay")

    @inlineCallbacks
    def test_http_request_full_data_limit(self):
        self.set_render(lambda r: "Four")

        d = self.with_agent(http_request_full, self.url, '', data_limit=3)

        def check_response(reason):
            self.assertTrue(reason.check('vumi.utils.HttpDataLimitError'))
            self.assertEqual(reason.getErrorMessage(),
                             "More than 3 bytes received")

        d.addBoth(check_response)
        yield d

    @inlineCallbacks
    def test_http_request_full_ok_with_timeout_set(self):
        """
        If a request completes within the timeout, everything is happy.
        """
        clock = Clock()
        self.set_render(lambda r: "Yay")
        response = yield self.with_agent(
            http_request_full, self.url, '', timeout=30, reactor=clock)
        self.assertEqual(response.delivered_body, "Yay")
        self.assertEqual(response.code, http.OK)
        # Advance the clock past the timeout limit.
        clock.advance(30)

    @inlineCallbacks
    def test_http_request_full_drop_with_timeout_set(self):
        """
        If a request fails within the timeout, everything is happy(ish).
        """
        clock = Clock()
        d = self.set_async_render()
        got_data = self.with_agent(
            http_request_full, self.url, '', timeout=30, reactor=clock)
        request = yield d
        request.setResponseCode(http.OK)
        request.write("Foo!")
        request.transport.loseConnection()

        yield self.assertFailure(got_data, ResponseFailed)
        # Advance the clock past the timeout limit.
        clock.advance(30)

    def test_http_request_full_timeout_before_connect(self):
        """
        A request can time out before a connection is made.
        """
        clock = Clock()
        # Instead of setting a render function, we tell the server not to
        # accept connections.
        self.fake_http.fake_server.auto_accept = False
        d = self.with_agent(
            http_request_full, self.url, '', timeout=30, reactor=clock)
        self.assertNoResult(d)
        clock.advance(29)
        self.assertNoResult(d)
        clock.advance(1)
        self.failureResultOf(d, HttpTimeoutError)

    @inlineCallbacks
    def test_http_request_full_timeout_after_connect(self):
        """
        The client disconnects after the timeout if no data has been received
        from the server.
        """
        clock = Clock()
        request_started = self.set_async_render()
        client_done = self.with_agent(
            http_request_full, self.url, '', timeout=30, reactor=clock)
        yield request_started

        self.assertNoResult(client_done)
        clock.advance(29)
        self.assertNoResult(client_done)
        clock.advance(1)
        failure = self.failureResultOf(client_done, HttpTimeoutError)
        self.assertEqual(
            failure.getErrorMessage(), "Timeout while connecting")

    @inlineCallbacks
    def test_http_request_full_timeout_after_first_receive(self):
        """
        The client disconnects after the timeout even if some data has already
        been received.
        """
        clock = Clock()
        request_started = self.set_async_render()
        client_done = self.with_agent(
            http_request_full, self.url, '', timeout=30, reactor=clock)
        request = yield request_started
        request_done = request.notifyFinish()

        request.write("some data")
        clock.advance(1)
        yield wait0()
        self.assertNoResult(client_done)
        self.assertNoResult(request_done)

        clock.advance(28)
        self.assertNoResult(client_done)
        self.assertNoResult(request_done)
        clock.advance(1)
        failure = self.failureResultOf(client_done, HttpTimeoutError)
        self.assertEqual(
            failure.getErrorMessage(), "Timeout while receiving data")
        yield self.assertFailure(request_done, ConnectionDone)


class TestPkgResources(VumiTestCase):

    vumi_tests_path = os.path.dirname(__file__)

    def test_absolute_path(self):
        pkg = PkgResources("vumi.tests")
        self.assertEqual('/foo/bar', pkg.path('/foo/bar'))

    def test_relative_path(self):
        pkg = PkgResources("vumi.tests")
        self.assertEqual(os.path.join(self.vumi_tests_path, 'foo/bar'),
                         pkg.path('foo/bar'))


class TestStatusEdgeDetector(VumiTestCase):
    def test_status_not_change(self):
        '''If the status doesn't change, None should be returned.'''
        sed = StatusEdgeDetector()
        status1 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'test'}
        self.assertEqual(sed.check_status(**status1), status1)

        status2 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'another test'}
        self.assertEqual(sed.check_status(**status2), None)

    def test_status_change(self):
        '''If the status does change, the status should be returned.'''
        sed = StatusEdgeDetector()
        status1 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'test'}
        self.assertEqual(sed.check_status(**status1), status1)

        status2 = {
            'component': 'foo',
            'status': 'degraded',
            'type': 'bar',
            'message': 'another test'}
        self.assertEqual(sed.check_status(**status2), status2)

    def test_components_separate(self):
        '''A state change in one component should not affect other
        components.'''
        sed = StatusEdgeDetector()
        comp1_status1 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'test'}
        self.assertEqual(sed.check_status(**comp1_status1), comp1_status1)

        comp2_status1 = {
            'component': 'bar',
            'status': 'ok',
            'type': 'bar',
            'message': 'another test'}
        self.assertEqual(sed.check_status(**comp2_status1), comp2_status1)

        comp2_status2 = {
            'component': 'bar',
            'status': 'degraded',
            'type': 'bar',
            'message': 'another test'}
        self.assertEqual(sed.check_status(**comp2_status2), comp2_status2)

        comp1_status2 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'test'}
        self.assertEqual(sed.check_status(**comp1_status2), None)

    def test_type_change(self):
        '''A change in status type should result in the status being
        returned.'''
        sed = StatusEdgeDetector()
        status1 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'bar',
            'message': 'test'}
        self.assertEqual(sed.check_status(**status1), status1)

        status2 = {
            'component': 'foo',
            'status': 'ok',
            'type': 'baz',
            'message': 'test'}
        self.assertEqual(sed.check_status(**status2), status2)
PK=H	##vumi/tests/fake_connection.pyfrom twisted.internet.defer import Deferred, DeferredQueue
from twisted.internet.error import (
    ConnectionRefusedError, ConnectionDone, ConnectionAborted)
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Protocol, ServerFactory
from twisted.internet.task import deferLater
from twisted.protocols.loopback import loopbackAsync
from twisted.python.failure import Failure
from twisted.web.client import ProxyAgent
from twisted.web.resource import Resource
from twisted.web.server import Site
from zope.interface import implementer


def wait0(r=None):
    """
    Wait zero seconds to give the reactor a chance to work.

    Returns its (optional) argument, so it's useful as a callback.
    """
    from twisted.internet import reactor
    return deferLater(reactor, 0, lambda: r)


class ProtocolDouble(Protocol):
    """
    A stand-in protocol for manually driving one side of a connection.
    """

    def __init__(self):
        self.received = b""
        self.disconnected_reason = None

    def dataReceived(self, data):
        self.received += data

    def connectionLost(self, reason):
        self.connected = False
        self.disconnected_reason = reason

    def write(self, data):
        """
        Write some bytes and allow the reactor to send them.
        """
        self.transport.write(data)
        return wait0()


class FakeServer(object):
    """
    Fake server container for testing client/server interactions.
    """

    def __init__(self, server_factory, auto_accept=True, on_connect=None):
        self.server_factory = server_factory
        self.auto_accept = auto_accept
        self.connection_queue = DeferredQueue()
        self.on_connect = on_connect

    # Public API.

    @classmethod
    def for_protocol(cls, protocol, *args, **kw):
        factory = ServerFactory.forProtocol(protocol)
        return cls(factory, *args, **kw)

    @property
    def endpoint(self):
        """
        Get an endpoint that connects clients to this server.
        """
        return FakeServerEndpoint(self)

    def await_connection(self):
        """
        Wait for a client to start connecting, and then return a
        :class:`FakeConnection` object.
        """
        return self.connection_queue.get()

    # Internal stuff.

    def _handle_connection(self):
        conn = FakeConnection(self)
        if self.on_connect is not None:
            conn._connected_d.addCallback(lambda _: self.on_connect(conn))
        self.connection_queue.put(conn)
        if self.auto_accept:
            conn.accept_connection()
        return conn._accept_d


class FakeConnection(object):
    """
    Fake server connection.
    """

    def __init__(self, server):
        self.server = server
        self.client_protocol = None
        self.server_protocol = None

        self._accept_d = Deferred()
        self._connected_d = Deferred()
        self._finished_d = Deferred()

    @property
    def connected(self):
        return self._connected_d.called and not self._finished_d.called

    @property
    def pending(self):
        return not self._accept_d.called

    def await_connected(self):
        """
        Wait for a client to finish connecting.
        """
        return self._connected_d

    def accept_connection(self):
        """
        Accept a pending connection.
        """
        assert self.pending, "Connection is not pending."
        self.server_protocol = self.server.server_factory.buildProtocol(None)
        self._accept_d.callback(
            FakeServerProtocolWrapper(self, self.server_protocol))
        return self.await_connected()

    def reject_connection(self, reason=None):
        """
        Reject a pending connection.
        """
        assert self.pending, "Connection is not pending."
        if reason is None:
            reason = ConnectionRefusedError()
        self._accept_d.errback(reason)

    def await_finished(self):
        """
        Wait for the both sides of the connection to close.
        """
        return self._finished_d

    # Internal stuff.

    def _finish_connecting(self, client_protocol, finished_d):
        self.client_protocol = client_protocol
        finished_d.chainDeferred(self._finished_d)
        self._connected_d.callback(None)


@implementer(IStreamClientEndpoint)
class FakeServerEndpoint(object):
    """
    This endpoint connects a client directly to a FakeSMSC.
    """
    def __init__(self, server):
        self.server = server

    def connect(self, protocolFactory):
        d = self.server._handle_connection()
        return d.addCallback(self._make_connection, protocolFactory)

    def _make_connection(self, server, protocolFactory):
        client = protocolFactory.buildProtocol(None)
        patch_protocol_for_agent(client)
        finished_d = loopbackAsync(server, client)
        server.finish_connecting(client, finished_d)
        return client


def patch_protocol_for_agent(protocol):
    """
    Patch the protocol's makeConnection and connectionLost methods to make the
    protocol and its transport behave more like what `Agent` expects.

    While `Agent` is the driving force behind this, other clients and servers
    will no doubt have similar requirements.
    """
    old_makeConnection = protocol.makeConnection
    old_connectionLost = protocol.connectionLost

    def new_makeConnection(transport):
        patch_transport_fake_push_producer(transport)
        patch_transport_abortConnection(transport, protocol)
        return old_makeConnection(transport)

    def new_connectionLost(reason):
        # Replace ConnectionDone with ConnectionAborted if we aborted.
        if protocol._fake_connection_aborted and reason.check(ConnectionDone):
            reason = Failure(ConnectionAborted())
        return old_connectionLost(reason)

    protocol.makeConnection = new_makeConnection
    protocol.connectionLost = new_connectionLost
    protocol._fake_connection_aborted = False


def patch_if_missing(obj, name, method):
    """
    Patch a method onto an object if it isn't already there.
    """
    setattr(obj, name, getattr(obj, name, method))


def patch_transport_fake_push_producer(transport):
    """
    Patch the three methods belonging to IPushProducer onto the transport if it
    doesn't already have them. (`Agent` assumes its transport has these.)
    """
    patch_if_missing(transport, 'pauseProducing', lambda: None)
    patch_if_missing(transport, 'resumeProducing', lambda: None)
    patch_if_missing(transport, 'stopProducing', transport.loseConnection)


def patch_transport_abortConnection(transport, protocol):
    """
    Patch abortConnection() on the transport or add it if it doesn't already
    exist (`Agent` assumes its transport has this).

    The patched method sets an internal flag recording the abort and then calls
    the original method (if it existed) or transport.loseConnection (if it
    didn't).
    """
    _old_abortConnection = getattr(
        transport, 'abortConnection', transport.loseConnection)

    def abortConnection():
        protocol._fake_connection_aborted = True
        _old_abortConnection()

    transport.abortConnection = abortConnection


class FakeServerProtocolWrapper(Protocol):
    """
    Wrapper around a server protocol to track connection state.
    """

    def __init__(self, connection, protocol):
        self.connection = connection
        patch_protocol_for_agent(protocol)
        self.protocol = protocol

    def makeConnection(self, transport):
        Protocol.makeConnection(self, transport)
        return self.protocol.makeConnection(transport)

    def connectionLost(self, reason):
        return self.protocol.connectionLost(reason)

    def dataReceived(self, data):
        return self.protocol.dataReceived(data)

    def finish_connecting(self, client, finished_d):
        self.connection._finish_connecting(client, finished_d)


class FakeHttpServer(object):
    """
    HTTP server built on top of FakeServer.
    """

    def __init__(self, handler):
        site_factory = Site(HandlerResource(handler))
        self.fake_server = FakeServer(site_factory)

    @property
    def endpoint(self):
        return self.fake_server.endpoint

    def get_agent(self, reactor=None, contextFactory=None):
        """
        Returns an IAgent that makes requests to this fake server.
        """
        return ProxyAgentWithContext(
            self.endpoint, reactor=reactor, contextFactory=contextFactory)


class HandlerResource(Resource):
    isLeaf = True

    def __init__(self, handler):
        Resource.__init__(self)
        self.handler = handler

    def render_GET(self, request):
        return self.handler(request)

    def render_POST(self, request):
        return self.handler(request)

    def render_PUT(self, request):
        return self.handler(request)


class ProxyAgentWithContext(ProxyAgent):
    def __init__(self, endpoint, reactor=None, contextFactory=None):
        self.contextFactory = contextFactory  # To assert on in tests.
        super(ProxyAgentWithContext, self).__init__(endpoint, reactor=reactor)
PKqGkUѴ??"vumi/tests/test_fake_connection.pyfrom twisted.internet.defer import inlineCallbacks
from twisted.internet.error import (
    ConnectionRefusedError, UnknownHostError, ConnectionDone,
    ConnectionAborted)
from twisted.internet.protocol import Protocol, ClientFactory, ServerFactory
from twisted.trial.unittest import TestCase
from twisted.web.client import readBody
from twisted.web.iweb import IAgent
from twisted.web.server import NOT_DONE_YET

from vumi.tests.fake_connection import (
    FakeServer, wait0, ProtocolDouble, FakeHttpServer)


class DummyServerProtocol(ProtocolDouble):
    side = "server"


class DummyClientProtocol(ProtocolDouble):
    side = "client"


class TestFakeConnection(TestCase):
    """
    Tests for FakeConnection and friends.
    """

    def setUp(self):
        self.client_factory = ClientFactory.forProtocol(DummyClientProtocol)
        self.server_factory = ServerFactory.forProtocol(DummyServerProtocol)

    def connect_client(self, fake_server):
        """
        Create a client connection to a fake server.

        :returns: (connection, client)
        """
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)  # We don't want an existing connection.
        client_d = fake_server.endpoint.connect(self.client_factory)
        client = self.successResultOf(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_connected(conn, client)
        return (conn, client)

    def assert_pending(self, conn):
        """
        Assert that a connection is not yet connected.
        """
        self.assertEqual(conn.client_protocol, None)
        self.assertEqual(conn.server_protocol, None)
        self.assertEqual(conn.connected, False)
        self.assertEqual(conn.pending, True)
        self.assertNoResult(conn._accept_d)
        self.assertNoResult(conn._connected_d)
        self.assertNoResult(conn._finished_d)

    def assert_connected(self, conn, client):
        """
        Assert that a connection is connected to a client.
        """
        self.assertIsInstance(conn.client_protocol, DummyClientProtocol)
        self.assertEqual(conn.client_protocol.side, "client")
        self.assertEqual(conn.client_protocol.connected, True)
        self.assertEqual(conn.client_protocol.disconnected_reason, None)
        self.assertIsInstance(conn.server_protocol, DummyServerProtocol)
        self.assertEqual(conn.server_protocol.side, "server")
        self.assertEqual(conn.server_protocol.connected, True)
        self.assertEqual(conn.server_protocol.disconnected_reason, None)
        self.assertEqual(conn.connected, True)
        self.assertEqual(conn.pending, False)
        self.successResultOf(conn._accept_d)
        self.successResultOf(conn._connected_d)
        self.assertNoResult(conn._finished_d)
        self.assertEqual(conn.client_protocol, client)

    def assert_connection_rejected(self, conn):
        self.assertEqual(conn.client_protocol, None)
        self.assertEqual(conn.server_protocol, None)
        self.assertEqual(conn.connected, False)
        self.assertEqual(conn.pending, False)
        self.assertNoResult(conn._connected_d)
        self.assertNoResult(conn._finished_d)

    def assert_disconnected(self, conn, client_reason=ConnectionDone,
                            server_reason=ConnectionDone):
        self.assertEqual(conn.client_protocol.connected, False)
        self.assertEqual(conn.server_protocol.connected, False)
        client_reason_f = conn.client_protocol.disconnected_reason
        server_reason_f = conn.server_protocol.disconnected_reason
        self.assertEqual(client_reason_f.check(client_reason), client_reason)
        self.assertEqual(server_reason_f.check(server_reason), server_reason)

    def test_client_connect_auto(self):
        """
        A server will automatically accept client connections by default.
        """
        fake_server = FakeServer(self.server_factory)
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)

        client_d = fake_server.endpoint.connect(self.client_factory)
        client = self.successResultOf(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_connected(conn, client)

    def test_accept_connection(self):
        """
        Connections can be accepted manually if desired.
        """
        fake_server = FakeServer(self.server_factory, auto_accept=False)
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)

        client_d = fake_server.endpoint.connect(self.client_factory)
        self.assertNoResult(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_pending(conn)

        connected_d = conn.await_connected()
        self.assertNoResult(connected_d)

        accepted_d = conn.accept_connection()
        self.successResultOf(accepted_d)
        self.successResultOf(connected_d)
        client = self.successResultOf(client_d)
        self.assert_connected(conn, client)

    def test_client_connect_hook(self):
        """
        An on_connect function can be passed to the server to be called
        whenever a connection is made.
        """
        def on_connect(conn):
            conn.hook_was_called = True
            conn.client_id_from_hook = id(conn.client_protocol)
            conn.server_id_from_hook = id(conn.server_protocol)

        fake_server = FakeServer(self.server_factory, on_connect=on_connect)
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)

        client_d = fake_server.endpoint.connect(self.client_factory)
        client = self.successResultOf(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_connected(conn, client)
        self.assertEqual(conn.hook_was_called, True)
        self.assertEqual(conn.client_id_from_hook, id(client))
        self.assertEqual(conn.server_id_from_hook, id(conn.server_protocol))

    def test_reject_connection(self):
        """
        Connections can be rejected manually if desired.
        """
        fake_server = FakeServer(self.server_factory, auto_accept=False)
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)

        client_d = fake_server.endpoint.connect(self.client_factory)
        self.assertNoResult(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_pending(conn)

        connected_d = conn.await_connected()
        self.assertNoResult(connected_d)

        conn.reject_connection()
        self.assertNoResult(connected_d)
        self.failureResultOf(client_d, ConnectionRefusedError)
        self.assert_connection_rejected(conn)

    def test_reject_connection_with_reason(self):
        """
        Connections can be rejected with a custom reason.
        """
        fake_server = FakeServer(self.server_factory, auto_accept=False)
        conn_d = fake_server.await_connection()
        self.assertNoResult(conn_d)

        client_d = fake_server.endpoint.connect(self.client_factory)
        self.assertNoResult(client_d)
        conn = self.successResultOf(conn_d)
        self.assert_pending(conn)

        connected_d = conn.await_connected()
        self.assertNoResult(connected_d)

        conn.reject_connection(UnknownHostError())
        self.assertNoResult(connected_d)
        self.failureResultOf(client_d, UnknownHostError)
        self.assert_connection_rejected(conn)

    @inlineCallbacks
    def test_client_disconnect(self):
        """
        If the client disconnects, the server's connection is also lost.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        finished_d = conn.await_finished()
        self.assertNoResult(finished_d)

        # The disconnection gets scheduled, but doesn't actually happen until
        # the next reactor tick.
        client.transport.loseConnection()
        self.assert_connected(conn, client)
        self.assertNoResult(finished_d)

        # Allow the reactor to run so the disconnection gets processed.
        yield wait0()
        self.assert_disconnected(conn)
        self.successResultOf(finished_d)

    @inlineCallbacks
    def test_server_disconnect(self):
        """
        If the server disconnects, the client's connection is also lost.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        finished_d = conn.await_finished()
        self.assertNoResult(finished_d)

        # The disconnection gets scheduled, but doesn't actually happen until
        # the next reactor tick.
        conn.server_protocol.transport.loseConnection()
        self.assert_connected(conn, client)
        self.assertNoResult(finished_d)

        # Allow the reactor to run so the disconnection gets processed.
        yield wait0()
        self.assert_disconnected(conn)
        self.successResultOf(finished_d)

    @inlineCallbacks
    def test_client_abort(self):
        """
        If the client aborts, the server's connection is also lost.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        finished_d = conn.await_finished()
        self.assertNoResult(finished_d)

        # The disconnection gets scheduled, but doesn't actually happen until
        # the next reactor tick.
        client.transport.abortConnection()
        self.assert_connected(conn, client)
        self.assertNoResult(finished_d)

        # Allow the reactor to run so the disconnection gets processed.
        yield wait0()
        self.assert_disconnected(conn, client_reason=ConnectionAborted)
        self.successResultOf(finished_d)

    @inlineCallbacks
    def test_server_abort(self):
        """
        If the server aborts, the client's connection is also lost.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        finished_d = conn.await_finished()
        self.assertNoResult(finished_d)

        # The disconnection gets scheduled, but doesn't actually happen until
        # the next reactor tick.
        conn.server_protocol.transport.abortConnection()
        self.assert_connected(conn, client)
        self.assertNoResult(finished_d)

        # Allow the reactor to run so the disconnection gets processed.
        yield wait0()
        self.assert_disconnected(conn, server_reason=ConnectionAborted)
        self.successResultOf(finished_d)

    @inlineCallbacks
    def test_send_client_to_server(self):
        """
        Bytes can be sent from the client to the server.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        server = conn.server_protocol
        self.assertEqual(server.received, b"")

        # Bytes sent, but not received until reactor runs.
        d = client.write(b"foo")
        self.assertEqual(server.received, b"")
        yield d
        self.assertEqual(server.received, b"foo")

        client.write(b"bar")
        d = client.write(b"baz")
        self.assertEqual(server.received, b"foo")
        yield d
        self.assertEqual(server.received, b"foobarbaz")

    @inlineCallbacks
    def test_send_server_to_client(self):
        """
        Bytes can be sent from the server to the client.
        """
        fake_server = FakeServer(self.server_factory)
        conn, client = self.connect_client(fake_server)
        server = conn.server_protocol
        self.assertEqual(server.received, b"")

        # Bytes sent, but not received until reactor runs.
        d = server.write(b"foo")
        self.assertEqual(client.received, b"")
        yield d
        self.assertEqual(client.received, b"foo")

        # Send two sets of bytes at once, waiting for the second.
        server.write(b"bar")
        d = server.write(b"baz")
        self.assertEqual(client.received, b"foo")
        yield d
        self.assertEqual(client.received, b"foobarbaz")

    @inlineCallbacks
    def test_server_for_protocol(self):
        """
        A FakeServer can also be constructed from a a protocol class instead of
        a factory.
        """
        class MyProtocol(Protocol):
            pass
        fake_server = FakeServer.for_protocol(MyProtocol)
        self.assertEqual(fake_server.server_factory.protocol, MyProtocol)
        self.assertNotEqual(fake_server.server_factory.protocol, Protocol)

        yield fake_server.endpoint.connect(self.client_factory)
        conn = yield fake_server.await_connection()
        self.assertIsInstance(conn.server_protocol, MyProtocol)


class TestFakeHttpServer(TestCase):
    @inlineCallbacks
    def assert_response(self, response, code, body):
        self.assertEqual(response.code, code)
        response_body = yield readBody(response)
        self.assertEqual(response_body, body)

    @inlineCallbacks
    def test_simple_request(self):
        """
        FakeHttpServer can handle a simple HTTP request using the IAgent
        provider it supplies.
        """
        requests = []
        fake_http = FakeHttpServer(lambda req: requests.append(req) or "hi")
        agent = fake_http.get_agent()
        self.assertTrue(IAgent.providedBy(agent))
        response = yield agent.request("GET", "http://example.com/hello")
        # We got a valid request and returned a valid response.
        [request] = requests
        self.assertEqual(request.method, "GET")
        self.assertEqual(request.path, "http://example.com/hello")
        yield self.assert_response(response, 200, "hi")

    @inlineCallbacks
    def test_async_response(self):
        """
        FakeHttpServer supports asynchronous responses.
        """
        requests = []
        fake_http = FakeHttpServer(
            lambda req: requests.append(req) or NOT_DONE_YET)
        response1_d = fake_http.get_agent().request(
            "GET", "http://example.com/hello/1")
        response2_d = fake_http.get_agent().request(
            "HEAD", "http://example.com/hello/2")

        # Wait for the requests to arrive.
        yield wait0()
        [request1, request2] = requests
        self.assertNoResult(response1_d)
        self.assertNoResult(response2_d)
        self.assertEqual(request1.method, "GET")
        self.assertEqual(request1.path, "http://example.com/hello/1")
        self.assertEqual(request2.method, "HEAD")
        self.assertEqual(request2.path, "http://example.com/hello/2")

        # Send a response to the second request.
        request2.finish()
        response2 = yield response2_d
        self.assertNoResult(response1_d)
        yield self.assert_response(response2, 200, "")

        # Send a response to the first request.
        request1.write("Thank you for waiting.")
        request1.finish()
        response1 = yield response1_d
        yield self.assert_response(response1, 200, "Thank you for waiting.")

    @inlineCallbacks
    def test_POST_request(self):
        """
        FakeHttpServer can handle a POST request.
        """
        requests = []
        fake_http = FakeHttpServer(lambda req: requests.append(req) or "hi")
        agent = fake_http.get_agent()
        self.assertTrue(IAgent.providedBy(agent))
        response = yield agent.request("POST", "http://example.com/hello")
        # We got a valid request and returned a valid response.
        [request] = requests
        self.assertEqual(request.method, "POST")
        self.assertEqual(request.path, "http://example.com/hello")
        yield self.assert_response(response, 200, "hi")

    @inlineCallbacks
    def test_PUT_request(self):
        """
        FakeHttpServer can handle a PUT request.
        """
        requests = []
        fake_http = FakeHttpServer(lambda req: requests.append(req) or "hi")
        agent = fake_http.get_agent()
        self.assertTrue(IAgent.providedBy(agent))
        response = yield agent.request("PUT", "http://example.com/hello")
        # We got a valid request and returned a valid response.
        [request] = requests
        self.assertEqual(request.method, "PUT")
        self.assertEqual(request.path, "http://example.com/hello")
        yield self.assert_response(response, 200, "hi")
PKqG2vumi/transports/base.py# -*- test-case-name: vumi.transports.tests.test_base -*-

"""
Common infrastructure for transport workers.

This is likely to get used heavily fast, so try get your changes in early.
"""

from twisted.internet.defer import maybeDeferred, inlineCallbacks, succeed

from vumi.config import ConfigText, ConfigBool
from vumi.message import TransportUserMessage, TransportEvent, TransportStatus
from vumi.worker import BaseWorker, then_call
from vumi.transports.failures import FailureMessage


class TransportConfig(BaseWorker.CONFIG_CLASS):
    """Base config definition for transports.

    You should subclass this and add transport-specific fields.
    """

    transport_name = ConfigText(
        "The name this transport instance will use to create its queues.",
        required=True, static=True)
    publish_status = ConfigBool(
        "Whether status messages should be published by the transport",
        default=False, static=True)


class Transport(BaseWorker):
    """
    Base class for transport workers.

    The following attributes are available for subclasses to control behaviour:

    * :attr:`start_message_consumer` -- Set to ``False`` if the message
      consumer should not be started. The subclass is responsible for starting
      it in this case.
    """

    SUPPRESS_FAILURE_EXCEPTIONS = True
    CONFIG_CLASS = TransportConfig

    transport_name = None
    start_message_consumer = True

    @property
    def status_connector_name(self):
        return "%s.status" % (self.transport_name,)

    def _validate_config(self):
        config = self.get_static_config()
        self.transport_name = config.transport_name
        self._should_publish_status = config.publish_status
        self.validate_config()

    @inlineCallbacks
    def setup_connectors(self):
        yield self.setup_publish_status_connector(self.status_connector_name)

        ro_connector = yield self.setup_ro_connector(self.transport_name)
        self.add_outbound_handler(
            self.handle_outbound_message, connector=ro_connector)

    def setup_worker(self):
        """
        Set up basic transport worker stuff.

        You shouldn't have to override this in subclasses.
        """
        d = self.setup_failure_publisher()
        then_call(d, self.setup_transport)
        if self.start_message_consumer:
            then_call(d, self.unpause_connectors)
        return d

    def teardown_worker(self):
        d = self.pause_connectors()
        d.addCallback(lambda r: self.teardown_transport())
        return d

    def setup_transport(self):
        """
        All transport_specific setup should happen in here.

        Subclasses should override this method to perform extra setup.
        """
        pass

    def teardown_transport(self):
        """
        Clean-up of setup done in setup_transport should happen here.
        """
        pass

    def setup_failure_publisher(self):
        d = self.publish_to('%s.failures' % (self.transport_name,))

        def cb(publisher):
            self.failure_publisher = publisher

        return d.addCallback(cb)

    def send_failure(self, message, exception, traceback):
        """Send a failure report."""
        try:
            failure_code = getattr(exception, "failure_code",
                                   FailureMessage.FC_UNSPECIFIED)
            failure_msg = FailureMessage(
                message=message.payload, failure_code=failure_code,
                reason=traceback)
            connector = self.connectors[self.transport_name]
            d = connector._middlewares.apply_publish(
                "failure", failure_msg, self.transport_name)
            d.addCallback(self.failure_publisher.publish_message)
            d.addCallback(lambda _f: self.failure_published())
        except:
            self.log.err(
                "Error publishing failure: %s, %s, %s"
                % (message, exception, traceback))
            raise
        return d

    def failure_published(self):
        pass

    def publish_message(self, **kw):
        """
        Publish a :class:`TransportUserMessage` message.

        Some default parameters are handled, so subclasses don't have
        to provide a lot of boilerplate.
        """
        kw.setdefault('transport_name', self.transport_name)
        kw.setdefault('transport_metadata', {})
        msg = TransportUserMessage(**kw)
        return self.connectors[self.transport_name].publish_inbound(msg)

    def publish_event(self, **kw):
        """
        Publish a :class:`TransportEvent` message.

        Some default parameters are handled, so subclasses don't have
        to provide a lot of boilerplate.
        """
        kw.setdefault('transport_name', self.transport_name)
        kw.setdefault('transport_metadata', {})
        event = TransportEvent(**kw)
        return self.connectors[self.transport_name].publish_event(event)

    def publish_ack(self, user_message_id, sent_message_id, **kw):
        """
        Helper method for publishing an ``ack`` event.
        """
        return self.publish_event(user_message_id=user_message_id,
                                  sent_message_id=sent_message_id,
                                  event_type='ack', **kw)

    def publish_nack(self, user_message_id, reason, **kw):
        """
        Helper method for publishing a ``nack`` event.
        """
        return self.publish_event(user_message_id=user_message_id,
                                  nack_reason=reason, event_type='nack', **kw)

    def publish_delivery_report(self, user_message_id, delivery_status, **kw):
        """
        Helper method for publishing a ``delivery_report`` event.
        """
        return self.publish_event(user_message_id=user_message_id,
                                  delivery_status=delivery_status,
                                  event_type='delivery_report', **kw)

    def publish_status(self, **kw):
        """
        Helper method for publishing a status message.
        """
        msg = TransportStatus(**kw)

        if self._should_publish_status:
            conn = self.connectors[self.status_connector_name]
            return conn.publish_status(msg)
        else:
            self.log.debug(
                'Status publishing disabled for transport %r, ignoring '
                'status %r' % (self.transport_name, msg))
            return succeed(msg)

    def _send_failure_eb(self, f, message):
        self.send_failure(message, f.value, f.getTraceback())
        self.log.err(f)
        if self.SUPPRESS_FAILURE_EXCEPTIONS:
            return None
        return f

    def _make_message_processor(self, handler):
        def processor(message):
            d = maybeDeferred(handler, message)
            d.addErrback(self._send_failure_eb, message)
            return d

        return processor

    def add_outbound_handler(self, handler, endpoint_name=None,
                             connector=None):
        if connector is None:
            connector = self.connectors[self.transport_name]

        processor = self._make_message_processor(handler)
        connector.set_outbound_handler(processor, endpoint_name=endpoint_name)

    def handle_outbound_message(self, message):
        """
        This must be overridden to read outbound messages and do the right
        thing with them.
        """
        raise NotImplementedError()

    @staticmethod
    def generate_message_id():
        """
        Generate a message id.
        """
        return TransportUserMessage.generate_id()
PK=JG|vumi/transports/scheduler.py# -*- test-case-name: vumi.transports.tests.test_scheduler -*-
import time
import iso8601
import pytz
import json
from datetime import datetime
from uuid import uuid4
import warnings

from twisted.internet.defer import inlineCallbacks
from twisted.internet.task import LoopingCall

from vumi import message


warnings.warn("vumi.transport.scheduler is deprecated. A replacement is coming"
              " soon.", category=DeprecationWarning)


class Scheduler(object):
    """
    Base class for stuff that needs to be published to a given queue
    at a given time.
    """

    def __init__(self, redis, callback, prefix='scheduler',
                    granularity=5, delivery_period=3, json_encoder=None,
                    json_decoder=None):
        self.r_server = redis
        self.r_prefix = prefix
        self.granularity = granularity
        self.delivery_period = delivery_period
        self._scheduled_timestamps_key = self.r_key("scheduled_timestamps")
        self.callback = callback
        self.json_encoder = json_encoder or message.JSONMessageEncoder
        self.json_decoder = json_decoder or message.date_time_decoder
        self.loop = LoopingCall(self.deliver_scheduled)

    @property
    def is_running(self):
        return self.loop.running

    def start(self):
        if not self.loop.running:
            self.loop.start(self.delivery_period, now=True)

    def stop(self):
        if self.loop.running:
            self.loop.stop()

    def r_key(self, key):
        """
        Prefix ``key`` with a worker-specific string.
        """
        return "#".join((self.r_prefix, key))

    def scheduled_key(self):
        """
        Construct a unique scheduled key.
        """
        timestamp = datetime.utcnow()
        unique_id = uuid4().get_hex()
        timestamp = timestamp.isoformat().split('.')[0]
        return self.r_key(".".join(("scheduled", timestamp, unique_id)))

    def get_scheduled(self, scheduled_key):
        return self.r_server.hgetall(scheduled_key)

    def get_next_write_timestamp(self, delta, now):
        now = int(now)
        return self.get_time_bucket(now + delta)

    def get_time_bucket(self, timestamp):
        timestamp += self.granularity - (timestamp % self.granularity)
        return datetime.utcfromtimestamp(timestamp).isoformat().split('.')[0]

    def get_read_timestamp(self, now):
        now = int(now)
        timestamp = datetime.utcfromtimestamp(now).replace(tzinfo=pytz.UTC)
        next_timestamp = self.r_server.zrange(self._scheduled_timestamps_key,
                                                0, 0)
        if next_timestamp:
            if iso8601.parse_date(next_timestamp[0]) <= timestamp:
                return next_timestamp[0]
        return None

    def get_next_read_timestamp(self):
        return self.get_read_timestamp(time.time())

    def get_scheduled_key(self, time):
        timestamp = self.get_time_bucket(time)
        bucket_key = self.r_key("scheduled_keys." + timestamp)
        # key of message to be delivered
        scheduled_key = self.r_server.spop(bucket_key)
        # if the set is empty, remove the timestamp entry from the
        # scheduled timestamps key
        if self.r_server.scard(bucket_key) < 1:
            self.r_server.zrem(self._scheduled_timestamps_key, timestamp)
        return scheduled_key

    def schedule(self, delta, payload, now=None):
        """
        Store the payload in Redis and call `self.callback` after
        `delta` seconds as counted from `now` onwards.


        :param delta: the amount of seconds
        :param payload: the payload send to `self.callback`
        :param now: Used to calculate the delta (timestamp in
                    seconds since epoch)

        If ``now`` is ``None`` then it will default to ``time.time()``
        """
        # do this first as we want it to blow up before any keys
        # are set should the content not be JSON encodable
        if not now:
            now = int(time.time())

        key = self.scheduled_key()
        self.add_to_scheduled_set(key)
        bucket_key = self.store_scheduled(key, delta, now)
        self.r_server.hmset(key, {
            'payload': json.dumps(payload, cls=self.json_encoder),
            'scheduled_at': datetime.utcnow().isoformat(),
            'bucket_key': bucket_key,
        })
        return key, bucket_key

    def add_to_scheduled_set(self, key):
        self.r_server.sadd(self.r_key("scheduled_keys"), key)

    def store_scheduled(self, scheduled_key, delta, now):
        timestamp = self.get_next_write_timestamp(delta, now)
        bucket_key = self.r_key("scheduled_keys." + timestamp)
        self.r_server.sadd(bucket_key, scheduled_key)
        self.store_read_timestamp(timestamp)
        return bucket_key

    def store_read_timestamp(self, timestamp):
        score = time.mktime(time.strptime(timestamp, "%Y-%m-%dT%H:%M:%S"))
        self.r_server.zadd(self._scheduled_timestamps_key, **{
            timestamp: score
        })

    def get_all_scheduled_keys(self):
        return self.r_server.smembers(self.r_key("scheduled_keys"))

    @inlineCallbacks
    def deliver_scheduled(self, _time=None):
        _time = _time or int(time.time())
        while True:
            scheduled_key = self.get_scheduled_key(_time - self.granularity)
            if not scheduled_key:
                return
            scheduled_data = self.get_scheduled(scheduled_key)
            scheduled_at = scheduled_data['scheduled_at']
            payload = json.loads(scheduled_data['payload'],
                                    object_hook=self.json_decoder)
            yield self.callback(scheduled_at, payload)
            self.clear_scheduled(scheduled_key)

    def clear_scheduled(self, key):
        self.r_server.srem(self.r_key("scheduled_keys"), key)
        message_data = self.get_scheduled(key)
        bucket_key = message_data['bucket_key']
        self.r_server.srem(bucket_key, key)
        self.r_server.delete(key)
PK=JG??vumi/transports/__init__.py"""Assorted core transports.

This is where all transports that are part of core vumi live.

.. note::
   Anything in :mod:`vumi.workers` is deprecated and needs to be migrated.
"""

from vumi.transports.base import Transport
from vumi.transports.failures import FailureWorker

__all__ = ['Transport', 'FailureWorker']
PK=JG</""vumi/transports/failures.py# -*- test-case-name: vumi.transports.tests.test_failures -*-

import time
from datetime import datetime
from uuid import uuid4

from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.task import LoopingCall

from vumi.service import Worker
from vumi.message import TransportMessage, to_json
from vumi.persist.txredis_manager import TxRedisManager


class FailureMessage(TransportMessage):
    MESSAGE_TYPE = 'failure_message'

    FC_UNSPECIFIED, FC_PERMANENT, FC_TEMPORARY = (None, 'permanent',
                                                  'temporary')

    def process_fields(self, fields):
        fields = super(FailureMessage, self).process_fields(fields)
        return fields

    def validate_fields(self):
        super(FailureMessage, self).validate_fields()
        self.assert_field_present(
            'message',
            'failure_code',
            'reason',
            )


class FailureCodeException(Exception):
    """Base class for exceptions encoding failure types."""
    def __init__(self, failure_code, msg):
        super(FailureCodeException, self).__init__(msg)
        self.failure_code = failure_code


class PermanentFailure(FailureCodeException):
    """Raise this failure if re-trying seems unlikely to succeed."""
    def __init__(self, msg):
        super(PermanentFailure, self).__init__(FailureMessage.FC_PERMANENT,
                                               msg)


class TemporaryFailure(FailureCodeException):
    """Raise this failure if re-trying might succeed."""
    def __init__(self, msg):
        super(TemporaryFailure, self).__init__(FailureMessage.FC_TEMPORARY,
                                               msg)


class FailureWorker(Worker):
    """
    Base class for transport failure handlers.

    Subclasses should implement :meth:`handle_failure`.
    """

    GRANULARITY = 5  # seconds
    DELIVERY_PERIOD = 3

    MAX_DELAY = 3600
    INITIAL_DELAY = 1
    DELAY_FACTOR = 3

    @inlineCallbacks
    def startWorker(self):
        self.configure_retries()
        yield self.set_up_redis()
        retry_rkey = self.get_rkey('retry')
        failures_rkey = self.get_rkey('failures')
        self.retry_publisher = yield self.publish_to(retry_rkey)
        self.consumer = yield self.consume(failures_rkey, self.process_message,
                                           message_class=FailureMessage)
        self.start_retry_delivery()

    @inlineCallbacks
    def stopWorker(self):
        if self.delivery_loop and self.delivery_loop.running:
            self.delivery_loop.stop()
            yield self.delivery_done
        yield self.consumer.stop()
        yield self.redis.close_manager()

    def configure_retries(self):
        for param in ['GRANULARITY', 'MAX_DELAY', 'INITIAL_DELAY',
                      'DELAY_FACTOR', 'DELIVERY_PERIOD']:
            setattr(self, param, self.config.get('retry_' + param.lower(),
                                                 getattr(self, param)))

    @inlineCallbacks
    def set_up_redis(self):
        r_config = self.config.get('redis_manager', {})
        redis = yield TxRedisManager.from_config(r_config)
        self.redis = redis.sub_manager("failures:%s" % (
                self.config['transport_name'],))

    def start_retry_delivery(self):
        self.delivery_loop = None
        if self.DELIVERY_PERIOD:
            self.delivery_loop = LoopingCall(self.deliver_retries)
            self.delivery_done = self.delivery_loop.start(self.DELIVERY_PERIOD)

    def get_rkey(self, route_name):
        return self.config['%s_routing_key' % route_name] % self.config

    def failure_key(self):
        """
        Construct a failure key.
        """
        timestamp = datetime.utcnow()
        failure_id = uuid4().get_hex()
        timestamp = timestamp.isoformat().split('.')[0]
        return ".".join(("failure", timestamp, failure_id))

    def add_to_failure_set(self, key):
        return self.redis.sadd("failure_keys", key)

    def get_failure_keys(self):
        return self.redis.smembers("failure_keys")

    @inlineCallbacks
    def store_failure(self, message, reason, retry_delay=None):
        """
        Store this failure in redis, with an optional retry delay.

        :param message: The failed message.
        :param reason: A string containing the failure reason.
        :param retry_delay: The (optional) retry delay in seconds.

        If ``retry_delay`` is not ``None``, a retry will be scheduled
        approximately ``retry_delay`` seconds in the future.
        """
        message_json = message
        if not isinstance(message, basestring):
            # This isn't already JSON-encoded.
            message_json = to_json(message)
        key = self.failure_key()
        if not retry_delay:
            retry_delay = 0
        yield self.redis.hmset(key, {
                "message": message_json,
                "reason": reason,
                "retry_delay": str(retry_delay),
                })
        yield self.add_to_failure_set(key)
        if retry_delay:
            yield self.store_retry(key, retry_delay)
        returnValue(key)

    def get_failure(self, failure_key):
        return self.redis.hgetall(failure_key)

    @inlineCallbacks
    def store_retry(self, failure_key, retry_delay, now=None):
        timestamp = self.get_next_write_timestamp(retry_delay, now=now)
        bucket_key = "retry_keys." + timestamp
        yield self.redis.sadd(bucket_key, failure_key)
        yield self.store_read_timestamp(timestamp)

    def store_read_timestamp(self, timestamp):
        score = time.mktime(time.strptime(timestamp, "%Y-%m-%dT%H:%M:%S"))
        return self.redis.zadd('retry_timestamps', **{timestamp: score})

    def get_next_write_timestamp(self, delta, now=None):
        if now is None:
            now = int(time.time())
        timestamp = now + delta
        timestamp += self.GRANULARITY - (timestamp % self.GRANULARITY)
        return datetime.utcfromtimestamp(timestamp).isoformat().split('.')[0]

    @inlineCallbacks
    def get_next_read_timestamp(self):
        now = int(time.time())
        timestamp = datetime.utcfromtimestamp(now).isoformat().split('.')[0]
        next_timestamp = yield self.redis.zrange('retry_timestamps', 0, 0)
        if next_timestamp and next_timestamp[0] <= timestamp:
            returnValue(next_timestamp[0])
        returnValue(None)

    @inlineCallbacks
    def get_next_retry_key(self):
        timestamp = yield self.get_next_read_timestamp()
        if not timestamp:
            return
        bucket_key = "retry_keys." + timestamp
        failure_key = yield self.redis.spop(bucket_key)
        if (yield self.redis.scard(bucket_key)) < 1:
            yield self.redis.zrem('retry_timestamps', timestamp)
        returnValue(failure_key)

    @inlineCallbacks
    def deliver_retry(self, retry_key, publisher):
        failure = yield self.get_failure(retry_key)
        published = yield publisher.publish_raw(failure['message'])
        returnValue(published)

    @inlineCallbacks
    def deliver_retries(self):
        while True:
            retry_key = yield self.get_next_retry_key()
            if not retry_key:
                return
            yield self.deliver_retry(retry_key, self.retry_publisher)

    def next_retry_delay(self, delay):
        if not delay:
            return self.INITIAL_DELAY
        return min(delay * self.DELAY_FACTOR, self.MAX_DELAY)

    def update_retry_metadata(self, message):
        rmd = message.get('retry_metadata', {})
        message['retry_metadata'] = {
            'retries': rmd.get('retries', 0) + 1,
            'delay': self.next_retry_delay(rmd.get('delay', 0)),
            }
        return message

    def handle_failure(self, message, failure_code, reason):
        """
        Handle a failed message from a transport.

        :param message: The failed message, as a dict.
        :param failure_code: The failure code.
        :param reason: A string containing the reason for the failure.

        This method should be overriden in subclasses to implement
        transport specific failure handling if needed.
        """
        if failure_code == FailureMessage.FC_TEMPORARY:
            return self.do_retry(message, reason)
        else:
            return self.store_failure(message, reason)

    def do_retry(self, message, reason):
        message = self.update_retry_metadata(message)
        return self.store_failure(
            message, reason, message['retry_metadata']['delay'])

    def process_message(self, failure_message):
        message = failure_message.payload['message']
        failure_code = failure_message.payload['failure_code']
        reason = failure_message.payload['reason']
        return self.handle_failure(message, failure_code, reason)
PK=JGnY]]#vumi/transports/netcore/__init__.pyfrom vumi.transports.netcore.netcore import NetcoreTransport

__all__ = ['NetcoreTransport']
PK=JGv"vumi/transports/netcore/netcore.py# -*- test-case-name: vumi.transports.netcore.tests.test_netcore -*-

from vumi.config import (
    ConfigServerEndpoint, ConfigText, ConfigBool, ConfigInt,
    ServerEndpointFallback)
from vumi.transports import Transport
from vumi.transports.httprpc.httprpc import HttpRpcHealthResource
from vumi.utils import build_web_site

from twisted.internet.defer import inlineCallbacks
from twisted.web import http
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET


class NetcoreTransportConfig(Transport.CONFIG_CLASS):

    twisted_endpoint = ConfigServerEndpoint(
        'The endpoint to listen on.',
        required=True, static=True,
        fallbacks=[ServerEndpointFallback()])
    web_path = ConfigText(
        "The path to serve this resource on.",
        default='/api/v1/netcore/', static=True)
    health_path = ConfigText(
        "The path to serve the health resource on.",
        default='/health/', static=True)
    reject_none = ConfigBool(
        "Reject messages where the content parameter equals 'None'",
        required=False, default=True, static=True)

    # TODO: Deprecate these fields when confmodel#5 is done.
    host = ConfigText(
        "*DEPRECATED* 'host' and 'port' fields may be used in place of the"
        " 'twisted_endpoint' field.", static=True)
    port = ConfigInt(
        "*DEPRECATED* 'host' and 'port' fields may be used in place of the"
        " 'twisted_endpoint' field.", static=True)


class NetcoreResource(Resource):

    isLeaf = True

    def __init__(self, transport):
        Resource.__init__(self)
        self.transport = transport
        self.config = transport.get_static_config()

    def render_POST(self, request):
        expected_keys = [
            'to_addr',
            'from_addr',
            'content',
            'circle',
            'source',
        ]

        received = set(request.args.keys())
        expected = set(expected_keys)
        if received != expected:
            request.setResponseCode(http.BAD_REQUEST)
            return ('Not all expected parameters received. '
                    'Only allowing: %r, received: %r' % (
                        expected_keys, request.args.keys()))
        param_values = [param[0] for param in request.args.values()]
        if not all(param_values):
            request.setResponseCode(http.BAD_REQUEST)
            return ('Not all parameters have values. '
                    'Received: %r' % (sorted(param_values),))

        content = request.args['content'][0]
        if self.config.reject_none and content == "None":
            request.setResponseCode(http.BAD_REQUEST)
            return ('"None" string literal not allowed for content parameter.')

        self.handle_request(request)
        return NOT_DONE_YET

    def handle_request(self, request):
        to_addr = request.args['to_addr'][0]
        from_addr = request.args['from_addr'][0]
        content = request.args['content'][0]
        circle = request.args['circle'][0]
        source = request.args['source'][0]

        # NOTE: If we have a leading 0 then the normalization middleware
        #       will deal with it.
        if not from_addr.startswith('0'):
            from_addr = '0%s' % (from_addr,)

        d = self.transport.handle_raw_inbound_message(
            to_addr, from_addr, content, circle, source)
        d.addCallback(lambda msg: request.write(msg['message_id']))
        d.addCallback(lambda _: request.finish())
        return d


class NetcoreTransport(Transport):

    CONFIG_CLASS = NetcoreTransportConfig

    @inlineCallbacks
    def setup_transport(self):
        config = self.get_static_config()
        self.endpoint = config.twisted_endpoint
        self.resource = NetcoreResource(self)

        self.factory = build_web_site({
            config.health_path: HttpRpcHealthResource(self),
            config.web_path: self.resource,
        })
        self.server = yield self.endpoint.listen(self.factory)

    def teardown_transport(self):
        return self.server.stopListening()

    def handle_raw_inbound_message(self, to_addr, from_addr, content,
                                   circle, source):
        return self.publish_message(
            content=content,
            from_addr=from_addr,
            to_addr=to_addr,
            transport_type='sms',
            transport_metadata={
                'netcore': {
                    'circle': circle,
                    'source': source,
                }
            })

    def get_health_response(self):
        return 'OK'
PK=JG)vumi/transports/netcore/tests/__init__.pyPK=JGo11-vumi/transports/netcore/tests/test_netcore.pyfrom urllib import urlencode

from twisted.internet.defer import inlineCallbacks
from twisted.web import http

from vumi.tests.helpers import VumiTestCase
from vumi.transports.netcore import NetcoreTransport
from vumi.transports.tests.helpers import TransportHelper
from vumi.utils import http_request_full


def request(transport, method, params={}, path=None):
    if path is None:
        path = transport.get_static_config().web_path

    addr = transport.server.getHost()
    url = 'http://%s:%s%s' % (addr.host,
                              addr.port,
                              path)
    return http_request_full(
        url, method=method, data=urlencode(params), headers={
            'Content-Type': ['application/x-www-form-urlencoded'],
        })


class NetCoreTestCase(VumiTestCase):

    transport_class = NetcoreTransport

    def setUp(self):
        self.tx_helper = self.add_helper(TransportHelper(self.transport_class))

    def get_transport(self, **config):
        defaults = {
            'twisted_endpoint': 'tcp:0',
        }
        defaults.update(config)
        return self.tx_helper.get_transport(defaults)

    @inlineCallbacks
    def test_inbound_sms_failure(self):
        transport = yield self.get_transport()
        resp = yield request(transport, 'POST', {
            'foo': 'bar'
        })
        self.assertEqual(resp.code, http.BAD_REQUEST)
        self.assertEqual(resp.delivered_body, (
            "Not all expected parameters received. Only allowing: "
            "['to_addr', 'from_addr', 'content', 'circle', 'source'], "
            "received: ['foo']"))
        self.assertEqual(
            [], self.tx_helper.get_dispatched_inbound())

    @inlineCallbacks
    def test_inbound_missing_values(self):
        transport = yield self.get_transport()
        resp = yield request(transport, 'POST', {
            'to_addr': '10010',
            'from_addr': '8800000000',
            'content': '',  # Intentionally empty!
            'source': 'sms',
            'circle': 'of life',
        })
        self.assertEqual(resp.code, http.BAD_REQUEST)
        self.assertEqual(resp.delivered_body, (
            "Not all parameters have values. "
            "Received: %r" % (sorted(
                ['', 'sms', 'of life', '10010', '8800000000']),)))
        self.assertEqual(
            [], self.tx_helper.get_dispatched_inbound())

    @inlineCallbacks
    def test_inbound_content_none_string_literal(self):
        transport = yield self.get_transport()
        resp = yield request(transport, 'POST', {
            'to_addr': '10010',
            'from_addr': '8800000000',
            'content': 'None',  # Python str(None) on netcore's side
            'source': 'sms',
            'circle': 'of life',
        })
        self.assertEqual(resp.code, http.BAD_REQUEST)
        self.assertEqual(resp.delivered_body, (
            '"None" string literal not allowed for content parameter.'))
        self.assertEqual(
            [], self.tx_helper.get_dispatched_inbound())

    @inlineCallbacks
    def test_inbound_sms_success(self):
        transport = yield self.get_transport()
        resp = yield request(transport, 'POST', {
            'to_addr': '10010',
            'from_addr': '8800000000',
            'content': 'foo',
            'source': 'sms',
            'circle': 'of life',
        })
        self.assertEqual(resp.code, http.OK)
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['to_addr'], '10010')
        self.assertEqual(msg['from_addr'], '08800000000')
        self.assertEqual(msg['content'], 'foo')
        self.assertEqual(msg['transport_metadata'], {
            'netcore': {
                'source': 'sms',
                'circle': 'of life',
            }
        })

    @inlineCallbacks
    def test_health_resource(self):
        transport = yield self.get_transport()
        health_path = transport.get_static_config().health_path
        resp = yield request(transport, 'GET', path=health_path)
        self.assertEqual(resp.delivered_body, 'OK')
        self.assertEqual(resp.code, http.OK)
PK=JGBF#vumi/transports/twitter/__init__.pyfrom vumi.transports.twitter.twitter import (
    ConfigTwitterEndpoints, TwitterTransport)

__all__ = ['ConfigTwitterEndpoints', 'TwitterTransport']
PK=JG9G))"vumi/transports/twitter/twitter.py# -*- test-case-name: vumi.transports.twitter.tests.test_twitter -*-
from twisted.python import log
from twisted.internet.defer import inlineCallbacks
from txtwitter.twitter import TwitterClient
from txtwitter import messagetools

from vumi.transports.base import Transport
from vumi.config import ConfigBool, ConfigText, ConfigList, ConfigDict


class ConfigTwitterEndpoints(ConfigDict):
    field_type = 'twitter_endpoints'

    def clean(self, value):
        endpoints_dict = super(ConfigTwitterEndpoints, self).clean(value)

        if 'dms' not in endpoints_dict and 'tweets' not in endpoints_dict:
            self.raise_config_error(
                "needs configuration for either dms, tweets or both")

        if endpoints_dict.get('dms') == endpoints_dict.get('tweets'):
            self.raise_config_error(
                "has the same endpoint for dms and tweets: '%s'"
                % endpoints_dict['dms'])

        return endpoints_dict


class TwitterTransportConfig(Transport.CONFIG_CLASS):
    screen_name = ConfigText(
        "The screen name for the twitter account",
        required=True, static=True)
    consumer_key = ConfigText(
        "The OAuth consumer key for the twitter account",
        required=True, static=True)
    consumer_secret = ConfigText(
        "The OAuth consumer secret for the twitter account",
        required=True, static=True)
    access_token = ConfigText(
        "The OAuth access token for the twitter account",
        required=True, static=True)
    access_token_secret = ConfigText(
        "The OAuth access token secret for the twitter account",
        required=True, static=True)
    endpoints = ConfigTwitterEndpoints(
        "Which endpoints to use for dms and tweets",
        default={'tweets': 'default'}, static=True)
    terms = ConfigList(
        "A list of terms to be tracked by the transport",
        default=[], static=True)
    autofollow = ConfigBool(
        "Determines whether the transport will follow users that follow the "
        "transport's user",
        default=False, static=True)


class TwitterTransport(Transport):
    """Twitter transport."""

    transport_type = 'twitter'

    CONFIG_CLASS = TwitterTransportConfig
    NO_USER_ADDR = 'NO_USER'

    OUTBOUND_HANDLERS = {
        'tweets': 'handle_outbound_tweet',
        'dms': 'handle_outbound_dm',
    }

    def get_client(self, *a, **kw):
        return TwitterClient(*a, **kw)

    def setup_transport(self):
        config = self.get_static_config()
        self.screen_name = config.screen_name

        self.autofollow = config.autofollow

        self.client = self.get_client(
            config.access_token,
            config.access_token_secret,
            config.consumer_key,
            config.consumer_secret)

        self.endpoints = config.endpoints

        for msg_type, endpoint in self.endpoints.iteritems():
            handler = getattr(self, self.OUTBOUND_HANDLERS[msg_type])
            handler = self.make_outbound_handler(handler)
            self.add_outbound_handler(handler, endpoint_name=endpoint)

        if config.terms:
            self.track_stream = self.client.stream_filter(
                self.handle_track_stream, track=config.terms)
            self.track_stream.startService()
        else:
            self.track_stream = None

        self.user_stream = self.client.userstream_user(
            self.handle_user_stream, with_='user')
        self.user_stream.startService()

    @inlineCallbacks
    def teardown_transport(self):
        if self.track_stream is not None:
            yield self.track_stream.stopService()

        yield self.user_stream.stopService()

    def make_outbound_handler(self, twitter_handler):
        @inlineCallbacks
        def handler(message):
            try:
                twitter_message = yield twitter_handler(message)

                yield self.publish_ack(
                    user_message_id=message['message_id'],
                    sent_message_id=twitter_message['id_str'])
            except Exception, e:
                reason = '%s' % (e,)
                log.err('Outbound twitter message failed: %s' % (reason,))

                yield self.publish_nack(
                    user_message_id=message['message_id'],
                    sent_message_id=message['message_id'],
                    reason=reason)

        return handler

    @classmethod
    def screen_name_as_addr(cls, screen_name):
        return u'@%s' % (screen_name,)

    @classmethod
    def addr_as_screen_name(cls, addr):
        return addr[1:] if addr.startswith('@') else addr

    def is_own_tweet(self, message):
        user = messagetools.tweet_user(message)
        return self.screen_name == messagetools.user_screen_name(user)

    def is_own_dm(self, message):
        sender = messagetools.dm_sender(message)
        return self.screen_name == messagetools.user_screen_name(sender)

    def is_own_follow(self, message):
        source_screen_name = messagetools.user_screen_name(message['source'])
        return source_screen_name == self.screen_name

    @classmethod
    def tweet_to_addr(cls, tweet):
        mentions = messagetools.tweet_user_mentions(tweet)
        to_addr = cls.NO_USER_ADDR

        if mentions:
            mention = mentions[0]
            [start_index, end_index] = mention['indices']

            if start_index == 0:
                to_addr = cls.screen_name_as_addr(mention['screen_name'])

        return to_addr

    @classmethod
    def tweet_from_addr(cls, tweet):
        user = messagetools.tweet_user(tweet)
        return cls.screen_name_as_addr(messagetools.user_screen_name(user))

    @classmethod
    def tweet_content(cls, tweet):
        to_addr = cls.tweet_to_addr(tweet)
        content = messagetools.tweet_text(tweet)

        if to_addr != cls.NO_USER_ADDR and content.startswith(to_addr):
            content = content[len(to_addr):].lstrip()

        return content

    def publish_tweet(self, tweet):
        return self.publish_message(
            content=self.tweet_content(tweet),
            to_addr=self.tweet_to_addr(tweet),
            from_addr=self.tweet_from_addr(tweet),
            transport_type=self.transport_type,
            routing_metadata={
                'endpoint_name': self.endpoints['tweets']
            },
            transport_metadata={
                'twitter': {
                    'status_id': messagetools.tweet_id(tweet)
                }
            },
            helper_metadata={
                'twitter': {
                    'in_reply_to_status_id': (
                        messagetools.tweet_in_reply_to_id(tweet)),
                    'in_reply_to_screen_name': (
                        messagetools.tweet_in_reply_to_screen_name(tweet)),
                    'user_mentions': messagetools.tweet_user_mentions(tweet),
                }
            })

    def publish_dm(self, dm):
        sender = messagetools.dm_sender(dm)
        recipient = messagetools.dm_recipient(dm)

        return self.publish_message(
            content=messagetools.dm_text(dm),
            to_addr=self.screen_name_as_addr(recipient['screen_name']),
            from_addr=self.screen_name_as_addr(sender['screen_name']),
            transport_type=self.transport_type,
            routing_metadata={
                'endpoint_name': self.endpoints['dms']
            },
            helper_metadata={
                'dm_twitter': {
                    'id': messagetools.dm_id(dm),
                    'user_mentions': messagetools.dm_user_mentions(dm),
                }
            })

    def handle_track_stream(self, message):
        if messagetools.is_tweet(message):
            if self.is_own_tweet(message):
                log.msg("Tracked own tweet: %r" % (message,))
            else:
                log.msg("Tracked a tweet: %r" % (message,))
                self.publish_tweet(message)
        else:
            log.msg("Received non-tweet from tracking stream: %r" % message)

    def handle_user_stream(self, message):
        if messagetools.is_tweet(message):
            return self.handle_inbound_tweet(message)
        elif messagetools.is_dm(message.get('direct_message', {})):
            return self.handle_inbound_dm(message['direct_message'])
        elif message.get('event') == 'follow':
            return self.handle_follow(message)

        log.msg(
            "Received a user stream message that we do not handle: %r" %
            message)

    def handle_follow(self, follow):
        if self.is_own_follow(follow):
            log.msg("Received own follow on user stream: %r" % (follow,))
            return

        log.msg("Received follow on user stream: %r" % (follow,))

        if self.autofollow:
            screen_name = messagetools.user_screen_name(follow['source'])
            log.msg("Auto-following '%s'" %
                    (self.screen_name_as_addr(screen_name,)))
            return self.client.friendships_create(screen_name=screen_name)

    def handle_inbound_dm(self, dm):
        if self.is_own_dm(dm):
            log.msg("Received own DM on user stream: %r" % (dm,))
        elif 'dms' not in self.endpoints:
            log.msg(
                "Discarding DM received on user stream, no endpoint "
                "configured for DMs: %r" % (dm,))
        else:
            log.msg("Received DM on user stream: %r" % (dm,))
            self.publish_dm(dm)

    def handle_inbound_tweet(self, tweet):
        if self.is_own_tweet(tweet):
            log.msg("Received own tweet on user stream: %r" % (tweet,))
        elif 'tweets' not in self.endpoints:
            log.msg(
                "Discarding tweet received on user stream, no endpoint "
                "configured for tweets: %r" % (tweet,))
        else:
            log.msg("Received tweet on user stream: %r" % (tweet,))
            self.publish_tweet(tweet)

    def handle_outbound_dm(self, message):
        return self.client.direct_messages_new(
            screen_name=self.addr_as_screen_name(message['to_addr']),
            text=message['content'])

    def handle_outbound_tweet(self, message):
        log.msg("Twitter transport sending tweet %r" % (message,))

        metadata = message['transport_metadata'].get(self.transport_type, {})
        in_reply_to_status_id = metadata.get('status_id')

        content = message['content']
        if message['to_addr'] != self.NO_USER_ADDR:
            content = "%s %s" % (message['to_addr'], content)

        return self.client.statuses_update(
            content, in_reply_to_status_id=in_reply_to_status_id)
PK=JG+3 B B-vumi/transports/twitter/tests/test_twitter.pyfrom twisted.internet.defer import inlineCallbacks
from txtwitter.tests.fake_twitter import FakeTwitter

from vumi.tests.utils import LogCatcher
from vumi.tests.helpers import VumiTestCase
from vumi.config import Config
from vumi.errors import ConfigError
from vumi.transports.twitter import (
    ConfigTwitterEndpoints, TwitterTransport)
from vumi.transports.tests.helpers import TransportHelper


class TestTwitterEndpointsConfig(VumiTestCase):
    def test_clean_no_endpoints(self):
        class ToyConfig(Config):
            endpoints = ConfigTwitterEndpoints("test endpoints")

        self.assertRaises(ConfigError, ToyConfig, {'endpoints': {}})

    def test_clean_same_endpoints(self):
        class ToyConfig(Config):
            endpoints = ConfigTwitterEndpoints("test endpoints")

        self.assertRaises(ConfigError, ToyConfig, {'endpoints': {
            'dms': 'default',
            'tweets': 'default'
        }})


class TestTwitterTransport(VumiTestCase):
    @inlineCallbacks
    def setUp(self):
        self.twitter = FakeTwitter()
        self.user = self.twitter.new_user('me', 'me')
        self.client = self.twitter.get_client(self.user.id_str)

        self.patch(
            TwitterTransport, 'get_client', lambda *a, **kw: self.client)

        self.tx_helper = self.add_helper(TransportHelper(TwitterTransport))

        self.config = {
            'screen_name': 'me',
            'consumer_key': 'consumer1',
            'consumer_secret': 'consumersecret1',
            'access_token': 'token1',
            'access_token_secret': 'tokensecret1',
            'terms': ['arnold', 'the', 'term'],
            'endpoints': {
                'tweets': 'tweet_endpoint',
                'dms': 'dm_endpoint'
            }
        }

        self.transport = yield self.tx_helper.get_transport(self.config)

    def test_config_endpoints_default(self):
        del self.config['endpoints']
        self.config['transport_name'] = 'twitter'
        config = TwitterTransport.CONFIG_CLASS(self.config)
        self.assertEqual(config.endpoints, {'tweets': 'default'})

    @inlineCallbacks
    def test_config_no_tracking_stream(self):
        self.config['terms'] = []
        transport = yield self.tx_helper.get_transport(self.config)
        self.assertEqual(transport.track_stream, None)

    @inlineCallbacks
    def test_tracking_tweets(self):
        someone = self.twitter.new_user('someone', 'someone')
        tweet = self.twitter.new_tweet('arnold', someone.id_str)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(msg['from_addr'], '@someone')
        self.assertEqual(msg['to_addr'], 'NO_USER')
        self.assertEqual(msg['content'], 'arnold')

        self.assertEqual(
            msg['transport_metadata'],
            {'twitter': {'status_id': tweet.id_str}})

        self.assertEqual(msg['helper_metadata'], {
            'twitter': {
                'in_reply_to_status_id': None,
                'in_reply_to_screen_name': None,
                'user_mentions': []
            }
        })

    @inlineCallbacks
    def test_tracking_reply_tweets(self):
        someone = self.twitter.new_user('someone', 'someone')
        someone_else = self.twitter.new_user('someone_else', 'someone_else')
        tweet1 = self.twitter.new_tweet('@someone_else hello', someone.id_str)
        tweet2 = self.twitter.new_tweet(
            '@someone arnold', someone_else.id_str, reply_to=tweet1.id_str)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(msg['from_addr'], '@someone_else')
        self.assertEqual(msg['to_addr'], '@someone')
        self.assertEqual(msg['content'], 'arnold')

        self.assertEqual(
            msg['transport_metadata'],
            {'twitter': {'status_id': tweet2.id_str}})

        self.assertEqual(msg['helper_metadata'], {
            'twitter': {
                'in_reply_to_status_id': tweet1.id_str,
                'in_reply_to_screen_name': 'someone',
                'user_mentions': [{
                    'id_str': someone.id_str,
                    'id': int(someone.id_str),
                    'indices': [0, 8],
                    'screen_name': someone.screen_name,
                    'name': someone.name,
                }]
            }
        })

    def test_tracking_own_messages(self):
        with LogCatcher() as lc:
            tweet = self.twitter.new_tweet('arnold', self.user.id_str)
            tweet = tweet.to_dict(self.twitter)

            self.assertTrue(any(
                "Tracked own tweet:" in msg for msg in lc.messages()))

    @inlineCallbacks
    def test_inbound_tweet(self):
        someone = self.twitter.new_user('someone', 'someone')
        tweet = self.twitter.new_tweet('@me hello', someone.id_str)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(msg['from_addr'], '@someone')
        self.assertEqual(msg['to_addr'], '@me')
        self.assertEqual(msg['content'], 'hello')
        self.assertEqual(msg.get_routing_endpoint(), 'tweet_endpoint')

        self.assertEqual(
            msg['transport_metadata'],
            {'twitter': {'status_id': tweet.id_str}})

        self.assertEqual(msg['helper_metadata'], {
            'twitter': {
                'in_reply_to_status_id': None,
                'in_reply_to_screen_name': 'me',
                'user_mentions': [{
                    'id_str': self.user.id_str,
                    'id': int(self.user.id_str),
                    'indices': [0, 3],
                    'screen_name': self.user.screen_name,
                    'name': self.user.name,
                }]
            }
        })

    @inlineCallbacks
    def test_inbound_tweet_reply(self):
        someone = self.twitter.new_user('someone', 'someone')
        tweet1 = self.twitter.new_tweet('@someone hello', self.user.id_str)
        tweet2 = self.twitter.new_tweet(
            '@me goodbye', someone.id_str, reply_to=tweet1.id_str)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(msg['from_addr'], '@someone')
        self.assertEqual(msg['to_addr'], '@me')
        self.assertEqual(msg['content'], 'goodbye')

        self.assertEqual(
            msg['transport_metadata'],
            {'twitter': {'status_id': tweet2.id_str}})

        self.assertEqual(msg['helper_metadata'], {
            'twitter': {
                'in_reply_to_status_id': tweet1.id_str,
                'in_reply_to_screen_name': 'me',
                'user_mentions': [{
                    'id_str': self.user.id_str,
                    'id': int(self.user.id_str),
                    'indices': [0, 3],
                    'screen_name': self.user.screen_name,
                    'name': self.user.name,
                }]
            }
        })

    def test_inbound_own_tweet(self):
        with LogCatcher() as lc:
            self.twitter.new_tweet('hello', self.user.id_str)

            self.assertTrue(any(
                "Received own tweet on user stream" in msg
                for msg in lc.messages()))

    @inlineCallbacks
    def test_inbound_tweet_no_endpoint(self):
        self.config['endpoints'] = {'dms': 'default'}
        yield self.tx_helper.get_transport(self.config)
        someone = self.twitter.new_user('someone', 'someone')

        with LogCatcher() as lc:
            self.twitter.new_tweet('@me hello', someone.id_str)

            self.assertTrue(any(
                "Discarding tweet received on user stream, no endpoint "
                "configured for tweets" in msg
                for msg in lc.messages()))

    @inlineCallbacks
    def test_inbound_dm(self):
        someone = self.twitter.new_user('someone', 'someone')
        dm = self.twitter.new_dm('hello @me', someone.id_str, self.user.id_str)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(msg['from_addr'], '@someone')
        self.assertEqual(msg['to_addr'], '@me')
        self.assertEqual(msg['content'], 'hello @me')
        self.assertEqual(msg.get_routing_endpoint(), 'dm_endpoint')

        self.assertEqual(msg['helper_metadata'], {
            'dm_twitter': {
                'id': dm.id_str,
                'user_mentions': [{
                    'id_str': self.user.id_str,
                    'id': int(self.user.id_str),
                    'indices': [6, 9],
                    'screen_name': self.user.screen_name,
                    'name': self.user.name,
                }]
            }
        })

    def test_inbound_own_dm(self):
        with LogCatcher() as lc:
            someone = self.twitter.new_user('someone', 'someone')
            self.twitter.new_dm('hello', self.user.id_str, someone.id_str)

            self.assertTrue(any(
                "Received own DM on user stream" in msg
                for msg in lc.messages()))

    @inlineCallbacks
    def test_inbound_dm_no_endpoint(self):
        self.config['endpoints'] = {'tweets': 'default'}
        yield self.tx_helper.get_transport(self.config)
        someone = self.twitter.new_user('someone', 'someone')

        with LogCatcher() as lc:
            self.twitter.new_dm('hello @me', someone.id_str, self.user.id_str)

            self.assertTrue(any(
                "Discarding DM received on user stream, no endpoint "
                "configured for DMs" in msg
                for msg in lc.messages()))

    @inlineCallbacks
    def test_auto_following(self):
        self.config['autofollow'] = True
        yield self.tx_helper.get_transport(self.config)

        with LogCatcher() as lc:
            someone = self.twitter.new_user('someone', 'someone')
            self.twitter.add_follow(someone.id_str, self.user.id_str)

            self.assertTrue(any(
                "Received follow on user stream" in msg
                for msg in lc.messages()))

            self.assertTrue(any(
                "Auto-following '@someone'" in msg
                for msg in lc.messages()))

        follow = self.twitter.get_follow(self.user.id_str, someone.id_str)
        self.assertEqual(follow.source_id, self.user.id_str)
        self.assertEqual(follow.target_id, someone.id_str)

    @inlineCallbacks
    def test_auto_following_disabled(self):
        self.config['autofollow'] = False
        yield self.tx_helper.get_transport(self.config)

        with LogCatcher() as lc:
            someone = self.twitter.new_user('someone', 'someone')
            self.twitter.add_follow(someone.id_str, self.user.id_str)

            self.assertTrue(any(
                "Received follow on user stream" in msg
                for msg in lc.messages()))

        follow = self.twitter.get_follow(self.user.id_str, someone.id_str)
        self.assertTrue(follow is None)

    def test_inbound_own_follow(self):
        with LogCatcher() as lc:
            someone = self.twitter.new_user('someone', 'someone')
            self.twitter.add_follow(self.user.id_str, someone.id_str)

            self.assertTrue(any(
                "Received own follow on user stream" in msg
                for msg in lc.messages()))

    @inlineCallbacks
    def test_tweet_sending(self):
        self.twitter.new_user('someone', 'someone')
        msg = yield self.tx_helper.make_dispatch_outbound(
            'hello', to_addr='@someone', endpoint='tweet_endpoint')
        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)

        self.assertEqual(ack['user_message_id'], msg['message_id'])

        tweet = self.twitter.get_tweet(ack['sent_message_id'])
        self.assertEqual(tweet.text, '@someone hello')
        self.assertEqual(tweet.reply_to, None)

    @inlineCallbacks
    def test_tweet_reply_sending(self):
        tweet1 = self.twitter.new_tweet(
            'hello', self.user.id_str, endpoint='tweet_endpoint')

        inbound_msg = self.tx_helper.make_inbound(
            'hello',
            from_addr='@someone',
            endpoint='tweet_endpoint',
            transport_metadata={
                'twitter': {'status_id': tweet1.id_str}
            })

        msg = yield self.tx_helper.make_dispatch_reply(inbound_msg, "goodbye")
        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)

        self.assertEqual(ack['user_message_id'], msg['message_id'])

        tweet2 = self.twitter.get_tweet(ack['sent_message_id'])
        self.assertEqual(tweet2.text, '@someone goodbye')
        self.assertEqual(tweet2.reply_to, tweet1.id_str)

    @inlineCallbacks
    def test_tweet_sending_failure(self):
        def fail(*a, **kw):
            raise Exception(':(')

        self.patch(self.client, 'statuses_update', fail)

        with LogCatcher() as lc:
            msg = yield self.tx_helper.make_dispatch_outbound(
                'hello', endpoint='tweet_endpoint')

            self.assertEqual(
                [e['message'][0] for e in lc.errors],
                ["'Outbound twitter message failed: :('"])

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['sent_message_id'], msg['message_id'])
        self.assertEqual(nack['nack_reason'], ':(')

    @inlineCallbacks
    def test_dm_sending(self):
        self.twitter.new_user('someone', 'someone')

        msg = yield self.tx_helper.make_dispatch_outbound(
            'hello', to_addr='@someone', endpoint='dm_endpoint')
        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)

        self.assertEqual(ack['user_message_id'], msg['message_id'])

        dm = self.twitter.get_dm(ack['sent_message_id'])
        sender = self.twitter.get_user(dm.sender_id_str)
        recipient = self.twitter.get_user(dm.recipient_id_str)

        self.assertEqual(dm.text, 'hello')
        self.assertEqual(sender.screen_name, 'me')
        self.assertEqual(recipient.screen_name, 'someone')

    @inlineCallbacks
    def test_dm_sending_failure(self):
        def fail(*a, **kw):
            raise Exception(':(')

        self.patch(self.client, 'direct_messages_new', fail)

        with LogCatcher() as lc:
            msg = yield self.tx_helper.make_dispatch_outbound(
                'hello', endpoint='dm_endpoint')

            self.assertEqual(
                [e['message'][0] for e in lc.errors],
                ["'Outbound twitter message failed: :('"])

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['sent_message_id'], msg['message_id'])
        self.assertEqual(nack['nack_reason'], ':(')

    def test_track_stream_for_non_tweet(self):
        with LogCatcher() as lc:
            self.transport.handle_track_stream({'foo': 'bar'})

            self.assertEqual(
                lc.messages(),
                ["Received non-tweet from tracking stream: {'foo': 'bar'}"])

    def test_user_stream_for_unsupported_message(self):
        with LogCatcher() as lc:
            self.transport.handle_user_stream({'foo': 'bar'})

            self.assertEqual(
                lc.messages(),
                ["Received a user stream message that we do not handle: "
                 "{'foo': 'bar'}"])

    def test_tweet_content_with_mention_at_start(self):
        self.assertEqual('hello', self.transport.tweet_content({
            'id_str': '12345',
            'text': '@fakeuser hello',
            'user': {},
            'entities': {
                'user_mentions': [{
                    'id_str': '123',
                    'screen_name': 'fakeuser',
                    'name': 'Fake User',
                    'indices': [0, 8]
                }]
            },
        }))

    def test_tweet_content_with_mention_not_at_start(self):
        self.assertEqual('hello @fakeuser!', self.transport.tweet_content({
            'id_str': '12345',
            'text': 'hello @fakeuser!',
            'user': {},
            'entities': {
                'user_mentions': [{
                    'id_str': '123',
                    'screen_name': 'fakeuser',
                    'name': 'Fake User',
                    'indices': [6, 14]
                }]
            },
        }))

    def test_tweet_content_with_no_mention(self):
        self.assertEqual('hello', self.transport.tweet_content({
            'id_str': '12345',
            'text': 'hello',
            'user': {},
            'entities': {
                'user_mentions': []
            },
        }))

    def test_tweet_content_with_no_user_in_text(self):
        self.assertEqual('NO_USER hello', self.transport.tweet_content({
            'id_str': '12345',
            'text': 'NO_USER hello',
            'user': {},
            'entities': {
                'user_mentions': []
            },
        }))
PK=JG)vumi/transports/twitter/tests/__init__.pyPK=JG5$-vumi/transports/vodacom_messaging/__init__.py"""Synchronous HTTP RPC-based message transports."""

from vumi.transports.vodacom_messaging.vodacom_messaging import (
    VodacomMessagingTransport, VodacomMessagingResponse)

__all__ = ['VodacomMessagingTransport', 'VodacomMessagingResponse']
PK=JGМ6vumi/transports/vodacom_messaging/vodacom_messaging.py# -*- test-case-name: vumi.transports.vodacom_messaging.tests.test_vodacom_messaging -*-

from vumi.message import TransportUserMessage
from vumi.transports.httprpc import HttpRpcTransport


class VodacomMessagingTransport(HttpRpcTransport):
    """Vodacom Messaging USSD over HTTP transport."""

    ENCODING = 'utf-8'

    def handle_raw_inbound_message(self, msgid, request):
        content = str(request.args.get('request', [None])[0])
        msisdn = str(request.args.get('msisdn', [None])[0])
        ussd_session_id = str(request.args.get('ussdSessionId', [None])[0])
        provider = str(request.args.get('provider', [None])[0])
        if content.startswith(self.config.get('ussd_string_prefix')):
            session_event = TransportUserMessage.SESSION_NEW
            to_addr = content
        else:
            session_event = TransportUserMessage.SESSION_RESUME
            to_addr = ''
        transport_metadata = {'session_id': ussd_session_id}
        self.publish_message(
                message_id=msgid,
                content=content,
                to_addr=to_addr,
                from_addr=msisdn,
                provider=provider,
                session_event=session_event,
                transport_name=self.transport_name,
                transport_type=self.config.get('transport_type'),
                transport_metadata=transport_metadata,
                )

    def handle_outbound_message(self, message):
        missing_fields = self.ensure_message_values(message,
                                ['in_reply_to', 'content'])
        if missing_fields:
            return self.reject_message(message, missing_fields)

        should_close = (message['session_event']
                        == TransportUserMessage.SESSION_CLOSE)
        vmr = VodacomMessagingResponse(self.config['web_host'],
                                        self.config['web_path'])
        vmr.set_headertext(message['content'])
        if not should_close:
            vmr.accept_freetext()
        self.finish_request(message['in_reply_to'],
                            unicode(vmr).encode(self.ENCODING))
        return self.publish_ack(user_message_id=message['message_id'],
            sent_message_id=message['message_id'])


class VodacomMessagingResponse(object):
    def __init__(self, web_host, web_path):
        self.web_host = web_host
        self.web_path = web_path
        self.freetext_option = None
        self.template_freetext_option_string = ('')
        self.option_list = []
        self.template_numbered_option_string = ('%(text)s')

    def set_headertext(self, headertext):
        self.headertext = headertext

    def add_option(self, text, order=None):
        self.freetext_option = None
        dictionary = {'text': text}
        if order:
            dictionary['order'] = int(order)
        else:
            dictionary['order'] = len(self.option_list) + 1
        dictionary.update({
            'web_path': self.web_path,
            'web_host': self.web_host})
        self.option_list.append(dictionary)

    def accept_freetext(self):
        self.option_list = []
        self.freetext_option = self.template_freetext_option_string % {
            'web_path': self.web_path,
            'web_host': self.web_host}

    def __str__(self):
        headertext = '\t%s\n' % self.headertext
        options = ''
        if self.freetext_option or len(self.option_list) > 0:
            options = '\t\n'
            for o in self.option_list:
                options += ('\t\t' + self.template_numbered_option_string % o
                            + '\n')
            if self.freetext_option:
                options += '\t\t' + self.freetext_option + '\n'
            options += '\t\n'
        response = '\n' + headertext + options + ''
        return response
PK=JG
//Avumi/transports/vodacom_messaging/tests/test_vodacom_messaging.pyimport re
from xml.etree import ElementTree
from urllib import urlencode

from twisted.internet.defer import inlineCallbacks
from vumi.utils import http_request
from vumi.transports.vodacom_messaging import (VodacomMessagingResponse,
    VodacomMessagingTransport)
from vumi.message import TransportUserMessage
from vumi.transports.tests.helpers import TransportHelper
from vumi.tests.helpers import VumiTestCase


class TestVodacomMessagingTransport(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.config = {
            'transport_type': 'ussd',
            'ussd_string_prefix': '*120*666#',
            'web_path': "/foo",
            'web_host': "127.0.0.1",
            'web_port': 0,
            'username': 'testuser',
            'password': 'testpass',
        }
        self.tx_helper = self.add_helper(
            TransportHelper(VodacomMessagingTransport))
        self.transport = yield self.tx_helper.get_transport(self.config)
        self.transport_url = self.transport.get_transport_url().rstrip('/')

    @inlineCallbacks
    def test_inbound_new_continue(self):
        url = "%s%s?%s" % (
            self.transport_url,
            self.config['web_path'],
            urlencode({
                'ussdSessionId': 123,
                'msisdn': 555,
                'provider': 'web',
                'request': '*120*666#',
            }))
        d = http_request(url, '', method='GET')
        msg, = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['transport_type'], "ussd")
        self.assertEqual(msg['transport_metadata'], {
            "session_id": "123"
        })
        self.assertEqual(msg['session_event'],
            TransportUserMessage.SESSION_NEW)
        self.assertEqual(msg['from_addr'], '555')
        self.assertEqual(msg['to_addr'], '*120*666#')
        self.assertEqual(msg['content'], '*120*666#')
        self.tx_helper.make_dispatch_reply(msg, "OK")
        response = yield d
        correct_response = '\n\tOK\n\t' \
                '\n\t\t\n\t\n'
        self.assertEqual(response, correct_response)

    @inlineCallbacks
    def test_inbound_resume_continue(self):
        url = "%s%s?%s" % (
            self.transport_url,
            self.config['web_path'],
            urlencode({
                'ussdSessionId': 123,
                'msisdn': 555,
                'provider': 'web',
                'request': 1,
            })
        )
        d = http_request(url, '', method='GET')
        msg, = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['transport_type'], "ussd")
        self.assertEqual(msg['transport_metadata'], {"session_id": "123"})
        self.assertEqual(
            msg['session_event'], TransportUserMessage.SESSION_RESUME)
        self.assertEqual(msg['from_addr'], '555')
        self.assertEqual(msg['to_addr'], '')
        self.assertEqual(msg['content'], '1')
        self.tx_helper.make_dispatch_reply(msg, "OK")
        response = yield d
        correct_response = '\n\tOK\n\t' \
                '\n\t\t\n\t\n'
        self.assertEqual(response, correct_response)

    @inlineCallbacks
    def test_inbound_resume_close(self):
        url = "%s%s?%s" % (
            self.transport_url,
            self.config['web_path'],
            urlencode({
                'ussdSessionId': 123,
                'msisdn': 555,
                'provider': 'web',
                'request': 1,
            })
        )
        d = http_request(url, '', method='GET')
        msg, = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['transport_type'], "ussd")
        self.assertEqual(msg['transport_metadata'], {"session_id": "123"})
        self.assertEqual(
            msg['session_event'], TransportUserMessage.SESSION_RESUME)
        self.assertEqual(msg['from_addr'], '555')
        self.assertEqual(msg['to_addr'], '')
        self.assertEqual(msg['content'], '1')
        self.tx_helper.make_dispatch_reply(msg, "OK", continue_session=False)
        response = yield d
        correct_response = '\n\tOK' + \
                            '\n'
        self.assertEqual(response, correct_response)

    @inlineCallbacks
    def test_nack(self):
        msg = yield self.tx_helper.make_dispatch_outbound("outbound")
        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['sent_message_id'], msg['message_id'])
        self.assertEqual(nack['nack_reason'],
            'Missing fields: in_reply_to')


class VodacomMessagingResponseTest(VumiTestCase):
    '''
    Test the construction of XML replies for Vodacom Messaging
    '''

    def setUp(self):
        self.web_host = 'vumi.p.org'
        self.web_path = '/api/v1/ussd/vmes/'

    def stdXML(self, obj):
        string = ElementTree.tostring(ElementTree.fromstring(str(obj)))
        return re.sub(r'\n\s*', '', string)

    def testMakeEndMessage(self):
        vmr = VodacomMessagingResponse(self.web_host, self.web_path)
        vmr.set_headertext("Goodbye")
        ref = '''
            
                Goodbye
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

    def testMakeFreetextMessage(self):
        vmr = VodacomMessagingResponse(self.web_host, self.web_path)
        vmr.set_headertext("Please enter your name")
        vmr.accept_freetext()
        ref = '''
            
                Please enter your name
                
                    
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))
        ref = '''
        
            Please enter your name
            
            
        
        '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

    def testMakeOptionMessage(self):
        vmr = VodacomMessagingResponse(self.web_host, self.web_path)
        vmr.set_headertext("Pick a card")
        vmr.accept_freetext()
        vmr.add_option("Ace of diamonds")
        vmr.add_option("2 of clubs")
        vmr.add_option("3 of hearts")
        ref = '''
            
                Pick a card
                
                    
                    
                    
                
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

        vmr.accept_freetext()
        ref = '''
            
                Pick a card
                
                    
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

        vmr.add_option("King of spades")
        vmr.add_option("Queen of diamonds")
        vmr.add_option("Joker")
        ref = '''
            
                Pick a card
                
                    
                    
                    
                
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

        ref = '''
        
            Pick a card
            
            
            
            
            
        
        '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))

    def testMakeOrderedOptionMessage(self):
        vmr = VodacomMessagingResponse(self.web_host, self.web_path)
        vmr.set_headertext("Pick a card")
        vmr.accept_freetext()
        vmr.add_option("3 of hearts", 3)
        vmr.add_option("2 of clubs", 2)
        vmr.add_option("Ace of diamonds", 1)
        ref = '''
            
                Pick a card
                
                    
                    
                    
                
            
            '''
        self.assertEquals(self.stdXML(vmr), self.stdXML(ref))
PK=JG3vumi/transports/vodacom_messaging/tests/__init__.pyPK=JGi]]#vumi/transports/devnull/__init__.pyfrom vumi.transports.devnull.devnull import DevNullTransport

__all__ = ['DevNullTransport']
PK=JG6t
L::"vumi/transports/devnull/devnull.py# -*- test-case-name: vumi.transports.devnull.tests.test_devnull -*-
import random
import uuid

from vumi import log

from vumi.transports.base import Transport
from vumi.message import TransportUserMessage

from twisted.internet.defer import inlineCallbacks


class DevNullTransport(Transport):
    """
    DevNullTransport for messages that need fake delivery to networks.
    Useful for testing.

    Configuration parameters:

    :type transport_type: str
    :param transport_type:
        The transport type to emulate, defaults to sms.
    :type ack_rate: float
    :param ack_rate:
        How many messages should be ack'd. The remainder will be nacked.
        The `failure_rate` and `reply_rate` treat the `ack_rate` as 100%.
    :type failure_rate: float
    :param failure_rate:
        How many messages should be treated as failures.
        Float value between 0.0 and 1.0.
    :type reply_rate: float
    :param reply_rate:
        For how many messages should we generate a reply?
        Float value between 0.0 and 1.0.
    :type reply_copy: str
    :param reply_copy:
        What copy should be sent as the reply, defaults to echo-ing the content
        of the outbound message.
    """

    def validate_config(self):
        self.transport_type = self.config.get('transport_type', 'sms')
        self.ack_rate = float(self.config['ack_rate'])
        self.failure_rate = float(self.config['failure_rate'])
        self.reply_rate = float(self.config['reply_rate'])
        self.reply_copy = self.config.get('reply_copy')

    def setup_transport(self):
        pass

    def teardown_transport(self):
        pass

    @inlineCallbacks
    def handle_outbound_message(self, message):
        if random.random() > self.ack_rate:
            yield self.publish_nack(message['message_id'],
                'Not accepted by network')
            return

        dr = ('failed' if random.random() < self.failure_rate
                else 'delivered')
        log.info('MT %(dr)s: %(from_addr)s -> %(to_addr)s: %(content)s' % {
            'dr': dr,
            'from_addr': message['from_addr'],
            'to_addr': message['to_addr'],
            'content': message['content'],
            })
        yield self.publish_ack(message['message_id'],
            sent_message_id=uuid.uuid4().hex)
        yield self.publish_delivery_report(message['message_id'], dr)
        if random.random() < self.reply_rate:
            reply_copy = self.reply_copy or message['content']
            log.info('MO %(from_addr)s -> %(to_addr)s: %(content)s' % {
                'from_addr': message['to_addr'],
                'to_addr': message['from_addr'],
                'content': reply_copy,
                })
            yield self.publish_message(
                message_id=uuid.uuid4().hex,
                content=reply_copy,
                to_addr=message['from_addr'],
                from_addr=message['to_addr'],
                provider='devnull',
                session_event=TransportUserMessage.SESSION_NONE,
                transport_type=self.transport_type,
                transport_metadata={})
PK=JGvb-vumi/transports/devnull/tests/test_devnull.pyfrom twisted.internet.defer import inlineCallbacks

from vumi.tests.helpers import VumiTestCase
from vumi.transports.devnull import DevNullTransport
from vumi.tests.utils import LogCatcher
from vumi.transports.tests.helpers import TransportHelper


class TestDevNullTransport(VumiTestCase):

    def setUp(self):
        self.tx_helper = self.add_helper(TransportHelper(DevNullTransport))

    @inlineCallbacks
    def test_outbound_logging(self):
        yield self.tx_helper.get_transport({
            'ack_rate': 1,
            'failure_rate': 0,
            'reply_rate': 1,
        })
        with LogCatcher() as logger:
            msg = yield self.tx_helper.make_dispatch_outbound("outbound")
        log_msg = logger.messages()[0]
        self.assertTrue(msg['to_addr'] in log_msg)
        self.assertTrue(msg['from_addr'] in log_msg)
        self.assertTrue(msg['content'] in log_msg)

    @inlineCallbacks
    def test_ack_publishing(self):
        yield self.tx_helper.get_transport({
            'ack_rate': 1,
            'failure_rate': 0.2,
            'reply_rate': 0.8,
        })
        yield self.tx_helper.make_dispatch_outbound("outbound")
        [ack, dr] = self.tx_helper.get_dispatched_events()
        self.assertEqual(ack['event_type'], 'ack')
        self.assertEqual(dr['event_type'], 'delivery_report')

    @inlineCallbacks
    def test_nack_publishing(self):
        yield self.tx_helper.get_transport({
            'ack_rate': 0,
            'failure_rate': 0.2,
            'reply_rate': 0.8,
        })
        yield self.tx_helper.make_dispatch_outbound("outbound")
        [nack] = self.tx_helper.get_dispatched_events()
        self.assertEqual(nack['event_type'], 'nack')

    @inlineCallbacks
    def test_reply_sending(self):
        yield self.tx_helper.get_transport({
            'ack_rate': 1,
            'failure_rate': 0,
            'reply_rate': 1,
        })

        msg = yield self.tx_helper.make_dispatch_outbound("outbound")
        [reply_msg] = self.tx_helper.get_dispatched_inbound()
        self.assertEqual(msg['content'], reply_msg['content'])
        self.assertEqual(msg['to_addr'], reply_msg['from_addr'])
        self.assertEqual(msg['from_addr'], reply_msg['to_addr'])
PK=JG)vumi/transports/devnull/tests/__init__.pyPK=JGl]]#vumi/transports/apposit/__init__.pyfrom vumi.transports.apposit.apposit import AppositTransport

__all__ = ['AppositTransport']
PKqG0Ӌ"vumi/transports/apposit/apposit.py# -*- test-case-name: vumi.transports.apposit.tests.test_apposit -*-

import json
from urllib import urlencode

from twisted.web import http
from twisted.internet.defer import inlineCallbacks

from vumi import log
from vumi.utils import http_request_full
from vumi.config import ConfigDict, ConfigText
from vumi.transports.httprpc import HttpRpcTransport


class AppositTransportConfig(HttpRpcTransport.CONFIG_CLASS):
    """Apposit transport config."""

    credentials = ConfigDict(
        "A dictionary where the `from_addr` is used for the key lookup and "
        "the returned value should be a dictionary containing the "
        "corresponding username, password and service id.",
        required=True, static=True)
    outbound_url = ConfigText(
        "The URL to send outbound messages to.", required=True, static=True)


class AppositTransport(HttpRpcTransport):
    """
    HTTP transport for Apposit's interconnection services.
    """
    agent_factory = None  # For swapping out the Agent we use in tests.

    ENCODING = 'utf-8'
    CONFIG_CLASS = AppositTransportConfig
    CONTENT_TYPE = 'application/x-www-form-urlencoded'

    # Apposit supports multiple channel types (e.g. sms, ussd, ivr, email).
    # Currently, we only have this working for sms, but theoretically, this
    # transport could support other channel types that have corresponding vumi
    # transport types. However, supporting other channels may require a bit
    # more work if they work too differently to the sms channel. For example,
    # support for Apposit's ussd channel will probably require session
    # management, which currently isn't included, since the sms channel does
    # not need this.
    CHANNEL_LOOKUP = {'sms': 'SMS'}
    TRANSPORT_TYPE_LOOKUP = dict(
        reversed(i) for i in CHANNEL_LOOKUP.iteritems())

    EXPECTED_FIELDS = frozenset(['from', 'to', 'channel', 'content', 'isTest'])

    KNOWN_ERROR_RESPONSE_CODES = {
        '102001': "Username Not Set",
        '102002': "Password Not Set",
        '102003': "Username or password is invalid or not authorized",
        '102004': "Service ID Not Set",
        '102005': "Invalid Service Id",
        '102006': "Service Not Found",
        '102007': "Content not set",
        '102008': "To Address Not Set",
        '102009': "From Address Not Set",
        '102010': "Channel Not Set",
        '102011': "Invalid Channel",
        '102012': "The address provided is not subscribed",
        '102013': "The message content id is unregistered or not approved",
        '102014': "Message Content Public ID and Message Content Set",
        '102015': "Message Content Public ID or Message Content Not Set",
        '102022': "One or more messages failed while sending",
        '102024': "Outbound message routing not allowed for service",
        '102025': "Content or Content ID is not Approved",
        '102999': "Other Runtime Error",
    }

    UNKNOWN_RESPONSE_CODE_ERROR = "Response with unknown code received: %s"
    UNSUPPORTED_TRANSPORT_TYPE_ERROR = (
        "No corresponding channel exists for transport type: %s")

    def validate_config(self):
        config = self.get_static_config()
        self.credentials = config.credentials
        self.outbound_url = config.outbound_url
        return super(AppositTransport, self).validate_config()

    @inlineCallbacks
    def handle_raw_inbound_message(self, message_id, request):
        values, errors = self.get_field_values(request, self.EXPECTED_FIELDS)

        channel = values.get('channel')
        if channel is not None and channel not in self.CHANNEL_LOOKUP.values():
            errors['unsupported_channel'] = channel

        if errors:
            log.msg('Unhappy incoming message: %s' % (errors,))
            yield self.finish_request(message_id, json.dumps(errors),
                                      code=http.BAD_REQUEST)
            return

        self.emit("AppositTransport receiving inbound message from "
                  "%(from)s to %(to)s" % values)

        yield self.publish_message(
            transport_name=self.transport_name,
            message_id=message_id,
            content=values['content'],
            from_addr=values['from'],
            to_addr=values['to'],
            provider='apposit',
            transport_type=self.TRANSPORT_TYPE_LOOKUP[channel],
            transport_metadata={'apposit': {'isTest': values['isTest']}})

        yield self.finish_request(
            message_id, json.dumps({'message_id': message_id}))

    @inlineCallbacks
    def handle_outbound_message(self, message):
        channel = self.CHANNEL_LOOKUP.get(message['transport_type'])
        if channel is None:
            reason = (self.UNSUPPORTED_TRANSPORT_TYPE_ERROR
                      % message['transport_type'])
            log.msg(reason)
            yield self.publish_nack(message['message_id'], reason)
            return

        self.emit("Sending outbound message: %s" % (message,))

        # build the params dict and ensure each param encoded correctly
        credentials = self.credentials.get(message['from_addr'], {})
        params = dict((k, v.encode(self.ENCODING)) for k, v in {
            'username': credentials.get('username', ''),
            'password': credentials.get('password', ''),
            'serviceId': credentials.get('service_id', ''),
            'fromAddress': message['from_addr'],
            'toAddress': message['to_addr'],
            'content': message['content'],
            'channel': channel,
        }.iteritems())

        self.emit("Making HTTP POST request: %s with body %s" %
                  (self.outbound_url, params))

        response = yield http_request_full(
            self.outbound_url,
            data=urlencode(params),
            method='POST',
            headers={'Content-Type': self.CONTENT_TYPE},
            agent_class=self.agent_factory)

        self.emit("Response: (%s) %r" %
                  (response.code, response.delivered_body))

        response_content = response.delivered_body.strip()
        if response.code == http.OK:
            yield self.publish_ack(user_message_id=message['message_id'],
                                   sent_message_id=message['message_id'])
        else:
            error = self.KNOWN_ERROR_RESPONSE_CODES.get(response_content)
            if error is not None:
                reason = "(%s) %s" % (response_content, error)
            else:
                reason = self.UNKNOWN_RESPONSE_CODE_ERROR % response_content
            log.msg(reason)
            yield self.publish_nack(message['message_id'], reason)
PK=JG)vumi/transports/apposit/tests/__init__.pyPKqG#r!W((-vumi/transports/apposit/tests/test_apposit.py# -*- encoding: utf-8 -*-

import json
from urllib import urlencode

from twisted.web import http
from twisted.internet.defer import inlineCallbacks, DeferredQueue

from vumi.utils import http_request_full
from vumi.transports.apposit import AppositTransport
from vumi.tests.fake_connection import FakeHttpServer
from vumi.tests.helpers import VumiTestCase
from vumi.transports.tests.helpers import TransportHelper


class TestAppositTransport(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.fake_http = FakeHttpServer(self.handle_inbound_request)
        self.outbound_requests = DeferredQueue()
        self.fake_http_response = ''
        self.fake_http_response_code = http.OK
        self.base_url = "http://apposit.example.com/"

        config = {
            'web_path': 'api/v1/apposit/sms',
            'web_port': 0,
            'credentials': {
                '8123': {
                    'username': 'root',
                    'password': 'toor',
                    'service_id': 'service-id-1',
                },
                '8124': {
                    'username': 'admin',
                    'password': 'nimda',
                    'service_id': 'service-id-2',
                }
            },
            'outbound_url': self.base_url,
        }
        self.tx_helper = self.add_helper(
            TransportHelper(
                AppositTransport, transport_addr='8123',
                mobile_addr='251911223344'))
        self.transport = yield self.tx_helper.get_transport(config)
        self.transport.agent_factory = self.fake_http.get_agent
        self.transport_url = self.transport.get_transport_url()
        self.web_path = config['web_path']

    def send_full_inbound_request(self, **params):
        return http_request_full(
            '%s%s' % (self.transport_url, self.web_path),
            data=urlencode(params),
            method='POST',
            headers={'Content-Type': self.transport.CONTENT_TYPE})

    def send_inbound_request(self, **kwargs):
        params = {
            'from': '251911223344',
            'to': '8123',
            'channel': 'SMS',
            'content': 'never odd or even',
            'isTest': 'true',
        }
        params.update(kwargs)
        return self.send_full_inbound_request(**params)

    def handle_inbound_request(self, request):
        self.outbound_requests.put(request)
        request.setResponseCode(self.fake_http_response_code)
        return self.fake_http_response

    def set_fake_http_response(self, code=http.OK, body=''):
        self.fake_http_response_code = code
        self.fake_http_response = body

    def assert_outbound_request(self, request, **kwargs):
        expected_args = {
            'username': 'root',
            'password': 'toor',
            'serviceId': 'service-id-1',
            'fromAddress': '8123',
            'toAddress': '251911223344',
            'content': 'so many dynamos',
            'channel': 'SMS',
        }
        expected_args.update(kwargs)

        self.assertEqual(request.path, self.base_url)
        self.assertEqual(request.method, 'POST')
        self.assertEqual(dict((k, [v]) for k, v in expected_args.iteritems()),
                         request.args)
        self.assertEqual(request.getHeader('Content-Type'),
                         self.transport.CONTENT_TYPE)

    def assert_message_fields(self, msg, **kwargs):
        fields = {
            'transport_name': self.tx_helper.transport_name,
            'transport_type': 'sms',
            'from_addr': '251911223344',
            'to_addr': '8123',
            'content': 'so many dynamos',
            'provider': 'apposit',
            'transport_metadata': {'apposit': {'isTest': 'true'}},
        }
        fields.update(kwargs)

        for field_name, expected_value in fields.iteritems():
            self.assertEqual(msg[field_name], expected_value)

    def assert_ack(self, ack, msg):
        self.assertEqual(ack.payload['event_type'], 'ack')
        self.assertEqual(ack.payload['user_message_id'], msg['message_id'])
        self.assertEqual(ack.payload['sent_message_id'], msg['message_id'])

    def assert_nack(self, nack, msg, reason):
        self.assertEqual(nack.payload['event_type'], 'nack')
        self.assertEqual(nack.payload['user_message_id'], msg['message_id'])
        self.assertEqual(nack.payload['nack_reason'], reason)

    @inlineCallbacks
    def test_inbound(self):
        response = yield self.send_inbound_request(**{
            'from': '251911223344',
            'to': '8123',
            'content': 'so many dynamos',
            'channel': 'SMS',
            'isTest': 'true',
        })

        [msg] = self.tx_helper.get_dispatched_inbound()
        self.assert_message_fields(
            msg,
            transport_name=self.tx_helper.transport_name,
            transport_type='sms',
            from_addr='251911223344',
            to_addr='8123',
            content='so many dynamos',
            provider='apposit',
            transport_metadata={'apposit': {'isTest': 'true'}})

        self.assertEqual(response.code, http.OK)
        self.assertEqual(json.loads(response.delivered_body),
                         {'message_id': msg['message_id']})

    @inlineCallbacks
    def test_outbound(self):
        msg = yield self.tx_helper.make_dispatch_outbound('racecar')

        request = yield self.outbound_requests.get()
        self.assert_outbound_request(request, **{
            'username': 'root',
            'password': 'toor',
            'serviceId': 'service-id-1',
            'content': 'racecar',
            'fromAddress': '8123',
            'toAddress': '251911223344',
            'channel': 'SMS'
        })

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_ack(ack, msg)

    @inlineCallbacks
    def test_inbound_requests_for_non_ascii_content(self):
        response = yield self.send_inbound_request(
            content=u'Hliðskjálf'.encode('UTF-8'))
        [msg] = self.tx_helper.get_dispatched_inbound()
        self.assert_message_fields(msg, content=u'Hliðskjálf')

        self.assertEqual(response.code, http.OK)
        self.assertEqual(json.loads(response.delivered_body),
                         {'message_id': msg['message_id']})

    @inlineCallbacks
    def test_inbound_requests_for_unsupported_channel(self):
        response = yield self.send_full_inbound_request(**{
            'from': '251911223344',
            'to': '8123',
            'channel': 'steven',
            'content': 'never odd or even',
            'isTest': 'false',
        })

        self.assertEqual(response.code, 400)
        self.assertEqual(json.loads(response.delivered_body),
                         {'unsupported_channel': 'steven'})

    @inlineCallbacks
    def test_inbound_requests_for_unexpected_param(self):
        response = yield self.send_full_inbound_request(**{
            'from': '251911223344',
            'to': '8123',
            'channel': 'SMS',
            'steven': 'its a trap',
            'content': 'never odd or even',
            'isTest': 'false',
        })

        self.assertEqual(response.code, 400)
        self.assertEqual(json.loads(response.delivered_body),
                         {'unexpected_parameter': ['steven']})

    @inlineCallbacks
    def test_inbound_requests_for_missing_param(self):
        response = yield self.send_full_inbound_request(**{
            'from': '251911223344',
            'to': '8123',
            'content': 'never odd or even',
            'isTest': 'false',
        })

        self.assertEqual(response.code, 400)
        self.assertEqual(json.loads(response.delivered_body),
                         {'missing_parameter': ['channel']})

    @inlineCallbacks
    def test_outbound_request_credential_selection(self):
        msg1 = yield self.tx_helper.make_dispatch_outbound(
            'so many dynamos', from_addr='8123')
        request1 = yield self.outbound_requests.get()
        self.assert_outbound_request(
            request1,
            fromAddress='8123',
            username='root',
            password='toor',
            serviceId='service-id-1')

        msg2 = yield self.tx_helper.make_dispatch_outbound(
            'so many dynamos', from_addr='8124')
        request2 = yield self.outbound_requests.get()
        self.assert_outbound_request(
            request2,
            fromAddress='8124',
            username='admin',
            password='nimda',
            serviceId='service-id-2')

        [ack1, ack2] = yield self.tx_helper.wait_for_dispatched_events(2)
        self.assert_ack(ack1, msg1)
        self.assert_ack(ack2, msg2)

    @inlineCallbacks
    def test_outbound_requests_for_non_ascii_content(self):
        msg = yield self.tx_helper.make_dispatch_outbound(u'Hliðskjálf')
        request = yield self.outbound_requests.get()
        self.assert_outbound_request(request, content='Hliðskjálf')

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_ack(ack, msg)

    @inlineCallbacks
    def test_outbound_requests_for_known_error_responses(self):
        code = '102999'
        self.set_fake_http_response(http.BAD_REQUEST, code)

        msg = yield self.tx_helper.make_dispatch_outbound('racecar')

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_nack(nack, msg, "(%s) %s" % (
            code, self.transport.KNOWN_ERROR_RESPONSE_CODES[code]))

    @inlineCallbacks
    def test_outbound_requests_for_unknown_error_responses(self):
        code = '103000'
        self.set_fake_http_response(http.BAD_REQUEST, code)

        msg = yield self.tx_helper.make_dispatch_outbound("so many dynamos")

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_nack(
            nack, msg, self.transport.UNKNOWN_RESPONSE_CODE_ERROR % code)

    @inlineCallbacks
    def test_outbound_requests_for_unsupported_transport_types(self):
        transport_type = 'steven'
        msg = yield self.tx_helper.make_dispatch_outbound(
            "so many dynamos", transport_type=transport_type)

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_nack(
            nack, msg,
            self.transport.UNSUPPORTED_TRANSPORT_TYPE_ERROR % transport_type)
PK=JGOY!vumi/transports/integrat/utils.py# -*- test-case-name: vumi.transports.integrat.tests.test_utils -*-

from xml.etree import ElementTree


def safetext(element):
    return element.text or ''


class HigateXMLParser(object):

    def parse(self, xmlstring):
        element = ElementTree.fromstring(xmlstring)

        messagedict = {}
        try:
            responselist = element.find("Response").items()
            for i in responselist:
                messagedict[i[0]] = i[1]
        except Exception:
            pass
        try:
            requestlist = element.find("Request").items()
            for i in requestlist:
                messagedict[i[0]] = i[1]
        except Exception:
            pass

        ##############  Conditional checks ##########################

        if messagedict.get('Type') == "OnResult":
            resultlist = element.find("Response").find("OnResult").items()
            for i in resultlist:
                messagedict[i[0]] = i[1]

        if messagedict.get('Type') == "SendSMS":
            pass  # TODO

        if messagedict.get('Type') == "OnReceiveSMS":
            receivelist = element.find("Response").find("OnReceiveSMS").items()
            hex = safetext(element.find("Response").find("OnReceiveSMS").find(
                "Content"))
            messagedict['hex'] = hex
            for i in receivelist:
                messagedict[i[0]] = i[1]

        if messagedict.get('Type') == "OnOBSResponse":
            pass  # TODO

        if messagedict.get('Type') == "OnLBSResponse":
            pass  # TODO

        if messagedict.get('Type') == "OnUSSEvent":
            contextlist = element.find("Response").find("OnUSSEvent").find(
                "USSContext").items()
            if element.find("Response").find("OnUSSEvent").find(
                    "USSText") is not None:
                USSText = safetext(element.find("Response").find(
                        "OnUSSEvent").find("USSText"))

                messagedict['USSText'] = USSText

            messagedict['EventType'] = element.find("Response").find(
                "OnUSSEvent").attrib['Type']

            for i in contextlist:
                messagedict[i[0]] = i[1]

        if messagedict.get('Type') == "USSReply":
            UserID = safetext(element.find("Request").find("UserID"))
            Password = safetext(element.find("Request").find("Password"))
            USSText = safetext(element.find("Request").find("USSText"))
            messagedict['UserID'] = UserID
            messagedict['Password'] = Password
            messagedict['USSText'] = USSText

        #############################################################

        return messagedict

    def parse_response(self, xmlstring):
        element = ElementTree.fromstring(xmlstring)
        status_code = int(element.get('status_code'))
        if not status_code:
            return {}

        data = element.find('Data')
        error_elements = element.findall('Data/field')
        messagedict = {
            'status_code': status_code,
            'error': data.get('name'),
            'error_fields': [{f.get('name'): f.get('value')}
                                for f in error_elements],
        }

        return messagedict

    def build(self, messagedict):
        message = ElementTree.Element("Message")
        version = ElementTree.SubElement(message, "Version")
        version.set("Version", "1.0")

        ##############  Conditional checks ##########################

        if messagedict.get("Type") == "USSReply":
            request = ElementTree.SubElement(message, "Request")
            request.set("Type", messagedict.get("Type"))
            request.set("SessionID", messagedict.get("SessionID", ""))
            request.set("Flags", messagedict.get("Flags", "0"))
            userid = ElementTree.SubElement(request, "UserID")
            userid.set("Orientation", "TR")
            userid.text = messagedict.get("UserID", "")
            password = ElementTree.SubElement(request, "Password")
            password.text = messagedict.get("Password", "")
            usstext = ElementTree.SubElement(request, "USSText")
            usstext.set("Type", "TEXT")
            usstext.text = messagedict.get("USSText", "")

        #############################################################

        return ElementTree.tostring(message)
PKqG׵$vumi/transports/integrat/integrat.py# -*- test-case-name: vumi.transports.integrat.tests.test_integrat -*-

from twisted.internet.defer import inlineCallbacks
from twisted.web import http
from twisted.web.resource import Resource

from vumi.utils import http_request, normalize_msisdn
from vumi.message import TransportUserMessage
from vumi.transports.integrat.utils import HigateXMLParser
from vumi.transports import Transport

hxg = HigateXMLParser()


class IntegratHttpResource(Resource):
    isLeaf = True

    # map events to session event types
    EVENT_TYPE_MAP = {
        'open': TransportUserMessage.SESSION_NEW,
        'close': TransportUserMessage.SESSION_CLOSE,
        'resume': TransportUserMessage.SESSION_RESUME,
        }

    # Integrat sends both 'new' and 'open' events but
    # we only pass 'open' events on ('open' is the more
    # complete and reliable of the two in Integrat's case).
    EVENTS_TO_SKIP = set(['new'])

    def __init__(self, transport_name, transport_type, publish_message):
        self.transport_name = transport_name
        self.transport_type = transport_type
        self.publish_message = publish_message

    def render(self, request):
        request.setResponseCode(http.OK)
        request.setHeader('Content-Type', 'text/plain')
        hxg_msg = hxg.parse(request.content.read())

        if hxg_msg.get('Type') != 'OnUSSEvent':
            # TODO: add support for non-USSD messages
            return ''

        text = hxg_msg.get('USSText', '').strip()

        if hxg_msg['EventType'] == 'Request':
            if text == 'REQ':
                # This indicates a new session event but Integrat
                # also sends a non-request message with type 'open'
                # below (and that is the one we use to trigger Vumi's
                # new session message.
                return ''
            else:
                session_event = TransportUserMessage.SESSION_RESUME
        else:
            event_type = hxg_msg['EventType'].lower()
            if event_type in self.EVENTS_TO_SKIP:
                return ''
            session_event = self.EVENT_TYPE_MAP.get(
                event_type, TransportUserMessage.SESSION_RESUME)

        if session_event != TransportUserMessage.SESSION_RESUME:
            text = None

        transport_metadata = {
            'session_id': hxg_msg['SessionID'],
            }
        self.publish_message(
            from_addr=normalize_msisdn(hxg_msg['MSISDN']),
            to_addr=hxg_msg['ConnStr'],
            session_event=session_event,
            content=text,
            transport_name=self.transport_name,
            transport_type=self.transport_type,
            transport_metadata=transport_metadata,
            )
        return ''


class HealthResource(Resource):
    isLeaf = True

    def render(self, request):
        request.setResponseCode(http.OK)
        request.do_not_log = True
        return 'OK'


class IntegratTransport(Transport):
    """Integrat USSD transport over HTTP."""

    agent_factory = None  # For swapping out the Agent we use in tests.

    def validate_config(self):
        """
        Transport-specific config validation happens in here.
        """
        self.web_path = self.config['web_path']
        self.web_port = int(self.config['web_port'])
        self.integrat_url = self.config['url']
        self.integrat_username = self.config['username']
        self.integrat_password = self.config['password']
        self.transport_type = self.config.get('transport_type', 'ussd')

    @inlineCallbacks
    def setup_transport(self):
        """
        All transport_specific setup should happen in here.
        """
        integrat_resource = IntegratHttpResource(
            self.transport_name, self.transport_type, self.publish_message)
        self.web_resource = yield self.start_web_resources([
            (integrat_resource, self.web_path),
            (HealthResource(), 'health'),
        ], self.web_port)

    @inlineCallbacks
    def teardown_transport(self):
        yield self.web_resource.loseConnection()

    @inlineCallbacks
    def handle_outbound_message(self, message):
        text = message['content']
        if text is None:
            text = ''
        flags = '0'
        if message['session_event'] == message.SESSION_CLOSE:
            flags = '1'
        session_id = message['transport_metadata']['session_id']
        response = yield http_request(self.integrat_url, hxg.build({
            'Flags': flags,
            'SessionID': session_id,
            'Type': 'USSReply',
            'USSText': text,
            'Password': self.integrat_password,
            'UserID': self.integrat_username,
        }), headers={
            'Content-Type': ['text/xml; charset=utf-8']
        }, agent_class=self.agent_factory)
        error = hxg.parse_response(response)
        if not error:
            yield self.publish_ack(user_message_id=message['message_id'],
                                   sent_message_id=message['message_id'])
        else:
            yield self.publish_nack(
                user_message_id=message['message_id'],
                sent_message_id=message['message_id'],
                reason=', '.join([': '.join(ef.items()[0])
                                  for ef in error['error_fields']]))
PK=JG:tA$vumi/transports/integrat/__init__.py"""
Integrat HTTP USSD API.
"""

from vumi.transports.integrat.integrat import IntegratTransport


__all__ = ['IntegratTransport']
PKqGg+U%%/vumi/transports/integrat/tests/test_integrat.py# -*- encoding: utf-8 -*-

from twisted.internet.defer import inlineCallbacks, DeferredQueue

from vumi.utils import http_request
from vumi.message import TransportUserMessage
from vumi.transports.integrat.integrat import (
    IntegratHttpResource, IntegratTransport)
from vumi.transports.tests.helpers import TransportHelper
from vumi.tests.fake_connection import FakeHttpServer
from vumi.tests.helpers import VumiTestCase


XML_TEMPLATE = '''

    
    
        Higate
        LoginName
        SERVICECODE
        
        
            
            %(text)s
        
    

'''


class TestIntegratHttpResource(VumiTestCase):

    DEFAULT_MSG = {
        'from_addr': '+2799053421',
        'to_addr': '*120*44#',
        'transport_metadata': {
            'session_id': 'sess1234',
            },
        }

    def setUp(self):
        self.msgs = []
        resource = IntegratHttpResource("testgrat", "ussd", self._publish)
        self.fake_http = FakeHttpServer(lambda req: resource.render(req))

    def _publish(self, **kws):
        self.msgs.append(kws)

    def send_request(self, xml):
        return http_request(
            "/", xml, method='GET', agent_class=self.fake_http.get_agent)

    @inlineCallbacks
    def check_response(self, xml, responses):
        """Check that sending the given XML results in the given responses."""
        yield self.send_request(xml)
        for msg, expected_override in zip(self.msgs, responses):
            expected = self.DEFAULT_MSG.copy()
            expected.update(expected_override)
            for key, value in expected.items():
                self.assertEqual(msg[key], value)
        self.assertEqual(len(self.msgs), len(responses))
        del self.msgs[:]

    def make_ussd(self, ussd_type, sid=None, network_sid="netsid12345",
                  msisdn=None, connstr=None, text=""):
        sid = (self.DEFAULT_MSG['transport_metadata']['session_id']
               if sid is None else sid)
        msisdn = self.DEFAULT_MSG['from_addr'] if msisdn is None else msisdn
        connstr = self.DEFAULT_MSG['to_addr'] if connstr is None else connstr
        return XML_TEMPLATE % {
            'ussd_type': ussd_type, 'sid': sid, 'network_sid': network_sid,
            'msisdn': msisdn, 'connstr': connstr, 'text': text,
            }

    @inlineCallbacks
    def test_new_session(self):
        # this should not generate a message since we use
        # the 'open' event to start a new sesison.
        xml = self.make_ussd(ussd_type='New', text="")
        yield self.check_response(xml, [])

    @inlineCallbacks
    def test_new_session_via_request(self):
        # this should not generate a message since we use
        # the 'open' event to start a new sesison.
        xml = self.make_ussd(ussd_type='Request', text="REQ")
        yield self.check_response(xml, [])

    @inlineCallbacks
    def test_open_session(self):
        xml = self.make_ussd(ussd_type='Open', text="")
        yield self.check_response(xml, [{
            'session_event': TransportUserMessage.SESSION_NEW,
            'content': None,
            }])

    @inlineCallbacks
    def test_resume_session(self):
        xml = self.make_ussd(ussd_type='Request', text="foo")
        yield self.check_response(xml, [{
            'session_event': TransportUserMessage.SESSION_RESUME,
            'content': 'foo',
            }])

    @inlineCallbacks
    def test_non_ussd(self):
        xml = """
        
            
            
            
              06052677F6A565 ...etc
            
            
        
        """
        yield self.check_response(xml, [])


class TestIntegratTransport(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.integrat_calls = DeferredQueue()
        self.fake_http = FakeHttpServer(self.handle_request)
        self.base_url = "http://integrat.example.com/"
        config = {
            'web_path': "foo",
            'web_port': "0",
            'url': self.base_url,
            'username': 'testuser',
            'password': 'testpass',
            }
        self.tx_helper = self.add_helper(TransportHelper(IntegratTransport))
        self.transport = yield self.tx_helper.get_transport(config)
        self.transport.agent_factory = self.fake_http.get_agent
        addr = self.transport.web_resource.getHost()
        self.transport_url = "http://%s:%s/" % (addr.host, addr.port)
        self.higate_response = ''

    def handle_request(self, request):
        # The content attr will have been set to None by the time we read this.
        request.content_body = request.content.getvalue()
        self.integrat_calls.put(request)
        return self.higate_response

    @inlineCallbacks
    def test_health(self):
        result = yield http_request(self.transport_url + "health", "",
                                    method='GET')
        self.assertEqual(result, "OK")

    @inlineCallbacks
    def test_outbound(self):
        yield self.tx_helper.make_dispatch_outbound("hi", transport_metadata={
            'session_id': "sess123",
        })
        req = yield self.integrat_calls.get()
        self.assertEqual(req.path, self.base_url)
        self.assertEqual(req.method, 'POST')
        self.assertEqual(req.getHeader('content-type'),
                         'text/xml; charset=utf-8')
        self.assertEqual(req.content_body,
                         ''
                         ''
                         'testuser'
                         'testpass'
                         'hi'
                         '')

    @inlineCallbacks
    def test_outbound_no_content(self):
        yield self.tx_helper.make_dispatch_outbound(None, transport_metadata={
            'session_id': "sess123",
        })
        req = yield self.integrat_calls.get()
        self.assertEqual(req.path, self.base_url)
        self.assertEqual(req.method, 'POST')
        self.assertEqual(req.getHeader('content-type'),
                         'text/xml; charset=utf-8')
        self.assertEqual(req.content_body,
                         ''
                         ''
                         'testuser'
                         'testpass'
                         ''
                         '')

    @inlineCallbacks
    def test_inbound(self):
        xml = XML_TEMPLATE % {
            'ussd_type': 'Request',
            'sid': 'sess1234',
            'network_sid': "netsid12345",
            'msisdn': '27345',
            'connstr': '*120*99#',
            'text': 'foobar',
            }
        yield http_request(self.transport_url + "foo", xml, method='GET')
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['transport_type'], "ussd")
        self.assertEqual(msg['transport_metadata'],
                         {"session_id": "sess1234"})
        self.assertEqual(msg['session_event'],
                         TransportUserMessage.SESSION_RESUME)
        self.assertEqual(msg['from_addr'], '27345')
        self.assertEqual(msg['to_addr'], '*120*99#')
        self.assertEqual(msg['content'], 'foobar')

    @inlineCallbacks
    def test_inbound_non_ascii(self):
        xml = (XML_TEMPLATE % {
            'ussd_type': 'Request',
            'sid': 'sess1234',
            'network_sid': "netsid12345",
            'msisdn': '27345',
            'connstr': '*120*99#',
            'text': u'öæł',
            }).encode("utf-8")
        yield http_request(self.transport_url + "foo", xml, method='GET')
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(msg['content'], u'öæł')

    @inlineCallbacks
    def test_nack(self):
        self.higate_response = """
            
                
                    
                    
                
            """.strip()

        msg = yield self.tx_helper.make_dispatch_outbound(
            "hi", transport_metadata={'session_id': "sess123"})
        yield self.integrat_calls.get()
        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['sent_message_id'], msg['message_id'])
        self.assertEqual(nack['nack_reason'],
                         'error_code: -1, reason: Expecting POST, not GET')
PK=JG*vumi/transports/integrat/tests/__init__.pyPK=JGJc6c6,vumi/transports/integrat/tests/test_utils.pyimport re
from xml.etree import ElementTree

from twisted.python import log

from vumi.transports.integrat.utils import HigateXMLParser
from vumi.tests.helpers import VumiTestCase


class TestHigateXML(VumiTestCase):
    '''
    Tests for the Sample XML found at:
    http://www.integrat.co.za/wiki/index.php/Sample_xml
    '''

    def setUp(self):
        self.dolog = True
        self.hxp = HigateXMLParser()

    def testParseOnResult(self):
        OnResult_xml = '''
        
            
            
                Higate
                Http001
                HC001
                1
                
            
        
        '''
        OnResult_dict = {
            'Code': '3',
            'Text': 'Acknowledged',
            'SubCode': '0',
            'RefNo': '2313344',
            'Flags': '0',
            'SeqNo': '8199250',
            'TOC': 'SMS',
            'Type': 'OnResult',
            }
        if self.dolog:
            log.msg("OnResult -> %s" % (repr(self.hxp.parse(OnResult_xml))))
        self.assertEquals(self.hxp.parse(OnResult_xml), OnResult_dict)

    def testParseSendSMS_Linked(self):
        SendSMS_xml = '''
        
            
                George
                xxxxxx
                
                    
                    0
                    
                    
                    Test message from Higate Http client
                    
                
         
        
        '''
        SendSMS_dict = {'RefNo': '1', 'Type': 'SendSMS'}
        if self.dolog:
            log.msg("SendSMS -> %s" % (repr(self.hxp.parse(SendSMS_xml))))
        self.assertEquals(self.hxp.parse(SendSMS_xml), SendSMS_dict)

    def testParseSendSMS(self):
        SendSMS_xml = '''
        
            
                George
                xxxxxx
                
                    
                    0
                    Test message from Higate Http client
                    
                
         
        
        '''
        SendSMS_dict = {'RefNo': '1', 'Type': 'SendSMS'}
        if self.dolog:
            log.msg("SendSMS -> %s" % (repr(self.hxp.parse(SendSMS_xml))))
        self.assertEquals(self.hxp.parse(SendSMS_xml), SendSMS_dict)

    def testParseOnReceiveSMS(self):
        OnReceiveSMS_xml = '''
        
            
            
                Higate
                Client1
                SRC0123
                
                
                    06052677F6A565 ...etc
                
            
        
        '''
        OnReceiveSMS_dict = {
            'NetworkID': '1',
            'FromAddr': '27829023456',
            'SeqNo': '576674646',
            'AdultRating': '0',
            'hex': '06052677F6A565 ...etc',
            'Value': '0',
            'ToTag': '777',
            'ToAddr': '27829020203777',
            'EsmClass': '128',
            'DataCoding': '8',
            'Type': 'OnReceiveSMS',
            'Sent': '20100614135709',
            }
        if self.dolog:
            log.msg("OnReceiveSMS -> %r" % (self.hxp.parse(OnReceiveSMS_xml),))
        self.assertEquals(self.hxp.parse(OnReceiveSMS_xml), OnReceiveSMS_dict)

    def testParseOnOBSResponse(self):
        OnOBSResponse_xml = '''
        
            
            
                Higate
                LoginName
                SERVICECODE
                2
                32
                6
                An exception occured in : setErrorVaribles :
                  ControlException on control eventChargeValidation[ORA-0
                
                
            
        
        '''
        OnOBSResponse_dict = {
            'RefNo': '123',
            'SeqNo': '1234568',
            'Type': 'OnOBSResponse',
            }
        if self.dolog:
            log.msg("OnOBSResponse -> %r" % (
                    self.hxp.parse(OnOBSResponse_xml),))
        self.assertEquals(self.hxp.parse(OnOBSResponse_xml),
                          OnOBSResponse_dict)

    def testParseOnLBSResponse(self):
        OnLBSResponse_xml = '''
        
            
            
                Higate
                LoginName
                SERVICECODE
                1
                4096
                4
                Receipted
                
                    
                        
                        
                        
                            
                                -25955564
                                28133442
                                High
                                2009-01-27T13:17:28.000Z
                                0
                                Vod:0:0
                                0
                            
                        
                    
                
            
        
        '''
        OnLBSResponse_dict = {
            'RefNo': '123',
            'SeqNo': '548245219',
            'Type': 'OnLBSResponse',
            }
        if self.dolog:
            log.msg("OnLBSResponse -> %r" % (
                    self.hxp.parse(OnLBSResponse_xml),))
        self.assertEquals(self.hxp.parse(OnLBSResponse_xml),
                          OnLBSResponse_dict)

    def testParseOnUSSEventRequest(self):
        OnUSSEvent_xml = '''
        
            
            
                Higate
                LoginName
                SERVICECODE
                
                
                    
                    REQ
                
            
        
        '''

        OnUSSEvent_dict = {'ConnStr': '*120*99*123#',
                         'MSISDN': '27821234567',
                         'NetworkSID': '310941653',
                         'Script': 'testscript',
                         'SessionID': '16502',
                         'Type': 'OnUSSEvent',
                         'USSText': 'REQ',
                         'EventType': 'Request'}
        if self.dolog:
            log.msg("OnUSSEvent -> %r" % (self.hxp.parse(OnUSSEvent_xml),))
        self.assertEquals(self.hxp.parse(OnUSSEvent_xml), OnUSSEvent_dict)

    def testParseOnUSSEventRequestOpen(self):
        OnUSSEvent_xml = '''
        
            
            
            
                Higate
                LoginName
                SERVICECODE
                
                
                    
                
            
        
        '''
        OnUSSEvent_dict = {'ConnStr': '*120*99*123#',
                         'MSISDN': '27821234567',
                         'NetworkSID': '310941653',
                         'Script': 'testscript',
                         'SessionID': '16502',
                         'Type': 'OnUSSEvent',
                         'EventType': 'Open'}
        if self.dolog:
            log.msg("OnUSSEvent -> %r" % (self.hxp.parse(OnUSSEvent_xml),))
        self.assertEquals(self.hxp.parse(OnUSSEvent_xml), OnUSSEvent_dict)

    def testParseOnUSSEventRequestClose(self):
        OnUSSEvent_xml = '''
        
            
            
            
                Higate
                LoginName
                LoginName
                
                
                    
                
            
        
        '''
        OnUSSEvent_dict = {'ConnStr': '*120*99*123#',
                         'MSISDN': '27821234567',
                         'NetworkSID': '310941653',
                         'Script': 'testscript',
                         'SessionID': '16502',
                         'Type': 'OnUSSEvent',
                         'EventType': 'Close'}
        if self.dolog:
            log.msg("OnUSSEvent -> %r" % (self.hxp.parse(OnUSSEvent_xml),))
        self.assertEquals(self.hxp.parse(OnUSSEvent_xml), OnUSSEvent_dict)

    def testParseUSSReply(self):
        USSReply_xml = '''
        
        
         
               LoginName
               xxxxxxxx
               Welcome the this USSD session
         
        
        '''
        USSReply_dict = {'Flags': '0',
                        'Password': 'xxxxxxxx',
                        'SessionID': '223665',
                        'Type': 'USSReply',
                        'USSText': 'Welcome the this USSD session',
                        'UserID': 'LoginName'}
        if self.dolog:
            log.msg("USSReply -> %s" % (repr(self.hxp.parse(USSReply_xml))))
        self.assertEquals(self.hxp.parse(USSReply_xml), USSReply_dict)

    def testBuildUSSReply(self):
        USSReply_dict = {'Flags': '0',
                        'Password': 'xxxxxxxx',
                        'SessionID': '223665',
                        'Type': 'USSReply',
                        'USSText': 'Welcome the this USSD session',
                        'UserID': 'LoginName'}
        USSReply_xml = '''
        
         
               LoginName
               xxxxxxxx
               Welcome the this USSD session
         
        
        '''
        # make xml string formatting compact & consistent
        USSReply_xml = ElementTree.tostring(
            ElementTree.fromstring(USSReply_xml))
        USSReply_xml = re.sub(r'\n\s*', '', USSReply_xml)
        if self.dolog:
            log.msg("USSReply -> %s" % (repr(self.hxp.build(USSReply_dict))))
        self.assertEquals(self.hxp.build(USSReply_dict), USSReply_xml)
PK=JGt1'vumi/transports/mtn_nigeria/__init__.py"""
MTN Nigeria USSD transport.
"""

from vumi.transports.mtn_nigeria.mtn_nigeria_ussd import (
    MtnNigeriaUssdTransport)


__all__ = ['MtnNigeriaUssdTransport', 'XmlOverTcpClient']
PK=JGO3##/vumi/transports/mtn_nigeria/mtn_nigeria_ussd.pyfrom twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks
from twisted.internet.protocol import ReconnectingClientFactory

from vumi import log
from vumi.transports.base import Transport
from vumi.message import TransportUserMessage
from vumi.config import ConfigInt, ConfigText, ConfigDict
from vumi.components.session import SessionManager
from vumi.transports.mtn_nigeria.xml_over_tcp import (
    XmlOverTcpError, CodedXmlOverTcpError, XmlOverTcpClient)


class MtnNigeriaUssdTransportConfig(Transport.CONFIG_CLASS):
    """MTN Nigeria USSD transport configuration."""

    server_hostname = ConfigText(
        "Hostname of the server the transport's client should connect to.",
        required=True, static=True)
    server_port = ConfigInt(
        "Port that the server is listening on.",
        required=True, static=True)
    username = ConfigText(
        "The username for this transport.",
        required=True, static=True)
    password = ConfigText(
        "The password for this transport.",
        required=True, static=True)
    application_id = ConfigText(
        "An application ID required by MTN Nigeria for client authentication.",
        required=True, static=True)
    enquire_link_interval = ConfigInt(
        "The interval (in seconds) between enquire links sent to the server "
        "to check whether the connection is alive and well.",
        default=30, static=True)
    timeout_period = ConfigInt(
        "How long (in seconds) after sending an enquire link request the "
        "client should wait for a response before timing out. NOTE: The "
        "timeout period should not be longer than the enquire link interval",
        default=30, static=True)
    user_termination_response = ConfigText(
        "Response given back to the user if the user terminated the session.",
        default='Session Ended', static=True)
    redis_manager = ConfigDict(
        "Parameters to connect to Redis with",
        default={}, static=True)
    session_timeout_period = ConfigInt(
        "Max length (in seconds) of a USSD session",
        default=600, static=True)


class MtnNigeriaUssdTransport(Transport):
    """
    USSD transport for MTN Nigeria.

    This transport connects as a TCP client and sends messages using a
    custom protocol whose packets consist of binary headers plus an XML body.
    """

    transport_type = 'ussd'

    CONFIG_CLASS = MtnNigeriaUssdTransportConfig

    # The encoding we use internally
    ENCODING = 'UTF-8'

    REQUIRED_METADATA_FIELDS = set(['session_id', 'clientId'])

    @inlineCallbacks
    def setup_transport(self):
        config = self.get_static_config()
        self.user_termination_response = config.user_termination_response

        r_prefix = "vumi.transports.mtn_nigeria:%s" % self.transport_name
        self.session_manager = yield SessionManager.from_redis_config(
            config.redis_manager, r_prefix,
            config.session_timeout_period)

        self.factory = MtnNigeriaUssdClientFactory(
            vumi_transport=self,
            username=config.username,
            password=config.password,
            application_id=config.application_id,
            enquire_link_interval=config.enquire_link_interval,
            timeout_period=config.timeout_period)
        self.client_connector = reactor.connectTCP(
            config.server_hostname, config.server_port, self.factory)
        log.msg('Connecting')

    def teardown_transport(self):
        if self.client_connector is not None:
            self.factory.stopTrying()
            self.client_connector.disconnect()

        return self.session_manager.stop()

    @staticmethod
    def pop_fields(params, *fields):
        return (params.pop(k, None) for k in fields)

    @staticmethod
    def determine_session_event(msg_type, end_of_session):
        if msg_type == '1':
            return TransportUserMessage.SESSION_NEW
        if end_of_session == '0' and msg_type == '4':
            return TransportUserMessage.SESSION_RESUME
        return TransportUserMessage.SESSION_CLOSE

    @inlineCallbacks
    def handle_raw_inbound_message(self, session_id, params):
        # ensure the params are in the encoding we use internally
        params['session_id'] = session_id
        params = dict((k, v.decode(self.ENCODING))
                      for k, v in params.iteritems())

        session_event = self.determine_session_event(
            *self.pop_fields(params, 'msgtype', 'EndofSession'))

        # For the first message of a session, the `user_data` field is the ussd
        # code. For subsequent messages, 'user_data' is the user's content.  We
        # need to keep track of the ussd code we get in in the first session
        # message so we can link the correct `to_addr` to subsequent messages
        if session_event == TransportUserMessage.SESSION_NEW:
            # Set the content to none if this the start of the session.
            # Prevents this inbound message being mistaken as a user message.
            content = None

            to_addr = params.pop('userdata')
            session = yield self.session_manager.create_session(
                session_id, ussd_code=to_addr)
        else:
            session = yield self.session_manager.load_session(session_id)
            to_addr = session['ussd_code']
            content = params.pop('userdata')

        # pop the remaining needed field (the rest is left as metadata)
        [from_addr] = self.pop_fields(params, 'msisdn')

        log.msg('MtnNigeriaUssdTransport receiving inbound message from %s '
                'to %s: %s' % (from_addr, to_addr, content))

        if session_event == TransportUserMessage.SESSION_CLOSE:
            self.factory.client.send_data_response(
                session_id=session_id,
                request_id=params['requestId'],
                star_code=params['starCode'],
                client_id=params['clientId'],
                msisdn=from_addr,
                user_data=self.user_termination_response,
                end_session=True)

        yield self.publish_message(
            content=content,
            to_addr=to_addr,
            from_addr=from_addr,
            provider='mtn_nigeria',
            session_event=session_event,
            transport_type=self.transport_type,
            transport_metadata={'mtn_nigeria_ussd': params})

    def send_response(self, message_id, **client_args):
        try:
            self.factory.client.send_data_response(**client_args)
        except XmlOverTcpError as e:
            return self.publish_nack(message_id, "Response failed: %s" % e)

        return self.publish_ack(user_message_id=message_id,
                                sent_message_id=message_id)

    def validate_outbound_message(self, message):
        metadata = message['transport_metadata']['mtn_nigeria_ussd']
        missing_fields = (self.REQUIRED_METADATA_FIELDS - set(metadata.keys()))
        if missing_fields:
            raise CodedXmlOverTcpError(
                '208',
                "Required message transport metadata fields missing in "
                "outbound message: %s" % list(missing_fields))

    @inlineCallbacks
    def handle_outbound_message(self, message):
        metadata = message['transport_metadata']['mtn_nigeria_ussd']

        try:
            self.validate_outbound_message(message)
        except CodedXmlOverTcpError as e:
            log.msg(e)
            yield self.publish_nack(message['message_id'], "%s" % e)
            yield self.factory.client.send_error_response(
                metadata.get('session_id'),
                message.payload.get('in_reply_to'),
                e.code)
            return

        log.msg(
            'MtnNigeriaUssdTransport sending outbound message: %s' % message)

        end_session = (
            message['session_event'] == TransportUserMessage.SESSION_CLOSE)
        yield self.send_response(
            message_id=message['message_id'],
            session_id=metadata['session_id'],
            request_id=metadata['requestId'],
            star_code=metadata['starCode'],
            client_id=metadata['clientId'],
            msisdn=message['to_addr'],
            user_data=message['content'].encode(self.ENCODING),
            end_session=end_session)


class MtnNigeriaUssdClient(XmlOverTcpClient):
    def __init__(self, vumi_transport, **kwargs):
        XmlOverTcpClient.__init__(self, **kwargs)
        self.vumi_transport = vumi_transport

    def connectionMade(self):
        XmlOverTcpClient.connectionMade(self)
        self.factory.resetDelay()

    def data_request_received(self, session_id, params):
        return self.vumi_transport.handle_raw_inbound_message(
            session_id, params)


class MtnNigeriaUssdClientFactory(ReconnectingClientFactory):
    protocol = MtnNigeriaUssdClient

    def __init__(self, **client_args):
        self.client_args = client_args
        self.client = None

    def buildProtocol(self, addr):
        client = self.protocol(**self.client_args)
        client.factory = self
        self.client = client
        return client
PK=Hʣ)F)F+vumi/transports/mtn_nigeria/xml_over_tcp.pyimport random
import struct

from twisted.web import microdom
from twisted.internet import reactor
from twisted.internet.task import LoopingCall
from twisted.internet.protocol import Protocol

from vumi import log


class XmlOverTcpError(Exception):
    """
    Raised when an error occurs while interacting with the XmlOverTcp protocol.
    """


class CodedXmlOverTcpError(XmlOverTcpError):
    """
    Raised when an XmlOverTcpError occurs and an error code is available
    """

    ERRORS = {
        '001': 'Invalid User Name Password',
        '002': 'Buffer Overflow',
        '200': 'No free dialogs',
        '201': 'Invalid Destination  (applies for n/w initiated session only)',
        '202': 'Subscriber Not reachable.',
        '203': ('Timer Expiry (session with subscriber terminated due to '
                'TimerExp)'),
        '204': 'Subscriber is Black Listed.',
        '205': ('Service not Configured. (some service is created but but no '
               'menu configured for this)'),
        '206': 'Network Error',
        '207': 'Unknown Error',
        '208': 'Invalid Message',
        '209': 'Subscriber terminated Session (subscriber chose exit option)',
        '210': 'Incomplete Menu',
        '211': 'ER not running',
        '212': 'Timeout waiting for response from ER',
    }

    def __init__(self, code, reason=None):
        self.code = code
        self.msg = self.ERRORS.get(code, 'Unknown Code')
        self.reason = reason

    def __str__(self):
        return '(%s) %s%s' % (
            self.code,
            self.msg,
            ': %s' % self.reason if self.reason else '')


class XmlOverTcpClient(Protocol):
    SESSION_ID_HEADER_SIZE = 16
    LENGTH_HEADER_SIZE = 16
    HEADER_SIZE = SESSION_ID_HEADER_SIZE + LENGTH_HEADER_SIZE
    HEADER_FORMAT = '!%ss%ss' % (SESSION_ID_HEADER_SIZE, LENGTH_HEADER_SIZE)
    SESSION_ID_CHARACTERS = "0123456789"

    REQUEST_ID_LENGTH = 10

    PACKET_RECEIVED_HANDLERS = {
        'USSDRequest': 'handle_data_request',
        'USSDResponse': 'handle_data_response',
        'AUTHResponse': 'handle_login_response',
        'AUTHError': 'handle_login_error_response',
        'ENQRequest': 'handle_enquire_link_request',
        'ENQResponse': 'handle_enquire_link_response',
        'USSDError': 'handle_error_response',
    }

    # packet types which don't need the client to be authenticated
    IGNORE_AUTH_PACKETS = [
        'AUTHResponse', 'AUTHError', 'AUTHRequest', 'USSDError']

    # received packet fields
    DATA_REQUEST_FIELDS = set([
        'requestId', 'msisdn', 'clientId', 'starCode', 'msgtype', 'phase',
        'dcs', 'userdata'])
    OTHER_DATA_REQUEST_FIELDS = set(['EndofSession'])
    LOGIN_RESPONSE_FIELDS = set(['requestId', 'authMsg'])
    LOGIN_ERROR_FIELDS = set(['requestId', 'authMsg', 'errorCode'])
    OTHER_LOGIN_ERROR_FIELDS = set(['errorMsg'])
    ENQUIRE_LINK_FIELDS = set(['requestId', 'enqCmd'])
    ERROR_FIELDS = set(['requestId', 'errorCode'])
    OTHER_ERROR_FIELDS = set(['errorMsg'])

    # Data requests and responses need to include a 'dcs' (data coding scheme)
    # field. '15' is used for ASCII, and is the default. The documentation
    # does not offer any other codes.
    DATA_CODING_SCHEME = '15'

    # By observation, it appears that latin1 is the protocol's encoding
    ENCODING = 'latin1'

    # Data requests and responses need to include a 'phase' field. The
    # documentation does not provide any information about 'phase', but we are
    # assuming this refers to the USSD phase. This should be set to 2 for
    # interactive two-way communication.
    PHASE = '2'

    def __init__(self, username, password, application_id,
                 enquire_link_interval=30, timeout_period=30):
        self.username = username
        self.password = password
        self.application_id = application_id
        self.enquire_link_interval = enquire_link_interval
        self.timeout_period = timeout_period

        self.clock = reactor
        self.authenticated = False
        self.scheduled_timeout = None
        self.periodic_enquire_link = LoopingCall(
            self.send_enquire_link_request)

        self.reset_buffer()

    def connectionMade(self):
        self.login()

    def connectionLost(self, reason):
        log.msg("Connection lost")
        self.stop_periodic_enquire_link()
        self.cancel_scheduled_timeout()
        self.reset_buffer()

    def reset_buffer(self):
        self._buffer = ''
        self._current_header = None

    def timeout(self):
        log.msg("No enquire link response received after %s seconds, "
                "disconnecting" % self.timeout_period)
        self.disconnect()

    def disconnect(self):
        """For easier test stubbing."""
        self.transport.loseConnection()

    def cancel_scheduled_timeout(self):
        if (self.scheduled_timeout is not None
                and self.scheduled_timeout.active()):
            self.scheduled_timeout.cancel()

    def reset_scheduled_timeout(self):
        self.cancel_scheduled_timeout()

        # cap the timeout period at the enquire link interval
        delay = min(self.timeout_period, self.enquire_link_interval)
        self.scheduled_timeout = self.clock.callLater(delay, self.timeout)

    def start_periodic_enquire_link(self):
        if not self.authenticated:
            log.msg("Heartbeat could not be started, client not authenticated")
            return

        self.periodic_enquire_link.clock = self.clock
        d = self.periodic_enquire_link.start(
            self.enquire_link_interval, now=True)
        log.msg("Heartbeat started")

        return d

    def stop_periodic_enquire_link(self):
        self.cancel_scheduled_timeout()
        if self.periodic_enquire_link.running:
            self.periodic_enquire_link.stop()
        log.msg("Heartbeat stopped")

    def dataReceived(self, data):
        self._buffer += data

        while self._buffer:
            header = self.peak_buffer(self.HEADER_SIZE)

            if not header:
                return

            session_id, length = self.deserialize_header(header)
            packet = self.pop_buffer(length)

            if not packet:
                return

            body = packet[self.HEADER_SIZE:]

            try:
                packet_type, params = self.deserialize_body(body)
            except Exception, e:
                log.err("Error parsing packet (%s): %r" % (e, packet))
                self.disconnect()
                return

            self.packet_received(session_id, packet_type, params)

    def pop_buffer(self, n):
        if n > len(self._buffer):
            return None

        buffer = self._buffer
        self._buffer = buffer[n:]
        return buffer[:n]

    def peak_buffer(self, n):
        if n > len(self._buffer):
            return None

        return self._buffer[:n]

    @classmethod
    def remove_nullbytes(cls, s):
        return s.replace('\0', '')

    @classmethod
    def deserialize_header(cls, header):
        session_id, length = struct.unpack(cls.HEADER_FORMAT, header)

        # The headers appear to be padded with trailing nullbytes, so we need
        # to remove these before doing any other parsing
        return (cls.remove_nullbytes(session_id),
                int(cls.remove_nullbytes(length)))

    @staticmethod
    def _xml_node_text(node):
        result = ''

        for child in node.childNodes:
            if isinstance(child, microdom.CharacterData):
                result += child.value
            elif isinstance(child, microdom.EntityReference):
                result += microdom.unescape(
                    child.toxml(), chars=microdom.XML_ESCAPE_CHARS)

        return result.strip()

    @classmethod
    def deserialize_body(cls, body):
        document = microdom.parseXMLString(body.decode(cls.ENCODING))
        root = document.firstChild()

        params = dict(
            (node.nodeName, cls._xml_node_text(node))
            for node in root.childNodes)

        return root.nodeName, params

    def packet_received(self, session_id, packet_type, params):
        log.debug("Packet of type '%s' with session id '%s' received: %s"
                  % (packet_type, session_id, params))

        # dispatch the packet to the appropriate handler
        handler_name = self.PACKET_RECEIVED_HANDLERS.get(packet_type, None)
        if handler_name is None:
            log.err("Packet of an unknown type received: %s" % packet_type)
            return self.send_error_response(
                session_id, params.get('requestId'), '208')

        if (not self.authenticated and
                packet_type not in self.IGNORE_AUTH_PACKETS):
            log.err("'%s' packet received before client authentication "
                    "was completed" % packet_type)
            return self.send_error_response(
                session_id, params.get('requestId'), '207')

        getattr(self, handler_name)(session_id, params)

    def validate_packet_fields(self, params, mandatory_fields,
                               other_fields=set()):
        packet_fields = set(params.keys())

        all_fields = mandatory_fields | other_fields
        unexpected_fields = packet_fields - all_fields
        if unexpected_fields:
            raise CodedXmlOverTcpError(
                '208',
                "Unexpected fields in received packet: %s"
                % sorted(unexpected_fields))

        missing_mandatory_fields = mandatory_fields - packet_fields
        if missing_mandatory_fields:
            raise CodedXmlOverTcpError(
                '208',
                "Missing mandatory fields in received packet: %s"
                % sorted(missing_mandatory_fields))

    def handle_error(self, session_id, request_id, e):
        log.err(e)
        self.send_error_response(session_id, request_id, e.code)

    def handle_login_response(self, session_id, params):
        try:
            self.validate_packet_fields(params, self.LOGIN_RESPONSE_FIELDS)
        except CodedXmlOverTcpError as e:
            self.disconnect()
            self.handle_error(session_id, params.get('requestId'), e)
            return

        log.msg("Client authentication complete.")
        self.authenticated = True
        self.start_periodic_enquire_link()

    def handle_login_error_response(self, session_id, params):
        try:
            self.validate_packet_fields(
                params, self.LOGIN_ERROR_FIELDS, self.OTHER_LOGIN_ERROR_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        log.err("Login failed, disconnecting")
        self.disconnect()

    def handle_error_response(self, session_id, params):
        try:
            self.validate_packet_fields(
                params, self.ERROR_FIELDS, self.OTHER_ERROR_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        log.err(
            "Server sent error message: %s" %
            CodedXmlOverTcpError(params['errorCode'], params.get('errorMsg')))

    def handle_data_request(self, session_id, params):

        try:
            self.validate_packet_fields(
                params,
                self.DATA_REQUEST_FIELDS,
                self.OTHER_DATA_REQUEST_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        # if EndofSession is not in params, assume the end of session
        params.setdefault('EndofSession', '1')
        self.data_request_received(session_id, params)

    def data_request_received(self, session_id, params):
        raise NotImplementedError("Subclasses should implement.")

    def handle_data_response(self, session_id, params):
        # We seem to get these if we reply to a session that has already been
        # closed.

        try:
            self.validate_packet_fields(
                params,
                self.DATA_REQUEST_FIELDS,
                self.OTHER_DATA_REQUEST_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        # if EndofSession is not in params, assume the end of session
        params.setdefault('EndofSession', '1')
        self.data_response_received(session_id, params)

    def data_response_received(self, session_id, params):
        log.msg("Received spurious USSDResponse message, ignoring.")

    @classmethod
    def serialize_header_field(cls, header, header_size):
        return str(header).ljust(header_size, '\0')

    @classmethod
    def serialize_header(cls, session_id, body):
        length = len(body) + cls.HEADER_SIZE
        return struct.pack(
            cls.HEADER_FORMAT,
            cls.serialize_header_field(session_id, cls.SESSION_ID_HEADER_SIZE),
            cls.serialize_header_field(length, cls.LENGTH_HEADER_SIZE))

    @classmethod
    def serialize_body(cls, packet_type, params):
        root = microdom.Element(packet_type.encode('utf8'), preserveCase=True)

        for name, value in params:
            el = microdom.Element(name.encode('utf8'), preserveCase=True)
            el.appendChild(microdom.Text(value.encode('utf8')))
            root.appendChild(el)

        data = root.toxml()
        return data.decode('utf8').encode(cls.ENCODING, 'xmlcharrefreplace')

    @classmethod
    def serialize_packet(cls, session_id, packet_type, params):
        body = cls.serialize_body(packet_type, params)
        return cls.serialize_header(session_id, body) + body

    def send_packet(self, session_id, packet_type, params):
        if (not self.authenticated
                and packet_type not in self.IGNORE_AUTH_PACKETS):
            raise XmlOverTcpError(
                "'%s' packet could not be sent, client not authenticated"
                % packet_type)

        packet = self.serialize_packet(session_id, packet_type, params)
        log.debug("Sending packet: %s" % packet)
        self.transport.write(packet)

    @classmethod
    def gen_session_id(cls):
        """
        Generates session id. Used for packets needing a dummy session id.
        """
        return "".join(
            random.choice(cls.SESSION_ID_CHARACTERS)
            for i in range(cls.SESSION_ID_HEADER_SIZE))

    @classmethod
    def gen_request_id(cls):
        # NOTE: The protocol requires request ids to be number only ids. With a
        # request id length of 10 digits, generating ids using randint could
        # well cause collisions to occur, although this should be unlikely.
        return str(random.randint(0, (10 ** cls.REQUEST_ID_LENGTH) - 1))

    def login(self):
        params = [
            ('requestId', self.gen_request_id()),
            ('userName', self.username),
            ('passWord', self.password),  # plaintext passwords, yay :/
            ('applicationId', self.application_id),
        ]
        self.send_packet(self.gen_session_id(), 'AUTHRequest', params)
        log.msg('Logging in')

    def send_error_response(self, session_id=None, request_id=None,
                            code='207'):
        params = [
            ('requestId', request_id or self.gen_request_id()),
            ('errorCode', code),
        ]
        self.send_packet(
            session_id or self.gen_session_id(), 'USSDError', params)

    def send_data_response(self, session_id, request_id, client_id, msisdn,
                           user_data, star_code, end_session=True):
        if end_session:
            msg_type = '6'
            end_of_session = '1'
        else:
            msg_type = '2'
            end_of_session = '0'

        # XXX: delivery reports can be given for the delivery of the last
        # message in a session. However, the documentation does not provide any
        # information on how delivery report packets look, so this is currently
        # disabled ('delvrpt' is set to '0' below).

        packet_params = [
            ('requestId', request_id),
            ('msisdn', msisdn),
            ('starCode', star_code),
            ('clientId', client_id),
            ('phase', self.PHASE),
            ('msgtype', msg_type),
            ('dcs', self.DATA_CODING_SCHEME),
            ('userdata', user_data),
            ('EndofSession', end_of_session),
            ('delvrpt', '0'),
        ]

        self.send_packet(session_id, 'USSDResponse', packet_params)

    def handle_enquire_link_request(self, session_id, params):
        try:
            self.validate_packet_fields(params, self.ENQUIRE_LINK_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        log.debug("Enquire link request received, sending response")
        self.send_enquire_link_response(session_id, params['requestId'])

    def send_enquire_link_request(self):
        log.debug("Sending enquire link request")
        self.send_packet(self.gen_session_id(), 'ENQRequest', [
            ('requestId', self.gen_request_id()),
            ('enqCmd', 'ENQUIRELINK')
        ])
        self.reset_scheduled_timeout()

    def handle_enquire_link_response(self, session_id, params):
        try:
            self.validate_packet_fields(params, self.ENQUIRE_LINK_FIELDS)
        except CodedXmlOverTcpError as e:
            self.handle_error(session_id, params.get('requestId'), e)
            return

        log.debug("Enquire link response received, sending next request in %s "
                  "seconds" % self.enquire_link_interval)
        self.cancel_scheduled_timeout()

    def send_enquire_link_response(self, session_id, request_id):
        self.send_packet(session_id, 'ENQResponse', [
            ('requestId', request_id),
            ('enqCmd', 'ENQUIRELINKRSP')
        ])
PK=H19r9r6vumi/transports/mtn_nigeria/tests/test_xml_over_tcp.py# -*- test-case-name: vumi.transports.mtn_nigeria.tests.test_xml_over_tcp -*-
# -*- coding: utf-8 -*-

import struct
from itertools import count

from twisted.internet.task import Clock
from twisted.internet.defer import inlineCallbacks, DeferredQueue

from vumi import log
from vumi.transports.mtn_nigeria.xml_over_tcp import (
    XmlOverTcpError, CodedXmlOverTcpError, XmlOverTcpClient)
from vumi.transports.mtn_nigeria.tests import utils
from vumi.tests.helpers import VumiTestCase


class ToyXmlOverTcpClient(XmlOverTcpClient):
    _PACKET_RECEIVED_HANDLERS = {'DummyPacket': 'dummy_packet_received'}

    def __init__(self):
        XmlOverTcpClient.__init__(self, 'root', 'toor', '1029384756')
        self.PACKET_RECEIVED_HANDLERS.update(self._PACKET_RECEIVED_HANDLERS)

        self.received_dummy_packets = []
        self.received_data_request_packets = []
        self.disconnected = False

        self.session_id_counter = count()
        self.generated_session_ids = []

        self.request_id_counter = count()
        self.generated_request_ids = []
        self.received_queue = DeferredQueue()

    def wait_for_data(self):
        return self.received_queue.get()

    def connectionMade(self):
        pass

    def dataReceived(self, data):
        XmlOverTcpClient.dataReceived(self, data)
        self.received_queue.put(data)

    def dummy_packet_received(self, session_id, params):
        self.received_dummy_packets.append((session_id, params))

    def data_request_received(self, session_id, params):
        self.received_data_request_packets.append((session_id, params))

    def disconnect(self):
        self.disconnected = True

    @classmethod
    def session_id_from_nr(cls, nr):
        return cls.serialize_header_field(nr, cls.SESSION_ID_HEADER_SIZE)

    def gen_session_id(self):
        return self.session_id_from_nr(next(self.session_id_counter))

    def gen_request_id(self):
        return str(next(self.request_id_counter))


class XmlOverTcpClientServerMixin(utils.MockClientServerMixin):
    client_protocol = ToyXmlOverTcpClient
    server_protocol = utils.MockXmlOverTcpServer


class TestXmlOverTcpClient(VumiTestCase, XmlOverTcpClientServerMixin):
    def setUp(self):
        errors = dict(CodedXmlOverTcpError.ERRORS)
        errors['000'] = 'Dummy error occured'
        self.patch(CodedXmlOverTcpError, 'ERRORS', errors)

        self.logs = {'msg': [], 'err': [], 'debug': []}
        self.patch(log, 'msg', lambda *a: self.append_to_log('msg', *a))
        self.patch(log, 'err', lambda *a: self.append_to_log('err', *a))
        self.patch(log, 'debug', lambda *a: self.append_to_log('debug', *a))

        self.add_cleanup(self.stop_protocols)
        return self.start_protocols()

    def append_to_log(self, log_name, *args):
        self.logs[log_name].append(' '.join(str(a) for a in args))

    def assert_in_log(self, log_name, substr):
        log = self.logs[log_name]
        if not any(substr in m for m in log):
            self.fail("'%s' not in %s log" % (substr, log_name))

    @staticmethod
    def mk_raw_packet(session_id, length_header, body):
        header = struct.pack(
            XmlOverTcpClient.HEADER_FORMAT, session_id, length_header)
        return header + body

    @inlineCallbacks
    def test_packet_parsing_for_packets_with_wierd_bodies(self):
        data = utils.mk_packet('0', "")
        self.client.authenticated = True
        self.server.send_data(data)

        yield self.client.wait_for_data()
        err_msg = self.logs['err'][0]
        self.assertTrue("Error parsing packet" in err_msg)
        self.assertTrue(('%r' % (data,)) in err_msg)
        self.assertTrue(self.client.disconnected)

    def test_packet_header_serializing(self):
        self.assertEqual(
            XmlOverTcpClient.serialize_header('23', 'abcdef'),
            '23\0\0\0\0\0\0\0\0\0\0\0\0\0\0'
            '38\0\0\0\0\0\0\0\0\0\0\0\0\0\0')

    def test_packet_header_deserializing(self):
        session_id, length = XmlOverTcpClient.deserialize_header(
            '23\0\0\0\0\0\0\0\0\0\0\0\0\0\0'
            '38\0\0\0\0\0\0\0\0\0\0\0\0\0\0')

        self.assertEqual(session_id, '23')
        self.assertEqual(length, 38)

    def test_packet_body_serializing(self):
        body = XmlOverTcpClient.serialize_body(
            'DummyPacket',
            [('requestId', '123456789abcdefg')])
        expected_body = (
            ""
            "123456789abcdefg"
            "")
        self.assertEqual(body, expected_body)

    def test_packet_body_serializing_for_non_latin1_chars(self):
        body = XmlOverTcpClient.serialize_body(
            'DummyPacket',
            [('requestId', '123456789abcdefg'),
             ('userdata', u'Erdős')])
        expected_body = (
            ""
            "123456789abcdefg"
            "Erdős"
            "")
        self.assertEqual(body, expected_body)

    def test_packet_body_deserializing(self):
        body = '\n'.join([
            "",
            "\t",
            "\t\t123456789abcdefg",
            "\t",
            "\t",
            "\t\t2347067123456",
            "\t",
            "\t",
            "\t\t759",
            "\t",
            "\t",
            "\t\t441",
            "\t",
            "\t",
            "\t\t2",
            "\t",
            "\t",
            "\t\t15",
            "\t",
            "\t",
            "\t\t\xa4",
            "\t",
            "\t",
            "\t\t4",
            "\t",
            "\t",
            "\t\t0",
            "\t",
            "\n"
        ])
        packet_type, params = XmlOverTcpClient.deserialize_body(body)

        self.assertEqual(packet_type, 'USSDRequest')
        self.assertEqual(params, {
            'requestId': '123456789abcdefg',
            'msisdn': '2347067123456',
            'userdata': u'\xa4',
            'clientId': '441',
            'dcs': '15',
            'msgtype': '4',
            'phase': '2',
            'starCode': '759',
            'EndofSession': '0',
        })

    def test_packet_body_deserializing_for_invalid_xml_chars(self):
        body = '\n'.join([
            ''
            '\t'
            '\t\t123456789abcdefg'
            '\t'
            '\t'
            '\t\t2341234567890',
            '\t',
            '\t',
            '\t\t759',
            '\t',
            '\t',
            '\t\t441',
            '\t',
            '\t',
            '\t\t2',
            '\t',
            '\t',
            '\t\t229',
            '\t',
            '\t',
            '\t\t\x18',
            '\t',
            '\t',
            '\t\t4',
            '\t',
            '\t',
            '\t\t0',
            '\t',
            '',
        ])
        packet_type, params = XmlOverTcpClient.deserialize_body(body)

        self.assertEqual(packet_type, 'USSDRequest')
        self.assertEqual(params, {
            'EndofSession': '0',
            'clientId': '441',
            'dcs': '229',
            'msgtype': '4',
            'msisdn': '2341234567890',
            'phase': '2',
            'requestId': '123456789abcdefg',
            'starCode': '759',
            'userdata': u'\x18',
        })

    def test_packet_body_deserializing_for_entity_references(self):
        body = '\n'.join([
            '',
            '\t',
            '\t\t123456789abcdefg',
            '\t',
            '\t',
            '\t\t2341234567890',
            '\t',
            '\t',
            '\t\t759',
            '\t',
            '\t',
            '\t\t441',
            '\t',
            '\t',
            '\t\t2',
            '\t',
            '\t',
            '\t\t15',
            '\t',
            '\t',
            '\t\tTeam's rank',
            '\t',
            '\t\n\t\t4',
            '\t',
            '\t',
            '\t\t0',
            '\t',
            '',
        ])
        packet_type, params = XmlOverTcpClient.deserialize_body(body)

        self.assertEqual(packet_type, 'USSDRequest')
        self.assertEqual(params, {
            'EndofSession': u'0',
            'clientId': u'441',
            'dcs': u'15',
            'msgtype': u'4',
            'msisdn': u'2341234567890',
            'phase': u'2',
            'requestId': u'123456789abcdefg',
            'starCode': u'759',
            'userdata': u"Team's rank"
        })

    @inlineCallbacks
    def test_contiguous_packets_received(self):
        body_a = "123"
        body_b = "456"

        data = utils.mk_packet('0', body_a)
        data += utils.mk_packet('1', body_b)
        self.client.authenticated = True
        self.server.send_data(data)

        yield self.client.wait_for_data()
        self.assertEqual(
            self.client.received_dummy_packets, [
                ('0', {'someParam': '123'}),
                ('1', {'someParam': '456'}),
            ])

    @inlineCallbacks
    def test_packets_split_over_socket_reads(self):
        body = "123"
        data = utils.mk_packet('0', body)
        split_position = int(len(data) / 2)

        self.client.authenticated = True

        self.server.send_data(data[:split_position])
        yield self.client.wait_for_data()

        self.server.send_data(data[split_position:])
        yield self.client.wait_for_data()

        self.assertEqual(
            self.client.received_dummy_packets,
            [('0', {'someParam': '123'})])

    @inlineCallbacks
    def test_partial_data_received(self):
        body_a = "123"
        body_b = "456"

        # add a full first packet, then concatenate a sliced version of a
        # second packet
        data = utils.mk_packet('0', body_a)
        data += utils.mk_packet('1', body_b)[:12]
        self.client.authenticated = True
        self.server.send_data(data)

        yield self.client.wait_for_data()
        self.assertEqual(
            self.client.received_dummy_packets,
            [('0', {'someParam': '123'})])

    @inlineCallbacks
    def test_authentication(self):
        request_body = (
            ""
            "0"
            "root"
            "toor"
            "1029384756"
            "")
        expected_request_packet = utils.mk_packet('0', request_body)

        response_body = (
            ""
            "0"
            "SUCCESS"
            "")
        response_packet = utils.mk_packet('a', response_body)
        self.server.responses[expected_request_packet] = response_packet

        self.client.login()
        yield self.client.wait_for_data()
        self.assertTrue(self.client.authenticated)

    @inlineCallbacks
    def test_authentication_error_handling(self):
        request_body = (
            ""
            "0"
            "root"
            "toor"
            "1029384756"
            "")
        expected_request_packet = utils.mk_packet('0', request_body)

        response_body = (
            ""
            "0"
            "FAILURE"
            "001"
            "")
        response_packet = utils.mk_packet('0', response_body)
        self.server.responses[expected_request_packet] = response_packet

        self.client.login()
        yield self.client.wait_for_data()
        self.assertFalse(self.client.authenticated)
        self.assertTrue(self.client.disconnected)
        self.assert_in_log('err', 'Login failed, disconnecting')

    @inlineCallbacks
    def test_unknown_packet_handling(self):
        request_body = (
            "0")
        request_packet = utils.mk_packet('0', request_body)

        response_body = (
            ""
            "0"
            "208"
            "")
        expected_response_packet = utils.mk_packet('0', response_body)

        self.server.send_data(request_packet)
        yield self.client.wait_for_data()

        response_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_response_packet, response_packet)
        self.assert_in_log(
            'err', "Packet of an unknown type received: UnknownPacket")

    @inlineCallbacks
    def test_packet_received_before_auth(self):
        request_body = (
            ""
            "0"
            "")
        request_packet = utils.mk_packet('0', request_body)

        response_body = (
            ""
            "0"
            "207"
            "")
        expected_response_packet = utils.mk_packet('0', response_body)

        self.server.send_data(request_packet)
        yield self.client.wait_for_data()

        response_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_response_packet, response_packet)
        self.assert_in_log(
            'err',
            "'DummyPacket' packet received before client authentication was "
            "completed")

    def test_packet_send_before_auth(self):
        self.assertRaises(
            XmlOverTcpError,
            self.client.send_packet, '0', 'DummyPacket', [])

    @inlineCallbacks
    def test_data_request_handling(self):
        body = (
            ""
            "1291850641"
            "27845335367"
            "123"
            "123"
            "2"
            "15"
            "*123#"
            "1"
            "0"
            ""
        )
        packet = utils.mk_packet('0', body)
        self.client.authenticated = True
        self.server.send_data(packet)

        yield self.client.wait_for_data()
        expected_params = {
            'requestId': '1291850641',
            'msisdn': '27845335367',
            'starCode': '123',
            'clientId': '123',
            'phase': '2',
            'dcs': '15',
            'userdata': '*123#',
            'msgtype': '1',
            'EndofSession': '0',
        }
        self.assertEqual(
            self.client.received_data_request_packets,
            [('0', expected_params)])

    @inlineCallbacks
    def test_data_response_handling(self):
        body = (
            ""
            "1291850641"
            "27845335367"
            "123"
            "123"
            "2"
            "15"
            "Session closed due to cause 0"
            "1"
            "1"
            ""
        )
        packet = utils.mk_packet('0', body)
        self.client.authenticated = True
        self.server.send_data(packet)

        yield self.client.wait_for_data()
        self.assert_in_log(
            'msg', "Received spurious USSDResponse message, ignoring.")

    def test_field_validation_for_valid_cases(self):
        self.client.validate_packet_fields(
            {'a': '1', 'b': '2'},
            set(['a', 'b']),
            set(['b', 'c']))

        self.client.validate_packet_fields(
            {'a': '1', 'b': '2'},
            set(['a', 'b']))

    def test_field_validation_for_missing_mandatory_fields(self):
        self.assertRaises(
            XmlOverTcpError,
            self.client.validate_packet_fields,
            {'requestId': '1291850641', 'a': '1', 'b': '2'},
            set(['requestId', 'a', 'b', 'c']))

    def test_field_validation_for_unexpected_fields(self):
        self.assertRaises(
            XmlOverTcpError,
            self.client.validate_packet_fields,
            {'requestId': '1291850641', 'a': '1', 'b': '2', 'd': '3'},
            set(['requestId', 'a', 'b']))

    @inlineCallbacks
    def test_login_response_validation(self):
        body = "0"
        bad_packet = utils.mk_packet('0', body)

        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % ['authMsg'])

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_login_error_response_validation(self):
        bad_packet = utils.mk_packet(
            '0', "0")

        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % ['authMsg', 'errorCode'])

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_error_response_validation(self):
        bad_packet = utils.mk_packet(
            '0', "0")

        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % ['errorCode'])

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_data_request_validation(self):
        bad_packet = utils.mk_packet(
            '0', "0")

        self.client.authenticated = True
        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()

        missing_fields = ['userdata', 'msisdn', 'clientId', 'starCode',
                          'msgtype', 'phase', 'dcs']
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % sorted(missing_fields))

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_enquire_link_request_validation(self):
        bad_packet = utils.mk_packet(
            '0', "0")

        self.client.authenticated = True
        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % ['enqCmd'])

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_enquire_link_response_validation(self):
        bad_packet = utils.mk_packet(
            '0', "0")

        self.client.authenticated = True
        self.server.send_data(bad_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "(208) Invalid Message: Missing mandatory fields in received "
            "packet: %s" % ['enqCmd'])

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(received_packet, utils.mk_packet(
            '0',
            ""
            "0"
            "208"
            ""))

    @inlineCallbacks
    def test_continuing_session_data_response(self):
        body = (
            ""
            "1291850641"
            "27845335367"
            "123"
            "123"
            "2"
            "2"
            "15"
            "*123#"
            "0"
            "0"
            ""
        )
        expected_packet = utils.mk_packet('0', body)

        self.client.authenticated = True
        self.client.send_data_response(
            session_id='0',
            request_id='1291850641',
            star_code='123',
            client_id='123',
            msisdn='27845335367',
            user_data='*123#',
            end_session=False)

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_packet, received_packet)

    @inlineCallbacks
    def test_ending_session_data_response(self):
        body = (
            ""
            "1291850641"
            "27845335367"
            "123"
            "123"
            "2"
            "6"
            "15"
            "*123#"
            "1"
            "0"
            ""
        )
        expected_packet = utils.mk_packet('0', body)

        self.client.authenticated = True
        self.client.send_data_response(
            session_id='0',
            request_id='1291850641',
            star_code='123',
            client_id='123',
            msisdn='27845335367',
            user_data='*123#',
            end_session=True)

        received_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_packet, received_packet)

    def assert_next_timeout(self, t):
        return self.assertAlmostEqual(
            self.client.scheduled_timeout.getTime(), t, 1)

    def assert_timeout_cancelled(self):
        self.assertFalse(self.client.scheduled_timeout.active())

    @inlineCallbacks
    def test_periodic_client_enquire_link(self):
        request_body_a = (
            ""
            "0"
            "ENQUIRELINK"
            "")
        expected_request_packet_a = utils.mk_packet('0', request_body_a)

        response_body_a = (
            ""
            "0"
            "ENQUIRELINKRSP"
            "")
        response_packet_a = utils.mk_packet('0', response_body_a)
        self.server.responses[expected_request_packet_a] = response_packet_a

        request_body_b = (
            ""
            "1"
            "ENQUIRELINK"
            "")
        expected_request_packet_b = utils.mk_packet('1', request_body_b)

        response_body_b = (
            ""
            "1"
            "ENQUIRELINKRSP"
            "")
        response_packet_b = utils.mk_packet('1', response_body_b)
        self.server.responses[expected_request_packet_b] = response_packet_b

        clock = Clock()
        t0 = clock.seconds()
        self.client.clock = clock
        self.client.enquire_link_interval = 120
        self.client.timeout_period = 20
        self.client.authenticated = True
        self.client.start_periodic_enquire_link()

        # advance to just after the first enquire link request
        clock.advance(0.01)
        self.assert_next_timeout(t0 + 20)

        # wait for the first enquire link response
        yield self.client.wait_for_data()
        self.assert_timeout_cancelled()

        # advance to just after the second enquire link request
        clock.advance(120.01)
        self.assert_next_timeout(t0 + 140)

        # wait for the second enquire link response
        yield self.client.wait_for_data()
        self.assert_timeout_cancelled()

    @inlineCallbacks
    def test_timeout(self):
        request_body = (
            ""
            "0"
            "ENQUIRELINK"
            "")
        expected_request_packet = utils.mk_packet('0', request_body)

        clock = Clock()
        self.client.clock = clock
        self.client.enquire_link_interval = 120
        self.client.timeout_period = 20
        self.client.authenticated = True
        self.client.start_periodic_enquire_link()

        # wait for the first enquire link request
        received_request_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_request_packet, received_request_packet)

        # advance to just before the timeout should occur
        clock.advance(19.9)
        self.assertFalse(self.client.disconnected)

        # advance to just after the timeout should occur
        clock.advance(0.1)
        self.assertTrue(self.client.disconnected)
        self.assert_in_log(
            'msg',
            "No enquire link response received after 20 seconds, "
            "disconnecting")

    @inlineCallbacks
    def test_server_enquire_link(self):
        request_body = (
            ""
            "0"
            "ENQUIRELINK"
            "")
        request_packet = utils.mk_packet('0', request_body)

        response_body = (
            ""
            "0"
            "ENQUIRELINKRSP"
            "")
        expected_response_packet = utils.mk_packet('0', response_body)

        self.client.authenticated = True
        self.server.send_data(request_packet)
        response_packet = yield self.server.wait_for_data()
        self.assertEqual(expected_response_packet, response_packet)

    @inlineCallbacks
    def test_error_response_handling_for_known_codes(self):
        body = (
            ""
            "0"
            "000"
            "Some Reason"
            ""
        )
        error_packet = utils.mk_packet('0', body)

        self.server.send_data(error_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "Server sent error message: (000) Dummy error occured: "
            "Some Reason")

    @inlineCallbacks
    def test_error_response_handling_for_unknown_codes(self):
        body = (
            ""
            "0"
            "1337"
            "Some Reason"
            ""
        )
        error_packet = utils.mk_packet('0', body)

        self.server.send_data(error_packet)
        yield self.client.wait_for_data()
        self.assert_in_log(
            'err',
            "Server sent error message: (1337) Unknown Code: "
            "Some Reason")

    def test_gen_session_id(self):
        sessid = XmlOverTcpClient.gen_session_id()
        self.assertEqual(len(sessid), XmlOverTcpClient.SESSION_ID_HEADER_SIZE)
        self.assertTrue(
            all(c in XmlOverTcpClient.SESSION_ID_CHARACTERS for c in sessid))
PKqGG\^*vumi/transports/mtn_nigeria/tests/utils.pyfrom twisted.internet.defer import (
    Deferred, inlineCallbacks, gatherResults, maybeDeferred, DeferredQueue)
from twisted.internet import reactor
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Factory, ClientCreator

from vumi.transports.mtn_nigeria.xml_over_tcp import XmlOverTcpClient


def mk_packet(session_id, body):
    return XmlOverTcpClient.serialize_header(session_id, body) + body


class MockServerFactory(Factory):
    def __init__(self):
        self.deferred_server = Deferred()


class MockServer(Protocol):
    def connectionMade(self):
        self.factory.deferred_server.callback(self)

    def connectionLost(self, reason):
        self.factory.on_connection_lost.callback(None)


class MockServerMixin(object):
    server_protocol = None

    @inlineCallbacks
    def start_server(self):
        self.server_disconnected = Deferred()
        factory = MockServerFactory()
        factory.on_connection_lost = self.server_disconnected
        factory.protocol = self.server_protocol
        self.server_port = reactor.listenTCP(0, factory, interface='127.0.0.1')
        self.server = yield factory.deferred_server

    def stop_server(self):
        # Turns out stopping these things is tricky.
        # See http://mumak.net/stuff/twisted-disconnect.html
        return gatherResults([
            maybeDeferred(self.server_port.loseConnection),
            self.server_disconnected])

    def get_server_port(self):
        return self.server_port.getHost().port


class MockXmlOverTcpServer(MockServer):
    def __init__(self):
        self.responses = {}
        self.received_queue = DeferredQueue()

    def wait_for_data(self):
        return self.received_queue.get()

    def send_data(self, data):
        self.transport.write(data)

    def dataReceived(self, data):
        response = self.responses.get(data)
        if response is not None:
            self.transport.write(response)
        self.received_queue.put(data)


class MockXmlOverTcpServerMixin(MockServerMixin):
    server_protocol = MockXmlOverTcpServer


class MockClientMixin(object):
    client_protocol = None

    @inlineCallbacks
    def start_client(self, port):
        self.client_disconnected = Deferred()
        self.client_creator = ClientCreator(reactor, self.client_protocol)
        self.client = yield self.client_creator.connectTCP('127.0.0.1', port)
        conn_lost = self.client.connectionLost

        def connectionLost_wrapper(reason):
            d = maybeDeferred(conn_lost, reason)
            d.chainDeferred(self.client_disconnected)
            return d
        self.client.connectionLost = connectionLost_wrapper

    def stop_client(self):
        self.client.transport.loseConnection()
        return self.client_disconnected


class MockClientServerMixin(MockClientMixin, MockServerMixin):
    @inlineCallbacks
    def start_protocols(self):
        deferred_server = self.start_server()
        yield self.start_client(self.get_server_port())
        yield deferred_server  # we need to wait for the client to connect

    @inlineCallbacks
    def stop_protocols(self):
        yield self.stop_client()
        yield self.stop_server()
PK=JGm8s%+,+,5vumi/transports/mtn_nigeria/tests/test_mtn_nigeria.py# -*- test-case-name: vumi.transports.mtn_nigeria.tests.test_mtn_nigeria -*-

from twisted.internet.defer import Deferred, inlineCallbacks

from vumi.message import TransportUserMessage
from vumi.transports.mtn_nigeria.tests import utils
from vumi.tests.helpers import VumiTestCase
from vumi.transports.mtn_nigeria import MtnNigeriaUssdTransport
from vumi.transports.mtn_nigeria import mtn_nigeria_ussd
from vumi.transports.mtn_nigeria.tests.utils import MockXmlOverTcpServerMixin
from vumi.transports.mtn_nigeria.xml_over_tcp import (
    XmlOverTcpError, CodedXmlOverTcpError)
from vumi.transports.tests.helpers import TransportHelper


class TestMtnNigeriaUssdTransport(VumiTestCase, MockXmlOverTcpServerMixin):

    REQUEST_PARAMS = {
        'request_id': '1291850641',
        'msisdn': '27845335367',
        'star_code': '123',
        'client_id': '0123',
        'phase': '2',
        'dcs': '15',
        'user_data': '*123#',
        'msg_type': '1',
        'end_of_session': '0',
    }
    REQUEST_BODY = (
        ""
        "%(request_id)s"
        "%(msisdn)s"
        "%(star_code)s"
        "%(client_id)s"
        "%(phase)s"
        "%(msg_type)s"
        "%(dcs)s"
        "%(user_data)s"
        "%(end_of_session)s"
        ""
    )

    RESPONSE_PARAMS = {
        'request_id': '1291850641',
        'msisdn': '27845335367',
        'star_code': '123',
        'client_id': '0123',
        'phase': '2',
        'dcs': '15',
        'user_data': '',
        'msg_type': '2',
        'end_of_session': '0',
        'delivery_report': '0',
    }
    RESPONSE_BODY = (
        ""
        "%(request_id)s"
        "%(msisdn)s"
        "%(star_code)s"
        "%(client_id)s"
        "%(phase)s"
        "%(msg_type)s"
        "%(dcs)s"
        "%(user_data)s"
        "%(end_of_session)s"
        "%(delivery_report)s"
        ""
    )

    EXPECTED_TRANSPORT_METADATA = {
        'mtn_nigeria_ussd': {
            'session_id': '0',
            'clientId': '0123',
            'phase': '2',
            'dcs': '15',
            'starCode': '123',
            'requestId': '1291850641',
        },
    }

    @inlineCallbacks
    def setUp(self):
        self.tx_helper = self.add_helper(
            TransportHelper(MtnNigeriaUssdTransport))
        deferred_login = self.fake_login(
            mtn_nigeria_ussd.MtnNigeriaUssdClientFactory.protocol)
        deferred_server = self.start_server()
        self.add_cleanup(self.stop_server)

        self.transport = yield self.tx_helper.get_transport({
            'server_hostname': '127.0.0.1',
            'server_port': self.get_server_port(),
            'username': 'root',
            'password': 'toor',
            'application_id': '1029384756',
            'enquire_link_interval': 240,
            'timeout_period': 120,
            'user_termination_response': 'Bye',
        })
        # We need to tear the transport down before stopping the server.
        self.add_cleanup(self.transport.stopWorker)
        yield deferred_server

        self.session_manager = self.transport.session_manager
        yield self.session_manager.redis._purge_all()

        yield deferred_login
        self.client = self.transport.factory.client

    def fake_login(self, protocol_cls):
        d = Deferred()

        def stubbed_login(self):
            self.authenticated = True
            if not d.called:
                d.callback(None)
        self.patch(protocol_cls, 'login', stubbed_login)
        return d

    @inlineCallbacks
    def mk_session(self, session_id, ussd_code):
        # first pre-populate the redis datastore to simulate session resume
        # note: imimobile do not provide a session id, so instead we use the
        # msisdn as the session id
        yield self.session_manager.create_session(
            session_id, ussd_code=ussd_code)

    def mk_data_request(self, session_id, **kw):
        params = self.REQUEST_PARAMS.copy()
        params.update(kw)
        return utils.mk_packet(session_id, self.REQUEST_BODY % params)

    def mk_data_response(self, session_id, **kw):
        params = self.RESPONSE_PARAMS.copy()
        params.update(kw)
        return utils.mk_packet(session_id, self.RESPONSE_BODY % params)

    def mk_error_response(self, session_id, request_id, error_code):
        body = (
            ""
            "%s"
            "%s"
            "" % (request_id, error_code))
        return utils.mk_packet(session_id, body)

    def send_request(self, session_id, **params):
        packet = self.mk_data_request(session_id, **params)
        self.server.send_data(packet)

    def assert_inbound_message(self, msg, **field_values):
        expected_payload = {
            'content': '',
            'from_addr': '27845335367',
            'to_addr': '*123#',
            'session_event': TransportUserMessage.SESSION_RESUME,
            'transport_name': self.tx_helper.transport_name,
            'transport_type': 'ussd',
            'transport_metadata': self.EXPECTED_TRANSPORT_METADATA,
        }
        expected_payload.update(field_values)

        for field, expected_value in expected_payload.iteritems():
            self.assertEqual(msg[field], expected_value)

    def assert_ack(self, ack, reply):
        self.assertEqual(ack.payload['event_type'], 'ack')
        self.assertEqual(ack.payload['user_message_id'], reply['message_id'])
        self.assertEqual(ack.payload['sent_message_id'], reply['message_id'])

    def assert_nack(self, nack, reply, reason):
        self.assertEqual(nack.payload['event_type'], 'nack')
        self.assertEqual(nack.payload['user_message_id'], reply['message_id'])
        self.assertEqual(nack.payload['nack_reason'], reason)

    @inlineCallbacks
    def test_inbound_begin(self):
        self.send_request('0', user_data='*123#')
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assert_inbound_message(
            msg,
            session_event=TransportUserMessage.SESSION_NEW,
            from_addr='27845335367',
            to_addr='*123#',
            content=None)

        reply_d = self.tx_helper.make_dispatch_reply(
            msg, "We are the Knights Who Say ... Ni!")

        response = yield self.server.wait_for_data()
        expected_response = self.mk_data_response(
            '0',
            user_data="We are the Knights Who Say ... Ni!")
        self.assertEqual(response, expected_response)

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        reply = yield reply_d
        self.assert_ack(ack, reply)

    @inlineCallbacks
    def test_inbound_resume_and_reply_with_end(self):
        yield self.mk_session('0', '*123#')

        self.send_request(
            '0',
            user_data="Well, what is it you want?",
            msg_type=4,
            end_of_session=0)
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assert_inbound_message(
            msg,
            session_event=TransportUserMessage.SESSION_RESUME,
            content="Well, what is it you want?")

        reply_d = self.tx_helper.make_dispatch_reply(
            msg, "We want ... a shrubbery!", continue_session=False)

        response_packet = yield self.server.wait_for_data()
        expected_response_packet = self.mk_data_response(
            '0',
            user_data="We want ... a shrubbery!",
            msg_type=6,
            end_of_session=1)
        self.assertEqual(response_packet, expected_response_packet)

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        reply = yield reply_d
        self.assert_ack(ack, reply)

    @inlineCallbacks
    def test_inbound_resume_and_reply_with_resume(self):
        yield self.mk_session('0', '*123#')

        self.send_request(
            '0',
            user_data="Well, what is it you want?",
            msg_type=4,
            end_of_session=0)
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assert_inbound_message(
            msg,
            session_event=TransportUserMessage.SESSION_RESUME,
            content="Well, what is it you want?")

        reply_d = self.tx_helper.make_dispatch_reply(
            msg, "We want ... a shrubbery!", continue_session=True)

        response_packet = yield self.server.wait_for_data()
        expected_response_packet = self.mk_data_response(
            '0',
            user_data="We want ... a shrubbery!",
            msg_type=2,
            end_of_session=0)
        self.assertEqual(response_packet, expected_response_packet)

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        reply = yield reply_d
        self.assert_ack(ack, reply)

    @inlineCallbacks
    def test_user_terminated_session(self):
        yield self.mk_session('0', '*123#')

        self.send_request(
            '0',
            user_data="I'm leaving now",
            msg_type=4,
            end_of_session=1)
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assert_inbound_message(
            msg,
            session_event=TransportUserMessage.SESSION_CLOSE,
            content="I'm leaving now")

        response_packet = yield self.server.wait_for_data()
        expected_response_packet = self.mk_data_response(
            '0',
            user_data='Bye',
            msg_type=6,
            end_of_session=1)
        self.assertEqual(response_packet, expected_response_packet)

    @inlineCallbacks
    def test_outbound_response_failure(self):
        # stub the client to fake a response failure
        def stubbed_send_data_response(*a, **kw):
            raise XmlOverTcpError("Something bad happened")

        self.patch(
            self.client,
            'send_data_response',
            stubbed_send_data_response)

        tm = self.EXPECTED_TRANSPORT_METADATA.copy()
        msg = self.tx_helper.make_inbound("foo", transport_metadata=tm)
        reply = yield self.tx_helper.make_dispatch_reply(msg, "It's a trap!")

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assert_nack(
            nack, reply, "Response failed: Something bad happened")

    @inlineCallbacks
    def test_outbound_metadata_fields_missing(self):
        msg = self.tx_helper.make_inbound(
            "foo", transport_metadata={
                'mtn_nigeria_ussd': {'session_id': '123'},
            })
        reply = yield self.tx_helper.make_dispatch_reply(msg, "It's a trap!")

        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        reason = "%s" % CodedXmlOverTcpError(
            '208',
            "Required message transport metadata fields missing in "
            "outbound message: %s" % ['clientId'])
        self.assert_nack(nack, reply, reason)
PK=JG-vumi/transports/mtn_nigeria/tests/__init__.pyPK=JG(.'vumi/transports/mediafonemc/__init__.py"""
Mediafone Cameroun HTTP SMS API.
"""

from vumi.transports.mediafonemc.mediafonemc import MediafoneTransport


__all__ = ['MediafoneTransport']
PKqGoFF*vumi/transports/mediafonemc/mediafonemc.py# -*- test-case-name: vumi.transports.mediafonemc.tests.test_mediafonemc -*-

import json
from urllib import urlencode

from twisted.python import log
from twisted.web import http
from twisted.internet.defer import inlineCallbacks

from vumi.utils import http_request_full
from vumi.transports.httprpc import HttpRpcTransport


class MediafoneTransport(HttpRpcTransport):
    """
    HTTP transport for Mediafone Cameroun.

    :param str web_path:
        The HTTP path to listen on.
    :param int web_port:
        The HTTP port
    :param str transport_name:
        The name this transport instance will use to create its queues
    :param str username:
        Mediafone account username.
    :param str password:
        Mediafone account password.
    :param str outbound_url:
        The URL to send outbound messages to.

    """

    transport_type = 'sms'
    agent_factory = None  # For swapping out the Agent we use in tests.

    ENCODING = 'utf-8'
    EXPECTED_FIELDS = set(['to', 'from', 'sms'])

    def setup_transport(self):
        self._username = self.config['username']
        self._password = self.config['password']
        self._outbound_url = self.config['outbound_url']
        return super(MediafoneTransport, self).setup_transport()

    @inlineCallbacks
    def handle_outbound_message(self, message):
        params = {
            'username': self._username,
            'password': self._password,
            'phone': message['to_addr'],
            'msg': message['content'],
            }
        log.msg("Sending outbound message: %s" % (message,))
        url = '%s?%s' % (self._outbound_url, urlencode(params))
        log.msg("Making HTTP request: %s" % (url,))
        response = yield http_request_full(
            url, '', method='GET', agent_class=self.agent_factory)
        log.msg("Response: (%s) %r" % (response.code, response.delivered_body))
        if response.code == http.OK:
            yield self.publish_ack(
                user_message_id=message['message_id'],
                sent_message_id=message['message_id'])
        else:
            yield self.publish_nack(
                user_message_id=message['message_id'],
                sent_message_id=message['message_id'],
                reason='Unexpected response code: %s' % (response.code,))

    @inlineCallbacks
    def handle_raw_inbound_message(self, message_id, request):
        values, errors = self.get_field_values(request, self.EXPECTED_FIELDS)
        if errors:
            log.msg('Unhappy incoming message: %s' % (errors,))
            yield self.finish_request(message_id, json.dumps(errors), code=400)
            return
        log.msg(('MediafoneTransport sending from %(from)s to %(to)s '
                 'message "%(sms)s"') % values)
        yield self.publish_message(
            message_id=message_id,
            content=values['sms'],
            to_addr=values['to'],
            from_addr=values['from'],
            provider='vumi',
            transport_type=self.transport_type,
        )
        yield self.finish_request(
            message_id, json.dumps({'message_id': message_id}))
PKqG͊UU5vumi/transports/mediafonemc/tests/test_mediafonemc.py# -*- encoding: utf-8 -*-

import json
from urllib import urlencode

from twisted.internet.defer import inlineCallbacks, DeferredQueue
from twisted.web import http

from vumi.utils import http_request, http_request_full
from vumi.tests.fake_connection import FakeHttpServer
from vumi.tests.helpers import VumiTestCase
from vumi.transports.mediafonemc import MediafoneTransport
from vumi.transports.tests.helpers import TransportHelper


class TestMediafoneTransport(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.mediafone_calls = DeferredQueue()
        self.fake_http = FakeHttpServer(self.handle_request)
        self.base_url = "http://mediafone.example.com/"

        self.config = {
            'web_path': "foo",
            'web_port': 0,
            'username': 'user',
            'password': 'pass',
            'outbound_url': self.base_url,
        }
        self.tx_helper = self.add_helper(TransportHelper(MediafoneTransport))
        self.transport = yield self.tx_helper.get_transport(self.config)
        self.transport.agent_factory = self.fake_http.get_agent
        self.transport_url = self.transport.get_transport_url()
        self.mediafonemc_response = ''
        self.mediafonemc_response_code = http.OK

    def handle_request(self, request):
        self.mediafone_calls.put(request)
        request.setResponseCode(self.mediafonemc_response_code)
        return self.mediafonemc_response

    def mkurl(self, content, from_addr="2371234567", **kw):
        params = {
            'to': '12345',
            'from': from_addr,
            'sms': content,
            }
        params.update(kw)
        return self.mkurl_raw(**params)

    def mkurl_raw(self, **params):
        return '%s%s?%s' % (
            self.transport_url,
            self.config['web_path'],
            urlencode(params)
        )

    @inlineCallbacks
    def test_health(self):
        result = yield http_request(
            self.transport_url + "health", "", method='GET')
        self.assertEqual(json.loads(result), {'pending_requests': 0})

    @inlineCallbacks
    def test_inbound(self):
        url = self.mkurl('hello')
        response = yield http_request(url, '', method='GET')
        [msg] = self.tx_helper.get_dispatched_inbound()
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['to_addr'], "12345")
        self.assertEqual(msg['from_addr'], "2371234567")
        self.assertEqual(msg['content'], "hello")
        self.assertEqual(json.loads(response),
                         {'message_id': msg['message_id']})

    @inlineCallbacks
    def test_outbound(self):
        yield self.tx_helper.make_dispatch_outbound(
            "hello world", to_addr="2371234567")
        req = yield self.mediafone_calls.get()
        self.assertEqual(req.path, self.base_url)
        self.assertEqual(req.method, 'GET')
        self.assertEqual({
            'username': ['user'],
            'phone': ['2371234567'],
            'password': ['pass'],
            'msg': ['hello world'],
        }, req.args)

    @inlineCallbacks
    def test_nack(self):
        self.mediafonemc_response_code = http.NOT_FOUND
        self.mediafonemc_response = 'Not Found'

        msg = yield self.tx_helper.make_dispatch_outbound(
            "outbound", to_addr="2371234567")

        yield self.mediafone_calls.get()
        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['sent_message_id'], msg['message_id'])
        self.assertEqual(nack['nack_reason'], 'Unexpected response code: 404')

    @inlineCallbacks
    def test_handle_non_ascii_input(self):
        url = self.mkurl(u"öæł".encode("utf-8"))
        response = yield http_request(url, '', method='GET')
        [msg] = self.tx_helper.get_dispatched_inbound()
        self.assertEqual(msg['transport_name'], self.tx_helper.transport_name)
        self.assertEqual(msg['to_addr'], "12345")
        self.assertEqual(msg['from_addr'], "2371234567")
        self.assertEqual(msg['content'], u"öæł")
        self.assertEqual(json.loads(response),
                         {'message_id': msg['message_id']})

    @inlineCallbacks
    def test_bad_parameter(self):
        url = self.mkurl('hello', foo='bar')
        response = yield http_request_full(url, '', method='GET')
        self.assertEqual(400, response.code)
        self.assertEqual(json.loads(response.delivered_body),
                         {'unexpected_parameter': ['foo']})

    @inlineCallbacks
    def test_missing_parameters(self):
        url = self.mkurl_raw(to='12345', sms='hello')
        response = yield http_request_full(url, '', method='GET')
        self.assertEqual(400, response.code)
        self.assertEqual(json.loads(response.delivered_body),
                         {'missing_parameter': ['from']})
PK=JG-vumi/transports/mediafonemc/tests/__init__.pyPK=JGD,vumi/transports/wechat/test_message_types.pyimport json

from twisted.trial.unittest import TestCase

from vumi.transports.wechat.message_types import (
    WeChatXMLParser, TextMessage, NewsMessage)
from vumi.transports.wechat.errors import WeChatParserException


class TestWeChatXMLParser(TestCase):

    def test_missing_msg_type(self):
        self.assertRaises(
            WeChatParserException, WeChatXMLParser.parse, '')

    def test_multiple_msg_types(self):
        self.assertRaises(
            WeChatParserException, WeChatXMLParser.parse,
            'FooBar')

    def test_text_message_parse(self):
        msg = WeChatXMLParser.parse(
            """
            
            
            
            1348831860
            
            
            1234567890123456
            
            """)

        self.assertEqual(msg.to_user_name, 'toUser')
        self.assertEqual(msg.from_user_name, 'fromUser')
        self.assertEqual(msg.create_time, '1348831860')
        self.assertEqual(msg.msg_id, '1234567890123456')
        self.assertEqual(msg.content, 'this is a test')
        self.assertTrue(isinstance(msg, TextMessage))

    def test_text_message_to_xml(self):
        msg = TextMessage(
            'to_addr', 'from_addr', '1348831860', 'this is a test')
        self.assertEqual(
            msg.to_xml(),
            "".join([
                "",
                "to_addr",
                "from_addr",
                "1348831860",
                "text",
                "this is a test",
                "",
            ]))

    def test_text_message_to_json(self):
        msg = TextMessage(
            'to_addr', 'from_addr', '1348831860', 'this is a test')
        self.assertEqual(
            json.loads(msg.to_json()),
            {
                'touser': 'to_addr',
                'msgtype': 'text',
                'text': {
                    'content': 'this is a test'
                }
            })

    def test_news_message_to_xml(self):
        msg = NewsMessage(
            'to_addr', 'from_addr', '1348831860', [{
                'title': 'title1',
                'description': 'description1',
            }, {
                'picurl': 'picurl',
                'url': 'url',
            }])
        self.assertEqual(
            msg.to_xml(),
            ''.join([
                "",
                "to_addr",
                "from_addr",
                "1348831860",
                "news",
                "2",
                "",
                "",
                "title1",
                "description1",
                "",
                "",
                "picurl",
                "url",
                "",
                "",
                "",
            ]))

    def test_news_message_to_json(self):
        msg = NewsMessage(
            'to_addr', 'from_addr', '1348831860', [{
                'title': 'title1',
                'description': 'description1',
            }, {
                'picurl': 'picurl',
                'url': 'url',
            }])
        self.assertEqual(
            json.loads(msg.to_json()),
            {
                'touser': 'to_addr',
                'msgtype': 'news',
                'news': {
                    'articles': [{
                        'title': 'title1',
                        'description': 'description1'
                    }, {
                        'picurl': 'picurl',
                        'url': 'url'
                    }]
                }
            })

    def test_event_message_parse(self):
        msg = WeChatXMLParser.parse(
            """
            
                
                    
                
                
                    
                
                1395130515
                
                    
                
                
                    
                
                
                    
                
            
            """)
        self.assertEqual(msg.to_user_name, 'toUser')
        self.assertEqual(msg.from_user_name, 'fromUser')
        self.assertEqual(msg.create_time, '1395130515')
        self.assertEqual(msg.event, 'subscribe')
        self.assertEqual(msg.event_key, '')
PK=JG&*$YY"vumi/transports/wechat/__init__.pyfrom vumi.transports.wechat.wechat import WeChatTransport

__all__ = ['WeChatTransport']
PK=JGj vumi/transports/wechat/errors.pyclass WeChatException(Exception):
    pass


class WeChatApiException(WeChatException):
    pass


class WeChatParserException(WeChatException):
    pass
PKfcHRR vumi/transports/wechat/wechat.py# -*- test-case-name: vumi.transports.wechat.tests.test_wechat -*-

import hashlib
import urllib
import json
from datetime import datetime
from functools import partial

from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, Deferred, returnValue
from twisted.web.resource import Resource
from twisted.web import http
from twisted.web.server import NOT_DONE_YET

from vumi import log
from vumi.config import (
    ConfigText, ConfigServerEndpoint, ConfigDict, ConfigInt, ConfigBool,
    ServerEndpointFallback)
from vumi.transports import Transport
from vumi.transports.httprpc.httprpc import HttpRpcHealthResource
from vumi.transports.wechat.errors import WeChatException, WeChatApiException
from vumi.transports.wechat.message_types import (
    TextMessage, EventMessage, NewsMessage, WeChatXMLParser)
from vumi.utils import build_web_site, http_request_full, StatusEdgeDetector

from vumi.message import TransportUserMessage
from vumi.persist.txredis_manager import TxRedisManager


def is_verifiable(request):
    params = ['signature', 'timestamp', 'nonce']
    return all([(key in request.args) for key in params])


def http_ok(request):
    return 200 <= request.code < 300


def verify(token, request):
    signature = request.args['signature'][0]
    timestamp = request.args['timestamp'][0]
    nonce = request.args['nonce'][0]

    hash_ = hashlib.sha1(''.join(sorted([timestamp, nonce, token])))

    return hash_.hexdigest() == signature


class WeChatConfig(Transport.CONFIG_CLASS):

    api_url = ConfigText(
        'The URL the WeChat API is accessible at.',
        default='https://api.wechat.com/cgi-bin/',
        required=False, static=True)
    auth_token = ConfigText(
        'This WeChat app\'s auth token. '
        'Used for initial message authentication.',
        required=True, static=True)
    twisted_endpoint = ConfigServerEndpoint(
        'The endpoint to listen on.',
        required=True, static=True, fallbacks=[ServerEndpointFallback()])
    web_path = ConfigText(
        "The path to serve this resource on.",
        default='/api/v1/wechat/', static=True)
    health_path = ConfigText(
        "The path to serve the health resource on.",
        default='/health/', static=True)
    redis_manager = ConfigDict('Parameters to connect to Redis with.',
                               default={}, required=False, static=True)
    wechat_appid = ConfigText(
        'The WeChat app_id. Issued by WeChat for developer accounts '
        'to allow push API access.', required=True, static=True)
    wechat_secret = ConfigText(
        'The WeChat secret. Issued by WeChat for developer accounts '
        'to allow push API access.', required=True, static=True)
    wechat_menu = ConfigDict(
        'The menu structure to create at boot.', required=False, static=True)
    wechat_mask_lifetime = ConfigInt(
        'How long, in seconds, to maintain an address mask for. '
        '(default 1 hour)', default=60 * 60 * 1, static=True)
    embed_user_profile = ConfigBool(
        'Whether or not to embed the WeChat User Profile info in '
        'messages received.', required=True, default=False, static=True)
    embed_user_profile_lang = ConfigText(
        'What language to request User Profile as.', required=False,
        default='en', static=True)
    embed_user_profile_lifetime = ConfigInt(
        'How long to cache User Profiles for.', default=(60 * 60),
        required=False, static=True)
    double_delivery_lifetime = ConfigInt(
        'How long to keep track of Message IDs and responses for double '
        'delivery tracking.', default=(60 * 60), required=False, static=True)

    # TODO: Deprecate these fields when confmodel#5 is done.
    host = ConfigText(
        "*DEPRECATED* 'host' and 'port' fields may be used in place of the"
        " 'twisted_endpoint' field.", static=True)
    port = ConfigInt(
        "*DEPRECATED* 'host' and 'port' fields may be used in place of the"
        " 'twisted_endpoint' field.", static=True)


class WeChatResource(Resource):

    isLeaf = True

    def __init__(self, transport):
        Resource.__init__(self)
        self.transport = transport
        self.config = transport.get_static_config()

    def render_GET(self, request):
        if is_verifiable(request) and verify(self.config.auth_token, request):
            return request.args['echostr'][0]
        request.setResponseCode(http.BAD_REQUEST)
        return ''

    @inlineCallbacks
    def validate_request(self, request):
        if not (is_verifiable(request)
                and verify(self.config.auth_token, request)):
            raise WeChatException('Bad request for incoming message')
        yield self.transport.add_status_good_req()
        returnValue(request)

    def render_POST(self, request):
        d = Deferred()
        d.addCallback(self.validate_request)
        d.addCallback(self.handle_request)
        d.addCallback(self.transport.queue_request, request)
        d.addErrback(self.handle_error, request)
        reactor.callLater(0, d.callback, request)
        return NOT_DONE_YET

    @inlineCallbacks
    def handle_error(self, failure, request):
        if not failure.trap(WeChatException):
            raise failure

        yield self.transport.add_status_bad_req()
        request.setResponseCode(http.BAD_REQUEST)
        request.write(failure.getErrorMessage())
        request.finish()

    def handle_request(self, request):
        d = request.notifyFinish()
        d.addBoth(
            lambda _: self.transport.handle_finished_request(request))

        wc_msg = WeChatXMLParser.parse(request.content.read())
        return self.transport.handle_raw_inbound_message(request, wc_msg)


class WeChatTransport(Transport):
    """

    A Transport for the WeChat API.

    API documentation
    ~~~~~~~~~~~~~~~~~

    http://admin.wechat.com/wiki/index.php?title=Main_Page


    Inbound Messaging
    ~~~~~~~~~~~~~~~~~

    Supported Common Message types:

        - Text Message

    Supported Event Message types:

        - Following / subscribe
        - Unfollowing / unsubscribe
        - Text Message (in response to Menu keypress events)


    Outbound Messaging
    ~~~~~~~~~~~~~~~~~~

    Supported Callback Message types:

        - Text Message
        - News Message

    Supported Customer Service Message types:

        - Text Message
        - News Message

    How it works
    ~~~~~~~~~~~~

    1) When a user subscribes to the Vumi account, and opens
     up the contact for the first time, the contact will
     send the first message.
    2) When the session ends, every time after that, the
     user has to send a text to the contact for it to respond
     (unlike when the user adds the contact for the first
     time as seen in 1.) The user can send anything to the
     contact for this to happen.

    """

    CONFIG_CLASS = WeChatConfig
    DEFAULT_MASK = 'default'
    MESSAGE_TYPES = [
        NewsMessage,
    ]
    DEFAULT_MESSAGE_TYPE = TextMessage
    # What key to store the `access_token` under in Redis
    ACCESS_TOKEN_KEY = 'access_token'
    # What key to store the `addr_mask` under in Redis
    ADDR_MASK_KEY = 'addr_mask'
    # What key to use when constructing the User Profile key
    USER_PROFILE_KEY = 'user_profile'
    # What key to use when constructing the cached reply key
    CACHED_REPLY_KEY = 'cached_reply'

    transport_type = 'wechat'
    agent_factory = None  # For swapping out the Agent we use in tests.

    def add_status_bad_req(self):
        return self.add_status(
            status='down', component='inbound', type='bad_request',
            message='Bad request received')

    def add_status_good_req(self):
        return self.add_status(
            status='ok', component='inbound', type='good_request',
            message='Good request received')

    @inlineCallbacks
    def setup_transport(self):
        config = self.get_static_config()
        self.request_dict = {}
        self.endpoint = config.twisted_endpoint
        self.resource = WeChatResource(self)
        self.factory = build_web_site({
            config.health_path: HttpRpcHealthResource(self),
            config.web_path: self.resource,
        })

        self.redis = yield TxRedisManager.from_config(config.redis_manager)
        self.server = yield self.endpoint.listen(self.factory)
        self.status_detect = StatusEdgeDetector()

        if config.wechat_menu:
            # not yielding because this shouldn't block startup
            d = self.get_access_token()
            d.addCallback(self.create_wechat_menu, config.wechat_menu)

    @inlineCallbacks
    def add_status(self, **kw):
        '''Publishes a status if it is not a repeat of the previously
        published status.'''
        if self.status_detect.check_status(**kw):
            yield self.publish_status(**kw)

    def http_request_full(self, *args, **kw):
        kw['agent_class'] = self.agent_factory
        return http_request_full(*args, **kw)

    @inlineCallbacks
    def create_wechat_menu(self, access_token, menu_structure):
        url = self.make_url('menu/create', {'access_token': access_token})
        response = yield self.http_request_full(
            url, method='POST', data=json.dumps(menu_structure),
            headers={'Content-Type': ['application/json']})
        if not http_ok(response):
            raise WeChatApiException(
                'Received HTTP code: %r when creating the menu.' % (
                    response.code,))
        data = json.loads(response.delivered_body)
        if data['errcode'] != 0:
            raise WeChatApiException(
                'Received errcode: %(errcode)s, errmsg: %(errmsg)s '
                'when creating WeChat Menu.' % data)
        log.info('WeChat Menu created succesfully.')

    def user_profile_key(self, open_id):
        return '@'.join([
            self.USER_PROFILE_KEY,
            open_id,
        ])

    def mask_key(self, user):
        return '@'.join([
            self.ADDR_MASK_KEY,
            user,
        ])

    def cached_reply_key(self, *parts):
        key_parts = [self.CACHED_REPLY_KEY]
        key_parts.extend(parts)
        return '@'.join(key_parts)

    def mask_addr(self, to_addr, mask):
        return '@'.join([to_addr, mask])

    def cache_addr_mask(self, user, mask):
        config = self.get_static_config()
        d = self.redis.setex(
            self.mask_key(user), config.wechat_mask_lifetime, mask)
        d.addCallback(lambda *a: mask)
        return d

    def get_addr_mask(self, user):
        d = self.redis.get(self.mask_key(user))
        d.addCallback(lambda mask: mask or self.DEFAULT_MASK)
        return d

    def clear_addr_mask(self, user):
        return self.redis.delete(self.mask_key(user))

    def handle_raw_inbound_message(self, request, wc_msg):
        return {
            TextMessage: self.handle_inbound_text_message,
            EventMessage: self.handle_inbound_event_message,
        }.get(wc_msg.__class__)(request, wc_msg)

    def wrap_expire(self, result, key, ttl):
        d = self.redis.expire(key, ttl)
        d.addCallback(lambda _: result)
        return d

    def mark_as_seen_recently(self, wc_msg_id):
        config = self.get_static_config()
        key = self.cached_reply_key(wc_msg_id)
        d = self.redis.setnx(key, 1)
        d.addCallback(
            lambda result: (
                self.wrap_expire(result, key, config.double_delivery_lifetime)
                if result else False))
        return d

    def was_seen_recently(self, wc_msg_id):
        return self.redis.exists(self.cached_reply_key(wc_msg_id))

    def get_cached_reply(self, wc_msg_id):
        return self.redis.get(self.cached_reply_key(wc_msg_id, 'reply'))

    def set_cached_reply(self, wc_msg_id, reply):
        config = self.get_static_config()
        return self.redis.setex(
            self.cached_reply_key(wc_msg_id, 'reply'),
            config.double_delivery_lifetime, reply)

    @inlineCallbacks
    def check_for_double_delivery(self, request, wc_msg_id):
        seen_recently = yield self.was_seen_recently(wc_msg_id)
        if not seen_recently:
            returnValue(False)

        cached_reply = yield self.get_cached_reply(wc_msg_id)
        if cached_reply:
            # we've got a reply still lying around, just parrot that instead.
            request.write(cached_reply)

        request.finish()
        returnValue(True)

    @inlineCallbacks
    def handle_inbound_text_message(self, request, wc_msg):
        double_delivery = yield self.check_for_double_delivery(
            request, wc_msg.msg_id)
        if double_delivery:
            log.msg('WeChat double delivery of message: %s' % (wc_msg.msg_id,))
            return

        lock = yield self.mark_as_seen_recently(wc_msg.msg_id)
        if not lock:
            log.msg('Unable to get lock for message id: %s' % (wc_msg.msg_id,))
            return

        config = self.get_static_config()
        if config.embed_user_profile:
            user_profile = yield self.get_user_profile(wc_msg.from_user_name)
        else:
            user_profile = {}

        mask = yield self.get_addr_mask(wc_msg.from_user_name)
        msg = yield self.publish_message(
            content=wc_msg.content,
            from_addr=wc_msg.from_user_name,
            to_addr=self.mask_addr(wc_msg.to_user_name, mask),
            timestamp=datetime.fromtimestamp(int(wc_msg.create_time)),
            transport_type=self.transport_type,
            transport_metadata={
                'wechat': {
                    'FromUserName': wc_msg.from_user_name,
                    'ToUserName': wc_msg.to_user_name,
                    'MsgType': 'text',
                    'MsgId': wc_msg.msg_id,
                    'UserProfile': user_profile,
                }
            })
        returnValue(msg)

    @inlineCallbacks
    def handle_inbound_event_message(self, request, wc_msg):
        if wc_msg.event.lower() in ('view', 'unsubscribe'):
            log.msg("%s clicked on %s" % (
                wc_msg.from_user_name, wc_msg.event_key))
            request.finish()
            yield self.clear_addr_mask(wc_msg.from_user_name)
            return

        if wc_msg.event_key:
            mask = yield self.cache_addr_mask(
                wc_msg.from_user_name, wc_msg.event_key)
        else:
            mask = yield self.get_addr_mask(wc_msg.from_user_name)

        if wc_msg.event.lower() in ('subscribe', 'click'):
            session_event = TransportUserMessage.SESSION_NEW
        else:
            session_event = TransportUserMessage.SESSION_NONE

        msg = yield self.publish_message(
            content=None,
            from_addr=wc_msg.from_user_name,
            to_addr=self.mask_addr(wc_msg.to_user_name, mask),
            timestamp=datetime.fromtimestamp(int(wc_msg.create_time)),
            transport_type=self.transport_type,
            session_event=session_event,
            transport_metadata={
                'wechat': {
                    'FromUserName': wc_msg.from_user_name,
                    'ToUserName': wc_msg.to_user_name,
                    'MsgType': 'event',
                    'Event': wc_msg.event,
                    'EventKey': wc_msg.event_key
                }
            })
        # Close the request to ensure we fire a push message on reply.
        request.finish()
        returnValue(msg)

    def force_close(self, message):
        request = self.get_request(message['message_id'])
        request.setResponseCode(http.INTERNAL_SERVER_ERROR)
        request.finish()

    def handle_finished_request(self, request):
        for message_id, request_ in self.request_dict.items():
            if request_ == request:
                self.request_dict.pop(message_id)

    def queue_request(self, message, request):
        if message is not None:
            self.request_dict[message['message_id']] = request

    def get_request(self, message_id):
        return self.request_dict.get(message_id, None)

    def infer_message_type(self, message):
        for message_type in self.MESSAGE_TYPES:
            result = message_type.accepts(message)
            if result is not None:
                return partial(message_type.from_vumi_message, result)
        return self.DEFAULT_MESSAGE_TYPE.from_vumi_message

    def handle_outbound_message(self, message):
        """
        Read outbound message and do what needs to be done with them.
        """
        request_id = message['in_reply_to']
        request = self.get_request(request_id)

        builder = self.infer_message_type(message)
        wc_msg = builder(message)

        if request is None or request.finished:
            # There's no pending request object for this message which
            # means we need to treat this as a customer service message
            # and hit WeChat's Push API (window available for 24hrs)
            return self.push_message(wc_msg, message)

        request.write(wc_msg.to_xml())
        request.finish()

        d = self.publish_ack(user_message_id=message['message_id'],
                             sent_message_id=message['message_id'])
        wc_metadata = message["transport_metadata"].get('wechat', {})
        if wc_metadata:
            d.addCallback(lambda _: self.set_cached_reply(
                wc_metadata['MsgId'], wc_msg.to_xml()))

        if message['session_event'] == TransportUserMessage.SESSION_CLOSE:
            d.addCallback(
                lambda _: self.clear_addr_mask(wc_msg.to_user_name))
        return d

    def push_message(self, wc_message, vumi_message):
        d = self.get_access_token()
        d.addCallback(
            lambda access_token: self.make_url('message/custom/send', {
                'access_token': access_token
            }))
        d.addCallback(
            lambda url: self.http_request_full(
                url, method='POST', data=wc_message.to_json(), headers={
                    'Content-Type': ['application/json']
                }))
        d.addCallback(self.handle_api_response, vumi_message)
        if vumi_message['session_event'] == TransportUserMessage.SESSION_CLOSE:
            d.addCallback(
                lambda ack: self.clear_addr_mask(wc_message.from_user_name))
        return d

    @inlineCallbacks
    def handle_api_response(self, response, message):
        if http_ok(response):
            ack = yield self.publish_ack(user_message_id=message['message_id'],
                                         sent_message_id=message['message_id'])
            returnValue(ack)
        nack = yield self.publish_nack(
            message['message_id'],
            reason='Received status code: %s' % (response.code,))
        returnValue(nack)

    @inlineCallbacks
    def get_access_token(self):
        access_token = yield self.redis.get(self.ACCESS_TOKEN_KEY)
        if access_token is None:
            access_token = yield self.request_new_access_token()
        returnValue(access_token)

    @inlineCallbacks
    def get_user_profile(self, open_id):
        config = self.get_static_config()
        up_key = self.user_profile_key(open_id)
        cached_up = yield self.redis.get(open_id)
        if cached_up:
            returnValue(json.loads(cached_up))

        access_token = yield self.get_access_token()
        response = yield self.http_request_full(self.make_url('user/info', {
            'access_token': access_token,
            'openid': open_id,
            'lang': config.embed_user_profile_lang,
        }), method='GET')
        user_profile = response.delivered_body
        yield self.redis.setex(up_key, config.embed_user_profile_lifetime,
                               user_profile)
        returnValue(json.loads(user_profile))

    @inlineCallbacks
    def request_new_access_token(self):
        config = self.get_static_config()
        response = yield self.http_request_full(self.make_url('token', {
            'grant_type': 'client_credential',
            'appid': config.wechat_appid,
            'secret': config.wechat_secret,
        }), method='GET')
        if not http_ok(response):
            raise WeChatApiException(
                ('Received HTTP status code %r when '
                 'requesting access token.') % (response.code,))

        data = json.loads(response.delivered_body)
        if 'errcode' in data:
            raise WeChatApiException(
                'Error when requesting access token. '
                'Errcode: %(errcode)s, Errmsg: %(errmsg)s.' % data)

        # make sure we're always ahead of the WeChat expiry
        access_token = data['access_token']
        expiry = int(data['expires_in']) * 0.90
        yield self.redis.setex(
            self.ACCESS_TOKEN_KEY, int(expiry), access_token)
        returnValue(access_token)

    def make_url(self, path, params):
        config = self.get_static_config()
        return '%s%s?%s' % (
            config.api_url, path, urllib.urlencode(params))

    def teardown_transport(self):
        return self.server.stopListening()

    def get_health_response(self):
        return "OK"
PK=JGPF'vumi/transports/wechat/message_types.py# -*- test-case-name: vumi.transports.wechat.tests.test_message_types -*-

import re
import json
from xml.etree.ElementTree import Element, SubElement, tostring, fromstring

from vumi.transports.wechat.errors import (
    WeChatParserException, WeChatException)


def get_child_value(node, name):
    [child] = node.findall(name)
    return (child.text.strip() if child.text is not None else '')


def append(node, tag, value):
    el = SubElement(node, tag)
    el.text = value


class WeChatMessage(object):

    mandatory_fields = ()
    optional_fields = ()

    @classmethod
    def from_xml(cls, doc):
        params = [get_child_value(doc, name)
                  for name in cls.mandatory_fields]

        for field in cls.optional_fields:
            try:
                params.append(get_child_value(doc, field))
            except ValueError:
                # element not present
                continue
        return cls(*params)


class TextMessage(WeChatMessage):

    mandatory_fields = (
        'ToUserName',
        'FromUserName',
        'CreateTime',
        'Content',
    )

    optional_fields = (
        'MsgId',
    )

    def __init__(self, to_user_name, from_user_name, create_time, content,
                 msg_id=None):
        self.to_user_name = to_user_name
        self.from_user_name = from_user_name
        self.create_time = create_time
        self.content = content
        self.msg_id = msg_id

    @classmethod
    def from_vumi_message(cls, message):
        md = message['transport_metadata'].get('wechat', {})
        from_addr = md.get('ToUserName', message['from_addr'])
        return cls(message['to_addr'], from_addr,
                   message['timestamp'].strftime('%s'),
                   message['content'])

    def to_xml(self):
        xml = Element('xml')
        append(xml, 'ToUserName', self.to_user_name)
        append(xml, 'FromUserName', self.from_user_name)
        append(xml, 'CreateTime', self.create_time)
        append(xml, 'MsgType', 'text')
        append(xml, 'Content', self.content)
        return tostring(xml)

    def to_json(self):
        return json.dumps({
            'touser': self.to_user_name,
            'msgtype': 'text',
            'text': {
                'content': self.content,
            }
        })


class NewsMessage(WeChatMessage):

    # Has something URL-ish in it
    URLISH = re.compile(
        r'(?P.*)'
        r'(?Phttp[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)'
        r'(?P.*?)')

    def __init__(self, to_user_name, from_user_name, create_time,
                 items=None):
        self.to_user_name = to_user_name
        self.from_user_name = from_user_name
        self.create_time = create_time
        self.items = ([] if items is None else items)

    @classmethod
    def accepts(cls, vumi_message):
        return cls.URLISH.match(vumi_message['content'])

    @classmethod
    def from_vumi_message(cls, match, vumi_message):
        md = vumi_message['transport_metadata'].get('wechat', {})
        from_addr = md.get('ToUserName', vumi_message['from_addr'])
        url_data = match.groupdict()
        return cls(
            vumi_message['to_addr'],
            from_addr,
            vumi_message['timestamp'].strftime('%s'),
            [{
                'title': '%(before)s' % url_data,
                'url': '%(url)s' % url_data,
                'description': vumi_message['content']
            }])

    def to_xml(self):
        xml = Element('xml')
        append(xml, 'ToUserName', self.to_user_name)
        append(xml, 'FromUserName', self.from_user_name)
        append(xml, 'CreateTime', self.create_time)
        append(xml, 'MsgType', 'news')
        append(xml, 'ArticleCount', str(len(self.items)))
        articles = SubElement(xml, 'Articles')
        for item in self.items:
            if not any(item.values()):
                raise WeChatException(
                    'News items must have some values.')

            item_element = SubElement(articles, 'item')
            if 'title' in item:
                append(item_element, 'Title', item['title'])
            if 'description' in item:
                append(item_element, 'Description', item['description'])
            if 'picurl' in item:
                append(item_element, 'PicUrl', item['picurl'])
            if 'url' in item:
                append(item_element, 'Url', item['url'])
        return tostring(xml)

    def to_json(self):
        return json.dumps({
            'touser': self.to_user_name,
            'msgtype': 'news',
            'news': {
                'articles': self.items
            }
        })


class EventMessage(WeChatMessage):

    mandatory_fields = (
        'ToUserName',
        'FromUserName',
        'CreateTime',
        'Event',
    )

    optional_fields = (
        'MsgId',
        'EventKey',
    )

    def __init__(self, to_user_name, from_user_name, create_time, event,
                 event_key=None):
        self.to_user_name = to_user_name
        self.from_user_name = from_user_name
        self.create_time = create_time
        self.event = event
        self.event_key = event_key


class WeChatXMLParser(object):

    ENCODING = 'utf-8'
    CLASS_MAP = {
        'text': TextMessage,
        'news': NewsMessage,
        'event': EventMessage,
    }

    @classmethod
    def parse(cls, string):
        doc = fromstring(string.decode(cls.ENCODING))
        klass = cls.get_class(doc)
        return klass.from_xml(doc)

    @classmethod
    def get_class(cls, doc):
        msg_types = doc.findall('MsgType')
        if not msg_types:
            raise WeChatParserException('No MsgType found.')

        if len(msg_types) > 1:
            raise WeChatParserException('More than 1 MsgType found.')

        [msg_type_element] = msg_types
        msg_type = msg_type_element.text.strip()
        if msg_type not in cls.CLASS_MAP:
            raise WeChatParserException(
                'Unsupported MsgType: %s' % (msg_type,))

        return cls.CLASS_MAP[msg_type]
PKfcH6[pp+vumi/transports/wechat/tests/test_wechat.pyimport hashlib
import json
import yaml
from urllib import urlencode

from twisted.internet.defer import (
    inlineCallbacks, DeferredQueue, returnValue, gatherResults)
from twisted.internet import task, reactor
from twisted.web import http
from twisted.web.server import NOT_DONE_YET
from twisted.trial.unittest import SkipTest

from vumi.tests.fake_connection import FakeHttpServer
from vumi.tests.helpers import VumiTestCase
from vumi.tests.utils import LogCatcher
from vumi.transports.tests.helpers import TransportHelper
from vumi.transports.wechat import WeChatTransport
from vumi.transports.wechat.errors import WeChatApiException
from vumi.transports.wechat.message_types import (
    WeChatXMLParser, TextMessage)
from vumi.utils import http_request_full
from vumi.message import TransportUserMessage
from vumi.persist.fake_redis import FakeRedis


def request(transport, *a, **kw):
    nonce = '1234'
    timestamp = '2014-01-01T00:00:00'
    token = transport.get_static_config().auth_token

    good_signature = hashlib.sha1(
        ''.join(sorted([timestamp, nonce, token]))).hexdigest()

    params = {
        'signature': good_signature,
        'timestamp': timestamp,
        'nonce': nonce,
    }

    params.update(kw.get('params', {}))
    kw['params'] = params
    return raw_request(transport, *a, **kw)


def raw_request(transport, method, path='', params=None, data=None):
    if params is None:
        params = {}

    addr = transport.server.getHost()

    path += '?%s' % (urlencode(params),)
    url = 'http://%s:%s%s%s' % (
        addr.host,
        addr.port,
        transport.get_static_config().web_path,
        path)
    return http_request_full(url, method=method, data=data)


class WeChatTestCase(VumiTestCase):

    def setUp(self):
        self.tx_helper = self.add_helper(TransportHelper(WeChatTransport))
        self.request_queue = DeferredQueue()
        self.fake_http = FakeHttpServer(self.handle_api_request)
        self.api_url = 'https://api.wechat.com/cgi-bin/'

    def handle_api_request(self, request):
        self.assertEqual(request.path[:len(self.api_url)], self.api_url)
        self.request_queue.put(request)
        return NOT_DONE_YET

    @inlineCallbacks
    def get_transport(self, **config):
        defaults = {
            'auth_token': 'token',
            'twisted_endpoint': 'tcp:0',
            'wechat_appid': 'appid',
            'wechat_secret': 'secret',
            'embed_user_profile': False,
            'publish_status': True,
        }
        defaults.update(config)
        transport = yield self.tx_helper.get_transport(defaults)
        transport.agent_factory = self.fake_http.get_agent
        returnValue(transport)

    @inlineCallbacks
    def get_transport_with_access_token(self, access_token, **config):
        transport = yield self.get_transport(**config)
        yield transport.redis.set(WeChatTransport.ACCESS_TOKEN_KEY,
                                  access_token)
        returnValue(transport)


class TestWeChatInboundMessaging(WeChatTestCase):

    @inlineCallbacks
    def test_auth_success(self):
        transport = yield self.get_transport()
        resp = yield request(
            transport, "GET", params={
                'echostr': 'success'
            })
        self.assertEqual(resp.delivered_body, 'success')

    @inlineCallbacks
    def test_auth_fail(self):
        transport = yield self.get_transport_with_access_token('foo')
        resp = yield request(
            transport, "GET", params={
                'signature': 'foo',
                'echostr': 'success'
            })
        self.assertNotEqual(resp.delivered_body, 'success')

    @inlineCallbacks
    def test_inbound_text_message(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            1234567890123456
            
            """.strip())

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        reply_msg = yield self.tx_helper.make_dispatch_reply(
            msg, 'foo')

        resp = yield resp_d
        reply = WeChatXMLParser.parse(resp.delivered_body)
        self.assertEqual(reply.to_user_name, 'fromUser')
        self.assertEqual(reply.from_user_name, 'toUser')
        self.assertTrue(int(reply.create_time) > 1348831860)
        self.assertTrue(isinstance(reply, TextMessage))

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(ack['event_type'], 'ack')
        self.assertEqual(ack['user_message_id'], reply_msg['message_id'])
        self.assertEqual(ack['sent_message_id'], reply_msg['message_id'])

        [status] = yield self.tx_helper.wait_for_dispatched_statuses(1)

        self.assertEquals(status['status'], 'ok')
        self.assertEquals(status['component'], 'inbound')
        self.assertEquals(status['type'], 'good_request')
        self.assertEquals(status['message'], 'Good request received')

    @inlineCallbacks
    def test_inbound_bad_request(self):
        transport = yield self.get_transport_with_access_token('foo')
        yield raw_request(
            transport, 'POST', params={'bad': 'params'}, data="""
            
            
            
            1348831860
            
            
            
            """.strip())
        [status] = yield self.tx_helper.wait_for_dispatched_statuses(1)

        self.assertEquals(status['status'], 'down')
        self.assertEquals(status['component'], 'inbound')
        self.assertEquals(status['type'], 'bad_request')
        self.assertEquals(status['message'], 'Bad request received')

    @inlineCallbacks
    def test_inbound_event_subscribe_message(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp = yield request(
            transport, 'POST', data="""
                
                    
                        
                    
                    
                        
                    
                    1395130515
                    
                        
                    
                    
                        
                    
                    
                
                """)
        self.assertEqual(resp.code, http.OK)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(
            msg['session_event'], TransportUserMessage.SESSION_NEW)
        self.assertEqual(msg['transport_metadata'], {
            'wechat': {
                'Event': 'subscribe',
                'EventKey': '',
                'FromUserName': 'fromUser',
                'MsgType': 'event',
                'ToUserName': 'toUser'
            }
        })

    @inlineCallbacks
    def test_inbound_menu_event_click_message(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp = yield request(
            transport, 'POST', data="""
                
                
                
                123456789
                
                
                
                
                """.strip())
        self.assertEqual(resp.code, http.OK)
        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        self.assertEqual(
            msg['session_event'], TransportUserMessage.SESSION_NEW)
        self.assertEqual(msg['transport_metadata'], {
            'wechat': {
                'Event': 'CLICK',
                'EventKey': 'EVENTKEY',
                'FromUserName': 'fromUser',
                'MsgType': 'event',
                'ToUserName': 'toUser'
            }
        })

        self.assertEqual(msg['to_addr'], 'toUser@EVENTKEY')

    @inlineCallbacks
    def test_inbound_menu_event_view_message(self):
        transport = yield self.get_transport_with_access_token('foo')

        with LogCatcher() as lc:
            resp = yield request(
                transport, 'POST', data="""
                    
                    
                    
                    123456789
                    
                    
                    
                    
                    """.strip())
            self.assertEqual(resp.code, http.OK)
            [] = self.tx_helper.get_dispatched_inbound()
            msg = lc.messages()[0]
            self.assertEqual(
                msg,
                'fromUser clicked on http://www.gotvafrica.com/mobi/home.aspx')

    @inlineCallbacks
    def test_unsupported_message_type(self):
        transport = yield self.get_transport_with_access_token('foo')

        response = yield request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            1234567890123456
            
            """.strip())

        self.assertEqual(
            response.code, http.BAD_REQUEST)
        self.assertEqual(
            response.delivered_body,
            "Unsupported MsgType: THIS_IS_UNSUPPORTED")
        self.assertEqual(
            [],
            self.tx_helper.get_dispatched_inbound())


class TestWeChatOutboundMessaging(WeChatTestCase):

    def dispatch_push_message(self, content, wechat_md, **kwargs):
        helper_metadata = kwargs.get('helper_metadata', {})
        wechat_metadata = helper_metadata.setdefault('wechat', {})
        wechat_metadata.update(wechat_md)
        return self.tx_helper.make_dispatch_outbound(
            content, helper_metadata=helper_metadata, **kwargs)

    @inlineCallbacks
    def test_ack_push_text_message(self):
        yield self.get_transport_with_access_token('foo')

        msg_d = self.dispatch_push_message('foo', {}, to_addr='toaddr')

        request = yield self.request_queue.get()
        self.assertEqual(request.path, self.api_url + 'message/custom/send')
        self.assertEqual(request.args, {
            'access_token': ['foo']
        })
        self.assertEqual(json.load(request.content), {
            'touser': 'toaddr',
            'msgtype': 'text',
            'text': {
                'content': 'foo'
            }
        })
        request.finish()

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        msg = yield msg_d
        self.assertEqual(ack['event_type'], 'ack')
        self.assertEqual(ack['user_message_id'], msg['message_id'])

    @inlineCallbacks
    def test_nack_push_text_message(self):
        yield self.get_transport_with_access_token('foo')
        msg_d = self.dispatch_push_message('foo', {})

        # fail the API request
        request = yield self.request_queue.get()
        request.setResponseCode(http.BAD_REQUEST)
        request.finish()

        msg = yield msg_d
        [nack] = yield self.tx_helper.wait_for_dispatched_events(1)
        self.assertEqual(
            nack['user_message_id'], msg['message_id'])
        self.assertEqual(nack['event_type'], 'nack')
        self.assertEqual(nack['nack_reason'], 'Received status code: 400')

    @inlineCallbacks
    def test_ack_push_inferred_news_message(self):
        yield self.get_transport_with_access_token('foo')
        # news is a collection or URLs apparently
        content = ('This is an awesome link for you! http://www.wechat.com/ '
                   'Go visit it.')
        msg_d = self.dispatch_push_message(
            content, {}, to_addr='toaddr')

        request = yield self.request_queue.get()
        self.assertEqual(request.path, self.api_url + 'message/custom/send')
        self.assertEqual(request.args, {
            'access_token': ['foo']
        })
        self.assertEqual(json.load(request.content), {
            'touser': 'toaddr',
            'msgtype': 'news',
            'news': {
                'articles': [
                    {
                        'title': 'This is an awesome link for you! ',
                        'url': 'http://www.wechat.com/',
                        'description': content,
                    }
                ]
            }
        })

        request.finish()

        [ack] = yield self.tx_helper.wait_for_dispatched_events(1)
        msg = yield msg_d
        self.assertEqual(ack['event_type'], 'ack')
        self.assertEqual(ack['user_message_id'], msg['message_id'])


class TestWeChatAccessToken(WeChatTestCase):

    @inlineCallbacks
    def test_request_new_access_token(self):
        transport = yield self.get_transport()
        config = transport.get_static_config()

        d = transport.request_new_access_token()

        req = yield self.request_queue.get()
        self.assertEqual(req.path, self.api_url + 'token')
        self.assertEqual(req.args, {
            'grant_type': ['client_credential'],
            'appid': [config.wechat_appid],
            'secret': [config.wechat_secret],
        })
        req.write(json.dumps({
            'access_token': 'the_access_token',
            'expires_in': 7200
        }))
        req.finish()

        access_token = yield d
        self.assertEqual(access_token, 'the_access_token')
        cached_token = yield transport.redis.get(
            WeChatTransport.ACCESS_TOKEN_KEY)
        self.assertEqual(cached_token, 'the_access_token')
        expiry = yield transport.redis.ttl(WeChatTransport.ACCESS_TOKEN_KEY)
        self.assertTrue(int(7200 * 0.8) < expiry <= int(7200 * 0.9))

    @inlineCallbacks
    def test_get_cached_access_token(self):
        transport = yield self.get_transport()
        yield transport.redis.set(WeChatTransport.ACCESS_TOKEN_KEY, 'foo')
        access_token = yield transport.get_access_token()
        self.assertEqual(access_token, 'foo')
        # Empty request queue means no WeChat API calls were made
        self.assertEqual(self.request_queue.size, None)


class TestWeChatAddrMasking(WeChatTestCase):

    @inlineCallbacks
    def test_default_mask(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            1234567890123456
            
            """.strip())

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(msg, 'foo')

        self.assertEqual(
            (yield transport.get_addr_mask('fromUser')),
            transport.DEFAULT_MASK)
        self.assertEqual(msg['to_addr'], 'toUser@default')
        yield resp_d

    @inlineCallbacks
    def test_mask_switching_on_event_key(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp = yield request(
            transport, 'POST', data="""
                
                
                
                123456789
                
                
                
                
                """.strip())
        self.assertEqual(resp.code, http.OK)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        self.assertEqual(
            msg['session_event'], TransportUserMessage.SESSION_NEW)

        self.assertEqual(
            (yield transport.get_addr_mask('fromUser')), 'EVENTKEY')
        self.assertEqual(msg['to_addr'], 'toUser@EVENTKEY')

    @inlineCallbacks
    def test_mask_caching_on_text_message(self):
        transport = yield self.get_transport_with_access_token('foo')
        yield transport.cache_addr_mask('fromUser', 'foo')

        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            1234567890123456
            
            """.strip())

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(msg, 'foo')

        self.assertEqual(msg['to_addr'], 'toUser@foo')
        yield resp_d

    @inlineCallbacks
    def test_mask_clearing_on_session_end(self):
        transport = yield self.get_transport_with_access_token('foo')
        yield transport.cache_addr_mask('fromUser', 'foo')

        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            1234567890123456
            
            """.strip())

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(
            msg, 'foo', session_event=TransportUserMessage.SESSION_CLOSE)

        self.assertEqual(msg['to_addr'], 'toUser@foo')
        self.assertEqual(
            (yield transport.get_addr_mask('fromUser')),
            transport.DEFAULT_MASK)
        yield resp_d

    @inlineCallbacks
    def test_inbound_event_unsubscribe_message(self):
        transport = yield self.get_transport_with_access_token('foo')
        yield transport.cache_addr_mask('fromUser', 'foo')

        resp = yield request(
            transport, 'POST', data="""
                
                    
                        
                    
                    
                        
                    
                    1395130515
                    
                        
                    
                    
                        
                    
                    
                
                """)
        self.assertEqual(resp.code, http.OK)
        self.assertEqual([], self.tx_helper.get_dispatched_inbound())
        self.assertEqual(
            (yield transport.get_addr_mask('fromUser')),
            transport.DEFAULT_MASK)


class TestWeChatMenuCreation(WeChatTestCase):

    MENU_TEMPLATE = """
    button:
      - name: Daily Song
        type: click
        key: V1001_TODAY_MUSIC

      - name: ' Artist Profile'
        type: click
        key: V1001_TODAY_SINGER

      - name: Menu
        sub_button:
          - name: Search
            type: view
            url: 'http://www.soso.com/'
          - name: Video
            type: view
            url: 'http://v.qq.com/'
          - name: Like us
            type: click
            key: V1001_GOOD
    """
    MENU = yaml.safe_load(MENU_TEMPLATE)

    @inlineCallbacks
    def test_create_new_menu_success(self):
        transport = yield self.get_transport_with_access_token('foo')

        d = transport.create_wechat_menu('foo', self.MENU)
        req = yield self.request_queue.get()
        self.assertEqual(req.path, self.api_url + 'menu/create')
        self.assertEqual(req.args, {
            'access_token': ['foo'],
        })

        self.assertEqual(json.load(req.content), self.MENU)
        req.write(json.dumps({'errcode': 0, 'errmsg': 'ok'}))
        req.finish()

        yield d

    @inlineCallbacks
    def test_create_new_menu_failure(self):
        transport = yield self.get_transport_with_access_token('foo')
        d = transport.create_wechat_menu('foo', self.MENU)

        req = yield self.request_queue.get()
        req.write(json.dumps({
            'errcode': 40018,
            'errmsg': 'invalid button name size',
        }))
        req.finish()

        exception = yield self.assertFailure(d, WeChatApiException)
        self.assertEqual(
            exception.message,
            ('Received errcode: 40018, errmsg: invalid button name '
             'size when creating WeChat Menu.'))


class TestWeChatInferMessage(WeChatTestCase):

    @inlineCallbacks
    def test_infer_news_message(self):
        transport = yield self.get_transport_with_access_token('foo')

        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            10234567890123456
            
            """.strip())

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(
            msg, ('To continue you need to accept the T&Cs available at '
                  'http://tandcurl.com/ . Have you read and do you accept '
                  'the terms and conditions?\n1. Yes\n2. No'))

        resp = yield resp_d
        self.assertTrue(
            'http://tandcurl.com/' in resp.delivered_body)
        self.assertTrue(
            'To continue you need to accept the T&Cs available '
            'at '
            in resp.delivered_body)
        self.assertTrue(
            'To continue you need to accept the T&Cs '
            'available at http://tandcurl.com/ . Have you read and do '
            'you accept the terms and conditions?\n1. Yes\n2. No'
            ''
            in resp.delivered_body)


class TestWeChatEmbedUserProfile(WeChatTestCase):

    @inlineCallbacks
    def test_embed_user_profile(self):
        # NOTE: From http://admin.wechat.com/wiki/index.php?title=User_Profile
        user_profile = {
            "subscribe": 1,
            "openid": "fromUser",
            "nickname": "Band",
            "sex": 1,
            "language": "zh_CN",
            "city": "Guangzhou",
            "province": "Guangdong",
            "country": "China",
            "headimgurl": (
                "http://wx.qlogo.cn/mmopen/g3MonUZtNHkdmzicIlibx6iaFqAc56v"
                "xLSUfpb6n5WKSYVY0ChQKkiaJSgQ1dZuTOgvLLrhJbERQQ4eMsv84eavH"
                "iaiceqxibJxCfHe/0"),
            "subscribe_time": 1382694957
        }

        transport = yield self.get_transport_with_access_token(
            'foo', embed_user_profile=True)
        resp_d = request(
            transport, 'POST', data="""
            
            
            
            1348831860
            
            
            10234567890123456
            
            """.strip())

        req = yield self.request_queue.get()
        self.assertEqual(req.args, {
            'access_token': ['foo'],
            'lang': ['en'],
            'openid': ['fromUser'],
        })

        req.write(json.dumps(user_profile))
        req.finish()

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(msg, 'Bye!')

        self.assertEqual(
            msg['transport_metadata']['wechat']['UserProfile'],
            user_profile)

        up_key = transport.user_profile_key('fromUser')
        cached_up = yield transport.redis.get(up_key)
        config = transport.get_static_config()
        self.assertEqual(json.loads(cached_up), user_profile)
        self.assertTrue(0
                        < (yield transport.redis.ttl(up_key))
                        <= config.embed_user_profile_lifetime)
        yield resp_d


class TestWeChatInsanity(WeChatTestCase):

    @inlineCallbacks
    def test_double_delivery_handling(self):
        transport = yield self.get_transport_with_access_token('foo')

        xml = """
        
        
        
        1348831860
        
        
        1234567890123456
        
        """.strip()

        resp1_d = request(transport, 'POST', data=xml)

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)
        yield self.tx_helper.make_dispatch_reply(msg, 'foo')

        resp1 = yield resp1_d
        reply1 = WeChatXMLParser.parse(resp1.delivered_body)
        self.assertTrue(isinstance(reply1, TextMessage))

        # this one should bounce straight away
        resp2 = yield request(transport, 'POST', data=xml)
        self.assertEqual(resp2.code, http.OK)
        reply2 = WeChatXMLParser.parse(resp2.delivered_body)
        self.assertEqual(reply1.to_xml(), reply2.to_xml())
        # Nothing new was added
        self.assertEqual(1, len(self.tx_helper.get_dispatched_inbound()))

    @inlineCallbacks
    def test_close_double_delivery_handling(self):
        transport = yield self.get_transport_with_access_token('foo')

        xml = """
        
        
        
        1348831860
        
        
        1234567890123456
        
        """.strip()

        resp1_d = request(transport, 'POST', data=xml % ('first',))
        resp2_d = task.deferLater(reactor, 0.1, request, transport, 'POST',
                                  data=xml % ('second',))

        [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1)

        # the second request should return first
        resp2 = yield resp2_d
        self.assertEqual(resp2.code, http.OK)
        self.assertEqual(resp2.delivered_body, '')

        yield self.tx_helper.make_dispatch_reply(msg, 'foo')

        resp1 = yield resp1_d
        reply1 = WeChatXMLParser.parse(resp1.delivered_body)
        self.assertTrue(isinstance(reply1, TextMessage))

    @inlineCallbacks
    def test_locking(self):
        transport1 = yield self.get_transport_with_access_token('foo')
        transport2 = yield self.get_transport_with_access_token('foo')
        transport3 = yield self.get_transport_with_access_token('foo')

        if any([isinstance(tx.redis._client, FakeRedis)
                for tx in [transport1, transport2, transport3]]):
            raise SkipTest(
                'FakeRedis setnx is not atomic. '
                'See https://github.com/praekelt/vumi/issues/789')

        locks = yield gatherResults([
            transport1.mark_as_seen_recently('msg-id'),
            transport2.mark_as_seen_recently('msg-id'),
            transport3.mark_as_seen_recently('msg-id'),
        ])
        self.assertEqual(sorted(locks), [0, 0, 1])
PK=JG(vumi/transports/wechat/tests/__init__.pyPK=JG&q"vumi/transports/parlayx/parlayx.py# -*- test-case-name: vumi.transports.parlayx.tests.test_parlayx -*-
import uuid

from twisted.internet.defer import inlineCallbacks, returnValue

from vumi import log
from vumi.config import ConfigText, ConfigInt, ConfigBool
from vumi.transports.base import Transport
from vumi.transports.failures import TemporaryFailure, PermanentFailure
from vumi.transports.parlayx.client import (
    ParlayXClient, ServiceException, PolicyException)
from vumi.transports.parlayx.server import SmsNotificationService
from vumi.transports.parlayx.soaputil import SoapFault


class ParlayXTransportConfig(Transport.CONFIG_CLASS):
    web_notification_path = ConfigText(
        'Path to listen for delivery and receipt notifications on',
        static=True)
    web_notification_port = ConfigInt(
        'Port to listen for delivery and receipt notifications on',
        default=0, static=True)
    notification_endpoint_uri = ConfigText(
        'URI of the ParlayX SmsNotificationService in Vumi', static=True)
    short_code = ConfigText(
        'Service activation number or short code to receive deliveries for',
        static=True)
    remote_send_uri = ConfigText(
        'URI of the remote ParlayX SendSmsService', static=True)
    remote_notification_uri = ConfigText(
        'URI of the remote ParlayX SmsNotificationService', static=True)
    start_notifications = ConfigBool(
        'Start (and stop) the ParlayX notification service?', static=True)
    service_provider_service_id = ConfigText(
        'Provisioned service provider service identifier', static=True)
    service_provider_id = ConfigText(
        'Provisioned service provider identifier/username', static=True)
    service_provider_password = ConfigText(
        'Provisioned service provider password', static=True)


class ParlayXTransport(Transport):
    """ParlayX SMS transport.

    ParlayX is a defunkt standard web service API for telephone networks.
    See http://en.wikipedia.org/wiki/Parlay_X for an overview.

    .. warning::

       This transport has not been tested against another ParlayX
       implementation. If you use it, please provide feedback to the
       Vumi development team on your experiences.
    """

    CONFIG_CLASS = ParlayXTransportConfig
    transport_type = 'sms'

    def _create_client(self, config):
        """
        Create a `ParlayXClient` instance.
        """
        return ParlayXClient(
            service_provider_service_id=config.service_provider_service_id,
            service_provider_id=config.service_provider_id,
            service_provider_password=config.service_provider_password,
            short_code=config.short_code,
            endpoint=config.notification_endpoint_uri,
            send_uri=config.remote_send_uri,
            notification_uri=config.remote_notification_uri)

    @inlineCallbacks
    def setup_transport(self):
        config = self.get_static_config()
        log.info('Starting ParlayX transport: %s' % (self.transport_name,))
        self.web_resource = yield self.start_web_resources(
            [(SmsNotificationService(self.handle_raw_inbound_message,
                                     self.publish_delivery_report),
              config.web_notification_path)],
            config.web_notification_port)
        self._parlayx_client = self._create_client(config)
        if config.start_notifications:
            yield self._parlayx_client.start_sms_notification()

    @inlineCallbacks
    def teardown_transport(self):
        config = self.get_static_config()
        log.info('Stopping ParlayX transport: %s' % (self.transport_name,))
        yield self.web_resource.loseConnection()
        if config.start_notifications:
            yield self._parlayx_client.stop_sms_notification()

    def handle_outbound_message(self, message):
        """
        Send a text message via the ParlayX client.
        """
        log.info('Sending SMS via ParlayX: %r' % (message.to_json(),))
        transport_metadata = message.get('transport_metadata', {})
        d = self._parlayx_client.send_sms(
            message['to_addr'],
            message['content'],
            unique_correlator(message['message_id']),
            transport_metadata.get('linkid'))
        d.addErrback(self.handle_outbound_message_failure, message)
        d.addCallback(
            lambda requestIdentifier: self.publish_ack(
                message['message_id'], requestIdentifier))
        return d

    @inlineCallbacks
    def handle_outbound_message_failure(self, f, message):
        """
        Handle outbound message failures.

        `ServiceException`, `PolicyException` and client-class SOAP faults
        result in `PermanentFailure` being raised; server-class SOAP faults
        instances result in `TemporaryFailure` being raised; and other failures
        are passed through.
        """
        log.error(f, 'Sending SMS failure on ParlayX: %r' % (
            self.transport_name,))

        if not f.check(ServiceException, PolicyException):
            if f.check(SoapFault):
                # We'll give server-class unknown SOAP faults the benefit of
                # the doubt as far as temporary failures go.
                if f.value.code.endswith('Server'):
                    raise TemporaryFailure(f)

        yield self.publish_nack(message['message_id'], f.getErrorMessage())
        if f.check(SoapFault):
            # We've ruled out unknown SOAP faults, so this must be a permanent
            # failure.
            raise PermanentFailure(f)
        returnValue(f)

    def handle_raw_inbound_message(self, correlator, linkid, inbound_message):
        """
        Handle incoming text messages from `SmsNotificationService` callbacks.
        """
        log.info('Receiving SMS via ParlayX: %r: %r' % (
            correlator, inbound_message,))
        message_id = extract_message_id(correlator)
        return self.publish_message(
            message_id=message_id,
            content=inbound_message.message,
            to_addr=inbound_message.service_activation_number,
            from_addr=inbound_message.sender_address,
            provider='parlayx',
            transport_type=self.transport_type,
            transport_metadata=dict(linkid=linkid))


def unique_correlator(message_id, _uuid=None):
    """
    Construct a unique message identifier from an existing message
    identifier.

    This is necessary for the cases where a ``TransportMessage`` needs to
    be transmitted, since ParlayX wants unique identifiers for all sent
    messages.
    """
    if _uuid is None:
        _uuid = uuid.uuid4()
    return '%s:%s' % (message_id, _uuid)


def extract_message_id(correlator):
    """
    Extract the Vumi message identifier from a ParlayX correlator.
    """
    return correlator.split(':', 1)[0]
PK=JG.^6^6"vumi/transports/parlayx/xmlutil.py# -*- test-case-name: vumi.transports.parlayx.tests.test_xmlutil -*-
"""
XML convenience types and functions.

============
Introduction
============

In this domain-specific language, building on concepts from ``lxml.builder``,
the main goal is to improve the readability and structure of code that needs to
create XML documents programmatically, in particular XML that makes use of XML
namespaces. There are three main parts to consider in achieving this, starting
from the bottom working our way up.


-----------
1. Elements
-----------

`ElementMaker`, which `Element` is an instance of, is a basic XML element
factory that produces ElementTree element instances. The name (or tag) of the
element is specified as the first parameter to the `ElementMaker.element`
method. Children can be provided as positional parameters, a child is: text; an
ElementTree Element instance; a `dict` that will be applied as XML attributes
to the element; or a callable that returns any of the previous items.
Additionally, XML attributes can be passed as Python keyword arguments.

As a convenience, calling an `ElementMaker` instance is the same as invoking
the `ElementMaker.element` method.

    >>> from xml.etree.ElementTree import tostring
    >>> tostring(
    ... Element('parent', {'attr': 'value'},
    ...     Element('child1', 'content', attr2='value2')))
    'content'


------------------
2. Qualified names
------------------

`QualifedName` is a type that fills two roles: a way to represent a qualified
XML element name, either including a namespace or as a name in the local XML
namespace (element names that include a namespace are stored in Clark's
notation, e.g. ``{http://example.com}tag``); and an ElementTree element
factory.

As a convenience, calling an `QualifiedName` instance is the same as invoking
the `QualifiedName.element` method. While this bears some similarity to
`ElementMaker`, it plays an important role for `Namespace`.

    >>> from xml.etree.ElementTree import tostring
    >>> tostring(
    ... QualifiedName('{http://example.com}parent')(
    ...     QualifiedName('child1', 'content')))
    'content'
    ''


-------------
3. Namespaces
-------------

Again, `Namespace` fills two roles: a way to represent an XML namespace, with
an optional XML namespace prefix; and a `QualifiedName` factory.

Attribute access on a `Namespace` instance will produce a new `QualifiedName`
instance whose element name will be the name of the accessed attribute
qualified in the `Namespace`'s specified XML namespace. `LocalNamespace` is
a convenience for ``Namespace(None)``, which produces `QualifiedName` instances
in the local XML namespace.

    >>> from xml.etree.ElementTree import tostring
    >>> NS = Namespace('http://example.com', 'ex')
    >>> tostring(
    ... NS.parent({'attr': 'value'},
    ...     NS.child1('content'),
    ...     LocalNamespace.child2('content2')))
    ''
    'contentcontent2'

XML attributes may be qualified too:

    >>> from xml.etree.ElementTree import tostring
    >>> NS = Namespace('http://example.com', 'ex')
    >>> tostring(
    ... NS.parent({NS.attr: 'value'}))
    ''
"""
from collections import defaultdict
from xml.etree import ElementTree as etree

try:
    from xml.etree.ElementTree import register_namespace
    register_namespace  # For Pyflakes.
except ImportError:
    # This doesn't exist before Python 2.7, see
    # http://effbot.org/zone/element-namespaces.htm#element-tree-representation

    def register_namespace(prefix, uri):
        etree._namespace_map[uri] = prefix

try:
    from xml.etree.ElementTree import ParseError
    ParseError  # For Pyflakes.
except ImportError:
    from xml.parsers.expat import ExpatError as ParseError
    ParseError  # For Pyflakes.


class Namespace(object):
    """
    XML namespace.

    Attribute access on `Namespace` instances will produce `QualifiedName`
    instances in this XML namespace. If `uri` is `None`, the names will be in
    the XML local namespace.

    :ivar __uri: XML namespace URI, or `None` for the local namespace.
    :ivar __prefix: XML namespace prefix, or `None` for no predefined prefix.
    """
    def __init__(self, uri, prefix=None):
        # We want to avoid polluting the instance dict as much as possible,
        # since attribute access is how we produce QualifiedNames.
        self.__uri = uri
        self.__prefix = prefix
        if self.__prefix is not None:
            register_namespace(self.__prefix, self.__uri)

    def __str__(self):
        return self.__uri

    def __repr__(self):
        return '<%s uri=%r prefix=%r>' % (
            type(self).__name__, self.__uri, self.__prefix)

    def __eq__(self, other):
        if not isinstance(other, Namespace):
            return False
        return other.__uri == self.__uri and other.__prefix == self.__prefix

    def __getattr__(self, tag):
        if self.__uri is None:
            qname = QualifiedName(tag)
        else:
            qname = QualifiedName(self.__uri, tag)
        # Put this into the instance dict, to avoid doing this again for the
        # same result.
        setattr(self, tag, qname)
        return qname


class QualifiedName(etree.QName, object):
    """
    A qualified XML name.

    As a convenience, calling a `QualifiedName` instance is the same as
    invoking `QualifiedName.element` on an instance.

    :ivar text: Qualified name in Clark's notation.
    """
    def __repr__(self):
        xmlns, local = split_qualified(self.text)
        return '<%s xmlns=%r local=%r>' % (type(self).__name__, xmlns, local)

    def __eq__(self, other):
        if not isinstance(other, etree.QName):
            return False
        return other.text == self.text

    def element(self, *children, **attrib):
        """
        Create an ElementTree element.

        The element tag name will be this qualified name's
        `QualifiedName.text` value.

        :param *children: Child content or elements.
        :param **attrib: Element XML attributes.
        :return: ElementTree element.
        """
        return Element(self.text, *children, **attrib)

    __call__ = element


class ElementMaker(object):
    """
    An ElementTree element factory.

    As a convenience, calling an `ElementMaker` instance is the same as
    invoking `ElementMaker.element` on an instance.
    """
    def __init__(self, typemap=None):
        """
        :param typemap:
            Mapping of Python types to callables, taking an ElementTree element
            and some child value. This map will be consulted in
            `ElementMaker.element` for child items.
        """
        self._makeelement = etree.Element
        self._typemap = {
            list: self._add_children,
            dict: self._set_attributes,
            unicode: self._add_text,
            str: self._add_text}
        if typemap is not None:
            self._typemap.update(typemap)

    def _add_children(self, elem, children):
        """
        Add children to an element.
        """
        for child in children:
            self._handle_child(elem, child)

    def _set_attributes(self, elem, attrib):
        """
        Set XML attributes on an element.

        :param elem: Parent ElementTree element.

        :param attrib:
            Mapping of text attribute names, or `xml.etree.ElementTree.QName`
            instances, to attribute values.
        """
        for k, v in attrib.items():
            # XXX: Do something smarter with k and v? lxml does some
            # transformation stuff.
            elem.set(k, v)

    def _add_text(self, elem, text):
        """
        Add text content to an element.

        :param elem: Parent ElementTree element.

        :param text: Text content to add.
        """
        # If the element has any children we need to add the text to the
        # tail.
        if len(elem):
            elem[-1] = (elem[-1].tail or '') + text
        else:
            elem.text = (elem.text or '') + text

    def _handle_child(self, parent, child):
        """
        Add a child element to a parent element.

        Child elements can be any of the following:

        * A callable, that will be called with no parameters;
        * An ElementTree element;
        * `str` or `unicode` text content;
        * A `list` containing any of the above.
        """
        if callable(child):
            child = child()
        t = self._typemap.get(type(child))
        if t is None:
            if etree.iselement(child):
                parent.append(child)
                return
            raise TypeError('Unknown child type: %r' % (child,))

        v = t(parent, child)
        if v is not None:
            self._handle_child(parent, v)

    def element(self, tag, *children, **attrib):
        """
        Create an ElementTree element.

        :param tag: Tag name or `QualifiedName` instance.
        :param *children: Child content or elements.
        :param **attrib: Element XML attributes.
        :return: ElementTree element.
        """
        if isinstance(tag, etree.QName):
            tag = tag.text

        elem = self._makeelement(tag)

        if attrib:
            self._set_attributes(elem, attrib)

        for child in children:
            self._handle_child(elem, child)

        return elem

    __call__ = element


def elemfind(elem, path):
    """
    Helper version of `xml.etree.ElementTree.Element.find` that understands
    `xml.etree.ElementTree.QName`.
    """
    return next(iter(elemfindall(elem, path)), None)


def elemfindall(elem, path):
    """
    Helper version of `xml.etree.ElementTree.Element.findall` that understands
    `xml.etree.ElementTree.QName`.
    """
    if isinstance(path, etree.QName):
        path = path.text
    return elem.findall(path)


def split_qualified(fqname):
    """
    Split a fully qualified element name, in Clark's notation, into its URI and
    local name components.

    :param fqname: Fully qualified name in Clark's notation.
    :return: 2-tuple containing the namespace URI and local tag name.
    """
    if fqname and fqname[0] == '{':
        return tuple(fqname[1:].split('}'))
    return None, fqname


def gettext(elem, path, default=None, parse=None):
    """
    Get the text of an `ElementTree` element and optionally transform it.

    If `default` and `parse` are not `None`, `parse` will be called with
    `default`.

    :param elem:
        ElementTree element to find `path` on.

    :type path: unicode
    :param path:
        Path to the sub-element.

    :param default:
        A default value to use if the `text` attribute on the found element
        is `None`, or the element is not found; defaults to `None`.

    :type  parse: callable
    :param parse:
        A callable to transform the found element's text.
    """
    return next(gettextall(elem, path, default, parse), default)


def gettextall(elem, path, default=None, parse=None):
    """
    Get the text of all matching `ElementTree` elements and optionally
    transform them.

    If `default` and `parse` are not `None`, `parse` will be called with
    `default`.

    :param elem:
        ElementTree element to find `path` on.

    :type path: unicode
    :param path:
        Path to the sub-elements.

    :param default:
        A default value to use if the `text` attribute found on an element
        is `None`, or an element is not found; defaults to `None`.

    :type  parse: callable
    :param parse:
        A callable to transform the found element's text.
    """
    es = elemfindall(elem, path)
    for e in es:
        if e is None or e.text is None:
            result = default
        else:
            result = unicode(e.text).strip()

        if result is not None and parse is not None:
            result = parse(result)

        yield result


def element_to_dict(root):
    """
    Convert an ElementTree element into a dictionary structure.

    Text content is stored against a special key, ``#text``, unless the element
    contains only text and no attributes.

    Attributes are converted into dictionaries of the attribute name, prefixed
    with ``@``, keyed against the attribute value, which are keyed against the
    root element's name.

    Child elements are recursively turned into dictionaries. Child elements
    with the same name are coalesced into a ``list``.

    :param root: ElementTree element root to convert into a ``dict``.
    :return: ``dict`` representation of `root`.
    """
    d = {root.tag: {} if root.attrib else None}
    children = root.getchildren()
    if children:
        dd = defaultdict(list)
        for child_dict in map(element_to_dict, children):
            for k, v in child_dict.iteritems():
                dd[k].append(v)
        d = {root.tag: dict((k, v[0] if len(v) == 1 else v)
                            for k, v in dd.iteritems())}

    if root.attrib:
        d[root.tag].update(
            ('@' + str(k), v) for k, v in root.attrib.iteritems())

    if root.text:
        text = root.text.strip()
        if children or root.attrib:
            if text:
                d[root.tag]['#text'] = text
        else:
            d[root.tag] = text

    return d


LocalNamespace = Namespace(None)
Element = ElementMaker()
parse_document = etree.parse
fromstring = etree.fromstring
tostring = etree.tostring


__all__ = [
    'Namespace', 'QualifiedName', 'ElementMaker', 'elemfind', 'elemfindall',
    'split_qualified', 'gettext', 'gettextall', 'LocalNamespace', 'Element',
    'parse_document', 'fromstring', 'tostring', 'element_to_dict']
PK=JG\EN!vumi/transports/parlayx/client.py# -*- test-case-name: vumi.transports.parlayx.tests.test_client -*-
import hashlib
import uuid
from collections import namedtuple
from datetime import datetime

from vumi.transports.parlayx.soaputil import perform_soap_request, SoapFault
from vumi.transports.parlayx.xmlutil import (
    gettext, gettextall, Namespace, LocalNamespace as L)


PARLAYX_COMMON_NS = Namespace(
    'http://www.csapi.org/schema/parlayx/common/v2_1', 'parlayx_common')
SEND_NS = Namespace(
    'http://www.csapi.org/schema/parlayx/sms/send/v2_2/local', 'send')
NOTIFICATION_MANAGER_NS = Namespace(
    'http://www.csapi.org/schema/parlayx/sms/notification_manager/v2_3/local',
    'nm')
PARLAYX_HEAD_NS = Namespace(
    'http://www.huawei.com.cn/schema/common/v2_1', 'parlayx_head')


def format_address(msisdn):
    """
    Format a normalized MSISDN as a URI that ParlayX will accept.
    """
    if not msisdn.startswith('+'):
        raise ValueError('Only international format addresses are supported')
    return 'tel:' + msisdn[1:]


def format_timestamp(when):
    """
    Format a `datetime` instance timestamp according to ParlayX
    requirements.
    """
    return when.strftime('%Y%m%d%H%M%S')


def make_password(service_provider_id, service_provider_password,
                   timestamp):
    """
    Build a time-sensitive password for a request.
    """
    return hashlib.md5(
        service_provider_id +
        service_provider_password +
        timestamp).hexdigest()


class _ParlayXFaultDetail(namedtuple('_ParlayXFaultDetail',
                                   ['message_id', 'text', 'variables'])):
    """
    Generic ParlayX SOAP fault detail.
    """
    tag = None

    @classmethod
    def from_element(cls, element):
        if element.tag != cls.tag.text:
            return None
        return cls(
            message_id=gettext(element, 'messageId'),
            text=gettext(element, 'text'),
            variables=list(gettextall(element, 'variables')))


class ServiceExceptionDetail(_ParlayXFaultDetail):
    """
    ParlayX service exception detail.
    """
    tag = PARLAYX_COMMON_NS.ServiceExceptionDetail


class ServiceException(SoapFault):
    """
    ParlayX service exception.
    """
    detail_type = ServiceExceptionDetail


class PolicyExceptionDetail(_ParlayXFaultDetail):
    """
    ParlayX policy exception detail.
    """
    tag = PARLAYX_COMMON_NS.PolicyExceptionDetail


class PolicyException(SoapFault):
    """
    ParlayX policy exception.
    """
    detail_type = PolicyExceptionDetail


class ParlayXClient(object):
    """
    ParlayX SOAP client.

    :ivar _service_correlator:
        A unique identifier for this service, used when registering and
        deregistering for SMS notifications.
    """
    def __init__(self, service_provider_service_id, service_provider_id,
                 service_provider_password, short_code, endpoint, send_uri,
                 notification_uri, perform_soap_request=perform_soap_request):
        """
        :param service_provider_service_id:
            Provisioned service provider service identifier.
        :param service_provider_id:
            Provisioned service provider identifier/username.
        :param service_provider_password:
            Provisioned service provider password.
        :param short_code:
            SMS shortcode or service activation number.
        :param endpoint:
            URI to which the remote ParlayX service will deliver notification
            messages.
        :param send_uri:
            URI for the ParlayX ``SendSmsService`` SOAP endpoint.
        :param notification_uri:
            URI for the ParlayX ``SmsNotificationService`` SOAP endpoint.
        """
        self.service_provider_service_id = service_provider_service_id
        self.service_provider_id = service_provider_id
        self.service_provider_password = service_provider_password
        self.short_code = short_code
        self.endpoint = endpoint
        self.send_uri = send_uri
        self.notification_uri = notification_uri
        self.perform_soap_request = perform_soap_request
        self._service_correlator = uuid.uuid4().hex

    def _now(self):
        """
        The current date and time.
        """
        return datetime.now()

    def _make_header(self, service_subscription_address=None, linkid=None):
        """
        Create a ``RequestSOAPHeader`` element.

        :param service_subscription_address:
            Service subscription address for the ``OA`` header field, this
            field is omitted if its value is ``None``.
        """
        NS = PARLAYX_HEAD_NS
        other = []
        timestamp = format_timestamp(self._now())
        if service_subscription_address is not None:
            other.append(NS.OA(format_address(service_subscription_address)))
        if linkid is not None:
            other.append(NS.linkid(linkid))
        return NS.RequestSOAPHeader(
            NS.spId(self.service_provider_id),
            NS.spPassword(
                make_password(
                    self.service_provider_id,
                    self.service_provider_password,
                    timestamp)),
            NS.serviceId(self.service_provider_service_id),
            NS.timeStamp(timestamp),
            *other)

    def start_sms_notification(self):
        """
        Register a notification delivery endpoint with the remote ParlayX
        service.
        """
        body = NOTIFICATION_MANAGER_NS.startSmsNotification(
            NOTIFICATION_MANAGER_NS.reference(
                L.endpoint(self.endpoint),
                L.interfaceName('notifySmsReception'),
                L.correlator(self._service_correlator)),
            NOTIFICATION_MANAGER_NS.smsServiceActivationNumber(
                self.short_code))
        header = self._make_header()
        return self.perform_soap_request(
            uri=self.notification_uri,
            action='',
            body=body,
            header=header,
            expected_faults=[ServiceException])

    def stop_sms_notification(self):
        """
        Deregister notification delivery with the remote ParlayX service.
        """
        body = NOTIFICATION_MANAGER_NS.stopSmsNotification(
            L.correlator(self._service_correlator))
        header = self._make_header()
        return self.perform_soap_request(
            uri=self.notification_uri,
            action='',
            body=body,
            header=header,
            expected_faults=[ServiceException])

    def send_sms(self, to_addr, content, message_id, linkid=None):
        """
        Send an SMS.
        """
        def _extractRequestIdentifier((body, header)):
            return gettext(body, './/' + str(SEND_NS.result), default='')

        body = SEND_NS.sendSms(
            SEND_NS.addresses(format_address(to_addr)),
            SEND_NS.message(content),
            SEND_NS.receiptRequest(
                L.endpoint(self.endpoint),
                L.interfaceName(u'SmsNotification'),
                L.correlator(message_id)))
        header = self._make_header(
            service_subscription_address=to_addr,
            linkid=linkid)
        d = self.perform_soap_request(
            uri=self.send_uri,
            action='',
            body=body,
            header=header,
            expected_faults=[PolicyException, ServiceException])
        d.addCallback(_extractRequestIdentifier)
        return d


__all__ = [
    'format_address', 'ServiceExceptionDetail', 'ServiceException',
    'PolicyExceptionDetail', 'PolicyException', 'ParlayXClient']
PK=JGlJ4xx#vumi/transports/parlayx/__init__.py"""
ParlayX SOAP API.
"""
from vumi.transports.parlayx.parlayx import ParlayXTransport


__all__ = ['ParlayXTransport']
PK=JG/!vumi/transports/parlayx/server.py# -*- test-case-name: vumi.transports.parlayx.tests.test_server -*-
import iso8601
from collections import namedtuple

from twisted.internet.defer import maybeDeferred, fail
from twisted.python import log
from twisted.python.constants import Values, ValueConstant
from twisted.web import http
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from vumi.transports.parlayx.client import PARLAYX_COMMON_NS
from vumi.transports.parlayx.soaputil import (
    soap_envelope, unwrap_soap_envelope, soap_fault, SoapFault)
from vumi.transports.parlayx.xmlutil import (
    Namespace, elemfind, gettext, split_qualified, parse_document, tostring)
from vumi.utils import normalize_msisdn


NOTIFICATION_NS = Namespace(
    'http://www.csapi.org/schema/parlayx/sms/notification/v2_2/local', 'loc')


def normalize_address(address):
    """
    Normalize a ParlayX address.
    """
    if address.startswith('tel:'):
        address = address[4:]
    return normalize_msisdn(address)


class DeliveryStatus(Values):
    """
    ParlayX `DeliveryStatus` enumeration type.
    """
    DeliveredToNetwork = ValueConstant('delivered')
    DeliveryUncertain = ValueConstant('pending')
    DeliveryImpossible = ValueConstant('failed')
    MessageWaiting = ValueConstant('pending')
    DeliveredToTerminal = ValueConstant('delivered')
    DeliveryNotificationNotSupported = ValueConstant('failed')


class SmsMessage(namedtuple('SmsMessage',
                            ['message', 'sender_address',
                             'service_activation_number', 'timestamp'])):
    """
    ParlayX `SmsMessage` complex type.
    """
    @classmethod
    def from_element(cls, root):
        """
        Create an `SmsMessage` instance from an ElementTree element.
        """
        return cls(
            message=gettext(root, 'message'),
            sender_address=gettext(
                root, 'senderAddress', parse=normalize_address),
            service_activation_number=gettext(
                root, 'smsServiceActivationNumber', parse=normalize_address),
            timestamp=gettext(root, 'dateTime', parse=iso8601.parse_date))


class DeliveryInformation(namedtuple('DeliveryInformation',
                                     ['address', 'delivery_status'])):
    """
    ParlayX `DeliveryInformation` complex type.
    """
    @classmethod
    def from_element(cls, root):
        """
        Create a `DeliveryInformation` instance from an ElementTree element.
        """
        try:
            delivery_status = gettext(
                root, 'deliveryStatus', parse=DeliveryStatus.lookupByName)
        except ValueError, e:
            raise ValueError(
                'No such delivery status enumeration value: %r' % (str(e),))
        else:
            return cls(
                address=gettext(root, 'address', parse=normalize_address),
                delivery_status=delivery_status)


class SmsNotificationService(Resource):
    """
    Web resource to handle SOAP requests for ParlayX SMS deliveries and
    delivery receipts.
    """
    isLeaf = True

    def __init__(self, callback_message_received, callback_message_delivered):
        self.callback_message_received = callback_message_received
        self.callback_message_delivered = callback_message_delivered
        Resource.__init__(self)

    def render_POST(self, request):
        """
        Process a SOAP request and convert any exceptions into SOAP faults.
        """
        def _writeResponse(response):
            request.setHeader('Content-Type', 'text/xml; charset="utf-8"')
            request.write(tostring(soap_envelope(response)))
            request.finish()

        def _handleSuccess(result):
            request.setResponseCode(http.OK)
            return result

        def _handleError(f):
            # XXX: Perhaps report this back to the transport somehow???
            log.err(f, 'Failure processing SOAP request')
            request.setResponseCode(http.INTERNAL_SERVER_ERROR)
            faultcode = u'soapenv:Server'
            if f.check(SoapFault):
                return f.value.to_element()
            return soap_fault(faultcode, f.getErrorMessage())

        try:
            tree = parse_document(request.content)
            body, header = unwrap_soap_envelope(tree)
        except:
            d = fail()
        else:
            d = maybeDeferred(self.process, request, body, header)
            d.addCallback(_handleSuccess)

        d.addErrback(_handleError)
        d.addCallback(_writeResponse)
        return NOT_DONE_YET

    def process(self, request, body, header=None):
        """
        Process a SOAP request.
        """
        for child in body.getchildren():
            # Since there is no SOAPAction header, and these requests are not
            # made to different endpoints, the only way to handle these is to
            # switch on the root element's name. Yuck.
            localname = split_qualified(child.tag)[1]
            meth = getattr(self, 'process_' + localname, self.process_unknown)
            return meth(child, header, localname)

        raise SoapFault(u'soapenv:Client', u'No actionable items')

    def process_unknown(self, root, header, name):
        """
        Process unknown notification deliverables.
        """
        raise SoapFault(u'soapenv:Server', u'No handler for %s' % (name,))

    def process_notifySmsReception(self, root, header, name):
        """
        Process a received text message.
        """
        linkid = None
        if header is not None:
            linkid = gettext(header, './/' + str(PARLAYX_COMMON_NS.linkid))

        correlator = gettext(root, NOTIFICATION_NS.correlator)
        message = SmsMessage.from_element(
            elemfind(root, NOTIFICATION_NS.message))
        d = maybeDeferred(
            self.callback_message_received, correlator, linkid, message)
        d.addCallback(
            lambda ignored: NOTIFICATION_NS.notifySmsReceptionResponse())
        return d

    def process_notifySmsDeliveryReceipt(self, root, header, name):
        """
        Process a text message delivery receipt.
        """
        correlator = gettext(root, NOTIFICATION_NS.correlator)
        delivery_info = DeliveryInformation.from_element(
            elemfind(root, NOTIFICATION_NS.deliveryStatus))
        d = maybeDeferred(self.callback_message_delivered,
            correlator, delivery_info.delivery_status.value)
        d.addCallback(
            lambda ignored: NOTIFICATION_NS.notifySmsDeliveryReceiptResponse())
        return d


# XXX: Only used for debugging with SoapUI:
# twistd web --class=vumi.transports.parlayx.server.Root --port=9080
class Root(Resource):
    def getChild(self, path, request):
        from twisted.internet.defer import succeed
        noop = lambda *a, **kw: succeed(None)
        if request.postpath == ['services', 'SmsNotification']:
            return SmsNotificationService(noop, noop)
        return None


__all__ = [
    'normalize_address', 'DeliveryStatus', 'SmsMessage', 'DeliveryInformation',
    'SmsNotificationService']
PK=JG]388#vumi/transports/parlayx/soaputil.py# -*- test-case-name: vumi.transports.parlayx.tests.test_soaputil -*-
"""
Utility functions for performing and processing SOAP requests, constructing
SOAP responses and parsing and constructing SOAP faults.
"""
from twisted.web import http
from twisted.python import log

from vumi.utils import http_request_full
from vumi.transports.parlayx.xmlutil import (
    Namespace, LocalNamespace, elemfind, gettext, fromstring, tostring)


SOAP_ENV = Namespace('http://schemas.xmlsoap.org/soap/envelope/', 'soapenv')


def perform_soap_request(uri, action, body, header=None,
                         expected_faults=None,
                         http_request_full=http_request_full):
    """
    Perform a SOAP request.

    If the remote server responds with an HTTP 500 status, then it is assumed
    that the body contains a SOAP fault, which is then parsed and a `SoapFault`
    exception raised.

    :param uri: SOAP endpoint URI.
    :param action: SOAP action.
    :param body:
        SOAP body that will appear in an envelope ``Body`` element.
    :param header:
        SOAP header that will appear in an envelope ``Header`` element, or
        ``None`` so omit the header.
    :param expected_faults:
        A `list` of `SoapFault` subclasses to be used to extract fault details
        from SOAP faults.
    :param http_request_full:
        Callable to perform an HTTP request, see
        `vumi.utils.http_request_full`.
    :return:
        `Deferred` that fires with the response, in the case of success, or
        a `SoapFault` in the case of failure.
    """
    def _parse_soap_response(response):
        root = fromstring(response.delivered_body)
        body, header = unwrap_soap_envelope(root)
        if response.code == http.INTERNAL_SERVER_ERROR:
            raise SoapFault.from_element(body, expected_faults)
        return body, header

    envelope = soap_envelope(body, header)
    headers = {
        'SOAPAction': action,
        'Content-Type': 'text/xml; charset="utf-8"'}
    d = http_request_full(uri, tostring(envelope), headers)
    d.addCallback(_parse_soap_response)
    return d


def soap_envelope(body, header=None):
    """
    Wrap an element or text in a SOAP envelope.
    """
    parts = [SOAP_ENV.Body(body)]
    if header is not None:
        parts.insert(0, SOAP_ENV.Header(header))
    return SOAP_ENV.Envelope(*parts)


def unwrap_soap_envelope(root):
    """
    Unwrap a SOAP request and return the SOAP header and body elements.
    """
    header = elemfind(root, SOAP_ENV.Header)
    body = elemfind(root, SOAP_ENV.Body)
    if body is None:
        raise SoapFault(u'soapenv:Client', u'Malformed SOAP request')
    return body, header


def soap_fault(faultcode, faultstring=None, faultactor=None, detail=None):
    """
    Create a SOAP fault response.
    """
    def _maybe(f, value):
        if value is not None:
            return f(value)
        return None

    xs = [
        LocalNamespace.faultcode(faultcode),
        _maybe(LocalNamespace.faultstring, faultstring),
        _maybe(LocalNamespace.faultactor, faultactor),
        _maybe(LocalNamespace.detail, detail)]
    # filter(None, xs) doesn't do what we want because of weird implicit
    # truthiness with ElementTree elements.
    return SOAP_ENV.Fault(*[x for x in xs if x is not None])


def _parse_expected_faults(detail, expected_faults):
    """
    Parse expected SOAP faults from a SOAP fault ``detail`` element.

    :param detail:
        ElementTree element containing SOAP fault detail elements that will
        attempt to be matched against the expected SOAP faults.
    :param expected_faults:
        A `list` of `SoapFault` subclasses whose ``detail_type`` attribute will
        be used to determine a match against each top-level SOAP fault detail
        element.
    :return:
        A 2-tuple of the matching exception type and an instance of it's
        ``detail_type``, or ``None`` if there were no matches.
    """
    if detail is None:
        return None

    for child in detail.getchildren():
        for exc_type in expected_faults:
            try:
                if exc_type.detail_type is not None:
                    det = exc_type.detail_type.from_element(child)
                    if det is not None:
                        return exc_type, det
            except:
                log.err(
                    None, 'Error parsing SOAP fault element (%r) with %r' % (
                        child, exc_type.detail_type))

    return None


def parse_soap_fault(body, expected_faults=None):
    """
    Parse a SOAP fault element and its details.

    :param body: SOAP ``Body`` element.
    :param expected_faults:
        A `list` of `SoapFault` subclasses whose ``detail_type`` attribute will
        be used to determine a match against each top-level SOAP fault detail
        element.
    :return:
        A 2-tuple of: matching exception type and an instance of it's
        ``detail_type``; and SOAP fault information (code, string, actor,
        detail). ``None`` if there is no SOAP fault.
    """
    fault = elemfind(body, SOAP_ENV.Fault)
    if fault is None:
        return None
    faultcode = gettext(fault, u'faultcode')
    faultstring = gettext(fault, u'faultstring')
    faultactor = gettext(fault, u'faultactor')
    detail = elemfind(fault, u'detail')

    if expected_faults is None:
        expected_faults = []
    parsed = _parse_expected_faults(detail, expected_faults)
    return parsed, (faultcode, faultstring, faultactor, detail)


class SoapFault(RuntimeError):
    """
    An exception that constitutes a SOAP fault.

    :cvar detail_type:
        A type with a ``from_element`` callable that takes a top-level detail
        element and attempts to parse it, returning an instance of itself if
        successful or ``None``.

    :ivar code: SOAP fault code.
    :ivar string: SOAP fault string.
    :ivar actor: SOAP fault actor.
    :ivar detail: SOAP fault detail ElementTree element.
    :ivar parsed_detail: ``detail_type`` instance, or ``None``.
    """
    detail_type = None

    def __init__(self, code, string, actor=None, detail=None,
                 parsed_detail=None):
        self.code = code
        self.string = string
        self.actor = actor
        self.detail = detail
        self.parsed_detail = parsed_detail
        RuntimeError.__init__(self, string)

    def __repr__(self):
        return '<%s code=%r string=%r actor=%r parsed_detail=%r>' % (
            type(self).__name__,
            self.code,
            self.string,
            self.actor,
            self.parsed_detail)

    @classmethod
    def from_element(cls, root, expected_faults=None):
        """
        Parse a SOAP fault from an ElementTree element.

        :param expected_faults:
            A `list` of `SoapFault` subclasses whose ``detail_type`` attribute
            will be used to determine a match against each top-level SOAP fault
            detail element.
        :return: A `SoapFault` subclass.
        """
        faultinfo = parse_soap_fault(root, expected_faults)
        if faultinfo is None:
            raise ValueError(
                'Element (%r) does not contain a SOAP fault' % (root,))

        parsed_fault, faultinfo = faultinfo
        if parsed_fault is None:
            parsed_fault = SoapFault, None
        exc_type, parsed_detail = parsed_fault

        faultcode, faultstring, faultactor, detail = faultinfo
        return exc_type(
            faultcode, faultstring, faultactor, detail, parsed_detail)

    def to_element(self):
        """
        Serialize this SOAP fault to an ElementTree element.
        """
        detail = self.detail
        if detail is not None:
            detail = self.detail.getchildren()
        return soap_fault(
            self.code, self.string, self.actor, detail)


__all__ = [
    'perform_soap_request', 'soap_envelope', 'unwrap_soap_envelope',
    'soap_fault', 'parse_soap_fault', 'SoapFault']
PK=JGM]GG&vumi/transports/parlayx/tests/utils.pyfrom collections import namedtuple

from twisted.python import failure

from vumi.transports.parlayx.client import format_address
from vumi.transports.parlayx.soaputil import soap_envelope
from vumi.transports.parlayx.server import NOTIFICATION_NS, normalize_address
from vumi.transports.parlayx.xmlutil import LocalNamespace as L, tostring


# XXX: This can be deleted as soon as we're using Twisted > 13.0.0.
class _FailureResultOfMixin(object):
    """
    Mixin providing the more recent and useful version of
    `TestCase.failureResultOf`.
    """
    def failureResultOf(self, deferred, *expectedExceptionTypes):
        """
        Return the current failure result of C{deferred} or raise
        C{self.failException}.

        @param deferred: A L{Deferred} which
            has a failure result.  This means
            L{Deferred.callback} or
            L{Deferred.errback} has
            been called on it and it has reached the end of its callback chain
            and the last callback or errback raised an exception or returned a
            L{failure.Failure}.
        @type deferred: L{Deferred}

        @param expectedExceptionTypes: Exception types to expect - if
            provided, and the the exception wrapped by the failure result is
            not one of the types provided, then this test will fail.

        @raise SynchronousTestCase.failureException: If the
            L{Deferred} has no result, has a
            success result, or has an unexpected failure result.

        @return: The failure result of C{deferred}.
        @rtype: L{failure.Failure}
        """
        result = []
        deferred.addBoth(result.append)
        if not result:
            self.fail(
                "Failure result expected on %r, found no result instead" % (
                    deferred,))
        elif not isinstance(result[0], failure.Failure):
            self.fail(
                "Failure result expected on %r, "
                "found success result (%r) instead" % (deferred, result[0]))
        elif (expectedExceptionTypes and
              not result[0].check(*expectedExceptionTypes)):
            expectedString = " or ".join([
                '.'.join((t.__module__, t.__name__)) for t in
                expectedExceptionTypes])

            self.fail(
                "Failure of type (%s) expected on %r, "
                "found type %r instead: %s" % (
                    expectedString, deferred, result[0].type,
                    result[0].getTraceback()))
        else:
            return result[0]


class MockResponse(namedtuple('MockResponse', ['code', 'delivered_body'])):
    """
    Mock response from ``http_request_full``.
    """
    @classmethod
    def build(cls, code, body, header=None):
        """
        Build a `MockResponse` containing a SOAP envelope.
        """
        return cls(
            code=code,
            delivered_body=tostring(soap_envelope(body, header)))


def create_sms_reception_element(correlator, message, sender_address,
                                 service_activation_number):
    """
    Helper for creating an ``notifySmsReception`` element.
    """
    return NOTIFICATION_NS.notifySmsReception(
        NOTIFICATION_NS.correlator(correlator),
        NOTIFICATION_NS.message(
            L.message(message),
            L.senderAddress(format_address(normalize_address(sender_address))),
            L.smsServiceActivationNumber(service_activation_number)))


def create_sms_delivery_receipt(correlator, address, delivery_status):
    """
    Helper for creating an ``notifySmsDeliveryReceipt`` element.
    """
    return NOTIFICATION_NS.notifySmsDeliveryReceipt(
        NOTIFICATION_NS.correlator(correlator),
        NOTIFICATION_NS.deliveryStatus(
            L.address(format_address(normalize_address(address))),
            L.deliveryStatus(delivery_status.name)))


__all__ = [
    'MockResponse', 'create_sms_reception_element',
    'create_sms_delivery_receipt']
PK=JG/"ۿ66,vumi/transports/parlayx/tests/test_client.pyfrom datetime import datetime
from functools import partial

from twisted.internet.defer import succeed
from twisted.trial.unittest import TestCase
from twisted.web import http

from vumi.transports.parlayx.client import (
    PARLAYX_COMMON_NS, PARLAYX_HEAD_NS, NOTIFICATION_MANAGER_NS, SEND_NS,
    format_address, ServiceExceptionDetail, ServiceException,
    PolicyExceptionDetail, PolicyException, ParlayXClient, format_timestamp,
    make_password)
from vumi.transports.parlayx.soaputil import (
    perform_soap_request, unwrap_soap_envelope, soap_fault)
from vumi.transports.parlayx.xmlutil import (
    LocalNamespace as L, elemfind, fromstring, element_to_dict)
from vumi.transports.parlayx.tests.utils import (
    MockResponse, _FailureResultOfMixin)


class FormatAddressTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.client.format_address`.
    """
    def test_invalid(self):
        """
        `format_address` raises ``ValueError` for invalid MSISDNs.
        """
        self.assertRaises(ValueError, format_address, '12345')
        self.assertRaises(ValueError, format_address, 'nope')

    def test_format(self):
        """
        `format_address` formats MSISDNs in a way that ParlayX services will
        accept.
        """
        self.assertEqual(
            'tel:27117654321', format_address('+27117654321'))
        self.assertEqual(
            'tel:264117654321', format_address('+264117654321'))


class FormatTimestampTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.client.format_timestamp`.
    """
    def test_format(self):
        """
        Format a `datetime` instance timestamp according to ParlayX
        requirements.
        """
        self.assertEqual(
            '20130618105933',
            format_timestamp(datetime(2013, 6, 18, 10, 59, 33)))


class MakePasswordTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.client.make_password`.
    """
    def test_make_password(self):
        """
        Build a time-sensitive password for a request.
        """
        timestamp = format_timestamp(datetime(2013, 6, 18, 10, 59, 33))
        self.assertEqual(
            '1f2e67e642b16f6623459fa76dc3894f',
            make_password('user', 'password', timestamp))


class ServiceExceptionDetailTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.client.ServiceExceptionDetail`.
    """
    def test_unmatched(self):
        """
        `ServiceExceptionDetail.from_element` returns ``None`` if the element's
        tag is not a service exception detail.
        """
        elem = L.WhatIsThis(
            L.foo('a'),
            L.bar('b'))
        self.assertIdentical(None, ServiceExceptionDetail.from_element(elem))

    def test_from_element(self):
        """
        `ServiceExceptionDetail.from_element` returns
        a `ServiceExceptionDetail` instance by parsing
        a ``ServiceExceptionDetail`` detail element.
        """
        elem = PARLAYX_COMMON_NS.ServiceExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        detail = ServiceExceptionDetail.from_element(elem)
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))


class PolicyExceptionDetailTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.client.PolicyExceptionDetail`.
    """
    def test_unmatched(self):
        """
        `PolicyExceptionDetail.from_element` returns ``None`` if the element's
        tag is not a policy exception detail.
        """
        elem = L.WhatIsThis(
            L.foo('a'),
            L.bar('b'))
        self.assertIdentical(None, PolicyExceptionDetail.from_element(elem))

    def test_from_element(self):
        """
        `PolicyExceptionDetail.from_element` returns a `PolicyExceptionDetail`
        instance by parsing a ``PolicyExceptionDetail`` detail element.
        """
        elem = PARLAYX_COMMON_NS.PolicyExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        detail = PolicyExceptionDetail.from_element(elem)
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))


class ParlayXClientTests(_FailureResultOfMixin, TestCase):
    """
    Tests for `vumi.transports.parlayx.client.ParlayXClient`.
    """
    def setUp(self):
        self.requests = []

    def _http_request_full(self, response, uri, body, headers):
        """
        A mock for `vumi.utils.http_request_full`.

        Store an HTTP request's information and return a canned response.
        """
        self.requests.append((uri, body, headers))
        return succeed(response)

    def _perform_soap_request(self, response, *a, **kw):
        """
        Perform a SOAP request with a canned response.
        """
        return perform_soap_request(
            http_request_full=partial(
                self._http_request_full, response), *a, **kw)

    def _make_client(self, response=''):
        """
        Create a `ParlayXClient` instance that uses a stubbed
        `perform_soap_request` function.
        """
        return ParlayXClient(
            'service_id', 'user', 'password', 'short', 'endpoint', 'send',
            'notification',
            perform_soap_request=partial(self._perform_soap_request, response))

    def test_start_sms_notification(self):
        """
        `ParlayXClient.start_sms_notification` performs a SOAP request to the
        remote ParlayX notification endpoint indicating where delivery and
        receipt notifications for a particular service activation number can be
        delivered.
        """
        client = self._make_client(
            MockResponse.build(
                http.OK, NOTIFICATION_MANAGER_NS.startSmsNotificationResponse))
        client._now = partial(datetime, 2013, 6, 18, 10, 59, 33)
        self.successResultOf(client.start_sms_notification())
        self.assertEqual(1, len(self.requests))
        self.assertEqual('notification', self.requests[0][0])
        body, header = unwrap_soap_envelope(fromstring(self.requests[0][1]))
        self.assertEqual(
            {str(NOTIFICATION_MANAGER_NS.startSmsNotification): {
                str(NOTIFICATION_MANAGER_NS.reference): {
                    'correlator': client._service_correlator,
                    'endpoint': 'endpoint',
                    'interfaceName': 'notifySmsReception'},
                str(NOTIFICATION_MANAGER_NS.smsServiceActivationNumber):
                    'short'}},
            element_to_dict(
                elemfind(body, NOTIFICATION_MANAGER_NS.startSmsNotification)))
        self.assertEqual(
            {str(PARLAYX_HEAD_NS.RequestSOAPHeader): {
                str(PARLAYX_HEAD_NS.serviceId): 'service_id',
                str(PARLAYX_HEAD_NS.spId): 'user',
                str(PARLAYX_HEAD_NS.spPassword):
                    '1f2e67e642b16f6623459fa76dc3894f',
                str(PARLAYX_HEAD_NS.timeStamp): '20130618105933'}},
            element_to_dict(
                elemfind(header, PARLAYX_HEAD_NS.RequestSOAPHeader)))

    def test_start_sms_notification_service_fault(self):
        """
        `ParlayXClient.start_sms_notification` expects `ServiceExceptionDetail`
        fault details in SOAP requests that fail for remote service-related
        reasons.
        """
        detail = PARLAYX_COMMON_NS.ServiceExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        client = self._make_client(
            MockResponse.build(
                http.INTERNAL_SERVER_ERROR,
                soap_fault('soapenv:Server', 'Whoops', detail=detail)))
        f = self.failureResultOf(
            client.start_sms_notification(), ServiceException)
        detail = f.value.parsed_detail
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))

    def test_stop_sms_notification(self):
        """
        `ParlayXClient.stop_sms_notification` performs a SOAP request to the
        remote ParlayX notification endpoint indicating that delivery and
        receipt notifications for a particular service activation number can be
        deactivated.
        """
        client = self._make_client(
            MockResponse.build(
                http.OK, NOTIFICATION_MANAGER_NS.stopSmsNotificationResponse))
        client._now = partial(datetime, 2013, 6, 18, 10, 59, 33)
        self.successResultOf(client.stop_sms_notification())
        self.assertEqual(1, len(self.requests))
        self.assertEqual('notification', self.requests[0][0])
        body, header = unwrap_soap_envelope(fromstring(self.requests[0][1]))
        self.assertEqual(
            {str(NOTIFICATION_MANAGER_NS.stopSmsNotification): {
                'correlator': client._service_correlator}},
            element_to_dict(
                elemfind(body, NOTIFICATION_MANAGER_NS.stopSmsNotification)))
        self.assertEqual(
            {str(PARLAYX_HEAD_NS.RequestSOAPHeader): {
                str(PARLAYX_HEAD_NS.serviceId): 'service_id',
                str(PARLAYX_HEAD_NS.spId): 'user',
                str(PARLAYX_HEAD_NS.spPassword):
                    '1f2e67e642b16f6623459fa76dc3894f',
                str(PARLAYX_HEAD_NS.timeStamp): '20130618105933'}},
            element_to_dict(
                elemfind(header, PARLAYX_HEAD_NS.RequestSOAPHeader)))

    def test_stop_sms_notification_service_fault(self):
        """
        `ParlayXClient.stop_sms_notification` expects `ServiceExceptionDetail`
        fault details in SOAP requests that fail for remote service-related
        reasons.
        """
        detail = PARLAYX_COMMON_NS.ServiceExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        client = self._make_client(
            MockResponse.build(
                http.INTERNAL_SERVER_ERROR,
                soap_fault('soapenv:Server', 'Whoops', detail=detail)))
        f = self.failureResultOf(
            client.stop_sms_notification(), ServiceException)
        detail = f.value.parsed_detail
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))

    def test_send_sms(self):
        """
        `ParlayXClient.send_sms` performs a SOAP request to the
        remote ParlayX send endpoint to deliver a message via SMS.
        """
        client = self._make_client(
            MockResponse.build(
                http.OK, SEND_NS.sendSmsResponse(SEND_NS.result('reference'))))
        client._now = partial(datetime, 2013, 6, 18, 10, 59, 33)
        response = self.successResultOf(
            client.send_sms('+27117654321', 'content', 'message_id', 'linkid'))
        self.assertEqual('reference', response)
        self.assertEqual(1, len(self.requests))
        self.assertEqual('send', self.requests[0][0])

        body, header = unwrap_soap_envelope(fromstring(self.requests[0][1]))
        self.assertEqual(
            {str(SEND_NS.sendSms): {
                str(SEND_NS.addresses): 'tel:27117654321',
                str(SEND_NS.message): 'content',
                str(SEND_NS.receiptRequest): {
                    'correlator': 'message_id',
                    'endpoint': 'endpoint',
                    'interfaceName': 'SmsNotification'}}},
            element_to_dict(elemfind(body, SEND_NS.sendSms)))
        self.assertEqual(
            {str(PARLAYX_HEAD_NS.RequestSOAPHeader): {
                str(PARLAYX_HEAD_NS.serviceId): 'service_id',
                str(PARLAYX_HEAD_NS.spId): 'user',
                str(PARLAYX_HEAD_NS.spPassword):
                    '1f2e67e642b16f6623459fa76dc3894f',
                str(PARLAYX_HEAD_NS.timeStamp): '20130618105933',
                str(PARLAYX_HEAD_NS.linkid): 'linkid',
                str(PARLAYX_HEAD_NS.OA): 'tel:27117654321'}},
            element_to_dict(
                elemfind(header, PARLAYX_HEAD_NS.RequestSOAPHeader)))

    def test_send_sms_service_fault(self):
        """
        `ParlayXClient.send_sms` expects `ServiceExceptionDetail` fault details
        in SOAP requests that fail for remote service-related reasons.
        """
        detail = PARLAYX_COMMON_NS.ServiceExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        client = self._make_client(
            MockResponse.build(
                http.INTERNAL_SERVER_ERROR,
                soap_fault('soapenv:Server', 'Whoops', detail=detail)))
        f = self.failureResultOf(
            client.send_sms('+27117654321', 'content', 'message_id'),
            ServiceException)
        detail = f.value.parsed_detail
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))

    def test_send_sms_policy_fault(self):
        """
        `ParlayXClient.send_sms` expects `PolicyExceptionDetail` fault details
        in SOAP requests that fail for remote policy-related reasons.
        """
        detail = PARLAYX_COMMON_NS.PolicyExceptionDetail(
            L.messageId('a'),
            L.text('b'),
            L.variables('c'),
            L.variables('d'))
        client = self._make_client(
            MockResponse.build(
                http.INTERNAL_SERVER_ERROR,
                soap_fault('soapenv:Server', 'Whoops', detail=detail)))
        f = self.failureResultOf(
            client.send_sms('+27117654321', 'content', 'message_id'),
            PolicyException)
        detail = f.value.parsed_detail
        self.assertEqual(
            ('a', 'b', ['c', 'd']),
            (detail.message_id, detail.text, detail.variables))
PK=JGz**,vumi/transports/parlayx/tests/test_server.pyimport iso8601
from datetime import datetime
from StringIO import StringIO

from twisted.trial.unittest import TestCase
from twisted.web import http
from twisted.web.test.requesthelper import DummyRequest

from vumi.transports.parlayx.server import (
    NOTIFICATION_NS, PARLAYX_COMMON_NS, normalize_address, SmsMessage,
    DeliveryInformation, DeliveryStatus, SmsNotificationService)
from vumi.transports.parlayx.soaputil import SoapFault, SOAP_ENV, soap_envelope
from vumi.transports.parlayx.xmlutil import (
    ParseError, LocalNamespace as L, tostring, fromstring, element_to_dict)
from vumi.transports.parlayx.tests.utils import (
    create_sms_reception_element, create_sms_delivery_receipt)


class NormalizeAddressTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.server.normalize_address`.
    """
    def test_not_prefixed(self):
        """
        `normalize_address` still normalizes addresses that are not prefixed
        with ``tel:``.
        """
        self.assertEqual(
            '+27117654321', normalize_address('27 11 7654321'))
        self.assertEqual(
            '54321', normalize_address('54321'))

    def test_prefixed(self):
        """
        `normalize_address` strips any ``tel:`` prefix and normalizes the
        address.
        """
        self.assertEqual(
            '+27117654321', normalize_address('tel:27 11 7654321'))
        self.assertEqual(
            '+27117654321', normalize_address('tel:27117654321'))
        self.assertEqual(
            '54321', normalize_address('tel:54321'))


class SmsMessageTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.server.SmsMessage`.
    """
    def test_from_element(self):
        """
        `SmsMessage.from_element` parses a ParlayX ``SmsMessage`` complex type,
        with an ISO8601 timestamp, into an `SmsMessage` instance.
        """
        timestamp = datetime(
            2013, 6, 12, 13, 15, 0, tzinfo=iso8601.iso8601.Utc())
        msg = SmsMessage.from_element(
            NOTIFICATION_NS.message(
                L.message('message'),
                L.senderAddress('tel:27117654321'),
                L.smsServiceActivationNumber('54321'),
                L.dateTime('2013-06-12T13:15:00')))
        self.assertEqual(
            ('message', '+27117654321', '54321', timestamp),
            (msg.message, msg.sender_address, msg.service_activation_number,
             msg.timestamp))

    def test_from_element_missing_timestamp(self):
        """
        `SmsMessage.from_element` parses a ParlayX ``SmsMessage`` complex type,
        without a timestamp, into an `SmsMessage` instance.
        """
        msg = SmsMessage.from_element(
            NOTIFICATION_NS.message(
                L.message('message'),
                L.senderAddress('tel:27117654321'),
                L.smsServiceActivationNumber('54321')))
        self.assertEqual(
            ('message', '+27117654321', '54321', None),
            (msg.message, msg.sender_address, msg.service_activation_number,
             msg.timestamp))


class DeliveryInformationTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.server.DeliveryInformation`.
    """
    def test_from_element(self):
        """
        `DeliveryInformation.from_element` parses a ParlayX
        ``DeliveryInformation`` complex type into a `DeliveryInformation`
        instance. Known ``DeliveryStatus`` enumeration values are parsed into
        `DeliveryStatus` attributes.
        """
        info = DeliveryInformation.from_element(
            NOTIFICATION_NS.deliveryStatus(
                L.address('tel:27117654321'),
                L.deliveryStatus('DeliveredToNetwork')))
        self.assertEqual(
            ('+27117654321', DeliveryStatus.DeliveredToNetwork),
            (info.address, info.delivery_status))

    def test_from_element_unknown_status(self):
        """
        `DeliveryInformation.from_element` raises ``ValueError`` if an unknown
        ``DeliveryStatus`` enumeration value is specified.
        """
        e = self.assertRaises(ValueError,
            DeliveryInformation.from_element,
            NOTIFICATION_NS.deliveryStatus(
                L.address('tel:27117654321'),
                L.deliveryStatus('WhatIsThis')))
        self.assertEqual(
            "No such delivery status enumeration value: 'WhatIsThis'", str(e))


class SmsNotificationServiceTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.server.SmsNotificationService`.
    """
    def test_process_empty(self):
        """
        `SmsNotificationService.process` raises `SoapFault` if there are no
        actionable child elements in the request body.
        """
        service = SmsNotificationService(None, None)
        exc = self.assertRaises(SoapFault,
            service.process, None, L.root())
        self.assertEqual(
            ('soapenv:Client', 'No actionable items'),
            (exc.code, str(exc)))

    def test_process_unknown(self):
        """
        `SmsNotificationService.process` invokes
        `SmsNotificationService.process_unknown`, for handling otherwise
        unknown requests, which raises `SoapFault`.
        """
        service = SmsNotificationService(None, None)
        exc = self.assertRaises(SoapFault,
            service.process, None, L.root(L.WhatIsThis))
        self.assertEqual(
            ('soapenv:Server', 'No handler for WhatIsThis'),
            (exc.code, str(exc)))

    def test_process_notifySmsReception(self):
        """
        `SmsNotificationService.process_notifySmsReception` invokes the
        message delivery callback with the correlator (message identifier) and
        a `SmsMessage` instance containing the details of the delivered
        message.
        """
        def callback(*a):
            self.callbacks.append(a)
        self.callbacks = []
        service = SmsNotificationService(callback, None)
        self.successResultOf(service.process(None,
            SOAP_ENV.Body(
                create_sms_reception_element(
                    '1234', 'message', '+27117654321', '54321')),
            SOAP_ENV.Header(
                PARLAYX_COMMON_NS.NotifySOAPHeader(
                    PARLAYX_COMMON_NS.linkid('linkid')))))

        self.assertEqual(1, len(self.callbacks))
        correlator, linkid, msg = self.callbacks[0]
        self.assertEqual(
            ('1234', 'linkid', 'message', '+27117654321', '54321', None),
            (correlator, linkid, msg.message, msg.sender_address,
             msg.service_activation_number, msg.timestamp))

    def test_process_notifySmsDeliveryReceipt(self):
        """
        `SmsNotificationService.process_notifySmsDeliveryReceipt` invokes the
        delivery receipt callback with the correlator (message identifier) and
        the delivery status (translated into a Vumi-compatible value.)
        """
        def callback(*a):
            self.callbacks.append(a)
        self.callbacks = []
        service = SmsNotificationService(None, callback)
        self.successResultOf(service.process(None,
            SOAP_ENV.Body(
                create_sms_delivery_receipt(
                    '1234',
                    '+27117654321',
                    DeliveryStatus.DeliveryUncertain))))

        self.assertEqual(1, len(self.callbacks))
        correlator, status = self.callbacks[0]
        self.assertEqual(('1234', 'pending'), self.callbacks[0])

    def test_render(self):
        """
        `SmsNotificationService.render_POST` parses a SOAP request and
        dispatches it to `SmsNotificationService.process` for processing.
        """
        service = SmsNotificationService(None, None)
        service.process = lambda *a, **kw: L.done()
        request = DummyRequest([])
        request.content = StringIO(tostring(soap_envelope('hello')))
        d = request.notifyFinish()
        service.render_POST(request)
        self.successResultOf(d)
        self.assertEqual(http.OK, request.responseCode)
        self.assertEqual(
            {str(SOAP_ENV.Envelope): {
                str(SOAP_ENV.Body): {
                    'done': None}}},
            element_to_dict(fromstring(''.join(request.written))))

    def test_render_soap_fault(self):
        """
        `SmsNotificationService.render_POST` logs any exceptions that occur
        during processing and writes a SOAP fault back to the request. If the
        logged exception is a `SoapFault` its ``to_element`` method is invoked
        to serialize the fault.
        """
        service = SmsNotificationService(None, None)
        service.process = lambda *a, **kw: L.done()
        request = DummyRequest([])
        request.content = StringIO(tostring(L.hello()))
        d = request.notifyFinish()

        service.render_POST(request)
        self.successResultOf(d)
        self.assertEqual(http.INTERNAL_SERVER_ERROR, request.responseCode)
        failures = self.flushLoggedErrors(SoapFault)
        self.assertEqual(1, len(failures))
        self.assertEqual(
            {str(SOAP_ENV.Envelope): {
                str(SOAP_ENV.Body): {
                    str(SOAP_ENV.Fault): {
                        'faultcode': 'soapenv:Client',
                        'faultstring': 'Malformed SOAP request'}}}},
            element_to_dict(fromstring(''.join(request.written))))

    def test_render_exceptions(self):
        """
        `SmsNotificationService.render_POST` logs any exceptions that occur
        during processing and writes a SOAP fault back to the request.
        """
        def process(*a, **kw):
            raise ValueError('What is this')
        service = SmsNotificationService(None, None)
        service.process = process
        request = DummyRequest([])
        request.content = StringIO(tostring(soap_envelope('hello')))
        d = request.notifyFinish()

        service.render_POST(request)
        self.successResultOf(d)
        self.assertEqual(http.INTERNAL_SERVER_ERROR, request.responseCode)
        failures = self.flushLoggedErrors(ValueError)
        self.assertEqual(1, len(failures))
        self.assertEqual(
            {str(SOAP_ENV.Envelope): {
                str(SOAP_ENV.Body): {
                    str(SOAP_ENV.Fault): {
                        'faultcode': 'soapenv:Server',
                        'faultstring': 'What is this'}}}},
            element_to_dict(fromstring(''.join(request.written))))

    def test_render_invalid_xml(self):
        """
        `SmsNotificationService.render_POST` does not accept invalid XML body
        content.
        """
        service = SmsNotificationService(None, None)
        request = DummyRequest([])
        request.content = StringIO('sup')
        d = request.notifyFinish()

        service.render_POST(request)
        self.successResultOf(d)
        self.assertEqual(http.INTERNAL_SERVER_ERROR, request.responseCode)
        failures = self.flushLoggedErrors(ParseError)
        self.assertEqual(1, len(failures))
PK=JG)vumi/transports/parlayx/tests/__init__.pyPK=JGh?gܻ##-vumi/transports/parlayx/tests/test_parlayx.pyfrom functools import partial

from twisted.internet.defer import inlineCallbacks, succeed, fail
from twisted.trial.unittest import TestCase

from vumi.tests.helpers import VumiTestCase
from vumi.transports.failures import PermanentFailure
from vumi.transports.parlayx import ParlayXTransport
from vumi.transports.parlayx.parlayx import (
    unique_correlator, extract_message_id)
from vumi.transports.parlayx.client import PolicyException, ServiceException
from vumi.transports.parlayx.server import DeliveryStatus
from vumi.transports.parlayx.soaputil import perform_soap_request
from vumi.transports.parlayx.tests.utils import (
    create_sms_reception_element, create_sms_delivery_receipt)
from vumi.transports.tests.helpers import TransportHelper


class MockParlayXClient(object):
    """
    A mock ``ParlayXClient`` that doesn't involve real HTTP requests but
    instead uses canned responses.
    """
    def __init__(self, start_sms_notification=None, stop_sms_notification=None,
                 send_sms=None):
        if start_sms_notification is None:
            start_sms_notification = partial(succeed, None)
        if stop_sms_notification is None:
            stop_sms_notification = partial(succeed, None)
        if send_sms is None:
            send_sms = partial(succeed, 'request_message_id')

        self.responses = {
            'start_sms_notification': start_sms_notification,
            'stop_sms_notification': stop_sms_notification,
            'send_sms': send_sms}
        self.calls = []

    def _invoke_response(self, name, args):
        """
        Invoke the canned response for the method name ``name`` and log the
        invocation.
        """
        self.calls.append((name, args))
        return self.responses[name]()

    def start_sms_notification(self):
        return self._invoke_response('start_sms_notification', [])

    def stop_sms_notification(self):
        return self._invoke_response('stop_sms_notification', [])

    def send_sms(self, to_addr, content, linkid, message_id):
        return self._invoke_response(
            'send_sms', [to_addr, content, linkid, message_id])


class TestParlayXTransport(VumiTestCase):
    """
    Tests for `vumi.transports.parlayx.ParlayXTransport`.
    """

    @inlineCallbacks
    def setUp(self):
        # TODO: Get rid of this hardcoded port number.
        self.port = 9999
        config = {
            'web_notification_path': '/hello',
            'web_notification_port': self.port,
            'notification_endpoint_uri': 'endpoint_uri',
            'short_code': '54321',
            'remote_send_uri': 'send_uri',
            'remote_notification_uri': 'notification_uri',
        }
        self.tx_helper = self.add_helper(TransportHelper(ParlayXTransport))
        self.uri = 'http://127.0.0.1:%s%s' % (
            self.port, config['web_notification_path'])

        def _create_client(transport, config):
            return MockParlayXClient()
        self.patch(
            self.tx_helper.transport_class, '_create_client',
            _create_client)
        self.transport = yield self.tx_helper.get_transport(
            config, start=False)

    @inlineCallbacks
    def test_ack(self):
        """
        Basic message delivery.
        """
        yield self.transport.startWorker()
        msg = yield self.tx_helper.make_dispatch_outbound("hi")
        [event] = self.tx_helper.get_dispatched_events()
        self.assertEqual(event['event_type'], 'ack')
        self.assertEqual(event['user_message_id'], msg['message_id'])

        client = self.transport._parlayx_client
        self.assertEqual(1, len(client.calls))
        linkid = client.calls[0][1][3]
        self.assertIdentical(None, linkid)

    @inlineCallbacks
    def test_ack_linkid(self):
        """
        Basic message delivery uses stored ``linkid`` from transport metadata
        if available.
        """
        yield self.transport.startWorker()
        msg = yield self.tx_helper.make_dispatch_outbound(
            "hi", transport_metadata={'linkid': 'linkid'})
        [event] = self.tx_helper.get_dispatched_events()
        self.assertEqual(event['event_type'], 'ack')
        self.assertEqual(event['user_message_id'], msg['message_id'])

        client = self.transport._parlayx_client
        self.assertEqual(1, len(client.calls))
        linkid = client.calls[0][1][3]
        self.assertEqual('linkid', linkid)

    @inlineCallbacks
    def test_nack(self):
        """
        Exceptions raised in an outbound message handler result in the message
        delivery failing, and a failure event being logged.
        """
        def _create_client(transport, config):
            return MockParlayXClient(
                send_sms=partial(fail, ValueError('failed')))
        self.patch(
            self.tx_helper.transport_class, '_create_client',
            _create_client)

        yield self.transport.startWorker()
        msg = yield self.tx_helper.make_dispatch_outbound("hi")
        [event] = self.tx_helper.get_dispatched_events()
        self.assertEqual(event['event_type'], 'nack')
        self.assertEqual(event['user_message_id'], msg['message_id'])
        self.assertEqual(event['nack_reason'], 'failed')

        failures = self.flushLoggedErrors(ValueError)
        # Logged once by the transport and once by Twisted for being unhandled.
        self.assertEqual(2, len(failures))

    @inlineCallbacks
    def _test_nack_permanent(self, expected_exception):
        """
        The expected exception, when raised in an outbound message handler,
        results in a `PermanentFailure` and is logged along with the original
        exception.
        """
        def _create_client(transport, config):
            return MockParlayXClient(
                send_sms=partial(
                    fail, expected_exception('soapenv:Client', 'failed')))
        self.patch(
            self.tx_helper.transport_class, '_create_client',
            _create_client)

        yield self.transport.startWorker()
        msg = yield self.tx_helper.make_dispatch_outbound("hi")
        [event] = self.tx_helper.get_dispatched_events()
        self.assertEqual(event['event_type'], 'nack')
        self.assertEqual(event['user_message_id'], msg['message_id'])
        self.assertEqual(event['nack_reason'], 'failed')

        failures = self.flushLoggedErrors(expected_exception, PermanentFailure)
        self.assertEqual(2, len(failures))

    def test_nack_service_exception(self):
        """
        When `ServiceException` is raised in an outbound message handler, it
        results in a `PermanentFailure` exception.
        """
        return self._test_nack_permanent(ServiceException)

    def test_nack_policy_exception(self):
        """
        When `PolicyException` is raised in an outbound message handler, it
        results in a `PermanentFailure` exception.
        """
        return self._test_nack_permanent(PolicyException)

    @inlineCallbacks
    def test_receive_sms(self):
        """
        When a text message is submitted to the Vumi ParlayX
        ``notifySmsReception`` SOAP endpoint, a message is
        published containing the message identifier, message content, from
        address and to address that accurately match what was submitted.
        """
        yield self.transport.startWorker()
        body = create_sms_reception_element(
            '1234', 'message', '+27117654321', '54321')
        yield perform_soap_request(self.uri, '', body)
        [msg] = self.tx_helper.get_dispatched_inbound()
        self.assertEqual(
            ('1234', 'message', '+27117654321', '54321'),
            (msg['message_id'], msg['content'], msg['from_addr'],
             msg['to_addr']))

    @inlineCallbacks
    def test_delivery_receipt(self):
        """
        When a delivery receipt is submitted to the Vumi ParlayX
        ``notifySmsDeliveryReceiptResponse`` SOAP endpoint, an event is
        published containing the message identifier and the delivery status
        that accurately match what was submitted.
        """
        yield self.transport.startWorker()
        body = create_sms_delivery_receipt(
            '1234', '+27117654321', DeliveryStatus.DeliveredToNetwork)
        yield perform_soap_request(self.uri, '', body)
        [event] = self.tx_helper.get_dispatched_events()
        self.assertEqual(
            ('1234', 'delivered'),
            (event['user_message_id'], event['delivery_status']))


class TransportUtilsTests(TestCase):
    """
    Tests for miscellaneous functions in `vumi.transports.parlayx`.
    """
    def test_unique_correlator(self):
        """
        `unique_correlator` combines a Vumi transport message identifier and
        a UUID.
        """
        self.assertEqual(
            'arst:12341234', unique_correlator('arst', '12341234'))

    def test_extract_message_id(self):
        """
        `extract_message_id` splits a ParlayX correlator into a Vumi transport
        message identifier and a UUID.
        """
        self.assertEqual(
            'arst', extract_message_id('arst:12341234'))
PK=JGz'&EE-vumi/transports/parlayx/tests/test_xmlutil.pyfrom twisted.python.constants import Names, NamedConstant
from twisted.trial.unittest import TestCase

from vumi.transports.parlayx.xmlutil import (
    Namespace, QualifiedName, ElementMaker, LocalNamespace as L,
    split_qualified, gettext, gettextall, tostring, elemfind, elemfindall,
    element_to_dict)


class NamespaceTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.Namespace`.
    """
    def test_str(self):
        """
        ``str(Namespace)`` produces the Namespace URI.
        """
        uri = 'http://example.com'
        self.assertEqual(uri, str(Namespace(uri)))

    def test_repr(self):
        """
        ``repr(Namespace)`` produces self-explanatory human-readable output.
        """
        self.assertEqual(
            '',
            repr(Namespace(None)))
        self.assertEqual(
            "",
            repr(Namespace('http://example.com')))
        self.assertEqual(
            "",
            repr(Namespace('http://example.com', 'ex')))

    def test_equality(self):
        """
        Two `Namespace` instances created with the same values compare equal to
        one another.
        """
        self.assertEqual(
            Namespace('http://example.com'),
            Namespace('http://example.com'))
        self.assertEqual(
            Namespace('http://example.com', 'ex'),
            Namespace('http://example.com', 'ex'))
        self.assertNotEqual(
            Namespace('http://example.com'),
            Namespace('http://example.com', 'ex'))
        self.assertNotEqual(
            Namespace('http://example.com/'),
            Namespace('http://example.com'))

    def test_qualified_name(self):
        """
        `Namespace.__getattr__` produces qualified `QualifiedName` instances if
        `Namespace.__uri` is not `None`.
        """
        uri = 'http://example.com'
        ns = Namespace(uri)
        self.assertEqual(
            QualifiedName(uri, 'foo'),
            ns.foo)

    def test_local_name(self):
        """
        `Namespace.__getattr__` produces local `QualifiedName` instances if
        `Namespace.__uri` is `None`.
        """
        ns = Namespace(None)
        self.assertEqual(
            QualifiedName('foo'),
            ns.foo)


class QualifiedNameTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.QualifiedName`.
    """
    def test_repr(self):
        """
        ``repr(QualifiedName)`` produces self-explanatory human-readable
        output.
        """
        self.assertEqual(
            "",
            repr(QualifiedName('tag')))
        self.assertEqual(
            "",
            repr(QualifiedName('http://example.com', 'tag')))

    def test_equality(self):
        """
        Two `QualifiedName` instances created with the same values compare
        equal to one another.
        """
        self.assertEqual(
            QualifiedName('tag'),
            QualifiedName('tag'))
        self.assertEqual(
            QualifiedName('http://example.com', 'tag'),
            QualifiedName('http://example.com', 'tag'))
        # Parameters are internally converted to Clark notation anyway.
        self.assertEqual(
            QualifiedName('http://example.com', 'tag'),
            QualifiedName('{http://example.com}tag'))
        self.assertNotEqual(
            QualifiedName('tag'),
            QualifiedName('http://example.com', 'tag'))
        self.assertNotEqual(
            QualifiedName('http://example.com/', 'tag'),
            QualifiedName('http://example.com', 'tag'))

    def test_element(self):
        """
        `QualifiedName` instances are callable and produce ElementTree
        elements.
        """
        qname = QualifiedName('tag')
        self.assertEqual(
            '',
            tostring(qname()))
        self.assertEqual(
            'hello',
            tostring(qname(u'hello')))
        self.assertEqual(
            'hello',
            tostring(qname('hello', key='value')))
        self.assertEqual(
            'hello',
            tostring(qname('hello', dict(key='value'))))


class ElementMakerTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.ElementMaker`.
    """
    def setUp(self):
        # ElementTree has a global namespace prefix map. We need to patch it
        # out here to make the tests independent of each other.
        import xml.etree.ElementTree
        self.patch(xml.etree.ElementTree, '_namespace_map', {})

    def test_unknown_child_type(self):
        """
        `ElementMaker` instances raise `TypeError` when called with children of
        unmapped types.
        """
        E = ElementMaker()
        exc = self.assertRaises(TypeError, E, 'tag', None)
        self.assertEqual('Unknown child type: None', str(exc))

    def test_simple(self):
        """
        Calling `ElementMaker` instances produces ElementTree elements,
        children and attributes can be provided too.
        """
        E = ElementMaker()
        self.assertEqual(
            '',
            tostring(E('tag')))
        self.assertEqual(
            'hello',
            tostring(E('tag', 'hello')))
        self.assertEqual(
            'hello',
            tostring(E('tag', 'hello', key='value')))
        self.assertEqual(
            'hello',
            tostring(E('tag', 'hello', dict(key='value'))))

    def test_callable(self):
        """
        Providing a callable child will result in that child being called, with
        no arguments, to produce the actual child value.
        """
        E = ElementMaker()
        self.assertEqual(
            '',
            tostring(E('tag', L.child)))
        self.assertEqual(
            'hello',
            tostring(E('tag', lambda: 'hello')))

    def test_list(self):
        """
        Providing a list child will result in all the elements of the list
        added individually.
        """
        E = ElementMaker()
        self.assertEqual(
            '',
            tostring(E('tag', [L.child1, L.child2])))
        self.assertEqual(
            'text1text2',
            tostring(E('tag', ['text1', 'text2'])))

    def test_nested(self):
        """
        Children can themselves be ElementTree elements, resulting in nested
        elements.
        """
        E = ElementMaker()
        self.assertEqual(
            '',
            tostring(E('tag', E('child'))))
        self.assertEqual(
            'hello',
            tostring(E('tag', E('child', 'hello'))))
        self.assertEqual(
            'hello',
            tostring(E('tag', E('child', 'hello', key='value'))))
        self.assertEqual(
            'hello',
            tostring(E('tag', E('child', 'hello', dict(key='value')))))

    def test_namespaced(self):
        """
        Tags that are `QualifiedName` instances or use Clark notation produce
        namespaced XML elements.
        """
        E = ElementMaker()
        self.assertEqual(
            '',
            tostring(E('{http://example.com}tag')))
        self.assertEqual(
            '',
            tostring(QualifiedName('http://example.com', 'tag')()))
        ns = Namespace('http://example.com', 'ex')
        self.assertEqual(
            '',
            tostring(ns.tag()))

    def test_namespaced_attributes(self):
        """
        XML attributes that are `QualifiedName` instances or use Clark notation
        produce namespaced XML element attributes.
        """
        ns = Namespace('http://example.com', 'ex')
        attrib = {ns.key: 'value'}
        self.assertEqual(
            {'{http://example.com}tag': {
                '@{http://example.com}key': 'value'}},
            element_to_dict(ns.tag(attrib)))
        attrib = {'{http://example.com}key': 'value'}
        self.assertEqual(
            {'{http://example.com}tag': {
                '@{http://example.com}key': 'value'}},
            element_to_dict(ns.tag(attrib)))

    def test_typemap(self):
        """
        Providing a type map to `ElementMaker` allows the caller to specify how
        to serialize types other than strings and dictionaries.
        """
        E = ElementMaker(typemap={
            float: lambda e, v: '%0.2f' % (v,),
            int: lambda e, v: L.int(str(v))})
        self.assertEqual(
            '2.50',
            tostring(E('tag', 2.5)))
        self.assertEqual(
            '42',
            tostring(E('tag', 42)))


class MetasyntacticVariables(Names):
    """
    Metasyntactic variable names.
    """
    Foo = NamedConstant()


class GetTextTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.gettext`.
    """
    def setUp(self):
        self.root = L.top(
            L.a('hello'),
            L.b('42'),
            L.b('24'),
            L.c('Foo'),
            L.d,
            L.sub(
                L.e('world'),
                L.e('all'),
                L.f))

    def test_simple(self):
        """
        Getting a sub-element with a `text` attribute returns the text as a
        `unicode` object.
        """
        res = gettext(self.root, u'a')
        self.assertIdentical(unicode, type(res))
        self.assertEqual(res, u'hello')

        res = gettext(self.root, u'sub/e')
        self.assertIdentical(unicode, type(res))
        self.assertEqual(res, u'world')

    def test_default(self):
        """
        Getting a sub-element without a `text` attribute, or attempting to get
        a sub-element that does not exist, results in the `default` parameter
        to `gettext` being used, defaulting to `None`.
        """
        self.assertIdentical(gettext(self.root, u'd'), None)
        self.assertEqual(gettext(self.root, u'd', default=42), 42)

        self.assertIdentical(gettext(self.root, u'sub/f'), None)
        res = gettext(self.root, u'sub/f', default='a')
        self.assertIdentical(str, type(res))
        self.assertEqual(res, 'a')

        self.assertIdentical(gettext(self.root, u'haha_what'), None)
        self.assertEqual(gettext(self.root, u'haha_what', default=42), 42)

    def test_parse(self):
        """
        Specifying a `parse` callable results in that being called to transform
        the element text.
        """
        self.assertEqual(
            42,
            gettext(self.root, u'b', parse=int))
        self.assertEqual(
            MetasyntacticVariables.Foo,
            gettext(self.root, u'c',
                    parse=MetasyntacticVariables.lookupByName))
        self.assertRaises(ValueError,
            gettext, self.root, u'c', parse=int)

    def test_parseWithDefault(self):
        """
        In the event that a default value is specified and a `parse` callable
        given, and the default value is used, the default value will be passed
        to the callable.
        """
        self.assertEqual(
            42,
            gettext(self.root, u'b', default=3.1415, parse=int))
        self.assertEqual(
            21,
            gettext(self.root, u'd', default=21, parse=int))
        self.assertRaises(ValueError,
            gettext, self.root, u'd', default='foo', parse=int)

    def test_gettextall(self):
        """
        `gettextall` is like `gettext` except it uses `elemfindall` instead of
        `elemfind`, returning a ``list`` of results.
        """
        self.assertEqual(
            [42, 24],
            list(gettextall(self.root, u'b', parse=int)))
        self.assertEqual(
            ['world', 'all'],
            list(gettextall(self.root, u'sub/e')))
        self.assertEqual(
            [],
            list(gettextall(self.root, u'what')))


class SplitQualifiedTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.split_qualified`.
    """
    def test_local(self):
        """
        `split_qualified` splits a local XML name into `None` and the tag name.
        """
        self.assertEqual((None, 'tag'), split_qualified('tag'))

    def test_qualified(self):
        """
        `split_qualified` splits a qualified XML name into a URI and the tag
        name.
        """
        self.assertEqual(
            ('http://example.com', 'tag'),
            split_qualified('{http://example.com}tag'))


class FindTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.elemfind`.
    """
    def setUp(self):
        self.root = L.parent(
            L.child1, L.child2, L.child2, L.child3)

    def test_elemfind(self):
        """
        `elemfind` finds the first `QualifiedName` or path specified.
        """
        self.assertEqual(
            '',
            tostring(elemfind(self.root, 'child1')))
        self.assertEqual(
            '',
            tostring(elemfind(self.root, L.child2)))

    def test_elemfind_none(self):
        """
        `elemfind` returns ``None`` if the `QualifiedName` or path specified
        cannot be found.
        """
        self.assertIdentical(None, elemfind(self.root, L.what))

    def test_elemfindall(self):
        """
        `elemfind` finds all sub-elements with the `QualifiedName` or path
        specified.
        """
        self.assertEqual(
            [''],
            map(tostring, elemfindall(self.root, L.child1)))
        self.assertEqual(
            ['', ''],
            map(tostring, elemfindall(self.root, 'child2')))

    def test_elemfindall_none(self):
        """
        `elemfind` returns an empty list if the `QualifiedName` or path
        specified cannot be found.
        """
        self.assertEqual([], elemfindall(self.root, L.what))


class ElementToDictTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.xmlutil.element_to_dict`.
    """
    def test_empty(self):
        """
        An empty element produces a ``None`` value keyed against its tag name.
        """
        self.assertEqual(
            {'root': None},
            element_to_dict(L.root()))

    def test_empty_attributes(self):
        """
        An element containing only attributes, and no content, has its
        attributes, prefixed with an ``@`` keyed against its tag name.
        """
        self.assertEqual(
            {'root': {'@attr': 'value'}},
            element_to_dict(L.root(attr='value')))

    def test_text(self):
        """
        An element containing only text content, has its text keyed against its
        tag name.
        """
        self.assertEqual(
            {'root': 'hello'},
            element_to_dict(L.root('hello')))

    def test_text_attributes(self):
        """
        An element containing attributes and text content, has its
        attributes, prefixed with an ``@`` keyed against its tag name and its
        text keyed against ``#text``.
        """
        self.assertEqual(
            {'root': {'#text': 'hello', '@attr': 'value'}},
            element_to_dict(L.root('hello', attr='value')))

    def test_children_text(self):
        """
        Child elements are recursively nested.

        An element containing only text content, has its text keyed against its
        tag name.
        """
        self.assertEqual(
            {'root': {'child': 'hello'}},
            element_to_dict(
                L.root(L.child('hello'))))

    def test_children_attributes(self):
        """
        Child elements are recursively nested.

        An element containing only attributes, and no content, has its
        attributes, prefixed with an ``@`` keyed against its tag name.
        """
        self.assertEqual(
            {'root': {'child': {'@attr': 'value'}}},
            element_to_dict(
                L.root(L.child(attr='value'))))

    def test_children_text_attributes(self):
        """
        Child elements are recursively nested.

        An element containing attributes and text content, has its
        attributes, prefixed with an ``@`` keyed against its tag name and its
        text keyed against ``#text``.
        """
        self.assertEqual(
            {'root': {'child': {'#text': 'hello', '@attr': 'value'}}},
            element_to_dict(L.root(L.child('hello', attr='value'))))

    def test_children_multiple(self):
        """
        Multiple child elements with the same tag name are coalesced into
        a ``list``.
        """
        self.assertEqual(
            {'root': {'child': [{'@attr': 'value'}, 'hello']}},
            element_to_dict(
                L.root(L.child(attr='value'), L.child('hello'))))

    def test_namespaced(self):
        """
        `element_to_dict` supports namespaced element names and namespaced
        attributes.
        """
        ns = Namespace('http://example.com', 'ex')
        self.assertEqual(
            {str(ns.root): {
                'child': [
                    {'@' + str(ns.attr): 'value'},
                    {'@attr2': 'value2',
                     '#text': 'hello'},
                    'world']}},
            element_to_dict(
                ns.root(
                    L.child({ns.attr: 'value'}),
                    L.child('hello', attr2='value2'),
                    L.child('world'))))
PK=JG)).vumi/transports/parlayx/tests/test_soaputil.pyfrom collections import namedtuple
from functools import partial

from twisted.internet.defer import succeed
from twisted.trial.unittest import TestCase
from twisted.web import http

from vumi.transports.parlayx.soaputil import (
    SOAP_ENV, soap_envelope, unwrap_soap_envelope, soap_fault, tostring,
    perform_soap_request, SoapFault)
from vumi.transports.parlayx.xmlutil import (
    ParseError, gettext, Element, LocalNamespace as L, element_to_dict)
from vumi.transports.parlayx.tests.utils import (
    MockResponse, _FailureResultOfMixin)


class SoapWrapperTests(TestCase):
    """
    Tests for `vumi.transports.parlayx.soaputil.soap_envelope`,
    `vumi.transports.parlayx.soaputil.unwrap_soap_envelope` and
    `vumi.transports.parlayx.soaputil.soap_fault`.
    """
    def test_soap_envelope(self):
        """
        `soap_envelope` wraps content in a SOAP envelope element.
        """
        self.assertEqual(
            ''
            'hello',
            tostring(soap_envelope('hello')))
        self.assertEqual(
            ''
            'hello',
            tostring(soap_envelope(Element('tag', 'hello'))))

    def test_unwrap_soap_envelope(self):
        """
        `unwrap_soap_envelope` unwraps a SOAP envelope element, with no header,
        to a tuple of the SOAP body element and ``None``.
        """
        body, header = unwrap_soap_envelope(
            soap_envelope(Element('tag', 'hello')))
        self.assertIdentical(None, header)
        self.assertEqual(
            'hello'
            '',
            tostring(body))

    def test_unwrap_soap_envelope_header(self):
        """
        `unwrap_soap_envelope` unwraps a SOAP envelope element, with a header,
        to a tuple of the SOAP body and header elements.
        """
        body, header = unwrap_soap_envelope(
            soap_envelope(
                Element('tag', 'hello'),
                Element('header', 'value')))
        self.assertEqual(
            ''
            '
value
', tostring(header)) self.assertEqual( '' 'hello', tostring(body)) def test_soap_fault(self): """ `soap_fault` constructs a SOAP fault element from a code and description. """ self.assertEqual( '' 'soapenv:Client' 'Oops.', tostring(soap_fault('soapenv:Client', 'Oops.'))) class ToyFaultDetail(namedtuple('ToyFaultDetail', ['foo', 'bar'])): """ A SOAP fault detail used for tests. """ @classmethod def from_element(cls, elem): if elem.tag != L.ToyFaultDetail.text: return None return cls(gettext(elem, 'foo'), gettext(elem, 'bar')) class ToyFault(SoapFault): """ A SOAP fault used for tests. """ detail_type = ToyFaultDetail def _make_fault(*a, **kw): """ Create a SOAP body containing a SOAP fault. """ return SOAP_ENV.Body(soap_fault(*a, **kw)) class SoapFaultTests(TestCase): """ Tests for `vumi.transports.parlayx.soaputil.SoapFault`. """ def test_missing_fault(self): """ `SoapFault.from_element` raises `ValueError` if the element contains no SOAP fault. """ self.assertRaises( ValueError, SoapFault.from_element, Element('tag')) def test_from_element(self): """ `SoapFault.from_element` creates a `SoapFault` instance from an ElementTree element and parses known SOAP fault details. """ detail = L.ToyFaultDetail(L.foo('a'), L.bar('b')) fault = SoapFault.from_element(_make_fault( 'soapenv:Client', 'message', 'actor', detail=detail)) self.assertEqual( ('soapenv:Client', 'message', 'actor'), (fault.code, fault.string, fault.actor)) self.assertIdentical(None, fault.parsed_detail) def test_to_element(self): """ `SoapFault.to_element` serializes the fault to a SOAP ``Fault`` ElementTree element. """ detail = L.ToyFaultDetail(L.foo('a'), L.bar('b')) fault = SoapFault.from_element(_make_fault( 'soapenv:Client', 'message', 'actor', detail=detail)) self.assertEqual( {str(SOAP_ENV.Fault): { 'faultcode': fault.code, 'faultstring': fault.string, 'faultactor': fault.actor, 'detail': { 'ToyFaultDetail': {'foo': 'a', 'bar': 'b'}}}}, element_to_dict(fault.to_element())) def test_to_element_no_detail(self): """ `SoapFault.to_element` serializes the fault to a SOAP ``Fault`` ElementTree element, omitting the ``detail`` element if `SoapFault.detail` is None. """ fault = SoapFault.from_element(_make_fault( 'soapenv:Client', 'message', 'actor')) self.assertEqual( {str(SOAP_ENV.Fault): { 'faultcode': fault.code, 'faultstring': fault.string, 'faultactor': fault.actor}}, element_to_dict(fault.to_element())) def test_expected_faults(self): """ `SoapFault.from_element` creates an instance of a specified `SoapFault` subclass if a fault detail of a recognised type occurs. """ detail = [ L.WhatIsThis( L.foo('a'), L.bar('b')), L.ToyFaultDetail( L.foo('c'), L.bar('d'))] fault = SoapFault.from_element(_make_fault( 'soapenv:Client', 'message', 'actor', detail=detail), expected_faults=[ToyFault]) self.assertEqual( ('soapenv:Client', 'message', 'actor'), (fault.code, fault.string, fault.actor)) parsed_detail = fault.parsed_detail self.assertEqual( ('c', 'd'), (parsed_detail.foo, parsed_detail.bar)) class PerformSoapRequestTests(_FailureResultOfMixin, TestCase): """ Tests for `vumi.transports.parlayx.soaputil.perform_soap_request`. """ def setUp(self): self.requests = [] def _http_request_full(self, response, uri, body, headers): """ A mock for `vumi.utils.http_request_full`. Store an HTTP request's information and return a canned response. """ self.requests.append((uri, body, headers)) return succeed(response) def _perform_soap_request(self, response, *a, **kw): """ Perform a SOAP request with a canned response. """ return perform_soap_request( http_request_full=partial( self._http_request_full, response), *a, **kw) def test_success(self): """ `perform_soap_request` issues a SOAP request, over HTTP, to a URI, sets the ``SOAPAction`` header and parses the response as a SOAP envelope. """ response = MockResponse.build(http.OK, 'response', 'response_header') body, header = self.successResultOf( self._perform_soap_request(response, 'uri', 'action', 'request')) self.assertEqual([ ('uri', tostring(soap_envelope('request')), {'SOAPAction': 'action', 'Content-Type': 'text/xml; charset="utf-8"'})], self.requests) self.assertEqual(SOAP_ENV.Body.text, body.tag) self.assertEqual('response', body.text) self.assertEqual(SOAP_ENV.Header.text, header.tag) self.assertEqual('response_header', header.text) def test_response_not_xml(self): """ `perform_soap_request` raises `xml.etree.ElementTree.ParseError` if the response is not valid XML. """ response = MockResponse(http.OK, 'hello') self.failureResultOf( self._perform_soap_request(response, 'uri', 'action', 'request'), ParseError) def test_response_no_body(self): """ `perform_soap_request` raises `SoapFault` if the response contains no SOAP body element.. """ response = MockResponse(http.OK, tostring(SOAP_ENV.Envelope('hello'))) f = self.failureResultOf( self._perform_soap_request(response, 'uri', 'action', 'request'), SoapFault) self.assertEqual('soapenv:Client', f.value.code) self.assertEqual('Malformed SOAP request', f.getErrorMessage()) def test_fault(self): """ `perform_soap_request` raises `SoapFault`, parsed from the ``Fault`` element in the response, if the response HTTP status is ``500 Internal server error``. """ response = MockResponse.build( http.INTERNAL_SERVER_ERROR, soap_fault('soapenv:Server', 'Whoops')) f = self.failureResultOf( self._perform_soap_request(response, 'uri', 'action', 'request'), SoapFault) self.assertEqual( ('soapenv:Server', 'Whoops'), (f.value.code, f.getErrorMessage())) def test_expected_fault(self): """ `perform_soap_request` raises a `SoapFault` subclass when a SOAP fault detail matches one of the expected fault types. """ detail = L.ToyFaultDetail(L.foo('a'), L.bar('b')) response = MockResponse.build( http.INTERNAL_SERVER_ERROR, soap_fault('soapenv:Server', 'Whoops', detail=detail)) f = self.failureResultOf( self._perform_soap_request( response, 'uri', 'action', 'request', expected_faults=[ToyFault]), ToyFault) self.assertEqual( ('soapenv:Server', 'Whoops'), (f.value.code, f.getErrorMessage())) parsed_detail = f.value.parsed_detail self.assertEqual( ('a', 'b'), (parsed_detail.foo, parsed_detail.bar)) PK=JGv^y+aa"vumi/transports/airtel/__init__.pyfrom vumi.transports.airtel.airtel import AirtelUSSDTransport __all__ = ['AirtelUSSDTransport'] PK=JGVnl%% vumi/transports/airtel/airtel.py# -*- test-case-name: vumi.transports.airtel.tests.test_airtel -*- import json import re from twisted.internet.defer import inlineCallbacks from twisted.web import http from vumi.transports.httprpc import HttpRpcTransport from vumi.components.session import SessionManager from vumi.message import TransportUserMessage from vumi import log from vumi.config import ConfigInt, ConfigText, ConfigBool, ConfigDict class AirtelUSSDTransportConfig(HttpRpcTransport.CONFIG_CLASS): airtel_username = ConfigText('The username for this transport', default=None, static=True) airtel_password = ConfigText('The password for this transport', default=None, static=True) airtel_charge = ConfigBool( 'Whether or not to charge for the responses sent.', required=False, default=False, static=True) airtel_charge_amount = ConfigInt('How much to charge', default=0, required=False, static=True) redis_manager = ConfigDict('Parameters to connect to Redis with.', default={}, required=False, static=True) session_key_prefix = ConfigText( 'The prefix to use for session key management. Specify this' 'if you are using more than 1 worker in a load-balanced' 'fashion.', default=None, static=True) ussd_session_timeout = ConfigInt('Max length of a USSD session', default=60 * 10, required=False, static=True) to_addr_pattern = ConfigText( 'A regular expression that to_addr values in messages that start a' ' new USSD session must match. Initial messages with invalid' ' to_addr values are rejected.', default=None, required=False, static=True, ) class AirtelUSSDTransport(HttpRpcTransport): """ Client implementation for the Comviva Flares HTTP Pull API. Based on Flares 1.5.0, document version 1.2.0 """ transport_type = 'ussd' content_type = 'text/plain; charset=utf-8' to_addr_re = None ENCODING = 'utf-8' CONFIG_CLASS = AirtelUSSDTransportConfig EXPECTED_AUTH_FIELDS = set(['userid', 'password']) EXPECTED_CLEANUP_FIELDS = set(['SessionID', 'msisdn', 'clean', 'error']) EXPECTED_USSD_FIELDS = set(['SessionID', 'MSISDN', 'MSC', 'input']) @inlineCallbacks def setup_transport(self): super(AirtelUSSDTransport, self).setup_transport() config = self.get_static_config() self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, self.get_session_key_prefix(), config.ussd_session_timeout) if config.to_addr_pattern is not None: self.to_addr_re = re.compile(config.to_addr_pattern) def get_session_key_prefix(self): config = self.get_static_config() default_session_key_prefix = "vumi.transports.airtel:%s" % ( self.transport_name,) return (config.session_key_prefix or default_session_key_prefix) def is_cleanup(self, request): return 'clean' in request.args def requires_auth(self): config = self.get_static_config() return (None not in (config.airtel_username, config.airtel_password)) def is_authenticated(self, request): config = self.get_static_config() if self.EXPECTED_AUTH_FIELDS.issubset(request.args): username = request.args['userid'][0] password = request.args['password'][0] auth = (username == config.airtel_username and password == config.airtel_password) if not auth: log.msg('Invalid authentication credentials: %s:%s' % ( username, password)) return auth def valid_to_addr(self, to_addr): if self.to_addr_re is None: return True return bool(self.to_addr_re.match(to_addr)) def handle_bad_request(self, message_id, request, errors): log.msg('Unhappy incoming message: %s' % (errors,)) return self.finish_request(message_id, json.dumps(errors), code=http.BAD_REQUEST) def handle_raw_inbound_message(self, message_id, request): if self.requires_auth() and not self.is_authenticated(request): self.finish_request(message_id, 'Forbidden', code=http.FORBIDDEN) return if self.is_cleanup(request): return self.handle_cleanup_request(message_id, request) return self.handle_ussd_request(message_id, request) @inlineCallbacks def handle_cleanup_request(self, message_id, request): if self.requires_auth(): fields = self.EXPECTED_CLEANUP_FIELDS.union( self.EXPECTED_AUTH_FIELDS) else: fields = self.EXPECTED_CLEANUP_FIELDS values, errors = self.get_field_values(request, fields) if errors: self.handle_bad_request(message_id, request, errors) return session_id = values['SessionID'] session = yield self.session_manager.load_session(session_id) if not session: log.warning('Received cleanup for unknown airtel session.', session_id=session_id) self.finish_request(message_id, 'Unknown Session', code=http.OK) return from_addr = values['msisdn'] to_addr = session['to_addr'] session_event = TransportUserMessage.SESSION_CLOSE yield self.session_manager.clear_session(session_id) yield self.publish_message( message_id=message_id, content='', to_addr=to_addr, from_addr=from_addr, provider='airtel', session_event=session_event, transport_type=self.transport_type, transport_metadata={ 'airtel': { 'clean': values['clean'], 'error': values['error'], }, }) self.finish_request(message_id, '', code=http.OK) @inlineCallbacks def handle_ussd_request(self, message_id, request): if self.requires_auth(): fields = self.EXPECTED_USSD_FIELDS.union( self.EXPECTED_AUTH_FIELDS) else: fields = self.EXPECTED_USSD_FIELDS values, errors = self.get_field_values(request, fields) if errors: self.handle_bad_request(message_id, request, errors) return session_id = values['SessionID'] from_addr = values['MSISDN'] session = yield self.session_manager.load_session(session_id) if session: to_addr = session['to_addr'] yield self.session_manager.save_session(session_id, session) session_event = TransportUserMessage.SESSION_RESUME content = values['input'] else: # Airtel doesn't provide us with the full to_addr, the start * # and ending # are omitted, add those again so we can use it # for internal routing. to_addr = '*%s#' % (values['input'],) if self.valid_to_addr(to_addr): yield self.session_manager.create_session( session_id, from_addr=from_addr, to_addr=to_addr) session_event = TransportUserMessage.SESSION_NEW content = '' else: self.handle_bad_request(message_id, request, { "invalid_session": ( "Session id %r has not been encountered in the last %s" " seconds and the 'input' request parameter value" " %r doesn't look like a valid USSD address." % (session_id, self.session_manager.max_session_length, to_addr) ) }) return yield self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=from_addr, provider='airtel', session_event=session_event, transport_type=self.transport_type, transport_metadata={ 'airtel': { 'MSC': values['MSC'], }, }) def handle_outbound_message(self, message): config = self.get_static_config() missing_fields = self.ensure_message_values( message, ['in_reply_to', 'content']) if missing_fields: return self.reject_message(message, missing_fields) if message['session_event'] == TransportUserMessage.SESSION_CLOSE: free_flow = 'FB' else: free_flow = 'FC' headers = { 'Freeflow': [free_flow], 'charge': [('Y' if config.airtel_charge else 'N')], 'amount': [str(config.airtel_charge_amount)], } content = message['content'].encode(self.ENCODING).lstrip() if self.noisy: log.debug('in_reply_to: %s' % (message['in_reply_to'],)) log.debug('content: %r' % (content,)) log.debug('Response headers: %r' % (headers,)) self.finish_request( message['in_reply_to'], content, code=http.OK, headers=headers) return self.publish_ack( user_message_id=message['message_id'], sent_message_id=message['message_id']) PK=JG(vumi/transports/airtel/tests/__init__.pyPK=JGDD+vumi/transports/airtel/tests/test_airtel.pyimport json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from twisted.web import http from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.airtel import AirtelUSSDTransport from vumi.message import TransportUserMessage from vumi.utils import http_request_full from vumi.transports.tests.helpers import TransportHelper class AirtelUSSDTransportTestCase(VumiTestCase): airtel_username = None airtel_password = None session_id = 'session-id' @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(AirtelUSSDTransport)) self.config = self.mk_config() self.transport = yield self.tx_helper.get_transport(self.config) self.session_manager = self.transport.session_manager self.add_cleanup(self.session_manager.stop) self.transport_url = self.transport.get_transport_url( self.config['web_path']) yield self.session_manager.redis._purge_all() # just in case def mk_config(self): return { 'web_port': 0, 'web_path': '/api/v1/airtel/ussd/', 'airtel_username': self.airtel_username, 'airtel_password': self.airtel_password, 'validation_mode': 'permissive', } def mk_full_request(self, **params): return http_request_full('%s?%s' % (self.transport_url, urlencode(params)), data='', method='GET') def mk_request(self, **params): defaults = { 'MSISDN': '27761234567', } if all([self.airtel_username, self.airtel_password]): defaults.update({ 'userid': self.airtel_username, 'password': self.airtel_password, }) defaults.update(params) return self.mk_full_request(**defaults) def mk_ussd_request(self, content, **kwargs): defaults = { 'MSC': 'msc', 'input': content, 'SessionID': self.session_id, } defaults.update(kwargs) return self.mk_request(**defaults) def mk_cleanup_request(self, **kwargs): defaults = { 'clean': 'clean-session', 'error': 522, 'SessionID': self.session_id, } defaults.update(kwargs) return self.mk_request(**defaults) class TestAirtelUSSDTransport(AirtelUSSDTransportTestCase): @inlineCallbacks def test_inbound_begin(self): # Second connect is the actual start of the session deferred = self.mk_ussd_request('121') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '') self.assertEqual(msg['to_addr'], '*121#') self.assertEqual(msg['from_addr'], '27761234567'), self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['transport_metadata'], { 'airtel': { 'MSC': 'msc', }, }) yield self.tx_helper.make_dispatch_reply(msg, "ussd message") response = yield deferred self.assertEqual(response.delivered_body, 'ussd message') self.assertEqual(response.headers.getRawHeaders('Freeflow'), ['FC']) self.assertEqual(response.headers.getRawHeaders('charge'), ['N']) self.assertEqual(response.headers.getRawHeaders('amount'), ['0']) @inlineCallbacks def test_strip_leading_newlines(self): # Second connect is the actual start of the session deferred = self.mk_ussd_request('121') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(msg, "\nfoo\n") response = yield deferred self.assertEqual(response.delivered_body, 'foo\n') self.assertEqual(response.headers.getRawHeaders('Freeflow'), ['FC']) self.assertEqual(response.headers.getRawHeaders('charge'), ['N']) self.assertEqual(response.headers.getRawHeaders('amount'), ['0']) @inlineCallbacks def test_inbound_resume_and_reply_with_end(self): # first pre-populate the redis datastore to simulate prior BEG message yield self.session_manager.create_session(self.session_id, to_addr='*167*7#', from_addr='27761234567', session_event=TransportUserMessage.SESSION_RESUME) # Safaricom gives us the history of the full session in the USSD_PARAMS # The last submitted bit of content is the last value delimited by '*' deferred = self.mk_ussd_request('c') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], 'c') self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '27761234567') self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply( msg, "hello world", continue_session=False) response = yield deferred self.assertEqual(response.delivered_body, 'hello world') self.assertEqual(response.headers.getRawHeaders('Freeflow'), ['FB']) @inlineCallbacks def test_inbound_resume_with_failed_to_addr_lookup(self): deferred = self.mk_request(MSISDN='123456', input='7*a', SessionID='foo') response = yield deferred self.assertEqual(json.loads(response.delivered_body), { 'missing_parameter': ['MSC'], }) @inlineCallbacks def test_to_addr_handling(self): d1 = self.mk_ussd_request('167*7*1') [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167*7*1#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, "hello world") yield d1 # follow up with the user submitting 'a' d2 = self.mk_ussd_request('a') [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167*7*1#') self.assertEqual(msg2['content'], 'a') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply( msg2, "hello world", continue_session=False) yield d2 @inlineCallbacks def test_hitting_url_twice_without_content(self): d1 = self.mk_ussd_request('167*7*3') [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167*7*3#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, 'Hello') yield d1 # make the exact same request again d2 = self.mk_ussd_request('') [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167*7*3#') self.assertEqual(msg2['content'], '') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg2, 'Hello') yield d2 @inlineCallbacks def test_submitting_asterisks_as_values(self): yield self.session_manager.create_session(self.session_id, to_addr='*167*7#', from_addr='27761234567') # we're submitting a bunch of *s deferred = self.mk_ussd_request('****') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '****') yield self.tx_helper.make_dispatch_reply(msg, 'Hello') yield deferred @inlineCallbacks def test_submitting_asterisks_as_values_after_asterisks(self): yield self.session_manager.create_session(self.session_id, to_addr='*167*7#', from_addr='27761234567') # we're submitting a bunch of *s deferred = self.mk_ussd_request('**') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '**') yield self.tx_helper.make_dispatch_reply(msg, 'Hello') yield deferred @inlineCallbacks def test_submitting_with_base_code_empty_ussd_params(self): d1 = self.mk_ussd_request('167') [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, 'Hello') yield d1 # ask for first menu d2 = self.mk_ussd_request('1') [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167#') self.assertEqual(msg2['content'], '1') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg2, 'Hello') yield d2 # ask for second menu d3 = self.mk_ussd_request('1') [msg1, msg2, msg3] = ( yield self.tx_helper.wait_for_dispatched_inbound(3)) self.assertEqual(msg3['to_addr'], '*167#') self.assertEqual(msg3['content'], '1') self.assertEqual(msg3['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg3, 'Hello') yield d3 @inlineCallbacks def test_cleanup_unknown_session(self): response = yield self.mk_cleanup_request(msisdn='foo') self.assertEqual(response.code, http.OK) self.assertEqual(response.delivered_body, 'Unknown Session') @inlineCallbacks def test_cleanup_session(self): yield self.session_manager.create_session(self.session_id, to_addr='*167*7#', from_addr='27761234567') response = yield self.mk_cleanup_request(msisdn='27761234567') self.assertEqual(response.code, http.OK) self.assertEqual(response.delivered_body, '') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '27761234567') self.assertEqual(msg['transport_metadata'], { 'airtel': { 'error': '522', 'clean': 'clean-session', } }) @inlineCallbacks def test_cleanup_session_missing_params(self): response = yield self.mk_request(clean='clean-session') self.assertEqual(response.code, http.BAD_REQUEST) json_response = json.loads(response.delivered_body) self.assertEqual(set(json_response['missing_parameter']), set(['msisdn', 'SessionID', 'error'])) @inlineCallbacks def test_cleanup_as_seen_in_production(self): """what's a technical spec between friends?""" yield self.session_manager.create_session('13697502734175597', to_addr='*167*7#', from_addr='254XXXXXXXXX') query_string = ("msisdn=254XXXXXXXXX&clean=cleann&error=523" "&SessionID=13697502734175597&MSC=254XXXXXXXXX" "&=&=en&=9031510005344&=&=&=postpaid" "&=20130528171235405&=200220130528171113956582") response = yield http_request_full( '%s?%s' % (self.transport_url, query_string), data='', method='GET') self.assertEqual(response.code, http.OK) self.assertEqual(response.delivered_body, '') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '254XXXXXXXXX') self.assertEqual(msg['transport_metadata'], { 'airtel': { 'clean': 'cleann', 'error': '523', } }) class TestAirtelUSSDTransportWithToAddrValidation(AirtelUSSDTransportTestCase): def mk_config(self): config = super(TestAirtelUSSDTransportWithToAddrValidation, self).mk_config() config['to_addr_pattern'] = '^\*121#$' return config @inlineCallbacks def test_inbound_begin_with_valid_to_addr(self): # Second connect is the actual start of the session deferred = self.mk_ussd_request('121') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '') self.assertEqual(msg['to_addr'], '*121#') self.assertEqual(msg['from_addr'], '27761234567'), self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['transport_metadata'], { 'airtel': { 'MSC': 'msc', }, }) yield self.tx_helper.make_dispatch_reply(msg, "ussd message") response = yield deferred self.assertEqual(response.delivered_body, 'ussd message') self.assertEqual(response.headers.getRawHeaders('Freeflow'), ['FC']) self.assertEqual(response.headers.getRawHeaders('charge'), ['N']) self.assertEqual(response.headers.getRawHeaders('amount'), ['0']) @inlineCallbacks def test_inbound_begin_with_invalid_to_addr(self): # Second connect is the actual start of the session with LogCatcher(message='Unhappy') as lc: response = yield self.mk_ussd_request('123') [log_msg] = lc.messages() self.assertEqual(response.code, 400) error_msg = json.loads(response.delivered_body) expected_error = { 'invalid_session': ( "Session id u'session-id' has not been encountered in the" " last 600 seconds and the 'input' request parameter value" " u'*123#' doesn't look like a valid USSD address." ) } self.assertEqual(error_msg, expected_error) self.assertEqual( log_msg, "Unhappy incoming message: %s" % (expected_error,)) class TestAirtelUSSDTransportWithAuth(TestAirtelUSSDTransport): transport_class = AirtelUSSDTransport airtel_username = 'userid' airtel_password = 'password' @inlineCallbacks def test_cleanup_session_invalid_auth(self): response = yield self.mk_cleanup_request(userid='foo', password='bar') self.assertEqual(response.code, http.FORBIDDEN) self.assertEqual(response.delivered_body, 'Forbidden') @inlineCallbacks def test_cleanup_as_seen_in_production(self): """what's a technical spec between friends?""" yield self.session_manager.create_session('13697502734175597', to_addr='*167*7#', from_addr='254XXXXXXXXX') query_string = ("msisdn=254XXXXXXXXX&clean=cleann&error=523" "&SessionID=13697502734175597&MSC=254XXXXXXXXX" "&=&=en&=9031510005344&=&=&=postpaid" "&=20130528171235405&=200220130528171113956582" "&userid=%s&password=%s" % (self.airtel_username, self.airtel_password)) response = yield http_request_full( '%s?%s' % (self.transport_url, query_string), data='', method='GET') self.assertEqual(response.code, http.OK) self.assertEqual(response.delivered_body, '') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '254XXXXXXXXX') self.assertEqual(msg['transport_metadata'], { 'airtel': { 'clean': 'cleann', 'error': '523', } }) class TestLoadBalancedAirtelUSSDTransport(VumiTestCase): def setUp(self): self.default_config = { 'web_port': 0, 'web_path': '/api/v1/airtel/ussd/', 'validation_mode': 'permissive', } self.tx_helper = self.add_helper(TransportHelper(AirtelUSSDTransport)) @inlineCallbacks def test_session_prefixes(self): config1 = self.default_config.copy() config1['transport_name'] = 'transport_1' config1['session_key_prefix'] = 'foo' config2 = self.default_config.copy() config2['transport_name'] = 'transport_2' config2['session_key_prefix'] = 'foo' self.transport1 = yield self.tx_helper.get_transport(config1) self.transport2 = yield self.tx_helper.get_transport(config2) self.transport3 = yield self.tx_helper.get_transport( self.default_config) self.assertEqual(self.transport1.get_session_key_prefix(), 'foo') self.assertEqual(self.transport2.get_session_key_prefix(), 'foo') self.assertEqual(self.transport3.get_session_key_prefix(), 'vumi.transports.airtel:sphex') PKh^xG&4,,vumi/transports/xmpp/xmpp.py# -*- test-case-name: vumi.transports.xmpp.tests.test_xmpp -*- # -*- encoding: utf-8 -*- from twisted.words.protocols.jabber.jid import JID from twisted.words.xish import domish from twisted.words.xish.domish import Element as DomishElement from twisted.internet.task import LoopingCall from twisted.internet.defer import inlineCallbacks from wokkel.client import XMPPClient from wokkel.ping import PingClientProtocol from wokkel.xmppim import (RosterClientProtocol, MessageProtocol, PresenceClientProtocol) from vumi.transports.base import Transport class TransportRosterClientProtocol(RosterClientProtocol): def connectionInitialized(self): # get the roster as soon as the connection's been initialized, this # allows us to see who's online but more importantly, allows us to see # who's added us to their roster. This allows us to auto subscribe to # anyone, automatically adding them to our roster, skips the "user ... # wants to add you to their roster, allow? yes/no" hoopla. self.getRoster() class TransportPresenceClientProtocol(PresenceClientProtocol): """ A custom presence protocol to automatically accept any subscription attempt. """ def __init__(self, initialized_callback, *args, **kwargs): super(TransportPresenceClientProtocol, self).__init__(*args, **kwargs) self.initialized_callback = initialized_callback def connectionInitialized(self): super(TransportPresenceClientProtocol, self).connectionInitialized() self.initialized_callback() def subscribeReceived(self, entity): self.subscribe(entity) self.subscribed(entity) def unsubscribeReceived(self, entity): self.unsubscribe(entity) self.unsubscribed(entity) class XMPPTransportProtocol(MessageProtocol, object): def __init__(self, jid, message_callback, connection_callback, connection_lost_callback=None,): super(MessageProtocol, self).__init__() self.jid = jid self.message_callback = message_callback self.connection_callback = connection_callback self.connection_lost_callback = connection_lost_callback def reply(self, jid, content): message = domish.Element((None, "message")) # intentionally leaving from blank, leaving for XMPP server # to figure out message['to'] = jid message['type'] = 'chat' message.addUniqueId() message.addElement((None, 'body'), content=content) self.xmlstream.send(message) def onMessage(self, message): """Messages sent to the bot will arrive here. Command handling routing is done in this function.""" if not isinstance(message.body, DomishElement): return None text = unicode(message.body).encode('utf-8').strip() from_addr, _, _ = message['from'].partition('/') self.message_callback( to_addr=self.jid.userhost(), from_addr=from_addr, content=text, transport_type='xmpp', transport_metadata={ 'xmpp_id': message.getAttribute('id'), }) def connectionMade(self): self.connection_callback() return super(XMPPTransportProtocol, self).connectionMade() def connectionLost(self, reason): if self.connection_lost_callback is not None: self.connection_lost_callback(reason) super(XMPPTransportProtocol, self).connectionLost(reason) class XMPPTransport(Transport): """XMPP transport. Configuration parameters: :type host: str :param host: The host of the XMPP server to connect to. :type port: int :param port: The port on the XMPP host to connect to. :type debug: bool :param debug: Whether or not to show all the XMPP traffic. Defaults to False. :type username: str :param username: The XMPP account username :type password: str :param password: The XMPP account password :type status: str :param status: The XMPP status 'away', 'xa', 'chat' or 'dnd' :type status_message: str :param status_message: The natural language status message for this XMPP transport. :type presence_interval: int :param presence_interval: How often (in seconds) to send a presence update to the roster. :type ping_interval: int :param ping_interval: How often (in seconds) to send a keep-alive ping to the XMPP server to keep the connection alive. Defaults to 60 seconds. """ start_message_consumer = False _xmpp_protocol = XMPPTransportProtocol _xmpp_client = XMPPClient def __init__(self, options, config=None): super(XMPPTransport, self).__init__(options, config=config) self.ping_call = LoopingCall(self.send_ping) self.presence_call = LoopingCall(self.send_presence) def validate_config(self): self.host = self.config['host'] self.port = int(self.config['port']) self.debug = self.config.get('debug', False) self.username = self.config['username'] self.password = self.config['password'] self.status = self.config['status'] self.status_message = self.config.get('status_message', '') self.ping_interval = self.config.get('ping_interval', 60) self.presence_interval = self.config.get('presence_interval', 60) def setup_transport(self): self.log.msg("Starting XMPPTransport: %s" % self.transport_name) self.jid = JID(self.username) self.xmpp_client = self._xmpp_client( self.jid, self.password, self.host, self.port) self.xmpp_client.logTraffic = self.debug self.xmpp_client.setServiceParent(self) self.presence = TransportPresenceClientProtocol(self.announce_presence) self.presence.setHandlerParent(self.xmpp_client) self.pinger = PingClientProtocol() self.pinger.setHandlerParent(self.xmpp_client) self.ping_call.start(self.ping_interval, now=False) roster = TransportRosterClientProtocol() roster.setHandlerParent(self.xmpp_client) self.xmpp_protocol = self._xmpp_protocol( self.jid, self.publish_message, self.unpause_connectors, connection_lost_callback=self.connection_lost) self.xmpp_protocol.setHandlerParent(self.xmpp_client) self.log.msg("XMPPTransport %s started." % self.transport_name) def connection_lost(self, reason): self.log.msg("XMPP Connection lost. %s" % reason) def announce_presence(self): if not self.presence_call.running: self.presence_call.start(self.presence_interval) @inlineCallbacks def send_ping(self): if self.xmpp_client.xmlstream: yield self.pinger.ping(self.jid) def send_presence(self): if self.xmpp_client.xmlstream: self.presence.available(statuses={ None: self.status}) def teardown_transport(self): self.log.msg("XMPPTransport %s stopped." % self.transport_name) ping_call = getattr(self, 'ping_call', None) if ping_call and ping_call.running: ping_call.stop() presence_call = getattr(self, 'presence_call', None) if presence_call and presence_call.running: presence_call.stop() def handle_outbound_message(self, message): recipient = message['to_addr'] text = message['content'] jid = JID(recipient).userhost() if not self.xmpp_protocol.xmlstream: self.log.err("Outbound undeliverable, XMPP not initialized yet.") return False else: self.xmpp_protocol.reply(jid, text) return self.publish_ack( user_message_id=message['message_id'], sent_message_id=message['message_id']) PK=JGΐrpp vumi/transports/xmpp/__init__.py""" Vumi XMPP transport. """ from vumi.transports.xmpp.xmpp import XMPPTransport __all__ = ['XMPPTransport'] PKh^xG^dV3'vumi/transports/xmpp/tests/test_xmpp.pyfrom twisted.python import log from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.task import Clock from twisted.words.xish import domish from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.xmpp.xmpp import ( XMPPTransport, XMPPClient, XMPPTransportProtocol) from vumi.transports.tests.helpers import TransportHelper class DummyXMLStream(object): def __init__(self): self.outbox = [] def send(self, message): self.outbox.append(message) def addObserver(self, event, observerfn, *args, **kwargs): """Ignore.""" class DummyXMPPClient(XMPPClient): def __init__(self, *args, **kw): XMPPClient.__init__(self, *args, **kw) self._connection = None def startService(self): pass def stopService(self): pass class DummyXMPPTransportProtocol(XMPPTransportProtocol): def __init__(self, *args, **kwargs): XMPPTransportProtocol.__init__(self, *args, **kwargs) self.xmlstream = DummyXMLStream() class TestXMPPTransport(VumiTestCase): @inlineCallbacks def mk_transport(self): self.tx_helper = self.add_helper(TransportHelper(XMPPTransport)) transport = yield self.tx_helper.get_transport({ 'username': 'user@xmpp.domain.com', 'password': 'testing password', 'status': 'chat', 'status_message': 'XMPP Transport', 'host': 'xmpp.domain.com', 'port': 5222, 'transport_type': 'xmpp', }, start=False) transport._xmpp_protocol = DummyXMPPTransportProtocol transport._xmpp_client = DummyXMPPClient transport.ping_call.clock = Clock() transport.presence_call.clock = Clock() yield transport.startWorker() yield transport.xmpp_protocol.connectionMade() self.jid = transport.jid returnValue(transport) def assert_ack(self, ack, reply): self.assertEqual(ack.payload['event_type'], 'ack') self.assertEqual(ack.payload['user_message_id'], reply['message_id']) self.assertEqual(ack.payload['sent_message_id'], reply['message_id']) @inlineCallbacks def test_outbound_message(self): transport = yield self.mk_transport() msg = yield self.tx_helper.make_dispatch_outbound( "hi", to_addr='user@xmpp.domain.com', from_addr='test@case.com') xmlstream = transport.xmpp_protocol.xmlstream self.assertEqual(len(xmlstream.outbox), 1) message = xmlstream.outbox[0] self.assertEqual(message['to'], 'user@xmpp.domain.com') self.assertTrue(message['id']) self.assertEqual(str(message.children[0]), 'hi') [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, msg) @inlineCallbacks def test_inbound_message(self): transport = yield self.mk_transport() message = domish.Element((None, "message")) message['to'] = self.jid.userhost() message['from'] = 'test@case.com' message.addUniqueId() message.addElement((None, 'body'), content='hello world') protocol = transport.xmpp_protocol protocol.onMessage(message) [msg] = yield self.tx_helper.wait_for_dispatched_inbound() self.assertEqual(msg['to_addr'], self.jid.userhost()) self.assertEqual(msg['from_addr'], 'test@case.com') self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertNotEqual(msg['message_id'], message['id']) self.assertEqual(msg['transport_metadata']['xmpp_id'], message['id']) self.assertEqual(msg['content'], 'hello world') @inlineCallbacks def test_message_without_id(self): transport = yield self.mk_transport() message = domish.Element((None, "message")) message['to'] = self.jid.userhost() message['from'] = 'test@case.com' message.addElement((None, 'body'), content='hello world') self.assertFalse(message.hasAttribute('id')) protocol = transport.xmpp_protocol protocol.onMessage(message) [msg] = yield self.tx_helper.wait_for_dispatched_inbound() self.assertTrue(msg['message_id']) self.assertEqual(msg['transport_metadata']['xmpp_id'], None) @inlineCallbacks def test_pinger(self): """ The transport's pinger should send a ping after the ping_interval. """ transport = yield self.mk_transport() self.assertEqual(transport.ping_interval, 60) # The LoopingCall should be configured and started. self.assertEqual(transport.ping_call.f, transport.send_ping) self.assertEqual(transport.ping_call.a, ()) self.assertEqual(transport.ping_call.kw, {}) self.assertEqual(transport.ping_call.interval, 60) self.assertTrue(transport.ping_call.running) # Stub output stream xmlstream = DummyXMLStream() transport.xmpp_client.xmlstream = xmlstream transport.pinger.xmlstream = xmlstream # Ping transport.ping_call.clock.advance(59) self.assertEqual(xmlstream.outbox, []) transport.ping_call.clock.advance(2) self.assertEqual(len(xmlstream.outbox), 1, repr(xmlstream.outbox)) [message] = xmlstream.outbox self.assertEqual(message['to'], u'user@xmpp.domain.com') self.assertEqual(message['type'], u'get') [child] = message.children self.assertEqual(child.toXml(), u"") @inlineCallbacks def test_presence(self): """ The transport's presence should be announced regularly. """ transport = yield self.mk_transport() self.assertEqual(transport.presence_interval, 60) # The LoopingCall should be configured and started. self.assertFalse(transport.presence_call.running) # Stub output stream xmlstream = DummyXMLStream() transport.xmpp_client.xmlstream = xmlstream transport.xmpp_client._initialized = True transport.presence.xmlstream = xmlstream self.assertEqual(xmlstream.outbox, []) transport.presence.connectionInitialized() transport.presence_call.clock.advance(1) self.assertEqual(len(xmlstream.outbox), 1, repr(xmlstream.outbox)) self.assertEqual(transport.presence_call.f, transport.send_presence) self.assertEqual(transport.presence_call.a, ()) self.assertEqual(transport.presence_call.kw, {}) self.assertEqual(transport.presence_call.interval, 60) self.assertTrue(transport.presence_call.running) [presence] = xmlstream.outbox self.assertEqual( presence.toXml(), u"chat") @inlineCallbacks def test_normalizing_from_addr(self): transport = yield self.mk_transport() message = domish.Element((None, "message")) message['to'] = self.jid.userhost() message['from'] = 'test@case.com/some_xmpp_id' message.addUniqueId() message.addElement((None, 'body'), content='hello world') protocol = transport.xmpp_protocol protocol.onMessage(message) [msg] = yield self.tx_helper.wait_for_dispatched_inbound() self.assertEqual(msg['from_addr'], 'test@case.com') self.assertEqual(msg['transport_metadata']['xmpp_id'], message['id']) @inlineCallbacks def test_xmpp_connection_lost(self): '''When the XMPP connection is lost, the connection callback should be called, which should log that the connection was lost.''' transport = yield self.mk_transport() with LogCatcher() as lc: yield transport.xmpp_protocol.connectionLost('Test connection') [connection_lost_log] = lc.logs self.assertEqual( log.textFromEventDict(connection_lost_log), 'XMPP Connection lost. Test connection') PK=JG&vumi/transports/xmpp/tests/__init__.pyPK=JGC55&vumi/transports/tests/test_failures.pyimport time import json from datetime import datetime, timedelta from twisted.internet.defer import inlineCallbacks from vumi.message import Message from vumi.transports.failures import FailureWorker from vumi.tests.helpers import VumiTestCase, PersistenceHelper, WorkerHelper def mktimestamp(delta=0): timestamp = datetime.utcnow() + timedelta(seconds=delta) return timestamp.isoformat().split('.')[0] class TestFailureWorker(VumiTestCase): def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) return self.make_worker() @inlineCallbacks def make_worker(self, retry_delivery_period=0): self.worker_helper = self.add_helper(WorkerHelper('sphex')) config = self.persistence_helper.mk_config({ 'transport_name': 'sphex', 'retry_routing_key': 'sms.outbound.%(transport_name)s', 'failures_routing_key': 'sms.failures.%(transport_name)s', 'retry_delivery_period': retry_delivery_period, }) self.worker = yield self.worker_helper.get_worker( FailureWorker, config) self.redis = self.worker.redis yield self.redis._purge_all() # Just in case def assert_write_timestamp(self, expected, delta, now): self.assertEqual(expected, self.worker.get_next_write_timestamp(delta, now=now)) @inlineCallbacks def assert_zcard(self, expected, key): self.assertEqual(expected, (yield self.redis.zcard(key))) @inlineCallbacks def assert_equal_d(self, expected, value): self.assertEqual((yield expected), (yield value)) @inlineCallbacks def assert_not_equal_d(self, expected, value): self.assertNotEqual((yield expected), (yield value)) @inlineCallbacks def assert_get_retry_key(self, exists=True): retry_key = yield self.worker.get_next_retry_key() if exists: self.assertNotEqual(None, retry_key) else: self.assertEqual(None, retry_key) @inlineCallbacks def assert_stored_timestamps(self, *expected): timestamps = yield self.redis.zrange('retry_timestamps', 0, -1) self.assertEqual(list(expected), timestamps) def assert_published_retries(self, expected): msgs = self.worker_helper.get_dispatched( 'sms.outbound', 'sphex', Message) self.assertEqual(expected, [m.payload for m in msgs]) def store_failure(self, reason=None, message=None): if not reason: reason = "bad stuff happened" if not message: message = {'message': 'foo', 'reason': reason} return self.worker.store_failure(message, reason) @inlineCallbacks def store_retry(self, retry_delay=0, now_delta=0, reason=None, message_json=None): key = yield self.store_failure(reason, message_json) now = time.time() + now_delta yield self.worker.store_retry(key, retry_delay, now=now) @inlineCallbacks def test_redis_access(self): """ Sanity check that we can put stuff in redis (or our fake) and get it out again. """ def r_get(key): return self.worker.redis.get(key) yield self.assert_equal_d(None, r_get("foo")) yield self.assert_equal_d([], self.redis.keys()) yield self.worker.redis.set("foo", "bar") yield self.assert_equal_d("bar", r_get("foo")) yield self.assert_equal_d(['foo'], self.redis.keys()) @inlineCallbacks def test_store_failure(self): """ Store a failure in redis and make sure we can get at it again. """ key = yield self.store_failure(reason="reason") yield self.assert_equal_d(set([key]), self.worker.get_failure_keys()) message_json = json.dumps({"message": "foo", "reason": "reason"}) yield self.assert_equal_d({ "message": message_json, "retry_delay": "0", "reason": "reason", }, self.redis.hgetall(key)) # Test a second one, this time with a JSON-encoded message key2 = yield self.store_failure( message=json.dumps({"foo": "bar"}), reason="reason") yield self.assert_equal_d( set([key, key2]), self.worker.get_failure_keys()) message_json = json.dumps({"foo": "bar"}) yield self.assert_equal_d({ "message": message_json, "retry_delay": "0", "reason": "reason", }, self.redis.hgetall(key2)) def test_write_timestamp(self): """ We need granular timestamps. """ start = datetime.utcnow().isoformat() timestamp = self.worker.get_next_write_timestamp(0) end = (datetime.utcnow() + timedelta(seconds=6)).isoformat() self.assertTrue(start < timestamp < end) self.assert_write_timestamp("1970-01-01T00:00:05", 0, 0) self.assert_write_timestamp("1970-01-01T00:00:05", 0, 4) self.assert_write_timestamp("1970-01-01T00:00:10", 0, 5) self.assert_write_timestamp("1970-01-01T00:00:10", 0, 9) self.assert_write_timestamp("1970-01-01T00:00:05", 2, 0) self.assert_write_timestamp("1970-01-01T00:00:10", 2, 4) self.assert_write_timestamp("1970-01-01T00:00:10", 2, 5) self.assert_write_timestamp("1970-01-01T00:00:15", 2, 9) self.assert_write_timestamp("1970-01-01T00:03:25", 101, 100) self.assert_write_timestamp("1970-01-01T00:03:30", 101, 104) self.assert_write_timestamp("1970-01-01T00:03:30", 101, 105) self.assert_write_timestamp("1970-01-01T00:03:35", 101, 109) def test_write_timestamp_granularity(self): """ We need granular timestamps with assorted granularity. """ self.worker.GRANULARITY = 10 self.assert_write_timestamp("1970-01-01T00:00:10", 0, 0) self.assert_write_timestamp("1970-01-01T00:00:10", 0, 4) self.assert_write_timestamp("1970-01-01T00:00:10", 0, 5) self.assert_write_timestamp("1970-01-01T00:00:10", 0, 9) self.assert_write_timestamp("1970-01-01T00:00:20", 0, 11) self.assert_write_timestamp("1970-01-01T00:00:20", 12, 0) self.assert_write_timestamp("1970-01-01T00:00:20", 12, 4) self.assert_write_timestamp("1970-01-01T00:00:20", 12, 5) self.assert_write_timestamp("1970-01-01T00:00:30", 12, 9) self.assert_write_timestamp("1970-01-01T00:00:30", 12, 11) self.worker.GRANULARITY = 3 self.assert_write_timestamp("1970-01-01T00:00:03", 0, 0) self.assert_write_timestamp("1970-01-01T00:00:06", 0, 4) self.assert_write_timestamp("1970-01-01T00:00:06", 0, 5) self.assert_write_timestamp("1970-01-01T00:00:12", 0, 9) self.assert_write_timestamp("1970-01-01T00:00:12", 0, 11) self.assert_write_timestamp("1970-01-01T00:00:15", 12, 0) self.assert_write_timestamp("1970-01-01T00:00:18", 12, 4) self.assert_write_timestamp("1970-01-01T00:00:18", 12, 5) self.assert_write_timestamp("1970-01-01T00:00:24", 12, 9) self.assert_write_timestamp("1970-01-01T00:00:24", 12, 11) @inlineCallbacks def test_store_read_timestamp(self): """ We need to store granular timestamps. """ timestamp1 = "1977-07-28T12:34:56" timestamp2 = "1980-07-30T12:34:56" timestamp3 = "1980-09-02T12:34:56" yield self.assert_stored_timestamps() yield self.worker.store_read_timestamp(timestamp2) yield self.assert_stored_timestamps(timestamp2) yield self.worker.store_read_timestamp(timestamp1) yield self.assert_stored_timestamps(timestamp1, timestamp2) yield self.worker.store_read_timestamp(timestamp3) yield self.assert_stored_timestamps(timestamp1, timestamp2, timestamp3) @inlineCallbacks def test_read_timestamp(self): """ We need to read the next timestamp. """ past = mktimestamp(-10) future = mktimestamp(10) yield self.assert_equal_d(None, self.worker.get_next_read_timestamp()) yield self.worker.store_read_timestamp(future) yield self.assert_equal_d(None, self.worker.get_next_read_timestamp()) yield self.worker.store_read_timestamp(past) yield self.assert_equal_d(past, self.worker.get_next_read_timestamp()) @inlineCallbacks def test_store_retry(self): """ Store a retry in redis and make sure we can get at it again. """ timestamp = "1970-01-01T00:00:05" retry_key = "retry_keys." + timestamp key = yield self.store_failure() yield self.assert_zcard(0, 'retry_timestamps') yield self.worker.store_retry(key, 0, now=0) yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_equal_d([timestamp], self.redis.zrange('retry_timestamps', 0, 0)) yield self.assert_equal_d(set([key]), self.redis.smembers(retry_key)) def test_get_retry_key_none(self): """ If there are no stored retries, get None. """ return self.assert_get_retry_key(False) @inlineCallbacks def test_get_retry_key_future(self): """ If there are no retries due, get None. """ yield self.store_retry(10) yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_get_retry_key(False) yield self.assert_zcard(1, 'retry_timestamps') @inlineCallbacks def test_get_retry_key_one_due(self): """ Get a retry from redis when we have one due. """ yield self.store_retry(0, -5) yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_get_retry_key() yield self.assert_zcard(0, 'retry_timestamps') yield self.assert_get_retry_key(False) @inlineCallbacks def test_get_retry_key_two_due(self): """ Get a retry from redis when we have two due. """ yield self.store_retry(0, -5) yield self.store_retry(0, -5) yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_get_retry_key() yield self.assert_zcard(1, 'retry_timestamps') @inlineCallbacks def test_get_retry_key_two_due_different_times(self): """ Get a retry from redis when we have two due at different times. """ yield self.store_retry(0, -5) yield self.store_retry(0, -15) yield self.assert_zcard(2, 'retry_timestamps') yield self.assert_get_retry_key() yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_get_retry_key() yield self.assert_zcard(0, 'retry_timestamps') @inlineCallbacks def test_get_retry_key_one_due_one_future(self): """ Get a retry from redis when we have one due and one in the future. """ yield self.store_retry(0, -5) yield self.worker.store_retry(self.store_failure(), 0) yield self.assert_zcard(2, 'retry_timestamps') yield self.assert_get_retry_key() yield self.assert_zcard(1, 'retry_timestamps') yield self.assert_get_retry_key(False) yield self.assert_zcard(1, 'retry_timestamps') @inlineCallbacks def test_deliver_retries_none(self): """ Delivering no retries should do nothing. """ yield self.worker.deliver_retries() self.assert_published_retries([]) @inlineCallbacks def test_deliver_retries_future(self): """ Delivering no current retries should do nothing. """ yield self.worker.store_retry(self.store_failure(), 0) yield self.worker.deliver_retries() self.assert_published_retries([]) @inlineCallbacks def test_deliver_retries_one_due(self): """ Delivering a current retry should deliver one message. """ yield self.store_retry(0, -5) yield self.worker.deliver_retries() self.assert_published_retries([{ 'message': 'foo', 'reason': 'bad stuff happened', }]) @inlineCallbacks def test_deliver_retries_many_due(self): """ Delivering current retries should deliver all messages. """ yield self.store_retry(0, -5) yield self.store_retry(0, -15) yield self.store_retry(0, -5) yield self.worker.deliver_retries() self.assert_published_retries([{ 'message': 'foo', 'reason': 'bad stuff happened', }] * 3) def test_update_retry_metadata(self): """ Retry metadata should be updated as appropriate. """ def mkmsg(retries, delay): return {'retry_metadata': {'retries': retries, 'delay': delay}} def assert_update_rmd(retries, delay, msg): msg = self.worker.update_retry_metadata(msg) self.assertEqual({'retries': retries, 'delay': delay}, msg['retry_metadata']) assert_update_rmd(1, 1, {}) assert_update_rmd(2, 3, mkmsg(1, 1)) assert_update_rmd(3, 9, mkmsg(2, 3)) @inlineCallbacks def test_start_retrying(self): """ The retry publisher should start when configured appropriately. """ self.assertEqual(None, self.worker.delivery_loop) yield self.worker.stopWorker() yield self.make_worker(1) self.assertEqual(self.worker.deliver_retries, self.worker.delivery_loop.f) self.assertTrue(self.worker.delivery_loop.running) PK=JG~[ vumi/transports/tests/utils.pyfrom twisted.internet.defer import inlineCallbacks from vumi.tests.utils import VumiWorkerTestCase, PersistenceMixin class TransportTestCase(VumiWorkerTestCase, PersistenceMixin): """ This is a base class for testing transports. """ transport_class = None def setUp(self): self._persist_setUp() super(TransportTestCase, self).setUp() @inlineCallbacks def tearDown(self): yield super(TransportTestCase, self).tearDown() yield self._persist_tearDown() def get_transport(self, config, cls=None, start=True): """ Get an instance of a transport class. :param config: Config dict. :param cls: The transport class to instantiate. Defaults to :attr:`transport_class` :param start: True to start the transport (default), False otherwise. Some default config values are helpfully provided in the interests of reducing boilerplate: * ``transport_name`` defaults to :attr:`self.transport_name` """ if cls is None: cls = self.transport_class config = self.mk_config(config) config.setdefault('transport_name', self.transport_name) return self.get_worker(config, cls, start) def mkmsg_in(self, *args, **kw): msg = super(TransportTestCase, self).mkmsg_in(*args, **kw) return self._make_matcher(msg) def mkmsg_ack(self, *args, **kw): msg = super(TransportTestCase, self).mkmsg_ack(*args, **kw) return self._make_matcher(msg, 'event_id') def mkmsg_delivery(self, *args, **kw): msg = super(TransportTestCase, self).mkmsg_delivery(*args, **kw) return self._make_matcher(msg, 'event_id') def get_dispatched_messages(self): return self.get_dispatched_inbound() def wait_for_dispatched_messages(self, amount): return self.wait_for_dispatched_inbound(amount) def clear_dispatched_messages(self): return self.clear_dispatched_inbound() def dispatch(self, message, rkey=None, exchange='vumi'): if rkey is None: rkey = self.rkey('outbound') return self._dispatch(message, rkey, exchange) PKqGrr"vumi/transports/tests/test_base.pyfrom twisted.internet.defer import inlineCallbacks from vumi.tests.helpers import VumiTestCase from vumi.transports.base import Transport from vumi.transports.tests.helpers import TransportHelper from vumi.tests.utils import LogCatcher class TestBaseTransport(VumiTestCase): TEST_MIDDLEWARE_CONFIG = { "middleware": [ {"mw1": "vumi.middleware.tests.utils.RecordingMiddleware"}, {"mw2": "vumi.middleware.tests.utils.RecordingMiddleware"}, ], } def setUp(self): self.tx_helper = self.add_helper(TransportHelper(Transport)) @inlineCallbacks def test_start_transport(self): tr = yield self.tx_helper.get_transport({}) self.assertEqual(self.tx_helper.transport_name, tr.transport_name) self.assertTrue(len(tr.connectors) >= 1) connector = tr.connectors[tr.transport_name] self.assertTrue(connector._consumers.keys(), set(['outbound'])) self.assertTrue(connector._publishers.keys(), set(['inbound', 'event'])) self.assertEqual(tr.failure_publisher.routing_key, '%s.failures' % (tr.transport_name,)) @inlineCallbacks def test_middleware_for_inbound_messages(self): transport = yield self.tx_helper.get_transport( self.TEST_MIDDLEWARE_CONFIG) orig_msg = self.tx_helper.make_inbound("inbound") yield transport.publish_message(**orig_msg.payload) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['record'], [ ['mw2', 'inbound', self.tx_helper.transport_name], ['mw1', 'inbound', self.tx_helper.transport_name], ]) @inlineCallbacks def test_middleware_for_events(self): transport = yield self.tx_helper.get_transport( self.TEST_MIDDLEWARE_CONFIG) orig_msg = self.tx_helper.make_ack() yield transport.publish_event(**orig_msg.payload) [msg] = self.tx_helper.get_dispatched_events() self.assertEqual(msg['record'], [ ['mw2', 'event', self.tx_helper.transport_name], ['mw1', 'event', self.tx_helper.transport_name], ]) @inlineCallbacks def test_middleware_for_failures(self): transport = yield self.tx_helper.get_transport( self.TEST_MIDDLEWARE_CONFIG) orig_msg = self.tx_helper.make_outbound("outbound") yield transport.send_failure(orig_msg, ValueError(), "dummy_traceback") [msg] = self.tx_helper.get_dispatched_failures() self.assertEqual(msg['record'], [ ['mw2', 'failure', self.tx_helper.transport_name], ['mw1', 'failure', self.tx_helper.transport_name], ]) @inlineCallbacks def test_middleware_for_outbound_messages(self): msgs = [] transport = yield self.tx_helper.get_transport( self.TEST_MIDDLEWARE_CONFIG) transport.add_outbound_handler(msgs.append) yield self.tx_helper.make_dispatch_outbound("outbound") [msg] = msgs self.assertEqual(msg['record'], [ ('mw1', 'outbound', self.tx_helper.transport_name), ('mw2', 'outbound', self.tx_helper.transport_name), ]) def get_tx_consumers(self, tx): for connector in tx.connectors.values(): for consumer in connector._consumers.values(): yield consumer @inlineCallbacks def test_transport_prefetch_count_custom(self): transport = yield self.tx_helper.get_transport({ 'amqp_prefetch_count': 1, }) consumers = list(self.get_tx_consumers(transport)) self.assertEqual(1, len(consumers)) for consumer in consumers: fake_channel = consumer.channel._fake_channel self.assertEqual(fake_channel.qos_prefetch_count, 1) @inlineCallbacks def test_transport_prefetch_count_default(self): transport = yield self.tx_helper.get_transport({}) consumers = list(self.get_tx_consumers(transport)) self.assertEqual(1, len(consumers)) for consumer in consumers: fake_channel = consumer.channel._fake_channel self.assertEqual(fake_channel.qos_prefetch_count, 20) @inlineCallbacks def test_add_outbound_handler(self): transport = yield self.tx_helper.get_transport({}) msgs = [] msg = transport.add_outbound_handler(msgs.append, endpoint_name='foo') msg = yield self.tx_helper.make_dispatch_outbound( "outbound", endpoint='foo') self.assertEqual(msgs, [msg]) @inlineCallbacks def test_publish_status(self): transport = yield self.tx_helper.get_transport({ 'transport_name': 'foo', 'publish_status': True }) msg = yield transport.publish_status( status='down', component='foo', type='bar', message='baz') self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'foo') self.assertEqual(msg['type'], 'bar') self.assertEqual(msg['message'], 'baz') msgs = self.tx_helper.get_dispatched_statuses('foo.status') self.assertEqual(msgs, [msg]) @inlineCallbacks def test_publish_status_disabled(self): transport = yield self.tx_helper.get_transport({ 'transport_name': 'foo', 'worker_name': 'foo', 'publish_status': False }) with LogCatcher() as lc: msg = yield transport.publish_status( status='down', component='foo', type='bar', message='baz') self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'foo') self.assertEqual(msg['type'], 'bar') self.assertEqual(msg['message'], 'baz') msgs = self.tx_helper.get_dispatched_statuses('foo.status') self.assertEqual(msgs, []) [log] = lc.logs self.assertEqual( log['message'][0], "Status publishing disabled for transport 'foo', " "ignoring status %r" % (msg,)) self.assertEqual(log['system'], 'foo') PKqGɈ*vumi/transports/tests/test_test_helpers.pyfrom twisted.internet.defer import inlineCallbacks from vumi.transports.base import Transport from vumi.transports.failures import FailureMessage from vumi.transports.tests.helpers import TransportHelper from vumi.tests.helpers import ( VumiTestCase, IHelper, PersistenceHelper, MessageHelper, WorkerHelper, MessageDispatchHelper, success_result_of) class RunningCheckTransport(Transport): tx_worker_running = False def setup_transport(self): self.tx_worker_running = True def teardown_transport(self): self.tx_worker_running = False class FakeCleanupCheckHelper(object): cleaned_up = False def cleanup(self): self.cleaned_up = True class TestTransportHelper(VumiTestCase): def test_implements_IHelper(self): """ TransportHelper instances should provide the IHelper interface. """ self.assertTrue(IHelper.providedBy(TransportHelper(None))) def test_defaults(self): """ TransportHelper instances should have the expected parameter defaults. """ fake_tx_class = object() tx_helper = TransportHelper(fake_tx_class) self.assertEqual(tx_helper.transport_class, fake_tx_class) self.assertIsInstance(tx_helper.persistence_helper, PersistenceHelper) self.assertIsInstance(tx_helper.msg_helper, MessageHelper) self.assertIsInstance(tx_helper.worker_helper, WorkerHelper) dispatch_helper = tx_helper.dispatch_helper self.assertIsInstance(dispatch_helper, MessageDispatchHelper) self.assertEqual(dispatch_helper.msg_helper, tx_helper.msg_helper) self.assertEqual( dispatch_helper.worker_helper, tx_helper.worker_helper) self.assertEqual(tx_helper.persistence_helper.use_riak, False) def test_all_params(self): """ TransportHelper should pass use_riak to its PersistenceHelper and all other params to its MessageHelper. """ fake_tx_class = object() tx_helper = TransportHelper( fake_tx_class, use_riak=True, transport_addr='Obs station') self.assertEqual(tx_helper.persistence_helper.use_riak, True) self.assertEqual(tx_helper.msg_helper.transport_addr, 'Obs station') def test_setup_sync(self): """ TransportHelper.setup() should return ``None``, not a Deferred. """ msg_helper = TransportHelper(None) self.add_cleanup(msg_helper.cleanup) self.assertEqual(msg_helper.setup(), None) def test_cleanup(self): """ TransportHelper.cleanup() should call .cleanup() on its PersistenceHelper and WorkerHelper. """ tx_helper = TransportHelper(None) tx_helper.persistence_helper = FakeCleanupCheckHelper() tx_helper.worker_helper = FakeCleanupCheckHelper() self.assertEqual(tx_helper.persistence_helper.cleaned_up, False) self.assertEqual(tx_helper.worker_helper.cleaned_up, False) success_result_of(tx_helper.cleanup()) self.assertEqual(tx_helper.persistence_helper.cleaned_up, True) self.assertEqual(tx_helper.worker_helper.cleaned_up, True) @inlineCallbacks def test_get_transport_defaults(self): """ .get_transport() should return a started transport worker. """ tx_helper = self.add_helper(TransportHelper(RunningCheckTransport)) app = yield tx_helper.get_transport({}) self.assertIsInstance(app, RunningCheckTransport) self.assertEqual(app.tx_worker_running, True) @inlineCallbacks def test_get_transport_no_start(self): """ .get_transport() should return an unstarted transport worker if passed ``start=False``. """ tx_helper = self.add_helper(TransportHelper(RunningCheckTransport)) app = yield tx_helper.get_transport({}, start=False) self.assertIsInstance(app, RunningCheckTransport) self.assertEqual(app.tx_worker_running, False) @inlineCallbacks def test_get_application_different_class(self): """ .get_transport() should return an instance of the specified worker class if one is provided. """ tx_helper = self.add_helper(TransportHelper(Transport)) app = yield tx_helper.get_transport({}, cls=RunningCheckTransport) self.assertIsInstance(app, RunningCheckTransport) def _add_to_dispatched(self, broker, rkey, msg): broker.exchange_declare('vumi', 'direct', durable=True) broker.publish_message('vumi', rkey, msg) def test_get_dispatched_failures(self): """ .get_dispatched_failures() should get failures dispatched by the transport. """ tx_helper = TransportHelper(Transport) dispatched = tx_helper.get_dispatched_failures('fooconn') self.assertEqual(dispatched, []) msg = FailureMessage( message=tx_helper.msg_helper.make_outbound('foo').payload, failure_code=FailureMessage.FC_UNSPECIFIED, reason='sadness') self._add_to_dispatched( tx_helper.worker_helper.broker, 'fooconn.failures', msg) dispatched = tx_helper.get_dispatched_failures('fooconn') self.assertEqual(dispatched, [msg]) def test_get_dispatched_failures_no_connector(self): """ .get_dispatched_failures() should use the default connector if none is passed in. """ tx_helper = TransportHelper(Transport, transport_name='fooconn') dispatched = tx_helper.get_dispatched_failures() self.assertEqual(dispatched, []) msg = FailureMessage( message=tx_helper.msg_helper.make_outbound('foo').payload, failure_code=FailureMessage.FC_UNSPECIFIED, reason='sadness') self._add_to_dispatched( tx_helper.worker_helper.broker, 'fooconn.failures', msg) dispatched = tx_helper.get_dispatched_failures() self.assertEqual(dispatched, [msg]) PKrgTGԊ vumi/transports/tests/helpers.pyfrom twisted.internet.defer import inlineCallbacks from zope.interface import implements from vumi.transports.failures import FailureMessage from vumi.tests.helpers import ( MessageHelper, WorkerHelper, PersistenceHelper, MessageDispatchHelper, generate_proxies, IHelper, ) class TransportHelper(object): """ Test helper for transport workers. This helper construct and wraps several lower-level helpers and provides higher-level functionality for transport tests. :param transport_class: The worker class for the transport being tested. :param bool use_riak: Set to ``True`` if the test requires Riak. This is passed to the underlying :class:`~vumi.tests.helpers.PersistenceHelper`. :param \**msg_helper_args: All other keyword params are passed to the underlying :class:`~vumi.tests.helpers.MessageHelper`. """ implements(IHelper) def __init__(self, transport_class, use_riak=False, **msg_helper_args): self.transport_class = transport_class self.persistence_helper = PersistenceHelper(use_riak=use_riak) self.msg_helper = MessageHelper(**msg_helper_args) self.transport_name = self.msg_helper.transport_name self.worker_helper = WorkerHelper( connector_name=self.transport_name, status_connector_name="%s.status" % (self.transport_name,)) self.dispatch_helper = MessageDispatchHelper( self.msg_helper, self.worker_helper) # Proxy methods from our helpers. generate_proxies(self, self.msg_helper) generate_proxies(self, self.worker_helper) generate_proxies(self, self.dispatch_helper) generate_proxies(self, self.persistence_helper) def setup(self): self.persistence_helper.setup() self.worker_helper.setup() @inlineCallbacks def cleanup(self): yield self.worker_helper.cleanup() yield self.persistence_helper.cleanup() def get_transport(self, config, cls=None, start=True): """ Get an instance of a transport class. :param config: Config dict. :param cls: The transport class to instantiate. Defaults to :attr:`transport_class` :param start: True to start the transport (default), False otherwise. Some default config values are helpfully provided in the interests of reducing boilerplate: * ``transport_name`` defaults to :attr:`self.transport_name` """ if cls is None: cls = self.transport_class config = self.mk_config(config) config.setdefault('transport_name', self.transport_name) return self.get_worker(cls, config, start) def get_dispatched_failures(self, connector_name=None): """ Get failures dispatched by a transport. :param str connector_name: Connector name. If ``None``, the default connector name for the helper instance will be used. :returns: A list of :class:`~vumi.transports.failures.FailureMessage` instances. """ return self.get_dispatched(connector_name, 'failures', FailureMessage) PK=JG!vumi/transports/tests/__init__.pyPK=JG!^'vumi/transports/tests/test_scheduler.pyimport time from datetime import datetime from twisted.internet.defer import inlineCallbacks from vumi.persist.fake_redis import FakeRedis from vumi.transports.scheduler import Scheduler from vumi.message import TransportUserMessage from vumi.utils import to_kwargs from vumi.tests.helpers import VumiTestCase, MessageHelper class TestScheduler(VumiTestCase): def setUp(self): self.r_server = FakeRedis() self.scheduler = Scheduler(self.r_server, self._scheduler_callback) self.add_cleanup(self.stop_scheduler) self._delivery_history = [] self.msg_helper = self.add_helper(MessageHelper()) def stop_scheduler(self): if self.scheduler.is_running: self.scheduler.stop() def _scheduler_callback(self, scheduled_at, message): self._delivery_history.append((scheduled_at, message)) return (scheduled_at, message) def assertDelivered(self, message): delivered_messages = [TransportUserMessage(**to_kwargs(payload)) for _, payload in self._delivery_history] self.assertIn(message['message_id'], [msg['message_id'] for msg in delivered_messages]) def assertNumDelivered(self, number): self.assertEqual(number, len(self._delivery_history)) def get_pending_messages(self): scheduled_timestamps = self.scheduler.r_key('scheduled_timestamps') return self.r_server.zrange(scheduled_timestamps, 0, -1) def test_scheduling(self): msg = self.msg_helper.make_inbound("inbound") now = time.mktime(datetime(2012, 1, 1).timetuple()) delta = 10 # seconds from now key, bucket_key = self.scheduler.schedule(delta, msg.payload, now) self.assertEqual(bucket_key, '%s#%s.%s' % ( self.scheduler.r_prefix, 'scheduled_keys', self.scheduler.get_next_write_timestamp(delta, now) )) scheduled_key = self.scheduler.get_scheduled_key(now) self.assertEqual(scheduled_key, None) scheduled_time = now + delta scheduled_key = self.scheduler.get_scheduled_key(scheduled_time) self.assertTrue(scheduled_key) self.assertEqual(set([scheduled_key]), self.scheduler.get_all_scheduled_keys()) @inlineCallbacks def test_delivery_loop(self): msg = self.msg_helper.make_inbound("inbound") now = time.mktime(datetime(2012, 1, 1).timetuple()) delta = 16 # seconds from now self.scheduler.schedule(delta, msg.payload, now) scheduled_time = now + delta + self.scheduler.granularity yield self.scheduler.deliver_scheduled(scheduled_time) self.assertDelivered(msg) @inlineCallbacks def test_deliver_loop_future(self): now = time.mktime(datetime(2012, 1, 1).timetuple()) for i in range(0, 3): msg = self.msg_helper.make_inbound( "inbound", message_id='message_%s' % (i,)) delta = i * 10 key, _ = self.scheduler.schedule(delta, msg.payload, now) scheduled_time = now + delta + self.scheduler.granularity self.assertEqual(set([key]), self.scheduler.get_all_scheduled_keys()) yield self.scheduler.deliver_scheduled(scheduled_time) self.assertNumDelivered(i + 1) self.assertEqual(set(), self.scheduler.get_all_scheduled_keys()) @inlineCallbacks def test_deliver_ancient_messages(self): # something stuck in the queue since 1912 or scheduler hasn't # been running since 1912 msg = self.msg_helper.make_inbound("inbound") way_back = time.mktime(datetime(1912, 1, 1).timetuple()) scheduled_key, _ = self.scheduler.schedule(0, msg.payload, way_back) self.assertTrue(scheduled_key) now = time.mktime(datetime.now().timetuple()) yield self.scheduler.deliver_scheduled(now) self.assertEqual(set([scheduled_key]), self.scheduler.get_all_scheduled_keys()) self.assertEqual(len(self.get_pending_messages()), 1) yield self.scheduler.deliver_scheduled( way_back + self.scheduler.granularity) self.assertDelivered(msg) self.assertEqual(self.get_pending_messages(), []) self.assertEqual(set(), self.scheduler.get_all_scheduled_keys()) @inlineCallbacks def test_clear_scheduled_messages(self): msg = self.msg_helper.make_inbound("inbound") now = time.mktime(datetime.now().timetuple()) scheduled_time = now + self.scheduler.granularity key, bucket = self.scheduler.schedule(0, msg.payload, scheduled_time) self.assertEqual(len(self.get_pending_messages()), 1) self.assertEqual(set([key]), self.scheduler.get_all_scheduled_keys()) self.scheduler.clear_scheduled(key) yield self.scheduler.deliver_scheduled() self.assertEqual(self.r_server.hgetall(key), {}) self.assertEqual(self.r_server.smembers(bucket), set()) self.assertNumDelivered(0) PK=JG#vumi/transports/smssync/__init__.py"""SMSSync (http://smssync.ushahidi.com/) transport for android devices""" from vumi.transports.smssync.smssync import SingleSmsSync, MultiSmsSync __all__ = ['SingleSmsSync', 'MultiSmsSync'] PK=JGP"ӧ11"vumi/transports/smssync/smssync.py# -*- test-case-name: vumi.transports.smssync.tests.test_smssync -*- import json import datetime from twisted.internet.defer import inlineCallbacks from twisted.internet import reactor from vumi import log from vumi.message import TransportUserMessage from vumi.utils import normalize_msisdn from vumi.persist.txredis_manager import TxRedisManager from vumi.transports.failures import PermanentFailure from vumi.transports.httprpc import HttpRpcTransport class SmsSyncMsgInfo(object): """Holder of attributes needed to process an SMSSync message. :param str account_id: An ID for the acocunt this message is being sent to / from. :param str smssync_secret: The shared SMSSync secret for the account this message is being sent to / from. :param str country_code: The default country_code for the account this message is being sent to / from. """ def __init__(self, account_id, smssync_secret, country_code): self.account_id = account_id self.smssync_secret = smssync_secret self.country_code = country_code class BaseSmsSyncTransport(HttpRpcTransport): """ Ushahidi SMSSync Transport for getting messages into vumi. :param str web_path: The path relative to the host where this listens :param int web_port: The port this listens on :param str transport_name: The name this transport instance will use to create its queues :param dict redis_manager: Redis client configuration. :param float reply_delay: The amount of time to wait (in seconds) for a reply message before closing the SMSSync HTTP inbound message request. Replies received within this amount of time will be returned with the reply (default: 0.5s). """ transport_type = 'sms' # SMSSync True and False constants SMSSYNC_TRUE, SMSSYNC_FALSE = ("true", "false") SMSSYNC_DATE_FORMAT = "%m-%d-%y %H:%M" MILLISECONDS = 1000 callLater = reactor.callLater def validate_config(self): super(BaseSmsSyncTransport, self).validate_config() self._reply_delay = float(self.config.get('reply_delay', '0.5')) @inlineCallbacks def setup_transport(self): r_config = self.config.get('redis_manager', {}) self.redis = yield TxRedisManager.from_config(r_config) yield super(BaseSmsSyncTransport, self).setup_transport() @inlineCallbacks def teardown_transport(self): yield super(BaseSmsSyncTransport, self).teardown_transport() yield self.redis.close_manager() def msginfo_for_request(self, request): """Returns an :class:`SmsSyncMsgInfo` instance for this request. May return a deferred that yields the actual result to its callback. """ raise NotImplementedError("Sub-classes should implement" " msginfo_for_request") def msginfo_for_message(self, msg): """Returns an :class:`SmsSyncMsgInfo` instance for this outbound message. May return a deferred that yields the actual result to its callback. """ raise NotImplementedError("Sub-classes should implement" " msginfo_for_message") def add_msginfo_metadata(self, payload, msginfo): """Update an outbound message's payload's transport_metadata to allow msginfo to be reconstructed from replies.""" raise NotImplementedError("Sub-class should implement" " add_msginfo_metadata") def key_for_account(self, account_id): return "outbound_messages#%s" % (account_id,) @inlineCallbacks def _handle_send(self, message_id, request): msginfo = yield self.msginfo_for_request(request) if msginfo is None: log.warning("Bad account: %r (args: %r)" % (request, request.args)) yield self._send_response(message_id, success=self.SMSSYNC_FALSE) return yield self._respond_with_pending_messages( msginfo, message_id, task='send', secret=msginfo.smssync_secret) def _check_request_args(self, request, expected_keys): expected_keys = set(expected_keys) present_keys = set(request.args.keys()) return expected_keys.issubset(present_keys) def _parse_timestamp(self, request): smssync_timestamp = request.args['sent_timestamp'][0] timestamp = None if timestamp is None: try: timestamp = datetime.datetime.strptime( smssync_timestamp, self.SMSSYNC_DATE_FORMAT) except ValueError: pass if timestamp is None: try: utc_ms = int(request.args['sent_timestamp'][0]) timestamp = datetime.datetime.utcfromtimestamp( utc_ms / self.MILLISECONDS) except ValueError: pass if timestamp is None: log.warning("Bad timestamp format: %r (args: %r)" % (request, request.args)) timestamp = datetime.datetime.utcnow() return timestamp @inlineCallbacks def _handle_receive(self, message_id, request): if not self._check_request_args(request, ['secret', 'sent_timestamp', 'sent_to', 'from', 'message']): log.warning("Bad request: %r (args: %r)" % (request, request.args)) yield self._send_response(message_id, success=self.SMSSYNC_FALSE) return msginfo = yield self.msginfo_for_request(request) supplied_secret = request.args['secret'][0] if msginfo is None or (msginfo.smssync_secret and not msginfo.smssync_secret == supplied_secret): log.warning("Bad secret or account: %r (args: %r)" % (request, request.args)) yield self._send_response(message_id, success=self.SMSSYNC_FALSE) return timestamp = self._parse_timestamp(request) normalize = lambda raw: normalize_msisdn(raw, msginfo.country_code) message = { 'message_id': message_id, 'transport_type': self.transport_type, 'to_addr': normalize(request.args['sent_to'][0]), 'from_addr': normalize(request.args['from'][0]), 'content': request.args['message'][0], 'timestamp': timestamp, } self.add_msginfo_metadata(message, msginfo) yield self.publish_message(**message) self.callLater(self._reply_delay, self._respond_with_pending_messages, msginfo, message_id, success=self.SMSSYNC_TRUE) def _send_response(self, message_id, **kw): response = {'payload': kw} return self.finish_request(message_id, json.dumps(response)) @inlineCallbacks def _respond_with_pending_messages(self, msginfo, message_id, **kw): """Gathers pending messages and sends a response including them.""" outbound_ids = [] outbound_messages = [] account_key = self.key_for_account(msginfo.account_id) while True: msg_json = yield self.redis.lpop(account_key) if msg_json is None: break msg = TransportUserMessage.from_json(msg_json) outbound_ids.append(msg['message_id']) outbound_messages.append({'to': msg['to_addr'], 'message': msg['content'] or ''}) yield self._send_response(message_id, messages=outbound_messages, **kw) for outbound_id in outbound_ids: yield self.publish_ack(user_message_id=outbound_id, sent_message_id=outbound_id) def handle_raw_inbound_message(self, message_id, request): # This matches the dispatch logic in Usahidi's request # handler for SMSSync. # See https://github.com/ushahidi/Ushahidi_Web/blob/ # master/plugins/smssync/controllers/smssync.php tasks = request.args.get('task') task = tasks[0] if tasks else None if task == "send": return self._handle_send(message_id, request) else: return self._handle_receive(message_id, request) @inlineCallbacks def handle_outbound_message(self, message): msginfo = yield self.msginfo_for_message(message) if msginfo is None: err_msg = ("SmsSyncTransport couldn't determine" " secret for outbound message.") yield self.publish_nack(user_message_id=message['message_id'], sent_message_id=message['message_id'], reason=err_msg) raise PermanentFailure(err_msg) else: account_key = self.key_for_account(msginfo.account_id) yield self.redis.rpush(account_key, message.to_json()) class SingleSmsSync(BaseSmsSyncTransport): """ Ushahidi SMSSync Transport for a single phone. Additional configuration options: :param str smssync_secret: Secret of the single phone (default: '', i.e. no secret set) :param str account_id: Account id for storing outbound messages under. Defaults to the `smssync_secret` which is fine unless the secret changes. :param str country_code: Default country code to use when normalizing MSISDNs sent by SMSSync (default is the empty string, which assumes numbers already include the international dialing prefix). """ def validate_config(self): super(SingleSmsSync, self).validate_config() # The secret is the empty string in the case where the single-phone # transport isn't using a secret (this fits with how the Ushahidi # handles the lack of a secret). self._smssync_secret = self.config.get('smssync_secret', '') self._account_id = self.config.get('account_id', self._smssync_secret) self._country_code = self.config.get('country_code', '').lstrip('+') def msginfo_for_request(self, request): return SmsSyncMsgInfo(self._account_id, self._smssync_secret, self._country_code) def msginfo_for_message(self, msg): return SmsSyncMsgInfo(self._account_id, self._smssync_secret, self._country_code) def add_msginfo_metadata(self, msg, msginfo): # The single phone SMSSync transport doesn't require any # transport metadata in order to reconstruct msginfo pass class MultiSmsSync(BaseSmsSyncTransport): """ Ushahidi SMSSync Transport for a multiple phones. Each phone accesses a URL that has the form `//`. A blank secret should be entered into the SMSSync `secret` field. Additional configuration options: :param dict country_codes: Map from `account_id` to the country code to use when normalizing MSISDNs sent by SMSSync to that API URL. If an `account_id` is not in this map the default is to use an empty country code string). """ def validate_config(self): super(MultiSmsSync, self).validate_config() self._country_codes = self.config.get('country_codes', {}) def _country_code(self, account_id): return self._country_codes.get(account_id, '').lstrip('+') def msginfo_for_request(self, request): pathparts = request.path.rstrip('/').split('/') if not pathparts or not pathparts[-1]: return None account_id = pathparts[-1] return SmsSyncMsgInfo(account_id, '', self._country_code(account_id)) def msginfo_for_message(self, msg): account_id = self.account_from_message(msg) if account_id is None: return None return SmsSyncMsgInfo(account_id, '', self._country_code(account_id)) def add_msginfo_metadata(self, msg, msginfo): # The single phone SMSSync transport doesn't require any # transport metadata in order to reconstruct msginfo self.add_account_to_payload(msg, msginfo.account_id) @staticmethod def account_from_message(msg): return msg['transport_metadata'].get('account_id') @classmethod def add_account_to_message(cls, msg, account_id): return cls.add_account_to_payload(msg.payload, account_id) @staticmethod def add_account_to_payload(payload, account_id): transport_metadata = payload.setdefault('transport_metadata', {}) transport_metadata['account_id'] = account_id PK=JG)vumi/transports/smssync/tests/__init__.pyPK=JG""-vumi/transports/smssync/tests/test_smssync.py# -*- encoding: utf-8 -*- """Tests for SMSSync transport.""" import json import datetime from urllib import urlencode from twisted.internet.defer import inlineCallbacks from twisted.internet.task import Clock from vumi.utils import http_request from vumi.tests.helpers import VumiTestCase from vumi.transports.smssync import SingleSmsSync, MultiSmsSync from vumi.transports.smssync.smssync import SmsSyncMsgInfo from vumi.transports.failures import PermanentFailure from vumi.transports.tests.helpers import TransportHelper class TestSingleSmsSync(VumiTestCase): transport_class = SingleSmsSync account_in_url = False @inlineCallbacks def setUp(self): self.clock = Clock() self.reply_delay = 0.5 self.auto_advance_clock = True self.config = { 'web_path': "foo", 'web_port': 0, 'reply_delay': self.reply_delay, } self.add_transport_config() self.tx_helper = self.add_helper(TransportHelper(self.transport_class)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.callLater = self._dummy_call_later self.transport_url = self.transport.get_transport_url() def _dummy_call_later(self, *args, **kw): self.clock.callLater(*args, **kw) if self.auto_advance_clock: self.clock.advance(self.reply_delay) def add_transport_config(self): self.config["smssync_secret"] = self.smssync_secret = "secretsecret" self.config["country_code"] = self.country_code = "+27" self.config["account_id"] = self.account_id = "test_account" def smssync_inbound(self, content, from_addr='123', to_addr='555', timestamp=None, message_id='1', secret=None): """Emulate an inbound message from SMSSync on an Android phone.""" msginfo = self.default_msginfo() if timestamp is None: timestamp = datetime.datetime.utcnow() if hasattr(timestamp, 'strftime'): timestamp = timestamp.strftime("%m-%d-%y %H:%M") if secret is None: secret = msginfo.smssync_secret # Timestamp format: mm-dd-yy-hh:mm, e.g. 11-27-11-07:11 params = { 'sent_to': to_addr, 'from': from_addr, 'message': content, 'sent_timestamp': timestamp, 'message_id': message_id, 'secret': secret, } return self.smssync_call(params, method='POST') def smssync_poll(self): """Emulate a poll from SMSSync for waiting outbound messages.""" return self.smssync_call({'task': 'send'}, method='GET') def smssync_call(self, params, method): url = self.mkurl(params) d = http_request(url, '', method=method) d.addCallback(json.loads) return d def mkurl(self, params): msginfo = self.default_msginfo() params = dict((k.encode('utf-8'), v.encode('utf-8')) for k, v in params.items()) return '%s%s%s?%s' % ( self.transport_url, self.config['web_path'], ("/%s/" % msginfo.account_id) if self.account_in_url else '', urlencode(params), ) def default_msginfo(self): return SmsSyncMsgInfo(self.account_id, self.smssync_secret, self.country_code) @inlineCallbacks def test_inbound_success(self): now = datetime.datetime.utcnow().replace(second=0, microsecond=0) response = yield self.smssync_inbound(content=u'hællo', timestamp=now) self.assertEqual(response, {"payload": {"success": "true", "messages": []}}) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "555") self.assertEqual(msg['from_addr'], "123") self.assertEqual(msg['content'], u"hællo") self.assertEqual(msg['timestamp'], now) @inlineCallbacks def test_inbound_millisecond_timestamp(self): smssync_ms = '1377125641000' now = datetime.datetime.utcfromtimestamp(int(smssync_ms) / 1000) response = yield self.smssync_inbound(content=u'hello', timestamp=smssync_ms) self.assertEqual(response, {"payload": {"success": "true", "messages": []}}) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['timestamp'], now) @inlineCallbacks def test_inbound_with_reply(self): self.auto_advance_clock = False now = datetime.datetime.utcnow().replace(second=0, microsecond=0) inbound_d = self.smssync_inbound(content=u'hællo', timestamp=now) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) reply = yield self.tx_helper.make_dispatch_reply(msg, u'ræply') self.clock.advance(self.reply_delay) response = yield inbound_d self.assertEqual(response, {"payload": {"success": "true", "messages": [{ "to": reply['to_addr'], "message": u"ræply", }], }}) @inlineCallbacks def test_normalize_msisdn(self): yield self.smssync_inbound(content="hi", from_addr="0555-7171", to_addr="0555-7272") [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['from_addr'], "+275557171") self.assertEqual(msg['to_addr'], "+275557272") @inlineCallbacks def test_inbound_invalid_secret(self): response = yield self.smssync_inbound(content=u'hello', secret='wrong') if self.smssync_secret == '': # blank secrets should not be checked self.assertEqual(response, {"payload": {"success": "true", "messages": []}}) else: self.assertEqual(response, {"payload": {"success": "false"}}) @inlineCallbacks def test_inbound_garbage(self): response = yield self.smssync_call({}, 'GET') self.assertEqual(response, {"payload": {"success": "false"}}) @inlineCallbacks def test_poll_outbound(self): outbound_msg = self.tx_helper.make_outbound(u'hællo') msginfo = self.default_msginfo() self.transport.add_msginfo_metadata(outbound_msg.payload, msginfo) yield self.tx_helper.dispatch_outbound(outbound_msg) response = yield self.smssync_poll() self.assertEqual(response, { "payload": { "task": "send", "secret": self.smssync_secret, "messages": [{ "to": outbound_msg['to_addr'], "message": outbound_msg['content'], }, ], }, }) [event] = yield self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], outbound_msg['message_id']) @inlineCallbacks def test_reply_round_trip(self): # test that calling .reply(...) generates a working reply (this is # non-trivial because the transport metadata needs to be correct for # this to work). yield self.smssync_inbound(content=u'Hi') [msg] = self.tx_helper.get_dispatched_inbound() yield self.tx_helper.make_dispatch_reply(msg, 'Hi back!') response = yield self.smssync_poll() self.assertEqual(response["payload"]["messages"], [{ "to": msg['from_addr'], "message": "Hi back!", }]) class TestMultiSmsSync(TestSingleSmsSync): transport_class = MultiSmsSync account_in_url = True def add_transport_config(self): self.account_id = "default_account_id" self.smssync_secret = "" self.country_code = "+27" self.config["country_codes"] = { self.account_id: self.country_code } @inlineCallbacks def test_nack(self): # we intentionally skip adding the msg info to force the transport # to reply with a nack msg = yield self.tx_helper.make_dispatch_outbound("hello world") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) [twisted_failure] = self.flushLoggedErrors(PermanentFailure) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], "SmsSyncTransport couldn't determine secret for outbound message.") PK=H++*vumi/transports/vumi_bridge/vumi_bridge.py# -*- test-case-name: vumi.transports.vumi_bridge.tests.test_vumi_bridge -*- import base64 import json import os import certifi from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks from twisted.web import http from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from twisted.web.client import Agent from treq.client import HTTPClient from vumi.transports import Transport from vumi.config import ConfigText, ConfigDict, ConfigInt, ConfigFloat from vumi.persist.txredis_manager import TxRedisManager from vumi.message import TransportUserMessage, TransportEvent from vumi.utils import to_kwargs, StatusEdgeDetector from vumi import log class VumiBridgeTransportConfig(Transport.CONFIG_CLASS): account_key = ConfigText( 'The account key to connect with.', static=True, required=True) conversation_key = ConfigText( 'The conversation key to use.', static=True, required=True) access_token = ConfigText( 'The access token for the conversation key.', static=True, required=True) base_url = ConfigText( 'The base URL for the API', static=True, default='https://go.vumi.org/api/v1/go/http_api_nostream/') message_life_time = ConfigInt( 'How long to keep message_ids around for.', static=True, default=48 * 60 * 60) # default is 48 hours. redis_manager = ConfigDict( "Redis client configuration.", default={}, static=True) max_reconnect_delay = ConfigInt( 'Maximum number of seconds between connection attempts', default=3600, static=True) max_retries = ConfigInt( 'Maximum number of consecutive unsuccessful connection attempts ' 'after which no further connection attempts will be made. If this is ' 'not explicitly set, no maximum is applied', static=True) initial_delay = ConfigFloat( 'Initial delay for first reconnection attempt', default=0.1, static=True) factor = ConfigFloat( 'A multiplicitive factor by which the delay grows', # (math.e) default=2.7182818284590451, static=True) jitter = ConfigFloat( 'Percentage of randomness to introduce into the delay length' 'to prevent stampeding.', # molar Planck constant times c, joule meter/mole default=0.11962656472, static=True) web_port = ConfigInt( "The port to listen for requests on, defaults to `0`.", default=0, static=True) web_path = ConfigText( "The path to listen for inbound requests on.", required=True, static=True) health_path = ConfigText( "The path to listen for downstream health checks on" " (useful with HAProxy)", default='health', static=True) class GoConversationTransportBase(Transport): @classmethod def agent_factory(cls): """For swapping out the Agent we use in tests.""" return Agent(reactor) def get_url(self, path): config = self.get_static_config() url = '/'.join([ config.base_url.rstrip('/'), config.conversation_key, path]) return url @inlineCallbacks def map_message_id(self, remote_message_id, local_message_id): config = self.get_static_config() yield self.redis.set(remote_message_id, local_message_id) yield self.redis.expire(remote_message_id, config.message_life_time) def get_message_id(self, remote_message_id): return self.redis.get(remote_message_id) def handle_inbound_message(self, message): return self.publish_message(**message.payload) @inlineCallbacks def handle_inbound_event(self, event): remote_message_id = event['user_message_id'] local_message_id = yield self.get_message_id(remote_message_id) event['user_message_id'] = local_message_id event['sent_message_id'] = remote_message_id yield self.publish_event(**event.payload) @inlineCallbacks def handle_outbound_message(self, message): headers = { 'Content-Type': 'application/json; charset=utf-8', } headers.update(self.get_auth_headers()) params = { 'to_addr': message['to_addr'], 'content': message['content'], 'message_id': message['message_id'], 'in_reply_to': message['in_reply_to'], 'session_event': message['session_event'] } if 'helper_metadata' in message: params['helper_metadata'] = message['helper_metadata'] http_client = HTTPClient(self.agent_factory()) resp = yield http_client.put( self.get_url('messages.json'), data=json.dumps(params).encode('utf-8'), headers=headers) resp_body = yield resp.content() if resp.code != http.OK: log.warning('Unexpected status code: %s, body: %s' % ( resp.code, resp_body)) self.update_status( status='down', component='submitted-to-vumi-go', type='bad_request', message='Message submission rejected by Vumi Go') yield self.publish_nack(message['message_id'], reason='Unexpected status code: %s' % ( resp.code,)) return remote_message = json.loads(resp_body) yield self.map_message_id( remote_message['message_id'], message['message_id']) self.update_status( status='ok', component='submitted-to-vumi-go', type='good_request', message='Message accepted by Vumi Go') yield self.publish_ack(user_message_id=message['message_id'], sent_message_id=remote_message['message_id']) def get_auth_headers(self): config = self.get_static_config() return { 'Authorization': ['Basic ' + base64.b64encode('%s:%s' % ( config.account_key, config.access_token))], } @inlineCallbacks def update_status(self, **kw): '''Publishes a status if it is not a repeat of the previously published status.''' if self.status_detect.check_status(**kw): yield self.publish_status(**kw) class GoConversationHealthResource(Resource): # Most of this copied wholesale from vumi.transports.httprpc. isLeaf = True def __init__(self, transport): self.transport = transport Resource.__init__(self) def render_GET(self, request): request.setResponseCode(http.OK) request.do_not_log = True return self.transport.get_health_response() class GoConversationResource(Resource): # Most of this copied wholesale from vumi.transports.httprpc. isLeaf = True def __init__(self, callback): self.callback = callback Resource.__init__(self) def render_(self, request, request_id=None): request.setHeader("content-type", 'application/json; charset=utf-8') self.callback(request) return NOT_DONE_YET def render_PUT(self, request): return self.render_(request) def render_POST(self, request): return self.render_(request) class GoConversationTransport(GoConversationTransportBase): # Most of this copied wholesale from vumi.transports.httprpc. CONFIG_CLASS = VumiBridgeTransportConfig redis = None web_resource = None @inlineCallbacks def setup_transport(self): self.setup_cacerts() config = self.get_static_config() self.redis = yield TxRedisManager.from_config( config.redis_manager) self.web_resource = yield self.start_web_resources([ (GoConversationResource(self.handle_raw_inbound_message), "%s/messages.json" % (config.web_path)), (GoConversationResource(self.handle_raw_inbound_event), "%s/events.json" % (config.web_path)), (GoConversationHealthResource(self), config.health_path), ], config.web_port) self.status_detect = StatusEdgeDetector() @inlineCallbacks def teardown_transport(self): if self.web_resource is not None: yield self.web_resource.loseConnection() if self.redis is not None: self.redis.close_manager() def setup_cacerts(self): # TODO: This installs an older CA certificate chain that allows # some weak CA certificates. We should switch to .where() when # Vumi Go's certificate doesn't rely on older intermediate # certificates. os.environ["SSL_CERT_FILE"] = certifi.old_where() def get_transport_url(self, suffix=''): """ Get the URL for the HTTP resource. Requires the worker to be started. This is mostly useful in tests, and probably shouldn't be used in non-test code, because the API might live behind a load balancer or proxy. """ addr = self.web_resource.getHost() return "http://%s:%s/%s/%s" % ( addr.host, addr.port, self.config["web_path"], suffix.lstrip('/')) @inlineCallbacks def handle_raw_inbound_event(self, request): try: data = json.loads(request.content.read()) msg = TransportEvent(_process_fields=True, **to_kwargs(data)) yield self.handle_inbound_event(msg) request.finish() if msg.payload["event_type"] == "ack": self.update_status( status='ok', component='sent-by-vumi-go', type='vumi_go_sent', message='Sent by Vumi Go') elif msg.payload["event_type"] == "nack": self.update_status( status='down', component='sent-by-vumi-go', type='vumi_go_failed', message='Vumi Go failed to send') self.update_status( status='ok', component='vumi-go-event', type='good_request', message='Good event received from Vumi Go') except Exception as e: log.err(e) request.setResponseCode(400) request.finish() self.update_status( status='down', component='vumi-go-event', type='bad_request', message='Bad event received from Vumi Go') @inlineCallbacks def handle_raw_inbound_message(self, request): try: data = json.loads(request.content.read()) msg = TransportUserMessage( _process_fields=True, **to_kwargs(data)) yield self.handle_inbound_message(msg) request.finish() self.update_status( status='ok', component='received-from-vumi-go', type='good_request', message='Good request received') except Exception as e: log.err(e) request.setResponseCode(400) request.finish() self.update_status( status='down', component='received-from-vumi-go', type='bad_request', message='Bad request received') PKqG_W/ %vumi/transports/vumi_bridge/client.py# -*- test-case-name: vumi.transports.vumi_bridge.tests.test_client -*- import json from twisted.internet.defer import Deferred from twisted.internet import reactor from twisted.web.client import Agent, ResponseDone, ResponseFailed from twisted.web import http from twisted.protocols import basic from twisted.python.failure import Failure from vumi.message import Message from vumi.utils import to_kwargs from vumi import log from vumi.errors import VumiError class VumiBridgeError(VumiError): """Raised by errors encountered by VumiBridge.""" class VumiBridgeInvalidJsonError(VumiError): """Raised when invalid JSON is received.""" class VumiMessageReceiver(basic.LineReceiver): delimiter = '\n' message_class = Message def __init__(self, message_class, callback, errback, on_connect=None, on_disconnect=None): self.message_class = message_class self.callback = callback self.errback = errback self._response = None self._wait_for_response = Deferred() self._on_connect = on_connect or (lambda *a: None) self._on_disconnect = on_disconnect or (lambda *a: None) self.disconnecting = False def get_response(self): return self._wait_for_response def handle_response(self, response): self._response = response if self._response.code == http.NO_CONTENT: self._wait_for_response.callback(self._response) else: self._response.deliverBody(self) def lineReceived(self, line): d = Deferred() d.addCallback(self.callback) d.addErrback(self.errback) line = line.strip() try: data = json.loads(line) d.callback(self.message_class( _process_fields=True, **to_kwargs(data))) except ValueError, e: f = Failure(VumiBridgeInvalidJsonError(line)) d.errback(f) except Exception, e: log.err() f = Failure(e) d.errback(f) def connectionMade(self): self._on_connect() def connectionLost(self, reason): # the PotentialDataLoss here is because Twisted didn't receive a # content length header, which is normal because we're streaming. if (reason.check(ResponseDone, ResponseFailed, http.PotentialDataLoss) and self._response is not None and not self._wait_for_response.called): self._wait_for_response.callback(self._response) if not self.disconnecting: self._on_disconnect(reason) def disconnect(self): self.disconnecting = True if self.transport and self.transport._producer is not None: self.transport._producer.loseConnection() self.transport._stopProxying() class StreamingClient(object): def __init__(self, agent_factory=None): if agent_factory is None: agent_factory = Agent self.agent = agent_factory(reactor) def stream(self, message_class, callback, errback, url, headers=None, on_connect=None, on_disconnect=None): receiver = VumiMessageReceiver( message_class, callback, errback, on_connect=on_connect, on_disconnect=on_disconnect) d = self.agent.request('GET', url, headers) d.addCallback(lambda response: receiver.handle_response(response)) d.addErrback(log.err) return receiver PK[H^-Mzz'vumi/transports/vumi_bridge/__init__.pyfrom vumi.transports.vumi_bridge.vumi_bridge import GoConversationTransport __all__ = [ 'GoConversationTransport', ] PKqGS 0vumi/transports/vumi_bridge/tests/test_client.pyfrom twisted.internet.defer import inlineCallbacks, DeferredQueue from twisted.web.server import NOT_DONE_YET from twisted.web.client import Agent, ResponseDone from vumi.transports.vumi_bridge.client import ( StreamingClient, VumiBridgeInvalidJsonError) from vumi.message import Message from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import VumiTestCase class TestStreamingClient(VumiTestCase): def setUp(self): self.fake_http = FakeHttpServer(self.handle_request) self.request_queue = DeferredQueue() self.client = StreamingClient(self.fake_http.get_agent) self.messages_received = DeferredQueue() self.errors_received = DeferredQueue() self.disconnects_received = DeferredQueue() def reason_trapper(reason): if reason.trap(ResponseDone): self.disconnects_received.put(reason.getErrorMessage()) self.receiver = self.client.stream( Message, self.messages_received.put, self.errors_received.put, "http://vumi-go-api.example.com/", on_disconnect=reason_trapper) def handle_request(self, request): self.request_queue.put(request) return NOT_DONE_YET def test_default_agent_factory(self): """ If `None` is passed as the `agent_factory`, `Agent` is used instead. """ self.assertNotIsInstance(self.client.agent, Agent) self.assertIsInstance(StreamingClient(None).agent, Agent) self.assertIsInstance(StreamingClient().agent, Agent) @inlineCallbacks def test_callback_on_disconnect(self): req = yield self.request_queue.get() req.write( '%s\n' % (Message(foo='bar').to_json().encode('utf-8'),)) req.finish() message = yield self.messages_received.get() self.assertEqual(message['foo'], 'bar') reason = yield self.disconnects_received.get() # this is the error message we get when a ResponseDone is raised # which happens when the remote server closes the connection. self.assertEqual(reason, 'Response body fully received') @inlineCallbacks def test_invalid_json(self): req = yield self.request_queue.get() req.write("Hello\n") req.finish() err = yield self.assertFailure( self.errors_received.get(), VumiBridgeInvalidJsonError) self.assertEqual(err.args, ("Hello",)) PK gHGP(P(5vumi/transports/vumi_bridge/tests/test_vumi_bridge.pyimport json import os from twisted.internet.defer import inlineCallbacks, returnValue, DeferredQueue from twisted.internet.task import Clock from twisted.web.client import Agent from twisted.web.server import NOT_DONE_YET import certifi from vumi.message import TransportUserMessage from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import VumiTestCase from vumi.transports.tests.helpers import TransportHelper from vumi.transports.vumi_bridge import GoConversationTransport from vumi.config import ConfigError from vumi.utils import http_request_full class TestGoConversationTransportBase(VumiTestCase): transport_class = None def setUp(self): self.tx_helper = self.add_helper(TransportHelper(self.transport_class)) self.fake_http = FakeHttpServer(self.handle_inbound_request) self.clock = Clock() self._request_queue = DeferredQueue() self._pending_reqs = [] self.add_cleanup(self.finish_requests) @inlineCallbacks def get_transport(self, start=True, **config): defaults = { 'account_key': 'account-key', 'conversation_key': 'conversation-key', 'access_token': 'access-token', 'publish_status': True, } defaults.update(config) transport = yield self.tx_helper.get_transport(defaults, start=False) transport.agent_factory = self.fake_http.get_agent if start: yield transport.startWorker() returnValue(transport) @inlineCallbacks def finish_requests(self): for req in self._pending_reqs: if not req.finished: yield req.finish() def handle_inbound_request(self, request): self._request_queue.put(request) return NOT_DONE_YET @inlineCallbacks def get_next_request(self): req = yield self._request_queue.get() self._pending_reqs.append(req) returnValue(req) class TestGoConversationTransport(TestGoConversationTransportBase): transport_class = GoConversationTransport def test_server_settings_without_configs(self): return self.assertFailure(self.get_transport(), ConfigError) def get_configured_transport(self, start=True): return self.get_transport(start=start, web_path='test', web_port='0') def post_msg(self, url, msg_json): data = msg_json.encode('utf-8') return http_request_full( url.encode('utf-8'), data=data, headers={ 'Content-Type': 'application/json; charset=utf-8', }) @inlineCallbacks def test_receiving_messages(self): transport = yield self.get_configured_transport() url = transport.get_transport_url('messages.json') msg = self.tx_helper.make_inbound("inbound") resp = yield self.post_msg(url, msg.to_json()) self.assertEqual(resp.code, 200) [received_msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(received_msg['message_id'], msg['message_id']) [status] = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEquals(status['status'], 'ok') self.assertEquals(status['component'], 'received-from-vumi-go') self.assertEquals(status['type'], 'good_request') self.assertEquals(status['message'], 'Good request received') @inlineCallbacks def test_receive_bad_message(self): transport = yield self.get_configured_transport() url = transport.get_transport_url('messages.json') resp = yield self.post_msg(url, 'This is not JSON.') self.assertEqual(resp.code, 400) [failure] = self.flushLoggedErrors() self.assertTrue('No JSON object' in str(failure)) [status] = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEquals(status['status'], 'down') self.assertEquals(status['component'], 'received-from-vumi-go') self.assertEquals(status['type'], 'bad_request') self.assertEquals(status['message'], 'Bad request received') @inlineCallbacks def test_receiving_ack_events(self): transport = yield self.get_configured_transport() url = transport.get_transport_url('events.json') # prime the mapping yield transport.map_message_id('remote', 'local') ack = self.tx_helper.make_ack(event_id='event-id') ack['user_message_id'] = 'remote' resp = yield self.post_msg(url, ack.to_json()) self.assertEqual(resp.code, 200) [received_ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(received_ack['event_id'], ack['event_id']) self.assertEqual(received_ack['user_message_id'], 'local') self.assertEqual(received_ack['sent_message_id'], 'remote') statuses = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEqual(len(statuses), 2) self.assertEquals(statuses[0]['status'], 'ok') self.assertEquals(statuses[0]['component'], 'sent-by-vumi-go') self.assertEquals(statuses[0]['type'], 'vumi_go_sent') self.assertEquals(statuses[0]['message'], 'Sent by Vumi Go') self.assertEquals(statuses[1]['status'], 'ok') self.assertEquals(statuses[1]['component'], 'vumi-go-event') self.assertEquals(statuses[1]['type'], 'good_request') self.assertEquals(statuses[1]['message'], 'Good event received from Vumi Go') @inlineCallbacks def test_receiving_nack_events(self): transport = yield self.get_configured_transport() url = transport.get_transport_url('events.json') # prime the mapping yield transport.map_message_id('remote', 'local') nack = self.tx_helper.make_nack(event_id='event-id') nack['user_message_id'] = 'remote' resp = yield self.post_msg(url, nack.to_json()) self.assertEqual(resp.code, 200) [received_nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(received_nack['event_id'], nack['event_id']) self.assertEqual(received_nack['user_message_id'], 'local') self.assertEqual(received_nack['sent_message_id'], 'remote') statuses = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEqual(len(statuses), 2) self.assertEquals(statuses[0]['status'], 'down') self.assertEquals(statuses[0]['component'], 'sent-by-vumi-go') self.assertEquals(statuses[0]['type'], 'vumi_go_failed') self.assertEquals(statuses[0]['message'], 'Vumi Go failed to send') self.assertEquals(statuses[1]['status'], 'ok') self.assertEquals(statuses[1]['component'], 'vumi-go-event') self.assertEquals(statuses[1]['type'], 'good_request') self.assertEquals(statuses[1]['message'], 'Good event received from Vumi Go') @inlineCallbacks def test_receive_bad_event(self): transport = yield self.get_configured_transport() url = transport.get_transport_url('events.json') resp = yield self.post_msg(url, 'This is not JSON.') self.assertEqual(resp.code, 400) [failure] = self.flushLoggedErrors() self.assertTrue('No JSON object' in str(failure)) [status] = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEquals(status['status'], 'down') self.assertEquals(status['component'], 'vumi-go-event') self.assertEquals(status['type'], 'bad_request') self.assertEquals(status['message'], 'Bad event received from Vumi Go') @inlineCallbacks def test_weak_cacerts_installed(self): yield self.get_configured_transport() self.assertEqual(os.environ["SSL_CERT_FILE"], certifi.old_where()) @inlineCallbacks def test_sending_messages(self): yield self.get_configured_transport() msg = self.tx_helper.make_outbound( "outbound", session_event=TransportUserMessage.SESSION_CLOSE) d = self.tx_helper.dispatch_outbound(msg) req = yield self.get_next_request() received_msg = json.loads(req.content.read()) self.assertEqual(received_msg, { 'content': msg['content'], 'in_reply_to': None, 'to_addr': msg['to_addr'], 'message_id': msg['message_id'], 'session_event': TransportUserMessage.SESSION_CLOSE, 'helper_metadata': {}, }) remote_id = TransportUserMessage.generate_id() reply = msg.copy() reply['message_id'] = remote_id req.write(reply.to_json().encode('utf-8')) req.finish() yield d [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['user_message_id'], msg['message_id']) self.assertEqual(ack['sent_message_id'], remote_id) [status] = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEquals(status['status'], 'ok') self.assertEquals(status['component'], 'submitted-to-vumi-go') self.assertEquals(status['type'], 'good_request') self.assertEquals(status['message'], 'Message accepted by Vumi Go') @inlineCallbacks def test_sending_bad_messages(self): yield self.get_configured_transport() msg = self.tx_helper.make_outbound( "outbound", session_event=TransportUserMessage.SESSION_CLOSE) self.tx_helper.dispatch_outbound(msg) req = yield self.get_next_request() req.setResponseCode(400, "Bad Request") req.finish() [status] = yield self.tx_helper.wait_for_dispatched_statuses(1) self.assertEquals(status['status'], 'down') self.assertEquals(status['component'], 'submitted-to-vumi-go') self.assertEquals(status['type'], 'bad_request') self.assertEquals(status['message'], 'Message submission rejected by Vumi Go') @inlineCallbacks def test_teardown_before_start(self): transport = yield self.get_configured_transport(start=False) yield transport.teardown_transport() def test_agent_factory_default(self): self.assertTrue(isinstance( GoConversationTransport.agent_factory(), Agent)) PK=JG-vumi/transports/vumi_bridge/tests/__init__.pyPK=JG"w!vumi/transports/dmark/__init__.py""" Dmark_ transports. .. _Dmark: http://dmarkmobile.com/. """ from vumi.transports.dmark.dmark_ussd import DmarkUssdTransport __all__ = ['DmarkUssdTransport'] PKqGr}Ƥ&&#vumi/transports/dmark/dmark_ussd.py# -*- test-case-name: vumi.transports.dmark.tests.test_dmark_ussd -*- import json from twisted.internet.defer import inlineCallbacks, returnValue from twisted.web import http from vumi.components.session import SessionManager from vumi.config import ConfigDict, ConfigInt from vumi.message import TransportUserMessage from vumi.transports.httprpc import HttpRpcTransport class DmarkUssdTransportConfig(HttpRpcTransport.CONFIG_CLASS): """Config for Dmark USSD transport.""" ussd_session_timeout = ConfigInt( "Number of seconds before USSD session information stored in Redis" " expires.", default=600, static=True) redis_manager = ConfigDict( "Redis client configuration.", default={}, static=True) class DmarkUssdTransport(HttpRpcTransport): """Dmark USSD transport over HTTP. When a USSD message is received, Dmark will make an HTTP GET request to the transport with the following query parameters: * ``transactionId``: A unique ID for the USSD session (string). * ``msisdn``: The phone number that the message was sent from (string). * ``ussdServiceCode``: The USSD Service code the request was made to (string). * ``transactionTime``: The time the USSD request was received at Dmark, as a Unix timestamp (UTC). * ``ussdRequestString``: The full content of the USSD request(string). * ``creationTime``: The time the USSD request was sent, as a Unix timestamp (UTC), if available. (This time is given by the mobile network, and may not always be reliable.) * ``response``: ``"false"`` if this is a new session, ``"true"`` if it is not. Currently not used by the transport (it relies on the ``transactionId`` being unique instead). The transport may respond to this request either using JSON or form-encoded data. A successful response must return HTTP status code 200. Any other response code is treated as a failure. This transport responds with JSON encoded data. The JSON response contains the following keys: * ``responseString``: The content to be returned to the phone number that originated the USSD request. * ``action``: Either ``end`` or ``request``. ``end`` signifies that no further interaction is expected from the user and the USSD session should be closed. ``request`` signifies that further interaction is expected. **Example JSON response**: .. sourcecode: javascript { "responseString": "Hello from Vumi!", "action": "end" } """ CONFIG_CLASS = DmarkUssdTransportConfig transport_type = 'ussd' ENCODING = 'utf-8' EXPECTED_FIELDS = frozenset([ 'transactionId', 'msisdn', 'ussdServiceCode', 'transactionTime', 'ussdRequestString', 'creationTime', 'response', ]) @inlineCallbacks def setup_transport(self): yield super(DmarkUssdTransport, self).setup_transport() config = self.get_static_config() r_prefix = "vumi.transports.dmark_ussd:%s" % self.transport_name self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, r_prefix, max_session_length=config.ussd_session_timeout) @inlineCallbacks def teardown_transport(self): yield super(DmarkUssdTransport, self).teardown_transport() yield self.session_manager.stop() @inlineCallbacks def session_event_for_transaction(self, transaction_id): # XXX: There is currently no way to detect when the user closes # the session (i.e. TransportUserMessage.SESSION_CLOSE) session_id = transaction_id session = yield self.session_manager.load_session(transaction_id) if session: session_event = TransportUserMessage.SESSION_RESUME yield self.session_manager.save_session(session_id, session) else: session_event = TransportUserMessage.SESSION_NEW yield self.session_manager.create_session( session_id, transaction_id=transaction_id) returnValue(session_event) @inlineCallbacks def handle_raw_inbound_message(self, request_id, request): try: values, errors = self.get_field_values( request, self.EXPECTED_FIELDS) except UnicodeDecodeError: self.log.msg('Bad request encoding: %r' % request) request_dict = { 'uri': request.uri, 'method': request.method, 'path': request.path, 'content': request.content.read(), 'headers': dict(request.requestHeaders.getAllRawHeaders()), } self.finish_request( request_id, json.dumps({'invalid_request': request_dict}), code=http.BAD_REQUEST) yield self.add_status( component='request', status='down', type='invalid_encoding', message='Invalid encoding', details={ 'request': request_dict, }) return if errors: self.log.msg('Unhappy incoming message: %r' % (errors,)) self.finish_request( request_id, json.dumps(errors), code=http.BAD_REQUEST) yield self.add_status( component='request', status='down', type='invalid_inbound_fields', message='Invalid inbound fields', details=errors) return yield self.add_status( component='request', status='ok', type='request_parsed', message='Request parsed',) to_addr = values["ussdServiceCode"] from_addr = values["msisdn"] session_event = yield self.session_event_for_transaction( values["transactionId"]) yield self.publish_message( message_id=request_id, content=values["ussdRequestString"], to_addr=to_addr, from_addr=from_addr, provider='dmark', session_event=session_event, transport_type=self.transport_type, transport_metadata={ 'dmark_ussd': { 'transaction_id': values['transactionId'], 'transaction_time': values['transactionTime'], 'creation_time': values['creationTime'], } }) @inlineCallbacks def handle_outbound_message(self, message): self.emit("DmarkUssdTransport consuming %r" % (message,)) missing_fields = self.ensure_message_values( message, ['in_reply_to', 'content']) if missing_fields: nack = yield self.reject_message(message, missing_fields) returnValue(nack) if message["session_event"] == TransportUserMessage.SESSION_CLOSE: action = "end" else: action = "request" response_data = { "responseString": message["content"], "action": action, } response_id = self.finish_request( message['in_reply_to'], json.dumps(response_data)) if response_id is not None: ack = yield self.publish_ack( user_message_id=message['message_id'], sent_message_id=message['message_id']) returnValue(ack) else: nack = yield self.publish_nack( user_message_id=message['message_id'], sent_message_id=message['message_id'], reason="Could not find original request.") returnValue(nack) def on_down_response_time(self, message_id, time): request = self.get_request(message_id) # We send different status events for error responses if request.code < 200 or request.code >= 300: return return self.add_status( component='response', status='down', type='very_slow_response', message='Very slow response', reasons=[ 'Response took longer than %fs' % ( self.response_time_down,) ], details={ 'response_time': time, }) def on_degraded_response_time(self, message_id, time): request = self.get_request(message_id) # We send different status events for error responses if request.code < 200 or request.code >= 300: return return self.add_status( component='response', status='degraded', type='slow_response', message='Slow response', reasons=[ 'Response took longer than %fs' % ( self.response_time_degraded,) ], details={ 'response_time': time, }) def on_good_response_time(self, message_id, time): request = self.get_request(message_id) # We send different status events for error responses if request.code < 200 or request.code >= 400: return return self.add_status( component='response', status='ok', type='response_sent', message='Response sent', details={ 'response_time': time, }) def on_timeout(self, message_id, time): return self.add_status( component='response', status='down', type='timeout', message='Response timed out', reasons=[ 'Response took longer than %fs' % ( self.request_timeout,) ], details={ 'response_time': time, }) PKqG)ܶsDsD.vumi/transports/dmark/tests/test_dmark_ussd.py# -*- coding: utf-8 -*- """Tests for vumi.transports.dmark.dmark_ussd.""" import json import urllib from twisted.internet.defer import inlineCallbacks from twisted.internet.task import Clock from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.dmark import DmarkUssdTransport from vumi.transports.httprpc.tests.helpers import HttpRpcTransportHelper class TestDmarkUssdTransport(VumiTestCase): _transaction_id = u'transaction-123' _to_addr = '*121#' _from_addr = '+00775551122' _request_defaults = { 'transactionId': _transaction_id, 'msisdn': _from_addr, 'ussdServiceCode': _to_addr, 'transactionTime': '1389971940', 'ussdRequestString': _to_addr, 'creationTime': '1389971950', 'response': 'false', } @inlineCallbacks def setUp(self): self.clock = Clock() self.patch(DmarkUssdTransport, 'get_clock', lambda _: self.clock) self.config = { 'web_port': 0, 'web_path': '/api/v1/dmark/ussd/', 'publish_status': True, } self.tx_helper = self.add_helper( HttpRpcTransportHelper(DmarkUssdTransport, request_defaults=self._request_defaults)) self.transport = yield self.tx_helper.get_transport(self.config) self.session_manager = self.transport.session_manager self.transport_url = self.transport.get_transport_url( self.config['web_path']) yield self.session_manager.redis._purge_all() # just in case self.session_timestamps = {} @inlineCallbacks def mk_session(self, transaction_id=_transaction_id): yield self.session_manager.create_session( transaction_id, transaction_id=transaction_id) def assert_inbound_message(self, msg, **field_values): expected_field_values = { 'content': self._request_defaults['ussdRequestString'], 'to_addr': self._to_addr, 'from_addr': self._from_addr, 'session_event': TransportUserMessage.SESSION_NEW, 'transport_metadata': { 'dmark_ussd': { 'transaction_id': self._request_defaults['transactionId'], 'transaction_time': self._request_defaults['transactionTime'], 'creation_time': self._request_defaults['creationTime'], }, } } expected_field_values.update(field_values) for field, expected_value in expected_field_values.iteritems(): self.assertEqual(msg[field], expected_value) def assert_ack(self, ack, reply): self.assertEqual(ack.payload['event_type'], 'ack') self.assertEqual(ack.payload['user_message_id'], reply['message_id']) self.assertEqual(ack.payload['sent_message_id'], reply['message_id']) def assert_nack(self, nack, reply, reason): self.assertEqual(nack.payload['event_type'], 'nack') self.assertEqual(nack.payload['user_message_id'], reply['message_id']) self.assertEqual(nack.payload['nack_reason'], reason) @inlineCallbacks def test_inbound_begin(self): user_content = "Who are you?" d = self.tx_helper.mk_request(ussdRequestString=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message( msg, session_event=TransportUserMessage.SESSION_NEW, content=user_content) reply_content = "We are the Knights Who Say ... Ni!" reply = msg.reply(reply_content) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(json.loads(response.delivered_body), { "responseString": reply_content, "action": "request", }) self.assertEqual(response.code, 200) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_inbound_status(self): d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) [status] = yield self.tx_helper.get_dispatched_statuses() self.tx_helper.dispatch_outbound(msg.reply('foo')) yield d self.assertEqual(status['status'], 'ok') self.assertEqual(status['component'], 'request') self.assertEqual(status['type'], 'request_parsed') self.assertEqual(status['message'], 'Request parsed') @inlineCallbacks def test_inbound_cannot_decode(self): '''If the content cannot be decoded, an error shoould be sent back''' user_content = "Who are you?".encode('utf-32') response = yield self.tx_helper.mk_request( ussdRequestString=user_content) self.assertEqual(response.code, 400) body = json.loads(response.delivered_body) request = body['invalid_request'] self.assertEqual(request['content'], '') self.assertEqual(request['path'], self.config['web_path']) self.assertEqual(request['method'], 'GET') self.assertEqual(request['headers']['Connection'], ['close']) encoded_str = urllib.urlencode({'ussdRequestString': user_content}) self.assertTrue(encoded_str in request['uri']) @inlineCallbacks def test_inbound_cannot_decode_status(self): '''If the request cannot be decoded, a status event should be sent''' user_content = "Who are you?".encode('utf-32') yield self.tx_helper.mk_request(ussdRequestString=user_content) [status] = self.tx_helper.get_dispatched_statuses() self.assertEqual(status['component'], 'request') self.assertEqual(status['status'], 'down') self.assertEqual(status['type'], 'invalid_encoding') self.assertEqual(status['message'], 'Invalid encoding') request = status['details']['request'] self.assertEqual(request['content'], '') self.assertEqual(request['path'], self.config['web_path']) self.assertEqual(request['method'], 'GET') self.assertEqual(request['headers']['Connection'], ['close']) encoded_str = urllib.urlencode({'ussdRequestString': user_content}) self.assertTrue(encoded_str in request['uri']) @inlineCallbacks def test_inbound_resume_and_reply_with_end(self): yield self.mk_session(self._transaction_id) user_content = "Well, what is it you want?" d = self.tx_helper.mk_request(ussdRequestString=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message( msg, session_event=TransportUserMessage.SESSION_RESUME, content=user_content) reply_content = "We want ... a shrubbery!" reply = msg.reply(reply_content, continue_session=False) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(json.loads(response.delivered_body), { "responseString": reply_content, "action": "end", }) self.assertEqual(response.code, 200) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_inbound_resume_and_reply_with_resume(self): yield self.mk_session() user_content = "Well, what is it you want?" d = self.tx_helper.mk_request(ussdRequestString=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message( msg, session_event=TransportUserMessage.SESSION_RESUME, content=user_content) reply_content = "We want ... a shrubbery!" reply = msg.reply(reply_content, continue_session=True) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(json.loads(response.delivered_body), { "responseString": reply_content, "action": "request", }) self.assertEqual(response.code, 200) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_request_with_missing_parameters(self): response = yield self.tx_helper.mk_request_raw( params={"ussdServiceCode": '', "msisdn": '', "creationTime": ''}) json_resp = json.loads(response.delivered_body) json_resp['missing_parameter'] = sorted(json_resp['missing_parameter']) self.assertEqual(json_resp, { 'missing_parameter': sorted([ "transactionTime", "transactionId", "response", "ussdRequestString", ]), }) self.assertEqual(response.code, 400) @inlineCallbacks def test_status_with_missing_parameters(self): '''A request with missing parameters should send a TransportStatus with the relevant details.''' yield self.tx_helper.mk_request_raw( params={"ussdServiceCode": '', "msisdn": '', "creationTime": ''}) [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'down') self.assertEqual(status['component'], 'request') self.assertEqual(status['type'], 'invalid_inbound_fields') self.assertEqual(sorted(status['details']['missing_parameter']), [ 'response', 'transactionId', 'transactionTime', 'ussdRequestString']) @inlineCallbacks def test_request_with_unexpected_parameters(self): response = yield self.tx_helper.mk_request( unexpected_p1='', unexpected_p2='') self.assertEqual(response.code, 400) body = json.loads(response.delivered_body) self.assertEqual(set(['unexpected_parameter']), set(body.keys())) self.assertEqual( sorted(body['unexpected_parameter']), ['unexpected_p1', 'unexpected_p2']) @inlineCallbacks def test_status_with_unexpected_parameters(self): '''A request with unexpected parameters should send a TransportStatus with the relevant details.''' yield self.tx_helper.mk_request( unexpected_p1='', unexpected_p2='') [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'down') self.assertEqual(status['component'], 'request') self.assertEqual(status['type'], 'invalid_inbound_fields') self.assertEqual(sorted(status['details']['unexpected_parameter']), [ 'unexpected_p1', 'unexpected_p2']) @inlineCallbacks def test_nack_insufficient_message_fields(self): reply = self.tx_helper.make_outbound( None, message_id='23', in_reply_to=None) self.tx_helper.dispatch_outbound(reply) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_nack(nack, reply, 'Missing fields: in_reply_to, content') @inlineCallbacks def test_nack_http_http_response_failure(self): reply = self.tx_helper.make_outbound( 'There are some who call me ... Tim!', message_id='23', in_reply_to='some-number') self.tx_helper.dispatch_outbound(reply) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_nack( nack, reply, 'Could not find original request.') @inlineCallbacks def test_status_quick_response(self): '''Ok status event should be sent if the response is quick.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.tx_helper.dispatch_outbound(msg.reply('foo')) yield d [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'ok') self.assertEqual(status['component'], 'response') self.assertEqual(status['message'], 'Response sent') self.assertEqual(status['type'], 'response_sent') @inlineCallbacks def test_status_degraded_slow_response(self): '''A degraded status event should be sent if the response took longer than 1 second.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.clock.advance(self.transport.response_time_degraded + 0.1) self.tx_helper.dispatch_outbound(msg.reply('foo')) yield d [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'degraded') self.assertTrue( str(self.transport.response_time_degraded) in status['reasons'][0]) self.assertEqual(status['component'], 'response') self.assertEqual(status['type'], 'slow_response') self.assertEqual(status['message'], 'Slow response') @inlineCallbacks def test_status_down_very_slow_response(self): '''A down status event should be sent if the response took longer than 10 seconds.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.clock.advance(self.transport.response_time_down + 0.1) self.tx_helper.dispatch_outbound(msg.reply('foo')) yield d [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'down') self.assertTrue( str(self.transport.response_time_down) in status['reasons'][0]) self.assertEqual(status['component'], 'response') self.assertEqual(status['type'], 'very_slow_response') self.assertEqual(status['message'], 'Very slow response') @inlineCallbacks def test_no_response_status_for_message_not_found(self): '''If we cannot find the starting timestamp for a message, no status message should be sent''' reply = self.tx_helper.make_outbound( 'There are some who call me ... Tim!', message_id='23', in_reply_to='some-number') self.tx_helper.dispatch_outbound(reply) statuses = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(len(statuses), 0) @inlineCallbacks def test_no_good_status_event_for_bad_responses(self): '''If the http response is not a good (200-399) response, then a status event shouldn't be sent, because we send different status events for those errors.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.transport.finish_request(msg['message_id'], '', code=500) yield d statuses = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(len(statuses), 0) @inlineCallbacks def test_no_degraded_status_event_for_bad_responses(self): '''If the http response is not a good (200-399) response, then a status event shouldn't be sent, because we send different status events for those errors.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.clock.advance(self.transport.response_time_degraded + 0.1) self.transport.finish_request(msg['message_id'], '', code=500) yield d statuses = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(len(statuses), 0) @inlineCallbacks def test_no_down_status_event_for_bad_responses(self): '''If the http response is not a good (200-399) response, then a status event shouldn't be sent, because we send different status events for those errors.''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.clock.advance(self.transport.response_time_down + 0.1) self.transport.finish_request(msg['message_id'], '', code=500) yield d statuses = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(len(statuses), 0) @inlineCallbacks def test_status_down_timeout(self): '''A down status event should be sent if the response timed out''' d = self.tx_helper.mk_request() [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.clear_dispatched_statuses() self.clock.advance(self.transport.request_timeout + 0.1) self.tx_helper.dispatch_outbound(msg.reply('foo')) yield d [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'down') self.assertTrue( str(self.transport.request_timeout) in status['reasons'][0]) self.assertEqual(status['component'], 'response') self.assertEqual(status['type'], 'timeout') self.assertEqual(status['message'], 'Response timed out') self.assertEqual(status['details'], { 'response_time': self.transport.request_timeout + 0.1, }) PK=JG'vumi/transports/dmark/tests/__init__.pyPK=JG%n]'vumi/transports/mtech_kenya/__init__.pyfrom .mtech_kenya import MTechKenyaTransport, MTechKenyaTransportV2 __all__ = ['MTechKenyaTransport', 'MTechKenyaTransportV2'] PKqGW5WW*vumi/transports/mtech_kenya/mtech_kenya.py# -*- test-case-name: vumi.transports.mtech_kenya.tests.test_mtech_kenya -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from vumi.utils import http_request_full from vumi import log from vumi.config import ConfigText from vumi.transports.httprpc import HttpRpcTransport class MTechKenyaTransportConfig(HttpRpcTransport.CONFIG_CLASS): outbound_url = ConfigText('The URL to send outbound messages to.', required=True, static=True) mt_username = ConfigText('The username sent with outbound messages', required=True, static=True) mt_password = ConfigText('The password sent with outbound messages', required=True, static=True) class MTechKenyaTransport(HttpRpcTransport): """ HTTP transport for Cellulant SMS. """ transport_type = 'sms' agent_factory = None # For swapping out the Agent we use in tests. CONFIG_CLASS = MTechKenyaTransportConfig EXPECTED_FIELDS = set(["shortCode", "MSISDN", "MESSAGE", "messageID"]) OPTIONAL_FIELDS = set(["linkID", "gateway", "message_type"]) KNOWN_ERROR_RESPONSE_CODES = { 401: 'Invalid username or password', 403: 'Invalid mobile number', } def make_request(self, params): config = self.get_static_config() url = '%s?%s' % (config.outbound_url, urlencode(params)) log.msg("Making HTTP request: %s" % (url,)) return http_request_full( url, '', method='POST', agent_class=self.agent_factory) @inlineCallbacks def handle_outbound_message(self, message): config = self.get_static_config() params = { 'user': config.mt_username, 'pass': config.mt_password, 'messageID': message['message_id'], 'shortCode': message['from_addr'], 'MSISDN': message['to_addr'], 'MESSAGE': message['content'], } link_id = message['transport_metadata'].get('linkID') if link_id is not None: params['linkID'] = link_id response = yield self.make_request(params) log.msg("Response: (%s) %r" % (response.code, response.delivered_body)) if response.code == 200: yield self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) else: error = self.KNOWN_ERROR_RESPONSE_CODES.get( response.code, 'Unknown response code: %s' % (response.code,)) yield self.publish_nack(message['message_id'], error) @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): values, errors = self.get_field_values( request, self.EXPECTED_FIELDS, self.OPTIONAL_FIELDS) if errors: log.msg('Unhappy incoming message: %s' % (errors,)) yield self.finish_request(message_id, json.dumps(errors), code=400) return log.msg(('MTechKenyaTransport sending from %(MSISDN)s to ' '%(shortCode)s message "%(MESSAGE)s"') % values) transport_metadata = {'transport_message_id': values['messageID']} if values.get('linkID') is not None: transport_metadata['linkID'] = values['linkID'] yield self.publish_message( message_id=message_id, content=values['MESSAGE'], to_addr=values['shortCode'], from_addr=values['MSISDN'], transport_type=self.transport_type, transport_metadata=transport_metadata, ) yield self.finish_request( message_id, json.dumps({'message_id': message_id})) class MTechKenyaTransportV2(MTechKenyaTransport): headers = { 'Content-Type': 'application/x-www-form-urlencoded' } def make_request(self, params): log.msg("Making HTTP request: %s" % (repr(params))) config = self.get_static_config() return http_request_full( config.outbound_url, urlencode(params), method='POST', headers=self.headers, agent_class=self.agent_factory) PKqG6e5vumi/transports/mtech_kenya/tests/test_mtech_kenya.py# -*- encoding: utf-8 -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks, DeferredQueue from vumi.utils import http_request, http_request_full from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import VumiTestCase from vumi.transports.mtech_kenya import ( MTechKenyaTransport, MTechKenyaTransportV2) from vumi.transports.tests.helpers import TransportHelper class TestMTechKenyaTransport(VumiTestCase): transport_class = MTechKenyaTransport @inlineCallbacks def setUp(self): self.cellulant_sms_calls = DeferredQueue() self.fake_http = FakeHttpServer(self.handle_request) self.base_url = "http://mtech-keyna.example.com/" self.valid_creds = { 'mt_username': 'testuser', 'mt_password': 'testpass', } self.config = { 'web_path': "foo", 'web_port': 0, 'outbound_url': self.base_url, } self.config.update(self.valid_creds) self.tx_helper = self.add_helper( TransportHelper(self.transport_class, mobile_addr='2371234567')) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.agent_factory = self.fake_http.get_agent self.transport_url = self.transport.get_transport_url() def handle_request(self, request): if request.args.get('user') != [self.valid_creds['mt_username']]: request.setResponseCode(401) elif request.args.get('MSISDN') != ['2371234567']: request.setResponseCode(403) self.cellulant_sms_calls.put(request) return '' def mkurl(self, content, from_addr="2371234567", **kw): params = { 'shortCode': '12345', 'MSISDN': from_addr, 'MESSAGE': content, 'messageID': '1234567', } params.update(kw) return self.mkurl_raw(**params) def mkurl_raw(self, **params): return '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode(params) ) @inlineCallbacks def test_health(self): result = yield http_request( self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), {'pending_requests': 0}) @inlineCallbacks def test_inbound(self): url = self.mkurl('hello') response = yield http_request(url, '', method='POST') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_handle_non_ascii_input(self): url = self.mkurl(u"öæł".encode("utf-8")) response = yield http_request(url, '', method='POST') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], u"öæł") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_bad_parameter(self): url = self.mkurl('hello', foo='bar') response = yield http_request_full(url, '', method='POST') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'unexpected_parameter': ['foo']}) @inlineCallbacks def test_outbound(self): msg = yield self.tx_helper.make_dispatch_outbound("hi") req = yield self.cellulant_sms_calls.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'POST') self.assertEqual({ 'user': ['testuser'], 'pass': ['testpass'], 'messageID': [msg['message_id']], 'shortCode': ['9292'], 'MSISDN': ['2371234567'], 'MESSAGE': ['hi'], }, req.args) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual('ack', ack['event_type']) @inlineCallbacks def test_outbound_bad_creds(self): self.valid_creds['mt_username'] = 'other_user' msg = yield self.tx_helper.make_dispatch_outbound("hi") req = yield self.cellulant_sms_calls.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'POST') self.assertEqual({ 'user': ['testuser'], 'pass': ['testpass'], 'messageID': [msg['message_id']], 'shortCode': ['9292'], 'MSISDN': ['2371234567'], 'MESSAGE': ['hi'], }, req.args) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual('nack', nack['event_type']) self.assertEqual('Invalid username or password', nack['nack_reason']) @inlineCallbacks def test_outbound_bad_msisdn(self): msg = yield self.tx_helper.make_dispatch_outbound( "hi", to_addr="4471234567") req = yield self.cellulant_sms_calls.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'POST') self.assertEqual({ 'user': ['testuser'], 'pass': ['testpass'], 'messageID': [msg['message_id']], 'shortCode': ['9292'], 'MSISDN': ['4471234567'], 'MESSAGE': ['hi'], }, req.args) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual('nack', nack['event_type']) self.assertEqual('Invalid mobile number', nack['nack_reason']) @inlineCallbacks def test_inbound_linkid(self): url = self.mkurl('hello', linkID='link123') response = yield http_request(url, '', method='POST') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], "hello") self.assertEqual(msg['transport_metadata'], { 'transport_message_id': '1234567', 'linkID': 'link123', }) self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_outbound_linkid(self): msg = yield self.tx_helper.make_dispatch_outbound( "hi", transport_metadata={'linkID': 'link123'}) req = yield self.cellulant_sms_calls.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'POST') self.assertEqual({ 'user': ['testuser'], 'pass': ['testpass'], 'messageID': [msg['message_id']], 'shortCode': ['9292'], 'MSISDN': ['2371234567'], 'MESSAGE': ['hi'], 'linkID': ['link123'], }, req.args) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual('ack', ack['event_type']) class TestMTechKenyaTransportV2(TestMTechKenyaTransport): transport_class = MTechKenyaTransportV2 PK=JG-vumi/transports/mtech_kenya/tests/__init__.pyPKqG4/|"+"+$vumi/transports/vas2nets/vas2nets.py# -*- test-case-name: vumi.transports.vas2nets.tests.test_vas2nets -*- # -*- encoding: utf-8 -*- from urllib import urlencode from datetime import datetime import string import warnings from StringIO import StringIO from twisted.web import http from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from twisted.python import log from twisted.internet.defer import inlineCallbacks from twisted.internet.protocol import Protocol from twisted.internet.error import ConnectionRefusedError from vumi.utils import http_request_full, normalize_msisdn, LogFilterSite from vumi.transports.base import Transport from vumi.transports.failures import TemporaryFailure, PermanentFailure from vumi.errors import VumiError def iso8601(vas2nets_timestamp): if vas2nets_timestamp: ts = datetime.strptime(vas2nets_timestamp, '%Y.%m.%d %H:%M:%S') return ts.isoformat() else: return '' def validate_characters(chars): single_byte_set = ''.join([ string.ascii_lowercase, # a-z string.ascii_uppercase, # A-Z u'0123456789', u'äöüÄÖÜàùòìèé§Ññ£$@', u' ', u'/?!#%&()*+,-:;<=>."\'', u'\n\r', ]) double_byte_set = u'|{}[]€\~^' superset = single_byte_set + double_byte_set for char in chars: if char not in superset: raise Vas2NetsEncodingError('illegal character %s' % char) if char in double_byte_set: warnings.warn(''.join['double byte character %s, max SMS length', ' is 70 chars as a result'] % char, Vas2NetsEncodingWarning) return chars def normalize_outbound_msisdn(msisdn): if msisdn.startswith('+'): return msisdn.replace('+', '00') else: return msisdn class Vas2NetsTransportError(VumiError): pass class Vas2NetsEncodingError(VumiError): pass class Vas2NetsEncodingWarning(VumiError): pass class ReceiveSMSResource(Resource): isLeaf = True def __init__(self, config, publish_func): self.config = config self.publish_func = publish_func self.transport_name = self.config['transport_name'] @inlineCallbacks def do_render(self, request): request.setResponseCode(http.OK) request.setHeader('Content-Type', 'text/plain') try: message_id = '%s.%s' % (self.transport_name, request.args['messageid'][0]) yield self.publish_func( transport_name=self.transport_name, transport_type='sms', message_id=message_id, transport_metadata={ 'original_message_id': message_id, 'timestamp': iso8601(request.args['time'][0]), 'network_id': request.args['provider'][0], 'keyword': request.args['keyword'][0], }, to_addr=normalize_msisdn(request.args['destination'][0]), from_addr=normalize_msisdn(request.args['sender'][0]), content=request.args['text'][0], ) log.msg("Enqueued.") except KeyError, e: request.setResponseCode(http.BAD_REQUEST) msg = ("Need more request keys to complete this request. \n\n" "Missing request key: %s" % (e,)) log.msg('Returning %s: %s' % (http.BAD_REQUEST, msg)) request.write(msg) except ValueError, e: request.setResponseCode(http.BAD_REQUEST) msg = "ValueError: %s" % e log.msg('Returning %s: %s' % (http.BAD_REQUEST, msg)) request.write(msg) except Exception, e: request.setResponseCode(http.INTERNAL_SERVER_ERROR) log.err("Error processing request: %s" % (request,)) request.finish() def render(self, request): self.do_render(request) return NOT_DONE_YET class DeliveryReceiptResource(Resource): isLeaf = True def __init__(self, config, publish_func): self.config = config self.publish_func = publish_func self.transport_name = self.config['transport_name'] @inlineCallbacks def do_render(self, request): log.msg('got hit with %s' % request.args) request.setResponseCode(http.OK) request.setHeader('Content-Type', 'text/plain') try: message_id = '%s.%s' % (self.transport_name, request.args['messageid'][0]) status = int(request.args['status'][0]) delivery_status = 'pending' if status < 0: delivery_status = 'failed' elif status in [2, 14]: delivery_status = 'delivered' yield self.publish_func( user_message_id=message_id, delivery_status=delivery_status, transport_metadata={ 'delivery_status': request.args['status'][0], 'delivery_message': request.args['text'][0], 'timestamp': iso8601(request.args['time'][0]), 'network_id': request.args['provider'][0], }, to_addr=normalize_msisdn(request.args['sender'][0]), ) except KeyError, e: request.setResponseCode(http.BAD_REQUEST) msg = ("Need more request keys to complete this request. \n\n" "Missing request key: %s" % (e,)) log.msg('Returning %s: %s' % (http.BAD_REQUEST, msg)) request.write(msg) except ValueError, e: request.setResponseCode(http.BAD_REQUEST) msg = "ValueError: %s" % e log.msg('Returning %s: %s' % (http.BAD_REQUEST, msg)) request.write(msg) except Exception, e: request.setResponseCode(http.INTERNAL_SERVER_ERROR) log.err("Error processing request: %s" % (request,)) request.finish() def render(self, request): self.do_render(request) return NOT_DONE_YET class HealthResource(Resource): isLeaf = True def __init__(self, config, publish_func): pass def render(self, request): request.setResponseCode(http.OK) request.do_not_log = True return 'OK' class HttpResponseHandler(Protocol): def __init__(self, deferred): self.deferred = deferred self.stringio = StringIO() def dataReceived(self, bytes): self.stringio.write(bytes) def connectionLost(self, reason): self.deferred.callback(self.stringio.getvalue()) class Vas2NetsTransport(Transport): agent_factory = None # For swapping out the Agent we use in tests. def mkres(self, cls, publish_func, path_key): resource = cls(self.config, publish_func) self._resources.append(resource) return (resource, self.config['web_%s_path' % (path_key,)]) @inlineCallbacks def setup_transport(self): self._resources = [] self.config.setdefault('web_health_path', 'health') resources = [ self.mkres(ReceiveSMSResource, self.publish_message, 'receive'), self.mkres(DeliveryReceiptResource, self.publish_delivery_report, 'receipt'), self.mkres(HealthResource, None, 'health'), ] self.receipt_resource = yield self.start_web_resources( resources, self.config['web_port'], LogFilterSite) def get_transport_url(self): """ Get the URL for the HTTP resource. Requires the worker to be started. This is mostly useful in tests, and probably shouldn't be used in non-test code, because the API might live behind a load balancer or proxy. """ addr = self.receipt_resource.getHost() return "http://%s:%s" % (addr.host, addr.port) @inlineCallbacks def handle_outbound_message(self, message): """ handle messages arriving over AMQP meant for delivery via vas2nets """ params = { 'username': self.config['username'], 'password': self.config['password'], 'owner': self.config['owner'], 'service': self.config['service'], } v2n_message_id = message.get('in_reply_to') if v2n_message_id is not None: if v2n_message_id.startswith(self.transport_name): v2n_message_id = v2n_message_id[len(self.transport_name) + 1:] else: v2n_message_id = message['message_id'] message_params = { 'call-number': normalize_outbound_msisdn(message['to_addr']), 'origin': message['from_addr'], 'messageid': v2n_message_id, 'provider': message['transport_metadata']['network_id'], 'tariff': message['transport_metadata'].get('tariff', 0), 'text': validate_characters(message['content']), 'subservice': self.config.get('subservice', message['transport_metadata'].get( 'keyword', '')), } params.update(message_params) log.msg('Hitting %s with %s' % (self.config['url'], params)) log.msg(urlencode(params)) try: response = yield http_request_full( self.config['url'], urlencode(params), { 'User-Agent': ['Vumi Vas2Net Transport'], 'Content-Type': ['application/x-www-form-urlencoded'], }, 'POST', agent_class=self.agent_factory) except ConnectionRefusedError: log.msg("Connection failed sending message:", message) raise TemporaryFailure('connection refused') log.msg('Headers', list(response.headers.getAllRawHeaders())) header = self.config.get('header', 'X-Nth-Smsid') if response.code != 200: raise PermanentFailure('server error: HTTP %s: %s' % (response.code, response.delivered_body)) if response.headers.hasHeader(header): transport_message_id = response.headers.getRawHeaders(header)[0] yield self.publish_ack( user_message_id=message['message_id'], sent_message_id=transport_message_id, ) else: err_msg = 'No SmsId Header, content: %s' % response.delivered_body yield self.publish_nack( user_message_id=message['message_id'], sent_message_id=message['message_id'], reason=err_msg) raise Vas2NetsTransportError(err_msg) def stopWorker(self): """shutdown""" super(Vas2NetsTransport, self).stopWorker() if hasattr(self, 'receipt_resource'): return self.receipt_resource.stopListening() PK=JGԡ$vumi/transports/vas2nets/__init__.py""" Vas2Nets HTTP SMS API. """ from vumi.transports.vas2nets.vas2nets import Vas2NetsTransport __all__ = ['Vas2NetsTransport'] PK=JG{䴯+vumi/transports/vas2nets/transport_stubs.py# -*- test-case-name: vumi.transports.vas2nets.tests.test_vas2nets_stubs -*- import uuid import random from urllib import urlencode from StringIO import StringIO from urlparse import urlparse from datetime import datetime from twisted.python import log from twisted.internet.defer import inlineCallbacks, Deferred, succeed from twisted.internet import reactor from twisted.internet.protocol import Protocol from twisted.web import http from twisted.web.resource import Resource from twisted.web.client import Agent from twisted.web.http_headers import Headers from vumi.utils import StringProducer from vumi.service import Worker class HttpResponseHandler(Protocol): def __init__(self, deferred): self.deferred = deferred self.stringio = StringIO() def dataReceived(self, bytes): self.stringio.write(bytes) def connectionLost(self, reason): self.deferred.callback(self.stringio.getvalue()) @classmethod def handle(cls, response): deferred = Deferred() response.deliverBody(cls(deferred)) return deferred @classmethod def req_POST(cls, url, params, headers=None): agent = Agent(reactor) hdrs = { 'User-Agent': ['Vumi Vas2Net Faker'], 'Content-Type': ['application/x-www-form-urlencoded'], } if headers: hdrs.update(headers) d = agent.request('POST', url, Headers(hdrs), StringProducer(urlencode(params))) return d.addCallback(cls.handle) class FakeVas2NetsHandler(Resource): """ Resource to accept outgoing messages and reply with a delivery report. """ isLeaf = True delay_choices = (0.5, 1, 1.5, 2, 5) deliver_hook = None def __init__(self, receipt_url, delay_choices=None, deliver_hook=None): if delay_choices: self.delay_choices = delay_choices if deliver_hook: self.deliver_hook = deliver_hook self.receipt_url = receipt_url def get_sms_id(self): return uuid.uuid4().get_hex()[-8:] def render_POST(self, request): request.setResponseCode(http.OK) required_fields = [ 'username', 'password', 'call-number', 'origin', 'text', 'messageid', 'provider', 'tariff', 'owner', 'service', 'subservice' ] log.msg('Received sms: %s' % (request.args,)) for key in required_fields: if key not in request.args: request.setResponseCode(http.BAD_REQUEST) sms_id = self.get_sms_id() request.setHeader('X-Nth-Smsid', sms_id) self.schedule_delivery(sms_id, *[request.args.get(f) for f in ['messageid', 'provider', 'sender']]) return "Result_code: 00, Message OK" def schedule_delivery(self, *args): if not self.receipt_url: return succeed(None) return reactor.callLater(random.choice(self.delay_choices), self.deliver_receipt, *args) def deliver_receipt(self, sms_id, message_id, provider, sender): CODES = [ ('2', 'DELIVRD'), ('2', 'DELIVRD'), ('2', 'DELIVRD'), ('2', 'DELIVRD'), ('-28', 'Presumably failed.'), ] code, status = random.choice(CODES) params = { 'smsid': sms_id, 'messageid': message_id, 'status': code, 'text': status, 'time': datetime.utcnow().strftime('%Y.%m.%d %H:%M:%S'), 'provider': provider, 'sender': sender, } log.msg("Sending receipt: %s" % (params,)) d = HttpResponseHandler.req_POST(self.receipt_url, params) if self.deliver_hook: d.addCallback(self.deliver_hook) return d class FakeVas2NetsWorker(Worker): delay_choices = None deliver_hook = None handler = FakeVas2NetsHandler @inlineCallbacks def startWorker(self): url = urlparse(self.config.get('url')) receipt_url = "http://127.0.0.1:%s%s" % ( self.config.get('web_port'), self.config.get('web_receipt_path')) self.receipt_resource = yield self.start_web_resources( [(self.handler(receipt_url, self.delay_choices, self.deliver_hook), url.path)], url.port) def stopWorker(self): if hasattr(self, 'receipt_resource'): self.receipt_resource.stopListening() PKqG%I*2*2/vumi/transports/vas2nets/tests/test_vas2nets.py# encoding: utf-8 import string from datetime import datetime from urllib import urlencode from twisted.web import http from twisted.python import log from twisted.internet.defer import inlineCallbacks from vumi.utils import http_request_full from vumi.message import TransportMessage from vumi.transports.failures import TemporaryFailure, PermanentFailure from vumi.transports.base import FailureMessage from vumi.transports.vas2nets.vas2nets import ( Vas2NetsTransport, validate_characters, Vas2NetsTransportError, Vas2NetsEncodingError, normalize_outbound_msisdn) from vumi.tests.helpers import VumiTestCase from vumi.tests.fake_connection import FakeHttpServer from vumi.transports.tests.helpers import TransportHelper class TestVas2NetsTransport(VumiTestCase): transport_type = 'sms' @inlineCallbacks def setUp(self): self.config = { 'url': 'http://vas2nets.example.com/', 'username': 'username', 'password': 'password', 'owner': 'owner', 'service': 'service', 'subservice': 'subservice', 'web_receive_path': '/receive', 'web_receipt_path': '/receipt', 'web_port': 0, } self.tx_helper = self.add_helper( TransportHelper(Vas2NetsTransport, transport_name='vas2nets')) self.transport = yield self.tx_helper.get_transport(self.config) self.transport_url = self.transport.get_transport_url() self.today = datetime.utcnow().date() def _make_handler(self, message_id, message, code, send_id): def handler(request): log.msg(request.content.read()) request.setResponseCode(code) required_fields = [ 'username', 'password', 'call-number', 'origin', 'text', 'messageid', 'provider', 'tariff', 'owner', 'service', 'subservice' ] log.msg('request.args', request.args) for key in required_fields: log.msg('checking for %s' % key) self.assertTrue(key in request.args) if send_id is not None: self.assertEqual(request.args['messageid'], [send_id]) if message_id: request.setHeader('X-Nth-Smsid', message_id) return message return handler def start_fake_http(self, msg_id, msg, code=http.OK, send_id=None): fake_http = FakeHttpServer( self._make_handler(msg_id, msg, code, send_id)) self.transport.agent_factory = fake_http.get_agent return fake_http def make_request(self, path, qparams): """ Builds a request URL with the appropriate params. """ args = { 'messageid': TransportMessage.generate_id(), 'time': self.today.strftime('%Y.%m.%d %H:%M:%S'), 'sender': '0041791234567', 'destination': '9292', 'provider': 'provider', 'keyword': '', 'header': '', 'text': '', 'keyword': '', } args.update(qparams) url = self.transport_url + path return http_request_full(url, urlencode(args), { 'Content-Type': ['application/x-www-form-urlencoded'], }) def make_delivery_report(self, status, tr_status, tr_message): transport_metadata = { 'delivery_message': tr_message, 'delivery_status': tr_status, 'network_id': 'provider', 'timestamp': self.today.strftime('%Y-%m-%dT%H:%M:%S'), } return self.tx_helper.make_delivery_report( self.tx_helper.make_outbound("foo", message_id="vas2nets.abc"), to_addr="+41791234567", delivery_status=status, transport_metadata=transport_metadata) def make_dispatch_outbound(self, content, **kw): transport_metadata = { 'original_message_id': 'vas2nets.def', 'keyword': '', 'network_id': 'provider', 'timestamp': self.today.strftime('%Y-%m-%dT%H:%M:%S'), } kw.setdefault('transport_metadata', transport_metadata) return self.tx_helper.make_dispatch_outbound(content, **kw) def assert_events_equal(self, expected, received): to_payload = lambda m: dict( (k, v) for k, v in m.payload.iteritems() if k not in ('event_id', 'timestamp', 'transport_type')) self.assertEqual(to_payload(expected), to_payload(received)) def assert_messages_equal(self, expected, received): to_payload = lambda m: dict( (k, v) for k, v in m.payload.iteritems() if k not in ('message_id', 'timestamp')) self.assertEqual(to_payload(expected), to_payload(received)) @inlineCallbacks def test_health_check(self): response = yield http_request_full(self.transport_url + "/health") self.assertEqual('OK', response.delivered_body) self.assertEqual(response.code, http.OK) @inlineCallbacks def test_receive_sms(self): response = yield self.make_request('/receive', { 'messageid': 'abc', 'text': 'hello world', }) self.assertEqual('', response.delivered_body) self.assertEqual(response.headers.getRawHeaders('content-type'), ['text/plain']) self.assertEqual(response.code, http.OK) [msg] = self.tx_helper.get_dispatched_inbound() expected_msg = self.tx_helper.make_inbound( "hello world", message_id='vas2nets.abc', transport_metadata={ 'original_message_id': 'vas2nets.abc', 'keyword': '', 'network_id': 'provider', 'timestamp': self.today.strftime('%Y-%m-%dT%H:%M:%S'), }) self.assert_messages_equal(expected_msg, msg) @inlineCallbacks def test_delivery_receipt_pending(self): response = yield self.make_request('/receipt', { 'smsid': '1', 'messageid': 'abc', 'sender': '+41791234567', 'status': '1', 'text': 'Message submitted to Provider for delivery.', }) self.assertEqual('', response.delivered_body) self.assertEqual(response.headers.getRawHeaders('content-type'), ['text/plain']) self.assertEqual(response.code, http.OK) msg = self.make_delivery_report( 'pending', '1', 'Message submitted to Provider for delivery.') [dr] = self.tx_helper.get_dispatched_events() self.assert_events_equal(msg, dr) @inlineCallbacks def test_delivery_receipt_failed(self): response = yield self.make_request('/receipt', { 'smsid': '1', 'messageid': 'abc', 'sender': '+41791234567', 'status': '-9', 'text': 'Message could not be delivered.', }) self.assertEqual('', response.delivered_body) self.assertEqual(response.headers.getRawHeaders('content-type'), ['text/plain']) self.assertEqual(response.code, http.OK) msg = self.make_delivery_report( 'failed', '-9', 'Message could not be delivered.') [dr] = self.tx_helper.get_dispatched_events() self.assert_events_equal(msg, dr) @inlineCallbacks def test_delivery_receipt_delivered(self): response = yield self.make_request('/receipt', { 'smsid': '1', 'messageid': 'abc', 'sender': '+41791234567', 'status': '2', 'text': 'Message delivered to MSISDN.', }) self.assertEqual('', response.delivered_body) self.assertEqual(response.headers.getRawHeaders('content-type'), ['text/plain']) self.assertEqual(response.code, http.OK) msg = self.make_delivery_report( 'delivered', '2', 'Message delivered to MSISDN.') [dr] = self.tx_helper.get_dispatched_events() self.assert_events_equal(msg, dr) def test_validate_characters(self): self.assertRaises( Vas2NetsEncodingError, validate_characters, u"ïøéå¬∆˚") self.assertTrue(validate_characters(string.ascii_lowercase)) self.assertTrue(validate_characters(string.ascii_uppercase)) self.assertTrue(validate_characters('0123456789')) self.assertTrue(validate_characters(u'äöü ÄÖÜ àùò ìèé §Ññ £$@')) self.assertTrue(validate_characters(u'/?!#%&()*+,-:;<=>.')) self.assertTrue(validate_characters(u'testing\ncarriage\rreturns')) self.assertTrue(validate_characters(u'testing "quotes"')) self.assertTrue(validate_characters(u"testing 'quotes'")) @inlineCallbacks def test_send_sms_success(self): mocked_message_id = TransportMessage.generate_id() mocked_message = "Result_code: 00, Message OK" # open an HTTP resource that mocks the Vas2Nets response for the # duration of this test self.start_fake_http(mocked_message_id, mocked_message) sent_msg = yield self.make_dispatch_outbound("hello") msg = self.tx_helper.make_ack( sent_msg, sent_message_id=mocked_message_id) [ack] = self.tx_helper.get_dispatched_events() self.assert_events_equal(msg, ack) @inlineCallbacks def test_send_sms_reply_success(self): mocked_message_id = TransportMessage.generate_id() reply_to_msgid = TransportMessage.generate_id() mocked_message = "Result_code: 00, Message OK" # open an HTTP resource that mocks the Vas2Nets response for the # duration of this test self.start_fake_http( mocked_message_id, mocked_message, send_id=reply_to_msgid) sent_msg = yield self.make_dispatch_outbound( "hello", in_reply_to=reply_to_msgid) msg = self.tx_helper.make_ack( sent_msg, sent_message_id=mocked_message_id) [ack] = self.tx_helper.get_dispatched_events() self.assert_events_equal(msg, ack) @inlineCallbacks def test_send_sms_fail(self): mocked_message_id = False mocked_message = ("Result_code: 04, Internal system error occurred " "while processing message") self.start_fake_http(mocked_message_id, mocked_message) msg = yield self.make_dispatch_outbound("hello") [twisted_failure] = self.flushLoggedErrors(Vas2NetsTransportError) failure = twisted_failure.value self.assertTrue("No SmsId Header" in str(failure)) [fmsg] = self.tx_helper.get_dispatched_failures() self.assertEqual(msg.payload, fmsg['message']) self.assertTrue( "Vas2NetsTransportError: No SmsId Header" in fmsg['reason']) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertTrue("No SmsId Header" in nack['nack_reason']) @inlineCallbacks def test_send_sms_noconn(self): # TODO: Figure out a solution that doesn't require hoping that # nothing's listening on this port. self.transport.config['url'] = 'http://127.0.0.1:9999/' msg = yield self.make_dispatch_outbound("hello") [twisted_failure] = self.flushLoggedErrors(TemporaryFailure) failure = twisted_failure.value self.assertTrue("connection refused" in str(failure)) [fmsg] = self.tx_helper.get_dispatched_failures() self.assertEqual(msg.payload, fmsg['message']) self.assertEqual(fmsg['failure_code'], FailureMessage.FC_TEMPORARY) self.assertTrue(fmsg['reason'].strip().endswith("connection refused")) @inlineCallbacks def test_send_sms_not_OK(self): mocked_message = "Page not found." self.start_fake_http(None, mocked_message, http.NOT_FOUND) msg = yield self.make_dispatch_outbound("hello") [twisted_failure] = self.flushLoggedErrors(PermanentFailure) failure = twisted_failure.value self.assertTrue("server error: HTTP 404:" in str(failure)) [fmsg] = self.tx_helper.get_dispatched_failures() self.assertEqual(msg.payload, fmsg['message']) self.assertEqual(fmsg['failure_code'], FailureMessage.FC_PERMANENT) self.assertTrue(fmsg['reason'].strip() .endswith("server error: HTTP 404: Page not found.")) def test_normalize_outbound_msisdn(self): self.assertEqual( normalize_outbound_msisdn('+27761234567'), '0027761234567') PKqG7j/vumi/transports/vas2nets/tests/test_failures.py# encoding: utf-8 from datetime import datetime from twisted.web import http from twisted.internet.defer import inlineCallbacks, returnValue, Deferred from vumi.transports.failures import ( FailureMessage, FailureWorker, TemporaryFailure) from vumi.transports.vas2nets.vas2nets import ( Vas2NetsTransport, Vas2NetsTransportError) from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import ( VumiTestCase, MessageHelper, WorkerHelper, PersistenceHelper) class FailureCounter(object): def __init__(self, count): self.count = count self.failures = 0 self.deferred = Deferred() def __call__(self): self.failures += 1 if self.failures >= self.count: self.deferred.callback(None) class TestVas2NetsFailureWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) self.msg_helper = self.add_helper(MessageHelper()) self.worker_helper = self.add_helper( WorkerHelper(self.msg_helper.transport_name)) self.today = datetime.utcnow().date() self.worker = yield self.mk_transport_worker({ 'transport_name': self.msg_helper.transport_name, 'url': 'http://vas2nets.example.com/', 'username': 'username', 'password': 'password', 'owner': 'owner', 'service': 'service', 'subservice': 'subservice', 'web_receive_path': '/receive', 'web_receipt_path': '/receipt', 'web_port': 0, }) self.fail_worker = yield self.mk_failure_worker({ 'transport_name': self.msg_helper.transport_name, 'retry_routing_key': '%(transport_name)s.outbound', 'failures_routing_key': '%(transport_name)s.failures', }) def mk_transport_worker(self, config): config = self.persistence_helper.mk_config(config) return self.worker_helper.get_worker(Vas2NetsTransport, config) @inlineCallbacks def mk_failure_worker(self, config): config = self.persistence_helper.mk_config(config) worker = yield self.worker_helper.get_worker( FailureWorker, config, start=False) worker.retry_publisher = yield self.worker.publish_to("foo") yield worker.startWorker() self.redis = worker.redis returnValue(worker) def mk_fake_http(self, body, headers=None, code=http.OK): if headers is None: headers = {'X-Nth-Smsid': 'message_id'} def handler(request): request.setResponseCode(code) for k, v in headers.items(): request.setHeader(k, v) return body fake_http = FakeHttpServer(handler) self.worker.agent_factory = fake_http.get_agent return fake_http def get_dispatched_failures(self): return self.worker_helper.get_dispatched( None, 'failures', FailureMessage) @inlineCallbacks def get_retry_keys(self): timestamps = yield self.redis.zrange('retry_timestamps', 0, 0) retry_keys = set() for timestamp in timestamps: bucket_key = "retry_keys." + timestamp retry_keys.update((yield self.redis.smembers(bucket_key))) returnValue(retry_keys) def make_outbound(self, content, **kw): kw.setdefault('transport_metadata', {'network_id': 'network-id'}) return self.msg_helper.make_outbound(content, **kw) @inlineCallbacks def test_send_sms_success(self): self.mk_fake_http("Result_code: 00, Message OK") msg = self.make_outbound("outbound") yield self.worker_helper.dispatch_outbound(msg) self.assertEqual(1, len(self.worker_helper.get_dispatched_events())) self.assertEqual(0, len(self.get_dispatched_failures())) @inlineCallbacks def test_send_sms_fail(self): """ A 'No SmsId Header' error should not be retried. """ self.worker.failure_published = FailureCounter(1) self.mk_fake_http("Result_code: 04, Internal system error " "occurred while processing message", {}) msg = self.make_outbound("outbound") yield self.worker_helper.dispatch_outbound(msg) yield self.worker.failure_published.deferred yield self.worker_helper.kick_delivery() self.assertEqual(1, len(self.worker_helper.get_dispatched_events())) self.assertEqual(1, len(self.get_dispatched_failures())) [twisted_failure] = self.flushLoggedErrors(Vas2NetsTransportError) failure = twisted_failure.value self.assertTrue("No SmsId Header" in str(failure)) [fmsg] = self.get_dispatched_failures() self.assertTrue( "Vas2NetsTransportError: No SmsId Header" in fmsg['reason']) [nack] = self.worker_helper.get_dispatched_events() self.assertTrue( "No SmsId Header" in nack['nack_reason']) yield self.worker_helper.kick_delivery() [key] = yield self.fail_worker.get_failure_keys() self.assertEqual(set(), (yield self.get_retry_keys())) @inlineCallbacks def test_send_sms_noconn(self): """ A 'connection refused' error should be retried. """ # TODO: Figure out a solution that doesn't require hoping that # nothing's listening on this port. self.worker.config['url'] = 'http://127.0.0.1:9999/' self.worker.failure_published = FailureCounter(1) msg = self.make_outbound("outbound") yield self.worker_helper.dispatch_outbound(msg) yield self.worker.failure_published.deferred self.assertEqual(0, len(self.worker_helper.get_dispatched_events())) self.assertEqual(1, len(self.get_dispatched_failures())) [twisted_failure] = self.flushLoggedErrors(TemporaryFailure) failure = twisted_failure.value self.assertTrue("connection refused" in str(failure)) [fmsg] = self.get_dispatched_failures() self.assertEqual(msg.payload, fmsg['message']) self.assertEqual(FailureMessage.FC_TEMPORARY, fmsg['failure_code']) self.assertTrue(fmsg['reason'].strip().endswith("connection refused")) yield self.worker_helper.kick_delivery() [key] = yield self.fail_worker.get_failure_keys() self.assertEqual(set([key]), (yield self.get_retry_keys())) PK=JG*vumi/transports/vas2nets/tests/__init__.pyPK=JGDy5vumi/transports/vas2nets/tests/test_vas2nets_stubs.py# encoding: utf-8 from datetime import datetime from urllib import urlencode from twisted.web import http from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from twisted.web.client import Agent from twisted.web.http_headers import Headers from twisted.python import log from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, Deferred from twisted.web.test.test_web import DummyRequest from vumi.service import Worker from vumi.transports.vas2nets.transport_stubs import ( FakeVas2NetsHandler, FakeVas2NetsWorker) from vumi.utils import StringProducer from vumi.tests.helpers import VumiTestCase, WorkerHelper def create_request(params={}, path='/', method='POST'): """ Creates a dummy Vas2Nets request for testing our resources with """ request = DummyRequest(path) request.method = method request.args = params return request class DummyResource(Resource): isLeaf = True def inc_receipts(self): self.receipts += 1 def render_POST(self, request): log.msg(request.content.read()) request.setResponseCode(http.OK) required_fields = ['smsid', 'status', 'text', 'time', 'provider', 'sender', 'messageid'] log.msg('request.args', request.args) for key in required_fields: log.msg('checking for %s' % key) assert key in request.args self.inc_receipts() return '' class DummyWorker(Worker): @inlineCallbacks def startWorker(self): self.test_resource = DummyResource() self.test_resource.receipts = 0 self.resource = yield self.start_web_resources( [ (self.test_resource, self.config['web_receipt_path']), ], self.config['web_port'], ) def stopWorker(self): if hasattr(self, 'resource'): self.resource.stopListening() class StubbedFakeVas2NetsWorker(FakeVas2NetsWorker): delay_choices = (0,) class TestFakeVas2NetsWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) self.config = { 'web_port': 9999, 'web_receive_path': '/t/receive', 'web_receipt_path': '/t/receipt', 'url': 'http://127.0.0.1:9998/t/send', } self.worker = yield self.worker_helper.get_worker( StubbedFakeVas2NetsWorker, self.config, start=False) self.test_worker = yield self.worker_helper.get_worker( DummyWorker, self.config, start=False) self.today = datetime.utcnow().date() def render_request(self, resource, request): d = request.notifyFinish() response = resource.render(request) if response != NOT_DONE_YET: request.write(response) request.finish() return d @inlineCallbacks def test_receive_sent_sms(self): resource = FakeVas2NetsHandler('http://127.0.0.1:9999/t/receipt', (0,)) resource.schedule_delivery = lambda *a: None request = create_request({ 'username': 'user', 'password': 'pass', 'owner': 'owner', 'service': 'service', 'subservice': 'subservice', 'call-number': '+27831234567', 'origin': '12345', 'messageid': 'message_id', 'provider': 'provider', 'tariff': 0, 'text': 'message content', }) yield self.render_request(resource, request) self.assertEquals(''.join(request.written), "Result_code: 00, Message OK") self.assertEquals(request.responseCode, http.OK) @inlineCallbacks def test_deliver_receipt(self): resource = FakeVas2NetsHandler('http://127.0.0.1:9999/t/receipt', (0,)) yield self.test_worker.startWorker() yield resource.deliver_receipt('smsid', 'msgid', 'provider', 'sender') @inlineCallbacks def test_round_trip(self): d = Deferred() self.worker.deliver_hook = lambda x: d.callback(None) self.worker.startWorker() self.test_worker.startWorker() params = { 'username': 'user', 'password': 'pass', 'owner': 'owner', 'service': 'service', 'subservice': 'subservice', 'call-number': '+27831234567', 'origin': '12345', 'messageid': 'message_id', 'provider': 'provider', 'tariff': 0, 'text': 'message content', } agent = Agent(reactor) response = yield agent.request( 'POST', self.config['url'], Headers({ 'User-Agent': ['Vumi Vas2Net Transport'], 'Content-Type': ['application/x-www-form-urlencoded'], }), StringProducer(urlencode(params))) log.msg('Headers', list(response.headers.getAllRawHeaders())) self.assertTrue(response.headers.hasHeader('X-Nth-Smsid')) yield d PK=JGZv1))vumi/transports/irc/irc.py# -*- test-case-name: vumi.transports.irc.tests.test_irc -*- """IRC transport.""" from twisted.words.protocols import irc from twisted.internet import protocol from twisted.internet.defer import inlineCallbacks from twisted.python import log from vumi.config import ( ConfigClientEndpoint, ConfigText, ConfigList, ConfigInt, ClientEndpointFallback) from vumi.reconnecting_client import ReconnectingClientService from vumi.transports import Transport from vumi.transports.failures import TemporaryFailure class IrcMessage(object): """Container for details of a message to or from an IRC user. :type sender: str :param sender: Who sent the message (usually user!ident@hostmask). :type recipient: str :param recipient: User or channel recieving the message. :type content: str :param content: Contents of message. :type nickname: str :param nickname: Nickname used by the client that received the message. Optional. :type command: str :param command: IRC command that produced the message. """ def __init__(self, sender, command, recipient, content, nickname=None): self.sender = self.canonicalize_recipient(sender) self.command = command self.recipient = self.canonicalize_recipient(recipient) self.content = content self.nickname = nickname def __eq__(self, other): if isinstance(other, IrcMessage): return all(getattr(self, name) == getattr(other, name) for name in ("sender", "command", "recipient", "content", "nickname")) return False @staticmethod def canonicalize_recipient(recipient): """Convert a generic IRC address (with possible server parts) to a simple lowercase username or channel.""" return recipient.partition('!')[0].lower() def channel(self): """Return the channel if the recipient is a channel. Otherwise return None. """ if self.recipient[:1] in ('#', '&', '$'): return self.recipient return None def addressed_to(self, nickname): nickname = self.canonicalize_recipient(nickname) if not self.channel(): return self.recipient == nickname parts = self.content.split(None, 1) maybe_nickname = parts[0].rstrip(':,') if parts else '' maybe_nickname = self.canonicalize_recipient(maybe_nickname) return maybe_nickname == nickname class VumiBotProtocol(irc.IRCClient): """An IRC bot that bridges IRC to Vumi.""" def __init__(self, nickname, channels, irc_transport): self.connected = False self.nickname = nickname self.channels = channels self.irc_transport = irc_transport def publish_message(self, irc_msg): self.irc_transport.handle_inbound_irc_message(irc_msg) def consume_message(self, irc_msg): recipient = irc_msg.recipient.encode('utf8') content = irc_msg.content.encode('utf8') if irc_msg.command == 'ACTION': self.describe(recipient, content) else: self.msg(recipient, content) # connecting and disconnecting from server def connectionMade(self): irc.IRCClient.connectionMade(self) self.connected = True log.msg("Connected (nickname is: %s)" % (self.nickname,)) def connectionLost(self, reason): irc.IRCClient.connectionLost(self, reason) self.connected = False log.msg("Disconnected (nickname was: %s)." % (self.nickname,)) # callbacks for events def signedOn(self): """Called when bot has succesfully signed on to server.""" log.msg("Attempting to join channels: %r" % (self.channels,)) for channel in self.channels: self.join(channel) def joined(self, channel): """This will get called when the bot joins the channel.""" log.msg("Joined %r" % (channel,)) def privmsg(self, sender, recipient, message): """This will get called when the bot receives a message.""" irc_msg = IrcMessage(sender, 'PRIVMSG', recipient, message, self.nickname) self.publish_message(irc_msg) def noticed(self, sender, recipient, message): """This will get called when the bot receives a notice.""" irc_msg = IrcMessage(sender, 'NOTICE', recipient, message, self.nickname) self.publish_message(irc_msg) def action(self, sender, recipient, message): """This will get called when the bot sees someone do an action.""" irc_msg = IrcMessage(sender, 'ACTION', recipient, message, self.nickname) self.publish_message(irc_msg) # irc callbacks def irc_NICK(self, prefix, params): """Called when an IRC user changes their nickname.""" old_nick = prefix.partition('!')[0] new_nick = params[0] log.msg("Nick changed from %r to %r" % (old_nick, new_nick)) # For fun, override the method that determines how a nickname is changed on # collisions. The default method appends an underscore. def alterCollidedNick(self, nickname): """ Generate an altered version of a nickname that caused a collision in an effort to create an unused related name for subsequent registration. """ return nickname + '^' class VumiBotFactory(protocol.ClientFactory): """A factory for :class:`VumiBotClient` instances. A new protocol instance will be created each time we connect to the server. """ # the class of the protocol to build when new connection is made protocol = VumiBotProtocol def __init__(self, vumibot_args): self.vumibot_args = vumibot_args self.irc_server = None self.vumibot = None def format_server_address(self, addr): # getattr is used in case someone connects to an # endpoint that isn't an IPv4 or IPv6 endpoint. return "%s:%s" % ( getattr(addr, 'host', 'unknown'), getattr(addr, 'port', 'unknown') ) def buildProtocol(self, addr): self.irc_server = self.format_server_address(addr) self.vumibot = self.protocol(*self.vumibot_args) return self.vumibot class IrcConfig(Transport.CONFIG_CLASS): """ IRC transport config. """ twisted_endpoint = ConfigClientEndpoint( "Endpoint to connect to the IRC server on.", fallbacks=[ClientEndpointFallback('network', 'port')], required=True, static=True) nickname = ConfigText( "IRC nickname for the transport IRC client to use.", required=True, static=True) channels = ConfigList( "List of channels to join.", default=(), static=True) # TODO: Deprecate these fields when confmodel#5 is done. network = ConfigText( "*DEPRECATED* 'network' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) port = ConfigInt( "*DEPRECATED* 'network' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True, default=6667) class IrcTransport(Transport): """ IRC based transport. """ CONFIG_CLASS = IrcConfig factory = None service = None def setup_transport(self): config = self.get_static_config() self.factory = VumiBotFactory((config.nickname, config.channels, self)) self.service = ReconnectingClientService( config.twisted_endpoint, self.factory) self.service.startService() @inlineCallbacks def teardown_transport(self): if self.service is not None: yield self.service.stopService() def handle_inbound_irc_message(self, irc_msg): irc_server = self.factory.irc_server irc_channel = irc_msg.channel() nickname = irc_msg.nickname to_addr = None content = irc_msg.content if irc_channel is None: # This is a direct message, not a channel message. to_addr = irc_msg.recipient elif irc_msg.addressed_to(nickname): # This is a channel message, but we've been mentioned by name. to_addr = nickname # Strip the name prefix, so workers don't have to handle it. content = (content.split(None, 1) + [''])[1] message_dict = { 'to_addr': to_addr, 'from_addr': irc_msg.sender, 'group': irc_channel, 'content': content, 'transport_name': self.transport_name, 'transport_type': self.config.get('transport_type', 'irc'), 'helper_metadata': { 'irc': { 'transport_nickname': nickname, 'addressed_to_transport': irc_msg.addressed_to(nickname), 'irc_server': irc_server, 'irc_channel': irc_channel, 'irc_command': irc_msg.command, }, }, 'transport_metadata': { 'irc_channel': irc_channel, }, } self.publish_message(**message_dict) @inlineCallbacks def handle_outbound_message(self, msg): vumibot = self.factory.vumibot if vumibot is None or not vumibot.connected: raise TemporaryFailure("IrcTransport not connected.") irc_metadata = msg['helper_metadata'].get('irc', {}) transport_metadata = msg['transport_metadata'] irc_command = irc_metadata.get('irc_command', 'PRIVMSG') # Continue to support pre-group-chat hackery. irc_channel = msg.get('group') or transport_metadata.get('irc_channel') recipient = irc_channel if irc_channel is not None else msg['to_addr'] content = msg['content'] if irc_channel and msg['to_addr'] and (irc_command != 'ACTION'): # We have a directed channel message, so prefix with the nick. content = "%s: %s" % (msg['to_addr'], content) irc_msg = IrcMessage(vumibot.nickname, irc_command, recipient, content) vumibot.consume_message(irc_msg) # intentionally duplicate message id in sent_message_id since # IRC doesn't have its own message ids. yield self.publish_ack(user_message_id=msg['message_id'], sent_message_id=msg['message_id']) PK=JGdgccvumi/transports/irc/__init__.py"""IRC Transport.""" from vumi.transports.irc.irc import IrcTransport __all__ = ['IrcTransport'] PK=JGܼ/E/E%vumi/transports/irc/tests/test_irc.py"""Tests for vumi.transports.irc.irc.""" from StringIO import StringIO from twisted.internet.defer import (inlineCallbacks, returnValue, DeferredQueue, Deferred) from twisted.internet.protocol import FileWrapper from vumi.tests.utils import LogCatcher from vumi.transports.failures import FailureMessage, TemporaryFailure from vumi.transports.irc.irc import IrcMessage, VumiBotProtocol from vumi.transports.irc import IrcTransport from vumi.transports.tests.helpers import TransportHelper from vumi.tests.helpers import VumiTestCase class TestIrcMessage(VumiTestCase): def test_message(self): msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') self.assertEqual(msg.sender, 'user') self.assertEqual(msg.command, 'PRIVMSG') self.assertEqual(msg.recipient, '#bar') self.assertEqual(msg.content, 'hello?') def test_action(self): msg = IrcMessage('user!userfoo@example.com', 'ACTION', '#bar', 'hello?') self.assertEqual(msg.command, 'ACTION') def test_channel(self): msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') self.assertEqual(msg.channel(), '#bar') msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', 'user2!user2@example.com', 'hello?') self.assertEqual(msg.channel(), None) def test_nick(self): msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?', 'nicktest') self.assertEqual(msg.nickname, 'nicktest') def test_addressed_to(self): msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', 'otheruser!userfoo@example.com', 'hello?', 'nicktest') self.assertFalse(msg.addressed_to('user')) self.assertTrue(msg.addressed_to('otheruser')) def test_equality(self): msg1 = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') msg2 = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') self.assertTrue(msg1 == msg2) def test_inequality(self): msg1 = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') self.assertFalse(msg1 == object()) def test_canonicalize_recipient(self): canonical = IrcMessage.canonicalize_recipient self.assertEqual(canonical("user!userfoo@example.com"), "user") self.assertEqual(canonical("#channel"), "#channel") self.assertEqual(canonical("userfoo"), "userfoo") class TestVumiBotProtocol(VumiTestCase): nick = "testnick" channel = "#test1" def setUp(self): self.f = StringIO() self.t = FileWrapper(self.f) self.vb = VumiBotProtocol(self.nick, [self.channel], self) self.vb.makeConnection(self.t) self.recvd_messages = [] def handle_inbound_irc_message(self, irc_msg): self.recvd_messages.append(irc_msg) def check(self, lines): connect_lines = [ "NICK %s" % self.nick, # foo and bar are twisted's mis-implementation of RFC 2812 # Compare http://tools.ietf.org/html/rfc2812#section-3.1.3 # and http://twistedmatrix.com/trac/browser/tags/releases/ # twisted-11.0.0/twisted/words/protocols/irc.py#L1552 "USER %s foo bar :None" % self.nick, ] expected_lines = connect_lines + lines self.assertEqual(self.f.getvalue().splitlines(), expected_lines) def test_publish_message(self): msg = IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?') self.vb.publish_message(msg) self.check([]) [recvd_msg] = self.recvd_messages self.assertEqual(recvd_msg, msg) def test_consume_message_privmsg(self): self.vb.consume_message(IrcMessage('user!userfoo@example.com', 'PRIVMSG', '#bar', 'hello?')) self.check(["PRIVMSG #bar :hello?"]) def test_consume_message_action(self): self.vb.consume_message(IrcMessage('user!userfoo@example.com', 'ACTION', '#bar', 'hello?')) self.check(["PRIVMSG #bar :\x01ACTION hello?\x01"]) def test_connection_made(self): # just check that the connect messages made it through self.check([]) def test_connection_lost(self): with LogCatcher() as logger: self.vb.connectionLost("test loss of connection") [logmsg] = logger.messages() self.assertEqual(logmsg, 'Disconnected (nickname was: %s).' % self.nick) self.assertEqual(logger.errors, []) def test_signed_on(self): self.vb.signedOn() self.check(['JOIN %s' % self.channel]) def test_joined(self): with LogCatcher() as logger: self.vb.joined(self.channel) [logmsg] = logger.messages() self.assertEqual(logmsg, 'Joined %r' % self.channel) def test_privmsg(self): sender, command, recipient, text = (self.nick, 'PRIVMSG', "#zoo", "Hello zooites") self.vb.privmsg(sender, recipient, text) [recvd_msg] = self.recvd_messages self.assertEqual(recvd_msg, IrcMessage(sender, command, recipient, text, self.vb.nickname)) def test_action(self): sender, command, recipient, text = (self.nick, 'ACTION', "#zoo", "waves at zooites") self.vb.action(sender, recipient, text) [recvd_msg] = self.recvd_messages self.assertEqual(recvd_msg, IrcMessage(sender, command, recipient, text, self.vb.nickname)) def test_irc_nick(self): with LogCatcher() as logger: self.vb.irc_NICK("oldnick!host", ["newnick"]) [logmsg] = logger.messages() self.assertEqual(logmsg, "Nick changed from 'oldnick' to 'newnick'") def test_alter_collided_nick(self): collided_nick = "commonnick" new_nick = self.vb.alterCollidedNick(collided_nick) self.assertEqual(new_nick, collided_nick + '^') from twisted.internet.protocol import ServerFactory from twisted.internet import reactor from twisted.words.protocols.irc import IRC class StubbyIrcServerProtocol(IRC): hostname = '127.0.0.1' def irc_unknown(self, prefix, command, params): self.factory.events.put((prefix, command, params)) def connectionLost(self, reason): IRC.connectionLost(self, reason) self.factory.finished_d.callback(None) class StubbyIrcServer(ServerFactory): protocol = StubbyIrcServerProtocol def startFactory(self): self.server = None self.events = DeferredQueue() self.finished_d = Deferred() def buildProtocol(self, addr): self.server = ServerFactory.buildProtocol(self, addr) self.server.factory = self return self.server @inlineCallbacks def filter_events(self, command_type): while True: ev = yield self.events.get() if ev[1] == command_type: returnValue(ev) class TestIrcTransport(VumiTestCase): nick = 'vumibottest' @inlineCallbacks def setUp(self): self.irc_server = StubbyIrcServer() self.add_cleanup(lambda: self.irc_server.finished_d) self.tx_helper = self.add_helper(TransportHelper(IrcTransport)) self.irc_connector = yield reactor.listenTCP( 0, self.irc_server, interface='127.0.0.1') self.add_cleanup(self.irc_connector.stopListening) addr = self.irc_connector.getHost() self.server_addr = "%s:%s" % (addr.host, addr.port) self.transport = yield self.tx_helper.get_transport({ 'network': addr.host, 'port': addr.port, 'channels': [], 'nickname': self.nick, }) # wait for transport to connect yield self.irc_server.filter_events("NICK") def dispatch_outbound_irc(self, *args, **kw): helper_metadata = kw.setdefault('helper_metadata', {'irc': {}}) irc_command = kw.pop('irc_command', None) if irc_command is not None: helper_metadata['irc']['irc_command'] = irc_command return self.tx_helper.make_dispatch_outbound(*args, **kw) def assert_inbound_message(self, msg, to_addr, from_addr, channel, content, addressed_to_transport, irc_command): self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], to_addr) self.assertEqual(msg['from_addr'], from_addr) self.assertEqual(msg['group'], channel) self.assertEqual(msg['content'], content) self.assertEqual(msg['helper_metadata'], { 'irc': { 'transport_nickname': self.nick, 'addressed_to_transport': addressed_to_transport, 'irc_server': self.server_addr, 'irc_channel': channel, 'irc_command': irc_command, } }) self.assertEqual(msg['transport_metadata'], { 'irc_channel': channel, }) def assert_ack_for(self, msg, ack): to_payload = lambda m: dict( (k, v) for k, v in m.payload.iteritems() if k not in ('event_id', 'timestamp', 'transport_type')) self.assertEqual(to_payload(self.tx_helper.make_ack(msg)), to_payload(ack)) def send_irc_message(self, content, recipient, sender="user!ident@host"): self.irc_server.server.privmsg(sender, recipient, content) @inlineCallbacks def test_handle_inbound_to_channel(self): text = "Hello gooites" self.send_irc_message(text, "#zoo") [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, to_addr=None, from_addr="user", channel="#zoo", content=text, addressed_to_transport=False, irc_command="PRIVMSG") @inlineCallbacks def test_handle_inbound_to_channel_directed(self): self.send_irc_message("%s: Hi" % (self.nick,), "#zoo") [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, to_addr=self.nick, from_addr="user", channel="#zoo", content="Hi", addressed_to_transport=True, irc_command="PRIVMSG") @inlineCallbacks def test_handle_inbound_to_user(self): self.send_irc_message("Hi there", "%s!bot@host" % (self.nick,)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, to_addr=self.nick, from_addr="user", channel=None, content="Hi there", addressed_to_transport=True, irc_command="PRIVMSG") @inlineCallbacks def test_handle_inbound_channel_notice(self): sender, recipient, text = "user!ident@host", "#zoo", "Hello gooites" self.irc_server.server.notice(sender, recipient, text) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], None) self.assertEqual(msg['from_addr'], "user") self.assertEqual(msg['group'], "#zoo") self.assertEqual(msg['content'], text) self.assertEqual(msg['helper_metadata'], { 'irc': { 'transport_nickname': self.nick, 'addressed_to_transport': False, 'irc_server': self.server_addr, 'irc_channel': '#zoo', 'irc_command': 'NOTICE', } }) self.assertEqual(msg['transport_metadata'], { 'irc_channel': '#zoo', }) @inlineCallbacks def test_handle_inbound_user_notice(self): sender, recipient, text = "user!ident@host", "bot", "Hello gooites" self.irc_server.server.notice(sender, recipient, text) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "bot") self.assertEqual(msg['from_addr'], "user") self.assertEqual(msg['group'], None) self.assertEqual(msg['content'], text) self.assertEqual(msg['helper_metadata'], { 'irc': { 'transport_nickname': self.nick, 'addressed_to_transport': False, 'irc_server': self.server_addr, 'irc_channel': None, 'irc_command': 'NOTICE', } }) self.assertEqual(msg['transport_metadata'], { 'irc_channel': None, }) @inlineCallbacks def test_handle_outbound_message_while_disconnected(self): yield self.irc_connector.stopListening() self.transport.factory.vumibot.connectionLost("testing disconnect") expected_error = "IrcTransport not connected." yield self.dispatch_outbound_irc("outbound") [error] = self.tx_helper.get_dispatched_failures() self.assertTrue(error['reason'].strip().endswith(expected_error)) [error] = self.flushLoggedErrors(TemporaryFailure) failure = error.value self.assertEqual(failure.failure_code, FailureMessage.FC_TEMPORARY) self.assertEqual(str(failure), expected_error) @inlineCallbacks def test_handle_outbound_to_channel_old(self): msg = yield self.dispatch_outbound_irc( "hello world", to_addr="#vumitest") event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['#vumitest', 'hello world'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_to_channel(self): msg = yield self.dispatch_outbound_irc( "hello world", to_addr=None, group="#vumitest") event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['#vumitest', 'hello world'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_to_channel_directed(self): msg = yield self.dispatch_outbound_irc( "hello world", to_addr="user", group="#vumitest") event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['#vumitest', 'user: hello world'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_to_user(self): msg = yield self.dispatch_outbound_irc( "hello world", to_addr="user", group=None) event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['user', 'hello world'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_action_to_channel(self): msg = yield self.dispatch_outbound_irc( "waves", to_addr=None, group="#vumitest", irc_command="ACTION") event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['#vumitest', '\x01ACTION waves\x01'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_action_to_channel_directed(self): msg = yield self.dispatch_outbound_irc( "waves", to_addr="user", group="#vumitest", irc_command='ACTION') event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['#vumitest', '\x01ACTION waves\x01'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) @inlineCallbacks def test_handle_outbound_action_to_user(self): msg = yield self.dispatch_outbound_irc( "waves", to_addr="user", group=None, irc_command='ACTION') event = yield self.irc_server.filter_events('PRIVMSG') self.assertEqual(event, ('', 'PRIVMSG', ['user', '\x01ACTION waves\x01'])) [smsg] = self.tx_helper.get_dispatched_events() self.assert_ack_for(msg, smsg) PK=JG%vumi/transports/irc/tests/__init__.pyPKqG)kpp vumi/transports/telnet/telnet.py# -*- test-case-name: vumi.transports.telnet.tests.test_telnet -*- """Transport that sends and receives to telnet clients.""" from twisted.internet.protocol import ServerFactory from twisted.internet.defer import inlineCallbacks, Deferred, gatherResults from twisted.conch.telnet import ( TelnetTransport, TelnetProtocol, StatefulTelnetProtocol) from vumi.config import ( ConfigServerEndpoint, ConfigText, ConfigInt, ServerEndpointFallback) from vumi.transports import Transport from vumi.message import TransportUserMessage class TelnetTransportProtocol(TelnetProtocol): """Extends Twisted's TelnetProtocol for the Telnet transport.""" def __init__(self, vumi_transport): self.vumi_transport = vumi_transport def getAddress(self): return self.vumi_transport._format_addr(self.transport.getPeer()) def connectionMade(self): self.vumi_transport.register_client(self) def connectionLost(self, reason): self.vumi_transport.deregister_client(self) def dataReceived(self, data): data = data.rstrip('\r\n') if data.lower() == '/quit': self.loseConnection() else: self.vumi_transport.handle_input(self, data) class AddressedTelnetTransportProtocol(StatefulTelnetProtocol): state = "ToAddr" def __init__(self, vumi_transport): self.vumi_transport = vumi_transport self.to_addr = None self.from_addr = None def connectionMade(self): self.transport.write('Please provide "to_addr":\n') def telnet_ToAddr(self, line): if not line: return "ToAddr" self.to_addr = line self.transport.write('Please provide "from_addr":\n') return "FromAddr" def telnet_FromAddr(self, line): if not line: return "FromAddr" if self.from_addr is None: self.from_addr = line summary = "[Sending all messages to: %s and from: %s]\n" % ( self.to_addr, self.from_addr) self.transport.write(summary) self.vumi_transport._to_addr = self.to_addr self.vumi_transport.register_client(self) return "SetupDone" def telnet_SetupDone(self, line): self.vumi_transport.handle_input(self, line.rstrip('\r\n')) def getAddress(self): return self.from_addr def connectionLost(self, reason): StatefulTelnetProtocol.connectionLost(self, reason) if self.from_addr is not None: self.vumi_transport.deregister_client(self) fallback_format_str = "tcp:interface={telnet_host}:port={telnet_port}" class TelnetServerConfig(Transport.CONFIG_CLASS): """ Telnet transport configuration. """ twisted_endpoint = ConfigServerEndpoint( "The endpoint the Telnet server will listen on.", fallbacks=[ServerEndpointFallback('telnet_host', 'telnet_port')], required=True, static=True) to_addr = ConfigText( "The to_addr to use for inbound messages. The default is to use" " the host:port of the telnet server.", default=None, static=True) transport_type = ConfigText( "The transport_type to use for inbound messages.", default='telnet', static=True) # TODO: Deprecate these fields when confmodel#5 is done. telnet_host = ConfigText( "*DEPRECATED* 'telnet_host' and 'telnet_port' fields may be used in" "place of the 'twisted_endpoint' field.", static=True) telnet_port = ConfigInt( "*DEPRECATED* 'telnet_host' and 'telnet_port' fields may be used in" " place of the 'twisted_endpoint' field.", static=True) class TelnetServerTransport(Transport): """Telnet based transport. This transport listens on a specified port for telnet clients and routes lines to and from connected clients. """ CONFIG_CLASS = TelnetServerConfig protocol = TelnetTransportProtocol telnet_server = None @inlineCallbacks def setup_transport(self): config = self.get_static_config() self._clients = {} def protocol(): return TelnetTransport(self.protocol, self) factory = ServerFactory() factory.protocol = protocol self.telnet_server = yield config.twisted_endpoint.listen(factory) self._transport_type = config.transport_type self._to_addr = config.to_addr if self._to_addr is None: self._to_addr = self._format_addr(self.telnet_server.getHost()) @inlineCallbacks def teardown_transport(self): if hasattr(self, 'telnet_server'): # We need to wait for all the client connections to be closed (and # their deregistration messages sent) before tearing down the rest # of the transport. wait_for_closed = gatherResults([ client.registration_d for client in self._clients.values()]) if self.telnet_server is not None: self.telnet_server.loseConnection() yield wait_for_closed def _format_addr(self, addr): return "%s:%s" % (addr.host, addr.port) def register_client(self, client): # We add our own Deferred to the client here because we only want to # fire it after we're finished with our own deregistration process. client.registration_d = Deferred() client_addr = client.getAddress() self.log.msg("Registering client connected from %r" % client_addr) self._clients[client_addr] = client self.send_inbound_message(client, None, TransportUserMessage.SESSION_NEW) def deregister_client(self, client): self.log.msg("Deregistering client.") self.send_inbound_message( client, None, TransportUserMessage.SESSION_CLOSE) del self._clients[client.getAddress()] client.registration_d.callback(None) def handle_input(self, client, text): self.send_inbound_message(client, text, TransportUserMessage.SESSION_RESUME) def send_inbound_message(self, client, text, session_event): self.publish_message( from_addr=client.getAddress(), to_addr=self._to_addr, session_event=session_event, content=text, transport_name=self.transport_name, transport_type=self._transport_type, ) def handle_outbound_message(self, message): failed = False text = message['content'] if text is None: text = u'' text = u"\n".join(text.splitlines()) client_addr = message['to_addr'] client = self._clients.get(client_addr) if client is None: # unknown addr, deliver to all clients = self._clients.values() text = u"UNKNOWN ADDR [%s]: %s" % (client_addr, text) failed = True else: clients = [client] text = text.encode('utf-8') for client in clients: client.transport.write("%s\n" % text) if message['session_event'] == TransportUserMessage.SESSION_CLOSE: client.transport.loseConnection() if failed: self.publish_nack(message['message_id'], u"Unknown address.") else: self.publish_ack(message['message_id'], message['message_id']) class AddressedTelnetServerTransport(TelnetServerTransport): protocol = AddressedTelnetTransportProtocol PK=JG'"vumi/transports/telnet/__init__.py"""Telnet server transport.""" from vumi.transports.telnet.telnet import (TelnetServerTransport, AddressedTelnetServerTransport) __all__ = ['TelnetServerTransport', 'AddressedTelnetServerTransport'] PK=JG(vumi/transports/telnet/tests/__init__.pyPK=JG7 `/!/!+vumi/transports/telnet/tests/test_telnet.py# coding: utf-8 """Tests for vumi.transports.telnet.transport.""" from twisted.internet.defer import ( inlineCallbacks, DeferredQueue, returnValue, Deferred) from twisted.protocols.basic import LineReceiver from twisted.internet import reactor, protocol from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.telnet import ( TelnetServerTransport, AddressedTelnetServerTransport) from vumi.transports.tests.helpers import TransportHelper NON_ASCII = u"öæł" class ClientProtocol(LineReceiver): def __init__(self): self.queue = DeferredQueue() self.connect_d = Deferred() self.disconnect_d = Deferred() def connectionMade(self): self.connect_d.callback(None) def lineReceived(self, line): self.queue.put(line) def connectionLost(self, reason): self.queue.put("DONE") self.disconnect_d.callback(None) class BaseTelnetServerTransortTestCase(VumiTestCase): transport_class = TelnetServerTransport transport_type = 'telnet' @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(self.transport_class)) self.worker = yield self.tx_helper.get_transport({'telnet_port': 0}) self.client = yield self.make_client() self.add_cleanup(self.wait_for_client_deregistration) yield self.wait_for_client_start() @inlineCallbacks def wait_for_client_deregistration(self): if self.client.transport.connected: self.client.transport.loseConnection() yield self.client.disconnect_d # Kick off the delivery of the deregistration message. yield self.tx_helper.kick_delivery() def wait_for_client_start(self): return self.client.connect_d @inlineCallbacks def make_client(self): addr = self.worker.telnet_server.getHost() cc = protocol.ClientCreator(reactor, ClientProtocol) client = yield cc.connectTCP("127.0.0.1", addr.port) returnValue(client) class TestTelnetServerTransport(BaseTelnetServerTransortTestCase): @inlineCallbacks def test_client_register(self): [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], None) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) @inlineCallbacks def test_client_deregister(self): self.client.transport.loseConnection() [reg, msg] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg['content'], None) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_handle_input(self): self.client.transport.write("foo\n") [reg, msg] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg['content'], "foo") self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_handle_non_ascii_input(self): self.client.transport.write(NON_ASCII.encode("utf-8")) [reg, msg] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg['content'], NON_ASCII) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_outbound_reply(self): [reg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(reg, "reply_foo") line = yield self.client.queue.get() self.assertEqual(line, "reply_foo") self.assertTrue(self.client.transport.connected) [event] = self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'ack') @inlineCallbacks def test_non_ascii_outbound_reply(self): [reg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(reg, NON_ASCII) line = yield self.client.queue.get() self.assertEqual(line, NON_ASCII.encode('utf-8')) self.assertTrue(self.client.transport.connected) @inlineCallbacks def test_non_ascii_outbound_unknown_address(self): [reg] = yield self.tx_helper.wait_for_dispatched_inbound(1) reg['from_addr'] = 'nowhere' yield self.tx_helper.make_dispatch_reply(reg, NON_ASCII) line = yield self.client.queue.get() self.assertEqual( line, (u"UNKNOWN ADDR [nowhere]: %s" % (NON_ASCII,)).encode('utf-8')) self.assertTrue(self.client.transport.connected) [event] = self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['nack_reason'], u'Unknown address.') @inlineCallbacks def test_outbound_close_event(self): [reg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply( reg, "reply_done", continue_session=False) line = yield self.client.queue.get() self.assertEqual(line, "reply_done") line = yield self.client.queue.get() self.assertEqual(line, "DONE") self.assertFalse(self.client.transport.connected) @inlineCallbacks def test_outbound_send(self): [reg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_outbound( "send_foo", to_addr=reg['from_addr']) line = yield self.client.queue.get() self.assertEqual(line, "send_foo") self.assertTrue(self.client.transport.connected) [event] = self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'ack') @inlineCallbacks def test_to_addr_override(self): old_worker = self.worker self.assertEqual( old_worker._to_addr, old_worker._format_addr(old_worker.telnet_server.getHost())) worker = yield self.tx_helper.get_transport({ 'telnet_port': 0, 'to_addr': 'foo' }) self.assertEqual(worker._to_addr, 'foo') yield worker.stopWorker() @inlineCallbacks def test_transport_type_override(self): self.assertEqual(self.worker._transport_type, 'telnet') # Clean up existing unused client. self.client.transport.loseConnection() [m_new, m_close] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(m_new['transport_type'], 'telnet') self.assertEqual(m_close['transport_type'], 'telnet') self.tx_helper.clear_dispatched_inbound() self.worker = yield self.tx_helper.get_transport({ 'telnet_port': 0, 'transport_type': 'foo', }) self.assertEqual(self.worker._transport_type, 'foo') self.client = yield self.make_client() yield self.wait_for_client_start() self.client.transport.write("foo\n") [m_new, msg] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(m_new['transport_type'], 'foo') self.assertEqual(msg['transport_type'], 'foo') class TestAddressedTelnetServerTransport(BaseTelnetServerTransortTestCase): transport_class = AddressedTelnetServerTransport def wait_for_server(self): """Wait for first message from client to be ready.""" return self.client.queue.get() @inlineCallbacks def test_handle_input(self): to_addr_prompt = yield self.wait_for_server() self.assertEqual('Please provide "to_addr":', to_addr_prompt) self.client.transport.write('to_addr\n') from_addr_prompt = yield self.wait_for_server() self.assertEqual('Please provide "from_addr":', from_addr_prompt) self.client.transport.write('from_addr\n') summary = yield self.wait_for_server() self.assertEqual( summary, "[Sending all messages to: to_addr and from: from_addr]") self.client.transport.write('foo!\n') [reg, msg] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(reg['from_addr'], 'from_addr') self.assertEqual(reg['to_addr'], 'to_addr') self.assertEqual(msg['from_addr'], 'from_addr') self.assertEqual(msg['to_addr'], 'to_addr') self.assertEqual(msg['content'], 'foo!') PK=JG,A%vumi/transports/imimobile/__init__.py""" ImiMobile HTTP USSD API. """ from vumi.transports.imimobile.imimobile_ussd import ImiMobileUssdTransport __all__ = ['ImiMobileUssdTransport'] PK=JGmV8!!+vumi/transports/imimobile/imimobile_ussd.py# -*- test-case-name: vumi.transports.imimobile.tests.test_imimobile_ussd -*- import re import json from datetime import datetime, timedelta from twisted.python import log from twisted.web import http from twisted.internet.defer import inlineCallbacks from vumi.components.session import SessionManager from vumi.message import TransportUserMessage from vumi.transports.httprpc import HttpRpcTransport class ImiMobileUssdTransport(HttpRpcTransport): """ HTTP transport for USSD with IMImobile in India. Configuration parameters: :param str transport_name: The name this transport instance will use to create its queues :param str web_path: The HTTP path to listen on. :param int web_port: The HTTP port to listen on. :param dict suffix_to_addrs: Mappings between url suffixes and to addresses. :param str user_terminated_session_message: A regex used to identify user terminated session messages. Default is '^Map Dialog User Abort User Reason'. :param str user_terminated_session_response: Response given back to the user if the user terminated the session. Default is 'Session Ended'. :param dict redis_manager: The configuration parameters for connecting to Redis. :param int ussd_session_timeout: Number of seconds before USSD session information stored in Redis expires. Default is 600s. """ transport_type = 'ussd' ENCODING = 'utf-8' EXPECTED_FIELDS = set(['msisdn', 'msg', 'code', 'tid', 'dcs']) # errors RESPONSE_FAILURE_ERROR = "Response to http request failed." INSUFFICIENT_MSG_FIELDS_ERROR = "Insufficiant message fields provided." def validate_config(self): super(ImiMobileUssdTransport, self).validate_config() # Mappings between url suffixes and the tags used as the to_addr for # inbound messages (e.g. shortcodes or longcodes). This is necessary # since the requests from ImiMobile do not provided us with this. self.suffix_to_addrs = self.config['suffix_to_addrs'] # IMImobile do not provide a parameter or header to signal termination # of the session by the user, other than sending "Map Dialog User Abort # User Reason: User specific reason" as the request's message content. self.user_terminated_session_re = re.compile( self.config.get('user_terminated_session_message', '^Map Dialog User Abort User Reason')) self.user_terminated_session_response = self.config.get( 'user_terminated_session_response', 'Session Ended') @inlineCallbacks def setup_transport(self): super(ImiMobileUssdTransport, self).setup_transport() # configure session manager r_config = self.config.get('redis_manager', {}) r_prefix = "vumi.transports.imimobile_ussd:%s" % self.transport_name session_timeout = int(self.config.get("ussd_session_timeout", 600)) self.session_manager = yield SessionManager.from_redis_config( r_config, r_prefix, max_session_length=session_timeout) @inlineCallbacks def teardown_transport(self): yield super(ImiMobileUssdTransport, self).teardown_transport() yield self.session_manager.stop() def get_to_addr(self, request): """ Extracts the request url path's suffix and uses it to obtain the tag associated with the suffix. Returns a tuple consisting of the tag and a dict of errors encountered. """ errors = {} [suffix] = request.postpath tag = self.suffix_to_addrs.get(suffix, None) if tag is None: errors['unknown_suffix'] = suffix return tag, errors @classmethod def ist_to_utc(cls, timestamp): """ Accepts a timestamp in the format `[M]M/[D]D/YYYY HH:MM:SS (am|pm)` and in India Standard Time, and returns a datetime object normalized to UTC time. """ return (datetime.strptime(timestamp, '%m/%d/%Y %I:%M:%S %p') - timedelta(hours=5, minutes=30)) def user_has_terminated_session(self, content): return self.user_terminated_session_re.match(content) is not None @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): errors = {} to_addr, to_addr_errors = self.get_to_addr(request) errors.update(to_addr_errors) values, field_value_errors = self.get_field_values(request, self.EXPECTED_FIELDS) errors.update(field_value_errors) if errors: log.msg('Unhappy incoming message: %s' % (errors,)) yield self.finish_request( message_id, json.dumps(errors), code=http.BAD_REQUEST) return from_addr = values['msisdn'] log.msg('ImiMobileTransport receiving inbound message from %s to %s.' % (from_addr, to_addr)) content = values['msg'] if self.user_has_terminated_session(content): yield self.session_manager.clear_session(from_addr) session_event = TransportUserMessage.SESSION_CLOSE # IMImobile use 0 for termination of a session self.finish_request( message_id, self.user_terminated_session_response, headers={'X-USSD-SESSION': ['0']}) else: # We use the msisdn (from_addr) to make a guess about the # whether the session is new or not. session = yield self.session_manager.load_session(from_addr) if session: session_event = TransportUserMessage.SESSION_RESUME yield self.session_manager.save_session(from_addr, session) else: session_event = TransportUserMessage.SESSION_NEW yield self.session_manager.create_session( from_addr, from_addr=from_addr, to_addr=to_addr) yield self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=from_addr, provider='imimobile', session_event=session_event, transport_type=self.transport_type, transport_metadata={ 'imimobile_ussd': { 'tid': values['tid'], 'code': values['code'], 'dcs': values['dcs'], } }) @inlineCallbacks def handle_outbound_message(self, message): error = None message_id = message['message_id'] if message.payload.get('in_reply_to') and 'content' in message.payload: # IMImobile use 1 for resume and 0 for termination of a session session_header_value = '1' if message['session_event'] == TransportUserMessage.SESSION_CLOSE: yield self.session_manager.clear_session(message['to_addr']) session_header_value = '0' response_id = self.finish_request( message['in_reply_to'], message['content'].encode(self.ENCODING), headers={'X-USSD-SESSION': [session_header_value]}) if response_id is None: error = self.RESPONSE_FAILURE_ERROR else: error = self.INSUFFICIENT_MSG_FIELDS_ERROR if error is not None: yield self.publish_nack(message_id, error) return yield self.publish_ack(user_message_id=message_id, sent_message_id=message_id) PK=JGEF&F&6vumi/transports/imimobile/tests/test_imimobile_ussd.pyimport json from datetime import datetime from twisted.internet.defer import inlineCallbacks from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.imimobile import ImiMobileUssdTransport from vumi.transports.httprpc.tests.helpers import HttpRpcTransportHelper class TestImiMobileUssdTransport(VumiTestCase): _from_addr = '9221234567' _to_addr = '56263' _request_defaults = { 'msisdn': _from_addr, 'msg': 'Spam Spam Spam Spam Spammity Spam', 'tid': '1', 'dcs': 'no-idea-what-this-is', 'code': 'VUMI', } @inlineCallbacks def setUp(self): self.config = { 'web_port': 0, 'web_path': '/api/v1/imimobile/ussd/', 'user_terminated_session_message': "^Farewell", 'user_terminated_session_response': "You have ended the session", 'suffix_to_addrs': { 'some-suffix': self._to_addr, 'some-other-suffix': '56264', } } self.tx_helper = self.add_helper( HttpRpcTransportHelper(ImiMobileUssdTransport, request_defaults=self._request_defaults)) self.transport = yield self.tx_helper.get_transport(self.config) self.session_manager = self.transport.session_manager self.transport_url = self.transport.get_transport_url( self.config['web_path']) yield self.session_manager.redis._purge_all() # just in case @inlineCallbacks def mk_session(self, from_addr=_from_addr, to_addr=_to_addr): # first pre-populate the redis datastore to simulate session resume # note: imimobile do not provide a session id, so instead we use the # msisdn as the session id yield self.session_manager.create_session( from_addr, to_addr=to_addr, from_addr=from_addr) def assert_message(self, msg, expected_field_values): for field, expected_value in expected_field_values.iteritems(): self.assertEqual(msg[field], expected_value) def assert_inbound_message(self, msg, **field_values): expected_field_values = { 'content': self._request_defaults['msg'], 'to_addr': '56263', 'from_addr': self._request_defaults['msisdn'], 'session_event': TransportUserMessage.SESSION_NEW, 'transport_metadata': { 'imimobile_ussd': { 'tid': self._request_defaults['tid'], 'dcs': self._request_defaults['dcs'], 'code': self._request_defaults['code'], }, } } expected_field_values.update(field_values) for field, expected_value in expected_field_values.iteritems(): self.assertEqual(msg[field], expected_value) def assert_ack(self, ack, reply): self.assertEqual(ack.payload['event_type'], 'ack') self.assertEqual(ack.payload['user_message_id'], reply['message_id']) self.assertEqual(ack.payload['sent_message_id'], reply['message_id']) def assert_nack(self, nack, reply, reason): self.assertEqual(nack.payload['event_type'], 'nack') self.assertEqual(nack.payload['user_message_id'], reply['message_id']) self.assertEqual(nack.payload['nack_reason'], reason) @inlineCallbacks def test_inbound_begin(self): # Second connect is the actual start of the session user_content = "Who are you?" d = self.tx_helper.mk_request('some-suffix', msg=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, session_event=TransportUserMessage.SESSION_NEW, content=user_content) reply_content = "We are the Knights Who Say ... Ni!" reply = msg.reply(reply_content) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(response.delivered_body, reply_content) self.assertEqual( response.headers.getRawHeaders('X-USSD-SESSION'), ['1']) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_inbound_resume_and_reply_with_end(self): from_addr = '9221234567' yield self.mk_session(from_addr) user_content = "Well, what is it you want?" d = self.tx_helper.mk_request('some-suffix', msg=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, session_event=TransportUserMessage.SESSION_RESUME, content=user_content) reply_content = "We want ... a shrubbery!" reply = msg.reply(reply_content, continue_session=False) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(response.delivered_body, reply_content) self.assertEqual( response.headers.getRawHeaders('X-USSD-SESSION'), ['0']) # Assert that the session was removed from the session manager session = yield self.session_manager.load_session(from_addr) self.assertEqual(session, {}) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_inbound_resume_and_reply_with_resume(self): yield self.mk_session() user_content = "Well, what is it you want?" d = self.tx_helper.mk_request('some-suffix', msg=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, session_event=TransportUserMessage.SESSION_RESUME, content=user_content) reply_content = "We want ... a shrubbery!" reply = msg.reply(reply_content, continue_session=True) self.tx_helper.dispatch_outbound(reply) response = yield d self.assertEqual(response.delivered_body, reply_content) self.assertEqual( response.headers.getRawHeaders('X-USSD-SESSION'), ['1']) [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_ack(ack, reply) @inlineCallbacks def test_inbound_close_and_reply(self): from_addr = '9221234567' yield self.mk_session(from_addr=from_addr) user_content = "Farewell, sweet Concorde!" d = self.tx_helper.mk_request('some-suffix', msg=user_content) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assert_inbound_message(msg, session_event=TransportUserMessage.SESSION_CLOSE, content=user_content) # Assert that the session was removed from the session manager session = yield self.session_manager.load_session(from_addr) self.assertEqual(session, {}) response = yield d self.assertEqual(response.delivered_body, "You have ended the session") self.assertEqual( response.headers.getRawHeaders('X-USSD-SESSION'), ['0']) @inlineCallbacks def test_request_with_unknown_suffix(self): response = yield self.tx_helper.mk_request('unk-suffix') self.assertEqual( response.delivered_body, json.dumps({'unknown_suffix': 'unk-suffix'})) self.assertEqual(response.code, 400) @inlineCallbacks def test_request_with_missing_parameters(self): response = yield self.tx_helper.mk_request_raw( 'some-suffix', params={"msg": '', "code": '', "dcs": ''}) self.assertEqual( response.delivered_body, json.dumps({'missing_parameter': ['msisdn', 'tid']})) self.assertEqual(response.code, 400) @inlineCallbacks def test_request_with_unexpected_parameters(self): response = yield self.tx_helper.mk_request( 'some-suffix', unexpected_p1='', unexpected_p2='') self.assertEqual(response.code, 400) body = json.loads(response.delivered_body) self.assertEqual(set(['unexpected_parameter']), set(body.keys())) self.assertEqual( sorted(body['unexpected_parameter']), ['unexpected_p1', 'unexpected_p2']) @inlineCallbacks def test_nack_insufficient_message_fields(self): reply = self.tx_helper.make_outbound( None, message_id='23', in_reply_to=None) self.tx_helper.dispatch_outbound(reply) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_nack( nack, reply, self.transport.INSUFFICIENT_MSG_FIELDS_ERROR) @inlineCallbacks def test_nack_http_http_response_failure(self): self.patch(self.transport, 'finish_request', lambda *a, **kw: None) reply = self.tx_helper.make_outbound( 'There are some who call me ... Tim!', message_id='23', in_reply_to='some-number') self.tx_helper.dispatch_outbound(reply) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assert_nack( nack, reply, self.transport.RESPONSE_FAILURE_ERROR) def test_ist_to_utc(self): self.assertEqual( ImiMobileUssdTransport.ist_to_utc("1/26/2013 03:30:00 pm"), datetime(2013, 1, 26, 10, 0, 0)) self.assertEqual( ImiMobileUssdTransport.ist_to_utc("01/29/2013 04:53:59 am"), datetime(2013, 1, 28, 23, 23, 59)) self.assertEqual( ImiMobileUssdTransport.ist_to_utc("01/31/2013 07:20:00 pm"), datetime(2013, 1, 31, 13, 50, 0)) self.assertEqual( ImiMobileUssdTransport.ist_to_utc("3/8/2013 8:5:5 am"), datetime(2013, 3, 8, 2, 35, 5)) PK=JG+vumi/transports/imimobile/tests/__init__.pyPK=JGw͆88 vumi/transports/truteq/truteq.py# -*- test-case-name: vumi.transports.truteq.tests.test_truteq -*- # -*- coding: utf-8 -*- """TruTeq USSD transport.""" from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.internet.protocol import Factory from txssmi.protocol import SSMIProtocol from txssmi import constants from vumi import log from vumi.components.session import SessionManager from vumi.config import ( ConfigText, ConfigInt, ConfigClientEndpoint, ConfigBool, ConfigDict, ClientEndpointFallback) from vumi.message import TransportUserMessage from vumi.reconnecting_client import ReconnectingClientService from vumi.transports.base import Transport from vumi.utils import normalize_msisdn class TruteqTransportConfig(Transport.CONFIG_CLASS): username = ConfigText( 'Username of the TruTeq account to connect to.', static=True) password = ConfigText( 'Password for the TruTeq account.', static=True) twisted_endpoint = ConfigClientEndpoint( 'The endpoint to connect to.', default='tcp:host=sms.truteq.com:port=50008', static=True, fallbacks=[ClientEndpointFallback()]) link_check_period = ConfigInt( 'Number of seconds between link checks sent to the server.', default=60, static=True) ussd_session_lifetime = ConfigInt( 'Maximum number of seconds to retain USSD session information.', default=300, static=True) debug = ConfigBool( 'Print verbose log output.', default=False, static=True) redis_manager = ConfigDict( 'How to connect to Redis.', default={}, static=True) # TODO: Deprecate these fields when confmodel#5 is done. host = ConfigText( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) port = ConfigInt( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) class TruteqTransportProtocol(SSMIProtocol): def connectionMade(self): config = self.factory.vumi_transport.get_static_config() self.factory.protocol_instance = self self.noisy = config.debug d = self.authenticate(config.username, config.password) d.addCallback( lambda success: ( self.link_check.start(config.link_check_period) if success else self.loseConnection())) return d def connectionLost(self, reason): if self.link_check.running: self.link_check.stop() SSMIProtocol.connectionLost(self, reason) def handle_MO(self, mo): return self.factory.vumi_transport.handle_unhandled_message(mo) def handle_BINARY_MO(self, mo): return self.factory.vumi_transport.handle_unhandled_message(mo) def handle_PREMIUM_MO(self, mo): return self.factory.vumi_transport.handle_unhandled_message(mo) def handle_PREMIUM_BINARY_MO(self, mo): return self.factory.vumi_transport.handle_unhandled_message(mo) def handle_USSD_MESSAGE(self, um): return self.factory.vumi_transport.handle_raw_inbound_message(um) def handle_EXTENDED_USSD_MESSAGE(self, um): return self.factory.vumi_transport.handle_raw_inbound_message(um) def handle_LOGOUT(self, msg): return self.factory.vumi_transport.handle_remote_logout(msg) class TruteqTransport(Transport): """ A transport for TruTeq. Currently only USSD messages are supported. """ CONFIG_CLASS = TruteqTransportConfig service_class = ReconnectingClientService protocol_class = TruteqTransportProtocol encoding = 'iso-8859-1' SSMI_TO_VUMI_EVENT = { constants.USSD_NEW: TransportUserMessage.SESSION_NEW, constants.USSD_RESPONSE: TransportUserMessage.SESSION_RESUME, constants.USSD_END: TransportUserMessage.SESSION_CLOSE, constants.USSD_TIMEOUT: TransportUserMessage.SESSION_CLOSE, } VUMI_TO_SSMI_EVENT = { TransportUserMessage.SESSION_NONE: constants.USSD_RESPONSE, TransportUserMessage.SESSION_NEW: constants.USSD_NEW, TransportUserMessage.SESSION_RESUME: constants.USSD_RESPONSE, TransportUserMessage.SESSION_CLOSE: constants.USSD_END, } @inlineCallbacks def setup_transport(self): config = self.get_static_config() self.client_factory = Factory.forProtocol(self.protocol_class) self.client_factory.vumi_transport = self prefix = "%s:ussd_codes" % (config.transport_name,) self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, prefix, config.ussd_session_lifetime) self.client_service = self.get_service( config.twisted_endpoint, self.client_factory) def get_service(self, endpoint, factory): client_service = self.service_class(endpoint, factory) client_service.startService() return client_service def teardown_transport(self): d = maybeDeferred(self.client_service.stopService) d.addCallback(lambda _: self.session_manager.stop()) return d @inlineCallbacks def handle_raw_inbound_message(self, ussd_message): if ussd_message.command_name == 'EXTENDED_USSD_MESSAGE': genfields = { 'IMSI': '', 'Subscriber Type': '', 'OperatorID': '', 'SessionID': '', 'ValiPort': '', } genfield_values = ussd_message.genfields.split(':') genfields.update( dict(zip(genfields.keys(), genfield_values))) else: genfields = {} session_event = self.SSMI_TO_VUMI_EVENT[ussd_message.type] msisdn = normalize_msisdn(ussd_message.msisdn) message = ussd_message.message.decode(self.encoding) if session_event == TransportUserMessage.SESSION_NEW: # If it's a new session then store the message as the USSD code if not message.endswith('#'): message = '%s#' % (message,) session = yield self.session_manager.create_session( msisdn, ussd_code=message) text = None else: session = yield self.session_manager.load_session(msisdn) text = message if session_event == TransportUserMessage.SESSION_CLOSE: yield self.session_manager.clear_session(msisdn) yield self.publish_message( from_addr=msisdn, to_addr=session['ussd_code'], session_event=session_event, content=text, transport_type='ussd', transport_metadata={}, helper_metadata={ 'truteq': { 'genfields': genfields, } }) def handle_outbound_message(self, message): protocol = self.client_factory.protocol_instance text = message.get('content') or '' # Truteq uses \r as a message delimiter in the protocol. # Make sure we're only sending \n for new lines. text = '\n'.join(text.splitlines()).encode(self.encoding) ssmi_session_type = self.VUMI_TO_SSMI_EVENT[message['session_event']] # We need to send unicode data to ssmi_client, but bytes for msisdn. msisdn = message['to_addr'].strip('+').encode(self.encoding) return protocol.send_ussd_message(msisdn, text, ssmi_session_type) def handle_remote_logout(self, msg): log.warning('Received remote logout command: %r' % ( msg,)) def handle_unhandled_message(self, mo): log.warning('Received unsupported message, dropping: %r.' % ( mo,)) PK=JG'rr"vumi/transports/truteq/__init__.py"""TruTeq transport.""" from vumi.transports.truteq.truteq import TruteqTransport __all__ = ['TruteqTransport'] PKqGw~ + ++vumi/transports/truteq/tests/test_truteq.py# -*- coding: utf-8 -*- """Test for vumi.transport.truteq.truteq.""" from twisted.internet.defer import inlineCallbacks, returnValue, DeferredQueue from twisted.internet.protocol import Protocol from txssmi import constants as c from txssmi.builder import SSMIRequest from txssmi.commands import ( Ack, USSDMessage, ExtendedUSSDMessage, SendUSSDMessage, MoMessage, ServerLogout) from vumi.message import TransportUserMessage from vumi.tests.fake_connection import FakeServer, wait0 from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.tests.helpers import TransportHelper from vumi.transports.truteq import TruteqTransport from vumi.transports.truteq.truteq import TruteqTransportProtocol # To reduce verbosity. SESSION_NEW = TransportUserMessage.SESSION_NEW SESSION_RESUME = TransportUserMessage.SESSION_RESUME SESSION_CLOSE = TransportUserMessage.SESSION_CLOSE SESSION_NONE = TransportUserMessage.SESSION_NONE class SSMIServerProtocol(Protocol): delimiter = TruteqTransportProtocol.delimiter def __init__(self): self.receive_queue = DeferredQueue() self._buf = b"" def dataReceived(self, data): self._buf += data self.parse_commands() def parse_commands(self): while self.delimiter in self._buf: line, _, self._buf = self._buf.partition(self.delimiter) if line: self.receive_queue.put(SSMIRequest.parse(line)) def send(self, command): self.transport.write(str(command)) self.transport.write(self.delimiter) return wait0() def receive(self): return self.receive_queue.get() def disconnect(self): self.transport.loseConnection() class TestTruteqTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(TruteqTransport)) self.fake_server = FakeServer.for_protocol(SSMIServerProtocol) self.config = { 'username': 'username', 'password': 'password', 'twisted_endpoint': self.fake_server.endpoint, } self.transport = yield self.tx_helper.get_transport(self.config) self.conn = yield self.fake_server.await_connection() yield self.conn.await_connected() self.server = self.conn.server_protocol yield self.process_login_commands(self.server, 'username', 'password') @inlineCallbacks def process_login_commands(self, server, username, password): cmd = yield server.receive() self.assertEqual(cmd.command_name, 'LOGIN') self.assertEqual(cmd.username, username) self.assertEqual(cmd.password, password) server.send(Ack(ack_type='1')) link_check = yield server.receive() self.assertEqual(link_check.command_name, 'LINK_CHECK') returnValue(True) def incoming_ussd(self, msisdn="12345678", ussd_type=c.USSD_RESPONSE, phase="ignored", message="Hello"): self.server.send(USSDMessage( msisdn=msisdn, type=ussd_type, phase=c.USSD_PHASE_UNKNOWN, message=message)) @inlineCallbacks def start_ussd(self, message="*678#", **kw): kw.setdefault("msisdn", "12345678") kw.setdefault("phase", c.USSD_PHASE_UNKNOWN) yield self.transport.handle_raw_inbound_message( USSDMessage(type=c.USSD_NEW, message=message, **kw)) self.tx_helper.clear_dispatched_inbound() @inlineCallbacks def check_msg(self, from_addr="+12345678", to_addr="*678#", content=None, session_event=None, helper_metadata=None): default_hmd = {'truteq': {'genfields': {}}} [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], 'ussd') self.assertEqual(msg['transport_metadata'], {}) self.assertEqual( msg['helper_metadata'], helper_metadata or default_hmd) self.assertEqual(msg['from_addr'], from_addr) self.assertEqual(msg['to_addr'], to_addr) self.assertEqual(msg['content'], content) self.assertEqual(msg['session_event'], session_event) self.tx_helper.clear_dispatched_inbound() @inlineCallbacks def test_handle_inbound_ussd_new(self): yield self.server.send(USSDMessage( msisdn='27000000000', type=c.USSD_NEW, message='*678#', phase=c.USSD_PHASE_1)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['to_addr'], '*678#') self.assertEqual(msg['session_event'], SESSION_NEW) self.assertEqual(msg['transport_type'], 'ussd') @inlineCallbacks def test_handle_inbound_extended_ussd_new(self): yield self.server.send(ExtendedUSSDMessage( msisdn='27000000000', type=c.USSD_NEW, message='*678#', genfields='::3', phase=c.USSD_PHASE_1)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['from_addr'], '+27000000000') self.assertEqual(msg['to_addr'], '*678#') self.assertEqual(msg['helper_metadata'], { 'truteq': { 'genfields': { 'IMSI': '', 'OperatorID': '3', 'SessionID': '', 'Subscriber Type': '', 'ValiPort': '', } } }) @inlineCallbacks def test_handle_remote_logout(self): cmd = ServerLogout(ip='127.0.0.1') with LogCatcher() as logger: yield self.server.send(cmd) [warning] = logger.messages() self.assertEqual( warning, "Received remote logout command: %r" % (cmd,)) @inlineCallbacks def test_handle_inbound_ussd_resume(self): yield self.start_ussd() self.incoming_ussd(ussd_type=c.USSD_RESPONSE, message="Hello") yield self.check_msg(content="Hello", session_event=SESSION_RESUME) @inlineCallbacks def test_handle_inbound_ussd_close(self): yield self.start_ussd() self.incoming_ussd(ussd_type=c.USSD_END, message="Done") yield self.check_msg(content="Done", session_event=SESSION_CLOSE) @inlineCallbacks def test_handle_inbound_ussd_timeout(self): yield self.start_ussd() self.incoming_ussd(ussd_type=c.USSD_TIMEOUT, message="Timeout") yield self.check_msg(content="Timeout", session_event=SESSION_CLOSE) @inlineCallbacks def test_handle_inbound_ussd_non_ascii(self): yield self.start_ussd() self.incoming_ussd( ussd_type=c.USSD_TIMEOUT, message=u"föóbær".encode("iso-8859-1")) yield self.check_msg(content=u"föóbær", session_event=SESSION_CLOSE) @inlineCallbacks def test_handle_inbound_ussd_with_comma_in_content(self): yield self.start_ussd() self.incoming_ussd(ussd_type=c.USSD_TIMEOUT, message=u"foo, bar") yield self.check_msg(content=u"foo, bar", session_event=SESSION_CLOSE) @inlineCallbacks def _test_outbound_ussd(self, vumi_session_type, ssmi_session_type, content="Test", encoding="utf-8"): yield self.tx_helper.make_dispatch_outbound( content, to_addr=u"+1234", session_event=vumi_session_type) ussd_call = yield self.server.receive() data = content.encode(encoding) if content else "" self.assertEqual(ussd_call.message, data) self.assertTrue(isinstance(ussd_call.message, str)) self.assertEqual(ussd_call.msisdn, '1234') self.assertEqual(ussd_call.type, ssmi_session_type) def test_handle_outbound_ussd_no_session(self): return self._test_outbound_ussd(SESSION_NONE, c.USSD_RESPONSE) def test_handle_outbound_ussd_null_content(self): return self._test_outbound_ussd(SESSION_NONE, c.USSD_RESPONSE, content=None) def test_handle_outbound_ussd_resume(self): return self._test_outbound_ussd(SESSION_RESUME, c.USSD_RESPONSE) def test_handle_outbound_ussd_close(self): return self._test_outbound_ussd(SESSION_CLOSE, c.USSD_END) def test_handle_outbound_ussd_non_ascii(self): return self._test_outbound_ussd( SESSION_NONE, c.USSD_RESPONSE, content=u"föóbær", encoding='iso-8859-1') @inlineCallbacks def _test_content_wrangling(self, submitted, expected): yield self.tx_helper.make_dispatch_outbound( submitted, to_addr=u"+1234", session_event=SESSION_NONE) # Grab what was sent to Truteq ussd_call = yield self.server.receive() expected_msg = SendUSSDMessage(msisdn='1234', message=expected, type=c.USSD_RESPONSE) self.assertEqual(ussd_call, expected_msg) def test_handle_outbound_ussd_with_comma_in_content(self): return self._test_content_wrangling( 'hello world, universe', 'hello world, universe') def test_handle_outbound_ussd_with_crln_in_content(self): return self._test_content_wrangling( 'hello\r\nwindows\r\nworld', 'hello\nwindows\nworld') def test_handle_outbound_ussd_with_cr_in_content(self): return self._test_content_wrangling( 'hello\rold mac os\rworld', 'hello\nold mac os\nworld') @inlineCallbacks def test_ussd_addr_retains_asterisks_and_hashes(self): self.incoming_ussd(ussd_type=c.USSD_NEW, message="*6*7*8#") yield self.check_msg(to_addr="*6*7*8#", session_event=SESSION_NEW) @inlineCallbacks def test_ussd_addr_appends_hashes_if_missing(self): self.incoming_ussd(ussd_type=c.USSD_NEW, message="*6*7*8") yield self.check_msg(to_addr="*6*7*8#", session_event=SESSION_NEW) @inlineCallbacks def test_handle_inbound_sms(self): cmd = MoMessage(msisdn='foo', message='bar', sequence='1') with LogCatcher() as logger: yield self.server.send(cmd) [warning] = logger.messages() self.assertEqual( warning[:59], "Received unsupported message, dropping: ') ET.SubElement(page, "session_id").text = self.session_id if self.title is not None: ET.SubElement(page, "title").text = self.title for text in self.text: lines = text.split('\n') div = ET.SubElement(page, "div") div.text = lines.pop(0) for line in lines: ET.SubElement(div, "br").tail = line if self.nav: nav = ET.SubElement(page, "navigation") for link in self.nav: ET.SubElement( nav, "link", pageId=link['pageId'], accesskey=link['accesskey']).text = link['text'] # We can't have "\n" in the output at all, it seems. return ET.tostring(page, encoding="UTF-8").replace("\n", "") def __str__(self): return self.to_xml() PK=JGkX_&vumi/transports/mtech_ussd/__init__.py"""Mtech USSD transport.""" from vumi.transports.mtech_ussd.mtech_ussd import MtechUssdTransport __all__ = ['MtechUssdTransport'] PK=JG,vumi/transports/mtech_ussd/tests/__init__.pyPK=JG&W..3vumi/transports/mtech_ussd/tests/test_mtech_ussd.pyfrom twisted.internet.defer import inlineCallbacks, returnValue from vumi.utils import http_request_full from vumi.message import TransportUserMessage from vumi.transports.mtech_ussd import MtechUssdTransport from vumi.transports.mtech_ussd.mtech_ussd import MtechUssdResponse from vumi.transports.tests.helpers import TransportHelper from vumi.tests.helpers import VumiTestCase class TestMtechUssdTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.config = { 'transport_type': 'ussd', 'ussd_string_prefix': '*120*666#', 'web_path': "/foo", 'web_host': "127.0.0.1", 'web_port': 0, 'username': 'testuser', 'password': 'testpass', } self.tx_helper = self.add_helper(TransportHelper(MtechUssdTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport_url = self.transport.get_transport_url().rstrip('/') self.url = "%s%s" % (self.transport_url, self.config['web_path']) yield self.transport.session_manager.redis._purge_all() # just in case def make_ussd_request_full(self, session_id, **kwargs): lines = [ '', '', ' %s' % (session_id,), ] for k, v in kwargs.items(): lines.append(' <%s>%s' % (k, v, k)) lines.append('') data = '\n'.join(lines) return http_request_full(self.url, data, method='POST') def make_ussd_request(self, session_id, **kwargs): return self.make_ussd_request_full(session_id, **kwargs).addCallback( lambda r: r.delivered_body) @inlineCallbacks def reply_to_message(self, content, **kw): [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(msg, content, **kw) returnValue(msg) @inlineCallbacks def test_empty_request(self): response = yield http_request_full(self.url, "", method='POST') self.assertEqual(response.code, 400) @inlineCallbacks def test_bad_request(self): response = yield http_request_full(self.url, "blah", method='POST') self.assertEqual(response.code, 400) @inlineCallbacks def test_inbound_new_continue(self): sid = 'a41739890287485d968ea66e8b44bfd3' response_d = self.make_ussd_request( sid, mobile_number='2348085832481', page_id='0', data='testmenu', gate='gateid') msg = yield self.reply_to_message("OK\n1 < 2") self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['transport_metadata'], {"session_id": sid}) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['from_addr'], '2348085832481') # self.assertEqual(msg['to_addr'], '*120*666#') self.assertEqual(msg['content'], 'testmenu') response = yield response_d correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '
OK
1 < 2
', '', '', '', '
', ]) self.assertEqual(response, correct_response) @inlineCallbacks def test_inbound_resume_continue(self): sid = 'a41739890287485d968ea66e8b44bfd3' yield self.transport.save_session(sid, '2348085832481', '*120*666#') response_d = self.make_ussd_request(sid, page_id="indexX", data="foo") msg = yield self.reply_to_message("OK") self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['transport_metadata'], {"session_id": sid}) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['from_addr'], '2348085832481') self.assertEqual(msg['to_addr'], '*120*666#') self.assertEqual(msg['content'], 'foo') response = yield response_d correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '
OK
', '', '', '', '
', ]) self.assertEqual(response, correct_response) @inlineCallbacks def test_nack(self): msg = yield self.tx_helper.make_dispatch_outbound("outbound") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Missing in_reply_to, content or session_id') @inlineCallbacks def test_inbound_missing_session(self): sid = 'a41739890287485d968ea66e8b44bfd3' response = yield self.make_ussd_request_full( sid, page_id="indexX", data="foo") self.assertEqual(400, response.code) self.assertEqual('', response.delivered_body) @inlineCallbacks def test_inbound_new_and_resume(self): sid = 'a41739890287485d968ea66e8b44bfd3' response_d = self.make_ussd_request( sid, mobile_number='2348085832481', page_id='0', data='testmenu', gate='gateid') msg = yield self.reply_to_message("OK\n1 < 2") self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['transport_metadata'], {"session_id": sid}) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['from_addr'], '2348085832481') # self.assertEqual(msg['to_addr'], '*120*666#') self.assertEqual(msg['content'], 'testmenu') response = yield response_d correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '
OK
1 < 2
', '', '', '', '
', ]) self.assertEqual(response, correct_response) self.tx_helper.clear_all_dispatched() response_d = self.make_ussd_request(sid, page_id="indexX", data="foo") msg = yield self.reply_to_message("OK") self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['transport_metadata'], {"session_id": sid}) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['from_addr'], '2348085832481') self.assertEqual(msg['to_addr'], 'gateid') self.assertEqual(msg['content'], 'foo') response = yield response_d correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '
OK
', '', '', '', '
', ]) self.assertEqual(response, correct_response) @inlineCallbacks def test_inbound_resume_close(self): sid = 'a41739890287485d968ea66e8b44bfd3' yield self.transport.save_session(sid, '2348085832481', '*120*666#') response_d = self.make_ussd_request(sid, page_id="indexX", data="foo") msg = yield self.reply_to_message("OK", continue_session=False) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['transport_metadata'], {"session_id": sid}) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['from_addr'], '2348085832481') self.assertEqual(msg['to_addr'], '*120*666#') self.assertEqual(msg['content'], 'foo') response = yield response_d correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '
OK
', '
', ]) self.assertEqual(response, correct_response) @inlineCallbacks def test_inbound_cancel(self): sid = 'a41739890287485d968ea66e8b44bfd3' yield self.transport.save_session(sid, '2348085832481', '*120*666#') response = yield self.make_ussd_request(sid, status="1") correct_response = ''.join([ "", '', 'a41739890287485d968ea66e8b44bfd3', '', ]) self.assertEqual(response, correct_response) class TestMtechUssdResponse(VumiTestCase): def setUp(self): self.mur = MtechUssdResponse("sid123") def assert_message_xml(self, *lines): xml_str = ''.join( [""] + list(lines)) self.assertEqual(self.mur.to_xml(), xml_str) def test_empty_response(self): self.assert_message_xml( '', 'sid123', '') def test_free_text(self): self.mur.add_text("Please enter your name") self.mur.add_freetext_option() self.assert_message_xml( '', 'sid123', '
Please enter your name
', '', '
') def test_menu_options(self): self.mur.add_text("Please choose:") self.mur.add_menu_item('chicken', '1') self.mur.add_menu_item('beef', '2') self.assert_message_xml( '', 'sid123', '
Please choose:
', '', 'chicken', 'beef', '', '
') def test_menu_options_title(self): self.mur.add_title("LUNCH") self.mur.add_text("Please choose:") self.mur.add_menu_item('chicken', '1') self.mur.add_menu_item('beef', '2') self.assert_message_xml( '', 'sid123', 'LUNCH', '
Please choose:
', '', 'chicken', 'beef', '', '
') PKqG>"vumi/transports/httprpc/httprpc.py# -*- test-case-name: vumi.transports.httprpc.tests.test_httprpc -*- import json from twisted.cred.portal import Portal from twisted.internet.defer import inlineCallbacks, succeed from twisted.internet import reactor from twisted.internet.task import LoopingCall from twisted.web import http from twisted.web.guard import BasicCredentialFactory, HTTPAuthSessionWrapper from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from vumi import log from vumi.config import ( ConfigText, ConfigInt, ConfigBool, ConfigError, ConfigFloat) from vumi.message import TransportStatus from vumi.transports.base import Transport from vumi.transports.httprpc.auth import HttpRpcRealm, StaticAuthChecker from vumi.utils import StatusEdgeDetector class HttpRpcTransportConfig(Transport.CONFIG_CLASS): """Base config definition for transports. You should subclass this and add transport-specific fields. """ web_path = ConfigText("The path to listen for requests on.", static=True) web_port = ConfigInt( "The port to listen for requests on, defaults to `0`.", default=0, static=True) web_username = ConfigText( "The username to require callers to authenticate with. If ``None``" " then no authentication is required. Currently only HTTP Basic" " authentication is supported.", default=None, static=True) web_password = ConfigText( "The password to go with ``web_username``. Must be ``None`` if and" " only if ``web_username`` is ``None``.", default=None, static=True) web_auth_domain = ConfigText( "The name of authentication domain.", default="Vumi HTTP RPC transport", static=True) health_path = ConfigText( "The path to listen for downstream health checks on" " (useful with HAProxy)", default='health', static=True) request_cleanup_interval = ConfigInt( "How often should we actively look for old connections that should" " manually be timed out. Anything less than `1` disables the request" " cleanup meaning that all request objects will be kept in memory" " until the server is restarted, regardless if the remote side has" " dropped the connection or not. Defaults to 5 seconds.", default=5, static=True) request_timeout = ConfigInt( "How long should we wait for the remote side generating the response" " for this synchronous operation to come back. Any connection that has" " waited longer than `request_timeout` seconds will manually be" " closed. Defaults to 4 minutes.", default=(4 * 60), static=True) request_timeout_status_code = ConfigInt( "What HTTP status code should be generated when a timeout occurs." " Defaults to `504 Gateway Timeout`.", default=504, static=True) request_timeout_body = ConfigText( "What HTTP body should be returned when a timeout occurs." " Defaults to ''.", default='', static=True) noisy = ConfigBool( "Defaults to `False` set to `True` to make this transport log" " verbosely.", default=False, static=True) validation_mode = ConfigText( "The mode to operate in. Can be 'strict' or 'permissive'. If 'strict'" " then any parameter received that is not listed in EXPECTED_FIELDS" " nor in IGNORED_FIELDS will raise an error. If 'permissive' then no" " error is raised as long as all the EXPECTED_FIELDS are present.", default='strict', static=True) response_time_down = ConfigFloat( "The maximum time allowed for a response before the service is " "considered `down`", default=10.0, static=True) response_time_degraded = ConfigFloat( "The maximum time allowed for a response before the service is " "considered `degraded`", default=1.0, static=True) def post_validate(self): auth_supplied = (self.web_username is None, self.web_password is None) if any(auth_supplied) and not all(auth_supplied): raise ConfigError("If either web_username or web_password is" " specified, both must be specified") class HttpRpcHealthResource(Resource): isLeaf = True def __init__(self, transport): self.transport = transport Resource.__init__(self) def render_GET(self, request): request.setResponseCode(http.OK) request.do_not_log = True return self.transport.get_health_response() class HttpRpcResource(Resource): isLeaf = True def __init__(self, transport): self.transport = transport Resource.__init__(self) def render_(self, request, request_id=None): request_id = request_id or Transport.generate_message_id() request.setHeader("content-type", self.transport.content_type) self.transport.set_request(request_id, request) self.transport.handle_raw_inbound_message(request_id, request) return NOT_DONE_YET def render_PUT(self, request): return self.render_(request) def render_GET(self, request): return self.render_(request) def render_POST(self, request): return self.render_(request) class HttpRpcTransport(Transport): """Base class for synchronous HTTP transports. Because a reply from an application worker is needed before the HTTP response can be completed, a reply needs to be returned to the same transport worker that generated the inbound message. This means that currently there many only be one transport worker for each instance of this transport of a given name. """ content_type = 'text/plain' CONFIG_CLASS = HttpRpcTransportConfig ENCODING = 'UTF-8' STRICT_MODE = 'strict' PERMISSIVE_MODE = 'permissive' DEFAULT_VALIDATION_MODE = STRICT_MODE KNOWN_VALIDATION_MODES = [STRICT_MODE, PERMISSIVE_MODE] def validate_config(self): config = self.get_static_config() self.web_path = config.web_path self.web_port = config.web_port self.web_username = config.web_username self.web_password = config.web_password self.web_auth_domain = config.web_auth_domain self.health_path = config.health_path.lstrip('/') self.request_timeout = config.request_timeout self.request_timeout_status_code = config.request_timeout_status_code self.noisy = config.noisy self.request_timeout_body = config.request_timeout_body self.gc_requests_interval = config.request_cleanup_interval self._validation_mode = config.validation_mode self.response_time_down = config.response_time_down self.response_time_degraded = config.response_time_degraded if self._validation_mode not in self.KNOWN_VALIDATION_MODES: raise ConfigError('Invalid validation mode: %s' % ( self._validation_mode,)) def get_transport_url(self, suffix=''): """ Get the URL for the HTTP resource. Requires the worker to be started. This is mostly useful in tests, and probably shouldn't be used in non-test code, because the API might live behind a load balancer or proxy. """ addr = self.web_resource.getHost() return "http://%s:%s/%s" % (addr.host, addr.port, suffix.lstrip('/')) def get_authenticated_resource(self, resource): if not self.web_username: return resource realm = HttpRpcRealm(resource) checkers = [ StaticAuthChecker(self.web_username, self.web_password), ] portal = Portal(realm, checkers) cred_factories = [ BasicCredentialFactory(self.web_auth_domain), ] return HTTPAuthSessionWrapper(portal, cred_factories) @inlineCallbacks def setup_transport(self): self._requests = {} self.request_gc = LoopingCall(self.manually_close_requests) self.clock = self.get_clock() self.request_gc.clock = self.clock self.request_gc.start(self.gc_requests_interval) rpc_resource = HttpRpcResource(self) rpc_resource = self.get_authenticated_resource(rpc_resource) # start receipt web resource self.web_resource = yield self.start_web_resources( [ (rpc_resource, self.web_path), (HttpRpcHealthResource(self), self.health_path), ], self.web_port) self.status_detect = StatusEdgeDetector() def add_status(self, **kw): '''Publishes a status if it is not a repeat of the previously published status.''' if self.status_detect.check_status(**kw): return self.publish_status(**kw) return succeed(None) @inlineCallbacks def teardown_transport(self): yield self.web_resource.loseConnection() if self.request_gc.running: self.request_gc.stop() def get_clock(self): """ For easier stubbing in tests """ return reactor def get_field_values(self, request, expected_fields, ignored_fields=frozenset()): values = {} errors = {} for field in request.args: if field not in (expected_fields | ignored_fields): if self._validation_mode == self.STRICT_MODE: errors.setdefault('unexpected_parameter', []).append(field) else: values[field] = ( request.args.get(field)[0].decode(self.ENCODING)) for field in expected_fields: if field not in values: errors.setdefault('missing_parameter', []).append(field) return values, errors def ensure_message_values(self, message, expected_fields): missing_fields = [] for field in expected_fields: if not message[field]: missing_fields.append(field) return missing_fields def manually_close_requests(self): for request_id, request_data in self._requests.items(): timestamp = request_data['timestamp'] response_time = self.clock.seconds() - timestamp if response_time > self.request_timeout: self.on_timeout(request_id, response_time) self.close_request(request_id) def close_request(self, request_id): log.warning('Timing out %s' % (self.get_request_to_addr(request_id),)) self.finish_request(request_id, self.request_timeout_body, self.request_timeout_status_code) def get_health_response(self): return json.dumps({ 'pending_requests': len(self._requests) }) def set_request(self, request_id, request_object, timestamp=None): if timestamp is None: timestamp = self.clock.seconds() self._requests[request_id] = { 'timestamp': timestamp, 'request': request_object, } def get_request(self, request_id): if request_id in self._requests: return self._requests[request_id]['request'] def remove_request(self, request_id): del self._requests[request_id] def emit(self, msg): if self.noisy: log.debug(msg) def handle_outbound_message(self, message): self.emit("HttpRpcTransport consuming %s" % (message)) missing_fields = self.ensure_message_values(message, ['in_reply_to', 'content']) if missing_fields: return self.reject_message(message, missing_fields) else: self.finish_request( message.payload['in_reply_to'], message.payload['content'].encode('utf-8')) return self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) def reject_message(self, message, missing_fields): return self.publish_nack(user_message_id=message['message_id'], sent_message_id=message['message_id'], reason='Missing fields: %s' % ', '.join(missing_fields)) def handle_raw_inbound_message(self, msgid, request): raise NotImplementedError("Sub-classes should implement" " handle_raw_inbound_message.") def finish_request(self, request_id, data, code=200, headers={}): self.emit("HttpRpcTransport.finish_request with data: %s" % ( repr(data),)) request = self.get_request(request_id) if request: for h_name, h_values in headers.iteritems(): request.responseHeaders.setRawHeaders(h_name, h_values) request.setResponseCode(code) request.write(data) request.finish() self.set_request_end(request_id) self.remove_request(request_id) response_id = "%s:%s:%s" % (request.client.host, request.client.port, Transport.generate_message_id()) return response_id # NOTE: This hackery is required so that we know what to_addr a message # was received on. This is useful so we can log more useful debug # information when something goes wrong, like a timeout for example. # # Since all the different transports that subclass this # base class have different implementations for retreiving the # to_addr it's impossible to grab this information higher up # in a consistent manner. def publish_message(self, **kwargs): self.set_request_to_addr(kwargs['message_id'], kwargs['to_addr']) return super(HttpRpcTransport, self).publish_message(**kwargs) def get_request_to_addr(self, request_id): return self._requests[request_id].get('to_addr', 'Unknown') def set_request_to_addr(self, request_id, to_addr): if request_id in self._requests: self._requests[request_id]['to_addr'] = to_addr def set_request_end(self, message_id): '''Checks the saved timestamp to see the response time. If the starting timestamp for the message cannot be found, nothing is done. If the time is more than `response_time_down`, a `down` status event is sent. If the time more than `response_time_degraded`, a `degraded` status event is sent. If the time is less than `response_time_degraded`, an `ok` status event is sent. ''' request = self._requests.get(message_id, None) if request is not None: response_time = self.clock.seconds() - request['timestamp'] if response_time > self.response_time_down: return self.on_down_response_time(message_id, response_time) elif response_time > self.response_time_degraded: return self.on_degraded_response_time(message_id, response_time) else: return self.on_good_response_time(message_id, response_time) def on_down_response_time(self, message_id, time): '''Can be overridden by subclasses to do something when the response time is high enough for the transport to be considered non-functioning.''' pass def on_degraded_response_time(self, message_id, time): '''Can be overridden by subclasses to do something when the response time is high enough for the transport to be considered running in a degraded state.''' pass def on_good_response_time(self, message_id, time): '''Can be overridden by subclasses to do something when the response time is low enough for the transport to be considered running normally.''' pass def on_timeout(self, message_id, time): '''Can be overridden by subclasses to do something when the response times out.''' pass PK=JGlvumi/transports/httprpc/auth.py# -*- coding: utf-8 -*- # -*- test-case-name: vumi.transports.httprpc.tests.test_auth -*- from zope.interface import implements from twisted.cred import portal, checkers, credentials, error from twisted.web import resource class HttpRpcRealm(object): implements(portal.IRealm) def __init__(self, resource): self._resource = resource def requestAvatar(self, user, mind, *interfaces): if resource.IResource in interfaces: return (resource.IResource, self._resource, lambda: None) raise NotImplementedError() class StaticAuthChecker(object): """Checks that a username and password matches given static values. """ implements(checkers.ICredentialsChecker) credentialInterfaces = (credentials.IUsernamePassword,) def __init__(self, username, password): self._username = username self._password = password def requestAvatarId(self, credentials): authorized = all((credentials.username == self._username, credentials.password == self._password)) if not authorized: raise error.UnauthorizedLogin() return self._username PK=JG˘#vumi/transports/httprpc/__init__.py"""Synchronous HTTP RPC-based message transports.""" from vumi.transports.httprpc.httprpc import HttpRpcTransport __all__ = ['HttpRpcTransport'] PKqGKJ%%-vumi/transports/httprpc/tests/test_httprpc.pyimport json from twisted.internet.defer import inlineCallbacks from twisted.internet.task import Clock from vumi.utils import http_request, http_request_full, basic_auth_string from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.httprpc import HttpRpcTransport from vumi.message import TransportUserMessage from vumi.transports.tests.helpers import TransportHelper class OkTransport(HttpRpcTransport): def handle_raw_inbound_message(self, msgid, request): self.publish_message( message_id=msgid, content='', to_addr='to_addr', from_addr='', provider='', session_event=TransportUserMessage.SESSION_NEW, transport_name=self.transport_name, transport_type=self.config.get('transport_type'), transport_metadata={}, ) class TestTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.clock = Clock() self.patch(OkTransport, 'get_clock', lambda _: self.clock) config = { 'web_path': "foo", 'web_port': 0, 'request_timeout': 10, 'request_timeout_status_code': 418, 'request_timeout_body': 'I am a teapot', 'publish_status': True, } self.tx_helper = self.add_helper(TransportHelper(OkTransport)) self.transport = yield self.tx_helper.get_transport(config) self.transport_url = self.transport.get_transport_url() @inlineCallbacks def test_health(self): result = yield http_request(self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), { 'pending_requests': 0 }) @inlineCallbacks def test_inbound(self): d = http_request(self.transport_url + "foo", '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) rep = yield self.tx_helper.make_dispatch_reply(msg, "OK") response = yield d self.assertEqual(response, 'OK') [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['user_message_id'], rep['message_id']) self.assertEqual(ack['sent_message_id'], rep['message_id']) @inlineCallbacks def test_nack(self): msg = yield self.tx_helper.make_dispatch_outbound("outbound") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Missing fields: in_reply_to') @inlineCallbacks def test_timeout(self): d = http_request_full(self.transport_url + "foo", '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) with LogCatcher(message='Timing') as lc: self.clock.advance(10.1) # .1 second after timeout response = yield d [warning] = lc.messages() self.assertEqual(warning, 'Timing out to_addr') self.assertEqual(response.delivered_body, 'I am a teapot') self.assertEqual(response.code, 418) @inlineCallbacks def test_publish_health_status_repeated(self): '''Repeated statuses should not be published, new ones should be.''' yield self.transport.add_status( component='foo', status='ok', type='bar', message='test') yield self.transport.add_status( component='foo', status='ok', type='bar', message='another test') [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'ok') yield self.tx_helper.clear_dispatched_statuses() yield self.transport.add_status( component='foo', status='degraded', type='bar', message='another test') [status] = yield self.tx_helper.get_dispatched_statuses() self.assertEqual(status['status'], 'degraded') class TestTransportWithAuthentication(VumiTestCase): @inlineCallbacks def setUp(self): self.clock = Clock() self.patch(OkTransport, 'get_clock', lambda _: self.clock) config = { 'web_path': "foo", 'web_port': 0, 'web_username': 'user-1', 'web_password': 'pass-secret', 'web_auth_domain': 'Mordor', 'request_timeout': 10, 'request_timeout_status_code': 418, 'request_timeout_body': 'I am a teapot', } self.tx_helper = self.add_helper(TransportHelper(OkTransport)) self.transport = yield self.tx_helper.get_transport(config) self.transport_url = self.transport.get_transport_url() @inlineCallbacks def test_health_doesnt_require_auth(self): result = yield http_request(self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), { 'pending_requests': 0 }) @inlineCallbacks def test_inbound_with_successful_auth(self): headers = { 'Authorization': basic_auth_string("user-1", "pass-secret") } d = http_request(self.transport_url + "foo", '', headers=headers, method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) rep = yield self.tx_helper.make_dispatch_reply(msg, "OK") response = yield d self.assertEqual(response, 'OK') [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['user_message_id'], rep['message_id']) self.assertEqual(ack['sent_message_id'], rep['message_id']) @inlineCallbacks def test_inbound_with_failed_auth(self): headers = { 'Authorization': basic_auth_string("user-1", "bad-pass") } d = http_request(self.transport_url + "foo", '', headers=headers, method='GET') response = yield d self.assertEqual(response, 'Unauthorized') @inlineCallbacks def test_inbound_without_auth(self): d = http_request(self.transport_url + "foo", '', method='GET') response = yield d self.assertEqual(response, 'Unauthorized') class JSONTransport(HttpRpcTransport): def handle_raw_inbound_message(self, msgid, request): request_content = json.loads(request.content.read()) self.publish_message( message_id=msgid, content=request_content['content'], to_addr=request_content['to_addr'], from_addr=request_content['from_addr'], provider='', session_event=TransportUserMessage.SESSION_NEW, transport_name=self.transport_name, transport_type=self.config.get('transport_type'), transport_metadata={}, ) class TestJSONTransport(VumiTestCase): @inlineCallbacks def setUp(self): config = { 'web_path': "foo", 'web_port': 0, } self.tx_helper = self.add_helper(TransportHelper(JSONTransport)) self.transport = yield self.tx_helper.get_transport(config) self.transport_url = self.transport.get_transport_url() @inlineCallbacks def test_inbound(self): d = http_request(self.transport_url + "foo", '{"content": "hello",' ' "to_addr": "the_app",' ' "from_addr": "some_msisdn"' '}', method='POST') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], 'hello') self.assertEqual(msg['to_addr'], 'the_app') self.assertEqual(msg['from_addr'], 'some_msisdn') yield self.tx_helper.make_dispatch_reply(msg, '{"content": "bye"}') response = yield d self.assertEqual(response, '{"content": "bye"}') class CustomOutboundTransport(OkTransport): RESPONSE_HEADERS = { 'Darth-Vader': ["Anakin Skywalker"], 'Admiral-Ackbar': ["It's a trap!", "Shark"] } def handle_outbound_message(self, message): self.finish_request( message.payload['in_reply_to'], message.payload['content'].encode('utf-8'), headers=self.RESPONSE_HEADERS) class TestCustomOutboundTransport(VumiTestCase): @inlineCallbacks def setUp(self): config = { 'web_path': "foo", 'web_port': 0, 'username': 'testuser', 'password': 'testpass', } self.tx_helper = self.add_helper( TransportHelper(CustomOutboundTransport)) self.transport = yield self.tx_helper.get_transport(config) self.transport_url = self.transport.get_transport_url() @inlineCallbacks def test_optional_headers(self): d = http_request_full(self.transport_url + "foo", '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(msg, "OK") response = yield d self.assertEqual( response.headers.getRawHeaders('Darth-Vader'), ["Anakin Skywalker"]) self.assertEqual( response.headers.getRawHeaders('Admiral-Ackbar'), ["It's a trap!", "Shark"]) PK=JG,G66(vumi/transports/httprpc/tests/helpers.pyfrom urllib import urlencode from twisted.internet.defer import inlineCallbacks, returnValue from zope.interface import implements from vumi.errors import VumiError from vumi.tests.helpers import IHelper, generate_proxies from vumi.transports.tests.helpers import TransportHelper from vumi.utils import http_request_full class HttpRpcTransportHelperError(VumiError): """Error raised when the HttpRpcTransportHelper encouters an error.""" class HttpRpcTransportHelper(object): """ Test helper for subclasses of :class:`~vumi.transports.httprpc.HttpRpcTransport`. Adds support for making HTTP requests to the HTTP RPC transport to the base :class:`~vumi.transports.tests.helpers.TransportHelper`. :param dict request_defaults: Default URL parameters for HTTP requests. Other parameters are the same as for :class:`~vumi.transports.tests.helpers.TransportHelper`. """ implements(IHelper) def __init__(self, transport_class, use_riak=False, request_defaults=None, **msg_helper_args): self._transport_helper = TransportHelper( transport_class, use_riak=use_riak, **msg_helper_args) if request_defaults is None: request_defaults = {} self.request_defaults = request_defaults self.transport_url = None generate_proxies(self, self._transport_helper) def setup(self, *args, **kw): return self._transport_helper.setup(*args, **kw) def cleanup(self): return self._transport_helper.cleanup() @inlineCallbacks def get_transport(self, config, cls=None, start=True): transport = yield self._transport_helper.get_transport( config, cls=cls, start=True) self.transport_url = transport.get_transport_url(config['web_path']) returnValue(transport) def mk_request_raw(self, suffix='', params=None, data=None, method='GET'): """ Make an HTTP request, ignoring this helper's ``request_defaults``. :param str suffix: Suffix to add to the transport's URL. :param dict params: A dictionary of URL parameters to append to the URL as a query string or None for no URL parameters. :param str data: Request body or None for no request body. :param str method: HTTP method to use for the request. :raises HttpRpcTransportHelperError: When invoked before calling :meth:`get_transport`. """ if self.transport_url is None: raise HttpRpcTransportHelperError( "call .get_transport() before making HTTP requests.") url = self.transport_url + suffix if params is not None: url += '?%s' % urlencode(params) return http_request_full(url, data=data, method=method) def mk_request(self, _suffix='', _data=None, _method='GET', **kw): """ Make an HTTP request. :param str _suffix: Suffix to add to the transport's URL. :param str _data: Request body or None for no request body. :param str _method: HTTP method to use for the request. :param \*\*kw: URL query string parameters. :raises HttpRpcTransportHelperError: When invoked before calling :meth:`get_transport`. The ``_`` prefixes on the function parameter names are to make accidental clashes with URL query parameter names less likely. """ params = self.request_defaults.copy() params.update(kw) return self.mk_request_raw( suffix=_suffix, params=params, data=_data, method=_method) PK=JG)vumi/transports/httprpc/tests/__init__.pyPK=JG P//*vumi/transports/httprpc/tests/test_auth.py# -*- coding: utf-8 -*- """Tests for vumi.transports.httprpc.auth.""" from twisted.web.resource import IResource from twisted.cred.credentials import UsernamePassword from twisted.cred.error import UnauthorizedLogin from vumi.tests.helpers import VumiTestCase from vumi.transports.httprpc.auth import HttpRpcRealm, StaticAuthChecker class TestHttpRpcRealm(VumiTestCase): def mk_realm(self): resource = object() return resource, HttpRpcRealm(resource) def test_resource_interface(self): user, mind = object(), object() expected_resource, realm = self.mk_realm() interface, resource, cleanup = realm.requestAvatar( user, mind, IResource) self.assertEqual(interface, IResource) self.assertEqual(resource, expected_resource) self.assertEqual(cleanup(), None) def test_unknown_interface(self): user, mind = object(), object() expected_resource, realm = self.mk_realm() self.assertRaises(NotImplementedError, realm.requestAvatar, user, mind, *[]) class TestStaticAuthChecker(VumiTestCase): def test_valid_credentials(self): checker = StaticAuthChecker("user", "pass") creds = UsernamePassword("user", "pass") self.assertEqual(checker.requestAvatarId(creds), "user") def test_invalid_credentials(self): checker = StaticAuthChecker("user", "pass") creds = UsernamePassword("user", "bad-pass") self.assertRaises(UnauthorizedLogin, checker.requestAvatarId, creds) PK=JG&vumi/transports/mtn_rwanda/__init__.pyPK=JGk  -vumi/transports/mtn_rwanda/mtn_rwanda_ussd.py# -*- test-case-name: vumi.transports.mtn_rwanda.tests.test_mtn_rwanda_ussd -*- from datetime import datetime from twisted.internet import reactor from twisted.web import xmlrpc from twisted.internet.defer import inlineCallbacks, Deferred, returnValue from vumi.message import TransportUserMessage from vumi.transports.base import Transport from vumi.config import ( ConfigServerEndpoint, ConfigInt, ConfigDict, ConfigText, ServerEndpointFallback) from vumi.components.session import SessionManager from vumi.transports.httprpc.httprpc import HttpRpcHealthResource from vumi.utils import build_web_site class MTNRwandaUSSDTransportConfig(Transport.CONFIG_CLASS): """ MTN Rwanda USSD transport configuration. """ twisted_endpoint = ConfigServerEndpoint( "The listening endpoint that the remote client will connect to.", required=True, static=True, fallbacks=[ServerEndpointFallback()]) timeout = ConfigInt( "No. of seconds to wait before removing a request that hasn't " "received a response yet.", default=30, static=True) redis_manager = ConfigDict( "Parameters to connect to redis with", default={}, static=True) session_timeout_period = ConfigInt( "Maximum length of a USSD session", default=600, static=True) web_path = ConfigText( "The path to serve this resource on.", required=True, static=True) health_path = ConfigText( "The path to serve the health resource on.", default='/health/', static=True) # TODO: Deprecate these fields when confmodel#5 is done. host = ConfigText( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) port = ConfigInt( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) class RequestTimedOutError(Exception): pass class InvalidRequest(Exception): pass class MTNRwandaUSSDTransport(Transport): transport_type = 'ussd' xmlrpc_server = None CONFIG_CLASS = MTNRwandaUSSDTransportConfig ENCODING = 'UTF-8' @inlineCallbacks def setup_transport(self): """ Transport specific setup - it sets up a connection. """ self._requests = {} self._requests_deferreds = {} self.callLater = reactor.callLater config = self.get_static_config() self.endpoint = config.twisted_endpoint self.timeout = config.timeout r_prefix = "vumi.transports.mtn_rwanda:%s" % self.transport_name self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, r_prefix, config.session_timeout_period) self.factory = build_web_site({ config.health_path: HttpRpcHealthResource(self), config.web_path: MTNRwandaXMLRPCResource(self), }) self.xmlrpc_server = yield self.endpoint.listen(self.factory) @inlineCallbacks def teardown_transport(self): """ Clean-up of setup done in setup_transport. """ self.session_manager.stop() if self.xmlrpc_server is not None: yield self.xmlrpc_server.stopListening() def get_health_response(self): return "OK" def set_request(self, request_id, request_args): self._requests[request_id] = request_args return request_args def get_request(self, request_id): if request_id in self._requests: request = self._requests[request_id] return request def remove_request(self, request_id): del self._requests[request_id] def timed_out(self, request_id): d = self._requests_deferreds[request_id] self.remove_request(request_id) d.errback(RequestTimedOutError( "Request %r timed out." % (request_id,))) REQUIRED_INBOUND_MESSAGE_FIELDS = set([ 'TransactionId', 'TransactionTime', 'MSISDN', 'USSDServiceCode', 'USSDRequestString']) def validate_inbound_data(self, msg_params): missing_fields = ( self.REQUIRED_INBOUND_MESSAGE_FIELDS - set(msg_params)) if missing_fields: return False else: return True @inlineCallbacks def handle_raw_inbound_request(self, message_id, values, d): """ Called by the XML-RPC server when it receives a payload that needs processing. """ self.timeout_request = self.callLater(self.timeout, self.timed_out, message_id) self._requests[message_id] = values self._requests_deferreds[message_id] = d if not self.validate_inbound_data(values.keys()): self.timeout_request.cancel() self.remove_request(message_id) d.errback(InvalidRequest("4001: Missing Parameters")) else: session_id = values['TransactionId'] session = yield self.session_manager.load_session(session_id) if session: session_event = TransportUserMessage.SESSION_RESUME content = values['USSDRequestString'] else: yield self.session_manager.create_session( session_id, from_addr=values['MSISDN'], to_addr=values['USSDServiceCode']) session_event = TransportUserMessage.SESSION_NEW content = None metadata = { 'transaction_id': values['TransactionId'], 'transaction_time': values['TransactionTime'], } res = yield self.publish_message( message_id=message_id, content=content, from_addr=values['MSISDN'], to_addr=values['USSDServiceCode'], session_event=session_event, transport_type=self.transport_type, transport_metadata={'mtn_rwanda_ussd': metadata} ) returnValue(res) @inlineCallbacks def finish_request(self, request_id, data, session_event): request = self.get_request(request_id) del request['USSDRequestString'] request['USSDResponseString'] = data request['TransactionTime'] = datetime.now().isoformat() if session_event == TransportUserMessage.SESSION_NEW: request['action'] = 'request' elif session_event == TransportUserMessage.SESSION_CLOSE: request['action'] = 'end' yield self.session_manager.clear_session(request['TransactionId']) elif session_event == TransportUserMessage.SESSION_RESUME: request['action'] = 'notify' self.set_request(request_id, request) d = self._requests_deferreds[request_id] self.remove_request(request_id) d.callback(request) def handle_outbound_message(self, message): """ Read outbound message and do what needs to be done with them. """ request_id = message['in_reply_to'] if self.get_request(request_id) is None: return self.publish_nack(user_message_id=message['message_id'], sent_message_id=message['message_id'], reason='Request not found') self.timeout_request.cancel() self.finish_request(request_id, message.payload['content'].encode('utf-8'), message['session_event']) return self.publish_ack(user_message_id=request_id, sent_message_id=request_id) class MTNRwandaXMLRPCResource(xmlrpc.XMLRPC): """ A Resource object implementing XML-RPC, can be published using twisted.web.server.Site. """ def __init__(self, transport): self.transport = transport xmlrpc.XMLRPC.__init__(self, allowNone=True) def xmlrpc_handleUSSD(self, request_data): request_id = Transport.generate_message_id() d = Deferred() self.transport.handle_raw_inbound_request(request_id, request_data, d) return d PK=JGi'558vumi/transports/mtn_rwanda/tests/test_mtn_rwanda_ussd.pyimport xmlrpclib from datetime import datetime from twisted.internet.defer import inlineCallbacks from twisted.internet import endpoints, tcp from twisted.internet.task import Clock from twisted.web.xmlrpc import Proxy from vumi.message import TransportUserMessage from vumi.transports.mtn_rwanda.mtn_rwanda_ussd import ( MTNRwandaUSSDTransport, RequestTimedOutError, InvalidRequest) from vumi.tests.helpers import VumiTestCase from vumi.transports.tests.helpers import TransportHelper class TestMTNRwandaUSSDTransport(VumiTestCase): session_id = 'session_id' @inlineCallbacks def setUp(self): """ Create the server (i.e. vumi transport instance) """ self.clock = Clock() self.tx_helper = self.add_helper( TransportHelper(MTNRwandaUSSDTransport)) self.transport = yield self.tx_helper.get_transport({ 'twisted_endpoint': 'tcp:port=0', 'timeout': '30', 'web_path': '/foo/', }) self.transport.callLater = self.clock.callLater self.session_manager = self.transport.session_manager def test_transport_creation(self): self.assertIsInstance(self.transport, MTNRwandaUSSDTransport) self.assertIsInstance(self.transport.endpoint, endpoints.TCP4ServerEndpoint) self.assertIsInstance(self.transport.xmlrpc_server, tcp.Port) def test_transport_teardown(self): d = self.transport.teardown_transport() self.assertTrue(self.transport.xmlrpc_server.disconnecting) return d def assert_inbound_message(self, expected_payload, msg, **field_values): field_values['message_id'] = msg['message_id'] expected_payload.update(field_values) for field, expected_value in expected_payload.iteritems(): self.assertEqual(msg[field], expected_value) @inlineCallbacks def test_inbound_request_and_reply(self): address = self.transport.xmlrpc_server.getHost() url = 'http://' + address.host + ':' + str(address.port) + '/foo/' proxy = Proxy(url) x = proxy.callRemote('handleUSSD', { 'TransactionId': '0001', 'USSDServiceCode': '543', 'USSDRequestString': '14321*1000#', 'MSISDN': '275551234', 'USSDEncoding': 'GSM0338', # Optional 'TransactionTime': '2013-07-05T22:58:47.565596' }) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) expected_inbound_payload = { 'message_id': '', 'content': None, 'from_addr': '', # msisdn 'to_addr': '', # service code 'session_event': TransportUserMessage.SESSION_RESUME, 'transport_name': self.tx_helper.transport_name, 'transport_type': 'ussd', 'transport_metadata': { 'mtn_rwanda_ussd': { 'transaction_id': '0001', 'transaction_time': '2013-07-05T22:58:47.565596', }, }, } yield self.assert_inbound_message( expected_inbound_payload, msg, from_addr='275551234', to_addr='543', session_event=TransportUserMessage.SESSION_NEW) expected_reply = {'MSISDN': '275551234', 'TransactionId': '0001', 'TransactionTime': datetime.now().isoformat(), 'USSDEncoding': 'GSM0338', 'USSDResponseString': 'Test message', 'USSDServiceCode': '543', 'action': 'end'} self.tx_helper.make_dispatch_reply( msg, expected_reply['USSDResponseString'], continue_session=False) received_text = yield x for key in received_text.keys(): if key == 'TransactionTime': self.assertEqual(len(received_text[key]), len(expected_reply[key])) else: self.assertEqual(expected_reply[key], received_text[key]) @inlineCallbacks def test_nack(self): msg = yield self.tx_helper.make_dispatch_outbound("outbound") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Request not found') @inlineCallbacks def test_inbound_faulty_request(self): address = self.transport.xmlrpc_server.getHost() url = 'http://' + address.host + ':' + str(address.port) + '/foo/' proxy = Proxy(url) try: yield proxy.callRemote('handleUSSD', { 'TransactionId': '0001', 'USSDServiceCode': '543', 'USSDRequestString': '14321*1000#', 'MSISDN': '275551234', 'USSDEncoding': 'GSM0338', }) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) except xmlrpclib.Fault, e: self.assertEqual(e.faultCode, 8002) self.assertEqual(e.faultString, 'error') else: self.fail('We expected an invalid request error.') [failure] = self.flushLoggedErrors(InvalidRequest) err = failure.value self.assertEqual(str(err), '4001: Missing Parameters') @inlineCallbacks def test_timeout(self): address = self.transport.xmlrpc_server.getHost() url = 'http://' + address.host + ':' + str(address.port) + '/foo/' proxy = Proxy(url) x = proxy.callRemote('handleUSSD', { 'TransactionId': '0001', 'USSDServiceCode': '543', 'USSDRequestString': '14321*1000#', 'MSISDN': '275551234', 'TransactionTime': '2013-07-05T22:58:47.565596' }) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.clock.advance(30) try: yield x except xmlrpclib.Fault, e: self.assertEqual(e.faultCode, 8002) self.assertEqual(e.faultString, 'error') else: self.fail('We expected a timeout error.') [failure] = self.flushLoggedErrors(RequestTimedOutError) err = failure.value self.assertTrue(str(err).endswith('timed out.')) PK=JG,vumi/transports/mtn_rwanda/tests/__init__.pyPKqG{f7!7!vumi/transports/mxit/mxit.py# -*- test-case-name: vumi.transports.mxit.tests.test_mxit -*- import json import base64 from urllib import urlencode, unquote_plus from HTMLParser import HTMLParser from twisted.web import http from twisted.internet.defer import inlineCallbacks, returnValue from vumi.config import ConfigText, ConfigInt, ConfigDict, ConfigList from vumi.persist.txredis_manager import TxRedisManager from vumi.transports.httprpc import HttpRpcTransport from vumi.transports.mxit.responses import MxitResponse from vumi.utils import http_request_full class MxitTransportException(Exception): """Raised when the Mxit API returns an error""" class MxitTransportConfig(HttpRpcTransport.CONFIG_CLASS): client_id = ConfigText( 'The OAuth2 ClientID assigned to this transport.', required=True, static=True) client_secret = ConfigText( 'The OAuth2 ClientSecret assigned to this transport.', required=True, static=True) timeout = ConfigInt( 'Timeout for outbound Mxit HTTP API calls.', required=False, default=30, static=True) redis_manager = ConfigDict( 'How to connect to Redis', required=True, static=True) api_send_url = ConfigText( 'The URL for the Mxit message sending API.', required=False, default="https://api.mxit.com/message/send/", static=True) api_auth_url = ConfigText( 'The URL for the Mxit authentication API.', required=False, default='https://auth.mxit.com', static=True) api_auth_scopes = ConfigList( 'The list of scopes to request access to.', required=False, static=True, default=['message/send']) class MxitTransport(HttpRpcTransport): """ HTTP Transport for MXit, implemented using the MXit Mobi Portal (for inbound messages and replies) and the Messaging API (for sends that aren't replies). * Mobi Portal API specification: http://dev.mxit.com/docs/mobi-portal-api * Message API specification: https://dev.mxit.com/docs/restapi/messaging/post-message-send """ CONFIG_CLASS = MxitTransportConfig content_type = 'text/html; charset=utf-8' transport_type = 'mxit' access_token_key = 'access_token' access_token_auto_decay = 0.95 agent_factory = None # For swapping out the Agent we use in tests. @inlineCallbacks def setup_transport(self): yield super(MxitTransport, self).setup_transport() config = self.get_static_config() self.redis = yield TxRedisManager.from_config(config.redis_manager) def is_mxit_request(self, request): return request.requestHeaders.hasHeader('X-Mxit-Contact') def noop(self, key): return key def parse_location(self, location): return dict(zip([ 'country_code', 'country_name', 'subdivision_code', 'subdivision_name', 'city_code', 'city', 'network_operator_id', 'client_features_bitset', 'cell_id' ], location.split(','))) def parse_profile(self, profile): return dict(zip([ 'language_code', 'country_code', 'date_of_birth', 'gender', 'tariff_plan', ], profile.split(','))) def html_decode(self, html): """ Turns '<b>foo</b>' into u'foo' """ return HTMLParser().unescape(html) def get_request_data(self, request): headers = request.requestHeaders header_ops = [ ('X-Device-User-Agent', self.noop), ('X-Mxit-Contact', self.noop), ('X-Mxit-USERID-R', self.noop), ('X-Mxit-Nick', self.noop), ('X-Mxit-Location', self.parse_location), ('X-Mxit-Profile', self.parse_profile), ('X-Mxit-User-Input', self.html_decode), ] data = {} for header, proc in header_ops: if headers.hasHeader(header): [value] = headers.getRawHeaders(header) data[header] = proc(value) return data def get_request_content(self, request): headers = request.requestHeaders [content] = headers.getRawHeaders('X-Mxit-User-Input', [None]) if content: return unquote_plus(content) if request.args and 'input' in request.args: [content] = request.args['input'] return content return None def handle_raw_inbound_message(self, msg_id, request): if not self.is_mxit_request(request): return self.finish_request( msg_id, data=http.RESPONSES[http.BAD_REQUEST], code=http.BAD_REQUEST) data = self.get_request_data(request) content = self.get_request_content(request) return self.publish_message( message_id=msg_id, content=content, to_addr=data['X-Mxit-Contact'], from_addr=data['X-Mxit-USERID-R'], provider='mxit', transport_type=self.transport_type, helper_metadata={ 'mxit_info': data, }) def handle_outbound_message(self, message): self.emit("MxitTransport consuming %s" % (message)) if message["in_reply_to"] is None: return self.handle_outbound_send(message) else: return self.handle_outbound_reply(message) @inlineCallbacks def handle_outbound_reply(self, message): missing_fields = self.ensure_message_values( message, ['in_reply_to']) if missing_fields: yield self.reject_message(message, missing_fields) else: yield self.render_response(message) yield self.publish_ack( user_message_id=message['message_id'], sent_message_id=message['message_id']) @inlineCallbacks def get_access_token(self): access_token = yield self.redis.get(self.access_token_key) if access_token is None: access_token, expiry = yield self.request_new_access_token() # always make sure we expire before the token actually does safe_expiry = expiry * self.access_token_auto_decay yield self.redis.setex( self.access_token_key, int(safe_expiry), access_token) returnValue(access_token) @inlineCallbacks def request_new_access_token(self): config = self.get_static_config() url = '%s/token' % (config.api_auth_url) auth = base64.b64encode( '%s:%s' % (config.client_id, config.client_secret)) headers = { 'Content-Type': 'application/x-www-form-urlencoded', 'Authorization': 'Basic %s' % (auth,) } data = urlencode({ 'grant_type': 'client_credentials', 'scope': ' '.join(config.api_auth_scopes) }) response = yield http_request_full( url=url, method='POST', headers=headers, data=data, agent_class=self.agent_factory) data = json.loads(response.delivered_body) if 'error' in data: raise MxitTransportException( '%(error)s: %(error_description)s.' % data) returnValue( (data['access_token'].encode('utf8'), int(data['expires_in']))) @inlineCallbacks def handle_outbound_send(self, message): config = self.get_static_config() body = message['content'] access_token = yield self.get_access_token() headers = { "Content-Type": "application/json", "Authorization": "Bearer %s" % (access_token,) } data = { "Body": body, "ContainsMarkup": "true", "From": message["from_addr"], "To": message["to_addr"], "Spool": "true", } context_factory = None yield http_request_full( config.api_send_url, data=json.dumps(data), headers=headers, method="POST", timeout=config.timeout, context_factory=context_factory, agent_class=self.agent_factory) @inlineCallbacks def render_response(self, message): msg_id = message['in_reply_to'] request = self.get_request(msg_id) if request: data = yield MxitResponse(message).flatten() super(MxitTransport, self).finish_request( msg_id, data, code=http.OK) PK=JGPQQ vumi/transports/mxit/__init__.pyfrom vumi.transports.mxit.mxit import MxitTransport __all__ = ['MxitTransport'] PK=JGQC!vumi/transports/mxit/responses.pyimport re from twisted.web.template import Element, renderer, XMLFile, flattenString from twisted.python.filepath import FilePath from vumi.utils import PkgResources MXIT_RESOURCES = PkgResources(__name__) class ResponseParser(object): HEADER_PATTERN = r'^(.*)[\r\n]{1,2}\d?' ITEM_PATTERN = r'^(\d+)\. (.+)$' def __init__(self, content): header_match = re.match(self.HEADER_PATTERN, content) if header_match: [self.header] = header_match.groups() self.items = re.findall(self.ITEM_PATTERN, content, re.MULTILINE) else: self.header = content self.items = [] @classmethod def parse(cls, content): p = cls(content) return p.header, p.items class MxitResponse(Element): loader = XMLFile(FilePath(MXIT_RESOURCES.path('templates/response.xml'))) def __init__(self, message, loader=None): self.header, self.items = ResponseParser.parse( message['content'] or u'') super(MxitResponse, self).__init__(loader or self.loader) @renderer def render_header(self, request, tag): return tag(self.header) @renderer def render_body(self, request, tag): if not self.items: return '' return tag @renderer def render_item(self, request, tag): for index, text in self.items: yield tag.clone().fillSlots(index=str(index), text=text) def flatten(self): return flattenString(None, self) PK=JGEwBB+vumi/transports/mxit/templates/response.xml

PK=H .&&'vumi/transports/mxit/tests/test_mxit.pyimport json import base64 from twisted.internet.defer import inlineCallbacks, DeferredQueue from twisted.web.http import BAD_REQUEST from twisted.web.server import NOT_DONE_YET from twisted.web.http_headers import Headers from vumi.transports.mxit import MxitTransport from vumi.transports.mxit.responses import ResponseParser from vumi.utils import http_request_full from vumi.tests.helpers import VumiTestCase from vumi.tests.fake_connection import FakeHttpServer from twisted.web.test.requesthelper import DummyRequest as TwistedDummyRequest from vumi.transports.tests.helpers import TransportHelper class DummyRequest(TwistedDummyRequest): def __init__(self, *args, **kw): # Twisted 13.2.0 doesn't have .requestHeaders on DummyRequest TwistedDummyRequest.__init__(self, *args, **kw) if not hasattr(self, 'requestHeaders'): self.requestHeaders = Headers() class TestMxitTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.fake_http = FakeHttpServer(self.handle_request) self.http_request_queue = DeferredQueue() config = { 'web_port': 0, 'web_path': '/api/v1/mxit/mobiportal/', 'client_id': 'client_id', 'client_secret': 'client_secret', } self.sample_loc_str = 'cc,cn,sc,sn,cc,c,noi,cfb,ci' self.sample_profile_str = 'lc,cc,dob,gender,tariff' self.sample_html_str = '<&>' self.sample_req_headers = { 'X-Device-User-Agent': 'ua', 'X-Mxit-Contact': 'contact', 'X-Mxit-USERID-R': 'user-id', 'X-Mxit-Nick': 'nick', 'X-Mxit-Location': self.sample_loc_str, 'X-Mxit-Profile': self.sample_profile_str, 'X-Mxit-User-Input': self.sample_html_str, } self.sample_menu_resp = "\n".join([ "Hello!", "1. option 1", "2. option 2", "3. option 3", ]) # same as above but the o's are replaced with # http://www.fileformat.info/info/unicode/char/f8/index.htm slashed_o = '\xc3\xb8' self.sample_unicode_menu_resp = unicode( self.sample_menu_resp.replace('o', slashed_o), 'utf-8') self.tx_helper = self.add_helper(TransportHelper(MxitTransport)) self.transport = yield self.tx_helper.get_transport(config) self.transport.agent_factory = self.fake_http.get_agent # NOTE: priming redis with an access token self.transport.redis.set(self.transport.access_token_key, 'foo') self.url = self.transport.get_transport_url(config['web_path']) def handle_request(self, request): self.http_request_queue.put(request) return NOT_DONE_YET def test_is_mxit_request(self): req = DummyRequest([]) self.assertFalse(self.transport.is_mxit_request(req)) req.requestHeaders.addRawHeader('X-Mxit-Contact', 'foo') self.assertTrue(self.transport.is_mxit_request(req)) def test_noop(self): self.assertEqual(self.transport.noop('foo'), 'foo') def test_parse_location(self): self.assertEqual(self.transport.parse_location(self.sample_loc_str), { 'country_code': 'cc', 'country_name': 'cn', 'subdivision_code': 'sc', 'subdivision_name': 'sn', 'city_code': 'cc', 'city': 'c', 'network_operator_id': 'noi', 'client_features_bitset': 'cfb', 'cell_id': 'ci', }) def test_parse_profile(self): self.assertEqual( self.transport.parse_profile(self.sample_profile_str), { 'country_code': 'cc', 'date_of_birth': 'dob', 'gender': 'gender', 'language_code': 'lc', 'tariff_plan': 'tariff', }) def test_html_decode(self): self.assertEqual( self.transport.html_decode(self.sample_html_str), '<&>') def test_get_request_data(self): req = DummyRequest([]) headers = req.requestHeaders for key, value in self.sample_req_headers.items(): headers.addRawHeader(key, value) data = self.transport.get_request_data(req) self.assertEqual(data, { 'X-Device-User-Agent': 'ua', 'X-Mxit-Contact': 'contact', 'X-Mxit-Location': { 'cell_id': 'ci', 'city': 'c', 'city_code': 'cc', 'client_features_bitset': 'cfb', 'country_code': 'cc', 'country_name': 'cn', 'network_operator_id': 'noi', 'subdivision_code': 'sc', 'subdivision_name': 'sn', }, 'X-Mxit-Nick': 'nick', 'X-Mxit-Profile': { 'country_code': 'cc', 'date_of_birth': 'dob', 'gender': 'gender', 'language_code': 'lc', 'tariff_plan': 'tariff', }, 'X-Mxit-USERID-R': 'user-id', 'X-Mxit-User-Input': u'<&>', }) def test_get_request_content_from_header(self): req = DummyRequest([]) req.requestHeaders.addRawHeader('X-Mxit-User-Input', 'foo') self.assertEqual(self.transport.get_request_content(req), 'foo') def test_get_quote_plus_request_content_from_header(self): req = DummyRequest([]) req.requestHeaders.addRawHeader('X-Mxit-User-Input', 'foo+bar') self.assertEqual( self.transport.get_request_content(req), 'foo bar') def test_get_quoted_request_content_from_header(self): req = DummyRequest([]) req.requestHeaders.addRawHeader('X-Mxit-User-Input', 'foo%20bar') self.assertEqual( self.transport.get_request_content(req), 'foo bar') def test_get_request_content_from_args(self): req = DummyRequest([]) req.args = {'input': ['bar']} self.assertEqual(self.transport.get_request_content(req), 'bar') def test_get_request_content_when_missing(self): req = DummyRequest([]) self.assertEqual(self.transport.get_request_content(req), None) @inlineCallbacks def test_invalid_request(self): resp = yield http_request_full(self.url) self.assertEqual(resp.code, BAD_REQUEST) @inlineCallbacks def test_request(self): resp_d = http_request_full( self.url, headers=self.sample_req_headers) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, self.sample_menu_resp) resp = yield resp_d self.assertTrue('1. option 1' in resp.delivered_body) self.assertTrue('2. option 2' in resp.delivered_body) self.assertTrue('3. option 3' in resp.delivered_body) self.assertTrue('?input=1' in resp.delivered_body) self.assertTrue('?input=2' in resp.delivered_body) self.assertTrue('?input=3' in resp.delivered_body) def test_response_parser(self): header, items = ResponseParser.parse(self.sample_menu_resp) self.assertEqual(header, 'Hello!') self.assertEqual(items, [ ('1', 'option 1'), ('2', 'option 2'), ('3', 'option 3'), ]) header, items = ResponseParser.parse('foo!') self.assertEqual(header, 'foo!') self.assertEqual(items, []) @inlineCallbacks def test_unicode_rendering(self): resp_d = http_request_full( self.url, headers=self.sample_req_headers) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, self.sample_unicode_menu_resp) resp = yield resp_d self.assertTrue( 'Hell\xc3\xb8' in resp.delivered_body) self.assertTrue( '\xc3\xb8pti\xc3\xb8n 1' in resp.delivered_body) @inlineCallbacks def test_outbound_that_is_not_a_reply(self): d = self.tx_helper.make_dispatch_outbound( content="Send!", to_addr="mxit-1", from_addr="mxit-2") req = yield self.http_request_queue.get() body = json.load(req.content) self.assertEqual(body, { 'Body': 'Send!', 'To': 'mxit-1', 'From': 'mxit-2', 'ContainsMarkup': 'true', 'Spool': 'true', }) [auth] = req.requestHeaders.getRawHeaders('Authorization') # primed access token self.assertEqual(auth, 'Bearer foo') req.finish() yield d @inlineCallbacks def test_getting_access_token(self): transport = self.transport redis = transport.redis # clear primed value yield redis.delete(transport.access_token_key) d = transport.get_access_token() req = yield self.http_request_queue.get() [auth] = req.requestHeaders.getRawHeaders('Authorization') self.assertEqual( auth, 'Basic %s' % ( base64.b64encode('client_id:client_secret'))) self.assertEqual( ['grant_type=client_credentials', 'scope=message%2Fsend'], sorted(req.content.read().split('&'))) req.write(json.dumps({ 'access_token': 'access_token', 'expires_in': '10' })) req.finish() access_token = yield d self.assertEqual(access_token, 'access_token') self.assertFalse(isinstance(access_token, unicode)) ttl = yield redis.ttl(transport.access_token_key) self.assertTrue( 0 < ttl <= (transport.access_token_auto_decay * 10)) PK=JG&vumi/transports/mxit/tests/__init__.pyPK=JG[[vumi/transports/api/oldapi.py# -*- test-case-name: vumi.transports.api.tests.test_oldapi -*- import json import re from base64 import b64decode from twisted.python import log from twisted.web import http from vumi.transports.httprpc import HttpRpcTransport class OldSimpleHttpTransport(HttpRpcTransport): """ Maintains the API used by the old Django based method of loading SMS's into VUMI over HTTP Configuration options: web_path : str The path relative to the host where this listens web_port : int The port this listens on transport_name : str The name this transport instance will use to create it's queues identities : dictionary user : str password : str default_transport : str """ def validate_config(self): super(OldSimpleHttpTransport, self).validate_config() self.identities = self.config.get('identities', {}) def get_health_response(self): return json.dumps({}) def get_credentials(self, request): auth_header = 'Authorization' headers = request.requestHeaders if headers.hasHeader(auth_header): auth = headers.getRawHeaders(auth_header)[0] creds = b64decode(auth.split(' ')[-1]) return creds.split(':') else: return '', '' def is_authorized(self, username, password): return self.identities.get(username) == password def handle_outbound_message(self, message): log.msg("OldSimpleHttpTransport consuming %s" % (message)) return self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) def check_authorization(self, request): username, password = self.get_credentials(request) if self.identities and not self.is_authorized(username, password): return False, username return True, username def handle_raw_inbound_message(self, request_id, request): authorized, username = self.check_authorization(request) if not authorized: return self.finish_request(request_id, 'Not Authorized', code=http.UNAUTHORIZED) message = request.args.get('message', [None])[0] to_msisdns = request.args.get('to_msisdn', []) from_msisdn = request.args.get('from_msisdn', [None])[0] return_list = [] for to_msisdn in to_msisdns: message_id = self.generate_message_id() content = message to_addr = to_msisdn from_addr = from_msisdn log.msg( 'OldSimpleHttpTransport sending from %s to %s message "%s"' % ( from_addr, to_addr, content)) self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=from_addr, provider='vumi', transport_type='old_simple_http', transport_metadata={ 'http_user': username, } ) return_list.append({ "message": message, "to_msisdn": to_msisdn, "from_msisdn": from_msisdn, "id": message_id, }) return self.finish_request(request_id, json.dumps(return_list)) class OldTemplateHttpTransport(OldSimpleHttpTransport): def handle_outbound_message(self, message): log.msg("OldTemplateHttpTransport consuming %s" % (message)) def extract_template_args(self, args, length): template_args = [] for i in range(length): template_args.append({}) for k, v in args.items(): if k.startswith("template_"): for i, x in enumerate(v): template_args[i][k] = x return template_args def handle_raw_inbound_message(self, request_id, request): authorized, username = self.check_authorization(request) if not authorized: return self.finish_request(request_id, 'Not Authorized', code=http.UNAUTHORIZED) opener = re.compile('{{ *') closer = re.compile(' *}}') template = request.args.get('template', [None])[0] template = opener.sub('%(template_', template) template = closer.sub(')s', template) to_msisdns = request.args.get('to_msisdn', []) from_msisdn = request.args.get('from_msisdn', [None])[0] template_args = self.extract_template_args(request.args, len(to_msisdns)) return_list = [] for i, to_msisdn in enumerate(to_msisdns): message_id = self.generate_message_id() message = content = template % template_args[i] to_addr = to_msisdn from_addr = from_msisdn log.msg(('OldTemplateHttpTransport sending from %s to %s ' 'message "%s"') % (from_addr, to_addr, content)) self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=from_addr, provider='vumi', transport_type='old_template_http', transport_metadata={ 'http_user': username, } ) return_list.append({ "message": message, "to_msisdn": to_msisdn, "from_msisdn": from_msisdn, "id": message_id, }) return self.finish_request(request_id, json.dumps(return_list)) PK=JG8 vumi/transports/api/api.py# -*- test-case-name: vumi.transports.api.tests.test_api -*- import json from twisted.python import log from twisted.internet.defer import inlineCallbacks from vumi.transports.httprpc import HttpRpcTransport from vumi.config import ConfigBool, ConfigList, ConfigDict class HttpApiConfig(HttpRpcTransport.CONFIG_CLASS): "HTTP API configuration." reply_expected = ConfigBool( "True if a reply message is expected.", default=False, static=True) allowed_fields = ConfigList( "The list of fields a request is allowed to contain. Defaults to the" " DEFAULT_ALLOWED_FIELDS class attribute.", static=True) field_defaults = ConfigDict( "Default values for fields not sent by the client.", default={}, static=True) class HttpApiTransport(HttpRpcTransport): """ Native HTTP API for getting messages into vumi. NOTE: This has no security. Put it behind a firewall or something. If reply_expected is True, the transport will wait for a reply message and will return the reply's content as the HTTP response body. If False, the message_id of the dispatched incoming message will be returned. """ transport_type = 'http_api' ENCODING = 'utf-8' CONFIG_CLASS = HttpApiConfig DEFAULT_ALLOWED_FIELDS = ( 'content', 'to_addr', 'from_addr', 'group', 'session_event', ) def setup_transport(self): config = self.get_static_config() self.reply_expected = config.reply_expected allowed_fields = config.allowed_fields if allowed_fields is None: allowed_fields = self.DEFAULT_ALLOWED_FIELDS self.allowed_fields = set(allowed_fields) self.field_defaults = config.field_defaults return super(HttpApiTransport, self).setup_transport() def handle_outbound_message(self, message): if self.reply_expected: return super(HttpApiTransport, self).handle_outbound_message( message) log.msg("HttpApiTransport dropping outbound message: %s" % (message)) def get_api_field_values(self, request, required_fields): values = self.field_defaults.copy() errors = {} for field in request.args: if field not in self.allowed_fields: errors.setdefault('unexpected_parameter', []).append(field) else: values[field] = ( request.args.get(field)[0].decode(self.ENCODING)) for field in required_fields: if field not in values and field in self.allowed_fields: errors.setdefault('missing_parameter', []).append(field) return values, errors @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): values, errors = self.get_api_field_values(request, ['content', 'to_addr', 'from_addr']) if errors: yield self.finish_request(message_id, json.dumps(errors), code=400) return log.msg(('HttpApiTransport sending from %(from_addr)s to %(to_addr)s ' 'message "%(content)s"') % values) payload = { 'message_id': message_id, 'transport_type': self.transport_type, } payload.update(values) yield self.publish_message(**payload) if not self.reply_expected: yield self.finish_request(message_id, json.dumps({'message_id': message_id})) PK=JGEyYYvumi/transports/api/__init__.py"""API transports to inject messages into VUMI.""" from vumi.transports.api.api import HttpApiTransport from vumi.transports.api.oldapi import (OldSimpleHttpTransport, OldTemplateHttpTransport) __all__ = ['HttpApiTransport', 'OldSimpleHttpTransport', 'OldTemplateHttpTransport'] PK=JG)@@%vumi/transports/api/tests/test_api.py# -*- encoding: utf-8 -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from vumi.utils import http_request, http_request_full from vumi.tests.helpers import VumiTestCase from vumi.transports.api import HttpApiTransport from vumi.transports.tests.helpers import TransportHelper def config_override(**config): def deco(fun): fun.config_override = config return fun return deco class TestHttpApiTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.config = { 'web_path': "foo", 'web_port': 0, } test_method = getattr(self, self._testMethodName) config_override = getattr(test_method, 'config_override', {}) self.config.update(config_override) self.tx_helper = self.add_helper(TransportHelper(HttpApiTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport_url = self.transport.get_transport_url() def mkurl(self, content, from_addr=123, to_addr=555, **kw): params = { 'to_addr': to_addr, 'from_addr': from_addr, 'content': content, } params.update(kw) return self.mkurl_raw(**params) def mkurl_raw(self, **params): return '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode(params) ) @inlineCallbacks def test_health(self): result = yield http_request( self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), {'pending_requests': 0}) @inlineCallbacks def test_inbound(self): url = self.mkurl('hello') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "555") self.assertEqual(msg['from_addr'], "123") self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_handle_non_ascii_input(self): url = self.mkurl(u"öæł".encode("utf-8")) response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "555") self.assertEqual(msg['from_addr'], "123") self.assertEqual(msg['content'], u"öæł") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks @config_override(reply_expected=True) def test_inbound_with_reply(self): d = http_request(self.mkurl('hello'), '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) yield self.tx_helper.make_dispatch_reply(msg, "OK") response = yield d self.assertEqual(response, 'OK') @inlineCallbacks def test_good_optional_parameter(self): url = self.mkurl('hello', group='#channel') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['group'], '#channel') self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_bad_parameter(self): url = self.mkurl('hello', foo='bar') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'unexpected_parameter': ['foo']}) @inlineCallbacks def test_missing_parameters(self): url = self.mkurl_raw(content='hello') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'missing_parameter': ['to_addr', 'from_addr']}) @inlineCallbacks @config_override(field_defaults={'to_addr': '555'}) def test_default_parameters(self): url = self.mkurl_raw(content='hello', from_addr='123') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "555") self.assertEqual(msg['from_addr'], "123") self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks @config_override(field_defaults={'to_addr': '555'}, allowed_fields=['content', 'from_addr']) def test_disallowed_default_parameters(self): url = self.mkurl_raw(content='hello', from_addr='123') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "555") self.assertEqual(msg['from_addr'], "123") self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks @config_override(allowed_fields=['content', 'from_addr']) def test_disallowed_parameters(self): url = self.mkurl('hello') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'unexpected_parameter': ['to_addr']}) PK=JG%vumi/transports/api/tests/__init__.pyPK=JG_B(vumi/transports/api/tests/test_oldapi.pyfrom base64 import b64encode import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from twisted.web import http from vumi.utils import http_request, http_request_full from vumi.tests.helpers import VumiTestCase from vumi.transports.api import ( OldSimpleHttpTransport, OldTemplateHttpTransport) from vumi.transports.tests.helpers import TransportHelper class TestOldSimpleHttpTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.config = { 'web_path': "foo", 'web_port': 0, } self.tx_helper = self.add_helper( TransportHelper(OldSimpleHttpTransport)) self.transport = yield self.tx_helper.get_transport(self.config) addr = self.transport.web_resource.getHost() self.transport_url = "http://%s:%s/" % (addr.host, addr.port) @inlineCallbacks def test_health(self): result = yield http_request(self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), {}) @inlineCallbacks def test_inbound(self): url = '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode([ ('to_msisdn', 555), ('to_msisdn', 556), ('from_msisdn', 123), ('message', 'hello'), ]) ) response = yield http_request(url, '', method='GET') [msg1, msg2] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg1['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg1['to_addr'], "555") self.assertEqual(msg2['to_addr'], "556") self.assertEqual(msg1['from_addr'], "123") self.assertEqual(msg1['content'], "hello") self.assertEqual(json.loads(response), [ { 'id': msg1['message_id'], 'message': msg1['content'], 'from_msisdn': msg1['from_addr'], 'to_msisdn': msg1['to_addr'], }, { 'id': msg2['message_id'], 'message': msg2['content'], 'from_msisdn': msg2['from_addr'], 'to_msisdn': msg2['to_addr'], }, ]) @inlineCallbacks def test_http_basic_auth(self): http_auth_config = self.config.copy() http_auth_config.update({ 'identities': { 'username': 'password', } }) transport = yield self.tx_helper.get_transport(http_auth_config) url = '%s%s?%s' % ( transport.get_transport_url(), self.config['web_path'], urlencode({ 'to_msisdn': '123', 'from_msisdn': '456', 'message': 'hello', })) response = yield http_request_full(url, '', method='GET') self.assertEqual(response.code, http.UNAUTHORIZED) self.assertEqual([], self.tx_helper.get_dispatched_inbound()) response = yield http_request_full(url, '', headers={ 'Authorization': ['Basic %s' % b64encode('username:password')] }, method='GET') self.assertEqual(response.code, http.OK) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['content'], 'hello') self.assertEqual(msg['transport_metadata'], { 'http_user': 'username', }) class TestOldTemplateHttpTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.config = { 'web_path': "foo", 'web_port': 0, } self.tx_helper = self.add_helper( TransportHelper(OldTemplateHttpTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport_url = self.transport.get_transport_url() @inlineCallbacks def test_inbound(self): url = '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode([ ('to_msisdn', 555), ('to_msisdn', 556), ('template_name', "Joe"), ('template_name', "Foo"), ('template_surname', "Smith"), ('template_surname', "Bar"), ('from_msisdn', 123), ('template', 'hello {{ name }} {{surname}}'), ]) ) response = yield http_request(url, '', method='GET') [msg1, msg2] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg1['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg1['to_addr'], "555") self.assertEqual(msg1['from_addr'], "123") self.assertEqual(msg1['content'], "hello Joe Smith") self.assertEqual(msg2['content'], "hello Foo Bar") self.assertEqual(json.loads(response), [ { 'id': msg1['message_id'], 'message': msg1['content'], 'from_msisdn': msg1['from_addr'], 'to_msisdn': msg1['to_addr'], }, { 'id': msg2['message_id'], 'message': msg2['content'], 'from_msisdn': msg2['from_addr'], 'to_msisdn': msg2['to_addr'], }, ]) PKqGַWs==$vumi/transports/smpp/smpp_service.pyfrom twisted.internet.defer import inlineCallbacks, returnValue, succeed from twisted.internet.task import LoopingCall from vumi.reconnecting_client import ReconnectingClientService from vumi.transports.smpp.protocol import ( EsmeProtocol, EsmeProtocolFactory, EsmeProtocolError) from vumi.transports.smpp.sequence import RedisSequence GSM_MAX_SMS_BYTES = 140 GSM_MAX_SMS_7BIT_CHARS = 160 class SmppService(ReconnectingClientService): throttle_statuses = ('ESME_RTHROTTLED', 'ESME_RMSGQFUL') def __init__(self, endpoint, bind_type, transport): self.transport = transport self.transport_name = transport.transport_name self.log = transport.log self.message_stash = self.transport.message_stash self.deliver_sm_processor = self.transport.deliver_sm_processor self.dr_processor = self.transport.dr_processor self.sequence_generator = RedisSequence(transport.redis) # Throttling setup. self.throttled = False self._throttled_pdus = [] self._unthrottle_delayedCall = None self.tps_counter = 0 self.tps_limit = self.get_config().mt_tps if self.tps_limit > 0: self.mt_tps_lc = LoopingCall(self.reset_mt_tps) else: self.mt_tps_lc = None # Connection setup. factory = EsmeProtocolFactory(self, bind_type) ReconnectingClientService.__init__(self, endpoint, factory) def get_protocol(self): return self._protocol def get_bind_state(self): if self._protocol is None: return EsmeProtocol.CLOSED_STATE return self._protocol.state def is_bound(self): if self._protocol is not None: return self._protocol.is_bound() return False def startService(self): if self.mt_tps_lc is not None: self.mt_tps_lc.clock = self.clock self.mt_tps_lc.start(1, now=True) return ReconnectingClientService.startService(self) def stopService(self): if self.mt_tps_lc and self.mt_tps_lc.running: self.mt_tps_lc.stop() d = succeed(None) if self._protocol is not None: d.addCallback(lambda _: self._protocol.disconnect()) d.addCallback(lambda _: ReconnectingClientService.stopService(self)) return d def get_config(self): return self.transport.get_static_config() @inlineCallbacks def reset_mt_tps(self): if self.throttled and self.need_mt_throttling(): if not self.is_bound(): # We don't have a bound SMPP connection, so try again later. self.log.msg( "Can't stop throttling while unbound, trying later.") return self.reset_mt_throttle_counter() yield self.stop_throttling() def reset_mt_throttle_counter(self): self.tps_counter = 0 def incr_mt_throttle_counter(self): self.tps_counter += 1 def need_mt_throttling(self): return self.tps_counter >= self.tps_limit def check_mt_throttling(self): if self.get_config().mt_tps > 0: self.incr_mt_throttle_counter() if self.need_mt_throttling(): # We can't yield here, because we need the current message to # finish sending before it will return. self.start_throttling() def _append_throttle_retry(self, seq_no): if seq_no not in self._throttled_pdus: self._throttled_pdus.append(seq_no) def check_stop_throttling(self, delay=None): if self._unthrottle_delayedCall is not None: # We already have one of these scheduled. return if delay is None: delay = self.get_config().throttle_delay self._unthrottle_delayedCall = self.clock.callLater( delay, self._check_stop_throttling) def check_stop_throttling_cb(self, ignored_result, delay=None): self.check_stop_throttling(delay) @inlineCallbacks def _check_stop_throttling(self): """ Check if we should stop throttling, and stop throttling if we should. At a high level, we try each throttled message in our list until all of them have been accepted by the SMSC, at which point we stop throttling. In more detail: We recursively process our list of throttled message_ids until either we have none left (at which point we stop throttling) or we find one we can successfully look up in our cache. When we find a message we can retry, we retry it and return. We remain throttled until the SMSC responds. If we're still throttled, the message_id gets appended to our list and another check is scheduled for later. If we're no longer throttled, this method gets called again immediately. When there are no more throttled message_ids in our list, we stop throttling. """ self._unthrottle_delayedCall = None if not self.is_bound(): # We don't have a bound SMPP connection, so try again later. self.log.msg("Can't check throttling while unbound, trying later.") self.check_stop_throttling() return if not self._throttled_pdus: # We have no throttled messages waiting, so stop throttling. self.log.msg("No more throttled messages to retry.") yield self.stop_throttling() return seq_no = self._throttled_pdus.pop(0) pdu_data = yield self.message_stash.get_cached_pdu(seq_no) yield self.retry_throttled_pdu(pdu_data, seq_no) @inlineCallbacks def retry_throttled_pdu(self, pdu_data, seq_no): if pdu_data is None: # We can't find this pdu, so log it and start again. self.log.warning( "Could not retrieve throttled pdu: %s" % (seq_no,)) self.check_stop_throttling(0) else: # Try handle this message again and leave the rest to our # submit_sm_resp handlers. self.log.msg("Retrying throttled pdu for message: %s" % ( pdu_data.vumi_message_id,)) # This is a new PDU, so it needs a new sequence number. new_seq_no = yield self.sequence_generator.next() pdu_data.pdu.obj['header']['sequence_number'] = new_seq_no yield self._protocol.send_submit_sm( pdu_data.vumi_message_id, pdu_data.pdu) yield self.message_stash.delete_cached_pdu(seq_no) @inlineCallbacks def start_throttling(self): if self.throttled: return self.log.msg("Throttling outbound messages.") self.throttled = True yield self.transport.pause_connectors() yield self.transport.on_throttled() @inlineCallbacks def stop_throttling(self): if not self.throttled: return self.log.msg("No longer throttling outbound messages.") self.throttled = False self.transport.unpause_connectors() yield self.transport.on_throttled_end() @inlineCallbacks def on_smpp_bind(self): self.transport.unpause_connectors() yield self.transport.on_smpp_bind() @inlineCallbacks def on_smpp_binding(self): yield self.transport.on_smpp_binding() @inlineCallbacks def on_smpp_unbinding(self): yield self.transport.on_smpp_unbinding() @inlineCallbacks def on_smpp_bind_timeout(self): yield self.transport.on_smpp_bind_timeout() @inlineCallbacks def on_connection_lost(self, reason): yield self.transport.pause_connectors() yield self.transport.on_connection_lost(reason) def handle_submit_sm_resp(self, message_id, smpp_id, pdu_status, seq_no): if pdu_status in self.throttle_statuses: return self.handle_submit_sm_throttled(seq_no) func = self.transport.handle_submit_sm_failure if pdu_status == 'ESME_ROK': func = self.transport.handle_submit_sm_success ms = self.message_stash d = func(message_id, smpp_id, pdu_status) d.addCallback(lambda _: ms.delete_cached_pdu(seq_no)) d.addCallback(lambda _: ms.delete_sequence_number_message_id(seq_no)) return d.addCallback(self.check_stop_throttling_cb, 0) def handle_submit_sm_throttled(self, message_id): self._append_throttle_retry(message_id) d = self.start_throttling() return d.addCallback(self.check_stop_throttling_cb) def submit_sm(self, *args, **kw): """ See :meth:`EsmeProtocol.submit_sm`. """ protocol = self.get_protocol() if protocol is None: raise EsmeProtocolError('submit_sm called while not connected.') self.check_mt_throttling() return protocol.submit_sm(*args, **kw) def submit_sm_long(self, vumi_message_id, destination_addr, long_message, **pdu_params): """ Send a `submit_sm` command with the message encoded in the ``message_payload`` optional parameter. Same parameters apply as for ``submit_sm`` with the exception that the ``short_message`` keyword argument is disallowed because it conflicts with the ``long_message`` field. :returns: list of 1 sequence number, int. :rtype: list """ if 'short_message' in pdu_params: raise EsmeProtocolError( 'short_message not allowed when sending a long message' 'in the message_payload') optional_parameters = pdu_params.pop('optional_parameters', {}).copy() optional_parameters.update({ 'message_payload': ( ''.join('%02x' % ord(c) for c in long_message)) }) return self.submit_sm( vumi_message_id, destination_addr, short_message='', sm_length=0, optional_parameters=optional_parameters, **pdu_params) def _fits_in_one_message(self, message): if len(message) <= GSM_MAX_SMS_BYTES: return True # NOTE: We already have byte strings here, so we assume that printable # ASCII characters are all the same as single-width GSM 03.38 # characters. if len(message) <= GSM_MAX_SMS_7BIT_CHARS: # TODO: We need better character handling and counting stuff. return all(0x20 <= ord(ch) <= 0x7f for ch in message) return False def csm_split_message(self, message): """ Chop the message into 130 byte chunks to leave 10 bytes for the user data header the SMSC is presumably going to add for us. This is a guess based mostly on optimism and the hope that we'll never have to deal with this stuff in production. NOTE: If we have utf-8 encoded data, we might break in the middle of a multibyte character. This should be ok since the message is only decoded after re-assembly of all individual segments. :param str message: The message to split :returns: list of strings :rtype: list """ if self._fits_in_one_message(message): return [message] payload_length = GSM_MAX_SMS_BYTES - 10 split_msg = [] while message: split_msg.append(message[:payload_length]) message = message[payload_length:] return split_msg @inlineCallbacks def submit_csm_sar(self, vumi_message_id, destination_addr, **pdu_params): """ Submit a concatenated SMS to the SMSC using the optional SAR parameter names in the various PDUS. :returns: List of sequence numbers (int) for each of the segments. :rtype: list """ split_msg = self.csm_split_message(pdu_params.pop('short_message')) if len(split_msg) == 1: # There is only one part, so send it without SAR stuff. sequence_numbers = yield self.submit_sm( vumi_message_id, destination_addr, short_message=split_msg[0], **pdu_params) returnValue(sequence_numbers) optional_parameters = pdu_params.pop('optional_parameters', {}).copy() ref_num = yield self.sequence_generator.next() sequence_numbers = [] yield self.message_stash.init_multipart_info( vumi_message_id, len(split_msg)) for i, msg in enumerate(split_msg): pdu_params = pdu_params.copy() optional_parameters.update({ # Reference number must be between 00 & FFFF 'sar_msg_ref_num': (ref_num % 0xFFFF), 'sar_total_segments': len(split_msg), 'sar_segment_seqnum': i + 1, }) sequence_number = yield self.submit_sm( vumi_message_id, destination_addr, short_message=msg, optional_parameters=optional_parameters, **pdu_params) sequence_numbers.extend(sequence_number) returnValue(sequence_numbers) @inlineCallbacks def submit_csm_udh(self, vumi_message_id, destination_addr, **pdu_params): """ Submit a concatenated SMS to the SMSC using user data headers (UDH) in the message content. Same parameters apply as for ``submit_sm`` with the exception that the ``esm_class`` keyword argument is disallowed because the SMPP spec mandates a value that is to be set for UDH. :returns: List of sequence numbers (int) for each of the segments. :rtype: list """ if 'esm_class' in pdu_params: raise EsmeProtocolError( 'Cannot specify esm_class, GSM spec sets this at 0x40 ' 'for concatenated messages using UDH.') pdu_params = pdu_params.copy() split_msg = self.csm_split_message(pdu_params.pop('short_message')) if len(split_msg) == 1: # There is only one part, so send it without UDH stuff. sequence_numbers = yield self.submit_sm( vumi_message_id, destination_addr, short_message=split_msg[0], **pdu_params) returnValue(sequence_numbers) ref_num = yield self.sequence_generator.next() sequence_numbers = [] yield self.message_stash.init_multipart_info( vumi_message_id, len(split_msg)) for i, msg in enumerate(split_msg): # 0x40 is the UDHI flag indicating that this payload contains a # user data header. # NOTE: Looking at the SMPP specs I can find no requirement # for this anywhere. pdu_params['esm_class'] = 0x40 # See http://en.wikipedia.org/wiki/User_Data_Header and # http://en.wikipedia.org/wiki/Concatenated_SMS for an # explanation of the magic numbers below. We should probably # abstract this out into a class that makes it less magic and # opaque. udh = ''.join([ '\05', # Full UDH header length '\00', # Information Element Identifier for Concatenated SMS '\03', # header length # Reference number must be between 00 & FF chr(ref_num % 0xFF), chr(len(split_msg)), chr(i + 1), ]) short_message = udh + msg sequence_number = yield self.submit_sm( vumi_message_id, destination_addr, short_message=short_message, **pdu_params) sequence_numbers.extend(sequence_number) returnValue(sequence_numbers) PK=HٙYPYP&vumi/transports/smpp/smpp_transport.py# -*- test-case-name: vumi.transports.smpp.tests.test_smpp_transport -*- import json import warnings from uuid import uuid4 from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue, succeed from smpp.pdu import decode_pdu from smpp.pdu_builder import PDU from vumi.message import TransportUserMessage from vumi.persist.txredis_manager import TxRedisManager from vumi.transports.base import Transport from vumi.transports.smpp.config import SmppTransportConfig from vumi.transports.smpp.deprecated.transport import ( SmppTransportConfig as OldSmppTransportConfig) from vumi.transports.smpp.deprecated.utils import convert_to_new_config from vumi.transports.smpp.smpp_service import SmppService from vumi.transports.failures import FailureMessage def sequence_number_key(seq_no): return 'sequence_number:%s' % (seq_no,) def multipart_info_key(seq_no): return 'multipart_info:%s' % (seq_no,) def message_key(message_id): return 'message:%s' % (message_id,) def pdu_key(seq_no): return 'pdu:%s' % (seq_no,) def remote_message_key(message_id): return 'remote_message:%s' % (message_id,) class CachedPDU(object): """ A cached PDU with its associated vumi message_id. """ def __init__(self, vumi_message_id, pdu): self.vumi_message_id = vumi_message_id self.pdu = pdu self.seq_no = pdu.obj['header']['sequence_number'] @classmethod def from_json(cls, pdu_json): if pdu_json is None: return None pdu_data = json.loads(pdu_json) pdu = PDU(None, None, None) pdu.obj = decode_pdu(pdu_data['pdu']) return cls(pdu_data['vumi_message_id'], pdu) def to_json(self): return json.dumps({ 'vumi_message_id': self.vumi_message_id, # We store the PDU in wire format to avoid json encoding troubles. 'pdu': self.pdu.get_hex(), }) class SmppMessageDataStash(object): """ Stash message data in Redis. """ def __init__(self, redis, config): self.redis = redis self.config = config def init_multipart_info(self, message_id, part_count): key = multipart_info_key(message_id) expiry = self.config.submit_sm_expiry d = self.redis.hmset(key, { 'parts': part_count, }) d.addCallback(lambda _: self.redis.expire(key, expiry)) return d def get_multipart_info(self, message_id): key = multipart_info_key(message_id) return self.redis.hgetall(key) def _update_multipart_info_success_cb(self, mp_info, key, remote_id): if not mp_info: # No multipart data, so do nothing. return part_key = 'part:%s' % (remote_id,) mp_info[part_key] = 'ack' d = self.redis.hset(key, part_key, 'ack') d.addCallback(lambda _: mp_info) return d def update_multipart_info_success(self, message_id, remote_id): key = multipart_info_key(message_id) d = self.get_multipart_info(message_id) d.addCallback(self._update_multipart_info_success_cb, key, remote_id) return d def _update_multipart_info_failure_cb(self, mp_info, key, remote_id): if not mp_info: # No multipart data, so do nothing. return part_key = 'part:%s' % (remote_id,) mp_info[part_key] = 'fail' d = self.redis.hset(key, part_key, 'fail') d.addCallback(lambda _: self.redis.hset(key, 'event_result', 'fail')) d.addCallback(lambda _: mp_info) return d def update_multipart_info_failure(self, message_id, remote_id): key = multipart_info_key(message_id) d = self.get_multipart_info(message_id) d.addCallback(self._update_multipart_info_failure_cb, key, remote_id) return d def _determine_multipart_event_cb(self, mp_info, message_id, event_type, remote_id): if not mp_info: # We don't seem to have a multipart message, so just return the # single-message data. return (True, event_type, remote_id) part_status_dict = dict( (k[5:], v) for k, v in mp_info.items() if k.startswith('part:')) remote_id = ','.join(sorted(part_status_dict.keys())) event_result = mp_info.get('event_result', None) if event_result is not None: # We already have a result, even if we don't have all the parts. event_type = event_result elif len(part_status_dict) >= int(mp_info['parts']): # We have all the parts, so we can determine the event type. if all(pv == 'ack' for pv in part_status_dict.values()): # All parts happy. event_type = 'ack' else: # At least one part failed. event_type = 'fail' else: # We don't have all the parts yet. return (False, None, None) # There's a race condition when we process multiple submit_sm_resps for # parts of the same messages concurrently. We only want to send one # event, so we do an atomic increment and ignore the event if we're # not the first to succeed. d = self.redis.hincrby( multipart_info_key(message_id), 'event_counter', 1) def confirm_multipart_event_cb(counter_value): if int(counter_value) == 1: return (True, event_type, remote_id) else: return (False, None, None) d.addCallback(confirm_multipart_event_cb) return d def get_multipart_event_info(self, message_id, event_type, remote_id): d = self.get_multipart_info(message_id) d.addCallback( self._determine_multipart_event_cb, message_id, event_type, remote_id) return d def expire_multipart_info(self, message_id): """ Set the TTL on multipart info hash to something small. We don't delete this in case there's still an in-flight operation that will recreate it without a TTL. """ expiry = self.config.completed_multipart_info_expiry return self.redis.expire(multipart_info_key(message_id), expiry) def set_sequence_number_message_id(self, sequence_number, message_id): key = sequence_number_key(sequence_number) expiry = self.config.submit_sm_expiry return self.redis.setex(key, expiry, message_id) def get_sequence_number_message_id(self, sequence_number): return self.redis.get(sequence_number_key(sequence_number)) def delete_sequence_number_message_id(self, sequence_number): return self.redis.delete(sequence_number_key(sequence_number)) def cache_message(self, message): key = message_key(message['message_id']) expiry = self.config.submit_sm_expiry return self.redis.setex(key, expiry, message.to_json()) def get_cached_message(self, message_id): d = self.redis.get(message_key(message_id)) d.addCallback(lambda json_data: ( TransportUserMessage.from_json(json_data) if json_data else None)) return d def delete_cached_message(self, message_id): return self.redis.delete(message_key(message_id)) def cache_pdu(self, vumi_message_id, pdu): cached_pdu = CachedPDU(vumi_message_id, pdu) key = pdu_key(cached_pdu.seq_no) expiry = self.config.submit_sm_expiry return self.redis.setex(key, expiry, cached_pdu.to_json()) def get_cached_pdu(self, seq_no): d = self.redis.get(pdu_key(seq_no)) return d.addCallback(CachedPDU.from_json) def delete_cached_pdu(self, seq_no): return self.redis.delete(pdu_key(seq_no)) def set_remote_message_id(self, message_id, smpp_message_id): if message_id is None: # If we store None, we end up with the string "None" in Redis. This # confuses later lookups (which treat any non-None value as a valid # identifier) and results in broken delivery reports. return succeed(None) key = remote_message_key(smpp_message_id) expire = self.config.third_party_id_expiry d = self.redis.setex(key, expire, message_id) d.addCallback(lambda _: message_id) return d def get_internal_message_id(self, smpp_message_id): return self.redis.get(remote_message_key(smpp_message_id)) def delete_remote_message_id(self, smpp_message_id): key = remote_message_key(smpp_message_id) return self.redis.delete(key) def expire_remote_message_id(self, smpp_message_id): key = remote_message_key(smpp_message_id) expire = self.config.final_dr_third_party_id_expiry return self.redis.expire(key, expire) class SmppTransceiverTransport(Transport): CONFIG_CLASS = SmppTransportConfig bind_type = 'TRX' clock = reactor start_message_consumer = False @property def throttled(self): return self.service.throttled @inlineCallbacks def setup_transport(self): yield self.publish_status_starting() config = self.get_static_config() self.log.msg( 'Starting SMPP Transport for: %s' % (config.twisted_endpoint,)) default_prefix = '%s@%s' % (config.system_id, config.transport_name) redis_prefix = config.split_bind_prefix or default_prefix self.redis = (yield TxRedisManager.from_config( config.redis_manager)).sub_manager(redis_prefix) self.dr_processor = config.delivery_report_processor( self, config.delivery_report_processor_config) self.deliver_sm_processor = config.deliver_short_message_processor( self, config.deliver_short_message_processor_config) self.submit_sm_processor = config.submit_short_message_processor( self, config.submit_short_message_processor_config) self.disable_ack = config.disable_ack self.disable_delivery_report = config.disable_delivery_report self.message_stash = SmppMessageDataStash(self.redis, config) self.service = self.start_service() def start_service(self): config = self.get_static_config() service = SmppService(config.twisted_endpoint, self.bind_type, self) service.clock = self.clock service.startService() return service @inlineCallbacks def teardown_transport(self): if self.service: yield self.service.stopService() yield self.redis._close() def _check_address_valid(self, message, field): try: message[field].encode('ascii') except UnicodeError: return False return True def _reject_for_invalid_address(self, message, field): return self.publish_nack( message['message_id'], u'Invalid %s: %s' % (field, message[field])) @inlineCallbacks def on_smpp_binding(self): yield self.publish_status_binding() @inlineCallbacks def on_smpp_unbinding(self): yield self.publish_status_unbinding() @inlineCallbacks def on_smpp_bind(self): yield self.publish_status_bound() if self.throttled: yield self.publish_throttled() @inlineCallbacks def on_throttled(self): yield self.publish_throttled() @inlineCallbacks def on_throttled_resume(self): yield self.publish_throttled() @inlineCallbacks def on_throttled_end(self): yield self.publish_throttled_end() @inlineCallbacks def on_smpp_bind_timeout(self): yield self.publish_status_bind_timeout() @inlineCallbacks def on_connection_lost(self, reason): yield self.publish_status_connection_lost(reason) def publish_status_starting(self): return self.publish_status( status='down', component='smpp', type='starting', message='Starting') def publish_status_binding(self): return self.publish_status( status='down', component='smpp', type='binding', message='Binding') def publish_status_unbinding(self): return self.publish_status( status='down', component='smpp', type='unbinding', message='Unbinding') def publish_status_bound(self): return self.publish_status( status='ok', component='smpp', type='bound', message='Bound') def publish_throttled(self): return self.publish_status( status='degraded', component='smpp', type='throttled', message='Throttled') def publish_throttled_end(self): return self.publish_status( status='ok', component='smpp', type='throttled_end', message='No longer throttled') def publish_status_bind_timeout(self): return self.publish_status( status='down', component='smpp', type='bind_timeout', message='Timed out awaiting bind') def publish_status_connection_lost(self, reason): return self.publish_status( status='down', component='smpp', type='connection_lost', message=str(reason.value)) @inlineCallbacks def handle_outbound_message(self, message): if not self._check_address_valid(message, 'to_addr'): yield self._reject_for_invalid_address(message, 'to_addr') return if not self._check_address_valid(message, 'from_addr'): yield self._reject_for_invalid_address(message, 'from_addr') return yield self.message_stash.cache_message(message) yield self.submit_sm_processor.handle_outbound_message( message, self.service) @inlineCallbacks def process_submit_sm_event(self, message_id, event_type, remote_id, command_status): if event_type == 'ack': yield self.message_stash.delete_cached_message(message_id) yield self.message_stash.expire_multipart_info(message_id) if not self.disable_ack: yield self.publish_ack(message_id, remote_id) else: if event_type != 'fail': self.log.warning( "Unexpected multipart event type %r, assuming 'fail'" % ( event_type,)) err_msg = yield self.message_stash.get_cached_message(message_id) command_status = command_status or 'Unspecified' if err_msg is None: self.log.warning( "Could not retrieve failed message: %s" % (message_id,)) else: yield self.message_stash.delete_cached_message(message_id) yield self.message_stash.expire_multipart_info(message_id) yield self.publish_nack(message_id, command_status) yield self.failure_publisher.publish_message( FailureMessage(message=err_msg.payload, failure_code=None, reason=command_status)) @inlineCallbacks def handle_submit_sm_success(self, message_id, smpp_message_id, command_status): yield self.message_stash.update_multipart_info_success( message_id, smpp_message_id) event_info = yield self.message_stash.get_multipart_event_info( message_id, 'ack', smpp_message_id) event_required, event_type, remote_id = event_info if event_required: yield self.process_submit_sm_event( message_id, event_type, remote_id, command_status) @inlineCallbacks def handle_submit_sm_failure(self, message_id, smpp_message_id, command_status): yield self.message_stash.update_multipart_info_failure( message_id, smpp_message_id) event_info = yield self.message_stash.get_multipart_event_info( message_id, 'fail', smpp_message_id) event_required, event_type, remote_id = event_info if event_required: yield self.process_submit_sm_event( message_id, event_type, remote_id, command_status) def handle_raw_inbound_message(self, **kwargs): # TODO: drop the kwargs, list the allowed key word arguments # explicitly with sensible defaults. message_type = kwargs.get('message_type', 'sms') message = { 'message_id': uuid4().hex, 'to_addr': kwargs['destination_addr'], 'from_addr': kwargs['source_addr'], 'content': kwargs['short_message'], 'transport_type': message_type, 'transport_metadata': {}, } if message_type == 'ussd': session_event = { 'new': TransportUserMessage.SESSION_NEW, 'continue': TransportUserMessage.SESSION_RESUME, 'close': TransportUserMessage.SESSION_CLOSE, }[kwargs['session_event']] message['session_event'] = session_event session_info = kwargs.get('session_info') message['transport_metadata']['session_info'] = session_info # TODO: This logs messages that fail to serialize to JSON # Usually this happens when an SMPP message has content # we can't decode (e.g. data_coding == 4). We should # remove the try-except once we handle such messages # better. return self.publish_message(**message).addErrback(self.log.err) @inlineCallbacks def handle_delivery_report( self, receipted_message_id, delivery_status, smpp_delivery_status): message_id = yield self.message_stash.get_internal_message_id( receipted_message_id) if message_id is None: self.log.warning( "Failed to retrieve message id for delivery report." " Delivery report from %s discarded." % self.transport_name) return if self.disable_delivery_report: dr = None else: dr = yield self.publish_delivery_report( user_message_id=message_id, delivery_status=delivery_status, transport_metadata={ 'smpp_delivery_status': smpp_delivery_status, }) if delivery_status in ('delivered', 'failed'): yield self.message_stash.expire_remote_message_id( receipted_message_id) returnValue(dr) class SmppReceiverTransport(SmppTransceiverTransport): bind_type = 'RX' class SmppTransmitterTransport(SmppTransceiverTransport): bind_type = 'TX' class SmppTransceiverTransportWithOldConfig(SmppTransceiverTransport): CONFIG_CLASS = OldSmppTransportConfig NEW_CONFIG_CLASS = SmppTransportConfig def __init__(self, *args, **kwargs): super(SmppTransceiverTransportWithOldConfig, self).__init__(*args, **kwargs) warnings.warn( 'This is a transport using a deprecated config file. ' 'Please use the new SmppTransceiverTransport, ' 'SmppTransmitterTransport or SmppReceiverTransport ' 'with the new processor aware SmppTransportConfig.', category=PendingDeprecationWarning) def get_static_config(self): # return if cached if hasattr(self, '_converted_static_config'): return self._converted_static_config cfg = super( SmppTransceiverTransportWithOldConfig, self).get_static_config() original = cfg._config_data.copy() config = convert_to_new_config( original, 'vumi.transports.smpp.processors.DeliveryReportProcessor', 'vumi.transports.smpp.processors.SubmitShortMessageProcessor', 'vumi.transports.smpp.processors.DeliverShortMessageProcessor' ) self._converted_static_config = self.NEW_CONFIG_CLASS( config, static=True) return self._converted_static_config PKH`zmWmW vumi/transports/smpp/protocol.py# -*- test-case-name: vumi.transports.smpp.tests.test_protocol -*- from functools import wraps from twisted.internet.protocol import Protocol, ClientFactory from twisted.internet.task import LoopingCall from twisted.internet.defer import ( inlineCallbacks, returnValue, maybeDeferred, DeferredQueue, succeed) from smpp.pdu import unpack_pdu from smpp.pdu_builder import ( BindTransceiver, BindReceiver, BindTransmitter, UnbindResp, Unbind, DeliverSMResp, EnquireLink, EnquireLinkResp, SubmitSM, QuerySM) from vumi.transports.smpp.pdu_utils import ( pdu_ok, seq_no, command_status, command_id, message_id, chop_pdu_stream) def require_bind(func): @wraps(func) def wrapper(self, *args, **kwargs): if not self.is_bound(): raise EsmeProtocolError('%s called in unbound state.' % (func,)) return func(self, *args, **kwargs) return wrapper class EsmeProtocolError(Exception): pass class EsmeProtocol(Protocol): noisy = True unbind_timeout = 2 OPEN_STATE = 'OPEN' CLOSED_STATE = 'CLOSED' BOUND_STATE_TRX = 'BOUND_TRX' BOUND_STATE_TX = 'BOUND_TX' BOUND_STATE_RX = 'BOUND_RX' BOUND_STATES = set([ BOUND_STATE_RX, BOUND_STATE_TX, BOUND_STATE_TRX, ]) _BIND_PDU = { 'TX': BindTransmitter, 'RX': BindReceiver, 'TRX': BindTransceiver, } def __init__(self, service, bind_type): """ An SMPP 3.4 client suitable for use by a Vumi Transport. :param SmppService service: The SMPP service that is using this protocol to communicate with an SMSC. """ self.service = service self.log = service.log self.bind_pdu = self._BIND_PDU[bind_type] self.clock = service.clock self.config = self.service.get_config() self.buffer = b'' self.state = self.CLOSED_STATE self.deliver_sm_processor = self.service.deliver_sm_processor self.dr_processor = self.service.dr_processor self.sequence_generator = self.service.sequence_generator self.enquire_link_call = LoopingCall(self.enquire_link) self.drop_link_call = None self.idle_timeout = self.config.smpp_enquire_link_interval * 2 self.disconnect_call = None self.unbind_resp_queue = DeferredQueue() def emit(self, msg): if self.noisy: self.log.debug(msg) @inlineCallbacks def connectionMade(self): self.state = self.OPEN_STATE self.log.msg('Connection made, current state: %s' % (self.state,)) self.bind( system_id=self.config.system_id, password=self.config.password, system_type=self.config.system_type, interface_version=self.config.interface_version, address_range=self.config.address_range) yield self.service.on_smpp_binding() @inlineCallbacks def bind(self, system_id, password, system_type, interface_version='34', addr_ton='', addr_npi='', address_range=''): """ Send the `bind_transmitter`, `bind_transceiver` or `bind_receiver` PDU to the SMSC in order to establish the connection. :param str system_id: Identifies the ESME system requesting to bind. :param str password: The password may be used by the SMSC to authenticate the ESME requesting to bind. If this is longer than 8 characters, it will be truncated and a warning will be logged. :param str system_type: Identifies the type of ESME system requesting to bind with the SMSC. :param str interface_version: Indicates the version of the SMPP protocol supported by the ESME. :param str addr_ton: Indicates Type of Number of the ESME address. :param str addr_npi: Numbering Plan Indicator for ESME address. :param str address_range: The ESME address. """ # Overly long passwords should be truncated. if len(password) > 8: password = password[:8] self.log.warning("Password longer than 8 characters, truncating.") sequence_number = yield self.sequence_generator.next() pdu = self.bind_pdu( sequence_number, system_id=system_id, password=password, system_type=system_type, interface_version=interface_version, addr_ton=addr_ton, addr_npi=addr_npi, address_range=address_range) self.send_pdu(pdu) self.drop_link_call = self.clock.callLater( self.config.smpp_bind_timeout, self.drop_link) @inlineCallbacks def drop_link(self): """ Called if the SMPP connection is not bound within ``smpp_bind_timeout`` amount of seconds """ if self.is_bound(): return yield self.service.on_smpp_bind_timeout() yield self.disconnect( 'Dropping link due to binding delay. Current state: %s' % ( self.state)) def disconnect(self, log_msg=None): """ Forcibly close the connection, logging ``log_msg`` if provided. :param str log_msg: The entry to write to the log file. """ if log_msg is not None: self.log.warning(log_msg) if not self.connected: return succeed(self.transport.loseConnection()) d = self.unbind() d.addCallback(lambda _: self.unbind_resp_queue.get()) d.addBoth(lambda *a: self.transport.loseConnection()) # Give the SMSC a few seconds to respond with an unbind_resp self.clock.callLater(self.unbind_timeout, d.cancel) return d def connectionLost(self, reason): """ :param Exception reason: The reason for the connection closed, generally a ``ConnectionDone`` """ self.state = self.CLOSED_STATE if self.enquire_link_call.running: self.enquire_link_call.stop() if self.drop_link_call is not None and self.drop_link_call.active(): self.drop_link_call.cancel() if self.disconnect_call is not None and self.disconnect_call.active(): self.disconnect_call.cancel() return self.service.on_connection_lost(reason) def is_bound(self): """ Returns ``True`` if the connection is in one of the known values of ``self.BOUND_STATES`` """ return self.state in self.BOUND_STATES @require_bind @inlineCallbacks def enquire_link(self): """ Ping the SMSC to see if they're still around. """ sequence_number = yield self.sequence_generator.next() self.send_pdu(EnquireLink(sequence_number)) returnValue(sequence_number) def send_pdu(self, pdu): """ Send a PDU to the SMSC :param smpp.pdu_builder.PDU pdu: The PDU object to send. """ self.emit('OUTGOING >> %r' % (pdu.get_obj(),)) return self.transport.write(pdu.get_bin()) def dataReceived(self, data): self.buffer += data data = self.handle_buffer() while data is not None: self.on_pdu(unpack_pdu(data)) data = self.handle_buffer() def handle_buffer(self): pdu_found = chop_pdu_stream(self.buffer) if pdu_found is None: return data, self.buffer = pdu_found return data def on_pdu(self, pdu): """ Handle a PDU that was received & decoded. :param dict pdu: The dict result one gets when calling ``smpp.pdu.unpack_pdu()`` on the received PDU """ self.emit('INCOMING << %r' % (pdu,)) handler = getattr(self, 'handle_%s' % (command_id(pdu),), self.on_unsupported_command_id) return maybeDeferred(handler, pdu) def on_unsupported_command_id(self, pdu): """ Called when an SMPP PDU is received for which no handler function has been defined. :param dict pdu: The dict result one gets when calling ``smpp.pdu.unpack_pdu()`` on the received PDU """ self.log.warning( 'Received unsupported SMPP command_id: %r' % (command_id(pdu),)) def handle_bind_transceiver_resp(self, pdu): if not pdu_ok(pdu): self.log.warning('Unable to bind: %r' % (command_status(pdu),)) self.transport.loseConnection() return self.state = self.BOUND_STATE_TRX return self.on_smpp_bind(seq_no(pdu)) def handle_bind_transmitter_resp(self, pdu): if not pdu_ok(pdu): self.log.warning('Unable to bind: %r' % (command_status(pdu),)) self.transport.loseConnection() return self.state = self.BOUND_STATE_TX return self.on_smpp_bind(seq_no(pdu)) def handle_bind_receiver_resp(self, pdu): if not pdu_ok(pdu): self.log.warning('Unable to bind: %r' % (command_status(pdu),)) self.transport.loseConnection() return self.state = self.BOUND_STATE_RX return self.on_smpp_bind(seq_no(pdu)) def on_smpp_bind(self, sequence_number): """Called when the bind has been setup""" self.drop_link_call.cancel() self.disconnect_call = self.clock.callLater( self.idle_timeout, self.disconnect, 'Disconnecting, no response from SMSC for longer ' 'than %s seconds' % (self.idle_timeout,)) self.enquire_link_call.clock = self.clock self.enquire_link_call.start(self.config.smpp_enquire_link_interval) return self.service.on_smpp_bind() def handle_unbind(self, pdu): return self.send_pdu(UnbindResp(seq_no(pdu))) def handle_submit_sm_resp(self, pdu): return self.on_submit_sm_resp( seq_no(pdu), message_id(pdu), command_status(pdu)) def on_submit_sm_resp(self, sequence_number, smpp_message_id, command_status): """ Called when a ``submit_sm_resp`` command was received. :param int sequence_number: The sequence_number of the command, should correlate with the sequence_number of the ``submit_sm`` command that this is a response to. :param str smpp_message_id: The message id that the SMSC is using for this message. This will be referred to in the delivery reports (if any). :param str command_status: The SMPP command_status for this command. Will determine if the ``submit_sm`` command was successful or not. Refer to the SMPP specification for full list of options. """ message_stash = self.service.message_stash d = message_stash.get_sequence_number_message_id(sequence_number) # only set the remote message id if the submission was successful, we # use remote message ids for delivery reports, so we won't need remote # message ids for failed submissions if command_status == 'ESME_ROK': d.addCallback( message_stash.set_remote_message_id, smpp_message_id) d.addCallback( self._handle_submit_sm_resp_callback, smpp_message_id, command_status, sequence_number) return d def _handle_submit_sm_resp_callback(self, message_id, smpp_message_id, command_status, sequence_number): if message_id is None: # We have no message_id, so log a warning instead of calling the # callback. self.log.warning( "Failed to retrieve message id for deliver_sm_resp." " ack/nack from %s discarded." % self.service.transport_name) else: return self.service.handle_submit_sm_resp( message_id, smpp_message_id, command_status, sequence_number) @inlineCallbacks def handle_deliver_sm(self, pdu): # These operate before the PDUs ``short_message`` or # ``message_payload`` fields have been string decoded. # NOTE: order is important! pdu_handler_chain = [ self.dr_processor.handle_delivery_report_pdu, self.deliver_sm_processor.handle_multipart_pdu, self.deliver_sm_processor.handle_ussd_pdu, ] for handler in pdu_handler_chain: handled = yield handler(pdu) if handled: self.send_pdu(DeliverSMResp(seq_no(pdu), command_status='ESME_ROK')) return # At this point we either have a DR in the message payload # or have a normal SMS that needs to be decoded and handled. content_parts = self.deliver_sm_processor.decode_pdus([pdu]) if not all([isinstance(part, unicode) for part in content_parts]): command_status = self.config.deliver_sm_decoding_error self.log.msg( 'Not all parts of the PDU were able to be decoded. ' 'Responding with %s.' % (command_status,), parts=content_parts) self.send_pdu(DeliverSMResp(seq_no(pdu), command_status=command_status)) return content = u''.join(content_parts) was_cdr = yield self.dr_processor.handle_delivery_report_content( content) if was_cdr: self.send_pdu(DeliverSMResp(seq_no(pdu), command_status='ESME_ROK')) return handled = yield self.deliver_sm_processor.handle_short_message_pdu(pdu) if handled: self.send_pdu(DeliverSMResp(seq_no(pdu), command_status="ESME_ROK")) return command_status = self.config.deliver_sm_decoding_error self.log.warning( 'Unable to process message. ' 'Responding with %s.' % (command_status,), content=content, pdu=pdu.get_obj()) self.send_pdu(DeliverSMResp(seq_no(pdu), command_status=command_status)) def handle_enquire_link(self, pdu): return self.send_pdu(EnquireLinkResp(seq_no(pdu))) def handle_enquire_link_resp(self, pdu): self.disconnect_call.reset(self.idle_timeout) @require_bind @inlineCallbacks def submit_sm(self, vumi_message_id, destination_addr, source_addr='', esm_class=0, protocol_id=0, priority_flag=0, schedule_delivery_time='', validity_period='', replace_if_present=0, data_coding=0, sm_default_msg_id=0, sm_length=0, short_message='', optional_parameters=None, **configured_parameters ): """ Put a `submit_sm` command on the wire. :param str source_addr: Address of SME which originated this message. If unknown leave blank. :param str destination_addr: Destination address of this short message. For mobile terminated messages, this is the directory number of the recipient MS. :param str service_type: The service_type parameter can be used to indicate the SMS Application service associated with the message. If unknown leave blank. :param int source_addr_ton: Type of Number for source address. :param int source_addr_npi: Numbering Plan Indicator for source address. :param int dest_addr_ton: Type of Number for destination. :param int dest_addr_npi: Numbering Plan Indicator for destination. :param int esm_class: Indicates Message Mode & Message Type. :param int protocol_id: Protocol Identifier. Network specific field. :param int priority_flag: Designates the priority level of the message. :param str schedule_delivery_time: The short message is to be scheduled by the SMSC for delivery. Leave blank for immediate delivery. :param str validity_period: The validity period of this message. Leave blank for SMSC default. :param int registered_delivery: Indicator to signify if an SMSC delivery receipt or an SME acknowledgement is required. :param int replace_if_present: Flag indicating if submitted message should replace an existing message. :param int data_coding: Defines the encoding scheme of the short message user data. :param int sm_default_msg_id: Indicates the short message to send from a list of pre- defined ('canned') short messages stored on the SMSC. Leave blank if not using an SMSC canned message. :param int sm_length: Length in octets of the short_message user data. This is automatically calculated and set during PDU encoding, no need to specify. :param int short_message: Up to 254 octets of short message user data. The exact physical limit for short_message size may vary according to the underlying network. Applications which need to send messages longer than 254 octets should use the message_payload parameter. In this case the sm_length field should be set to zero. :param dict optional_parameters: keys and values to be embedded in the PDU as tag-length-values. Refer to the SMPP specification and your SMSCs instructions on what valid and suitable keys and values are. :returns: list of 1 sequence number (int) for consistency with other submit_sm calls. :rtype: list """ configured_param_values = { 'service_type': self.config.service_type, 'source_addr_ton': self.config.source_addr_ton, 'source_addr_npi': self.config.source_addr_npi, 'dest_addr_ton': self.config.dest_addr_ton, 'dest_addr_npi': self.config.dest_addr_npi, 'registered_delivery': self.config.registered_delivery, } configured_param_values.update(configured_parameters) sequence_number = yield self.sequence_generator.next() pdu = SubmitSM( sequence_number=sequence_number, source_addr=source_addr, destination_addr=destination_addr, esm_class=esm_class, protocol_id=protocol_id, priority_flag=priority_flag, schedule_delivery_time=schedule_delivery_time, validity_period=validity_period, replace_if_present=replace_if_present, data_coding=data_coding, sm_default_msg_id=sm_default_msg_id, sm_length=sm_length, short_message=short_message, **configured_param_values) if optional_parameters: for key, value in optional_parameters.items(): pdu.add_optional_parameter(key, value) yield self.send_submit_sm(vumi_message_id, pdu) returnValue([sequence_number]) @inlineCallbacks def send_submit_sm(self, vumi_message_id, pdu): yield self.service.message_stash.cache_pdu(vumi_message_id, pdu) yield self.service.message_stash.set_sequence_number_message_id( seq_no(pdu.obj), vumi_message_id) self.send_pdu(pdu) @require_bind @inlineCallbacks def query_sm(self, message_id, source_addr_ton=0, source_addr_npi=0, source_addr='' ): """ Query the SMSC for the status of an earlier sent message. :param str message_id: Message ID of the message whose state is to be queried. This must be the SMSC assigned Message ID allocated to the original short message when submitted to the SMSC by the submit_sm, data_sm or submit_multi command, and returned in the response PDU by the SMSC. :param int source_addr_ton: Type of Number of message originator. This is used for verification purposes, and must match that supplied in the original request PDU (e.g. submit_sm). :param int source_addr_npi: Numbering Plan Identity of message originator. This is used for verification purposes, and must match that supplied in the original request PDU (e.g. submit_sm). :param str source_addr: Address of message originator. This is used for verification purposes, and must match that supplied in the original request PDU (e.g. submit_sm). """ sequence_number = yield self.sequence_generator.next() pdu = QuerySM( sequence_number=sequence_number, message_id=message_id, source_addr=source_addr, source_addr_npi=source_addr_npi, source_addr_ton=source_addr_ton) self.send_pdu(pdu) returnValue([sequence_number]) @inlineCallbacks def unbind(self): sequence_number = yield self.sequence_generator.next() self.send_pdu(Unbind(sequence_number)) yield self.service.on_smpp_unbinding() returnValue([sequence_number]) def handle_unbind_resp(self, pdu): self.unbind_resp_queue.put(pdu) class EsmeProtocolFactory(ClientFactory): protocol = EsmeProtocol def __init__(self, service, bind_type): self.service = service self.bind_type = bind_type def buildProtocol(self, addr): proto = self.protocol(self.service, self.bind_type) proto.factory = self return proto PK=JGdL)"vumi/transports/smpp/smpp_utils.pydef unpacked_pdu_opts(unpacked_pdu): pdu_opts = {} for opt in unpacked_pdu['body'].get('optional_parameters', []): pdu_opts[opt['tag']] = opt['value'] return pdu_opts def detect_ussd(pdu_opts): # TODO: Push this back to python-smpp? return ('ussd_service_op' in pdu_opts) def update_ussd_pdu(sm_pdu, continue_session, session_info=None): if session_info is None: session_info = '0000' session_info = "%04x" % (int(session_info, 16) + int(not continue_session)) sm_pdu.add_optional_parameter('ussd_service_op', '02') sm_pdu.add_optional_parameter('its_session_info', session_info) return sm_pdu PK=JG+[[ vumi/transports/smpp/__init__.py""" SMPP transport API. """ from vumi.transports.smpp.smpp_transport import ( SmppTransceiverTransport, SmppTransmitterTransport, SmppReceiverTransport, SmppTransceiverTransportWithOldConfig as SmppTransport) __all__ = [ 'SmppTransport', 'SmppTransceiverTransport', 'SmppTransmitterTransport', 'SmppReceiverTransport', ] PK=JG(!vumi/transports/smpp/pdu_utils.pyimport binascii from vumi.transports.smpp.smpp_utils import unpacked_pdu_opts def pdu_ok(pdu): return command_status(pdu) == 'ESME_ROK' def seq_no(pdu): return pdu['header']['sequence_number'] def command_status(pdu): return pdu['header']['command_status'] def command_id(pdu): return pdu['header']['command_id'] def message_id(pdu): # If we have an unsuccessful response, we may not get a message_id. if 'body' not in pdu and not pdu_ok(pdu): return None return pdu['body']['mandatory_parameters']['message_id'] def short_message(pdu): return pdu['body']['mandatory_parameters']['short_message'] def pdu_tlv(pdu, tag): return unpacked_pdu_opts(pdu)[tag] def chop_pdu_stream(data): if len(data) < 16: return bytes = binascii.b2a_hex(data[0:4]) cmd_length = int(bytes, 16) if len(data) < cmd_length: return pdu, data = (data[0:cmd_length], data[cmd_length:]) return pdu, data PK=JG #vumi/transports/smpp/iprocessors.pyfrom zope.interface import Interface class IDeliveryReportProcessor(Interface): def handle_delivery_report_pdu(pdu_data): """ Handle a delivery report PDU from the networks. This should always return a Deferred that fires with a ``True`` if a delivery report was found and handled and ``False`` if that was not the case. All processors should implement this even if it does nothing. """ def handle_delivery_report_content(pdu_data): """ Handle an unpacked delivery report from the networks. This can happen with certain SMSCs that don't set the necessary delivery report flags on a PDU. As a result we only detect the DR by matching a received SM against a predefined regex. """ class IDeliverShortMessageProcessor(Interface): def handle_short_message_pdu(pdu): """ Handle a short message PDU from the networks after it has been re-assembled and decoded. This should always return a Deferred that fires ``True`` or ``False`` depending on whether the PDU was handled succesfully. All processors should implement this even if it does nothing. """ def handle_multipart_pdu(pdu): """ Handle a part of a multipart PDU. This should always return a Deferred that fires ``True`` or ``False`` depending on whether the PDU was a multipart-part. All processors should implement this even if it does nothing. """ def handle_ussd_pdu(pdu): """ Handle a USSD pdu. This should always return a Deferred that fires ``True`` or ``False`` depending on whether the PDU had a PDU payload. NOTE: It is likely that the USSD bits of this Interface will move to its own Interface implementation once work starts on an USSD over SMPP implementation. All processors should implement this even if it does nothing. """ def decode_pdus(pdus): """ Decode a list of PDUs and return the contents for each PDU's ``short_message`` field. """ def dcs_decode(obj, data_coding): """ Decode a byte string and return the unicode string for it according to the specified data coding. """ class ISubmitShortMessageProcessor(Interface): def handle_raw_outbound_message(vumi_message, smpp_service): """ Handle an outbound message from Vumi by calling the appropriate methods on the service with the appropriate parameters. These parameters and values can differ per MNO. Should return a Deferred that fires with a the list of sequence_numbers returning from the submit_sm calls. """ PK=H,~vumi/transports/smpp/config.pyfrom vumi.config import ( ConfigText, ConfigInt, ConfigBool, ConfigClientEndpoint, ConfigDict, ConfigFloat, ConfigClassName, ClientEndpointFallback) from vumi.transports.smpp.iprocessors import ( IDeliveryReportProcessor, IDeliverShortMessageProcessor, ISubmitShortMessageProcessor) from vumi.codecs.ivumi_codecs import IVumiCodec from vumi.transports.base import Transport class SmppTransportConfig(Transport.CONFIG_CLASS): twisted_endpoint = ConfigClientEndpoint( 'The SMPP endpoint to connect to.', required=True, static=True, fallbacks=[ClientEndpointFallback()]) initial_reconnect_delay = ConfigInt( 'How long (in seconds) to wait between reconnecting attempts. ' 'Defaults to 5 seconds.', default=5, static=True) throttle_delay = ConfigFloat( "Delay (in seconds) before retrying a message after receiving " "`ESME_RTHROTTLED` or `ESME_RMSGQFUL`.", default=0.1, static=True) deliver_sm_decoding_error = ConfigText( 'The error to respond with when we were unable to decode all parts ' 'of a PDU.', default='ESME_RDELIVERYFAILURE', static=True) submit_sm_expiry = ConfigInt( 'How long (in seconds) to wait for the SMSC to return with a ' '`submit_sm_resp`. Defaults to 24 hours.', default=(60 * 60 * 24), static=True) disable_ack = ConfigBool( 'Disable publishing of `ack` events. In some cases this event ' 'causes more noise than signal. It can optionally be turned off. ' 'Defaults to False.', default=False, static=True) disable_delivery_report = ConfigBool( 'Disable publishing of `delivery_report` events. In some cases this ' 'event causes more noise than signal. It can optionally be turned ' 'off. Note that failed or successful delivery reports will still be ' 'used to track which SMPP message ids can be removed from temporary ' 'caches. Defaults to False.', default=False, static=True) third_party_id_expiry = ConfigInt( 'How long (in seconds) to keep 3rd party message IDs around to allow ' 'for matching submit_sm_resp and delivery report messages. Defaults ' 'to 1 week.', default=(60 * 60 * 24 * 7), static=True) final_dr_third_party_id_expiry = ConfigInt( 'How long (in seconds) to keep 3rd party message IDs around after ' 'receiving a success or failure delivery report for the message.', default=(60 * 60), static=True) completed_multipart_info_expiry = ConfigInt( 'How long (in seconds) to keep multipart message info for completed ' 'multipart messages around to avoid pending operations accidentally ' 'recreating them without an expiry time. Defaults to 1 hour.', default=(60 * 60), static=True) redis_manager = ConfigDict( 'How to connect to Redis.', default={}, static=True) split_bind_prefix = ConfigText( "This is the Redis prefix to use for storing things like sequence " "numbers and message ids for delivery report handling. It defaults " "to `@`. " "This should *ONLY* be done for TX & RX since messages sent via " "the TX bind are handled by the RX bind and they need to share the " "same prefix for the lookup for message ids in delivery reports to " "work.", default='', static=True) codec_class = ConfigClassName( 'Which class should be used to handle character encoding/decoding. ' 'MUST implement `IVumiCodec`.', default='vumi.codecs.VumiCodec', static=True, implements=IVumiCodec) delivery_report_processor = ConfigClassName( 'Which delivery report processor to use. ' 'MUST implement `IDeliveryReportProcessor`.', default=('vumi.transports.smpp.processors.' 'DeliveryReportProcessor'), static=True, implements=IDeliveryReportProcessor) delivery_report_processor_config = ConfigDict( 'The configuration for the `delivery_report_processor`.', default={}, static=True) deliver_short_message_processor = ConfigClassName( 'Which deliver short message processor to use. ' 'MUST implement `IDeliverShortMessageProcessor`.', default='vumi.transports.smpp.processors.DeliverShortMessageProcessor', static=True, implements=IDeliverShortMessageProcessor) deliver_short_message_processor_config = ConfigDict( 'The configuration for the `deliver_short_message_processor`.', default={}, static=True) submit_short_message_processor = ConfigClassName( 'Which submit short message processor to use. ' 'Should implements `ISubmitShortMessageProcessor`.', default='vumi.transports.smpp.processors.SubmitShortMessageProcessor', static=True, implements=ISubmitShortMessageProcessor) submit_short_message_processor_config = ConfigDict( 'The configuration for the `submit_short_message_processor`.', default={}, static=True) system_id = ConfigText( 'User id used to connect to the SMPP server.', required=True, static=True) password = ConfigText( 'Password for the system id.', required=True, static=True) system_type = ConfigText( "Additional system metadata that is passed through to the SMPP " "server on connect.", default="", static=True) interface_version = ConfigText( "SMPP protocol version. Default is '34' (i.e. version 3.4).", default="34", static=True) service_type = ConfigText( 'The SMPP service type.', default="", static=True) address_range = ConfigText( "Address range to receive. (SMSC-specific format, default empty.)", default="", static=True) dest_addr_ton = ConfigInt( 'Destination TON (type of number).', default=0, static=True) dest_addr_npi = ConfigInt( 'Destination NPI (number plan identifier). ' 'Default 1 (ISDN/E.164/E.163).', default=1, static=True) source_addr_ton = ConfigInt( 'Source TON (type of number).', default=0, static=True) source_addr_npi = ConfigInt( 'Source NPI (number plan identifier).', default=0, static=True) registered_delivery = ConfigBool( 'Whether or not to request delivery reports. Default True.', default=True, static=True) smpp_bind_timeout = ConfigInt( 'How long (in seconds) to wait for a succesful bind. Default 30.', default=30, static=True) smpp_enquire_link_interval = ConfigInt( "How long (in seconds) to delay before reconnecting to the server " "after being disconnected. Some WASPs, e.g. Clickatell require a 30s " "delay before reconnecting. In these cases a 45s " "`initial_reconnect_delay` is recommended. Default 55.", default=55, static=True) mt_tps = ConfigInt( 'Mobile Terminated Transactions per Second. The Maximum Vumi ' 'messages per second to attempt to put on the wire. ' 'Defaults to 0 which means no throttling is applied. ' '(NOTE: 1 Vumi message may result in multiple PDUs)', default=0, static=True, required=False) # TODO: Deprecate these fields when confmodel#5 is done. host = ConfigText( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) port = ConfigInt( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) PK=JG+V vumi/transports/smpp/sequence.py# -*- test-case-name: vumi.transports.smpp.tests.test_sequence -*- from twisted.internet.defer import inlineCallbacks, returnValue class RedisSequence(object): """ Generate a sequence of incrementing numbers that rollover at a given limit. This is backed by Redis' atomicity and safe to use in a distributed system. """ def __init__(self, redis, rollover_at=0xFFFF0000): self.redis = redis self.rollover_at = rollover_at def __iter__(self): return self def next(self): return self.get_next_seq() @inlineCallbacks def get_next_seq(self): """Get the next available SMPP sequence number. The valid range of sequence number is 0x00000001 to 0xFFFFFFFF. We start trying to wrap at 0xFFFF0000 so we can keep returning values (up to 0xFFFF of them) even while someone else is in the middle of resetting the counter. """ seq = yield self.redis.incr('smpp_last_sequence_number') if seq >= self.rollover_at: # We're close to the upper limit, so try to reset. It doesn't # matter if we actually succeed or not, since we're going to return # `seq` anyway. yield self._reset_seq_counter() returnValue(seq) @inlineCallbacks def _reset_seq_counter(self): """Reset the sequence counter in a safe manner. NOTE: There is a potential race condition in this implementation. If we acquire the lock and it expires while we still think we hold it, it's possible for the sequence number to be reset by someone else between the final vlue check and the reset call. This seems like a very unlikely situation, so we'll leave it like that for now. A better solution is to replace this whole method with a lua script that we send to redis, but scripting support is still very new at the time of writing. """ # SETNX can be used as a lock. locked = yield self.redis.setnx('smpp_last_sequence_number_wrap', 1) # If someone crashed in exactly the wrong place, the lock may be # held by someone else but have no expire time. A race condition # here may set the TTL multiple times, but that's fine. if (yield self.redis.ttl('smpp_last_sequence_number_wrap')) < 0: # The TTL only gets set if the lock exists and recently had no TTL. yield self.redis.expire('smpp_last_sequence_number_wrap', 10) if not locked: # We didn't actually get the lock, so our job is done. return if (yield self.redis.get('smpp_last_sequence_number')) < 0xFFFF0000: # Our stored sequence number is no longer outside the allowed # range, so someone else must have reset it before we got the lock. return # We reset the counter by deleting the key. The next INCR will recreate # it for us. yield self.redis.delete('smpp_last_sequence_number') PKqGoNL 8 8+vumi/transports/smpp/tests/test_protocol.py# -*- coding: utf-8 -*- from twisted.internet.defer import inlineCallbacks, gatherResults, succeed from twisted.internet.task import Clock from smpp.pdu_builder import ( Unbind, UnbindResp, SubmitSMResp, DeliverSM, EnquireLink) from vumi.log import WrappingLogger from vumi.tests.helpers import VumiTestCase, PersistenceHelper from vumi.transports.smpp.smpp_transport import ( SmppTransceiverTransport, SmppMessageDataStash) from vumi.transports.smpp.protocol import ( EsmeProtocol, EsmeProtocolFactory) from vumi.transports.smpp.pdu_utils import ( seq_no, command_status, command_id, short_message) from vumi.transports.smpp.sequence import RedisSequence from vumi.transports.smpp.tests.fake_smsc import FakeSMSC class DummySmppService(object): def __init__(self, clock, redis, config): self.log = WrappingLogger(system=config.get('worker_name')) self.clock = clock self.redis = redis self._config = config self._static_config = SmppTransceiverTransport.CONFIG_CLASS( self._config, static=True) config = self.get_config() self.dr_processor = config.delivery_report_processor( self, config.delivery_report_processor_config) self.deliver_sm_processor = config.deliver_short_message_processor( self, config.deliver_short_message_processor_config) self.sequence_generator = RedisSequence(self.redis) self.message_stash = SmppMessageDataStash(self.redis, config) self.paused = True def get_static_config(self): return self._static_config def get_config(self): return self._static_config def on_connection_lost(self, reason): self.paused = True def on_smpp_binding(self): pass def on_smpp_unbinding(self): pass def on_smpp_bind(self): self.paused = False def on_smpp_bind_timeout(self): pass class TestEsmeProtocol(VumiTestCase): @inlineCallbacks def setUp(self): self.clock = Clock() self.persistence_helper = self.add_helper(PersistenceHelper()) self.redis = yield self.persistence_helper.get_redis_manager() self.fake_smsc = FakeSMSC(auto_accept=False) def get_protocol(self, config={}, bind_type='TRX', accept_connection=True): cfg = { 'transport_name': 'sphex_transport', 'twisted_endpoint': 'tcp:host=127.0.0.1:port=0', 'system_id': 'system_id', 'password': 'password', 'smpp_bind_timeout': 30, } cfg.update(config) dummy_service = DummySmppService(self.clock, self.redis, cfg) factory = EsmeProtocolFactory(dummy_service, bind_type) proto_d = self.fake_smsc.endpoint.connect(factory) if accept_connection: self.fake_smsc.accept_connection() return proto_d def assertCommand(self, pdu, cmd_id, sequence_number=None, status=None, params={}): self.assertEqual(command_id(pdu), cmd_id) if sequence_number is not None: self.assertEqual(seq_no(pdu), sequence_number) if status is not None: self.assertEqual(command_status(pdu), status) pdu_params = {} if params: if 'body' not in pdu: raise Exception('Body does not have parameters.') mandatory_parameters = pdu['body']['mandatory_parameters'] for key in params: if key in mandatory_parameters: pdu_params[key] = mandatory_parameters[key] self.assertEqual(params, pdu_params) def lookup_message_ids(self, protocol, seq_nums): message_stash = protocol.service.message_stash lookup_func = message_stash.get_sequence_number_message_id return gatherResults([lookup_func(seq_num) for seq_num in seq_nums]) @inlineCallbacks def test_on_connection_made(self): connect_d = self.get_protocol(accept_connection=False) protocol = yield self.fake_smsc.await_connecting() self.assertEqual(protocol.state, EsmeProtocol.CLOSED_STATE) self.fake_smsc.accept_connection() protocol = yield connect_d # Same protocol. self.assertEqual(protocol.state, EsmeProtocol.OPEN_STATE) bind_pdu = yield self.fake_smsc.await_pdu() self.assertCommand( bind_pdu, 'bind_transceiver', sequence_number=1, params={ 'system_id': 'system_id', 'password': 'password', }) @inlineCallbacks def test_drop_link(self): protocol = yield self.get_protocol() bind_pdu = yield self.fake_smsc.await_pdu() self.assertCommand(bind_pdu, 'bind_transceiver') self.assertFalse(protocol.is_bound()) self.assertEqual(protocol.state, EsmeProtocol.OPEN_STATE) self.clock.advance(protocol.config.smpp_bind_timeout + 1) unbind_pdu = yield self.fake_smsc.await_pdu() self.assertCommand(unbind_pdu, 'unbind') yield self.fake_smsc.send_pdu(UnbindResp(seq_no(unbind_pdu))) yield self.fake_smsc.await_disconnect() @inlineCallbacks def test_on_smpp_bind(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() self.assertEqual(protocol.state, EsmeProtocol.BOUND_STATE_TRX) self.assertTrue(protocol.is_bound()) self.assertTrue(protocol.enquire_link_call.running) @inlineCallbacks def test_handle_unbind(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() self.assertEqual(protocol.state, EsmeProtocol.BOUND_STATE_TRX) self.fake_smsc.send_pdu(Unbind(0)) pdu = yield self.fake_smsc.await_pdu() self.assertCommand( pdu, 'unbind_resp', sequence_number=0, status='ESME_ROK') # We don't change state here. self.assertEqual(protocol.state, EsmeProtocol.BOUND_STATE_TRX) @inlineCallbacks def test_on_submit_sm_resp(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() calls = [] protocol.on_submit_sm_resp = lambda *a: calls.append(a) yield self.fake_smsc.send_pdu(SubmitSMResp(0, message_id='foo')) self.assertEqual(calls, [(0, 'foo', 'ESME_ROK')]) @inlineCallbacks def test_deliver_sm(self): calls = [] protocol = yield self.get_protocol() protocol.handle_deliver_sm = lambda pdu: succeed(calls.append(pdu)) yield self.fake_smsc.bind() yield self.fake_smsc.send_pdu( DeliverSM(0, message_id='foo', short_message='bar')) [deliver_sm] = calls self.assertCommand(deliver_sm, 'deliver_sm', sequence_number=0) @inlineCallbacks def test_deliver_sm_fail(self): yield self.get_protocol() yield self.fake_smsc.bind() yield self.fake_smsc.send_pdu(DeliverSM( sequence_number=0, message_id='foo', data_coding=4, short_message='string with unknown data coding')) deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertCommand( deliver_sm_resp, 'deliver_sm_resp', sequence_number=0, status='ESME_RDELIVERYFAILURE') @inlineCallbacks def test_deliver_sm_fail_with_custom_error(self): yield self.get_protocol({ "deliver_sm_decoding_error": "ESME_RSYSERR" }) yield self.fake_smsc.bind() yield self.fake_smsc.send_pdu(DeliverSM( sequence_number=0, message_id='foo', data_coding=4, short_message='string with unknown data coding')) deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertCommand( deliver_sm_resp, 'deliver_sm_resp', sequence_number=0, status='ESME_RSYSERR') @inlineCallbacks def test_on_enquire_link(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() pdu = EnquireLink(0) protocol.dataReceived(pdu.get_bin()) enquire_link_resp = yield self.fake_smsc.await_pdu() self.assertCommand( enquire_link_resp, 'enquire_link_resp', sequence_number=0, status='ESME_ROK') @inlineCallbacks def test_on_enquire_link_resp(self): protocol = yield self.get_protocol() calls = [] protocol.handle_enquire_link_resp = calls.append yield self.fake_smsc.bind() [pdu] = calls # bind_transceiver is sequence_number 1 self.assertEqual(seq_no(pdu), 2) self.assertEqual(command_id(pdu), 'enquire_link_resp') @inlineCallbacks def test_enquire_link_no_response(self): self.fake_smsc.auto_unbind = False protocol = yield self.get_protocol() yield self.fake_smsc.bind() self.assertEqual(self.fake_smsc.connected, True) self.clock.advance(protocol.idle_timeout) [enquire_link_pdu, unbind_pdu] = yield self.fake_smsc.await_pdus(2) self.assertCommand(enquire_link_pdu, 'enquire_link') self.assertCommand(unbind_pdu, 'unbind') self.assertEqual(self.fake_smsc.connected, True) self.clock.advance(protocol.unbind_timeout) yield self.fake_smsc.await_disconnect() @inlineCallbacks def test_enquire_link_looping(self): self.fake_smsc.auto_unbind = False protocol = yield self.get_protocol() yield self.fake_smsc.bind() self.assertEqual(self.fake_smsc.connected, True) # Respond to a few enquire_link cycles. for i in range(5): self.clock.advance(protocol.idle_timeout - 1) pdu = yield self.fake_smsc.await_pdu() self.assertCommand(pdu, 'enquire_link') yield self.fake_smsc.respond_to_enquire_link(pdu) # Fail to respond, so we disconnect. self.clock.advance(protocol.idle_timeout - 1) pdu = yield self.fake_smsc.await_pdu() self.assertCommand(pdu, 'enquire_link') self.clock.advance(1) unbind_pdu = yield self.fake_smsc.await_pdu() self.assertCommand(unbind_pdu, 'unbind') yield self.fake_smsc.send_pdu( UnbindResp(seq_no(unbind_pdu))) yield self.fake_smsc.await_disconnect() @inlineCallbacks def test_submit_sm(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() seq_nums = yield protocol.submit_sm( 'abc123', 'dest_addr', short_message='foo') submit_sm = yield self.fake_smsc.await_pdu() self.assertCommand(submit_sm, 'submit_sm', params={ 'short_message': 'foo', }) stored_ids = yield self.lookup_message_ids(protocol, seq_nums) self.assertEqual(['abc123'], stored_ids) @inlineCallbacks def test_submit_sm_configured_parameters(self): protocol = yield self.get_protocol({ 'service_type': 'stype', 'source_addr_ton': 2, 'source_addr_npi': 2, 'dest_addr_ton': 2, 'dest_addr_npi': 2, 'registered_delivery': 0, }) yield self.fake_smsc.bind() seq_nums = yield protocol.submit_sm( 'abc123', 'dest_addr', short_message='foo') submit_sm = yield self.fake_smsc.await_pdu() self.assertCommand(submit_sm, 'submit_sm', params={ 'short_message': 'foo', 'service_type': 'stype', 'source_addr_ton': 'national', # replaced by unpack_pdu() 'source_addr_npi': 2, 'dest_addr_ton': 'national', # replaced by unpack_pdu() 'dest_addr_npi': 2, 'registered_delivery': 0, }) stored_ids = yield self.lookup_message_ids(protocol, seq_nums) self.assertEqual(['abc123'], stored_ids) @inlineCallbacks def test_query_sm(self): protocol = yield self.get_protocol() yield self.fake_smsc.bind() yield protocol.query_sm('foo', source_addr='bar') query_sm = yield self.fake_smsc.await_pdu() self.assertCommand(query_sm, 'query_sm', params={ 'message_id': 'foo', 'source_addr': 'bar', }) @inlineCallbacks def test_unbind(self): protocol = yield self.get_protocol() calls = [] protocol.handle_unbind_resp = calls.append yield self.fake_smsc.bind() yield protocol.unbind() unbind_pdu = yield self.fake_smsc.await_pdu() protocol.dataReceived(UnbindResp(seq_no(unbind_pdu)).get_bin()) [unbind_resp_pdu] = calls self.assertEqual(seq_no(unbind_resp_pdu), seq_no(unbind_pdu)) @inlineCallbacks def test_bind_transmitter(self): protocol = yield self.get_protocol(bind_type='TX') yield self.fake_smsc.bind() self.assertTrue(protocol.is_bound()) self.assertEqual(protocol.state, protocol.BOUND_STATE_TX) @inlineCallbacks def test_bind_receiver(self): protocol = yield self.get_protocol(bind_type='RX') yield self.fake_smsc.bind() self.assertTrue(protocol.is_bound()) self.assertEqual(protocol.state, protocol.BOUND_STATE_RX) @inlineCallbacks def test_partial_pdu_data_received(self): protocol = yield self.get_protocol() calls = [] protocol.handle_deliver_sm = calls.append yield self.fake_smsc.bind() deliver_sm = DeliverSM(1, short_message='foo') pdu = deliver_sm.get_bin() half = len(pdu) / 2 pdu_part1, pdu_part2 = pdu[:half], pdu[half:] yield self.fake_smsc.send_bytes(pdu_part1) self.assertEqual([], calls) yield self.fake_smsc.send_bytes(pdu_part2) [handled_pdu] = calls self.assertEqual(command_id(handled_pdu), 'deliver_sm') self.assertEqual(seq_no(handled_pdu), 1) self.assertEqual(short_message(handled_pdu), 'foo') @inlineCallbacks def test_unsupported_command_id(self): protocol = yield self.get_protocol() calls = [] protocol.on_unsupported_command_id = calls.append invalid_pdu = { 'header': { 'command_id': 'foo', } } protocol.on_pdu(invalid_pdu) self.assertEqual(calls, [invalid_pdu]) PK=JG&vumi/transports/smpp/tests/__init__.pyPK=JGs ) )'vumi/transports/smpp/tests/fake_smsc.py# -*- test-case-name: vumi.transports.smpp.tests.test_fake_smsc -*- from twisted.internet.defer import ( Deferred, succeed, DeferredQueue, gatherResults) from twisted.internet.error import ConnectionRefusedError from twisted.internet.interfaces import IStreamClientEndpoint from twisted.internet.protocol import Protocol from twisted.internet.task import deferLater from twisted.protocols.loopback import loopbackAsync from zope.interface import implementer from smpp.pdu import unpack_pdu from smpp.pdu_builder import ( BindTransceiverResp, BindTransmitterResp, BindReceiverResp, EnquireLinkResp, UnbindResp, DeliverSM, SubmitSMResp) from vumi.transports.smpp.pdu_utils import seq_no, chop_pdu_stream, command_id def wait0(r=None): """ Wait zero seconds to give the reactor a chance to work. Returns its (optional) argument, so it's useful as a callback. """ from twisted.internet import reactor return deferLater(reactor, 0, lambda: r) class FakeSMSC(object): """ Fake SMSC for testing. By default, it accepts incoming connections and automatically responds to unbind commands. Only one client connection at a time is allowed. """ def __init__(self, auto_accept=True, auto_unbind=True): self.auto_accept = auto_accept self.auto_unbind = auto_unbind self.pdu_queue = DeferredQueue() self.endpoint = FakeSMSCEndpoint(self) self.connected = False self._reset_connection_ds() # Public API. def await_connecting(self): """ Wait for a client to start connecting, and then return the client protocol. This is useful if auto-accept is disabled, otherwise use :meth:`await_connected` instead. """ return self._listen_d def await_connected(self): """ Wait for a client to finish connecting. """ return self._connected_d def accept_connection(self): """ Accept a pending connection. This is only useful if auto-accept is disabled. """ assert self.has_pending_connection(), "No pending connection." self._accept_d.callback(self.protocol) return self.await_connected() def reject_connection(self): """ Reject a pending connection. This is only useful if auto-accept is disabled. The deferred returned by waiting for `await_connected()` for this connection will never fire. """ assert self.has_pending_connection(), "No pending connection." self._accept_d.errback(ConnectionRefusedError()) self._reset_connection_ds() def has_pending_connection(self): """ Returns `True` if there is a pending connection, `False` otherwise. """ return self._accept_d is not None and not self._accept_d.called def send_bytes(self, bytes): """ Put some bytes on the wire. This also waits zero seconds to allow the bytes to be delivered. """ self.protocol.transport.write(bytes) return wait0() def send_pdu(self, pdu): """ Send a PDU to the connected ESME. This also waits zero seconds to allow the PDU to be delivered. """ self.protocol.send_pdu(pdu) return wait0() def handle_pdu(self, pdu): """ Bypass the wire connection and call `on_pdu` directly. This allows the caller to wait until the PDU processing has finished. It also allows invalid PDUs to be sent. """ return self._client_protocol.on_pdu(pdu.obj) def bind(self, bind_pdu=None): """ Respond to a bind command. :param bind_pdu: The bind PDU to respond to. If `None`, the next PDU on the receive queue will be used. """ bind_d = self._given_or_next_pdu(bind_pdu) return bind_d.addCallback(self._bind_resp) def respond_to_enquire_link(self, enquire_link_pdu=None): """ Respond to an enquire_link command. :param enquire_link_pdu: The enquire_link PDU to respond to. If `None`, the next PDU on the receive queue will be used. """ enquire_link_d = self._given_or_next_pdu(enquire_link_pdu) return enquire_link_d.addCallback(self._enquire_link_resp) def await_pdu(self): """ Wait for the next PDU from the receive queue. """ return self.pdu_queue.get() def await_pdus(self, count): """ Wait for the next `count` PDUs from the receive queue. """ return gatherResults([self.pdu_queue.get() for _ in range(count)]) def waiting_pdu_count(self): """ Returns the number of PDUs in the receive queue. """ return len(self.pdu_queue.pending) def send_mo(self, sequence_number, short_message, data_coding=1, **kwargs): """ Send a DeliverSM PDU. """ return self.send_pdu( DeliverSM( sequence_number, short_message=short_message, data_coding=data_coding, **kwargs)) def submit_sm_resp(self, submit_sm_pdu=None, message_id=None, **kw): """ Respond to a submit_sm command. NOTE: This uses :meth:`handle_pdu` instead of :meth:`send_pdu` because there's a lot of async stuff going on. :param submit_sm_pdu: The submit_sm PDU to respond to. If `None`, the next PDU on the receive queue will be used. :param message_id: The message_id to put in the response. If `None`, one will be generated from the sequence number. """ submit_sm_d = self._given_or_next_pdu(submit_sm_pdu) return submit_sm_d.addCallback(self._submit_sm_resp, message_id, **kw) def disconnect(self): """ Disconnect. """ self.protocol.transport.loseConnection() return self.await_disconnect() def await_disconnect(self): """ Wait for the client to disconnect. """ return self._finished_d # Internal stuff. def _reset_connection_ds(self): # self._finished_d is special, because we need that after the # connection gets closed. self._listen_d = Deferred() self._accept_d = None self._connected_d = Deferred() self._bound_d = Deferred() self._client_protocol = None self.protocol = None def handle_connection(self, client_protocol): assert self.protocol is None, "Already connected." self._client_protocol = client_protocol self.protocol = FakeSMSCProtocol(self) self._accept_d = Deferred() self._listen_d.callback(client_protocol) if self.auto_accept: self.accept_connection() return self._accept_d def connection_made(self): self.connected = True self._connected_d.callback(None) def connection_lost(self): self.connected = False self.protocol.transport.loseConnection() self._reset_connection_ds() def pdu_received(self, pdu): self.pdu_queue.put(pdu) if self.auto_unbind and command_id(pdu) == 'unbind': self.send_pdu(UnbindResp(seq_no(pdu))) def set_finished(self, finished_d): self._finished_d = Deferred() finished_d.addCallback(self._finished_d.callback) def _given_or_next_pdu(self, pdu): if pdu is not None: return succeed(pdu) return self.pdu_queue.get() def assert_command_id(self, pdu, *command_ids): if command_id(pdu) not in command_ids: raise ValueError( "Expected PDU with command_id in [%s], got %s." % ( ", ".join(command_ids), command_id(pdu))) def _bind_resp(self, bind_pdu): resp_pdu_classes = { 'bind_transceiver': BindTransceiverResp, 'bind_receiver': BindReceiverResp, 'bind_transmitter': BindTransmitterResp, } self.assert_command_id(bind_pdu, *resp_pdu_classes) resp_pdu_class = resp_pdu_classes.get(command_id(bind_pdu)) self.send_pdu(resp_pdu_class(seq_no(bind_pdu))) eq_d = self.respond_to_enquire_link() return eq_d.addCallback(self._bound_d.callback) def _enquire_link_resp(self, enquire_link_pdu): self.assert_command_id(enquire_link_pdu, 'enquire_link') return self.send_pdu(EnquireLinkResp(seq_no(enquire_link_pdu))) def _submit_sm_resp(self, submit_sm_pdu, message_id, **kw): self.assert_command_id(submit_sm_pdu, 'submit_sm') sequence_number = seq_no(submit_sm_pdu) if message_id is None: message_id = "id%s" % (sequence_number,) # We use handle_pdu here to avoid complications with all the async. return self.handle_pdu(SubmitSMResp(sequence_number, message_id, **kw)) @implementer(IStreamClientEndpoint) class FakeSMSCEndpoint(object): """ This endpoint connects a client directly to a FakeSMSC. """ def __init__(self, fake_smsc): self.fake_smsc = fake_smsc def connect(self, protocolFactory): client = protocolFactory.buildProtocol(None) d = self.fake_smsc.handle_connection(client) return d.addCallback(self._make_connection, client) def _make_connection(self, server, client): finished_d = loopbackAsync(server, client) self.fake_smsc.set_finished(finished_d) return client class FakeSMSCProtocol(Protocol): """ Very simple protocol for pretending to be an SMSC. """ def __init__(self, fake_smsc): self.fake_smsc = fake_smsc self._buf = b"" def connectionMade(self): self.fake_smsc.connection_made() def connectionLost(self, reason): self.fake_smsc.connection_lost() def dataReceived(self, data): self._buf += data data = self.handle_buffer() while data is not None: self.pdu_received(unpack_pdu(data)) data = self.handle_buffer() def handle_buffer(self): pdu_found = chop_pdu_stream(self._buf) if pdu_found is None: return data, self._buf = pdu_found return data def pdu_received(self, pdu): self.fake_smsc.pdu_received(pdu) def send_pdu(self, pdu): self.transport.write(pdu.get_bin()) PK=JG/2n`n`,vumi/transports/smpp/tests/test_fake_smsc.pyfrom twisted.internet.defer import Deferred, inlineCallbacks from twisted.internet.error import ConnectionRefusedError from twisted.internet.protocol import Protocol, ClientFactory from twisted.internet.task import Clock, deferLater from smpp.pdu import unpack_pdu from smpp.pdu_builder import ( BindTransceiver, BindTransceiverResp, BindTransmitter, BindTransmitterResp, BindReceiver, BindReceiverResp, EnquireLink, EnquireLinkResp, DeliverSM, Unbind, UnbindResp, SubmitSM, SubmitSMResp) from vumi.tests.helpers import VumiTestCase from vumi.transports.smpp.tests.fake_smsc import FakeSMSC def wait0(): from twisted.internet import reactor return deferLater(reactor, 0, lambda: None) class FakeESME(Protocol): def __init__(self): self.received = b"" self.pdus_handled = [] self.handle_pdu_d = None def dataReceived(self, data): self.received += data def connectionLost(self, reason): self.connected = False def on_pdu(self, pdu): self.handle_pdu_d = Deferred() self.handle_pdu_d.addCallback(self._handle_pdu, pdu) return self.handle_pdu_d def _handle_pdu(self, r, pdu): self.pdus_handled.append(pdu) self.handle_pdu_d = None def write(self, data): self.transport.write(data) return wait0() class FakeESMEFactory(ClientFactory): protocol = FakeESME def __init__(self): self.proto = None def buildProtocol(self, addr): self.proto = ClientFactory.buildProtocol(self, addr) return self.proto class TestFakeSMSC(VumiTestCase): """ Tests for FakeSMSC. """ def setUp(self): self.clock = Clock() self.client_factory = FakeESMEFactory() def connect(self, fake_smsc): return fake_smsc.endpoint.connect(self.client_factory) def test_await_connecting(self): """ The caller can wait for a client connection attempt. """ fake_smsc = FakeSMSC(auto_accept=False) await_connecting_d = fake_smsc.await_connecting() self.assertNoResult(await_connecting_d) self.assertEqual(self.client_factory.proto, None) self.assertEqual(fake_smsc._client_protocol, None) connect_d = self.connect(fake_smsc) # The client connection has started ... self.successResultOf(await_connecting_d) client = self.client_factory.proto self.assertNotEqual(client, None) self.assertEqual(fake_smsc._client_protocol, client) # ... but has not yet been accepted. self.assertNoResult(connect_d) self.assertEqual(client.connected, False) def test_await_connected(self): """ The caller can wait for a client to connect. """ fake_smsc = FakeSMSC(auto_accept=True) await_connected_d = fake_smsc.await_connected() self.assertNoResult(await_connected_d) self.assertEqual(self.client_factory.proto, None) self.assertEqual(fake_smsc._client_protocol, None) self.connect(fake_smsc) # The client has connected. self.successResultOf(await_connected_d) client = self.client_factory.proto self.assertNotEqual(client, None) self.assertEqual(fake_smsc._client_protocol, client) self.assertEqual(client.connected, True) def test_accept_connection(self): """ With auto-accept disabled, a connection must be manually accepted. """ fake_smsc = FakeSMSC(auto_accept=False) await_connecting_d = fake_smsc.await_connecting() await_connected_d = fake_smsc.await_connected() self.assertNoResult(await_connecting_d) self.assertNoResult(await_connected_d) connect_d = self.connect(fake_smsc) # The client connection is pending. self.successResultOf(await_connecting_d) self.assertNoResult(await_connected_d) self.assertNoResult(connect_d) client = self.client_factory.proto self.assertEqual(client.connected, False) accept_d = fake_smsc.accept_connection() # The client is connected. self.successResultOf(await_connected_d) self.successResultOf(accept_d) self.assertEqual(client.connected, True) self.assertEqual(self.successResultOf(connect_d), client) def test_accept_connection_no_pending(self): """ There must be a pending connection to accept. """ fake_smsc = FakeSMSC(auto_accept=False) self.assertRaises(Exception, fake_smsc.accept_connection) def test_reject_connection(self): """ With auto-accept disabled, a connection may be rejected. """ fake_smsc = FakeSMSC(auto_accept=False) await_connecting_d = fake_smsc.await_connecting() await_connected_d = fake_smsc.await_connected() self.assertNoResult(await_connecting_d) self.assertNoResult(await_connected_d) connect_d = self.connect(fake_smsc) # The client connection is pending. self.successResultOf(await_connecting_d) self.assertNoResult(await_connected_d) self.assertNoResult(connect_d) client = self.client_factory.proto self.assertEqual(client.connected, False) fake_smsc.reject_connection() # The client is not connected. self.failureResultOf(connect_d, ConnectionRefusedError) self.assertNoResult(await_connected_d) self.assertEqual(client.connected, False) def test_reject_connection_no_pending(self): """ There must be a pending connection to reject. """ fake_smsc = FakeSMSC(auto_accept=False) self.assertRaises(Exception, fake_smsc.reject_connection) def test_has_pending_connection(self): """ FakeSMSC knows if there's a pending connection. """ fake_smsc = FakeSMSC(auto_accept=False) self.assertEqual(fake_smsc.has_pending_connection(), False) # Pending connection we reject. connect_d = self.connect(fake_smsc) self.assertEqual(fake_smsc.has_pending_connection(), True) fake_smsc.reject_connection() self.assertEqual(fake_smsc.has_pending_connection(), False) self.failureResultOf(connect_d) # Pending connection we accept. connected_d = self.connect(fake_smsc) self.assertEqual(fake_smsc.has_pending_connection(), True) fake_smsc.accept_connection() self.assertEqual(fake_smsc.has_pending_connection(), False) self.successResultOf(connected_d) @inlineCallbacks def test_send_bytes(self): """ Bytes can be sent to the client. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") send_d = fake_smsc.send_bytes(b"abc") # Bytes sent, not yet received. self.assertNoResult(send_d) self.assertEqual(client.received, b"") yield send_d # Bytes received. self.assertEqual(client.received, b"abc") @inlineCallbacks def test_send_pdu(self): """ A PDU can be sent to the client over the wire. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") self.assertEqual(client.pdus_handled, []) pdu = DeliverSM(0) send_d = fake_smsc.send_pdu(pdu) # PDU sent, not yet received. self.assertNoResult(send_d) self.assertNotEqual(client.received, pdu.get_bin()) yield send_d # PDU received. self.assertEqual(client.received, pdu.get_bin()) self.assertEqual(client.pdus_handled, []) def test_handle_pdu(self): """ A PDU can be sent to the client for direct processing. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") self.assertEqual(client.pdus_handled, []) pdu = DeliverSM(0) handle_d = fake_smsc.handle_pdu(pdu) # PDU sent, not yet processed. self.assertNoResult(handle_d) self.assertEqual(client.pdus_handled, []) client.handle_pdu_d.callback(None) # PDU processed. self.successResultOf(handle_d) self.assertEqual(client.received, b"") self.assertEqual(client.pdus_handled, [pdu.obj]) @inlineCallbacks def test_bind(self): """ FakeSMSC can accept a bind request and respond to the first enquire_link. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") bind_d = fake_smsc.bind() yield client.write(BindTransceiver(0).get_bin()) # Bind response received. self.assertNoResult(bind_d) self.assertEqual(client.received, BindTransceiverResp(0).get_bin()) client.received = b"" yield client.write(EnquireLink(1).get_bin()) # enquire_link response received. self.assertNoResult(bind_d) self.assertEqual(client.received, EnquireLinkResp(1).get_bin()) yield wait0() # Bind complete. self.successResultOf(bind_d) @inlineCallbacks def test_bind_mode_TRX(self): """ FakeSMSC can accept tranceiver bind requests. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") bind_d = fake_smsc.bind() yield client.write(BindTransceiver(0).get_bin()) yield client.write(EnquireLink(1).get_bin()) self.assertEqual(client.received, b"".join([ BindTransceiverResp(0).get_bin(), EnquireLinkResp(1).get_bin()])) yield wait0() self.successResultOf(bind_d) @inlineCallbacks def test_bind_mode_TX(self): """ FakeSMSC can accept transmitter bind requests. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") bind_d = fake_smsc.bind() yield client.write(BindTransmitter(0).get_bin()) yield client.write(EnquireLink(1).get_bin()) self.assertEqual(client.received, b"".join([ BindTransmitterResp(0).get_bin(), EnquireLinkResp(1).get_bin()])) yield wait0() self.successResultOf(bind_d) @inlineCallbacks def test_bind_mode_RX(self): """ FakeSMSC can accept receiver bind requests. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") bind_d = fake_smsc.bind() yield client.write(BindReceiver(0).get_bin()) yield client.write(EnquireLink(1).get_bin()) self.assertEqual(client.received, b"".join([ BindReceiverResp(0).get_bin(), EnquireLinkResp(1).get_bin()])) yield wait0() self.successResultOf(bind_d) @inlineCallbacks def test_bind_explicit(self): """ FakeSMSC can bind using a PDU explicitly passed in. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") bind_d = fake_smsc.bind(BindTransceiver(0).obj) yield wait0() # Bind response received. self.assertNoResult(bind_d) self.assertEqual(client.received, BindTransceiverResp(0).get_bin()) client.received = b"" yield client.write(EnquireLink(1).get_bin()) # enquire_link response received. self.assertNoResult(bind_d) self.assertEqual(client.received, EnquireLinkResp(1).get_bin()) yield wait0() # Bind complete. self.successResultOf(bind_d) @inlineCallbacks def test_bind_wrong_pdu(self): """ FakeSMSC will raise an exception if asked to bind with a non-bind PDU. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) bind_d = fake_smsc.bind() yield client.write(EnquireLink(0).get_bin()) self.failureResultOf(bind_d, ValueError) @inlineCallbacks def test_respond_to_enquire_link(self): """ FakeSMSC can respond to an enquire_link. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") rtel_d = fake_smsc.respond_to_enquire_link() yield client.write(EnquireLink(2).get_bin()) # enquire_link response received. self.assertNoResult(rtel_d) self.assertEqual(client.received, EnquireLinkResp(2).get_bin()) yield wait0() self.successResultOf(rtel_d) @inlineCallbacks def test_respond_to_enquire_link_explicit(self): """ FakeSMSC can respond to an enquire_link PDU explicitly passed in. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") rtel_d = fake_smsc.respond_to_enquire_link(EnquireLink(2).obj) yield wait0() # enquire_link response received. self.successResultOf(rtel_d) self.assertEqual(client.received, EnquireLinkResp(2).get_bin()) @inlineCallbacks def test_respond_to_enquire_link_wrong_pdu(self): """ FakeSMSC will raise an exception if asked to respond to an enquire_link that isn't an enquire_link. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) rtel_d = fake_smsc.respond_to_enquire_link() yield client.write(DeliverSM(0).get_bin()) self.failureResultOf(rtel_d, ValueError) @inlineCallbacks def test_await_pdu(self): """ The caller can wait for a PDU to arrive. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) pdu_d = fake_smsc.await_pdu() # No PDU yet. self.assertNoResult(pdu_d) client.write(EnquireLink(1).get_bin()) # No yield. # PDU sent, not yet received. self.assertNoResult(pdu_d) yield wait0() # PDU received. self.assertEqual( self.successResultOf(pdu_d), unpack_pdu(EnquireLink(1).get_bin())) @inlineCallbacks def test_await_pdu_arrived(self): """ The caller can wait for a PDU that has already arrived. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) yield client.write(EnquireLink(1).get_bin()) yield client.write(EnquireLink(2).get_bin()) self.assertEqual( self.successResultOf(fake_smsc.await_pdu()), unpack_pdu(EnquireLink(1).get_bin())) self.assertEqual( self.successResultOf(fake_smsc.await_pdu()), unpack_pdu(EnquireLink(2).get_bin())) @inlineCallbacks def test_await_pdus(self): """ The caller can wait for multiple PDUs to arrive. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) pdus_d = fake_smsc.await_pdus(2) # No PDUs yet. self.assertNoResult(pdus_d) yield client.write(EnquireLink(1).get_bin()) # One PDU received, no result. self.assertNoResult(pdus_d) yield client.write(EnquireLink(2).get_bin()) # Both PDUs received. self.assertEqual(self.successResultOf(pdus_d), [ unpack_pdu(EnquireLink(1).get_bin()), unpack_pdu(EnquireLink(2).get_bin())]) @inlineCallbacks def test_await_pdus_arrived(self): """ The caller can wait for multiple PDU that have already arrived. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) yield client.write(EnquireLink(1).get_bin()) yield client.write(EnquireLink(2).get_bin()) self.assertEqual(self.successResultOf(fake_smsc.await_pdus(2)), [ unpack_pdu(EnquireLink(1).get_bin()), unpack_pdu(EnquireLink(2).get_bin())]) @inlineCallbacks def test_waiting_pdu_count(self): """ FakeSMSC knows how many received PDUs are waiting. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) # Nothing received yet. self.assertEqual(fake_smsc.waiting_pdu_count(), 0) # Some PDUs received. yield client.write(EnquireLink(1).get_bin()) self.assertEqual(fake_smsc.waiting_pdu_count(), 1) yield client.write(EnquireLink(2).get_bin()) self.assertEqual(fake_smsc.waiting_pdu_count(), 2) # Some PDUs returned. self.successResultOf(fake_smsc.await_pdu()) self.assertEqual(fake_smsc.waiting_pdu_count(), 1) self.successResultOf(fake_smsc.await_pdu()) self.assertEqual(fake_smsc.waiting_pdu_count(), 0) # Wait for a PDU that arrives later. pdu_d = fake_smsc.await_pdu() self.assertNoResult(pdu_d) self.assertEqual(fake_smsc.waiting_pdu_count(), 0) yield client.write(EnquireLink(3).get_bin()) self.assertEqual(fake_smsc.waiting_pdu_count(), 0) self.successResultOf(pdu_d) @inlineCallbacks def test_send_mo(self): """ FakeSMSC can send a DeliverSM PDU. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") yield fake_smsc.send_mo(5, "hello") # First MO received. self.assertEqual(client.received, DeliverSM( 5, short_message="hello", data_coding=1).get_bin()) client.received = b"" yield fake_smsc.send_mo(6, "hello again", 8, destination_addr="123") # Second MO received. self.assertEqual(client.received, DeliverSM( 6, short_message="hello again", data_coding=8, destination_addr="123").get_bin()) @inlineCallbacks def test_disconnect(self): """ FakeSMSC can disconnect from the client. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(fake_smsc.connected, True) self.assertEqual(client.connected, True) disconnect_d = fake_smsc.disconnect() # Disconnect triggered, but not completed. self.assertNoResult(disconnect_d) self.assertEqual(client.connected, True) self.assertEqual(fake_smsc.connected, True) yield wait0() # Disconnect completed. self.successResultOf(disconnect_d) self.assertEqual(client.connected, False) self.assertEqual(fake_smsc.connected, False) self.assertEqual(fake_smsc.protocol, None) self.assertEqual(fake_smsc._client_protocol, None) self.assertNoResult(fake_smsc._listen_d) self.assertNoResult(fake_smsc._connected_d) @inlineCallbacks def test_await_disconnect(self): """ FakeSMSC can wait for the connection to close. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) disconnect_d = fake_smsc.await_disconnect() yield wait0() self.assertNoResult(disconnect_d) client.transport.loseConnection() # Disconnect triggered, but not completed. self.assertNoResult(disconnect_d) yield wait0() # Disconnect completed. self.successResultOf(disconnect_d) @inlineCallbacks def test_auto_unbind(self): """ FakeSMSC will automatically respond to an unbind request by default. The unbind PDU remains in the queue. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") self.assertEqual(fake_smsc.waiting_pdu_count(), 0) yield client.write(Unbind(7).get_bin()) self.assertEqual(client.received, UnbindResp(7).get_bin()) self.assertEqual(fake_smsc.waiting_pdu_count(), 1) @inlineCallbacks def test_submit_sm_resp(self): """ FakeSMSC can respond to a SubmitSM PDU and wait for it to finish being processed. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") # No params. submit_sm_resp_d = fake_smsc.submit_sm_resp() yield client.write(SubmitSM(123).get_bin()) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp(123, message_id="id123", command_status="ESME_ROK") self.assertEqual(client.pdus_handled, [resp.obj]) client.pdus_handled[:] = [] # Explicit message_id. submit_sm_resp_d = fake_smsc.submit_sm_resp(message_id="foo") yield client.write(SubmitSM(124).get_bin()) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp(124, message_id="foo", command_status="ESME_ROK") self.assertEqual(client.pdus_handled, [resp.obj]) @inlineCallbacks def test_submit_sm_resp_with_failure(self): """ FakeSMSC can respond to a SubmitSM PDU with a failure. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") # No params. submit_sm_resp_d = fake_smsc.submit_sm_resp( command_status="ESME_RSUBMITFAIL") yield client.write(SubmitSM(123).get_bin()) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp( 123, message_id="id123", command_status="ESME_RSUBMITFAIL") self.assertEqual(client.pdus_handled, [resp.obj]) client.pdus_handled[:] = [] # Explicit message_id. submit_sm_resp_d = fake_smsc.submit_sm_resp( command_status="ESME_RSUBMITFAIL", message_id="foo") yield client.write(SubmitSM(124).get_bin()) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp( 124, message_id="foo", command_status="ESME_RSUBMITFAIL") self.assertEqual(client.pdus_handled, [resp.obj]) @inlineCallbacks def test_submit_sm_resp_explicit(self): """ FakeSMSC can respond to a SubmitSM PDU that is explicitly passed in. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) self.assertEqual(client.received, b"") # No params. submit_sm_resp_d = fake_smsc.submit_sm_resp(SubmitSM(123).obj) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp(123, message_id="id123", command_status="ESME_ROK") self.assertEqual(client.pdus_handled, [resp.obj]) client.pdus_handled[:] = [] # Explicit message_id. submit_sm_resp_d = fake_smsc.submit_sm_resp( SubmitSM(124).obj, message_id="foo") yield client.write(SubmitSM(124).get_bin()) self.assertNoResult(submit_sm_resp_d) client.handle_pdu_d.callback(None) self.successResultOf(submit_sm_resp_d) resp = SubmitSMResp(124, message_id="foo", command_status="ESME_ROK") self.assertEqual(client.pdus_handled, [resp.obj]) @inlineCallbacks def test_test_submit_sm_resp_wrong_pdu(self): """ FakeSMSC will raise an exception if asked to bind with a non-bind PDU. """ fake_smsc = FakeSMSC() client = self.successResultOf(self.connect(fake_smsc)) submit_sm_resp_d = fake_smsc.submit_sm_resp() yield client.write(EnquireLink(0).get_bin()) self.failureResultOf(submit_sm_resp_d, ValueError) PK=H``1vumi/transports/smpp/tests/test_smpp_transport.py# -*- coding: utf-8 -*- import logging from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.task import Clock from smpp.pdu_builder import DeliverSM, SubmitSMResp from vumi.config import ConfigError from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.smpp.smpp_transport import ( message_key, remote_message_key, multipart_info_key, sequence_number_key, SmppTransceiverTransport, SmppTransmitterTransport, SmppReceiverTransport, SmppTransceiverTransportWithOldConfig) from vumi.transports.smpp.pdu_utils import ( pdu_ok, short_message, command_id, seq_no, pdu_tlv, unpacked_pdu_opts) from vumi.transports.smpp.processors import SubmitShortMessageProcessor from vumi.transports.smpp.tests.fake_smsc import FakeSMSC from vumi.transports.tests.helpers import TransportHelper class TestSmppTransportConfig(VumiTestCase): def test_host_port_fallback(self): """ Old-style 'host' and 'port' fields are still supported in configs. """ def parse_config(extra_config): config = { 'transport_name': 'name', 'system_id': 'foo', 'password': 'bar', } config.update(extra_config) return SmppTransceiverTransport.CONFIG_CLASS(config, static=True) # If we don't provide an endpoint config, we get an error. self.assertRaises(ConfigError, parse_config, {}) # If we do provide an endpoint config, we get an endpoint. cfg = {'twisted_endpoint': 'tcp:host=example.com:port=1337'} self.assertNotEqual(parse_config(cfg).twisted_endpoint.connect, None) # If we provide host and port configs, we get an endpoint. cfg = {'host': 'example.com', 'port': 1337} self.assertNotEqual(parse_config(cfg).twisted_endpoint.connect, None) class SmppTransportTestCase(VumiTestCase): DR_TEMPLATE = ("id:%s sub:... dlvrd:... submit date:200101010030" " done date:200101020030 stat:DELIVRD err:... text:Meep") DR_MINIMAL_TEMPLATE = "id:%s stat:DELIVRD text:Meep" transport_class = None def setUp(self): self.clock = Clock() self.fake_smsc = FakeSMSC() self.tx_helper = self.add_helper(TransportHelper(self.transport_class)) self.default_config = { 'transport_name': self.tx_helper.transport_name, 'worker_name': self.tx_helper.transport_name, 'twisted_endpoint': self.fake_smsc.endpoint, 'delivery_report_processor': 'vumi.transports.smpp.processors.' 'DeliveryReportProcessor', 'deliver_short_message_processor': ( 'vumi.transports.smpp.processors.' 'DeliverShortMessageProcessor'), 'system_id': 'foo', 'password': 'bar', 'deliver_short_message_processor_config': { 'data_coding_overrides': { 0: 'utf-8', } } } def _get_transport_config(self, config): """ This is overridden in a subclass. """ cfg = self.default_config.copy() cfg.update(config) return cfg @inlineCallbacks def get_transport(self, config={}, bind=True): cfg = self._get_transport_config(config) transport = yield self.tx_helper.get_transport(cfg, start=False) transport.clock = self.clock yield transport.startWorker() self.clock.advance(0) if bind: yield self.fake_smsc.bind() returnValue(transport) class SmppTransceiverTransportTestCase(SmppTransportTestCase): transport_class = SmppTransceiverTransport @inlineCallbacks def test_setup_transport(self): transport = yield self.get_transport(bind=False) protocol = yield transport.service.get_protocol() self.assertEqual(protocol.is_bound(), False) yield self.fake_smsc.bind() self.assertEqual(protocol.is_bound(), True) @inlineCallbacks def test_mo_sms(self): yield self.get_transport() self.fake_smsc.send_mo( sequence_number=1, short_message='foo', source_addr='123', destination_addr='456') deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertTrue(pdu_ok(deliver_sm_resp)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], 'foo') self.assertEqual(msg['from_addr'], '123') self.assertEqual(msg['to_addr'], '456') self.assertEqual(msg['transport_type'], 'sms') @inlineCallbacks def test_mo_sms_empty_sms_allowed(self): yield self.get_transport({ 'deliver_short_message_processor_config': { 'allow_empty_messages': True, } }) self.fake_smsc.send_mo( sequence_number=1, short_message='', source_addr='123', destination_addr='456') deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertTrue(pdu_ok(deliver_sm_resp)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '') @inlineCallbacks def test_mo_sms_empty_sms_disallowed(self): yield self.get_transport() with LogCatcher(message=r"^(Not all parts|WARNING)") as lc: self.fake_smsc.send_mo( sequence_number=1, short_message='', source_addr='123', destination_addr='456') deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertFalse(pdu_ok(deliver_sm_resp)) # check that failure to process delivery report was logged self.assertEqual(lc.messages(), [ "WARNING: Not decoding `None` message with data_coding=1", "Not all parts of the PDU were able to be decoded. " "Responding with ESME_RDELIVERYFAILURE.", ]) for l in lc.logs: self.assertEqual(l['system'], 'sphex') inbound = self.tx_helper.get_dispatched_inbound() self.assertEqual(inbound, []) events = self.tx_helper.get_dispatched_events() self.assertEqual(events, []) @inlineCallbacks def test_mo_delivery_report_pdu_opt_params(self): """ We always treat a message with the optional PDU params set as a delivery report. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') pdu = DeliverSM(sequence_number=1, esm_class=4) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 2) yield self.fake_smsc.handle_pdu(pdu) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_pdu_opt_params_esm_class_not_set(self): """ We always treat a message with the optional PDU params set as a delivery report, even if ``esm_class`` is not set. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') pdu = DeliverSM(sequence_number=1) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 2) yield self.fake_smsc.handle_pdu(pdu) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_pdu_esm_class_not_set(self): """ We treat a content-based DR as a normal message if the ``esm_class`` flags are not set. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') self.fake_smsc.send_mo( sequence_number=1, short_message=self.DR_TEMPLATE % ('foo',), source_addr='123', destination_addr='456') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], self.DR_TEMPLATE % ('foo',)) self.assertEqual(msg['from_addr'], '123') self.assertEqual(msg['to_addr'], '456') self.assertEqual(msg['transport_type'], 'sms') events = yield self.tx_helper.get_dispatched_events() self.assertEqual(events, []) @inlineCallbacks def test_mo_delivery_report_esm_class_with_full_content(self): """ If ``esm_class`` and content are both set appropriately, we process the DR. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') self.fake_smsc.send_mo( sequence_number=1, short_message=self.DR_TEMPLATE % ('foo',), source_addr='123', destination_addr='456', esm_class=4) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_esm_class_with_short_status(self): """ If the delivery report has a shorter status field, the default regex still matches. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') short_message = ( "id:foo sub:... dlvrd:... submit date:200101010030" " done date:200101020030 stat:FAILED err:042 text:Meep") self.fake_smsc.send_mo( sequence_number=1, short_message=short_message, source_addr='123', destination_addr='456', esm_class=4) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'failed') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_esm_class_with_minimal_content(self): """ If ``esm_class`` and content are both set appropriately, we process the DR even if the minimal subset of the content regex matches. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') self.fake_smsc.send_mo( sequence_number=1, source_addr='123', destination_addr='456', short_message=self.DR_MINIMAL_TEMPLATE % ('foo',), esm_class=4) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_content_with_nulls(self): """ If ``esm_class`` and content are both set appropriately, we process the DR even if some content fields contain null values. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') content = ( "id:%s sub:null dlvrd:null submit date:200101010030" " done date:200101020030 stat:DELIVRD err:null text:Meep") self.fake_smsc.send_mo( sequence_number=1, short_message=content % ("foo",), source_addr='123', destination_addr='456', esm_class=4) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_esm_class_with_bad_content(self): """ If ``esm_class`` indicates a DR but the regex fails to match, we log a warning and do nothing. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') lc = LogCatcher(message="esm_class 4 indicates") with lc: self.fake_smsc.send_mo( sequence_number=1, source_addr='123', destination_addr='456', short_message="foo", esm_class=4) yield self.fake_smsc.await_pdu() # check that failure to process delivery report was logged [warning] = lc.logs self.assertEqual( warning["message"][0], "esm_class 4 indicates delivery report, but content does not" " match regex: 'foo'") for l in lc.logs: self.assertEqual(l['system'], 'sphex') inbound = self.tx_helper.get_dispatched_inbound() self.assertEqual(inbound, []) events = self.tx_helper.get_dispatched_events() self.assertEqual(events, []) @inlineCallbacks def test_mo_delivery_report_esm_class_with_no_content(self): """ If ``esm_class`` indicates a DR but the content is empty, we log a warning and do nothing. """ transport = yield self.get_transport() yield transport.message_stash.set_remote_message_id('bar', 'foo') lc = LogCatcher(message="esm_class 4 indicates") with lc: self.fake_smsc.send_mo( sequence_number=1, source_addr='123', destination_addr='456', short_message=None, esm_class=4) yield self.fake_smsc.await_pdu() # check that failure to process delivery report was logged [warning] = lc.logs self.assertEqual( warning["message"][0], "esm_class 4 indicates delivery report, but content does not" " match regex: None") for l in lc.logs: self.assertEqual(l['system'], 'sphex') inbound = self.tx_helper.get_dispatched_inbound() self.assertEqual(inbound, []) events = self.tx_helper.get_dispatched_events() self.assertEqual(events, []) @inlineCallbacks def test_mo_delivery_report_esm_disabled_with_full_content(self): """ If ``esm_class`` checking is disabled and the content is set appropriately, we process the DR. """ transport = yield self.get_transport({ "delivery_report_processor_config": { "delivery_report_use_esm_class": False, } }) yield transport.message_stash.set_remote_message_id('bar', 'foo') self.fake_smsc.send_mo( sequence_number=1, short_message=self.DR_TEMPLATE % ('foo',), source_addr='123', destination_addr='456', esm_class=0) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_esm_disabled_with_minimal_content(self): """ If ``esm_class`` checking is disabled and the content is set appropriately, we process the DR even if the minimal subset of the content regex matches. """ transport = yield self.get_transport({ "delivery_report_processor_config": { "delivery_report_use_esm_class": False, } }) yield transport.message_stash.set_remote_message_id('bar', 'foo') self.fake_smsc.send_mo( sequence_number=1, source_addr='123', destination_addr='456', short_message=self.DR_MINIMAL_TEMPLATE % ('foo',), esm_class=0) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_delivery_report_esm_disabled_content_with_nulls(self): """ If ``esm_class`` checking is disabled and the content is set appropriately, we process the DR even if some content fields contain null values. """ transport = yield self.get_transport({ "delivery_report_processor_config": { "delivery_report_use_esm_class": False, } }) yield transport.message_stash.set_remote_message_id('bar', 'foo') content = ( "id:%s sub:null dlvrd:null submit date:200101010030" " done date:200101020030 stat:DELIVRD err:null text:Meep") self.fake_smsc.send_mo( sequence_number=1, short_message=content % ("foo",), source_addr='123', destination_addr='456', esm_class=0) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['user_message_id'], 'bar') @inlineCallbacks def test_mo_sms_unicode(self): yield self.get_transport() self.fake_smsc.send_mo( sequence_number=1, short_message='Zo\xc3\xab', data_coding=0) deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertTrue(pdu_ok(deliver_sm_resp)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], u'Zoë') @inlineCallbacks def test_mo_sms_multipart_long(self): yield self.get_transport() content = '1' * 255 pdu = DeliverSM(sequence_number=1) pdu.add_optional_parameter('message_payload', content.encode('hex')) self.fake_smsc.send_pdu(pdu) deliver_sm_resp = yield self.fake_smsc.await_pdu() self.assertEqual(1, seq_no(deliver_sm_resp)) self.assertTrue(pdu_ok(deliver_sm_resp)) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], content) @inlineCallbacks def test_mo_sms_multipart_udh(self): yield self.get_transport() deliver_sm_resps = [] self.fake_smsc.send_mo( sequence_number=1, short_message="\x05\x00\x03\xff\x03\x01back") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.fake_smsc.send_mo( sequence_number=2, short_message="\x05\x00\x03\xff\x03\x02 at") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.fake_smsc.send_mo( sequence_number=3, short_message="\x05\x00\x03\xff\x03\x03 you") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.assertEqual([1, 2, 3], map(seq_no, deliver_sm_resps)) self.assertTrue(all(map(pdu_ok, deliver_sm_resps))) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], u'back at you') @inlineCallbacks def test_mo_sms_multipart_udh_out_of_order(self): yield self.get_transport() deliver_sm_resps = [] self.fake_smsc.send_mo( sequence_number=1, short_message="\x05\x00\x03\xff\x03\x01back") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.fake_smsc.send_mo( sequence_number=3, short_message="\x05\x00\x03\xff\x03\x03 you") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.fake_smsc.send_mo( sequence_number=2, short_message="\x05\x00\x03\xff\x03\x02 at") deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.assertEqual([1, 3, 2], map(seq_no, deliver_sm_resps)) self.assertTrue(all(map(pdu_ok, deliver_sm_resps))) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], u'back at you') @inlineCallbacks def test_mo_sms_multipart_sar(self): yield self.get_transport() deliver_sm_resps = [] pdu1 = DeliverSM(sequence_number=1, short_message='back') pdu1.add_optional_parameter('sar_msg_ref_num', 1) pdu1.add_optional_parameter('sar_total_segments', 3) pdu1.add_optional_parameter('sar_segment_seqnum', 1) self.fake_smsc.send_pdu(pdu1) deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) pdu2 = DeliverSM(sequence_number=2, short_message=' at') pdu2.add_optional_parameter('sar_msg_ref_num', 1) pdu2.add_optional_parameter('sar_total_segments', 3) pdu2.add_optional_parameter('sar_segment_seqnum', 2) self.fake_smsc.send_pdu(pdu2) deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) pdu3 = DeliverSM(sequence_number=3, short_message=' you') pdu3.add_optional_parameter('sar_msg_ref_num', 1) pdu3.add_optional_parameter('sar_total_segments', 3) pdu3.add_optional_parameter('sar_segment_seqnum', 3) self.fake_smsc.send_pdu(pdu3) deliver_sm_resps.append((yield self.fake_smsc.await_pdu())) self.assertEqual([1, 2, 3], map(seq_no, deliver_sm_resps)) self.assertTrue(all(map(pdu_ok, deliver_sm_resps))) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], u'back at you') @inlineCallbacks def test_mo_bad_encoding(self): yield self.get_transport() bad_pdu = DeliverSM(555, short_message="SMS from server containing \xa7", destination_addr="2772222222", source_addr="2772000000", data_coding=1) good_pdu = DeliverSM(555, short_message="Next message", destination_addr="2772222222", source_addr="2772000000", data_coding=1) yield self.fake_smsc.handle_pdu(bad_pdu) yield self.fake_smsc.handle_pdu(good_pdu) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['message_type'], 'user_message') self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['content'], "Next message") dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) [failure] = self.flushLoggedErrors(UnicodeDecodeError) message = failure.getErrorMessage() codec, rest = message.split(' ', 1) self.assertEqual(codec, "'ascii'") self.assertTrue( rest.startswith("codec can't decode byte 0xa7 in position 27")) @inlineCallbacks def test_mo_sms_failed_remote_id_lookup(self): yield self.get_transport() lc = LogCatcher(message="Failed to retrieve message id") with lc: yield self.fake_smsc.handle_pdu( DeliverSM(sequence_number=1, esm_class=4, short_message=self.DR_TEMPLATE % ('foo',))) # check that failure to send delivery report was logged [warning] = lc.logs expected_msg = ( "Failed to retrieve message id for delivery report. Delivery" " report from %s discarded.") % (self.tx_helper.transport_name,) self.assertEqual(warning['message'], (expected_msg,)) for l in lc.logs: self.assertEqual(l['system'], 'sphex') @inlineCallbacks def test_mt_sms(self): yield self.get_transport() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(pdu), 'submit_sm') self.assertEqual(short_message(pdu), 'hello world') @inlineCallbacks def test_mt_sms_bad_to_addr(self): yield self.get_transport() msg = yield self.tx_helper.make_dispatch_outbound( 'hello world', to_addr=u'+\u2000') [event] = self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(event['nack_reason'], u'Invalid to_addr: +\u2000') @inlineCallbacks def test_mt_sms_bad_from_addr(self): yield self.get_transport() msg = yield self.tx_helper.make_dispatch_outbound( 'hello world', from_addr=u'+\u2000') [event] = self.tx_helper.get_dispatched_events() self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(event['nack_reason'], u'Invalid from_addr: +\u2000') @inlineCallbacks def test_mt_sms_submit_sm_encoding(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'submit_sm_encoding': 'latin1', } }) yield self.tx_helper.make_dispatch_outbound(u'Zoë destroyer of Ascii!') submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual( short_message(submit_sm_pdu), u'Zoë destroyer of Ascii!'.encode('latin-1')) @inlineCallbacks def test_mt_sms_submit_sm_null_message(self): """ We can successfully send a message with null content. """ yield self.get_transport() msg = self.tx_helper.make_outbound(None) yield self.tx_helper.dispatch_outbound(msg) pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(pdu), 'submit_sm') self.assertEqual(short_message(pdu), None) @inlineCallbacks def test_submit_sm_data_coding(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'submit_sm_data_coding': 8 } }) yield self.tx_helper.make_dispatch_outbound("hello world") submit_sm_pdu = yield self.fake_smsc.await_pdu() params = submit_sm_pdu['body']['mandatory_parameters'] self.assertEqual(params['data_coding'], 8) @inlineCallbacks def test_mt_sms_ack(self): yield self.get_transport() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(event['sent_message_id'], 'foo') @inlineCallbacks def assert_no_events(self): # NOTE: We can't test for the absence of an event in isolation but we # can test that for the presence of a second event only. fail_msg = self.tx_helper.make_outbound('hello fail') yield self.tx_helper.dispatch_outbound(fail_msg) submit_sm_fail_pdu = yield self.fake_smsc.await_pdu() self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_fail_pdu), message_id='__assert_no_events__', command_status='ESME_RINVDSTADR')) [fail] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(fail['event_type'], 'nack') @inlineCallbacks def test_mt_sms_disabled_ack(self): yield self.get_transport({'disable_ack': True}) msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo')) yield self.assert_no_events() @inlineCallbacks def test_mt_sms_nack(self): yield self.get_transport() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo', command_status='ESME_RINVDSTADR')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(event['nack_reason'], 'ESME_RINVDSTADR') @inlineCallbacks def test_mt_sms_failure(self): yield self.get_transport() message = yield self.tx_helper.make_dispatch_outbound( "message", message_id='446') submit_sm = yield self.fake_smsc.await_pdu() response = SubmitSMResp(seq_no(submit_sm), "3rd_party_id_3", command_status="ESME_RSUBMITFAIL") # A failure PDU might not have a body. response.obj.pop('body') self.fake_smsc.send_pdu(response) # There should be a nack [nack] = yield self.tx_helper.wait_for_dispatched_events(1) [failure] = yield self.tx_helper.get_dispatched_failures() self.assertEqual(failure['reason'], 'ESME_RSUBMITFAIL') self.assertEqual(failure['message'], message.payload) @inlineCallbacks def test_mt_sms_failure_with_no_reason(self): yield self.get_transport() message = yield self.tx_helper.make_dispatch_outbound( "message", message_id='446') submit_sm = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm), message_id='foo', command_status=None)) # There should be a nack [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], message['message_id']) self.assertEqual(nack['nack_reason'], 'Unspecified') [failure] = yield self.tx_helper.get_dispatched_failures() self.assertEqual(failure['reason'], 'Unspecified') @inlineCallbacks def test_mt_sms_seq_num_lookup_failure(self): transport = yield self.get_transport() lc = LogCatcher(message="Failed to retrieve message id") with lc: yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=0xbad, message_id='bad')) # Make sure we didn't store 'None' in redis. message_stash = transport.message_stash message_id = yield message_stash.get_internal_message_id('bad') self.assertEqual(message_id, None) # check that failure to send ack/nack was logged [warning] = lc.logs expected_msg = ( "Failed to retrieve message id for deliver_sm_resp. ack/nack" " from %s discarded.") % (self.tx_helper.transport_name,) self.assertEqual(warning['message'], (expected_msg,)) for l in lc.logs: self.assertEqual(l['system'], 'sphex') @inlineCallbacks def test_mt_sms_throttled(self): transport = yield self.get_transport() transport_config = transport.get_static_config() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() with LogCatcher(message="Throttling outbound messages.") as lc: yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo', command_status='ESME_RTHROTTLED')) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.clock.advance(transport_config.throttle_delay) submit_sm_pdu_retry = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu_retry), message_id='bar', command_status='ESME_ROK')) self.assertTrue(seq_no(submit_sm_pdu_retry) > seq_no(submit_sm_pdu)) self.assertEqual(short_message(submit_sm_pdu), 'hello world') self.assertEqual(short_message(submit_sm_pdu_retry), 'hello world') [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) # We're still throttled until our next attempt to unthrottle finds no # messages to retry. After a non-throttle submit_sm_resp, that happens # with no delay. with LogCatcher(message="No longer throttling outbound") as lc: self.clock.advance(0) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') @inlineCallbacks def test_mt_sms_multipart_throttled(self): """ When parts of a multipart message are throttled, we retry only those PDUs. """ transport = yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) transport_config = transport.get_static_config() msg = self.tx_helper.make_outbound('a' * 350) # Three parts. yield self.tx_helper.dispatch_outbound(msg) [pdu1, pdu2, pdu3] = yield self.fake_smsc.await_pdus(3) self.assertEqual(short_message(pdu1)[4:6], "\x03\x01") self.assertEqual(short_message(pdu2)[4:6], "\x03\x02") self.assertEqual(short_message(pdu3)[4:6], "\x03\x03") # Let two parts through. yield self.fake_smsc.submit_sm_resp(pdu1) yield self.fake_smsc.submit_sm_resp(pdu2) self.assertEqual(transport.throttled, False) # Throttle the third part. yield self.fake_smsc.submit_sm_resp( pdu3, command_status='ESME_RTHROTTLED') self.assertEqual(transport.throttled, True) self.clock.advance(transport_config.throttle_delay) retry_pdu = yield self.fake_smsc.await_pdu() # Assume nothing else is incrementing seuqnce numbers. self.assertEqual(seq_no(retry_pdu), seq_no(pdu3) + 1) # The retry should be identical to pdu3 except for the sequence number. pdu3_retry = dict((k, v.copy()) for k, v in pdu3.iteritems()) pdu3_retry['header']['sequence_number'] = seq_no(retry_pdu) self.assertEqual(retry_pdu, pdu3_retry) # Let the retry through. yield self.fake_smsc.submit_sm_resp(retry_pdu) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(transport.throttled, True) # Prod the clock to notice there are no more retries and unthrottle. self.clock.advance(0) self.assertEqual(transport.throttled, False) @inlineCallbacks def test_mt_sms_throttle_while_throttled(self): transport = yield self.get_transport() transport_config = transport.get_static_config() msg1 = self.tx_helper.make_outbound('hello world 1') msg2 = self.tx_helper.make_outbound('hello world 2') yield self.tx_helper.dispatch_outbound(msg1) yield self.tx_helper.dispatch_outbound(msg2) [ssm_pdu1, ssm_pdu2] = yield self.fake_smsc.await_pdus(2) yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu1), message_id='foo1', command_status='ESME_RTHROTTLED')) yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu2), message_id='foo2', command_status='ESME_RTHROTTLED')) # Advance clock, still throttled. self.clock.advance(transport_config.throttle_delay) ssm_pdu1_retry1 = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu1_retry1), message_id='bar1', command_status='ESME_RTHROTTLED')) # Advance clock, message no longer throttled. self.clock.advance(transport_config.throttle_delay) ssm_pdu2_retry1 = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu2_retry1), message_id='bar2', command_status='ESME_ROK')) # Prod clock, message no longer throttled. self.clock.advance(0) ssm_pdu1_retry2 = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu1_retry2), message_id='baz1', command_status='ESME_ROK')) self.assertEqual(short_message(ssm_pdu1), 'hello world 1') self.assertEqual(short_message(ssm_pdu2), 'hello world 2') self.assertEqual(short_message(ssm_pdu1_retry1), 'hello world 1') self.assertEqual(short_message(ssm_pdu2_retry1), 'hello world 2') self.assertEqual(short_message(ssm_pdu1_retry2), 'hello world 1') [event2, event1] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(event1['event_type'], 'ack') self.assertEqual(event1['user_message_id'], msg1['message_id']) self.assertEqual(event2['event_type'], 'ack') self.assertEqual(event2['user_message_id'], msg2['message_id']) @inlineCallbacks def test_mt_sms_reconnect_while_throttled(self): """ If we reconnect while throttled, we don't try to unthrottle before the connection is in a suitable state. """ transport = yield self.get_transport(bind=False) yield self.fake_smsc.bind() transport_config = transport.get_static_config() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) ssm_pdu = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu), message_id='foo1', command_status='ESME_RTHROTTLED')) # Drop SMPP connection and check throttling. yield self.fake_smsc.disconnect() with LogCatcher(message="Can't check throttling while unbound") as lc: self.clock.advance(transport_config.throttle_delay) [logmsg] = lc.logs self.assertEqual( logmsg["message"][0], "Can't check throttling while unbound, trying later.") for l in lc.logs: self.assertEqual(l['system'], 'sphex') # Fast-forward to reconnect (but don't bind) and check throttling. self.clock.advance(transport.service.delay) bind_pdu = yield self.fake_smsc.await_pdu() self.assertTrue( bind_pdu["header"]["command_id"].startswith("bind_")) with LogCatcher(message="Can't check throttling while unbound") as lc: self.clock.advance(transport_config.throttle_delay) [logmsg] = lc.logs self.assertEqual( logmsg["message"][0], "Can't check throttling while unbound, trying later.") for l in lc.logs: self.assertEqual(l['system'], 'sphex') # Bind and check throttling. yield self.fake_smsc.bind(bind_pdu) with LogCatcher(message="Can't check throttling while unbound") as lc: self.clock.advance(transport_config.throttle_delay) self.assertEqual(lc.logs, []) ssm_pdu_retry = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(ssm_pdu_retry), message_id='foo', command_status='ESME_ROK')) self.assertEqual(short_message(ssm_pdu), 'hello world') self.assertEqual(short_message(ssm_pdu_retry), 'hello world') [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) @inlineCallbacks def test_mt_sms_tps_limits(self): transport = yield self.get_transport({'mt_tps': 2}) with LogCatcher(message="Throttling outbound messages.") as lc: yield self.tx_helper.make_dispatch_outbound('hello world 1') yield self.tx_helper.make_dispatch_outbound('hello world 2') msg3_d = self.tx_helper.make_dispatch_outbound('hello world 3') [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.assertTrue(transport.throttled) [submit_sm_pdu1, submit_sm_pdu2] = yield self.fake_smsc.await_pdus(2) self.assertEqual(short_message(submit_sm_pdu1), 'hello world 1') self.assertEqual(short_message(submit_sm_pdu2), 'hello world 2') self.assertNoResult(msg3_d) with LogCatcher(message="No longer throttling outbound") as lc: self.clock.advance(1) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.assertFalse(transport.throttled) yield msg3_d submit_sm_pdu3 = yield self.fake_smsc.await_pdu() self.assertEqual(short_message(submit_sm_pdu3), 'hello world 3') @inlineCallbacks def test_mt_sms_tps_limits_multipart(self): """ TPS throttling counts PDUs, but finishes sending the current message. """ transport = yield self.get_transport({ 'mt_tps': 3, 'submit_short_message_processor_config': { 'send_multipart_udh': True, }, }) self.assertEqual(transport.throttled, False) with LogCatcher(message="Throttling outbound messages.") as lc: yield self.tx_helper.make_dispatch_outbound('1' * 200 + 'a') yield self.tx_helper.make_dispatch_outbound('2' * 200 + 'b') msg3_d = self.tx_helper.make_dispatch_outbound('3' * 200 + 'c') [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.assertEqual(transport.throttled, True) [pdu1_1, pdu1_2, pdu2_1, pdu2_2] = yield self.fake_smsc.await_pdus(4) self.assertEqual(short_message(pdu1_1)[-5:], '11111') self.assertEqual(short_message(pdu1_2)[-5:], '1111a') self.assertEqual(short_message(pdu2_1)[-5:], '22222') self.assertEqual(short_message(pdu2_2)[-5:], '2222b') self.assertNoResult(msg3_d) with LogCatcher(message="No longer throttling outbound") as lc: self.clock.advance(1) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) self.assertEqual(transport.throttled, False) for l in lc.logs: self.assertEqual(l['system'], 'sphex') yield msg3_d [pdu3_1, pdu3_2] = yield self.fake_smsc.await_pdus(2) self.assertEqual(short_message(pdu3_1)[-5:], '33333') self.assertEqual(short_message(pdu3_2)[-5:], '3333c') @inlineCallbacks def test_mt_sms_reconnect_while_tps_throttled(self): """ If we reconnect while throttled due to the tps limit, we don't try to unthrottle before the connection is in a suitable state. """ transport = yield self.get_transport({'mt_tps': 2}) with LogCatcher(message="Throttling outbound messages.") as lc: yield self.tx_helper.make_dispatch_outbound('hello world 1') yield self.tx_helper.make_dispatch_outbound('hello world 2') msg3_d = self.tx_helper.make_dispatch_outbound('hello world 3') [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.assertTrue(transport.throttled) [submit_sm_pdu1, submit_sm_pdu2] = yield self.fake_smsc.await_pdus(2) self.assertEqual(short_message(submit_sm_pdu1), 'hello world 1') self.assertEqual(short_message(submit_sm_pdu2), 'hello world 2') self.assertNoResult(msg3_d) # Drop SMPP connection and check throttling. yield self.fake_smsc.disconnect() with LogCatcher(message="Can't stop throttling while unbound") as lc: self.clock.advance(1) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) self.assertTrue(transport.throttled) for l in lc.logs: self.assertEqual(l['system'], 'sphex') # Fast-forward to reconnect (but don't bind) and check throttling. self.clock.advance(transport.service.delay) bind_pdu = yield self.fake_smsc.await_pdu() self.assertTrue( bind_pdu["header"]["command_id"].startswith("bind_")) with LogCatcher(message="Can't stop throttling while unbound") as lc: self.clock.advance(1) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) self.assertTrue(transport.throttled) for l in lc.logs: self.assertEqual(l['system'], 'sphex') # Bind and check throttling. yield self.fake_smsc.bind(bind_pdu) with LogCatcher(message="No longer throttling outbound") as lc: self.clock.advance(1) [logmsg] = lc.logs self.assertEqual(logmsg['logLevel'], logging.INFO) for l in lc.logs: self.assertEqual(l['system'], 'sphex') self.assertFalse(transport.throttled) submit_sm_pdu2 = yield self.fake_smsc.await_pdu() self.assertEqual(short_message(submit_sm_pdu2), 'hello world 3') @inlineCallbacks def test_mt_sms_queue_full(self): transport = yield self.get_transport() transport_config = transport.get_static_config() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo', command_status='ESME_RMSGQFUL')) self.clock.advance(transport_config.throttle_delay) submit_sm_pdu_retry = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu_retry), message_id='bar', command_status='ESME_ROK')) self.assertTrue(seq_no(submit_sm_pdu_retry) > seq_no(submit_sm_pdu)) self.assertEqual(short_message(submit_sm_pdu), 'hello world') self.assertEqual(short_message(submit_sm_pdu_retry), 'hello world') [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) @inlineCallbacks def test_mt_sms_remote_id_stored_only_on_rok(self): transport = yield self.get_transport() yield self.tx_helper.make_dispatch_outbound("msg1") submit_sm1 = yield self.fake_smsc.await_pdu() response = SubmitSMResp( seq_no(submit_sm1), "remote_1", command_status="ESME_RSUBMITFAIL") self.fake_smsc.send_pdu(response) yield self.tx_helper.make_dispatch_outbound("msg2") submit_sm2 = yield self.fake_smsc.await_pdu() response = SubmitSMResp( seq_no(submit_sm2), "remote_2", command_status="ESME_ROK") self.fake_smsc.send_pdu(response) yield self.tx_helper.wait_for_dispatched_events(2) self.assertFalse( (yield transport.redis.exists(remote_message_key('remote_1')))) self.assertTrue( (yield transport.redis.exists(remote_message_key('remote_2')))) @inlineCallbacks def test_mt_sms_unicode(self): yield self.get_transport() msg = self.tx_helper.make_outbound(u'Zoë') yield self.tx_helper.dispatch_outbound(msg) pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(pdu), 'submit_sm') self.assertEqual(short_message(pdu), 'Zo\xc3\xab') @inlineCallbacks def test_mt_sms_multipart_long(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'send_long_messages': True, } }) # SMPP specifies that messages longer than 254 bytes should # be put in the message_payload field using TLVs content = '1' * 255 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) submit_sm = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(submit_sm, 'message_payload').decode('hex'), content) @inlineCallbacks def test_mt_sms_multipart_udh(self): """ Sufficiently long messages are split into multiple PDUs with a UDH at the front of each. """ transport = yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) self.assertEqual( submit_sm1["body"]["mandatory_parameters"]["esm_class"], 0x40) self.assertEqual( submit_sm2["body"]["mandatory_parameters"]["esm_class"], 0x40) udh_hlen, udh_tag, udh_len, udh_ref, udh_tot, udh_seq = [ ord(octet) for octet in short_message(submit_sm1)[:6]] self.assertEqual(5, udh_hlen) self.assertEqual(0, udh_tag) self.assertEqual(3, udh_len) self.assertEqual(udh_tot, 2) self.assertEqual(udh_seq, 1) _, _, _, ref_to_udh_ref, _, udh_seq = [ ord(octet) for octet in short_message(submit_sm2)[:6]] self.assertEqual(ref_to_udh_ref, udh_ref) self.assertEqual(udh_seq, 2) # Our multipart_info Redis hash should contain the number of parts and # have an appropriate TTL. mstash = transport.message_stash multipart_info = yield mstash.get_multipart_info(msg['message_id']) self.assertEqual(multipart_info, {"parts": "2"}) mpi_ttl = yield mstash.redis.ttl(multipart_info_key(msg['message_id'])) self.assertTrue( mpi_ttl <= mstash.config.submit_sm_expiry, "mpi_ttl (%s) > submit_sm_expiry (%s)" % ( mpi_ttl, mstash.config.submit_sm_expiry)) @inlineCallbacks def test_mt_sms_multipart_udh_one_part(self): """ Messages that fit in a single part should not have a UDH. """ yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = "1" * 158 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) submit_sm = yield self.fake_smsc.await_pdu() self.assertEqual( submit_sm["body"]["mandatory_parameters"]["esm_class"], 0) self.assertEqual(short_message(submit_sm), "1" * 158) @inlineCallbacks def test_mt_sms_multipart_sar(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_sar': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) ref_num = pdu_tlv(submit_sm1, 'sar_msg_ref_num') self.assertEqual(pdu_tlv(submit_sm1, 'sar_total_segments'), 2) self.assertEqual(pdu_tlv(submit_sm1, 'sar_segment_seqnum'), 1) self.assertEqual(pdu_tlv(submit_sm2, 'sar_msg_ref_num'), ref_num) self.assertEqual(pdu_tlv(submit_sm2, 'sar_total_segments'), 2) self.assertEqual(pdu_tlv(submit_sm2, 'sar_segment_seqnum'), 2) @inlineCallbacks def test_mt_sms_multipart_sar_one_part(self): """ Messages that fit in a single part should not have SAR params set. """ yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_sar': True, } }) content = '1' * 158 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) submit_sm = yield self.fake_smsc.await_pdu() self.assertEqual(unpacked_pdu_opts(submit_sm), {}) self.assertEqual(short_message(submit_sm), "1" * 158) @inlineCallbacks def test_mt_sms_multipart_ack(self): """ When all PDUs of a multipart message have been successfully acknowledged, we clean up the relevant transient state and send an ack. """ transport = yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) # Our multipart_info Redis hash should contain the number of parts and # have an appropriate TTL. mstash = transport.message_stash multipart_info = yield mstash.get_multipart_info(msg['message_id']) self.assertEqual(multipart_info, {"parts": "2"}) mpi_ttl = yield mstash.redis.ttl(multipart_info_key(msg['message_id'])) self.assertTrue( mpi_ttl <= mstash.config.submit_sm_expiry, "mpi_ttl (%s) > submit_sm_expiry (%s)" % ( mpi_ttl, mstash.config.submit_sm_expiry)) # We get one response per PDU, so we only send the ack after receiving # both responses. self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm1), message_id='foo')) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm2), message_id='bar')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], msg['message_id']) self.assertEqual(event['sent_message_id'], 'bar,foo') # After all parts are acknowledged, our multipart_info hash should have # the details of the responses and a much shorter TTL. mstash = transport.message_stash multipart_info = yield mstash.get_multipart_info(msg['message_id']) self.assertEqual(multipart_info, { "parts": "2", "event_counter": "2", "part:foo": "ack", "part:bar": "ack", }) mpi_ttl = yield mstash.redis.ttl(multipart_info_key(msg['message_id'])) self.assertTrue( mpi_ttl <= mstash.config.completed_multipart_info_expiry, "mpi_ttl (%s) > completed_multipart_info_expiry (%s)" % ( mpi_ttl, mstash.config.completed_multipart_info_expiry)) @inlineCallbacks def test_mt_sms_multipart_fail_first_part(self): """ When all PDUs of a multipart message have been acknowledged and at least one of them failed, we clean up the relevant transient state and send a nack. """ transport = yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) # Our multipart_info Redis hash should contain the number of parts and # have an appropriate TTL. mstash = transport.message_stash multipart_info = yield mstash.get_multipart_info(msg['message_id']) self.assertEqual(multipart_info, {"parts": "2"}) mpi_ttl = yield mstash.redis.ttl(multipart_info_key(msg['message_id'])) self.assertTrue( mpi_ttl <= mstash.config.submit_sm_expiry, "mpi_ttl (%s) > submit_sm_expiry (%s)" % ( mpi_ttl, mstash.config.submit_sm_expiry)) # We get one response per PDU, so we only send the nack after receiving # both responses. self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm1), message_id='foo', command_status='ESME_RSUBMITFAIL')) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm2), message_id='bar')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) # After all parts are acknowledged, our multipart_info hash should have # the details of the responses and a much shorter TTL. mstash = transport.message_stash multipart_info = yield mstash.get_multipart_info(msg['message_id']) self.assertEqual(multipart_info, { "parts": "2", "event_counter": "2", "part:foo": "fail", "part:bar": "ack", "event_result": "fail", }) mpi_ttl = yield mstash.redis.ttl(multipart_info_key(msg['message_id'])) self.assertTrue( mpi_ttl <= mstash.config.completed_multipart_info_expiry, "mpi_ttl (%s) > completed_multipart_info_expiry (%s)" % ( mpi_ttl, mstash.config.completed_multipart_info_expiry)) @inlineCallbacks def test_mt_sms_multipart_fail_second_part(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm1), message_id='foo')) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm2), message_id='bar', command_status='ESME_RSUBMITFAIL')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) @inlineCallbacks def test_mt_sms_multipart_fail_no_remote_id(self): yield self.get_transport({ 'submit_short_message_processor_config': { 'send_multipart_udh': True, } }) content = '1' * 161 msg = self.tx_helper.make_outbound(content) yield self.tx_helper.dispatch_outbound(msg) [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm1), message_id='', command_status='ESME_RINVDSTADR')) self.fake_smsc.send_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm2), message_id='', command_status='ESME_RINVDSTADR')) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['event_type'], 'nack') self.assertEqual(event['user_message_id'], msg['message_id']) @inlineCallbacks def test_message_persistence(self): transport = yield self.get_transport() message_stash = transport.message_stash config = transport.get_static_config() msg = self.tx_helper.make_outbound("hello world") yield message_stash.cache_message(msg) ttl = yield transport.redis.ttl(message_key(msg['message_id'])) self.assertTrue(0 < ttl <= config.submit_sm_expiry) retrieved_msg = yield message_stash.get_cached_message( msg['message_id']) self.assertEqual(msg, retrieved_msg) yield message_stash.delete_cached_message(msg['message_id']) self.assertEqual( (yield message_stash.get_cached_message(msg['message_id'])), None) @inlineCallbacks def test_message_clearing(self): transport = yield self.get_transport() message_stash = transport.message_stash msg = self.tx_helper.make_outbound('hello world') yield message_stash.set_sequence_number_message_id( 3, msg['message_id']) yield message_stash.cache_message(msg) yield self.fake_smsc.handle_pdu(SubmitSMResp( sequence_number=3, message_id='foo', command_status='ESME_ROK')) self.assertEqual( None, (yield message_stash.get_cached_message(msg['message_id']))) @inlineCallbacks def test_sequence_number_persistence(self): """ We create sequence_number to message_id mappings with an appropriate TTL and can delete them when we're done. """ transport = yield self.get_transport() message_stash = transport.message_stash config = transport.get_static_config() yield message_stash.set_sequence_number_message_id(12, "abc") ttl = yield transport.redis.ttl(sequence_number_key(12)) self.assertTrue(0 < ttl <= config.submit_sm_expiry) message_id = yield message_stash.get_sequence_number_message_id(12) self.assertEqual(message_id, "abc") yield message_stash.delete_sequence_number_message_id(12) message_id = yield message_stash.get_sequence_number_message_id(12) self.assertEqual(message_id, None) @inlineCallbacks def test_sequence_number_clearing(self): """ When we finish processing a PDU response, the mapping gets deleted. """ transport = yield self.get_transport() message_stash = transport.message_stash yield message_stash.set_sequence_number_message_id(37, "def") message_id = yield message_stash.get_sequence_number_message_id(37) self.assertEqual(message_id, "def") yield self.fake_smsc.handle_pdu(SubmitSMResp( sequence_number=37, message_id='foo', command_status='ESME_ROK')) message_id = yield message_stash.get_sequence_number_message_id(37) self.assertEqual(message_id, None) @inlineCallbacks def test_link_remote_message_id(self): transport = yield self.get_transport() config = transport.get_static_config() msg = self.tx_helper.make_outbound('hello world') yield self.tx_helper.dispatch_outbound(msg) pdu = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(pdu), message_id='foo', command_status='ESME_ROK')) self.assertEqual( msg['message_id'], (yield transport.message_stash.get_internal_message_id('foo'))) ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue(0 < ttl <= config.third_party_id_expiry) @inlineCallbacks def test_out_of_order_responses(self): yield self.get_transport() yield self.tx_helper.make_dispatch_outbound("msg 1", message_id='444') submit_sm1 = yield self.fake_smsc.await_pdu() response1 = SubmitSMResp(seq_no(submit_sm1), "3rd_party_id_1") yield self.tx_helper.make_dispatch_outbound("msg 2", message_id='445') submit_sm2 = yield self.fake_smsc.await_pdu() response2 = SubmitSMResp(seq_no(submit_sm2), "3rd_party_id_2") # respond out of order - just to keep things interesting yield self.fake_smsc.handle_pdu(response2) yield self.fake_smsc.handle_pdu(response1) [ack1, ack2] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(ack1['user_message_id'], '445') self.assertEqual(ack1['sent_message_id'], '3rd_party_id_2') self.assertEqual(ack2['user_message_id'], '444') self.assertEqual(ack2['sent_message_id'], '3rd_party_id_1') @inlineCallbacks def test_delivery_report_for_unknown_message(self): dr = self.DR_TEMPLATE % ('foo',) deliver = DeliverSM(1, short_message=dr, esm_class=4) yield self.get_transport() with LogCatcher(message="Failed to retrieve message id") as lc: yield self.fake_smsc.handle_pdu(deliver) [warning] = lc.logs self.assertEqual(warning['message'], ("Failed to retrieve message id for delivery " "report. Delivery report from %s " "discarded." % self.tx_helper.transport_name,)) @inlineCallbacks def test_delivery_report_delivered_delete_stored_remote_id(self): transport = yield self.get_transport({ 'final_dr_third_party_id_expiry': 23, }) yield transport.message_stash.set_remote_message_id('bar', 'foo') remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl > 23, "remote_id_ttl (%s) <= final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) pdu = DeliverSM(sequence_number=1, esm_class=4) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 2) yield self.fake_smsc.handle_pdu(pdu) [dr] = yield self.tx_helper.wait_for_dispatched_events(1) remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl <= 23, "remote_id_ttl (%s) > final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) self.assertEqual(dr['event_type'], u'delivery_report') self.assertEqual(dr['delivery_status'], u'delivered') self.assertEqual(dr['transport_metadata'], { u'smpp_delivery_status': u'DELIVERED', }) @inlineCallbacks def test_delivery_report_failed_delete_stored_remote_id(self): transport = yield self.get_transport({ 'final_dr_third_party_id_expiry': 23, }) yield transport.message_stash.set_remote_message_id('bar', 'foo') remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl > 23, "remote_id_ttl (%s) <= final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) pdu = DeliverSM(sequence_number=1, esm_class=4) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 8) yield self.fake_smsc.handle_pdu(pdu) [dr] = yield self.tx_helper.wait_for_dispatched_events(1) remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl <= 23, "remote_id_ttl (%s) > final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) self.assertEqual(dr['event_type'], u'delivery_report') self.assertEqual(dr['delivery_status'], u'failed') self.assertEqual(dr['transport_metadata'], { u'smpp_delivery_status': u'REJECTED', }) @inlineCallbacks def test_delivery_report_pending_keep_stored_remote_id(self): transport = yield self.get_transport({ 'final_dr_third_party_id_expiry': 23, }) yield transport.message_stash.set_remote_message_id('bar', 'foo') remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl > 23, "remote_id_ttl (%s) <= final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) pdu = DeliverSM(sequence_number=1, esm_class=4) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 1) yield self.fake_smsc.handle_pdu(pdu) [dr] = yield self.tx_helper.wait_for_dispatched_events(1) remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl > 23, "remote_id_ttl (%s) <= final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) self.assertEqual(dr['event_type'], u'delivery_report') self.assertEqual(dr['delivery_status'], u'pending') self.assertEqual(dr['transport_metadata'], { u'smpp_delivery_status': u'ENROUTE', }) @inlineCallbacks def test_disable_delivery_report_delivered_delete_stored_remote_id(self): transport = yield self.get_transport({ 'final_dr_third_party_id_expiry': 23, 'disable_delivery_report': True, }) yield transport.message_stash.set_remote_message_id('bar', 'foo') remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl > 23, "remote_id_ttl (%s) <= final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) pdu = DeliverSM(sequence_number=1, esm_class=4) pdu.add_optional_parameter('receipted_message_id', 'foo') pdu.add_optional_parameter('message_state', 2) yield self.fake_smsc.handle_pdu(pdu) yield self.fake_smsc.await_pdu() yield self.assert_no_events() remote_id_ttl = yield transport.redis.ttl(remote_message_key('foo')) self.assertTrue( remote_id_ttl <= 23, "remote_id_ttl (%s) > final_dr_third_party_id_expiry (23)" % (remote_id_ttl,)) @inlineCallbacks def test_reconnect(self): transport = yield self.get_transport(bind=False) connector = transport.connectors[transport.transport_name] # Unbound and disconnected. self.assertEqual(connector._consumers['outbound'].paused, True) # Connect and bind. yield self.fake_smsc.bind() self.assertEqual(connector._consumers['outbound'].paused, False) # Disconnect. yield self.fake_smsc.disconnect() self.assertEqual(connector._consumers['outbound'].paused, True) # Wait for reconnect, but don't bind. self.clock.advance(transport.service.delay) yield self.fake_smsc.await_connected() self.assertEqual(connector._consumers['outbound'].paused, True) # Bind. yield self.fake_smsc.bind() self.assertEqual(connector._consumers['outbound'].paused, False) @inlineCallbacks def test_bind_params(self): yield self.get_transport({ 'system_id': 'myusername', 'password': 'mypasswd', 'system_type': 'SMPP', 'interface_version': '33', 'address_range': '*12345', }, bind=False) bind_pdu = yield self.fake_smsc.await_pdu() # This test runs for multiple bind types, so we only assert on the # common prefix of the command. self.assertEqual(bind_pdu['header']['command_id'][:5], 'bind_') self.assertEqual(bind_pdu['body'], {'mandatory_parameters': { 'system_id': 'myusername', 'password': 'mypasswd', 'system_type': 'SMPP', 'interface_version': '33', 'address_range': '*12345', 'addr_ton': 'unknown', 'addr_npi': 'unknown', }}) @inlineCallbacks def test_bind_params_long_password(self): lc = LogCatcher(message="Password longer than 8 characters,") with lc: yield self.get_transport({ 'worker_name': 'sphex', 'system_id': 'myusername', 'password': 'mypass789', 'system_type': 'SMPP', 'interface_version': '33', 'address_range': '*12345', }, bind=False) bind_pdu = yield self.fake_smsc.await_pdu() # This test runs for multiple bind types, so we only assert on the # common prefix of the command. self.assertEqual(bind_pdu['header']['command_id'][:5], 'bind_') self.assertEqual(bind_pdu['body'], {'mandatory_parameters': { 'system_id': 'myusername', 'password': 'mypass78', 'system_type': 'SMPP', 'interface_version': '33', 'address_range': '*12345', 'addr_ton': 'unknown', 'addr_npi': 'unknown', }}) # Check that the truncation was logged. [warning] = lc.logs expected_msg = "Password longer than 8 characters, truncating." self.assertEqual(warning['message'], (expected_msg,)) for l in lc.logs: self.assertEqual(l['system'], 'sphex') @inlineCallbacks def test_default_bind_params(self): yield self.get_transport(bind=False) bind_pdu = yield self.fake_smsc.await_pdu() # This test runs for multiple bind types, so we only assert on the # common prefix of the command. self.assertEqual(bind_pdu['header']['command_id'][:5], 'bind_') self.assertEqual(bind_pdu['body'], {'mandatory_parameters': { 'system_id': 'foo', # Mandatory param, defaulted by helper. 'password': 'bar', # Mandatory param, defaulted by helper. 'system_type': '', 'interface_version': '34', 'address_range': '', 'addr_ton': 'unknown', 'addr_npi': 'unknown', }}) @inlineCallbacks def test_startup_with_backlog(self): yield self.get_transport(bind=False) # Disconnected. for i in range(2): msg = self.tx_helper.make_outbound('hello world %s' % (i,)) yield self.tx_helper.dispatch_outbound(msg) # Connect and bind. yield self.fake_smsc.bind() [submit_sm1, submit_sm2] = yield self.fake_smsc.await_pdus(2) self.assertEqual(short_message(submit_sm1), 'hello world 0') self.assertEqual(short_message(submit_sm2), 'hello world 1') @inlineCallbacks def test_starting_status(self): """ The SMPP bind process emits three status events. """ yield self.get_transport({'publish_status': True}) msgs = yield self.tx_helper.wait_for_dispatched_statuses() [msg_starting, msg_binding, msg_bound] = msgs self.assertEqual(msg_starting['status'], 'down') self.assertEqual(msg_starting['component'], 'smpp') self.assertEqual(msg_starting['type'], 'starting') self.assertEqual(msg_starting['message'], 'Starting') self.assertEqual(msg_binding['status'], 'down') self.assertEqual(msg_binding['component'], 'smpp') self.assertEqual(msg_binding['type'], 'binding') self.assertEqual(msg_binding['message'], 'Binding') self.assertEqual(msg_bound['status'], 'ok') self.assertEqual(msg_bound['component'], 'smpp') self.assertEqual(msg_bound['type'], 'bound') self.assertEqual(msg_bound['message'], 'Bound') @inlineCallbacks def test_connect_status(self): transport = yield self.get_transport( {'publish_status': True}, bind=False) # disconnect yield self.fake_smsc.disconnect() self.tx_helper.clear_dispatched_statuses() # reconnect self.clock.advance(transport.service.delay) yield self.fake_smsc.await_connected() [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'binding') self.assertEqual(msg['message'], 'Binding') @inlineCallbacks def test_unbinding_status(self): transport = yield self.get_transport({'publish_status': True}) self.tx_helper.clear_dispatched_statuses() yield transport.service.get_protocol().unbind() [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'unbinding') self.assertEqual(msg['message'], 'Unbinding') @inlineCallbacks def test_bind_status(self): yield self.get_transport({'publish_status': True}, bind=False) self.tx_helper.clear_dispatched_statuses() yield self.fake_smsc.bind() [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'ok') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'bound') self.assertEqual(msg['message'], 'Bound') @inlineCallbacks def test_bind_timeout_status(self): yield self.get_transport({ 'publish_status': True, 'smpp_bind_timeout': 3, }, bind=False) # wait for bind pdu yield self.fake_smsc.await_pdu() self.tx_helper.clear_dispatched_statuses() self.clock.advance(3) [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'bind_timeout') self.assertEqual(msg['message'], 'Timed out awaiting bind') yield self.fake_smsc.disconnect() @inlineCallbacks def test_connection_lost_status(self): yield self.get_transport({'publish_status': True}) self.tx_helper.clear_dispatched_statuses() yield self.fake_smsc.disconnect() [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'down') self.assertEqual(msg['status'], 'down') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'connection_lost') self.assertEqual( msg['message'], 'Connection was closed cleanly: Connection done.') @inlineCallbacks def test_smsc_throttle_status(self): yield self.get_transport({ 'publish_status': True, 'throttle_delay': 3 }) self.tx_helper.clear_dispatched_statuses() msg = self.tx_helper.make_outbound("throttle me") yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo', command_status='ESME_RTHROTTLED')) [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'degraded') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'throttled') self.assertEqual(msg['message'], 'Throttled') self.tx_helper.clear_dispatched_statuses() self.clock.advance(3) submit_sm_pdu_retry = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu_retry), message_id='bar', command_status='ESME_ROK')) self.clock.advance(0) [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'ok') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'throttled_end') self.assertEqual(msg['message'], 'No longer throttled') @inlineCallbacks def test_smsc_throttle_reconnect_status(self): transport = yield self.get_transport({ 'publish_status': True, }) self.tx_helper.clear_dispatched_statuses() msg = self.tx_helper.make_outbound("throttle me") yield self.tx_helper.dispatch_outbound(msg) submit_sm_pdu = yield self.fake_smsc.await_pdu() yield self.fake_smsc.handle_pdu( SubmitSMResp(sequence_number=seq_no(submit_sm_pdu), message_id='foo', command_status='ESME_RTHROTTLED')) yield self.fake_smsc.disconnect() self.tx_helper.clear_dispatched_statuses() self.clock.advance(transport.service.delay) yield self.fake_smsc.bind() msgs = yield self.tx_helper.wait_for_dispatched_statuses() [msg1, msg2, msg3] = msgs self.assertEqual(msg1['type'], 'binding') self.assertEqual(msg2['type'], 'bound') self.assertEqual(msg3['status'], 'degraded') self.assertEqual(msg3['component'], 'smpp') self.assertEqual(msg3['type'], 'throttled') self.assertEqual(msg3['message'], 'Throttled') @inlineCallbacks def test_tps_throttle_status(self): yield self.get_transport({ 'publish_status': True, 'mt_tps': 2 }) self.tx_helper.clear_dispatched_statuses() yield self.tx_helper.make_dispatch_outbound('hello world 1') yield self.tx_helper.make_dispatch_outbound('hello world 2') self.tx_helper.make_dispatch_outbound('hello world 3') yield self.fake_smsc.await_pdus(2) # We can't wait here because that requires throttling to end. [msg] = self.tx_helper.get_dispatched_statuses() self.assertEqual(msg['status'], 'degraded') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'throttled') self.assertEqual(msg['message'], 'Throttled') self.tx_helper.clear_dispatched_statuses() self.clock.advance(1) [msg] = yield self.tx_helper.wait_for_dispatched_statuses() self.assertEqual(msg['status'], 'ok') self.assertEqual(msg['component'], 'smpp') self.assertEqual(msg['type'], 'throttled_end') self.assertEqual(msg['message'], 'No longer throttled') @inlineCallbacks def test_tps_throttle_reconnect_status(self): transport = yield self.get_transport({ 'publish_status': True, 'mt_tps': 2 }) self.tx_helper.clear_dispatched_statuses() yield self.tx_helper.make_dispatch_outbound('hello world 1') yield self.tx_helper.make_dispatch_outbound('hello world 2') self.tx_helper.make_dispatch_outbound('hello world 3') yield self.fake_smsc.await_pdus(2) yield self.fake_smsc.disconnect() self.tx_helper.clear_dispatched_statuses() self.clock.advance(transport.service.delay) yield self.fake_smsc.bind() msgs = yield self.tx_helper.wait_for_dispatched_statuses() [msg1, msg2, msg3] = msgs self.assertEqual(msg1['type'], 'binding') self.assertEqual(msg2['type'], 'bound') self.assertEqual(msg3['status'], 'degraded') self.assertEqual(msg3['component'], 'smpp') self.assertEqual(msg3['type'], 'throttled') self.assertEqual(msg3['message'], 'Throttled') class SmppTransmitterTransportTestCase(SmppTransceiverTransportTestCase): transport_class = SmppTransmitterTransport class SmppReceiverTransportTestCase(SmppTransceiverTransportTestCase): transport_class = SmppReceiverTransport class SmppTransceiverTransportWithOldConfigTestCase( SmppTransceiverTransportTestCase): transport_class = SmppTransceiverTransportWithOldConfig def setUp(self): self.clock = Clock() self.fake_smsc = FakeSMSC() self.tx_helper = self.add_helper(TransportHelper(self.transport_class)) self.default_config = { 'transport_name': self.tx_helper.transport_name, 'worker_name': self.tx_helper.transport_name, 'twisted_endpoint': self.fake_smsc.endpoint, 'system_id': 'foo', 'password': 'bar', 'data_coding_overrides': { 0: 'utf-8', } } def _get_transport_config(self, config): """ The test cases assume the new config, this flattens the config key word arguments value to match an old config layout without the processor configs. """ cfg = self.default_config.copy() processor_config_keys = [ 'submit_short_message_processor_config', 'deliver_short_message_processor_config', 'delivery_report_processor_config', ] for config_key in processor_config_keys: processor_config = config.pop(config_key, {}) for name, value in processor_config.items(): cfg[name] = value # Update with all remaining (non-processor) config values cfg.update(config) return cfg class TataUssdSmppTransportTestCase(SmppTransportTestCase): transport_class = SmppTransceiverTransport @inlineCallbacks def test_submit_and_deliver_ussd_continue(self): yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd") submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '02') self.assertEqual(pdu_tlv(submit_sm_pdu, 'its_session_info'), '0000') # Server delivers a USSD message to the Client pdu = DeliverSM(seq_no(submit_sm_pdu) + 1, short_message="reply!") pdu.add_optional_parameter('ussd_service_op', '02') pdu.add_optional_parameter('its_session_info', '0000') yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_submit_and_deliver_ussd_close(self): yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '02') self.assertEqual(pdu_tlv(submit_sm_pdu, 'its_session_info'), '0001') # Server delivers a USSD message to the Client pdu = DeliverSM(seq_no(submit_sm_pdu) + 1, short_message="reply!") pdu.add_optional_parameter('ussd_service_op', '02') pdu.add_optional_parameter('its_session_info', '0001') yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_CLOSE) class TestSubmitShortMessageProcessorConfig(VumiTestCase): def get_config(self, config_dict): return SubmitShortMessageProcessor.CONFIG_CLASS(config_dict) def assert_config_error(self, config_dict): try: self.get_config(config_dict) self.fail("ConfigError not raised.") except ConfigError as err: return err.args[0] def test_long_message_params(self): self.get_config({}) self.get_config({'send_long_messages': True}) self.get_config({'send_multipart_sar': True}) self.get_config({'send_multipart_udh': True}) errmsg = self.assert_config_error({ 'send_long_messages': True, 'send_multipart_sar': True, }) self.assertEqual(errmsg, ( "The following parameters are mutually exclusive: " "send_long_messages, send_multipart_sar")) errmsg = self.assert_config_error({ 'send_long_messages': True, 'send_multipart_sar': True, 'send_multipart_udh': True, }) self.assertEqual(errmsg, ( "The following parameters are mutually exclusive: " "send_long_messages, send_multipart_sar, send_multipart_udh")) PKqGg>H>H/vumi/transports/smpp/tests/test_smpp_service.py# -*- coding: utf-8 -*- from twisted.internet.defer import inlineCallbacks, succeed, gatherResults from twisted.internet.task import Clock from smpp.pdu_builder import Unbind, SubmitSM from vumi.log import WrappingLogger from vumi.tests.helpers import VumiTestCase, PersistenceHelper, skiptest from vumi.transports.smpp.smpp_transport import ( SmppTransceiverTransport, SmppMessageDataStash, pdu_key) from vumi.transports.smpp.protocol import EsmeProtocol, EsmeProtocolError from vumi.transports.smpp.smpp_service import SmppService from vumi.transports.smpp.pdu_utils import ( command_id, unpacked_pdu_opts, short_message) from vumi.transports.smpp.sequence import RedisSequence from vumi.transports.smpp.tests.fake_smsc import FakeSMSC class DummySmppTransport(object): def __init__(self, clock, redis, config): self.log = WrappingLogger(system=config.get('worker_name')) self.clock = clock self.redis = redis self._config = config self._static_config = SmppTransceiverTransport.CONFIG_CLASS( self._config, static=True) config = self.get_static_config() self.transport_name = config.transport_name self.dr_processor = config.delivery_report_processor( self, config.delivery_report_processor_config) self.deliver_sm_processor = config.deliver_short_message_processor( self, config.deliver_short_message_processor_config) self.submit_sm_processor = config.submit_short_message_processor( self, config.submit_short_message_processor_config) self.sequence_generator = RedisSequence(self.redis) self.message_stash = SmppMessageDataStash(self.redis, config) self.paused = True def get_static_config(self): return self._static_config def pause_connectors(self): self.paused = True def unpause_connectors(self): self.paused = False def on_smpp_binding(self): pass def on_smpp_unbinding(self): pass def on_smpp_bind(self): pass def on_smpp_bind_timeout(self): pass def on_throttled(self): pass def on_throttled_end(self): pass class TestSmppService(VumiTestCase): @inlineCallbacks def setUp(self): self.clock = Clock() self.persistence_helper = self.add_helper(PersistenceHelper()) self.redis = yield self.persistence_helper.get_redis_manager() self.fake_smsc = FakeSMSC(auto_accept=False) self.default_config = { 'transport_name': 'sphex_transport', 'twisted_endpoint': self.fake_smsc.endpoint, 'system_id': 'system_id', 'password': 'password', } def get_service(self, config={}, bind_type='TRX', start=True): """ Create and optionally start a new service object. """ cfg = self.default_config.copy() cfg.update(config) dummy_transport = DummySmppTransport(self.clock, self.redis, cfg) service = SmppService( self.fake_smsc.endpoint, bind_type, dummy_transport) service.clock = self.clock d = succeed(service) if start: d.addCallback(self.start_service) return d def start_service(self, service, accept_connection=True): """ Start the given service. """ service.startService() self.clock.advance(0) d = self.fake_smsc.await_connecting() if accept_connection: d.addCallback(lambda _: self.fake_smsc.accept_connection()) return d.addCallback(lambda _: service) def lookup_message_ids(self, service, seq_nums): """ Find vumi message ids associated with SMPP sequence numbers. """ lookup_func = service.message_stash.get_sequence_number_message_id return gatherResults([lookup_func(seq_num) for seq_num in seq_nums]) def set_sequence_number(self, service, seq_nr): return service.sequence_generator.redis.set( 'smpp_last_sequence_number', seq_nr) @inlineCallbacks def test_start_sequence(self): """ The service goes through several states while starting. """ # New service, never started. service = yield self.get_service(start=False) self.assertEqual(service.running, False) self.assertEqual(service.get_bind_state(), EsmeProtocol.CLOSED_STATE) # Start, but don't connect. yield self.start_service(service, accept_connection=False) self.assertEqual(service.running, True) self.assertEqual(service.get_bind_state(), EsmeProtocol.CLOSED_STATE) # Connect, but don't bind. yield self.fake_smsc.accept_connection() self.assertEqual(service.running, True) self.assertEqual(service.get_bind_state(), EsmeProtocol.OPEN_STATE) bind_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(bind_pdu), 'bind_transceiver') # Bind. yield self.fake_smsc.bind(bind_pdu) self.assertEqual(service.running, True) self.assertEqual( service.get_bind_state(), EsmeProtocol.BOUND_STATE_TRX) @inlineCallbacks def test_connect_retries(self): """ If we fail to connect, we retry. """ service = yield self.get_service(start=False) self.assertEqual(self.fake_smsc.has_pending_connection(), False) # Start, but don't connect. yield self.start_service(service, accept_connection=False) self.assertEqual(self.fake_smsc.has_pending_connection(), True) self.assertEqual(service._protocol, None) self.assertEqual(service.retries, 1) # Reject the connection. yield self.fake_smsc.reject_connection() self.assertEqual(service._protocol, None) self.assertEqual(service.retries, 2) # Advance to the next connection attempt. self.clock.advance(service.delay) self.assertEqual(self.fake_smsc.has_pending_connection(), True) self.assertEqual(service._protocol, None) self.assertEqual(service.retries, 2) # Accept the connection. yield self.fake_smsc.accept_connection() self.assertEqual(service.running, True) self.assertNotEqual(service._protocol, None) @inlineCallbacks def test_submit_sm(self): """ When bound, we can send a message. """ service = yield self.get_service() yield self.fake_smsc.bind() seq_nums = yield service.submit_sm( 'abc123', 'dest_addr', short_message='foo') submit_sm = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm), 'submit_sm') stored_ids = yield self.lookup_message_ids(service, seq_nums) self.assertEqual(['abc123'], stored_ids) @inlineCallbacks def test_submit_sm_unbound(self): """ When unbound, we can't send a message. """ service = yield self.get_service() self.assertRaises( EsmeProtocolError, service.submit_sm, 'abc123', 'dest_addr', short_message='foo') @inlineCallbacks def test_submit_sm_not_connected(self): """ When not connected, we can't send a message. """ service = yield self.get_service(start=False) yield self.start_service(service, accept_connection=False) self.assertRaises( EsmeProtocolError, service.submit_sm, 'abc123', 'dest_addr', short_message='foo') @skiptest("FIXME: We don't actually unbind and disconnect yet.") @inlineCallbacks def test_handle_unbind(self): """ If the SMSC sends an unbind command, we respond and disconnect. """ service = yield self.get_service() yield self.fake_smsc.bind() self.assertEqual(service.is_bound(), True) self.fake_smsc.send_pdu(Unbind(7)) unbind_resp_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(unbind_resp_pdu), 'unbind_resp') self.assertEqual(service.is_bound(), False) @inlineCallbacks def test_csm_split_message(self): """ A multipart message is split into chunks such that the smallest number of message parts are required. """ service = yield self.get_service() split = lambda msg: service.csm_split_message(msg.encode('utf-8')) # these are fine because they're in the 7-bit character set self.assertEqual(1, len(split(u'&' * 140))) self.assertEqual(1, len(split(u'&' * 160))) # ± is not in the 7-bit character set so it should utf-8 encode it # which bumps it over the 140 bytes self.assertEqual(2, len(split(u'±' + u'1' * 139))) @inlineCallbacks def test_submit_sm_long(self): """ A long message can be sent in a single PDU using the optional `message_payload` PDU field. """ service = yield self.get_service() yield self.fake_smsc.bind() long_message = 'This is a long message.' * 20 seq_nums = yield service.submit_sm_long( 'abc123', 'dest_addr', long_message) submit_sm = yield self.fake_smsc.await_pdu() pdu_opts = unpacked_pdu_opts(submit_sm) self.assertEqual('submit_sm', submit_sm['header']['command_id']) self.assertEqual( None, submit_sm['body']['mandatory_parameters']['short_message']) self.assertEqual(''.join('%02x' % ord(c) for c in long_message), pdu_opts['message_payload']) stored_ids = yield self.lookup_message_ids(service, seq_nums) self.assertEqual(['abc123'], stored_ids) @inlineCallbacks def test_submit_csm_sar(self): """ A long message can be sent in multiple PDUs with SAR fields set to instruct the SMSC to build user data headers. """ service = yield self.get_service({'send_multipart_sar': True}) yield self.fake_smsc.bind() long_message = 'This is a long message.' * 20 seq_nums = yield service.submit_csm_sar( 'abc123', 'dest_addr', short_message=long_message) pdus = yield self.fake_smsc.await_pdus(4) # seq no 1 == bind_transceiver, 2 == enquire_link, 3 == sar_msg_ref_num self.assertEqual([4, 5, 6, 7], seq_nums) msg_parts = [] msg_refs = [] for i, sm in enumerate(pdus): pdu_opts = unpacked_pdu_opts(sm) mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg_parts.append(mandatory_parameters['short_message']) self.assertTrue(len(mandatory_parameters['short_message']) <= 130) msg_refs.append(pdu_opts['sar_msg_ref_num']) self.assertEqual(i + 1, pdu_opts['sar_segment_seqnum']) self.assertEqual(4, pdu_opts['sar_total_segments']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual([3, 3, 3, 3], msg_refs) stored_ids = yield self.lookup_message_ids(service, seq_nums) self.assertEqual(['abc123'] * len(seq_nums), stored_ids) @inlineCallbacks def test_submit_csm_sar_ref_num_limit(self): """ The SAR reference number is set correctly when the generated reference number is larger than 0xFFFF. """ service = yield self.get_service({'send_multipart_sar': True}) yield self.fake_smsc.bind() # forward until we go past 0xFFFF yield self.set_sequence_number(service, 0x10000) long_message = 'This is a long message.' * 20 seq_nums = yield service.submit_csm_sar( 'abc123', 'dest_addr', short_message=long_message) pdus = yield self.fake_smsc.await_pdus(4) msg_parts = [] msg_refs = [] for i, sm in enumerate(pdus): pdu_opts = unpacked_pdu_opts(sm) mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg_parts.append(mandatory_parameters['short_message']) self.assertTrue(len(mandatory_parameters['short_message']) <= 130) msg_refs.append(pdu_opts['sar_msg_ref_num']) self.assertEqual(i + 1, pdu_opts['sar_segment_seqnum']) self.assertEqual(4, pdu_opts['sar_total_segments']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual([2, 2, 2, 2], msg_refs) stored_ids = yield self.lookup_message_ids(service, seq_nums) self.assertEqual(['abc123'] * len(seq_nums), stored_ids) @inlineCallbacks def test_submit_csm_sar_single_part(self): """ If the content fits in a single message, all the multipart madness is avoided. """ service = yield self.get_service({'send_multipart_sar': True}) yield self.fake_smsc.bind() content = 'a' * 160 seq_numbers = yield service.submit_csm_sar( 'abc123', 'dest_addr', short_message=content) self.assertEqual(len(seq_numbers), 1) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(short_message(submit_sm_pdu), content) self.assertEqual(unpacked_pdu_opts(submit_sm_pdu), {}) @inlineCallbacks def test_submit_csm_udh(self): """ A long message can be sent in multiple PDUs with carefully handcrafted user data headers. """ service = yield self.get_service({'send_multipart_udh': True}) yield self.fake_smsc.bind() long_message = 'This is a long message.' * 20 seq_numbers = yield service.submit_csm_udh( 'abc123', 'dest_addr', short_message=long_message) pdus = yield self.fake_smsc.await_pdus(4) self.assertEqual(len(seq_numbers), 4) msg_parts = [] msg_refs = [] for i, sm in enumerate(pdus): mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg = mandatory_parameters['short_message'] udh_hlen, udh_tag, udh_len, udh_ref, udh_tot, udh_seq = [ ord(octet) for octet in msg[:6]] self.assertEqual(5, udh_hlen) self.assertEqual(0, udh_tag) self.assertEqual(3, udh_len) msg_refs.append(udh_ref) self.assertEqual(4, udh_tot) self.assertEqual(i + 1, udh_seq) self.assertTrue(len(msg) <= 136) msg_parts.append(msg[6:]) self.assertEqual(0x40, mandatory_parameters['esm_class']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual(1, len(set(msg_refs))) stored_ids = yield self.lookup_message_ids(service, seq_numbers) self.assertEqual(['abc123'] * len(seq_numbers), stored_ids) @inlineCallbacks def test_submit_csm_udh_ref_num_limit(self): """ User data headers are crafted correctly when the generated reference number is larger than 0xFF. """ service = yield self.get_service({'send_multipart_udh': True}) yield self.fake_smsc.bind() # forward until we go past 0xFF yield self.set_sequence_number(service, 0x100) long_message = 'This is a long message.' * 20 seq_numbers = yield service.submit_csm_udh( 'abc123', 'dest_addr', short_message=long_message) pdus = yield self.fake_smsc.await_pdus(4) self.assertEqual(len(seq_numbers), 4) msg_parts = [] msg_refs = [] for i, sm in enumerate(pdus): mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg = mandatory_parameters['short_message'] udh_hlen, udh_tag, udh_len, udh_ref, udh_tot, udh_seq = [ ord(octet) for octet in msg[:6]] self.assertEqual(5, udh_hlen) self.assertEqual(0, udh_tag) self.assertEqual(3, udh_len) msg_refs.append(udh_ref) self.assertEqual(4, udh_tot) self.assertEqual(i + 1, udh_seq) self.assertTrue(len(msg) <= 136) msg_parts.append(msg[6:]) self.assertEqual(0x40, mandatory_parameters['esm_class']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual(1, len(set(msg_refs))) stored_ids = yield self.lookup_message_ids(service, seq_numbers) self.assertEqual(['abc123'] * len(seq_numbers), stored_ids) @inlineCallbacks def test_submit_csm_udh_single_part(self): """ If the content fits in a single message, all the multipart madness is avoided. """ service = yield self.get_service({'send_multipart_udh': True}) yield self.fake_smsc.bind() content = 'a' * 160 seq_numbers = yield service.submit_csm_udh( 'abc123', 'dest_addr', short_message=content) self.assertEqual(len(seq_numbers), 1) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(short_message(submit_sm_pdu), content) self.assertEqual( submit_sm_pdu['body']['mandatory_parameters']['esm_class'], 0) @inlineCallbacks def test_pdu_cache_persistence(self): """ A cached PDU has an appropriate TTL and can be deleted. """ service = yield self.get_service() message_stash = service.message_stash config = service.get_config() pdu = SubmitSM(1337, short_message="foo") yield message_stash.cache_pdu("vumi0", pdu) ttl = yield message_stash.redis.ttl(pdu_key(1337)) self.assertTrue(0 < ttl <= config.submit_sm_expiry) pdu_data = yield message_stash.get_cached_pdu(1337) self.assertEqual(pdu_data.vumi_message_id, "vumi0") self.assertEqual(pdu_data.pdu.get_hex(), pdu.get_hex()) yield message_stash.delete_cached_pdu(1337) deleted_pdu_data = yield message_stash.get_cached_pdu(1337) self.assertEqual(deleted_pdu_data, None) PK=JG{z+vumi/transports/smpp/tests/test_sequence.pyfrom twisted.internet.defer import inlineCallbacks from vumi.tests.helpers import VumiTestCase, PersistenceHelper from vumi.transports.smpp.sequence import RedisSequence class EsmeTestCase(VumiTestCase): @inlineCallbacks def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) self.redis = yield self.persistence_helper.get_redis_manager() @inlineCallbacks def test_iter(self): sequence_generator = RedisSequence(self.redis) iterator = iter(sequence_generator) self.assertEqual((yield iterator.next()), 1) iterator = iter(sequence_generator) self.assertEqual((yield iterator.next()), 2) @inlineCallbacks def test_next(self): sequence_generator = RedisSequence(self.redis) self.assertEqual((yield sequence_generator.next()), 1) @inlineCallbacks def test_get_next_sequence(self): sequence_generator = RedisSequence(self.redis) self.assertEqual((yield sequence_generator.get_next_seq()), 1) @inlineCallbacks def test_rollover(self): sequence_generator = RedisSequence(self.redis, rollover_at=3) self.assertEqual((yield sequence_generator.next()), 1) self.assertEqual((yield sequence_generator.next()), 2) self.assertEqual((yield sequence_generator.next()), 3) self.assertEqual((yield sequence_generator.next()), 1) PK=JG=(vumi/transports/smpp/deprecated/utils.pyfrom vumi.transports.smpp.processors import ( DeliveryReportProcessorConfig, SubmitShortMessageProcessorConfig, DeliverShortMessageProcessorConfig) def convert_to_new_config(config, dr_processor, submit_sm_processor, deliver_sm_processor): dr_config = dict( (field.name, config.pop(field.name)) for field in DeliveryReportProcessorConfig._get_fields() if field.name in config) submit_sm_config = dict( (field.name, config.pop(field.name)) for field in SubmitShortMessageProcessorConfig._get_fields() if field.name in config) deliver_sm_config = dict( (field.name, config.pop(field.name)) for field in DeliverShortMessageProcessorConfig._get_fields() if field.name in config) config.update({ 'delivery_report_processor': dr_processor, 'delivery_report_processor_config': dr_config, 'submit_short_message_processor': submit_sm_processor, 'submit_short_message_processor_config': submit_sm_config, 'deliver_short_message_processor': deliver_sm_processor, 'deliver_short_message_processor_config': deliver_sm_config, }) return config PK=JG+vumi/transports/smpp/deprecated/__init__.pyPK=JG=fQ*vumi/transports/smpp/deprecated/service.py from twisted.python import log from twisted.internet import defer from vumi.worker import BaseWorker from vumi.transports.smpp.deprecated.clientserver.server import ( SmscServerFactory) from vumi.transports.smpp.deprecated.transport import SmppTransportConfig from vumi.config import ConfigServerEndpoint class SmppServiceConfig(SmppTransportConfig): twisted_endpoint = ConfigServerEndpoint( 'Server endpoint description', required=True, static=True) class SmppService(BaseWorker): """ The SmppService """ CONFIG_CLASS = SmppServiceConfig def setup_connectors(self): pass @defer.inlineCallbacks def setup_worker(self): log.msg("Starting the SmppService") config = self.get_static_config() delivery_report_string = self.config.get('smsc_delivery_report_string') self.factory = SmscServerFactory( delivery_report_string=delivery_report_string) self.listening = yield config.twisted_endpoint.listen(self.factory) PK=JGU}W}W,vumi/transports/smpp/deprecated/transport.py# -*- test-case-name: vumi.transports.smpp.deprecated.tests.test_smpp -*- import warnings from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from vumi import log from vumi.reconnecting_client import ReconnectingClientService from vumi.utils import get_operator_number from vumi.transports.base import Transport from vumi.transports.smpp.deprecated.clientserver.client import ( EsmeTransceiverFactory, EsmeTransmitterFactory, EsmeReceiverFactory, EsmeCallbacks) from vumi.transports.failures import FailureMessage from vumi.message import Message, TransportUserMessage from vumi.persist.txredis_manager import TxRedisManager from vumi.config import ( ConfigText, ConfigInt, ConfigBool, ConfigDict, ConfigFloat, ConfigRegex, ConfigClientEndpoint, ClientEndpointFallback) class SmppTransportConfig(Transport.CONFIG_CLASS): DELIVERY_REPORT_REGEX = ( 'id:(?P\S{,65})' ' +sub:(?P...)' ' +dlvrd:(?P...)' ' +submit date:(?P\d*)' ' +done date:(?P\d*)' ' +stat:(?P[A-Z]{7})' ' +err:(?P...)' ' +[Tt]ext:(?P.{,20})' '.*' ) DELIVERY_REPORT_STATUS_MAPPING = { # Output values should map to themselves: 'delivered': 'delivered', 'failed': 'failed', 'pending': 'pending', # SMPP `message_state` values: 'ENROUTE': 'pending', 'DELIVERED': 'delivered', 'EXPIRED': 'failed', 'DELETED': 'failed', 'UNDELIVERABLE': 'failed', 'ACCEPTED': 'delivered', 'UNKNOWN': 'pending', 'REJECTED': 'failed', # From the most common regex-extracted format: 'DELIVRD': 'delivered', 'REJECTD': 'failed', # Currently we will accept this for Yo! TODO: investigate '0': 'delivered', } twisted_endpoint = ConfigClientEndpoint( 'The SMPP endpoint to connect to.', required=True, static=True, fallbacks=[ClientEndpointFallback()]) system_id = ConfigText( 'User id used to connect to the SMPP server.', required=True, static=True) password = ConfigText( 'Password for the system id.', required=True, static=True) system_type = ConfigText( "Additional system metadata that is passed through to the SMPP " "server on connect.", default="", static=True) interface_version = ConfigText( "SMPP protocol version. Default is '34' (i.e. version 3.4).", default="34", static=True) service_type = ConfigText( 'The SMPP service type', default="", static=True) dest_addr_ton = ConfigInt( 'Destination TON (type of number)', default=0, static=True) dest_addr_npi = ConfigInt( 'Destination NPI (number plan identifier). ' 'Default 1 (ISDN/E.164/E.163)', default=1, static=True) source_addr_ton = ConfigInt( 'Source TON (type of number)', default=0, static=True) source_addr_npi = ConfigInt( 'Source NPI (number plan identifier)', default=0, static=True) registered_delivery = ConfigBool( 'Whether or not to request delivery reports', default=True, static=True) smpp_bind_timeout = ConfigInt( 'How long to wait for a succesful bind', default=30, static=True) smpp_enquire_link_interval = ConfigInt( "Number of seconds to delay before reconnecting to the server after " "being disconnected. Default is 5s. Some WASPs, e.g. Clickatell " "require a 30s delay before reconnecting. In these cases a 45s " "initial_reconnect_delay is recommended.", default=55, static=True) initial_reconnect_delay = ConfigInt( 'How long to wait between reconnecting attempts', default=5, static=True) third_party_id_expiry = ConfigInt( 'How long (seconds) to keep 3rd party message IDs around to allow for ' 'matching submit_sm_resp and delivery report messages. Defaults to ' '1 week', default=(60 * 60 * 24 * 7), static=True) delivery_report_regex = ConfigRegex( 'What regex to use for matching delivery reports', default=DELIVERY_REPORT_REGEX, static=True) delivery_report_status_mapping = ConfigDict( "Mapping from delivery report message state to " "(`delivered`, `failed`, `pending`)", static=True, default=DELIVERY_REPORT_STATUS_MAPPING) submit_sm_expiry = ConfigInt( 'How long (seconds) to wait for the SMSC to return with a ' '`submit_sm_resp`. Defaults to 24 hours', default=(60 * 60 * 24), static=True) submit_sm_encoding = ConfigText( 'How to encode the SMS before putting on the wire', static=True, default='utf-8') submit_sm_data_coding = ConfigInt( 'What data_coding value to tell the SMSC we\'re using when putting' 'an SMS on the wire', static=True, default=0) data_coding_overrides = ConfigDict( "Overrides for data_coding character set mapping. This is useful for " "setting the default encoding (0), adding additional undefined " "encodings (such as 4 or 8) or overriding encodings in cases where " "the SMSC is violating the spec (which happens a lot). Keys should " "be integers, values should be strings containing valid Python " "character encoding names.", default={}, static=True) send_long_messages = ConfigBool( "If `True`, messages longer than 254 characters will be sent in the " "`message_payload` optional field instead of the `short_message` " "field. Default is `False`, simply because that maintains previous " "behaviour.", default=False, static=True) send_multipart_sar = ConfigBool( "If `True`, messages longer than 140 bytes will be sent as a series " "of smaller messages with the sar_* parameters set. Default is " "`False`.", default=False, static=True) send_multipart_udh = ConfigBool( "If `True`, messages longer than 140 bytes will be sent as a series " "of smaller messages with the user data headers. Default is `False`.", default=False, static=True) split_bind_prefix = ConfigText( "This is the Redis prefix to use for storing things like sequence " "numbers and message ids for delivery report handling. It defaults " "to `@:`. " "*ONLY* if the connection is split into two separate binds for RX " "and TX then make sure this is the same value for both binds. " "This _only_ needs to be done for TX & RX since messages sent via " "the TX bind are handled by the RX bind and they need to share the " "same prefix for the lookup for message ids in delivery reports to " "work.", default='', static=True) throttle_delay = ConfigFloat( "Delay (in seconds) before retrying a message after receiving " "`ESME_RTHROTTLED` or `ESME_RMSGQFUL`.", default=0.1, static=True) COUNTRY_CODE = ConfigText( "Used to translate a leading zero in a destination MSISDN into a " "country code. Default ''", default="", static=True) OPERATOR_PREFIX = ConfigDict( "Nested dictionary of prefix to network name mappings. Default {} " "(set network to 'UNKNOWN'). E.g. { '27': { '27761': 'NETWORK1' }} ", default={}, static=True) OPERATOR_NUMBER = ConfigDict( "Dictionary of source MSISDN to use for each network listed in " "OPERATOR_PREFIX. If a network is not listed, the source MSISDN " "specified by the message sender is used. Default {} (always used the " "from address specified by the message sender). " "E.g. { 'NETWORK1': '27761234567'}", default={}, static=True) redis_manager = ConfigDict( 'How to connect to Redis', default={}, static=True) # TODO: Deprecate these fields when confmodel#5 is done. host = ConfigText( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) port = ConfigInt( "*DEPRECATED* 'host' and 'port' fields may be used in place of the" " 'twisted_endpoint' field.", static=True) def post_validate(self): long_message_params = ( 'send_long_messages', 'send_multipart_sar', 'send_multipart_udh') set_params = [p for p in long_message_params if getattr(self, p)] if len(set_params) > 1: params = ', '.join(set_params) self.raise_config_error( "The following parameters are mutually exclusive: %s" % params) class SmppTransport(Transport): """ An SMPP Transceiver Transport. """ CONFIG_CLASS = SmppTransportConfig # Which of the keys in SmppTransportConfig are keys that are to # be passed on to the ESMETransceiver base class to create a bind with. SMPP_BIND_CONFIG_KEYS = [ 'system_id', 'password', 'system_type', 'interface_version', 'service_type', 'dest_addr_ton', 'dest_addr_npi', 'source_addr_ton', 'source_addr_npi', 'registered_delivery', ] # We only want to start this after we finish connecting to SMPP. start_message_consumer = False callLater = reactor.callLater @inlineCallbacks def setup_transport(self): warnings.warn( 'This SMPP implementation is deprecated. Please use the ' 'implementations available in vumi.transports.smpp.' 'smpp_transport instead.', category=DeprecationWarning) config = self.get_static_config() log.msg("Starting the SmppTransport for %s" % ( config.twisted_endpoint)) self.submit_sm_encoding = config.submit_sm_encoding self.submit_sm_data_coding = config.submit_sm_data_coding default_prefix = "%s@%s" % (config.system_id, config.transport_name) r_config = config.redis_manager r_prefix = config.split_bind_prefix or default_prefix redis = yield TxRedisManager.from_config(r_config) self.redis = redis.sub_manager(r_prefix) self.r_message_prefix = "message_json" self.throttled = False self.esme_callbacks = EsmeCallbacks( connect=self.esme_connected, disconnect=self.esme_disconnected, submit_sm_resp=self.submit_sm_resp, delivery_report=self.delivery_report, deliver_sm=self.deliver_sm) self._reconn_service = None if not hasattr(self, 'esme_client'): # start the Smpp transport (if we don't have one) self.factory = self.make_factory() self._reconn_service = ReconnectingClientService( config.twisted_endpoint, self.factory) self._reconn_service.startService() @inlineCallbacks def teardown_transport(self): if self._reconn_service is not None: yield self._reconn_service.stopService() yield self.redis._close() def get_smpp_bind_params(self): """Inspects the SmppTransportConfig and returns a dictionary that can be passed to an EsmeTransceiver (or subclass there of) to create a bind with""" config = self.get_static_config() return dict([(key, getattr(config, key)) for key in self.SMPP_BIND_CONFIG_KEYS]) def make_factory(self): return EsmeTransceiverFactory( self.get_static_config(), self.get_smpp_bind_params(), self.redis, self.esme_callbacks) def esme_connected(self, client): log.msg("ESME Connected, adding handlers") self.esme_client = client # Start the consumer self.unpause_connectors() @inlineCallbacks def handle_outbound_message(self, message): log.debug("Consumed outgoing message %r" % (message,)) log.debug("Unacknowledged message count: %s" % ( (yield self.esme_client.get_unacked_count()),)) yield self.r_set_message(message) yield self._submit_outbound_message(message) @inlineCallbacks def _submit_outbound_message(self, message): sequence_numbers = yield self.send_smpp(message) # TODO: Handle multiple acks for a single message that we split up. for sequence_number in sequence_numbers: yield self.r_set_id_for_sequence( sequence_number, message.payload.get("message_id")) def esme_disconnected(self): log.msg("ESME Disconnected") return self.pause_connectors() # Redis message storing methods def r_message_key(self, message_id): return "%s#%s" % (self.r_message_prefix, message_id) def r_set_message(self, message): config = self.get_static_config() message_id = message.payload['message_id'] message_key = self.r_message_key(message_id) d = self.redis.set(message_key, message.to_json()) d.addCallback(lambda _: self.redis.expire(message_key, config.submit_sm_expiry)) return d def r_get_message_json(self, message_id): return self.redis.get(self.r_message_key(message_id)) @inlineCallbacks def r_get_message(self, message_id): json_string = yield self.r_get_message_json(message_id) if json_string: returnValue(Message.from_json(json_string)) else: returnValue(None) def r_delete_message(self, message_id): return self.redis.delete(self.r_message_key(message_id)) # Redis sequence number storing methods def r_get_id_for_sequence(self, sequence_number): return self.redis.get(str(sequence_number)) def r_delete_for_sequence(self, sequence_number): return self.redis.delete(str(sequence_number)) def r_set_id_for_sequence(self, sequence_number, id): return self.redis.set(str(sequence_number), id) # Redis 3rd party id to vumi id mapping def r_third_party_id_key(self, third_party_id): return "3rd_party_id#%s" % (third_party_id,) def r_get_id_for_third_party_id(self, third_party_id): return self.redis.get(self.r_third_party_id_key(third_party_id)) def r_delete_for_third_party_id(self, third_party_id): return self.redis.delete( self.r_third_party_id_key(third_party_id)) @inlineCallbacks def r_set_id_for_third_party_id(self, third_party_id, id): config = self.get_static_config() rkey = self.r_third_party_id_key(third_party_id) yield self.redis.set(rkey, id) yield self.redis.expire(rkey, config.third_party_id_expiry) def _start_throttling(self): if self.throttled: return log.err("Throttling outbound messages.") self.throttled = True self.pause_connectors() def _stop_throttling(self): if not self.throttled: return log.err("No longer throttling outbound messages.") self.throttled = False self.unpause_connectors() @inlineCallbacks def submit_sm_resp(self, *args, **kwargs): transport_msg_id = kwargs['message_id'] sent_sms_id = ( yield self.r_get_id_for_sequence(kwargs['sequence_number'])) if sent_sms_id is None: log.err("Sequence number lookup failed for:%s" % ( kwargs['sequence_number'],)) else: yield self.r_set_id_for_third_party_id( transport_msg_id, sent_sms_id) yield self.r_delete_for_sequence(kwargs['sequence_number']) status = kwargs['command_status'] if status == 'ESME_ROK': # The sms was submitted ok yield self.submit_sm_success(sent_sms_id, transport_msg_id) yield self._stop_throttling() elif status in ('ESME_RTHROTTLED', 'ESME_RMSGQFUL'): yield self._start_throttling() yield self.submit_sm_throttled(sent_sms_id) else: # We have an error yield self.submit_sm_failure(sent_sms_id, status or 'Unspecified') yield self._stop_throttling() @inlineCallbacks def submit_sm_success(self, sent_sms_id, transport_msg_id): yield self.r_delete_message(sent_sms_id) log.debug("Mapping transport_msg_id=%s to sent_sms_id=%s" % ( transport_msg_id, sent_sms_id)) log.debug("PUBLISHING ACK: (%s -> %s)" % ( sent_sms_id, transport_msg_id)) self.publish_ack( user_message_id=sent_sms_id, sent_message_id=transport_msg_id) @inlineCallbacks def submit_sm_failure(self, sent_sms_id, reason, failure_code=None): error_message = yield self.r_get_message(sent_sms_id) if error_message is None: log.err("Could not retrieve failed message:%s" % ( sent_sms_id)) else: yield self.r_delete_message(sent_sms_id) yield self.publish_nack(sent_sms_id, reason) yield self.failure_publisher.publish_message(FailureMessage( message=error_message.payload, failure_code=None, reason=reason)) @inlineCallbacks def submit_sm_throttled(self, sent_sms_id): message = yield self.r_get_message(sent_sms_id) if message is None: log.err("Could not retrieve throttled message:%s" % ( sent_sms_id)) else: config = self.get_static_config() self.callLater(config.throttle_delay, self._submit_outbound_message, message) def delivery_status(self, state): config = self.get_static_config() return config.delivery_report_status_mapping.get(state, 'pending') @inlineCallbacks def delivery_report(self, message_id, message_state): delivery_status = self.delivery_status(message_state) message_id = yield self.r_get_id_for_third_party_id(message_id) if message_id is None: log.warning("Failed to retrieve message id for delivery report." " Delivery report from %s discarded." % self.transport_name) return log.msg("PUBLISHING DELIV REPORT: %s %s" % (message_id, delivery_status)) returnValue((yield self.publish_delivery_report( user_message_id=message_id, delivery_status=delivery_status))) def deliver_sm(self, *args, **kwargs): message_type = kwargs.get('message_type', 'sms') message = { 'message_id': kwargs['message_id'], 'to_addr': kwargs['destination_addr'], 'from_addr': kwargs['source_addr'], 'content': kwargs['short_message'], 'transport_type': message_type, 'transport_metadata': {}, } if message_type == 'ussd': session_event = { 'new': TransportUserMessage.SESSION_NEW, 'continue': TransportUserMessage.SESSION_RESUME, 'close': TransportUserMessage.SESSION_CLOSE, }[kwargs['session_event']] message['session_event'] = session_event session_info = kwargs.get('session_info') message['transport_metadata']['session_info'] = session_info log.msg("PUBLISHING INBOUND: %s" % (message,)) # TODO: This logs messages that fail to serialize to JSON # Usually this happens when an SMPP message has content # we can't decode (e.g. data_coding == 4). We should # remove the try-except once we handle such messages # better. return self.publish_message(**message).addErrback(log.err) def send_smpp(self, message): log.debug("Sending SMPP message: %s" % (message)) # first do a lookup in our YAML to see if we've got a source_addr # defined for the given MT number, if not, trust the from_addr # in the message to_addr = message['to_addr'] from_addr = message['from_addr'] text = message['content'] continue_session = ( message['session_event'] != TransportUserMessage.SESSION_CLOSE) config = self.get_static_config() route = get_operator_number(to_addr, config.COUNTRY_CODE, config.OPERATOR_PREFIX, config.OPERATOR_NUMBER) source_addr = route or from_addr session_info = message['transport_metadata'].get('session_info') return self.esme_client.submit_sm( # these end up in the PDU short_message=text.encode(self.submit_sm_encoding), data_coding=self.submit_sm_data_coding, destination_addr=to_addr.encode('ascii'), source_addr=source_addr.encode('ascii'), session_info=session_info.encode('ascii') if session_info is not None else None, # these don't end up in the PDU message_type=message['transport_type'], continue_session=continue_session, ) def stopWorker(self): log.msg("Stopping the SMPPTransport") return super(SmppTransport, self).stopWorker() def send_failure(self, message, exception, reason): """Send a failure report.""" log.msg("Failed to send: %s reason: %s" % (message, reason)) return super(SmppTransport, self).send_failure(message, exception, reason) class SmppTxTransport(SmppTransport): """An Smpp Transmitter Transport""" def make_factory(self): return EsmeTransmitterFactory( self.get_static_config(), self.get_smpp_bind_params(), self.redis, self.esme_callbacks) class SmppRxTransport(SmppTransport): """An Smpp Receiver Transport""" def make_factory(self): return EsmeReceiverFactory( self.get_static_config(), self.get_smpp_bind_params(), self.redis, self.esme_callbacks) PK=JG57p7p6vumi/transports/smpp/deprecated/clientserver/client.py# -*- test-case-name: vumi.transports.smpp.deprecated.clientserver.tests.test_client -*- import json import uuid from random import randint from twisted.internet import reactor from twisted.internet.protocol import Protocol, ClientFactory from twisted.internet.task import LoopingCall from twisted.internet.defer import inlineCallbacks, returnValue, DeferredQueue import binascii from smpp.pdu import unpack_pdu from smpp.pdu_builder import ( BindTransceiver, BindTransmitter, BindReceiver, DeliverSMResp, SubmitSM, EnquireLink, EnquireLinkResp, QuerySM, UnbindResp) from smpp.pdu_inspector import ( MultipartMessage, detect_multipart, multipart_key) from vumi import log GSM_MAX_SMS_BYTES = 140 def unpacked_pdu_opts(unpacked_pdu): pdu_opts = {} for opt in unpacked_pdu['body'].get('optional_parameters', []): pdu_opts[opt['tag']] = opt['value'] return pdu_opts def detect_ussd(pdu_opts): # TODO: Push this back to python-smpp? return ('ussd_service_op' in pdu_opts) def update_ussd_pdu(sm_pdu, continue_session, session_info=None): if session_info is None: session_info = '0000' session_info = "%04x" % (int(session_info, 16) + int(not continue_session)) sm_pdu._PDU__add_optional_parameter('ussd_service_op', '02') sm_pdu._PDU__add_optional_parameter('its_session_info', session_info) return sm_pdu class EsmeTransceiver(Protocol): BIND_PDU = BindTransceiver CONNECTED_STATE = 'BOUND_TRX' callLater = reactor.callLater def __init__(self, config, bind_params, redis, esme_callbacks): self.config = config self.bind_params = bind_params self.esme_callbacks = esme_callbacks self.state = 'CLOSED' log.msg('STATE: %s' % (self.state,)) self.smpp_bind_timeout = self.config.smpp_bind_timeout self.smpp_enquire_link_interval = \ self.config.smpp_enquire_link_interval self.datastream = '' self.redis = redis self._lose_conn = None # The PDU queue ensures that PDUs are processed in the order # they arrive. `self._process_pdu_queue()` loops forever # pulling PDUs off the queue and handling each before grabbing # the next. self._pdu_queue = DeferredQueue() self._process_pdu_queue() # intentionally throw away deferred @inlineCallbacks def get_next_seq(self): """Get the next available SMPP sequence number. The valid range of sequence number is 0x00000001 to 0xFFFFFFFF. We start trying to wrap at 0xFFFF0000 so we can keep returning values (up to 0xFFFF of them) even while someone else is in the middle of resetting the counter. """ seq = yield self.redis.incr('smpp_last_sequence_number') if seq >= 0xFFFF0000: # We're close to the upper limit, so try to reset. It doesn't # matter if we actually succeed or not, since we're going to return # `seq` anyway. yield self._reset_seq_counter() returnValue(seq) @inlineCallbacks def _reset_seq_counter(self): """Reset the sequence counter in a safe manner. NOTE: There is a potential race condition in this implementation. If we acquire the lock and it expires while we still think we hold it, it's possible for the sequence number to be reset by someone else between the final vlue check and the reset call. This seems like a very unlikely situation, so we'll leave it like that for now. A better solution is to replace this whole method with a lua script that we send to redis, but scripting support is still very new at the time of writing. """ # SETNX can be used as a lock. locked = yield self.redis.setnx('smpp_last_sequence_number_wrap', 1) # If someone crashed in exactly the wrong place, the lock may be # held by someone else but have no expire time. A race condition # here may set the TTL multiple times, but that's fine. if (yield self.redis.ttl('smpp_last_sequence_number_wrap')) < 0: # The TTL only gets set if the lock exists and recently had no TTL. yield self.redis.expire('smpp_last_sequence_number_wrap', 10) if not locked: # We didn't actually get the lock, so our job is done. return if (yield self.redis.get('smpp_last_sequence_number')) < 0xFFFF0000: # Our stored sequence number is no longer outside the allowed # range, so someone else must have reset it before we got the lock. return # We reset the counter by deleting the key. The next INCR will recreate # it for us. yield self.redis.delete('smpp_last_sequence_number') def pop_data(self): data = None if(len(self.datastream) >= 16): command_length = int(binascii.b2a_hex(self.datastream[0:4]), 16) if(len(self.datastream) >= command_length): data = self.datastream[0:command_length] self.datastream = self.datastream[command_length:] return data @inlineCallbacks def handle_data(self, data): pdu = unpack_pdu(data) command_id = pdu['header']['command_id'] if command_id not in ('enquire_link', 'enquire_link_resp'): log.debug('INCOMING <<<< %s' % binascii.b2a_hex(data)) log.debug('INCOMING <<<< %s' % pdu) handler = getattr(self, 'handle_%s' % (command_id,), self._command_handler_not_found) yield handler(pdu) @inlineCallbacks def _process_pdu_queue(self): data = yield self._pdu_queue.get() while data is not None: yield self.handle_data(data) data = yield self._pdu_queue.get() def _command_handler_not_found(self, pdu): log.err('No command handler available for %s' % (pdu,)) @inlineCallbacks def connectionMade(self): self.state = 'OPEN' log.msg('STATE: %s' % (self.state)) seq = yield self.get_next_seq() pdu = self.BIND_PDU(seq, **self.bind_params) log.msg(pdu.get_obj()) self.send_pdu(pdu) self.schedule_lose_connection(self.CONNECTED_STATE) def schedule_lose_connection(self, expected_status): self._lose_conn = self.callLater(self.smpp_bind_timeout, self.lose_unbound_connection, expected_status) def lose_unbound_connection(self, required_state): if self.state != required_state: log.msg('Breaking connection due to binding delay, %s != %s\n' % ( self.state, required_state)) self._lose_conn = None self.transport.loseConnection() else: log.msg('Successful bind: %s, cancelling bind timeout' % ( self.state)) def connectionLost(self, *args, **kwargs): self.state = 'CLOSED' self.stop_enquire_link() self.cancel_drop_connection_call() log.msg('STATE: %s' % (self.state)) self.esme_callbacks.disconnect() def dataReceived(self, data): self.datastream += data data = self.pop_data() while data is not None: self._pdu_queue.put(data) data = self.pop_data() def send_pdu(self, pdu): data = pdu.get_bin() unpacked = unpack_pdu(data) command_id = unpacked['header']['command_id'] if command_id not in ('enquire_link', 'enquire_link_resp'): log.debug('OUTGOING >>>> %s' % unpacked) self.transport.write(data) @inlineCallbacks def start_enquire_link(self): self.lc_enquire = LoopingCall(self.enquire_link) self.lc_enquire.start(self.smpp_enquire_link_interval) self.cancel_drop_connection_call() yield self.esme_callbacks.connect(self) @inlineCallbacks def stop_enquire_link(self): lc_enquire = getattr(self, 'lc_enquire', None) if lc_enquire and lc_enquire.running: lc_enquire.stop() log.msg('Stopped enquire link looping call') yield lc_enquire.deferred def cancel_drop_connection_call(self): if self._lose_conn is not None: self._lose_conn.cancel() self._lose_conn = None @inlineCallbacks def handle_unbind(self, pdu): yield self.send_pdu(UnbindResp( sequence_number=pdu['header']['sequence_number'])) self.transport.loseConnection() @inlineCallbacks def handle_bind_transceiver_resp(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': self.state = 'BOUND_TRX' yield self.start_enquire_link() log.msg('STATE: %s' % (self.state)) @inlineCallbacks def handle_submit_sm_resp(self, pdu): yield self.pop_unacked() message_id = pdu.get('body', {}).get( 'mandatory_parameters', {}).get('message_id') yield self.esme_callbacks.submit_sm_resp( sequence_number=pdu['header']['sequence_number'], command_status=pdu['header']['command_status'], command_id=pdu['header']['command_id'], message_id=message_id) def _decode_message(self, message, data_coding): """ Messages can arrive with one of a number of specified encodings. We only handle a subset of these. From the SMPP spec: 00000000 (0) SMSC Default Alphabet 00000001 (1) IA5(CCITTT.50)/ASCII(ANSIX3.4) 00000010 (2) Octet unspecified (8-bit binary) 00000011 (3) Latin1(ISO-8859-1) 00000100 (4) Octet unspecified (8-bit binary) 00000101 (5) JIS(X0208-1990) 00000110 (6) Cyrllic(ISO-8859-5) 00000111 (7) Latin/Hebrew (ISO-8859-8) 00001000 (8) UCS2(ISO/IEC-10646) 00001001 (9) PictogramEncoding 00001010 (10) ISO-2022-JP(MusicCodes) 00001011 (11) reserved 00001100 (12) reserved 00001101 (13) Extended Kanji JIS(X 0212-1990) 00001110 (14) KSC5601 00001111 (15) reserved Particularly problematic are the "Octet unspecified" encodings. """ codecs = { 1: 'ascii', 3: 'latin1', 8: 'utf-16be', # Actually UCS-2, but close enough. } codecs.update(self.config.data_coding_overrides) codec = codecs.get(data_coding, None) if codec is None or message is None: log.msg("WARNING: Not decoding message with data_coding=%s" % ( data_coding,)) else: try: return message.decode(codec) except Exception, e: log.msg("Error decoding message with data_coding=%s" % ( data_coding,)) log.err(e) return message @inlineCallbacks def handle_deliver_sm(self, pdu): if self.state not in ['BOUND_RX', 'BOUND_TRX']: log.err('WARNING: Received deliver_sm in wrong state: %s' % ( self.state)) return if pdu['header']['command_status'] != 'ESME_ROK': return # TODO: Only ACK messages once we've processed them? sequence_number = pdu['header']['sequence_number'] pdu_resp = DeliverSMResp(sequence_number, **self.bind_params) yield self.send_pdu(pdu_resp) pdu_params = pdu['body']['mandatory_parameters'] pdu_opts = unpacked_pdu_opts(pdu) # This might be a delivery receipt with PDU parameters. If we get a # delivery receipt without these parameters we'll try a regex match # later once we've decoded the message properly. receipted_message_id = pdu_opts.get('receipted_message_id', None) message_state = pdu_opts.get('message_state', None) if receipted_message_id is not None and message_state is not None: yield self.esme_callbacks.delivery_report( message_id=receipted_message_id, message_state={ 1: 'ENROUTE', 2: 'DELIVERED', 3: 'EXPIRED', 4: 'DELETED', 5: 'UNDELIVERABLE', 6: 'ACCEPTED', 7: 'UNKNOWN', 8: 'REJECTED', }.get(message_state, 'UNKNOWN'), ) # We might have a `message_payload` optional field to worry about. message_payload = pdu_opts.get('message_payload', None) if message_payload is not None: pdu_params['short_message'] = message_payload.decode('hex') if detect_ussd(pdu_opts): # We have a USSD message. yield self._handle_deliver_sm_ussd(pdu, pdu_params, pdu_opts) elif detect_multipart(pdu): # We have a multipart SMS. yield self._handle_deliver_sm_multipart(pdu, pdu_params) else: # We have a standard SMS. yield self._handle_deliver_sm_sms(pdu_params) def _deliver_sm(self, source_addr, destination_addr, short_message, **kw): delivery_report = self.config.delivery_report_regex.search( short_message or '') if delivery_report: # We have a delivery report. fields = delivery_report.groupdict() return self.esme_callbacks.delivery_report( message_id=fields['id'], message_state=fields['stat']) message_id = str(uuid.uuid4()) return self.esme_callbacks.deliver_sm( source_addr=source_addr, destination_addr=destination_addr, short_message=short_message, message_id=message_id, **kw) def _handle_deliver_sm_ussd(self, pdu, pdu_params, pdu_opts): # Some of this stuff might be specific to Tata's setup. service_op = pdu_opts['ussd_service_op'] session_event = 'close' if service_op == '01': # PSSR request. Let's assume it means a new session. session_event = 'new' elif service_op == '11': # PSSR response. This means session end. session_event = 'close' elif service_op in ('02', '12'): # USSR request or response. I *think* we only get the latter. session_event = 'continue' # According to the spec, the first octet is the session id and the # second is the client dialog id (first 7 bits) and end session flag # (last bit). # Since we don't use the client dialog id and the spec says it's # ESME-defined, treat the whole thing as opaque "session info" that # gets passed back in reply messages. its_session_number = int(pdu_opts['its_session_info'], 16) end_session = bool(its_session_number % 2) session_info = "%04x" % (its_session_number & 0xfffe) if end_session: # We have an explicit "end session" flag. session_event = 'close' decoded_msg = self._decode_message(pdu_params['short_message'], pdu_params['data_coding']) return self._deliver_sm( source_addr=pdu_params['source_addr'], destination_addr=pdu_params['destination_addr'], short_message=decoded_msg, message_type='ussd', session_event=session_event, session_info=session_info) def _handle_deliver_sm_sms(self, pdu_params): decoded_msg = self._decode_message(pdu_params['short_message'], pdu_params['data_coding']) return self._deliver_sm( source_addr=pdu_params['source_addr'], destination_addr=pdu_params['destination_addr'], short_message=decoded_msg) @inlineCallbacks def load_multipart_message(self, redis_key): value = yield self.redis.get(redis_key) value = json.loads(value) if value else {} log.debug("Retrieved value: %s" % (repr(value))) returnValue(MultipartMessage(self._unhex_from_redis(value))) def save_multipart_message(self, redis_key, multipart_message): data_dict = self._hex_for_redis(multipart_message.get_array()) return self.redis.set(redis_key, json.dumps(data_dict)) def _hex_for_redis(self, data_dict): for index, part in data_dict.items(): part['part_message'] = part['part_message'].encode('hex') return data_dict def _unhex_from_redis(self, data_dict): for index, part in data_dict.items(): part['part_message'] = part['part_message'].decode('hex') return data_dict @inlineCallbacks def _handle_deliver_sm_multipart(self, pdu, pdu_params): redis_key = "multi_%s" % (multipart_key(detect_multipart(pdu)),) log.debug("Redis multipart key: %s" % (redis_key)) multi = yield self.load_multipart_message(redis_key) multi.add_pdu(pdu) completed = multi.get_completed() if completed: yield self.redis.delete(redis_key) log.msg("Reassembled Message: %s" % (completed['message'])) # We assume that all parts have the same data_coding here, because # otherwise there's nothing sensible we can do. decoded_msg = self._decode_message(completed['message'], pdu_params['data_coding']) # and we can finally pass the whole message on yield self._deliver_sm( source_addr=completed['from_msisdn'], destination_addr=completed['to_msisdn'], short_message=decoded_msg) else: yield self.save_multipart_message(redis_key, multi) def handle_enquire_link(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': log.msg("enquire_link OK") sequence_number = pdu['header']['sequence_number'] pdu_resp = EnquireLinkResp(sequence_number) self.send_pdu(pdu_resp) else: log.msg("enquire_link NOT OK: %r" % (pdu,)) def handle_enquire_link_resp(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': log.msg("enquire_link_resp OK") else: log.msg("enquire_link_resp NOT OK: %r" % (pdu,)) def get_unacked_count(self): return self.redis.llen("unacked").addCallback(int) @inlineCallbacks def push_unacked(self, sequence_number=-1): yield self.redis.lpush("unacked", sequence_number) log.msg("unacked pushed to: %s" % ((yield self.get_unacked_count()))) @inlineCallbacks def pop_unacked(self): yield self.redis.lpop("unacked") log.msg("unacked popped to: %s" % ((yield self.get_unacked_count()))) @inlineCallbacks def submit_sm(self, **kwargs): if self.state not in ['BOUND_TX', 'BOUND_TRX']: log.err(('WARNING: submit_sm in wrong state: %s, ' 'dropping message: %s' % (self.state, kwargs))) returnValue(0) pdu_params = self.bind_params.copy() pdu_params.update(kwargs) message = pdu_params['short_message'] # We use GSM_MAX_SMS_BYTES here because we may have already-encoded # UCS-2 data to send and therefore can't use the 160 (7-bit) character # limit everyone knows and loves. If we have some other encoding # instead, this may result in unnecessarily short message parts. The # SMSC is probably going to treat whatever we send it as whatever # encoding it likes best and then encode (or mangle) it into a form it # thinks should be in the GSM message payload. Basically, when we have # to split messages up ourselves here we've already lost and the best # we can hope for is not getting hurt too badly by the inevitable # breakages. if len(message) > GSM_MAX_SMS_BYTES: if self.config.send_multipart_sar: sequence_numbers = yield self._submit_multipart_sar( **pdu_params) returnValue(sequence_numbers) elif self.config.send_multipart_udh: sequence_numbers = yield self._submit_multipart_udh( **pdu_params) returnValue(sequence_numbers) sequence_number = yield self._submit_sm(**pdu_params) returnValue([sequence_number]) @inlineCallbacks def _submit_sm(self, **pdu_params): sequence_number = yield self.get_next_seq() message = pdu_params['short_message'] sar_params = pdu_params.pop('sar_params', None) message_type = pdu_params.pop('message_type', 'sms') continue_session = pdu_params.pop('continue_session', True) session_info = pdu_params.pop('session_info', None) pdu = SubmitSM(sequence_number, **pdu_params) if message_type == 'ussd': update_ussd_pdu(pdu, continue_session, session_info) if self.config.send_long_messages and len(message) > 254: pdu.add_message_payload(''.join('%02x' % ord(c) for c in message)) if sar_params: pdu.set_sar_msg_ref_num(sar_params['msg_ref_num']) pdu.set_sar_total_segments(sar_params['total_segments']) pdu.set_sar_segment_seqnum(sar_params['segment_seqnum']) self.send_pdu(pdu) yield self.push_unacked(sequence_number) returnValue(sequence_number) @inlineCallbacks def _submit_multipart_sar(self, **pdu_params): message = pdu_params['short_message'] split_msg = [] # We chop the message into 130 byte chunks to leave 10 bytes for the # user data header the SMSC is presumably going to add for us. This is # a guess based mostly on optimism and the hope that we'll never have # to deal with this stuff in production. # FIXME: If we have utf-8 encoded data, we might break in the # middle of a multibyte character. payload_length = GSM_MAX_SMS_BYTES - 10 while message: split_msg.append(message[:payload_length]) message = message[payload_length:] ref_num = randint(1, 255) sequence_numbers = [] for i, msg in enumerate(split_msg): params = pdu_params.copy() params['short_message'] = msg params['sar_params'] = { 'msg_ref_num': ref_num, 'total_segments': len(split_msg), 'segment_seqnum': i + 1, } sequence_number = yield self._submit_sm(**params) sequence_numbers.append(sequence_number) returnValue(sequence_numbers) @inlineCallbacks def _submit_multipart_udh(self, **pdu_params): message = pdu_params['short_message'] split_msg = [] # We chop the message into 130 byte chunks to leave 10 bytes for the # 6-byte user data header we add and a little extra space in case the # SMSC does unexpected things with our message. # FIXME: If we have utf-8 encoded data, we might break in the # middle of a multibyte character. payload_length = GSM_MAX_SMS_BYTES - 10 while message: split_msg.append(message[:payload_length]) message = message[payload_length:] ref_num = randint(1, 255) sequence_numbers = [] for i, msg in enumerate(split_msg): params = pdu_params.copy() # 0x40 is the UDHI flag indicating that this payload contains a # user data header. params['esm_class'] = 0x40 # See http://en.wikipedia.org/wiki/User_Data_Header for an # explanation of the magic numbers below. We should probably # abstract this out into a class that makes it less magic and # opaque. udh = '\05\00\03%s%s%s' % ( chr(ref_num), chr(len(split_msg)), chr(i + 1)) params['short_message'] = udh + msg sequence_number = yield self._submit_sm(**params) sequence_numbers.append(sequence_number) returnValue(sequence_numbers) @inlineCallbacks def enquire_link(self, **kwargs): if self.state in ['BOUND_TX', 'BOUND_RX', 'BOUND_TRX']: sequence_number = yield self.get_next_seq() pdu = EnquireLink( sequence_number, **dict(self.bind_params, **kwargs)) self.send_pdu(pdu) returnValue(sequence_number) returnValue(0) @inlineCallbacks def query_sm(self, message_id, source_addr, **kwargs): if self.state in ['BOUND_TX', 'BOUND_TRX']: sequence_number = yield self.get_next_seq() pdu = QuerySM(sequence_number, message_id=message_id, source_addr=source_addr, **dict(self.bind_params, **kwargs)) self.send_pdu(pdu) returnValue(sequence_number) returnValue(0) class EsmeTransmitter(EsmeTransceiver): BIND_PDU = BindTransmitter CONNECTED_STATE = 'BOUND_TX' @inlineCallbacks def handle_bind_transmitter_resp(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': self.state = 'BOUND_TX' yield self.start_enquire_link() log.msg('STATE: %s' % (self.state)) class EsmeReceiver(EsmeTransceiver): BIND_PDU = BindReceiver CONNECTED_STATE = 'BOUND_RX' @inlineCallbacks def handle_bind_receiver_resp(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': self.state = 'BOUND_RX' yield self.start_enquire_link() log.msg('STATE: %s' % (self.state)) class EsmeTransceiverFactory(ClientFactory): def __init__(self, config, bind_params, redis, esme_callbacks): self.config = config self.bind_params = bind_params self.redis = redis self.esme = None self.esme_callbacks = esme_callbacks self.initialDelay = self.config.initial_reconnect_delay self.maxDelay = max(45, self.initialDelay) def startedConnecting(self, connector): log.msg('Started to connect.') def buildProtocol(self, addr): log.msg('Connected') self.esme = EsmeTransceiver( self.config, self.bind_params, self.redis, self.esme_callbacks) return self.esme @inlineCallbacks def clientConnectionLost(self, connector, reason): log.msg('Lost connection. Reason:', reason) ClientFactory.clientConnectionLost(self, connector, reason) def clientConnectionFailed(self, connector, reason): log.err(reason, 'Connection failed') ClientFactory.clientConnectionFailed(self, connector, reason) class EsmeTransmitterFactory(EsmeTransceiverFactory): def buildProtocol(self, addr): log.msg('Connected') self.esme = EsmeTransmitter( self.config, self.bind_params, self.redis, self.esme_callbacks) return self.esme class EsmeReceiverFactory(EsmeTransceiverFactory): def buildProtocol(self, addr): log.msg('Connected') self.esme = EsmeReceiver( self.config, self.bind_params, self.redis, self.esme_callbacks) return self.esme class EsmeCallbacks(object): """Callbacks for ESME factory and protocol.""" def __init__(self, connect=None, disconnect=None, submit_sm_resp=None, delivery_report=None, deliver_sm=None): self.connect = connect or self.fallback self.disconnect = disconnect or self.fallback self.submit_sm_resp = submit_sm_resp or self.fallback self.delivery_report = delivery_report or self.fallback self.deliver_sm = deliver_sm or self.fallback def fallback(self, *args, **kwargs): pass class ESME(object): """ The top 'Client' object Potentially should be able to bind as: * Transceiver * Transmitter and/or Receiver but currently only Transceiver is implemented """ def __init__(self, config, bind_params, redis, esme_callbacks): self.config = config self.bind_params = bind_params self.redis = redis self.esme_callbacks = esme_callbacks def bindTransciever(self): self.factory = EsmeTransceiverFactory( self.config, self.bind_params, self.redis, self.esme_callbacks) PK=JG8vumi/transports/smpp/deprecated/clientserver/__init__.pyPK=JGl6vumi/transports/smpp/deprecated/clientserver/server.py# -*- test-case-name: vumi.transports.smpp.deprecated.clientserver.test.test_server -*- import uuid from datetime import datetime from twisted.python import log from twisted.internet import reactor from twisted.internet.protocol import Protocol, ServerFactory from smpp.pdu_builder import (BindTransceiverResp, BindTransmitterResp, BindReceiverResp, EnquireLinkResp, SubmitSMResp, DeliverSM) from smpp.pdu_inspector import binascii, unpack_pdu class SmscServer(Protocol): def __init__(self, delivery_report_string=None): log.msg('__init__', 'SmscServer') self.delivery_report_string = delivery_report_string if self.delivery_report_string is None: self.delivery_report_string = 'id:%' \ 's sub:001 dlvrd:001 submit date:%' \ 's done date:%' \ 's stat:DELIVRD err:000 text:' self.datastream = '' def pop_data(self): data = None if(len(self.datastream) >= 16): command_length = int(binascii.b2a_hex(self.datastream[0:4]), 16) if(len(self.datastream) >= command_length): data = self.datastream[0:command_length] self.datastream = self.datastream[command_length:] return data def handle_data(self, data): pdu = unpack_pdu(data) log.msg('INCOMING <<<< %r' % (pdu,)) if pdu['header']['command_id'] == 'bind_transceiver': self.handle_bind_transceiver(pdu) if pdu['header']['command_id'] == 'bind_transmitter': self.handle_bind_transmitter(pdu) if pdu['header']['command_id'] == 'bind_receiver': self.handle_bind_receiver(pdu) if pdu['header']['command_id'] == 'submit_sm': self.handle_submit_sm(pdu) if pdu['header']['command_id'] == 'enquire_link': self.handle_enquire_link(pdu) def handle_bind_transceiver(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': sequence_number = pdu['header']['sequence_number'] system_id = pdu['body']['mandatory_parameters']['system_id'] pdu_resp = BindTransceiverResp(sequence_number, system_id=system_id) self.send_pdu(pdu_resp) def handle_bind_transmitter(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': sequence_number = pdu['header']['sequence_number'] system_id = pdu['body']['mandatory_parameters']['system_id'] pdu_resp = BindTransmitterResp(sequence_number, system_id=system_id) self.send_pdu(pdu_resp) def handle_bind_receiver(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': sequence_number = pdu['header']['sequence_number'] system_id = pdu['body']['mandatory_parameters']['system_id'] pdu_resp = BindReceiverResp(sequence_number, system_id=system_id) self.send_pdu(pdu_resp) def handle_enquire_link(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': sequence_number = pdu['header']['sequence_number'] pdu_resp = EnquireLinkResp(sequence_number) self.send_pdu(pdu_resp) def command_status(self, pdu): if pdu['body']['mandatory_parameters']['short_message'][:5] == "ESME_": return pdu['body']['mandatory_parameters']['short_message'].split( ' ')[0] else: return 'ESME_ROK' def handle_submit_sm(self, pdu): if pdu['header']['command_status'] == 'ESME_ROK': sequence_number = pdu['header']['sequence_number'] message_id = str(uuid.uuid4()) command_status = self.command_status(pdu) pdu_resp = SubmitSMResp( sequence_number, message_id, command_status) self.send_pdu(pdu_resp) reactor.callLater(0, self.delivery_report, message_id) def delivery_report(self, message_id): sequence_number = 1 short_message = (self.delivery_report_string % ( message_id, datetime.now().strftime("%y%m%d%H%M%S"), datetime.now().strftime("%y%m%d%H%M%S"))) pdu = DeliverSM(sequence_number, short_message=short_message) self.send_pdu(pdu) def dataReceived(self, data): self.datastream += data data = self.pop_data() while data is not None: self.handle_data(data) data = self.pop_data() def send_pdu(self, pdu): data = pdu.get_bin() log.msg('OUTGOING >>>> %r' % (unpack_pdu(data),)) self.transport.write(data) class SmscServerFactory(ServerFactory): protocol = SmscServer def __init__(self, delivery_report_string=None): self.delivery_report_string = delivery_report_string def buildProtocol(self, addr): self.smsc = self.protocol(self.delivery_report_string) return self.smsc PK=JG;vumi/transports/smpp/deprecated/clientserver/tests/utils.py""" Some utilities and things for testing various bits of SMPP. """ from twisted.internet.defer import DeferredQueue from smpp.pdu_inspector import unpack_pdu from vumi.transports.smpp.deprecated.clientserver.server import SmscServer class SmscTestServer(SmscServer): """ SMSC subclass that records inbound and outbound PDUs for later assertion. """ def __init__(self, delivery_report_string=None): self.pdu_queue = DeferredQueue() SmscServer.__init__(self, delivery_report_string) def handle_data(self, data): self.pdu_queue.put({ 'direction': 'inbound', 'pdu': unpack_pdu(data), }) return SmscServer.handle_data(self, data) def send_pdu(self, pdu): self.pdu_queue.put({ 'direction': 'outbound', 'pdu': pdu.get_obj(), }) return SmscServer.send_pdu(self, pdu) PK=JGTTAvumi/transports/smpp/deprecated/clientserver/tests/test_client.pyfrom twisted.internet.task import Clock from twisted.internet.defer import inlineCallbacks, returnValue from smpp.pdu_builder import DeliverSM, BindTransceiverResp, Unbind from smpp.pdu import unpack_pdu from vumi.tests.utils import LogCatcher from vumi.transports.smpp.deprecated.clientserver.client import ( EsmeTransceiver, EsmeReceiver, EsmeTransmitter, EsmeCallbacks, ESME, unpacked_pdu_opts) from vumi.transports.smpp.deprecated.transport import SmppTransportConfig from vumi.tests.helpers import VumiTestCase, PersistenceHelper class FakeTransport(object): def __init__(self, protocol): self.connected = True self.protocol = protocol def loseConnection(self): self.connected = False self.protocol.connectionLost() class FakeEsmeMixin(object): def setup_fake(self): self.transport = FakeTransport(self) self.clock = Clock() self.callLater = self.clock.callLater self.fake_sent_pdus = [] def fake_send_pdu(self, pdu): self.fake_sent_pdus.append(pdu) class FakeEsmeTransceiver(EsmeTransceiver, FakeEsmeMixin): def __init__(self, *args, **kwargs): EsmeTransceiver.__init__(self, *args, **kwargs) self.setup_fake() def send_pdu(self, pdu): return self.fake_send_pdu(pdu) class FakeEsmeReceiver(EsmeReceiver, FakeEsmeMixin): def __init__(self, *args, **kwargs): EsmeReceiver.__init__(self, *args, **kwargs) self.setup_fake() def send_pdu(self, pdu): return self.fake_send_pdu(pdu) class FakeEsmeTransmitter(EsmeTransmitter, FakeEsmeMixin): def __init__(self, *args, **kwargs): EsmeTransmitter.__init__(self, *args, **kwargs) self.setup_fake() def send_pdu(self, pdu): return self.fake_send_pdu(pdu) class EsmeTestCaseBase(VumiTestCase): ESME_CLASS = None def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) self._expected_callbacks = [] self.add_cleanup( self.assertEqual, self._expected_callbacks, [], "Uncalled callbacks.") def get_unbound_esme(self, host="127.0.0.1", port="0", system_id="1234", password="password", callbacks={}, extra_config={}): config_data = { "transport_name": "transport_name", "host": host, "port": port, "system_id": system_id, "password": password, } config_data.update(extra_config) config = SmppTransportConfig(config_data) esme_callbacks = EsmeCallbacks(**callbacks) def purge_manager(redis_manager): d = redis_manager._purge_all() # just in case d.addCallback(lambda result: redis_manager) return d redis_d = self.persistence_helper.get_redis_manager() redis_d.addCallback(purge_manager) return redis_d.addCallback( lambda r: self.ESME_CLASS(config, { 'system_id': system_id, 'password': password }, r, esme_callbacks)) @inlineCallbacks def get_esme(self, config={}, **callbacks): esme = yield self.get_unbound_esme(extra_config=config, callbacks=callbacks) yield esme.connectionMade() esme.fake_sent_pdus.pop() # Clear bind PDU. esme.state = esme.CONNECTED_STATE returnValue(esme) def get_sm(self, msg, data_coding=3): sm = DeliverSM(1, short_message=msg, data_coding=data_coding) return unpack_pdu(sm.get_bin()) def make_cb(self, fun): cb_id = len(self._expected_callbacks) self._expected_callbacks.append(cb_id) def cb(**value): self._expected_callbacks.remove(cb_id) return fun(value) return cb def assertion_cb(self, expected, *message_path): def fun(value): for k in message_path: value = value[k] self.assertEqual(expected, value) return self.make_cb(fun) class EsmeGenericMixin(object): """Generic tests.""" @inlineCallbacks def test_bind_timeout(self): callbacks_called = [] esme = yield self.get_unbound_esme(callbacks={ 'connect': lambda client: callbacks_called.append('connect'), 'disconnect': lambda: callbacks_called.append('disconnect'), }) yield esme.connectionMade() self.assertEqual([], callbacks_called) self.assertEqual(True, esme.transport.connected) self.assertNotEqual(None, esme._lose_conn) esme.clock.advance(esme.smpp_bind_timeout) self.assertEqual(['disconnect'], callbacks_called) self.assertEqual(False, esme.transport.connected) self.assertEqual(None, esme._lose_conn) @inlineCallbacks def test_bind_no_timeout(self): callbacks_called = [] esme = yield self.get_unbound_esme(callbacks={ 'connect': lambda client: callbacks_called.append('connect'), 'disconnect': lambda: callbacks_called.append('disconnect'), }) yield esme.connectionMade() self.assertEqual([], callbacks_called) self.assertEqual(True, esme.transport.connected) self.assertNotEqual(None, esme._lose_conn) esme.handle_bind_transceiver_resp(unpack_pdu( BindTransceiverResp(1).get_bin())) self.assertEqual(['connect'], callbacks_called) self.assertEqual(True, esme.transport.connected) self.assertEqual(None, esme._lose_conn) esme.lc_enquire.stop() yield esme.lc_enquire.deferred @inlineCallbacks def test_bind_and_disconnect(self): callbacks_called = [] esme = yield self.get_unbound_esme(callbacks={ 'connect': lambda client: callbacks_called.append('connect'), 'disconnect': lambda: callbacks_called.append('disconnect'), }) yield esme.connectionMade() esme.handle_bind_transceiver_resp(unpack_pdu( BindTransceiverResp(1).get_bin())) self.assertEqual(['connect'], callbacks_called) esme.lc_enquire.stop() yield esme.lc_enquire.deferred yield esme.transport.loseConnection() self.assertEqual(['connect', 'disconnect'], callbacks_called) self.assertEqual(False, esme.transport.connected) @inlineCallbacks def test_sequence_rollover(self): esme = yield self.get_unbound_esme() self.assertEqual(1, (yield esme.get_next_seq())) self.assertEqual(2, (yield esme.get_next_seq())) yield esme.redis.set('smpp_last_sequence_number', 0xFFFF0000) self.assertEqual(0xFFFF0001, (yield esme.get_next_seq())) self.assertEqual(1, (yield esme.get_next_seq())) @inlineCallbacks def test_unbind(self): esme = yield self.get_esme() yield esme.handle_data(Unbind(1).get_bin()) self.assertEqual(False, esme.transport.connected) class EsmeTransmitterMixin(EsmeGenericMixin): """Transmitter-side tests.""" @inlineCallbacks def test_submit_sm_sms(self): """Submit a USSD message with a session continue flag.""" esme = yield self.get_esme() yield esme.submit_sm(short_message='hello') [sm_pdu] = esme.fake_sent_pdus sm = unpack_pdu(sm_pdu.get_bin()) self.assertEqual('submit_sm', sm['header']['command_id']) self.assertEqual( 'hello', sm['body']['mandatory_parameters']['short_message']) self.assertEqual([], sm['body'].get('optional_parameters', [])) @inlineCallbacks def test_submit_sm_sms_long(self): """Submit a USSD message with a session continue flag.""" esme = yield self.get_esme(config={ 'send_long_messages': True, }) long_message = 'This is a long message.' * 20 yield esme.submit_sm(short_message=long_message) [sm_pdu] = esme.fake_sent_pdus sm = unpack_pdu(sm_pdu.get_bin()) pdu_opts = unpacked_pdu_opts(sm) self.assertEqual('submit_sm', sm['header']['command_id']) self.assertEqual( None, sm['body']['mandatory_parameters']['short_message']) self.assertEqual(''.join('%02x' % ord(c) for c in long_message), pdu_opts['message_payload']) @inlineCallbacks def test_submit_sm_sms_multipart_sar(self): """Submit a long SMS message using multipart sar fields.""" esme = yield self.get_esme(config={ 'send_multipart_sar': True, }) long_message = 'This is a long message.' * 20 seq_nums = yield esme.submit_sm(short_message=long_message) self.assertEqual([2, 3, 4, 5], seq_nums) self.assertEqual(4, len(esme.fake_sent_pdus)) msg_parts = [] msg_refs = [] for i, sm_pdu in enumerate(esme.fake_sent_pdus): sm = unpack_pdu(sm_pdu.get_bin()) pdu_opts = unpacked_pdu_opts(sm) mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg_parts.append(mandatory_parameters['short_message']) self.assertTrue(len(mandatory_parameters['short_message']) <= 130) msg_refs.append(pdu_opts['sar_msg_ref_num']) self.assertEqual(i + 1, pdu_opts['sar_segment_seqnum']) self.assertEqual(4, pdu_opts['sar_total_segments']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual(1, len(set(msg_refs))) @inlineCallbacks def test_submit_sm_sms_multipart_udh(self): """Submit a long SMS message using multipart user data headers.""" esme = yield self.get_esme(config={ 'send_multipart_udh': True, }) long_message = 'This is a long message.' * 20 seq_nums = yield esme.submit_sm(short_message=long_message) self.assertEqual([2, 3, 4, 5], seq_nums) self.assertEqual(4, len(esme.fake_sent_pdus)) msg_parts = [] msg_refs = [] for i, sm_pdu in enumerate(esme.fake_sent_pdus): sm = unpack_pdu(sm_pdu.get_bin()) mandatory_parameters = sm['body']['mandatory_parameters'] self.assertEqual('submit_sm', sm['header']['command_id']) msg = mandatory_parameters['short_message'] udh_hlen, udh_tag, udh_len, udh_ref, udh_tot, udh_seq = [ ord(octet) for octet in msg[:6]] self.assertEqual(5, udh_hlen) self.assertEqual(0, udh_tag) self.assertEqual(3, udh_len) msg_refs.append(udh_ref) self.assertEqual(4, udh_tot) self.assertEqual(i + 1, udh_seq) self.assertTrue(len(msg) <= 136) msg_parts.append(msg[6:]) self.assertEqual(0x40, mandatory_parameters['esm_class']) self.assertEqual(long_message, ''.join(msg_parts)) self.assertEqual(1, len(set(msg_refs))) @inlineCallbacks def test_submit_sm_ussd_continue(self): """Submit a USSD message with a session continue flag.""" esme = yield self.get_esme() yield esme.submit_sm( short_message='hello', message_type='ussd', continue_session=True, session_info='0100') [sm_pdu] = esme.fake_sent_pdus sm = unpack_pdu(sm_pdu.get_bin()) pdu_opts = unpacked_pdu_opts(sm) self.assertEqual('submit_sm', sm['header']['command_id']) self.assertEqual( 'hello', sm['body']['mandatory_parameters']['short_message']) self.assertEqual('02', pdu_opts['ussd_service_op']) self.assertEqual('0100', pdu_opts['its_session_info']) @inlineCallbacks def test_submit_sm_ussd_close(self): """Submit a USSD message with a session close flag.""" esme = yield self.get_esme() yield esme.submit_sm( short_message='hello', message_type='ussd', continue_session=False) [sm_pdu] = esme.fake_sent_pdus sm = unpack_pdu(sm_pdu.get_bin()) pdu_opts = unpacked_pdu_opts(sm) self.assertEqual('submit_sm', sm['header']['command_id']) self.assertEqual( 'hello', sm['body']['mandatory_parameters']['short_message']) self.assertEqual('02', pdu_opts['ussd_service_op']) self.assertEqual('0001', pdu_opts['its_session_info']) class EsmeReceiverMixin(EsmeGenericMixin): """Receiver-side tests.""" @inlineCallbacks def test_deliver_sm_simple(self): """A simple message should be delivered.""" esme = yield self.get_esme( deliver_sm=self.assertion_cb(u'hello', 'short_message')) yield esme.handle_deliver_sm(self.get_sm('hello')) @inlineCallbacks def test_deliver_sm_message_payload(self): """A message in the `message_payload` field should be delivered.""" esme = yield self.get_esme( deliver_sm=self.assertion_cb(u'hello', 'short_message')) sm = DeliverSM(1, short_message='') sm.add_message_payload(''.join('%02x' % ord(c) for c in 'hello')) yield esme.handle_deliver_sm(unpack_pdu(sm.get_bin())) @inlineCallbacks def test_deliver_sm_data_coding_override(self): """A simple message should be delivered.""" esme = yield self.get_esme(config={ 'data_coding_overrides': { 0: 'utf-16be' } }, deliver_sm=self.assertion_cb(u'hello', 'short_message')) yield esme.handle_deliver_sm( self.get_sm('\x00h\x00e\x00l\x00l\x00o', 0)) esme = yield self.get_esme(config={ 'data_coding_overrides': { 0: 'ascii' } }, deliver_sm=self.assertion_cb(u'hello', 'short_message')) yield esme.handle_deliver_sm( self.get_sm('hello', 0)) @inlineCallbacks def test_deliver_sm_ucs2(self): """A UCS-2 message should be delivered.""" esme = yield self.get_esme( deliver_sm=self.assertion_cb(u'hello', 'short_message')) yield esme.handle_deliver_sm( self.get_sm('\x00h\x00e\x00l\x00l\x00o', 8)) @inlineCallbacks def test_bad_sm_ucs2(self): """An invalid UCS-2 message should be discarded.""" bad_msg = '\n\x00h\x00e\x00l\x00l\x00o' esme = yield self.get_esme( deliver_sm=self.assertion_cb(bad_msg, 'short_message')) yield esme.handle_deliver_sm(self.get_sm(bad_msg, 8)) self.flushLoggedErrors() @inlineCallbacks def test_deliver_sm_delivery_report_delivered(self): esme = yield self.get_esme(delivery_report=self.assertion_cb({ 'message_id': '1b1720be-5f48-41c4-b3f8-6e59dbf45366', 'message_state': 'DELIVERED', })) sm = DeliverSM(1, short_message='delivery report') sm._PDU__add_optional_parameter( 'receipted_message_id', '1b1720be-5f48-41c4-b3f8-6e59dbf45366') sm._PDU__add_optional_parameter('message_state', 2) yield esme.handle_deliver_sm(unpack_pdu(sm.get_bin())) @inlineCallbacks def test_deliver_sm_delivery_report_rejected(self): esme = yield self.get_esme(delivery_report=self.assertion_cb({ 'message_id': '1b1720be-5f48-41c4-b3f8-6e59dbf45366', 'message_state': 'REJECTED', })) sm = DeliverSM(1, short_message='delivery report') sm._PDU__add_optional_parameter( 'receipted_message_id', '1b1720be-5f48-41c4-b3f8-6e59dbf45366') sm._PDU__add_optional_parameter('message_state', 8) yield esme.handle_deliver_sm(unpack_pdu(sm.get_bin())) @inlineCallbacks def test_deliver_sm_delivery_report_regex_fallback(self): esme = yield self.get_esme(delivery_report=self.assertion_cb({ 'message_id': '1b1720be-5f48-41c4-b3f8-6e59dbf45366', 'message_state': 'DELIVRD', })) yield esme.handle_deliver_sm(self.get_sm( 'id:1b1720be-5f48-41c4-b3f8-6e59dbf45366 sub:001 dlvrd:001 ' 'submit date:120726132548 done date:120726132548 stat:DELIVRD ' 'err:000 text:')) @inlineCallbacks def test_deliver_sm_delivery_report_regex_fallback_ucs2(self): esme = yield self.get_esme(delivery_report=self.assertion_cb({ 'message_id': '1b1720be-5f48', 'message_state': 'DELIVRD', })) dr_text = ( u'id:1b1720be-5f48 sub:001 dlvrd:001 ' u'submit date:120726132548 done date:120726132548 stat:DELIVRD ' u'err:000 text:').encode('utf-16be') yield esme.handle_deliver_sm(self.get_sm(dr_text, 8)) @inlineCallbacks def test_deliver_sm_delivery_report_regex_fallback_ucs2_long(self): esme = yield self.get_esme(delivery_report=self.assertion_cb({ 'message_id': '1b1720be-5f48-41c4-b3f8-6e59dbf45366', 'message_state': 'DELIVRD', })) dr_text = ( u'id:1b1720be-5f48-41c4-b3f8-6e59dbf45366 sub:001 dlvrd:001 ' u'submit date:120726132548 done date:120726132548 stat:DELIVRD ' u'err:000 text:').encode('utf-16be') sm = DeliverSM(1, short_message='', data_coding=8) sm.add_message_payload(dr_text.encode('hex')) yield esme.handle_deliver_sm(unpack_pdu(sm.get_bin())) @inlineCallbacks def test_deliver_sm_multipart(self): esme = yield self.get_esme( deliver_sm=self.assertion_cb(u'hello world', 'short_message')) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x02 world")) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x01hello")) @inlineCallbacks def test_deliver_sm_multipart_weird_coding(self): esme = yield self.get_esme( deliver_sm=self.assertion_cb(u'hello', 'short_message')) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x02l\x00l\x00o", 8)) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x01\x00h\x00e\x00", 8)) @inlineCallbacks def test_deliver_sm_multipart_arabic_ucs2(self): esme = yield self.get_esme( deliver_sm=self.assertion_cb( ('\xd8\xa7\xd9\x84\xd9\x84\xd9\x87 ' '\xd9\x85\xd8\xb9\xd9\x83').decode('utf-8'), 'short_message'), config={ 'data_coding_overrides': { 8: 'utf-8' } }) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x01\xd8\xa7\xd9\x84\xd9\x84\xd9\x87 ", 8)) yield esme.handle_deliver_sm(self.get_sm( "\x05\x00\x03\xff\x02\x02\xd9\x85\xd8\xb9\xd9\x83", 8)) @inlineCallbacks def test_deliver_sm_ussd_start(self): def assert_ussd(value): self.assertEqual('ussd', value['message_type']) self.assertEqual('new', value['session_event']) self.assertEqual(None, value['short_message']) esme = yield self.get_esme(deliver_sm=self.make_cb(assert_ussd)) sm = DeliverSM(1) sm._PDU__add_optional_parameter('ussd_service_op', '01') sm._PDU__add_optional_parameter('its_session_info', '0000') yield esme.handle_deliver_sm(unpack_pdu(sm.get_bin())) class TestEsmeTransceiver(EsmeTestCaseBase, EsmeReceiverMixin, EsmeTransmitterMixin): ESME_CLASS = FakeEsmeTransceiver class TestEsmeTransmitter(EsmeTestCaseBase, EsmeTransmitterMixin): ESME_CLASS = FakeEsmeTransmitter @inlineCallbacks def test_deliver_sm_simple(self): """A message delivery should log an error since we're supposed to be a transmitter only.""" def cb(**kw): self.assertEqual(u'hello', kw['short_message']) with LogCatcher() as log: esme = yield self.get_esme(deliver_sm=cb) esme.state = 'BOUND_TX' # Assume we've bound correctly as a TX esme.handle_deliver_sm(self.get_sm('hello')) [error] = log.errors self.assertTrue('deliver_sm in wrong state' in error['message'][0]) class TestEsmeReceiver(EsmeTestCaseBase, EsmeReceiverMixin): ESME_CLASS = FakeEsmeReceiver @inlineCallbacks def test_submit_sm_simple(self): """A simple message log an error when trying to send over a receiver.""" with LogCatcher() as log: esme = yield self.get_esme() esme.state = 'BOUND_RX' # Fake RX bind yield esme.submit_sm(short_message='hello') [error] = log.errors self.assertTrue(('submit_sm in wrong state' in error['message'][0])) class TestESME(VumiTestCase): def setUp(self): config = SmppTransportConfig({ "transport_name": "transport_name", "host": '127.0.0.1', "port": 2775, "system_id": 'test_system', "password": 'password', }) self.kvs = None self.esme_callbacks = None self.esme = ESME(config, { 'system_id': 'test_system', 'password': 'password', }, self.kvs, self.esme_callbacks) def test_bind_as_transceiver(self): return self.esme.bindTransciever() PK=JGEEAvumi/transports/smpp/deprecated/clientserver/tests/test_server.py# No tests yet, the server is currently only used to test the client PK=JG>vumi/transports/smpp/deprecated/clientserver/tests/__init__.pyPK=JG^/2vumi/transports/smpp/deprecated/tests/test_smpp.py# -*- coding: utf-8 -*- import binascii from twisted.internet.defer import Deferred, inlineCallbacks, succeed from twisted.internet.task import Clock from smpp.pdu_builder import SubmitSMResp, DeliverSM from vumi.config import ConfigError from vumi.message import TransportUserMessage from vumi.transports.smpp.deprecated.clientserver.client import ( EsmeTransceiver, EsmeCallbacks) from vumi.transports.smpp.deprecated.transport import ( SmppTransport, SmppTxTransport, SmppRxTransport) from vumi.transports.smpp.deprecated.service import SmppService from vumi.transports.smpp.deprecated.clientserver.client import ( unpacked_pdu_opts) from vumi.transports.smpp.deprecated.clientserver.tests.utils import ( SmscTestServer) from vumi.tests.utils import LogCatcher from vumi.transports.tests.helpers import TransportHelper from vumi.tests.helpers import VumiTestCase class TestSmppTransportConfig(VumiTestCase): def required_config(self, config_params): config = { "transport_name": "my_transport", "twisted_endpoint": "tcp:host=127.0.0.1:port=0", "system_id": "vumitest-vumitest-vumitest", "password": "password", } config.update(config_params) return config def get_config(self, config_dict): return SmppTransport.CONFIG_CLASS(config_dict) def assert_config_error(self, config_dict): try: self.get_config(config_dict) self.fail("ConfigError not raised.") except ConfigError as err: return err.args[0] def test_long_message_params(self): self.get_config(self.required_config({})) self.get_config(self.required_config({'send_long_messages': True})) self.get_config(self.required_config({'send_multipart_sar': True})) self.get_config(self.required_config({'send_multipart_udh': True})) errmsg = self.assert_config_error(self.required_config({ 'send_long_messages': True, 'send_multipart_sar': True, })) self.assertEqual(errmsg, ( "The following parameters are mutually exclusive: " "send_long_messages, send_multipart_sar")) errmsg = self.assert_config_error(self.required_config({ 'send_long_messages': True, 'send_multipart_sar': True, 'send_multipart_udh': True, })) self.assertEqual(errmsg, ( "The following parameters are mutually exclusive: " "send_long_messages, send_multipart_sar, send_multipart_udh")) class TestSmppTransport(VumiTestCase): @inlineCallbacks def setUp(self): config = { "system_id": "vumitest-vumitest-vumitest", "twisted_endpoint": "tcp:host=127.0.0.1:port=0", "password": "password", "smpp_bind_timeout": 12, "smpp_enquire_link_interval": 123, "third_party_id_expiry": 3600, # just 1 hour } # hack a lot of transport setup self.tx_helper = self.add_helper(TransportHelper(SmppTransport)) self.transport = yield self.tx_helper.get_transport( config, start=False) self.transport.esme_client = None yield self.transport.startWorker() self._make_esme() self.transport.esme_client = self.esme self.transport.esme_connected(self.esme) def _make_esme(self): self.esme_callbacks = EsmeCallbacks( connect=lambda: None, disconnect=lambda: None, submit_sm_resp=self.transport.submit_sm_resp, delivery_report=self.transport.delivery_report, deliver_sm=lambda: None) self.esme = EsmeTransceiver( self.transport.get_static_config(), self.transport.get_smpp_bind_params(), self.transport.redis, self.esme_callbacks) self.esme.sent_pdus = [] self.esme.send_pdu = self.esme.sent_pdus.append self.esme.state = 'BOUND_TRX' def assert_sent_contents(self, expected): pdu_contents = [p.obj['body']['mandatory_parameters']['short_message'] for p in self.esme.sent_pdus] self.assertEqual(expected, pdu_contents) @inlineCallbacks def test_message_persistence(self): # A simple test of set -> get -> delete for redis message persistence message1 = self.tx_helper.make_outbound("hello world") original_json = message1.to_json() yield self.transport.r_set_message(message1) retrieved_json = yield self.transport.r_get_message_json( message1['message_id']) self.assertEqual(original_json, retrieved_json) retrieved_message = yield self.transport.r_get_message( message1['message_id']) self.assertEqual(retrieved_message, message1) self.assertTrue((yield self.transport.r_delete_message( message1['message_id']))) self.assertEqual((yield self.transport.r_get_message_json( message1['message_id'])), None) self.assertEqual((yield self.transport.r_get_message( message1['message_id'])), None) @inlineCallbacks def test_message_persistence_expiry(self): message = self.tx_helper.make_outbound("hello world") yield self.transport.r_set_message(message) # check that the expiry is set message_key = self.transport.r_message_key(message['message_id']) config = self.transport.get_static_config() ttl = yield self.transport.redis.ttl(message_key) self.assertTrue(0 < ttl <= config.submit_sm_expiry) @inlineCallbacks def test_redis_third_party_id_persistence(self): # Testing: set -> get -> delete, for redis third party id mapping self.assertEqual( self.transport.get_static_config().third_party_id_expiry, 3600) our_id = "blergh34534545433454354" their_id = "omghesvomitingnumbers" yield self.transport.r_set_id_for_third_party_id(their_id, our_id) retrieved_our_id = ( yield self.transport.r_get_id_for_third_party_id(their_id)) self.assertEqual(our_id, retrieved_our_id) self.assertTrue(( yield self.transport.r_delete_for_third_party_id(their_id))) self.assertEqual(None, ( yield self.transport.r_get_id_for_third_party_id(their_id))) @inlineCallbacks def test_out_of_order_responses(self): # Sequence numbers are hardcoded, assuming we start fresh from 0. yield self.tx_helper.make_dispatch_outbound("msg 1", message_id='444') response1 = SubmitSMResp(1, "3rd_party_id_1") yield self.tx_helper.make_dispatch_outbound("msg 2", message_id='445') response2 = SubmitSMResp(2, "3rd_party_id_2") self.assert_sent_contents(["msg 1", "msg 2"]) # respond out of order - just to keep things interesting yield self.esme.handle_data(response2.get_bin()) yield self.esme.handle_data(response1.get_bin()) [ack1, ack2] = self.tx_helper.get_dispatched_events() self.assertEqual(ack1['user_message_id'], '445') self.assertEqual(ack1['sent_message_id'], '3rd_party_id_2') self.assertEqual(ack2['user_message_id'], '444') self.assertEqual(ack2['sent_message_id'], '3rd_party_id_1') @inlineCallbacks def test_failed_submit(self): message = yield self.tx_helper.make_dispatch_outbound( "message", message_id='446') response = SubmitSMResp( 1, "3rd_party_id_3", command_status="ESME_RSUBMITFAIL") yield self.esme.handle_data(response.get_bin()) self.assert_sent_contents(["message"]) # There should be a nack [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], message['message_id']) self.assertEqual(nack['nack_reason'], 'ESME_RSUBMITFAIL') [failure] = yield self.tx_helper.get_dispatched_failures() self.assertEqual(failure['reason'], 'ESME_RSUBMITFAIL') @inlineCallbacks def test_failed_submit_with_no_reason(self): message = yield self.tx_helper.make_dispatch_outbound( "message", message_id='446') # Equivalent of SubmitSMResp(1, "3rd_party_id_3", command_status='XXX') # but with a bad command_status (pdu_builder can't produce binary with # command_statuses it doesn't understand). Use # smpp.pdu.unpack(response_bin) to get a PDU object: response_hex = ("0000001f80000004" "0000ffff" # unknown command status "000000013372645f70617274795f69645f3300") yield self.esme.handle_data(binascii.a2b_hex(response_hex)) self.assert_sent_contents(["message"]) # There should be a nack [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], message['message_id']) self.assertEqual(nack['nack_reason'], 'Unspecified') [failure] = yield self.tx_helper.get_dispatched_failures() self.assertEqual(failure['reason'], 'Unspecified') @inlineCallbacks def test_delivery_report_for_unknown_message(self): dr = ("id:123 sub:... dlvrd:... submit date:200101010030" " done date:200101020030 stat:DELIVRD err:... text:Meep") deliver = DeliverSM(1, short_message=dr) with LogCatcher(message="Failed to retrieve message id") as lc: yield self.esme.handle_data(deliver.get_bin()) [warning] = lc.logs self.assertEqual(warning['message'], ("Failed to retrieve message id for delivery " "report. Delivery report from %s " "discarded." % self.tx_helper.transport_name,)) @inlineCallbacks def test_throttled_submit_ESME_RTHROTTLED(self): clock = Clock() self.transport.callLater = clock.callLater def assert_throttled_status(throttled, messages, acks): self.assertEqual(self.transport.throttled, throttled) self.assert_sent_contents(messages) self.assertEqual(acks, [ (m['user_message_id'], m['sent_message_id']) for m in self.tx_helper.get_dispatched_events()]) self.assertEqual([], self.tx_helper.get_dispatched_failures()) assert_throttled_status(False, [], []) yield self.tx_helper.make_dispatch_outbound( "Heimlich", message_id="447") response = SubmitSMResp(1, "3rd_party_id_4", command_status="ESME_RTHROTTLED") yield self.esme.handle_data(response.get_bin()) assert_throttled_status(True, ["Heimlich"], []) # Still waiting to resend clock.advance(0.05) yield self.transport.redis.exists('wait for redis') assert_throttled_status(True, ["Heimlich"], []) # Don't wait for this, because it won't be processed until later. self.tx_helper.make_dispatch_outbound("Other", message_id="448") assert_throttled_status(True, ["Heimlich"], []) # Resent clock.advance(0.05) yield self.transport.redis.exists('wait for redis') assert_throttled_status(True, ["Heimlich", "Heimlich"], []) # And acknowledged by the other side yield self.esme.handle_data(SubmitSMResp(2, "3rd_party_5").get_bin()) yield self.tx_helper.kick_delivery() yield self.esme.handle_data(SubmitSMResp(3, "3rd_party_6").get_bin()) assert_throttled_status( False, ["Heimlich", "Heimlich", "Other"], [('447', '3rd_party_5'), ('448', '3rd_party_6')]) @inlineCallbacks def test_throttled_submit_ESME_RMSGQFUL(self): clock = Clock() self.transport.callLater = clock.callLater def assert_throttled_status(throttled, messages, acks): self.assertEqual(self.transport.throttled, throttled) self.assert_sent_contents(messages) self.assertEqual(acks, [ (m['user_message_id'], m['sent_message_id']) for m in self.tx_helper.get_dispatched_events()]) self.assertEqual([], self.tx_helper.get_dispatched_failures()) assert_throttled_status(False, [], []) yield self.tx_helper.make_dispatch_outbound( "Heimlich", message_id="447") response = SubmitSMResp(1, "3rd_party_id_4", command_status="ESME_RMSGQFUL") yield self.esme.handle_data(response.get_bin()) assert_throttled_status(True, ["Heimlich"], []) # Still waiting to resend clock.advance(0.05) yield self.transport.redis.exists('wait for redis') assert_throttled_status(True, ["Heimlich"], []) # Don't wait for this, because it won't be processed until later. self.tx_helper.make_dispatch_outbound("Other", message_id="448") assert_throttled_status(True, ["Heimlich"], []) # Resent clock.advance(0.05) yield self.transport.redis.exists('wait for redis') assert_throttled_status(True, ["Heimlich", "Heimlich"], []) # And acknowledged by the other side yield self.esme.handle_data(SubmitSMResp(2, "3rd_party_5").get_bin()) yield self.tx_helper.kick_delivery() yield self.esme.handle_data(SubmitSMResp(3, "3rd_party_6").get_bin()) assert_throttled_status( False, ["Heimlich", "Heimlich", "Other"], [('447', '3rd_party_5'), ('448', '3rd_party_6')]) @inlineCallbacks def test_reconnect(self): connector = self.transport.connectors[self.transport.transport_name] self.assertFalse(connector._consumers['outbound'].paused) yield self.transport.esme_disconnected() self.assertTrue(connector._consumers['outbound'].paused) yield self.transport.esme_disconnected() self.assertTrue(connector._consumers['outbound'].paused) yield self.transport.esme_connected(self.esme) self.assertFalse(connector._consumers['outbound'].paused) yield self.transport.esme_connected(self.esme) self.assertFalse(connector._consumers['outbound'].paused) class MockSmppTransport(SmppTransport): @inlineCallbacks def esme_connected(self, client): yield super(MockSmppTransport, self).esme_connected(client) self._block_till_bind.callback(None) class MockSmppTxTransport(SmppTxTransport): @inlineCallbacks def esme_connected(self, client): yield super(MockSmppTxTransport, self).esme_connected(client) self._block_till_bind.callback(None) class MockSmppRxTransport(SmppRxTransport): @inlineCallbacks def esme_connected(self, client): yield super(MockSmppRxTransport, self).esme_connected(client) self._block_till_bind.callback(None) def mk_expected_pdu(direction, sequence_number, command_id, **extras): headers = { 'command_status': 'ESME_ROK', 'sequence_number': sequence_number, 'command_id': command_id, } headers.update(extras) return {"direction": direction, "pdu": {"header": headers}} class EsmeToSmscTestCase(VumiTestCase): CONFIG_OVERRIDE = {} @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(MockSmppTransport)) server_config = { "transport_name": self.tx_helper.transport_name, "system_id": "VumiTestSMSC", "password": "password", "twisted_endpoint": "tcp:0", "transport_type": "smpp", } server_config.update(self.CONFIG_OVERRIDE) self.service = SmppService(None, config=server_config) self.add_cleanup(self.cleanup_service) yield self.service.startWorker() self.service.factory.protocol = SmscTestServer host = self.service.listening.getHost() client_config = server_config.copy() client_config['twisted_endpoint'] = 'tcp:host=%s:port=%s' % ( host.host, host.port) self.transport = yield self.tx_helper.get_transport( client_config, start=False) self.expected_delivery_status = 'delivered' @inlineCallbacks def cleanup_service(self): yield self.service.listening.stopListening() yield self.service.listening.loseConnection() def assert_pdu_header(self, expected, actual, field): self.assertEqual(expected['pdu']['header'][field], actual['pdu']['header'][field]) def assert_server_pdu(self, expected, actual): self.assertEqual(expected['direction'], actual['direction']) self.assert_pdu_header(expected, actual, 'sequence_number') self.assert_pdu_header(expected, actual, 'command_status') self.assert_pdu_header(expected, actual, 'command_id') @inlineCallbacks def clear_link_pdus(self): for expected in [ mk_expected_pdu("inbound", 1, "bind_transceiver"), mk_expected_pdu("outbound", 1, "bind_transceiver_resp"), mk_expected_pdu("inbound", 2, "enquire_link"), mk_expected_pdu("outbound", 2, "enquire_link_resp")]: pdu = yield self.service.factory.smsc.pdu_queue.get() self.assert_server_pdu(expected, pdu) @inlineCallbacks def startTransport(self): self.transport._block_till_bind = Deferred() yield self.transport.startWorker() @inlineCallbacks def test_handshake_submit_and_deliver(self): # 1111111111111111111111111111111111111111111111111 expected_pdus_1 = [ mk_expected_pdu("inbound", 1, "bind_transceiver"), mk_expected_pdu("outbound", 1, "bind_transceiver_resp"), mk_expected_pdu("inbound", 2, "enquire_link"), mk_expected_pdu("outbound", 2, "enquire_link_resp"), ] # 2222222222222222222222222222222222222222222222222 expected_pdus_2 = [ mk_expected_pdu("inbound", 3, "submit_sm"), mk_expected_pdu("outbound", 3, "submit_sm_resp"), # the delivery report mk_expected_pdu("outbound", 1, "deliver_sm"), mk_expected_pdu("inbound", 1, "deliver_sm_resp"), ] # 3333333333333333333333333333333333333333333333333 expected_pdus_3 = [ # a sms delivered by the smsc mk_expected_pdu("outbound", 555, "deliver_sm"), mk_expected_pdu("inbound", 555, "deliver_sm_resp"), ] ## Startup yield self.startTransport() yield self.transport._block_till_bind # First we make sure the Client binds to the Server # and enquire_link pdu's are exchanged as expected pdu_queue = self.service.factory.smsc.pdu_queue for expected_message in expected_pdus_1: actual_message = yield pdu_queue.get() self.assert_server_pdu(expected_message, actual_message) # Next the Client submits a SMS to the Server # and recieves an ack and a delivery_report msg = yield self.tx_helper.make_dispatch_outbound("hello world") for expected_message in expected_pdus_2: actual_message = yield pdu_queue.get() self.assert_server_pdu(expected_message, actual_message) # We need the user_message_id to check the ack user_message_id = msg["message_id"] [ack, delv] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['transport_name'], self.tx_helper.transport_name) self.assertEqual(ack['user_message_id'], user_message_id) self.assertEqual(delv['message_type'], 'event') self.assertEqual(delv['event_type'], 'delivery_report') self.assertEqual(delv['transport_name'], self.tx_helper.transport_name) self.assertEqual(delv['user_message_id'], user_message_id) self.assertEqual(delv['delivery_status'], self.expected_delivery_status) # Finally the Server delivers a SMS to the Client pdu = DeliverSM(555, short_message="SMS from server", destination_addr="2772222222", source_addr="2772000000") self.service.factory.smsc.send_pdu(pdu) for expected_message in expected_pdus_3: actual_message = yield pdu_queue.get() self.assert_server_pdu(expected_message, actual_message) [mess] = self.tx_helper.get_dispatched_inbound() self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "SMS from server") dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) def send_out_of_order_multipart(self, smsc, to_addr, from_addr): destination_addr = to_addr source_addr = from_addr sequence_number = 1 short_message1 = "\x05\x00\x03\xff\x03\x01back" pdu1 = DeliverSM(sequence_number, short_message=short_message1, destination_addr=destination_addr, source_addr=source_addr) sequence_number = 2 short_message2 = "\x05\x00\x03\xff\x03\x02 at" pdu2 = DeliverSM(sequence_number, short_message=short_message2, destination_addr=destination_addr, source_addr=source_addr) sequence_number = 3 short_message3 = "\x05\x00\x03\xff\x03\x03 you" pdu3 = DeliverSM(sequence_number, short_message=short_message3, destination_addr=destination_addr, source_addr=source_addr) smsc.send_pdu(pdu2) smsc.send_pdu(pdu3) smsc.send_pdu(pdu1) @inlineCallbacks def test_submit_and_deliver(self): # Startup yield self.startTransport() yield self.transport._block_till_bind # Next the Client submits a SMS to the Server # and recieves an ack and a delivery_report msg = yield self.tx_helper.make_dispatch_outbound("hello world") # We need the user_message_id to check the ack user_message_id = msg["message_id"] [ack, delv] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['transport_name'], self.tx_helper.transport_name) self.assertEqual(ack['user_message_id'], user_message_id) self.assertEqual(delv['message_type'], 'event') self.assertEqual(delv['event_type'], 'delivery_report') self.assertEqual(delv['transport_name'], self.tx_helper.transport_name) self.assertEqual(delv['user_message_id'], user_message_id) self.assertEqual(delv['delivery_status'], self.expected_delivery_status) # Finally the Server delivers a SMS to the Client pdu = DeliverSM(555, short_message="SMS from server", destination_addr="2772222222", source_addr="2772000000") self.service.factory.smsc.send_pdu(pdu) # Have the server fire of an out-of-order multipart sms self.send_out_of_order_multipart(self.service.factory.smsc, to_addr="2772222222", from_addr="2772000000") [mess, multipart] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "SMS from server") # Check the incomming multipart is re-assembled correctly self.assertEqual(multipart['message_type'], 'user_message') self.assertEqual( multipart['transport_name'], self.tx_helper.transport_name) self.assertEqual(multipart['content'], "back at you") dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) @inlineCallbacks def test_submit_sm_encoding(self): # Startup yield self.startTransport() self.transport.submit_sm_encoding = 'latin-1' yield self.transport._block_till_bind yield self.clear_link_pdus() yield self.tx_helper.make_dispatch_outbound(u'Zoë destroyer of Ascii!') pdu_queue = self.service.factory.smsc.pdu_queue submit_sm_pdu = yield pdu_queue.get() sms = submit_sm_pdu['pdu']['body']['mandatory_parameters'] self.assertEqual( sms['short_message'], u'Zoë destroyer of Ascii!'.encode('latin-1')) # clear ack and nack yield self.tx_helper.wait_for_dispatched_events(2) @inlineCallbacks def test_submit_sm_data_coding(self): # Startup yield self.startTransport() self.transport.submit_sm_data_coding = 8 yield self.transport._block_till_bind yield self.clear_link_pdus() yield self.tx_helper.make_dispatch_outbound("hello world") pdu_queue = self.service.factory.smsc.pdu_queue submit_sm_pdu = yield pdu_queue.get() sms = submit_sm_pdu['pdu']['body']['mandatory_parameters'] self.assertEqual(sms['data_coding'], 8) # clear ack and nack yield self.tx_helper.wait_for_dispatched_events(2) @inlineCallbacks def test_submit_and_deliver_ussd_continue(self): # Startup yield self.startTransport() yield self.transport._block_till_bind yield self.clear_link_pdus() # Next the Client submits a USSD message to the Server # and recieves an ack msg = yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd") # First we make sure the Client binds to the Server # and enquire_link pdu's are exchanged as expected pdu_queue = self.service.factory.smsc.pdu_queue submit_sm_pdu = yield pdu_queue.get() self.assert_server_pdu( mk_expected_pdu('inbound', 3, 'submit_sm'), submit_sm_pdu) pdu_opts = unpacked_pdu_opts(submit_sm_pdu['pdu']) self.assertEqual('02', pdu_opts['ussd_service_op']) self.assertEqual('0000', pdu_opts['its_session_info']) # We need the user_message_id to check the ack user_message_id = msg.payload["message_id"] [ack, delv] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['transport_name'], self.tx_helper.transport_name) self.assertEqual(ack['user_message_id'], user_message_id) self.assertEqual(delv['message_type'], 'event') self.assertEqual(delv['event_type'], 'delivery_report') self.assertEqual(delv['transport_name'], self.tx_helper.transport_name) self.assertEqual(delv['user_message_id'], user_message_id) self.assertEqual(delv['delivery_status'], self.expected_delivery_status) # Finally the Server delivers a USSD message to the Client pdu = DeliverSM(555, short_message="reply!", destination_addr="2772222222", source_addr="2772000000") pdu._PDU__add_optional_parameter('ussd_service_op', '02') pdu._PDU__add_optional_parameter('its_session_info', '0000') self.service.factory.smsc.send_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual([], self.tx_helper.get_dispatched_failures()) @inlineCallbacks def test_submit_and_deliver_ussd_close(self): # Startup yield self.startTransport() yield self.transport._block_till_bind yield self.clear_link_pdus() # Next the Client submits a USSD message to the Server # and recieves an ack msg = yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE) # First we make sure the Client binds to the Server # and enquire_link pdu's are exchanged as expected pdu_queue = self.service.factory.smsc.pdu_queue submit_sm_pdu = yield pdu_queue.get() self.assert_server_pdu( mk_expected_pdu('inbound', 3, 'submit_sm'), submit_sm_pdu) pdu_opts = unpacked_pdu_opts(submit_sm_pdu['pdu']) self.assertEqual('02', pdu_opts['ussd_service_op']) self.assertEqual('0001', pdu_opts['its_session_info']) # We need the user_message_id to check the ack user_message_id = msg.payload["message_id"] [ack, delv] = yield self.tx_helper.wait_for_dispatched_events(2) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['transport_name'], self.tx_helper.transport_name) self.assertEqual(ack['user_message_id'], user_message_id) self.assertEqual(delv['message_type'], 'event') self.assertEqual(delv['event_type'], 'delivery_report') self.assertEqual(delv['transport_name'], self.tx_helper.transport_name) self.assertEqual(delv['user_message_id'], user_message_id) self.assertEqual(delv['delivery_status'], self.expected_delivery_status) # Finally the Server delivers a USSD message to the Client pdu = DeliverSM(555, short_message="reply!", destination_addr="2772222222", source_addr="2772000000") pdu._PDU__add_optional_parameter('ussd_service_op', '02') pdu._PDU__add_optional_parameter('its_session_info', '0001') self.service.factory.smsc.send_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_CLOSE) self.assertEqual([], self.tx_helper.get_dispatched_failures()) @inlineCallbacks def test_submit_and_deliver_with_missing_id_lookup(self): def r_failing_get(third_party_id): return succeed(None) self.transport.r_get_id_for_third_party_id = r_failing_get # Startup yield self.startTransport() yield self.transport._block_till_bind # Next the Client submits a SMS to the Server # and recieves an ack and a delivery_report lc = LogCatcher(message="Failed to retrieve message id") with lc: msg = yield self.tx_helper.make_dispatch_outbound("hello world") # We need the user_message_id to check the ack user_message_id = msg["message_id"] [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['transport_name'], self.tx_helper.transport_name) self.assertEqual(ack['user_message_id'], user_message_id) # check that failure to send delivery report was logged [warning] = lc.logs expected_msg = ( "Failed to retrieve message id for delivery report. Delivery" " report from %s discarded.") % (self.tx_helper.transport_name,) self.assertEqual(warning['message'], (expected_msg,)) class TestDeliveryYo(EsmeToSmscTestCase): # This tests a slightly non-standard delivery report format for Yo! # the following delivery_report_regex is required as a config option # "id:(?P\S{,65}) +sub:(?P.{1,3}) +dlvrd:(?P.{1,3})" # " +submit date:(?P\d*) +done date:(?P\d*)" # " +stat:(?P[0-9,A-Z]{1,7}) +err:(?P.{1,3})" #" +[Tt]ext:(?P.{,20}).* DELIVERY_REPORT_REGEX = ( "id:(?P\S{,65})" " +sub:(?P.{1,3})" " +dlvrd:(?P.{1,3})" " +submit date:(?P\d*)" " +done date:(?P\d*)" " +stat:(?P[0-9,A-Z]{1,7})" " +err:(?P.{1,3})" " +[Tt]ext:(?P.{,20}).*") CONFIG_OVERRIDE = { "delivery_report_regex": DELIVERY_REPORT_REGEX, "smsc_delivery_report_string": ( 'id:%s sub:1 dlvrd:1 submit date:%s done date:%s ' 'stat:0 err:0 text:If a general electio'), } class TestDeliveryOverrideMapping(EsmeToSmscTestCase): # This tests a non-standard delivery report status mapping. CONFIG_OVERRIDE = { "delivery_report_regex": "id:(?P\S+) stat:(?P\S+) .*", "delivery_report_status_mapping": {"foo": "delivered"}, "smsc_delivery_report_string": ( 'id:%s stat:foo submit date:%s done date:%s'), } class TestEsmeToSmscTx(VumiTestCase): def assert_pdu_header(self, expected, actual, field): self.assertEqual(expected['pdu']['header'][field], actual['pdu']['header'][field]) def assert_server_pdu(self, expected, actual): self.assertEqual(expected['direction'], actual['direction']) self.assert_pdu_header(expected, actual, 'sequence_number') self.assert_pdu_header(expected, actual, 'command_status') self.assert_pdu_header(expected, actual, 'command_id') @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(MockSmppTxTransport)) self.config = { "transport_name": self.tx_helper.transport_name, "system_id": "VumiTestSMSC", "password": "password", "host": "127.0.0.1", "transport_type": "smpp", } server_config = self.config.copy() server_config["twisted_endpoint"] = "tcp:0" self.service = SmppService(None, config=server_config) self.add_cleanup(self.cleanup_service) yield self.service.startWorker() self.service.factory.protocol = SmscTestServer self.config['port'] = self.service.listening.getHost().port self.transport = yield self.tx_helper.get_transport( self.config, start=False) self.expected_delivery_status = 'delivered' @inlineCallbacks def cleanup_service(self): yield self.service.listening.stopListening() yield self.service.listening.loseConnection() @inlineCallbacks def startTransport(self): self.transport._block_till_bind = Deferred() yield self.transport.startWorker() @inlineCallbacks def test_submit(self): # Startup yield self.startTransport() yield self.transport._block_till_bind # Next the Client submits a SMS to the Server # and recieves an ack msg = yield self.tx_helper.make_dispatch_outbound("hello world") # We need the user_message_id to check the ack user_message_id = msg["message_id"] [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['message_type'], 'event') self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['user_message_id'], user_message_id) dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) class TestEsmeToSmscRx(VumiTestCase): def assert_pdu_header(self, expected, actual, field): self.assertEqual(expected['pdu']['header'][field], actual['pdu']['header'][field]) def assert_server_pdu(self, expected, actual): self.assertEqual(expected['direction'], actual['direction']) self.assert_pdu_header(expected, actual, 'sequence_number') self.assert_pdu_header(expected, actual, 'command_status') self.assert_pdu_header(expected, actual, 'command_id') @inlineCallbacks def setUp(self): from twisted.internet.base import DelayedCall DelayedCall.debug = True self.tx_helper = self.add_helper(TransportHelper(MockSmppRxTransport)) self.config = { "transport_name": self.tx_helper.transport_name, "system_id": "VumiTestSMSC", "password": "password", "host": "127.0.0.1", "transport_type": "smpp", } server_config = self.config.copy() server_config['twisted_endpoint'] = "tcp:0" self.service = SmppService(None, config=server_config) self.add_cleanup(self.cleanup_service) yield self.service.startWorker() self.service.factory.protocol = SmscTestServer self.config['port'] = self.service.listening.getHost().port self.transport = yield self.tx_helper.get_transport( self.config, start=False) self.expected_delivery_status = 'delivered' @inlineCallbacks def startTransport(self): self.transport._block_till_bind = Deferred() yield self.transport.startWorker() @inlineCallbacks def cleanup_service(self): yield self.service.listening.stopListening() yield self.service.listening.loseConnection() @inlineCallbacks def test_deliver(self): # Startup yield self.startTransport() yield self.transport._block_till_bind # The Server delivers a SMS to the Client pdu = DeliverSM(555, short_message="SMS from server", destination_addr="2772222222", source_addr="2772000000") self.service.factory.smsc.send_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "SMS from server") dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) @inlineCallbacks def test_deliver_bad_encoding(self): # Startup yield self.startTransport() yield self.transport._block_till_bind # The Server delivers a SMS to the Client bad_pdu = DeliverSM(555, short_message="SMS from server containing \xa7", destination_addr="2772222222", source_addr="2772000000") good_pdu = DeliverSM(555, short_message="Next message", destination_addr="2772222222", source_addr="2772000000") self.service.factory.smsc.send_pdu(bad_pdu) self.service.factory.smsc.send_pdu(good_pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['message_type'], 'user_message') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], "Next message") dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) [failure] = self.flushLoggedErrors(UnicodeDecodeError) message = failure.getErrorMessage() codec, rest = message.split(' ', 1) self.assertTrue(codec in ("'utf8'", "'utf-8'")) self.assertTrue(rest.startswith( "codec can't decode byte 0xa7 in position 27")) @inlineCallbacks def test_deliver_ussd_start(self): # Startup yield self.startTransport() yield self.transport._block_till_bind # The Server delivers a SMS to the Client pdu = DeliverSM( 555, destination_addr="2772222222", source_addr="2772000000") pdu._PDU__add_optional_parameter('ussd_service_op', '01') pdu._PDU__add_optional_parameter('its_session_info', '0000') self.service.factory.smsc.send_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['transport_type'], 'ussd') self.assertEqual(mess['transport_name'], self.tx_helper.transport_name) self.assertEqual(mess['content'], None) self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_NEW) dispatched_failures = self.tx_helper.get_dispatched_failures() self.assertEqual(dispatched_failures, []) PK=JG1vumi/transports/smpp/deprecated/tests/__init__.pyPK=H 'RR*vumi/transports/smpp/processors/default.pyimport json from smpp.pdu_inspector import ( detect_multipart, multipart_key, MultipartMessage) from twisted.internet.defer import inlineCallbacks, returnValue, succeed from zope.interface import implements from vumi.config import ( Config, ConfigDict, ConfigRegex, ConfigText, ConfigInt, ConfigBool) from vumi.message import TransportUserMessage from vumi.transports.smpp.iprocessors import ( IDeliveryReportProcessor, IDeliverShortMessageProcessor, ISubmitShortMessageProcessor) from vumi.transports.smpp.smpp_utils import unpacked_pdu_opts, detect_ussd class DeliveryReportProcessorConfig(Config): DELIVERY_REPORT_REGEX = ( 'id:(?P[^ ]{,65})' '(?: +sub:(?P[^ ]+))?' '(?: +dlvrd:(?P[^ ]+))?' '(?: +submit date:(?P\d*))?' '(?: +done date:(?P\d*))?' ' +stat:(?P[A-Z]{5,7})' '(?: +err:(?P[^ ]+))?' ' +[Tt]ext:(?P.{,20})' '.*' ) DELIVERY_REPORT_STATUS_MAPPING = { # Output values should map to themselves: 'delivered': 'delivered', 'failed': 'failed', 'pending': 'pending', # SMPP `message_state` values: 'ENROUTE': 'pending', 'DELIVERED': 'delivered', 'EXPIRED': 'failed', 'DELETED': 'failed', 'UNDELIVERABLE': 'failed', 'ACCEPTED': 'delivered', 'UNKNOWN': 'pending', 'REJECTED': 'failed', # From the most common regex-extracted format: 'DELIVRD': 'delivered', 'REJECTD': 'failed', 'FAILED': 'failed', # Currently we will accept this for Yo! TODO: investigate '0': 'delivered', } delivery_report_regex = ConfigRegex( 'Regex to use for matching delivery reports', default=DELIVERY_REPORT_REGEX, static=True) delivery_report_status_mapping = ConfigDict( "Mapping from delivery report message state to" " (`delivered`, `failed`, `pending`)", default=DELIVERY_REPORT_STATUS_MAPPING, static=True) delivery_report_use_esm_class = ConfigBool( "Use `esm_class` PDU parameter to determine whether a message is a" " delivery report.", default=True, static=True) class DeliveryReportProcessor(object): implements(IDeliveryReportProcessor) CONFIG_CLASS = DeliveryReportProcessorConfig STATUS_MAP = { 1: 'ENROUTE', 2: 'DELIVERED', 3: 'EXPIRED', 4: 'DELETED', 5: 'UNDELIVERABLE', 6: 'ACCEPTED', 7: 'UNKNOWN', 8: 'REJECTED', } ESM_CLASS_MASK = 0b00111100 # If any of bits 5-2 are set, assume DR. def __init__(self, transport, config): self.transport = transport self.log = transport.log self.config = self.CONFIG_CLASS(config, static=True) def _handle_delivery_report_optional_params(self, pdu): """ Check if this might be a delivery report with optional PDU parameters. If so, handle it and return a deferred ``True``, otherwise return a deferred ``False``. """ pdu_opts = unpacked_pdu_opts(pdu) receipted_message_id = pdu_opts.get('receipted_message_id', None) message_state = pdu_opts.get('message_state', None) if receipted_message_id is None or message_state is None: return succeed(False) status = self.STATUS_MAP.get(message_state, 'UNKNOWN') d = self.transport.handle_delivery_report( receipted_message_id=receipted_message_id, delivery_status=self.delivery_status(status), smpp_delivery_status=status) d.addCallback(lambda _: True) return d def _process_delivery_report_content_fields(self, content_fields): """ Construct and dispatch a delivery report based on content fields as matched by our regex. """ receipted_message_id = content_fields['id'] message_state = content_fields['stat'] return self.transport.handle_delivery_report( receipted_message_id=receipted_message_id, delivery_status=self.delivery_status(message_state), smpp_delivery_status=message_state) def _handle_delivery_report_esm_class(self, pdu): """ Check if the ``esm_class`` indicates that this is a delivery report. If so, handle it and return a deferred ``True``, otherwise return a deferred ``False``. NOTE: We assume the message content is a string that matches our regex. We can't use the usual decoding process here because it lives elsewhere and the content should be plain ASCII generated by the SMSC anyway. """ if not self.config.delivery_report_use_esm_class: # We're not configured to check the ``esm_class``, so do nothing. return succeed(False) esm_class = pdu["body"]["mandatory_parameters"]["esm_class"] if not (esm_class & self.ESM_CLASS_MASK): # Delivery report flags in esm_class are not set. return succeed(False) content = pdu["body"]["mandatory_parameters"]["short_message"] match = self.config.delivery_report_regex.search(content or '') if not match: self.log.warning( ("esm_class %s indicates delivery report, but content" " does not match regex: %r") % (esm_class, content)) # Even though this doesn't match the regex, the esm_class indicates # that it's a DR and we therefore don't want to treat it as a # normal message. return succeed(True) fields = match.groupdict() d = self._process_delivery_report_content_fields(fields) d.addCallback(lambda _: True) return d @inlineCallbacks def handle_delivery_report_pdu(self, pdu): """ Check PDU optional params and ``esm_class`` to detect and handle delivery reports. Return a deferred ``True`` if a delivery report was detected and handled, otherwise return a deferred ``False``. In the latter case, the content may be examined in ``handle_delivery_report_content`` later. """ # Check for optional params indicating a DR. if (yield self._handle_delivery_report_optional_params(pdu)): returnValue(True) if (yield self._handle_delivery_report_esm_class(pdu)): returnValue(True) returnValue(False) def handle_delivery_report_content(self, content): """ Check the content against our delivery report regex and treat it as a delivery report if it matches. If we are configured to check the PDU ``esm_class``, we skip this check because any delivery reports will already have been handled by ``handle_delivery_report_pdu``. """ if self.config.delivery_report_use_esm_class: # We're configured to check ``esm_class``, so we don't check # content here. return succeed(False) match = self.config.delivery_report_regex.search(content or '') if not match: return succeed(False) # We have a delivery report. d = self._process_delivery_report_content_fields(match.groupdict()) d.addCallback(lambda _: True) return d def delivery_status(self, state): return self.config.delivery_report_status_mapping.get(state, 'pending') class DeliverShortMessageProcessorConfig(Config): data_coding_overrides = ConfigDict( "Overrides for data_coding character set mapping. This is useful for " "setting the default encoding (0), adding additional undefined " "encodings (such as 4 or 8) or overriding encodings in cases where " "the SMSC is violating the spec (which happens a lot). Keys should " "be integers, values should be strings containing valid Python " "character encoding names.", default={}, static=True) allow_empty_messages = ConfigBool( "If True, send on empty messages as an empty unicode string. " "If False, reject empty messages as invalid.", default=False, static=True) class DeliverShortMessageProcessor(object): """ Messages can arrive with one of a number of specified encodings. We only handle a subset of these. From the SMPP spec: 00000000 (0) SMSC Default Alphabet 00000001 (1) IA5(CCITTT.50)/ASCII(ANSIX3.4) 00000010 (2) Octet unspecified (8-bit binary) 00000011 (3) Latin1(ISO-8859-1) 00000100 (4) Octet unspecified (8-bit binary) 00000101 (5) JIS(X0208-1990) 00000110 (6) Cyrllic(ISO-8859-5) 00000111 (7) Latin/Hebrew (ISO-8859-8) 00001000 (8) UCS2(ISO/IEC-10646) 00001001 (9) PictogramEncoding 00001010 (10) ISO-2022-JP(MusicCodes) 00001011 (11) reserved 00001100 (12) reserved 00001101 (13) Extended Kanji JIS(X 0212-1990) 00001110 (14) KSC5601 00001111 (15) reserved Particularly problematic are the "Octet unspecified" encodings. """ implements(IDeliverShortMessageProcessor) CONFIG_CLASS = DeliverShortMessageProcessorConfig def __init__(self, transport, config): self.transport = transport self.log = transport.log self.redis = transport.redis self.codec = transport.get_static_config().codec_class() self.config = self.CONFIG_CLASS(config, static=True) self.data_coding_map = { 1: 'ascii', 3: 'latin1', # http://www.herongyang.com/Unicode/JIS-ISO-2022-JP-Encoding.html 5: 'iso2022_jp', 6: 'iso8859_5', 7: 'iso8859_8', # Actually UCS-2, but close enough. 8: 'utf-16be', # http://en.wikipedia.org/wiki/Short_Message_Peer-to-Peer 9: 'shift_jis', 10: 'iso2022_jp' } self.data_coding_map.update(self.config.data_coding_overrides) self.allow_empty_messages = self.config.allow_empty_messages def dcs_decode(self, obj, data_coding): codec_name = self.data_coding_map.get(data_coding, None) if codec_name is None: self.log.msg( "WARNING: Not decoding message with data_coding=%s" % ( data_coding,)) return obj elif obj is None: if self.allow_empty_messages: return u'' self.log.msg( "WARNING: Not decoding `None` message with data_coding=%s" % ( data_coding,)) return obj try: return self.codec.decode(obj, codec_name) except UnicodeDecodeError, e: self.log.msg("Error decoding message with data_coding=%s" % ( data_coding,)) self.log.err(e) return obj def decode_pdus(self, pdus): content = [] for pdu in pdus: pdu_params = pdu['body']['mandatory_parameters'] pdu_opts = unpacked_pdu_opts(pdu) # We might have a `message_payload` optional field to worry about. message_payload = pdu_opts.get('message_payload', None) if message_payload is not None: short_message = message_payload.decode('hex') else: short_message = pdu_params['short_message'] content.append( self.dcs_decode(short_message, pdu_params['data_coding'])) return content def handle_short_message_content(self, source_addr, destination_addr, short_message, **kw): return self.transport.handle_raw_inbound_message( source_addr=source_addr, destination_addr=destination_addr, short_message=short_message, **kw) def handle_short_message_pdu(self, pdu): pdu_params = pdu['body']['mandatory_parameters'] content_parts = self.decode_pdus([pdu]) if content_parts is not None: content = u''.join(content_parts) else: content = None d = self.handle_short_message_content( source_addr=pdu_params['source_addr'], destination_addr=pdu_params['destination_addr'], short_message=content) d.addCallback(lambda _: True) return d def handle_multipart_pdu(self, pdu): if not detect_multipart(pdu): return succeed(False) # We have a multipart SMS. pdu_params = pdu['body']['mandatory_parameters'] d = self.handle_deliver_sm_multipart(pdu, pdu_params) d.addCallback(lambda _: True) return d @inlineCallbacks def handle_deliver_sm_multipart(self, pdu, pdu_params): redis_key = "multi_%s" % (multipart_key(detect_multipart(pdu)),) self.log.debug("Redis multipart key: %s" % (redis_key)) multi = yield self.load_multipart_message(redis_key) multi.add_pdu(pdu) completed = multi.get_completed() if completed: yield self.redis.delete(redis_key) self.log.msg("Reassembled Message: %s" % (completed['message'])) # We assume that all parts have the same data_coding here, because # otherwise there's nothing sensible we can do. decoded_msg = self.dcs_decode(completed['message'], pdu_params['data_coding']) # and we can finally pass the whole message on yield self.handle_short_message_content( source_addr=completed['from_msisdn'], destination_addr=completed['to_msisdn'], short_message=decoded_msg) else: yield self.save_multipart_message(redis_key, multi) def handle_ussd_pdu(self, pdu): pdu_params = pdu['body']['mandatory_parameters'] pdu_opts = unpacked_pdu_opts(pdu) if not detect_ussd(pdu_opts): return succeed(False) # We have a USSD message. d = self.handle_deliver_sm_ussd(pdu, pdu_params, pdu_opts) d.addCallback(lambda _: True) return d def handle_deliver_sm_ussd(self, pdu, pdu_params, pdu_opts): # Some of this stuff might be specific to Tata's setup. service_op = pdu_opts['ussd_service_op'] session_event = 'close' if service_op == '01': # PSSR request. Let's assume it means a new session. session_event = 'new' elif service_op == '11': # PSSR response. This means session end. session_event = 'close' elif service_op in ('02', '12'): # USSR request or response. I *think* we only get the latter. session_event = 'continue' # According to the spec, the first octet is the session id and the # second is the client dialog id (first 7 bits) and end session flag # (last bit). # Since we don't use the client dialog id and the spec says it's # ESME-defined, treat the whole thing as opaque "session info" that # gets passed back in reply messages. its_session_number = int(pdu_opts['its_session_info'], 16) end_session = bool(its_session_number % 2) session_info = "%04x" % (its_session_number & 0xfffe) if end_session: # We have an explicit "end session" flag. session_event = 'close' decoded_msg = self.dcs_decode(pdu_params['short_message'], pdu_params['data_coding']) return self.handle_short_message_content( source_addr=pdu_params['source_addr'], destination_addr=pdu_params['destination_addr'], short_message=decoded_msg, message_type='ussd', session_event=session_event, session_info=session_info) def _hex_for_redis(self, data_dict): for index, part in data_dict.items(): part['part_message'] = part['part_message'].encode('hex') return data_dict def _unhex_from_redis(self, data_dict): for index, part in data_dict.items(): part['part_message'] = part['part_message'].decode('hex') return data_dict @inlineCallbacks def load_multipart_message(self, redis_key): value = yield self.redis.get(redis_key) value = json.loads(value) if value else {} self.log.debug("Retrieved value: %s" % (repr(value))) returnValue(MultipartMessage(self._unhex_from_redis(value))) def save_multipart_message(self, redis_key, multipart_message): data_dict = self._hex_for_redis(multipart_message.get_array()) return self.redis.set(redis_key, json.dumps(data_dict)) class SubmitShortMessageProcessorConfig(Config): submit_sm_encoding = ConfigText( 'How to encode the SMS before putting on the wire', static=True, default='utf-8') submit_sm_data_coding = ConfigInt( 'What data_coding value to tell the SMSC we\'re using when putting' 'an SMS on the wire', static=True, default=0) send_long_messages = ConfigBool( "If `True`, messages longer than 254 characters will be sent in the " "`message_payload` optional field instead of the `short_message` " "field. Default is `False`, simply because that maintains previous " "behaviour.", default=False, static=True) send_multipart_sar = ConfigBool( "If `True`, messages longer than 140 bytes will be sent as a series " "of smaller messages with the sar_* parameters set. Default is " "`False`.", default=False, static=True) send_multipart_udh = ConfigBool( "If `True`, messages longer than 140 bytes will be sent as a series " "of smaller messages with the user data headers. Default is `False`.", default=False, static=True) def post_validate(self): long_message_params = ( 'send_long_messages', 'send_multipart_sar', 'send_multipart_udh') set_params = [p for p in long_message_params if getattr(self, p)] if len(set_params) > 1: params = ', '.join(set_params) self.raise_config_error( "The following parameters are mutually exclusive: %s" % params) class SubmitShortMessageProcessor(object): implements(ISubmitShortMessageProcessor) CONFIG_CLASS = SubmitShortMessageProcessorConfig def __init__(self, transport, config): self.transport = transport self.log = transport.log self.config = self.CONFIG_CLASS(config, static=True) def handle_outbound_message(self, message, service): to_addr = message['to_addr'] from_addr = message['from_addr'] text = message['content'] if text is None: text = u"" vumi_message_id = message['message_id'] # TODO: this should probably be handled by a processor as these # USSD fields & params are TATA (India) specific session_event = message['session_event'] transport_type = message['transport_type'] optional_parameters = {} if transport_type == 'ussd': continue_session = ( session_event != TransportUserMessage.SESSION_CLOSE) session_info = message['transport_metadata'].get( 'session_info', '0000') optional_parameters.update({ 'ussd_service_op': '02', 'its_session_info': "%04x" % ( int(session_info, 16) + int(not continue_session)) }) return self.send_short_message( service, vumi_message_id, to_addr.encode('ascii'), text.encode(self.config.submit_sm_encoding), data_coding=self.config.submit_sm_data_coding, source_addr=from_addr.encode('ascii'), optional_parameters=optional_parameters) def send_short_message(self, service, vumi_message_id, destination_addr, content, data_coding=0, source_addr='', optional_parameters=None): """ Call the appropriate `submit_*` method depending on config. """ kwargs = dict( vumi_message_id=vumi_message_id, destination_addr=destination_addr, short_message=content, data_coding=data_coding, source_addr=source_addr, optional_parameters=optional_parameters) if self.config.send_long_messages: kwargs['long_message'] = kwargs.pop('short_message') return service.submit_sm_long(**kwargs) elif self.config.send_multipart_sar: return service.submit_csm_sar(**kwargs) elif self.config.send_multipart_udh: return service.submit_csm_udh(**kwargs) return service.submit_sm(**kwargs) PK=H)SS)vumi/transports/smpp/processors/sixdee.py# -*- test-case-name: vumi.transports.smpp.tests.test_sixdee -*- from vumi.config import ConfigInt, ConfigText from vumi.components.session import SessionManager from vumi.message import TransportUserMessage from vumi.transports.smpp.processors import default from twisted.internet.defer import inlineCallbacks, returnValue def make_vumi_session_identifier(msisdn, sixdee_session_identifier): return '%s+%s' % (msisdn, sixdee_session_identifier) class DeliverShortMessageProcessorConfig( default.DeliverShortMessageProcessorConfig): max_session_length = ConfigInt( 'Maximum length a USSD sessions data is to be kept for in seconds.', default=60 * 3, static=True) ussd_code_pdu_field = ConfigText( 'PDU field to read the message `to_addr` (USSD code) from. Possible' ' options are "short_message" (the default) and "destination_addr".', default='short_message', static=True) class DeliverShortMessageProcessor(default.DeliverShortMessageProcessor): CONFIG_CLASS = DeliverShortMessageProcessorConfig # NOTE: these keys are hexidecimal because of python-smpp encoding # quirkiness ussd_service_op_map = { '01': 'new', '12': 'continue', '81': 'close', # user abort } def __init__(self, transport, config): super(DeliverShortMessageProcessor, self).__init__(transport, config) self.transport = transport self.log = transport.log self.redis = transport.redis self.config = self.CONFIG_CLASS(config, static=True) self.session_manager = SessionManager( self.redis, max_session_length=self.config.max_session_length) @inlineCallbacks def handle_deliver_sm_ussd(self, pdu, pdu_params, pdu_opts): service_op = pdu_opts['ussd_service_op'] # 6D uses its_session_info as follows: # # * First 15 bit: dialog id (i.e. session id) # * Last bit: end session (1 to end, 0 to continue) its_session_number = int(pdu_opts['its_session_info'], 16) end_session = bool(its_session_number % 2) sixdee_session_identifier = "%04x" % (its_session_number & 0xfffe) vumi_session_identifier = make_vumi_session_identifier( pdu_params['source_addr'], sixdee_session_identifier) if end_session: session_event = 'close' else: session_event = self.ussd_service_op_map.get(service_op) if session_event == 'new': # PSSR request. Let's assume it means a new session. ussd_code = pdu_params[self.config.ussd_code_pdu_field] content = None yield self.session_manager.create_session( vumi_session_identifier, ussd_code=ussd_code) elif session_event == 'close': session = yield self.session_manager.load_session( vumi_session_identifier) ussd_code = session['ussd_code'] content = None yield self.session_manager.clear_session(vumi_session_identifier) else: if session_event != 'continue': self.log.warning(( 'Received unknown %r ussd_service_op, assuming continue.') % (service_op,)) session_event = 'continue' session = yield self.session_manager.load_session( vumi_session_identifier) ussd_code = session['ussd_code'] content = self.dcs_decode( pdu_params['short_message'], pdu_params['data_coding']) # This is stashed on the message and available when replying # with a `submit_sm` session_info = { 'ussd_service_op': service_op, 'session_identifier': sixdee_session_identifier, } result = yield self.handle_short_message_content( source_addr=pdu_params['source_addr'], destination_addr=ussd_code, short_message=content, message_type='ussd', session_event=session_event, session_info=session_info) returnValue(result) class SubmitShortMessageProcessorConfig( default.SubmitShortMessageProcessorConfig): max_session_length = ConfigInt( 'Maximum length a USSD sessions data is to be kept for in seconds.', default=60 * 3, static=True) class SubmitShortMessageProcessor(default.SubmitShortMessageProcessor): CONFIG_CLASS = SubmitShortMessageProcessorConfig # NOTE: these values are hexidecimal because of python-smpp encoding # quirkiness ussd_service_op_map = { 'continue': '02', 'close': '17', # end } def __init__(self, transport, config): super(SubmitShortMessageProcessor, self).__init__(transport, config) self.transport = transport self.redis = transport.redis self.config = self.CONFIG_CLASS(config, static=True) self.session_manager = SessionManager( self.redis, max_session_length=self.config.max_session_length) @inlineCallbacks def handle_outbound_message(self, message, service): to_addr = message['to_addr'] from_addr = message['from_addr'] text = message['content'] if text is None: text = u"" vumi_message_id = message['message_id'] session_event = message['session_event'] transport_type = message['transport_type'] optional_parameters = {} if transport_type == 'ussd': continue_session = ( session_event != TransportUserMessage.SESSION_CLOSE) session_info = message['transport_metadata'].get( 'session_info', {}) sixdee_session_identifier = session_info.get( 'session_identifier', '') vumi_session_identifier = make_vumi_session_identifier( to_addr, sixdee_session_identifier) its_session_info = ( int(sixdee_session_identifier, 16) | int(not continue_session)) service_op = self.ussd_service_op_map[('continue' if continue_session else 'close')] optional_parameters.update({ 'ussd_service_op': service_op, 'its_session_info': "%04x" % (its_session_info,) }) if not continue_session: yield self.session_manager.clear_session( vumi_session_identifier) resp = yield self.send_short_message( service, vumi_message_id, to_addr.encode('ascii'), text.encode(self.config.submit_sm_encoding), data_coding=self.config.submit_sm_data_coding, source_addr=from_addr.encode('ascii'), optional_parameters=optional_parameters) returnValue(resp) PKqGv\gg'vumi/transports/smpp/processors/mica.py# -*- test-case-name: vumi.transports.smpp.tests.test_mica -*- from vumi.config import ConfigInt from vumi.components.session import SessionManager from vumi.message import TransportUserMessage from vumi.transports.smpp.processors import default from twisted.internet.defer import inlineCallbacks, returnValue def make_vumi_session_identifier(msisdn, mica_session_identifier): return '%s+%s' % (msisdn, mica_session_identifier) class DeliverShortMessageProcessorConfig( default.DeliverShortMessageProcessorConfig): max_session_length = ConfigInt( 'Maximum length a USSD sessions data is to be kept for in seconds.', default=60 * 3, static=True) class DeliverShortMessageProcessor(default.DeliverShortMessageProcessor): CONFIG_CLASS = DeliverShortMessageProcessorConfig # NOTE: these keys are hexidecimal because of python-smpp encoding # quirkiness ussd_service_op_map = { '01': 'new', '12': 'continue', '81': 'close', # user abort } def __init__(self, transport, config): super(DeliverShortMessageProcessor, self).__init__(transport, config) self.transport = transport self.log = transport.log self.redis = transport.redis self.config = self.CONFIG_CLASS(config, static=True) self.session_manager = SessionManager( self.redis, max_session_length=self.config.max_session_length) @inlineCallbacks def handle_deliver_sm_ussd(self, pdu, pdu_params, pdu_opts): service_op = pdu_opts['ussd_service_op'] mica_session_identifier = pdu_opts['user_message_reference'] vumi_session_identifier = make_vumi_session_identifier( pdu_params['source_addr'], mica_session_identifier) session_event = self.ussd_service_op_map.get(service_op) if session_event == 'new': # PSSR request. Let's assume it means a new session. ussd_code = pdu_params['short_message'] content = None yield self.session_manager.create_session( vumi_session_identifier, ussd_code=ussd_code) elif session_event == 'close': session = yield self.session_manager.load_session( vumi_session_identifier) ussd_code = session['ussd_code'] content = None yield self.session_manager.clear_session(vumi_session_identifier) else: if session_event != 'continue': self.log.warning(( 'Received unknown %r ussd_service_op, assuming continue.') % (service_op,)) session_event = 'continue' session = yield self.session_manager.load_session( vumi_session_identifier) ussd_code = session['ussd_code'] content = self.dcs_decode( pdu_params['short_message'], pdu_params['data_coding']) # This is stashed on the message and available when replying # with a `submit_sm` session_info = { 'ussd_service_op': service_op, 'session_identifier': mica_session_identifier, } result = yield self.handle_short_message_content( source_addr=pdu_params['source_addr'], destination_addr=ussd_code, short_message=content, message_type='ussd', session_event=session_event, session_info=session_info) returnValue(result) class SubmitShortMessageProcessorConfig( default.SubmitShortMessageProcessorConfig): max_session_length = ConfigInt( 'Maximum length a USSD sessions data is to be kept for in seconds.', default=60 * 3, static=True) class SubmitShortMessageProcessor(default.SubmitShortMessageProcessor): CONFIG_CLASS = SubmitShortMessageProcessorConfig # NOTE: these values are hexidecimal because of python-smpp encoding # quirkiness ussd_service_op_map = { 'continue': '02', 'close': '17', # end } def __init__(self, transport, config): super(SubmitShortMessageProcessor, self).__init__(transport, config) self.transport = transport self.redis = transport.redis self.config = self.CONFIG_CLASS(config, static=True) self.session_manager = SessionManager( self.redis, max_session_length=self.config.max_session_length) @inlineCallbacks def handle_outbound_message(self, message, service): to_addr = message['to_addr'] from_addr = message['from_addr'] text = message['content'] if text is None: text = u"" vumi_message_id = message['message_id'] session_event = message['session_event'] transport_type = message['transport_type'] optional_parameters = {} if transport_type == 'ussd': continue_session = ( session_event != TransportUserMessage.SESSION_CLOSE) session_info = message['transport_metadata'].get( 'session_info', {}) mica_session_identifier = session_info.get( 'session_identifier', '') vumi_session_identifier = make_vumi_session_identifier( to_addr, mica_session_identifier) service_op = self.ussd_service_op_map[('continue' if continue_session else 'close')] optional_parameters.update({ 'ussd_service_op': service_op, 'user_message_reference': ( str(mica_session_identifier).zfill(2)), }) if not continue_session: yield self.session_manager.clear_session( vumi_session_identifier) resp = yield self.send_short_message( service, vumi_message_id, to_addr.encode('ascii'), text.encode(self.config.submit_sm_encoding), data_coding=self.config.submit_sm_data_coding, source_addr=from_addr.encode('ascii'), optional_parameters=optional_parameters) returnValue(resp) PK=JG (a+vumi/transports/smpp/processors/__init__.pyfrom vumi.transports.smpp.processors.default import ( DeliveryReportProcessor, DeliveryReportProcessorConfig, DeliverShortMessageProcessor, DeliverShortMessageProcessorConfig, SubmitShortMessageProcessor, SubmitShortMessageProcessorConfig) __all__ = [ 'DeliveryReportProcessor', 'DeliveryReportProcessorConfig', 'DeliverShortMessageProcessor', 'DeliverShortMessageProcessorConfig', 'SubmitShortMessageProcessor', 'SubmitShortMessageProcessorConfig' ] PK=JGIN:4:42vumi/transports/smpp/processors/tests/test_mica.pyfrom twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.task import Clock from smpp.pdu_builder import DeliverSM from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.tests.helpers import TransportHelper from vumi.transports.smpp.pdu_utils import ( command_id, seq_no, pdu_tlv, short_message) from vumi.transports.smpp.smpp_transport import SmppTransceiverTransport from vumi.transports.smpp.processors.mica import make_vumi_session_identifier from vumi.transports.smpp.tests.fake_smsc import FakeSMSC class MicaProcessorTestCase(VumiTestCase): def setUp(self): self.clock = Clock() self.fake_smsc = FakeSMSC() self.tx_helper = self.add_helper( TransportHelper(SmppTransceiverTransport)) self.default_config = { 'transport_name': self.tx_helper.transport_name, 'twisted_endpoint': self.fake_smsc.endpoint, 'deliver_short_message_processor': ( 'vumi.transports.smpp.processors.mica.' 'DeliverShortMessageProcessor'), 'submit_short_message_processor': ( 'vumi.transports.smpp.processors.mica.' 'SubmitShortMessageProcessor'), 'system_id': 'foo', 'password': 'bar', 'deliver_short_message_processor_config': { 'data_coding_overrides': { 0: 'utf-8', } }, 'submit_short_message_processor_config': { 'submit_sm_encoding': 'utf-16be', 'submit_sm_data_coding': 8, 'send_multipart_udh': True, } } @inlineCallbacks def get_transport(self, config={}, bind=True): cfg = self.default_config.copy() transport = yield self.tx_helper.get_transport(cfg, start=False) transport.clock = self.clock yield transport.startWorker() self.clock.advance(0) if bind: yield self.fake_smsc.bind() returnValue(transport) def assert_udh_parts(self, pdus, texts, encoding): pdu_header = lambda pdu: short_message(pdu)[:6] pdu_text = lambda pdu: short_message(pdu)[6:].decode(encoding) udh_header = lambda i: '\x05\x00\x03\x03\x07' + chr(i) self.assertEqual( [(pdu_header(pdu), pdu_text(pdu)) for pdu in pdus], [(udh_header(i + 1), text) for i, text in enumerate(texts)]) @inlineCallbacks def test_submit_sm_multipart_udh_ucs2(self): message = ( "A cup is a small, open container used for carrying and " "drinking drinks. It may be made of wood, plastic, glass, " "clay, metal, stone, china or other materials, and may have " "a stem, handles or other adornments. Cups are used for " "drinking across a wide range of cultures and social classes, " "and different styles of cups may be used for different liquids " "or in different situations. Cups have been used for thousands " "of years for the ...Reply 1 for more") yield self.get_transport() yield self.tx_helper.make_dispatch_outbound(message, to_addr='msisdn') pdus = yield self.fake_smsc.await_pdus(7) self.assert_udh_parts(pdus, [ ("A cup is a small, open container used" " for carrying and drinking d"), ("rinks. It may be made of wood, plastic," " glass, clay, metal, stone"), (", china or other materials, and may have" " a stem, handles or other"), (" adornments. Cups are used for drinking" " across a wide range of cu"), ("ltures and social classes, and different" " styles of cups may be us"), ("ed for different liquids or in different" " situations. Cups have be"), ("en used for thousands of years for the ...Reply 1 for more"), ], encoding='utf-16be') # utf-16be is close enough to UCS2 for pdu in pdus: self.assertTrue(len(short_message(pdu)) < 140) @inlineCallbacks def test_submit_and_deliver_ussd_new(self): session_identifier = 12345 yield self.get_transport() # Server delivers a USSD message to the Client pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], None) self.assertEqual(mess['to_addr'], '*123#') self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual( mess['transport_metadata'], { 'session_info': { 'session_identifier': 12345, 'ussd_service_op': '01', } }) @inlineCallbacks def test_deliver_sm_op_codes_new(self): session_identifier = 12345 yield self.get_transport() pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [start] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(start['session_event'], TransportUserMessage.SESSION_NEW) @inlineCallbacks def test_deliver_sm_op_codes_resume(self): source_addr = 'msisdn' session_identifier = 12345 vumi_session_identifier = make_vumi_session_identifier( source_addr, session_identifier) transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( vumi_session_identifier, ussd_code='*123#') pdu = DeliverSM(1, short_message="", source_addr=source_addr) pdu.add_optional_parameter('ussd_service_op', '12') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [resume] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(resume['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_deliver_sm_op_codes_end(self): source_addr = 'msisdn' session_identifier = 12345 vumi_session_identifier = make_vumi_session_identifier( source_addr, session_identifier) transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( vumi_session_identifier, ussd_code='*123#') pdu = DeliverSM(1, short_message="", source_addr=source_addr) pdu.add_optional_parameter('ussd_service_op', '81') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [end] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(end['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_deliver_sm_unknown_op_code(self): session_identifier = 12345 yield self.get_transport() pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '99') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [start, unknown] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(unknown['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_submit_sm_op_codes_resume(self): user_msisdn = 'msisdn' session_identifier = 12345 yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_RESUME, transport_metadata={ 'session_info': { 'session_identifier': session_identifier } }, to_addr=user_msisdn) resume = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(resume, 'ussd_service_op'), '02') @inlineCallbacks def test_submit_sm_op_codes_close(self): user_msisdn = 'msisdn' session_identifier = 12345 yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE, transport_metadata={ 'session_info': { 'session_identifier': session_identifier } }, to_addr=user_msisdn) close = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(close, 'ussd_service_op'), '17') @inlineCallbacks def test_submit_and_deliver_ussd_continue(self): user_msisdn = 'msisdn' session_identifier = 12345 vumi_session_identifier = make_vumi_session_identifier( user_msisdn, session_identifier) transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( vumi_session_identifier, ussd_code='*123#') yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", transport_metadata={ 'session_info': { 'session_identifier': session_identifier } }, to_addr=user_msisdn) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '02') self.assertEqual( pdu_tlv(submit_sm_pdu, 'user_message_reference'), session_identifier) # Server delivers a USSD message to the Client pdu = DeliverSM(seq_no(submit_sm_pdu) + 1, short_message="reply!", source_addr=user_msisdn) # 0x12 is 'continue' pdu.add_optional_parameter('ussd_service_op', '12') pdu.add_optional_parameter('user_message_reference', session_identifier) yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['to_addr'], '*123#') self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_submit_and_deliver_ussd_close(self): yield self.get_transport() session_identifier = 12345 yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE, transport_metadata={ 'session_info': { 'session_identifier': session_identifier } }) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '17') self.assertEqual(pdu_tlv(submit_sm_pdu, 'user_message_reference'), session_identifier) @inlineCallbacks def test_submit_sm_null_message(self): """ We can successfully send a message with null content. """ user_msisdn = 'msisdn' session_identifier = 12345 yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( None, transport_type="ussd", session_event=TransportUserMessage.SESSION_RESUME, transport_metadata={ 'session_info': { 'session_identifier': session_identifier } }, to_addr=user_msisdn) resume = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(resume, 'ussd_service_op'), '02') PK=JG1vumi/transports/smpp/processors/tests/__init__.pyPK=HFؠ::4vumi/transports/smpp/processors/tests/test_sixdee.pyfrom twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.task import Clock from smpp.pdu_builder import DeliverSM from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.tests.helpers import TransportHelper from vumi.transports.smpp.pdu_utils import ( command_id, seq_no, pdu_tlv, short_message) from vumi.transports.smpp.smpp_transport import SmppTransceiverTransport from vumi.transports.smpp.processors.sixdee import make_vumi_session_identifier from vumi.transports.smpp.tests.fake_smsc import FakeSMSC class SessionInfo(object): """Helper for holding session ids.""" def __init__(self, session_id=5678, addr='1234', continue_session=True): # all 6D session IDs are even by construction assert session_id % 2 == 0 self.session_id = session_id self.addr = addr self.continue_session = continue_session @property def its_info(self): return "%04x" % (self.session_id | int(not self.continue_session)) @property def vumi_id(self): return make_vumi_session_identifier(self.addr, self.sixdee_id) @property def sixdee_id(self): return "%04x" % self.session_id class SixDeeProcessorTestCase(VumiTestCase): transport_class = SmppTransceiverTransport def setUp(self): self.clock = Clock() self.fake_smsc = FakeSMSC() self.tx_helper = self.add_helper( TransportHelper(SmppTransceiverTransport)) self.default_config = { 'transport_name': self.tx_helper.transport_name, 'twisted_endpoint': self.fake_smsc.endpoint, 'deliver_short_message_processor': ( 'vumi.transports.smpp.processors.sixdee.' 'DeliverShortMessageProcessor'), 'submit_short_message_processor': ( 'vumi.transports.smpp.processors.sixdee.' 'SubmitShortMessageProcessor'), 'system_id': 'foo', 'password': 'bar', 'deliver_short_message_processor_config': { 'data_coding_overrides': { 0: 'utf-8', } }, 'submit_short_message_processor_config': { 'submit_sm_encoding': 'utf-16be', 'submit_sm_data_coding': 8, 'send_multipart_udh': True, } } @inlineCallbacks def get_transport(self, deliver_config={}, submit_config={}, bind=True): cfg = self.default_config.copy() cfg['deliver_short_message_processor_config'].update(deliver_config) cfg['submit_short_message_processor_config'].update(submit_config) transport = yield self.tx_helper.get_transport(cfg, start=False) transport.clock = self.clock yield transport.startWorker() self.clock.advance(0) if bind: yield self.fake_smsc.bind() returnValue(transport) def assert_udh_parts(self, pdus, texts, encoding): def pdu_header(pdu): return short_message(pdu)[:6] def pdu_text(pdu): return short_message(pdu)[6:].decode(encoding) def udh_header(i): return '\x05\x00\x03\x03\x07' + chr(i) self.assertEqual( [(pdu_header(pdu), pdu_text(pdu)) for pdu in pdus], [(udh_header(i + 1), text) for i, text in enumerate(texts)]) @inlineCallbacks def test_submit_sm_multipart_udh_ucs2(self): message = ( "A cup is a small, open container used for carrying and " "drinking drinks. It may be made of wood, plastic, glass, " "clay, metal, stone, china or other materials, and may have " "a stem, handles or other adornments. Cups are used for " "drinking across a wide range of cultures and social classes, " "and different styles of cups may be used for different liquids " "or in different situations. Cups have been used for thousands " "of years for the ...Reply 1 for more") yield self.get_transport() yield self.tx_helper.make_dispatch_outbound(message, to_addr='msisdn') pdus = yield self.fake_smsc.await_pdus(7) self.assert_udh_parts(pdus, [ ("A cup is a small, open container used" " for carrying and drinking d"), ("rinks. It may be made of wood, plastic," " glass, clay, metal, stone"), (", china or other materials, and may have" " a stem, handles or other"), (" adornments. Cups are used for drinking" " across a wide range of cu"), ("ltures and social classes, and different" " styles of cups may be us"), ("ed for different liquids or in different" " situations. Cups have be"), ("en used for thousands of years for the ...Reply 1 for more"), ], encoding='utf-16be') # utf-16be is close enough to UCS2 for pdu in pdus: self.assertTrue(len(short_message(pdu)) < 140) @inlineCallbacks def test_submit_and_deliver_ussd_new(self): session = SessionInfo() yield self.get_transport() # Server delivers a USSD message to the Client pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], None) self.assertEqual(mess['to_addr'], '*123#') self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual( mess['transport_metadata'], { 'session_info': { 'session_identifier': session.sixdee_id, 'ussd_service_op': '01', } }) @inlineCallbacks def test_submit_and_deliver_ussd_new_custom_ussd_code_field(self): session = SessionInfo() yield self.get_transport(deliver_config={ 'ussd_code_pdu_field': 'destination_addr', }) # Server delivers a USSD message to the Client pdu = DeliverSM(1, short_message="*IGNORE#", destination_addr="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], None) self.assertEqual(mess['to_addr'], '*123#') self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual( mess['transport_metadata'], { 'session_info': { 'session_identifier': session.sixdee_id, 'ussd_service_op': '01', } }) @inlineCallbacks def test_deliver_sm_op_codes_new(self): session = SessionInfo() yield self.get_transport() pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [start] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(start['session_event'], TransportUserMessage.SESSION_NEW) @inlineCallbacks def test_deliver_sm_op_codes_resume(self): session = SessionInfo() transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( session.vumi_id, ussd_code='*123#') pdu = DeliverSM(1, short_message="", source_addr=session.addr) pdu.add_optional_parameter('ussd_service_op', '12') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [resume] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(resume['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_deliver_sm_op_codes_end(self): session = SessionInfo() transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( session.vumi_id, ussd_code='*123#') pdu = DeliverSM(1, short_message="", source_addr=session.addr) pdu.add_optional_parameter('ussd_service_op', '81') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [end] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(end['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_deliver_sm_unknown_op_code(self): session = SessionInfo() yield self.get_transport() pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '01') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) pdu = DeliverSM(1, short_message="*123#") pdu.add_optional_parameter('ussd_service_op', '99') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [start, unknown] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(unknown['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_submit_sm_op_codes_resume(self): session = SessionInfo() yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_RESUME, transport_metadata={ 'session_info': { 'session_identifier': session.sixdee_id, } }, to_addr=session.addr) resume = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(resume, 'ussd_service_op'), '02') self.assertEqual(pdu_tlv(resume, 'its_session_info'), session.its_info) @inlineCallbacks def test_submit_sm_op_codes_close(self): session = SessionInfo(continue_session=False) yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE, transport_metadata={ 'session_info': { 'session_identifier': session.sixdee_id, } }, to_addr=session.addr) close = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(close, 'ussd_service_op'), '17') self.assertEqual(pdu_tlv(close, 'its_session_info'), session.its_info) @inlineCallbacks def test_submit_and_deliver_ussd_continue(self): session = SessionInfo() transport = yield self.get_transport() deliver_sm_processor = transport.deliver_sm_processor session_manager = deliver_sm_processor.session_manager yield session_manager.create_session( session.vumi_id, ussd_code='*123#') yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", transport_metadata={ 'session_info': { 'session_identifier': session.sixdee_id, } }, to_addr=session.addr) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '02') self.assertEqual(pdu_tlv(submit_sm_pdu, 'its_session_info'), session.its_info) # Server delivers a USSD message to the Client pdu = DeliverSM(seq_no(submit_sm_pdu) + 1, short_message="reply!", source_addr=session.addr) # 0x12 is 'continue' pdu.add_optional_parameter('ussd_service_op', '12') pdu.add_optional_parameter('its_session_info', session.its_info) yield self.fake_smsc.handle_pdu(pdu) [mess] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(mess['content'], "reply!") self.assertEqual(mess['transport_type'], "ussd") self.assertEqual(mess['to_addr'], '*123#') self.assertEqual(mess['session_event'], TransportUserMessage.SESSION_RESUME) @inlineCallbacks def test_submit_and_deliver_ussd_close(self): session = SessionInfo(continue_session=False) yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( "hello world", transport_type="ussd", session_event=TransportUserMessage.SESSION_CLOSE, transport_metadata={ 'session_info': { 'session_identifier': session.sixdee_id, } }) submit_sm_pdu = yield self.fake_smsc.await_pdu() self.assertEqual(command_id(submit_sm_pdu), 'submit_sm') self.assertEqual(pdu_tlv(submit_sm_pdu, 'ussd_service_op'), '17') self.assertEqual(pdu_tlv(submit_sm_pdu, 'its_session_info'), session.its_info) @inlineCallbacks def test_submit_sm_null_message(self): """ We can successfully send a message with null content. """ session = SessionInfo() yield self.get_transport() yield self.tx_helper.make_dispatch_outbound( None, transport_type="ussd", session_event=TransportUserMessage.SESSION_RESUME, transport_metadata={ 'session_info': { 'session_identifier': session.sixdee_id, } }, to_addr=session.addr) resume = yield self.fake_smsc.await_pdu() self.assertEqual(pdu_tlv(resume, 'ussd_service_op'), '02') self.assertEqual(pdu_tlv(resume, 'its_session_info'), session.its_info) PKqG| *vumi/transports/cellulant/cellulant_sms.py# -*- test-case-name: vumi.transports.cellulant.tests.test_cellulant_sms -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from vumi.utils import http_request_full from vumi import log from vumi.config import ConfigDict, ConfigText from vumi.transports.httprpc import HttpRpcTransport class CellulantSmsTransportConfig(HttpRpcTransport.CONFIG_CLASS): """Cellulant transport config. """ credentials = ConfigDict( "A dictionary where the `from_addr` is used for the key lookup and the" " returned value should be a dictionary containing the username and" " password.", required=True, static=True) outbound_url = ConfigText( "The URL to send outbound messages to.", required=True, static=True) class CellulantSmsTransport(HttpRpcTransport): """ HTTP transport for Cellulant SMS. """ transport_type = 'sms' agent_factory = None # For swapping out the Agent we use in tests. CONFIG_CLASS = CellulantSmsTransportConfig EXPECTED_FIELDS = set(["SOURCEADDR", "DESTADDR", "MESSAGE", "ID"]) IGNORED_FIELDS = set(["channelID", "keyword", "CHANNELID", "serviceID", "SERVICEID", "unsub", "transactionID"]) KNOWN_ERROR_RESPONSE_CODES = { 'E0': 'Insufficient HTTP Params passed', 'E1': 'Invalid username or password', 'E2': 'Credits have expired or run out', '1005': 'Suspect source address', } def validate_config(self): config = self.get_static_config() self._credentials = config.credentials self._outbound_url = config.outbound_url return super(CellulantSmsTransport, self).validate_config() @inlineCallbacks def handle_outbound_message(self, message): creds = self._credentials.get(message['from_addr'], {}) username = creds.get('username', '') password = creds.get('password', '') params = { 'username': username, 'password': password, 'source': message['from_addr'], 'destination': message['to_addr'], 'message': message['content'], } log.msg("Sending outbound message: %s" % (message,)) url = '%s?%s' % (self._outbound_url, urlencode(params)) log.msg("Making HTTP request: %s" % (url,)) response = yield http_request_full( url, '', method='GET', agent_class=self.agent_factory) log.msg("Response: (%s) %r" % (response.code, response.delivered_body)) content = response.delivered_body.strip() # we'll only send 1 message at a time and so the API can only # return this on a valid ack if content == '1': yield self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) else: error = self.KNOWN_ERROR_RESPONSE_CODES.get( content, 'Unknown response code: %s' % (content,)) yield self.publish_nack(message['message_id'], error) @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): values, errors = self.get_field_values( request, self.EXPECTED_FIELDS, self.IGNORED_FIELDS) if errors: log.msg('Unhappy incoming message: %s' % (errors,)) yield self.finish_request(message_id, json.dumps(errors), code=400) return log.msg(('CellulantSmsTransport sending from %(SOURCEADDR)s to ' '%(DESTADDR)s message "%(MESSAGE)s"') % values) yield self.publish_message( message_id=message_id, content=values['MESSAGE'], to_addr=values['DESTADDR'], from_addr=values['SOURCEADDR'], provider='vumi', transport_type=self.transport_type, transport_metadata={'transport_message_id': values['ID']}, ) yield self.finish_request( message_id, json.dumps({'message_id': message_id})) PK=JGZ%vumi/transports/cellulant/__init__.pyfrom vumi.transports.cellulant.cellulant import ( CellulantTransport, CellulantError) from vumi.transports.cellulant.cellulant_sms import CellulantSmsTransport __all__ = ['CellulantTransport', 'CellulantSmsTransport', 'CellulantError'] PK=JG_//&vumi/transports/cellulant/cellulant.py# -*- test-case-name: vumi.transports.cellulant.tests.test_cellulant -*- from twisted.internet.defer import inlineCallbacks from vumi.components.session import SessionManager from vumi.errors import VumiError from vumi.transports.httprpc import HttpRpcTransport from vumi.message import TransportUserMessage from vumi import log class CellulantError(VumiError): """Used to log errors specific to the Cellulant transport.""" def pack_ussd_message(message): next_level = 1 # Ignoring the menu levels content = message['content'] value_of_selection = 'null' service_id = 'null' if message['session_event'] == TransportUserMessage.SESSION_CLOSE: status = 'end' else: status = 'null' extra = 'null' return "%s|%s|%s|%s|%s|%s" % ( next_level, content, value_of_selection, service_id, status, extra) class CellulantTransport(HttpRpcTransport): """Cellulant USSD (via HTTP) transport.""" ENCODING = 'utf-8' EVENT_MAP = { 'BEG': TransportUserMessage.SESSION_NEW, 'ABO': TransportUserMessage.SESSION_CLOSE, } def validate_config(self): super(CellulantTransport, self).validate_config() self.transport_type = self.config.get('transport_type', 'ussd') @inlineCallbacks def setup_transport(self): super(CellulantTransport, self).setup_transport() r_config = self.config.get('redis_manager', {}) r_prefix = "vumi.transports.cellulant:%s" % self.transport_name session_timeout = int(self.config.get("ussd_session_timeout", 600)) self.session_manager = yield SessionManager.from_redis_config( r_config, r_prefix, session_timeout) @inlineCallbacks def teardown_transport(self): yield self.session_manager.stop() yield super(CellulantTransport, self).teardown_transport() def set_ussd_for_msisdn_session(self, msisdn, session, ussd): return self.session_manager.create_session( "%s:%s" % (msisdn, session), ussd=ussd) def get_ussd_for_msisdn_session(self, msisdn, session): d = self.session_manager.load_session("%s:%s" % (msisdn, session)) return d.addCallback(lambda s: s.get('ussd', None)) @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): op_code = request.args.get('opCode')[0] to_addr = None if op_code == "BEG": to_addr = request.args.get('INPUT')[0] content = None yield self.set_ussd_for_msisdn_session( request.args.get('MSISDN')[0], request.args.get('sessionID')[0], to_addr) else: to_addr = yield self.get_ussd_for_msisdn_session( request.args.get('MSISDN')[0], request.args.get('sessionID')[0]) content = request.args.get('INPUT')[0] if ((request.args.get('ABORT')[0] not in ('0', 'null')) or (op_code == 'ABO')): # respond to phones aborting a session self.finish_request(message_id, '') event = TransportUserMessage.SESSION_CLOSE else: event = self.EVENT_MAP.get(op_code, TransportUserMessage.SESSION_RESUME) if to_addr is None: # we can't continue so finish request and log error self.finish_request(message_id, '') log.error(CellulantError( "Failed redis USSD to_addr lookup for %s" % request.args)) else: transport_metadata = { 'session_id': request.args.get('sessionID')[0], } self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=request.args.get('MSISDN')[0], session_event=event, transport_name=self.transport_name, transport_type=self.transport_type, transport_metadata=transport_metadata, ) def handle_outbound_message(self, message): missing_fields = self.ensure_message_values(message, ['in_reply_to', 'content']) if missing_fields: return self.reject_message(message, missing_fields) self.finish_request(message['in_reply_to'], pack_ussd_message(message).encode(self.ENCODING)) return self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) PKqG7h005vumi/transports/cellulant/tests/test_cellulant_sms.py# -*- encoding: utf-8 -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks, DeferredQueue, returnValue from vumi.utils import http_request, http_request_full from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import VumiTestCase from vumi.transports.cellulant import CellulantSmsTransport from vumi.transports.tests.helpers import TransportHelper class FakeCellulant(object): def __init__(self): self.cellulant_sms_calls = DeferredQueue() self.fake_http = FakeHttpServer(self.handle_request) self.response = '' self.get_agent = self.fake_http.get_agent self.get = self.cellulant_sms_calls.get def handle_request(self, request): self.cellulant_sms_calls.put(request) return self.response class TestCellulantSmsTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.fake_cellulant = FakeCellulant() self.base_url = "http://cellulant.example.com/" self.config = { 'web_path': "foo", 'web_port': 0, 'credentials': { '2371234567': { 'username': 'user', 'password': 'pass', }, '9292': { 'username': 'other-user', 'password': 'other-pass', } }, 'outbound_url': self.base_url, } self.tx_helper = self.add_helper( TransportHelper(CellulantSmsTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.agent_factory = self.fake_cellulant.get_agent self.transport_url = self.transport.get_transport_url() def mkurl(self, content, from_addr="2371234567", **kw): params = { 'SOURCEADDR': from_addr, 'DESTADDR': '12345', 'MESSAGE': content, 'ID': '1234567', } params.update(kw) return self.mkurl_raw(**params) def mkurl_raw(self, **params): return '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode(params) ) @inlineCallbacks def test_health(self): result = yield http_request( self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), {'pending_requests': 0}) @inlineCallbacks def test_inbound(self): url = self.mkurl('hello') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_outbound(self): yield self.tx_helper.make_dispatch_outbound( "hello world", to_addr="2371234567") req = yield self.fake_cellulant.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'GET') self.assertEqual({ 'username': ['other-user'], 'password': ['other-pass'], 'source': ['9292'], 'destination': ['2371234567'], 'message': ['hello world'], }, req.args) @inlineCallbacks def test_outbound_creds_selection(self): yield self.tx_helper.make_dispatch_outbound( "hello world", to_addr="2371234567", from_addr='2371234567') req = yield self.fake_cellulant.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'GET') self.assertEqual({ 'username': ['user'], 'password': ['pass'], 'source': ['2371234567'], 'destination': ['2371234567'], 'message': ['hello world'], }, req.args) yield self.tx_helper.make_dispatch_outbound( "hello world", to_addr="2371234567", from_addr='9292') req = yield self.fake_cellulant.get() self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'GET') self.assertEqual({ 'username': ['other-user'], 'password': ['other-pass'], 'source': ['9292'], 'destination': ['2371234567'], 'message': ['hello world'], }, req.args) @inlineCallbacks def test_handle_non_ascii_input(self): url = self.mkurl(u"öæł".encode("utf-8")) response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], u"öæł") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) @inlineCallbacks def test_bad_parameter(self): url = self.mkurl('hello', foo='bar') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'unexpected_parameter': ['foo']}) @inlineCallbacks def test_missing_parameters(self): url = self.mkurl_raw(ID='12345678', DESTADDR='12345', MESSAGE='hello') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'missing_parameter': ['SOURCEADDR']}) @inlineCallbacks def test_ignored_parameters(self): url = self.mkurl('hello', channelID='a', keyword='b', CHANNELID='c', serviceID='d', SERVICEID='e', unsub='f') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) class TestAcksCellulantSmsTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.fake_cellulant = FakeCellulant() self.base_url = "http://cellulant.example.com/" self.config = { 'web_path': "foo", 'web_port': 0, 'credentials': { '2371234567': { 'username': 'user', 'password': 'pass', }, '9292': { 'username': 'other-user', 'password': 'other-pass', } }, 'outbound_url': self.base_url, 'validation_mode': 'permissive', } self.tx_helper = self.add_helper( TransportHelper(CellulantSmsTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.agent_factory = self.fake_cellulant.get_agent self.transport_url = self.transport.get_transport_url() def mock_response(self, response): self.fake_cellulant.response = response @inlineCallbacks def mock_event(self, msg, nr_events): self.mock_response(msg) yield self.tx_helper.make_dispatch_outbound( "foo", to_addr='2371234567', message_id='id_%s' % (msg,)) yield self.fake_cellulant.get() events = yield self.tx_helper.wait_for_dispatched_events(nr_events) returnValue(events) @inlineCallbacks def test_nack_param_error_E0(self): [nack] = yield self.mock_event('E0', 1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], 'id_E0') self.assertEqual(nack['nack_reason'], self.transport.KNOWN_ERROR_RESPONSE_CODES['E0']) @inlineCallbacks def test_nack_login_error_E1(self): [nack] = yield self.mock_event('E1', 1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], 'id_E1') self.assertEqual(nack['nack_reason'], self.transport.KNOWN_ERROR_RESPONSE_CODES['E1']) @inlineCallbacks def test_nack_credits_error_E2(self): [nack] = yield self.mock_event('E2', 1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], 'id_E2') self.assertEqual(nack['nack_reason'], self.transport.KNOWN_ERROR_RESPONSE_CODES['E2']) @inlineCallbacks def test_nack_delivery_failed_1005(self): [nack] = yield self.mock_event('1005', 1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], 'id_1005') self.assertEqual(nack['nack_reason'], self.transport.KNOWN_ERROR_RESPONSE_CODES['1005']) @inlineCallbacks def test_unknown_response(self): [nack] = yield self.mock_event('something_unexpected', 1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], 'id_something_unexpected') self.assertEqual(nack['nack_reason'], 'Unknown response code: something_unexpected') @inlineCallbacks def test_ack_success(self): [event] = yield self.mock_event('1', 1) self.assertEqual(event['event_type'], 'ack') self.assertEqual(event['user_message_id'], 'id_1') class TestPermissiveCellulantSmsTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.fake_cellulant = FakeCellulant() self.base_url = "http://cellulant.example.com/" self.config = { 'web_path': "foo", 'web_port': 0, 'credentials': { '2371234567': { 'username': 'user', 'password': 'pass', }, '9292': { 'username': 'other-user', 'password': 'other-pass', } }, 'outbound_url': self.base_url, 'validation_mode': 'permissive', } self.tx_helper = self.add_helper( TransportHelper(CellulantSmsTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.agent_factory = self.fake_cellulant.get_agent self.transport_url = self.transport.get_transport_url() def mkurl(self, content, from_addr="2371234567", **kw): params = { 'SOURCEADDR': from_addr, 'DESTADDR': '12345', 'MESSAGE': content, 'ID': '1234567', } params.update(kw) return self.mkurl_raw(**params) def mkurl_raw(self, **params): return '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode(params) ) @inlineCallbacks def test_bad_parameter_in_permissive_mode(self): url = self.mkurl('hello', foo='bar') response = yield http_request_full(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(200, response.code) self.assertEqual(json.loads(response.delivered_body), {'message_id': msg['message_id']}) @inlineCallbacks def test_missing_parameters(self): url = self.mkurl_raw(ID='12345678', DESTADDR='12345', MESSAGE='hello') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'missing_parameter': ['SOURCEADDR']}) @inlineCallbacks def test_ignored_parameters(self): url = self.mkurl('hello', channelID='a', keyword='b', CHANNELID='c', serviceID='d', SERVICEID='e', unsub='f') response = yield http_request(url, '', method='GET') [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['content'], "hello") self.assertEqual(json.loads(response), {'message_id': msg['message_id']}) PK=JGB%1vumi/transports/cellulant/tests/test_cellulant.pyfrom urllib import urlencode from twisted.internet.defer import inlineCallbacks from vumi.tests.helpers import VumiTestCase from vumi.transports.cellulant import CellulantTransport, CellulantError from vumi.message import TransportUserMessage from vumi.utils import http_request from vumi.transports.tests.helpers import TransportHelper class TestCellulantTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.config = { 'web_port': 0, 'web_path': '/api/v1/ussd/cellulant/', 'ussd_session_timeout': 60, } self.tx_helper = self.add_helper(TransportHelper(CellulantTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport_url = self.transport.get_transport_url( self.config['web_path']) yield self.transport.session_manager.redis._purge_all() # just in case def mk_request(self, **params): defaults = { 'MSISDN': '27761234567', 'INPUT': '', 'opCode': 'BEG', 'ABORT': '0', 'sessionID': '1', } defaults.update(params) return http_request('%s?%s' % (self.transport_url, urlencode(defaults)), data='', method='GET') @inlineCallbacks def test_redis_caching(self): # delete the key that shouldn't exist (in case of testing real redis) yield self.transport.session_manager.redis.delete("msisdn:123") tx = self.transport val = yield tx.get_ussd_for_msisdn_session("msisdn", "123") self.assertEqual(None, val) yield tx.set_ussd_for_msisdn_session("msisdn", "123", "*bar#") val = yield tx.get_ussd_for_msisdn_session("msisdn", "123") self.assertEqual("*bar#", val) @inlineCallbacks def test_inbound_begin(self): deferred = self.mk_request(INPUT="*120*1#") [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], None) self.assertEqual(msg['to_addr'], '*120*1#') self.assertEqual(msg['from_addr'], '27761234567'), self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['transport_metadata'], { 'session_id': '1', }) yield self.tx_helper.make_dispatch_reply(msg, "ussd message") response = yield deferred self.assertEqual(response, '1|ussd message|null|null|null|null') @inlineCallbacks def test_inbound_resume_and_reply_with_end(self): # first pre-populate the redis datastore to simulate prior BEG message yield self.transport.set_ussd_for_msisdn_session( '27761234567', '1', '*120*VERY_FAKE_CODE#', ) deferred = self.mk_request(INPUT='hi', opCode='') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], 'hi') self.assertEqual(msg['to_addr'], '*120*VERY_FAKE_CODE#') self.assertEqual(msg['from_addr'], '27761234567') self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['transport_metadata'], { 'session_id': '1', }) yield self.tx_helper.make_dispatch_reply( msg, "hello world", continue_session=False) response = yield deferred self.assertEqual(response, '1|hello world|null|null|end|null') @inlineCallbacks def test_inbound_resume_with_failed_to_addr_lookup(self): deferred = self.mk_request(MSISDN='123456', INPUT='hi', opCode='') response = yield deferred self.assertEqual(response, '') [f] = self.flushLoggedErrors(CellulantError) self.assertTrue(str(f.value).startswith( "Failed redis USSD to_addr lookup for {")) @inlineCallbacks def test_inbound_abort_opcode(self): # first pre-populate the redis datastore to simulate prior BEG message yield self.transport.set_ussd_for_msisdn_session( '27761234567', '1', '*120*VERY_FAKE_CODE#', ) # this one should return immediately with a blank # as there isn't going to be a sensible response resp = yield self.mk_request(opCode='ABO') self.assertEqual(resp, '') [msg] = yield self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_inbound_abort_field(self): # should also return immediately resp = yield self.mk_request(ABORT=1) self.assertEqual(resp, '') [msg] = yield self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_nack(self): msg = yield self.tx_helper.make_dispatch_outbound("foo") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Missing fields: in_reply_to') PK=JG+vumi/transports/cellulant/tests/__init__.pyPK=JGC&vumi/transports/safaricom/safaricom.py# -*- test-case-name: vumi.transports.safaricom.tests.test_safaricom -*- import json from twisted.internet.defer import inlineCallbacks from vumi.transports.httprpc import HttpRpcTransport from vumi.message import TransportUserMessage from vumi.components.session import SessionManager from vumi import log class SafaricomTransport(HttpRpcTransport): """ HTTP transport for USSD with Safaricom in Kenya. :param str web_path: The HTTP path to listen on. :param int web_port: The HTTP port :param str transport_name: The name this transport instance will use to create its queues :param dict redis: The configuration parameters for connecting to Redis. :param int ussd_session_timeout: The number of seconds after which a timeout is forced on a transport level. """ transport_type = 'ussd' ENCODING = 'utf-8' EXPECTED_FIELDS = set(['ORIG', 'DEST', 'SESSION_ID', 'USSD_PARAMS']) def validate_config(self): super(SafaricomTransport, self).validate_config() self.transport_type = self.config.get('transport_type', 'ussd') self.redis_config = self.config.get('redis_manager', {}) self.r_prefix = "vumi.transports.safaricom:%s" % self.transport_name self.r_session_timeout = int(self.config.get("ussd_session_timeout", 600)) @inlineCallbacks def setup_transport(self): super(SafaricomTransport, self).setup_transport() self.session_manager = yield SessionManager.from_redis_config( self.redis_config, self.r_prefix, self.r_session_timeout) @inlineCallbacks def teardown_transport(self): yield self.session_manager.stop() yield super(SafaricomTransport, self).teardown_transport() @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): values, errors = self.get_field_values(request, self.EXPECTED_FIELDS) if errors: log.err('Unhappy incoming message: %s' % (errors,)) yield self.finish_request(message_id, json.dumps(errors), code=400) return self.emit(('SafaricomTransport sending from %s to %s ' 'for %s message "%s" (%s still pending)') % ( values['ORIG'], values['DEST'], values['SESSION_ID'], values['USSD_PARAMS'], len(self._requests), )) session_id = values['SESSION_ID'] from_addr = values['ORIG'] dest = values['DEST'] ussd_params = values['USSD_PARAMS'] session = yield self.session_manager.load_session(session_id) if session: to_addr = session['to_addr'] last_ussd_params = session['last_ussd_params'] new_params = ussd_params[len(last_ussd_params):] if new_params: if last_ussd_params: content = new_params[1:] else: content = new_params else: content = '' session['last_ussd_params'] = ussd_params yield self.session_manager.save_session(session_id, session) session_event = TransportUserMessage.SESSION_RESUME else: if ussd_params: to_addr = '*%s*%s#' % (dest, ussd_params) else: to_addr = '*%s#' % (dest,) yield self.session_manager.create_session(session_id, from_addr=from_addr, to_addr=to_addr, last_ussd_params=ussd_params) session_event = TransportUserMessage.SESSION_NEW content = '' yield self.publish_message( message_id=message_id, content=content, to_addr=to_addr, from_addr=from_addr, provider='safaricom', session_event=session_event, transport_type=self.transport_type, transport_metadata={ 'safaricom': { 'session_id': session_id, } } ) def handle_outbound_message(self, message): missing_fields = self.ensure_message_values(message, ['in_reply_to', 'content']) if missing_fields: return self.reject_message(message, missing_fields) if message['session_event'] == TransportUserMessage.SESSION_CLOSE: command = 'END' else: command = 'CON' self.finish_request(message['in_reply_to'], ('%s %s' % (command, message['content'])).encode(self.ENCODING)) return self.publish_ack(user_message_id=message['message_id'], sent_message_id=message['message_id']) PK=JG &%vumi/transports/safaricom/__init__.py""" Safaricom HTTP USSD API. """ from vumi.transports.safaricom.safaricom import SafaricomTransport __all__ = ['SafaricomTransport'] PK=JG+vumi/transports/safaricom/tests/__init__.pyPK=JGY[)""1vumi/transports/safaricom/tests/test_safaricom.pyimport json from urllib import urlencode from twisted.internet.defer import inlineCallbacks from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.transports.safaricom import SafaricomTransport from vumi.transports.tests.helpers import TransportHelper from vumi.utils import http_request class TestSafaricomTransport(VumiTestCase): @inlineCallbacks def setUp(self): config = { 'web_port': 0, 'web_path': '/api/v1/safaricom/ussd/', } self.tx_helper = self.add_helper(TransportHelper(SafaricomTransport)) self.transport = yield self.tx_helper.get_transport(config) self.session_manager = self.transport.session_manager self.transport_url = self.transport.get_transport_url( config['web_path']) yield self.session_manager.redis._purge_all() # just in case def mk_full_request(self, **params): return http_request('%s?%s' % (self.transport_url, urlencode(params)), data='', method='GET') def mk_request(self, **params): defaults = { 'ORIG': '27761234567', 'DEST': '167', 'SESSION_ID': 'session-id', 'USSD_PARAMS': '', } defaults.update(params) return self.mk_full_request(**defaults) @inlineCallbacks def test_inbound_begin(self): # Second connect is the actual start of the session deferred = self.mk_request(USSD_PARAMS='7') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '') self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '27761234567'), self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['transport_metadata'], { 'safaricom': { 'session_id': 'session-id', }, }) yield self.tx_helper.make_dispatch_reply(msg, "ussd message") response = yield deferred self.assertEqual(response, 'CON ussd message') @inlineCallbacks def test_inbound_resume_and_reply_with_end(self): # first pre-populate the redis datastore to simulate prior BEG message yield self.session_manager.create_session('session-id', to_addr='*167*7#', from_addr='27761234567', last_ussd_params='7*a*b', session_event=TransportUserMessage.SESSION_RESUME) # Safaricom gives us the history of the full session in the USSD_PARAMS # The last submitted bit of content is the last value delimited by '*' deferred = self.mk_request(USSD_PARAMS='7*a*b*c') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], 'c') self.assertEqual(msg['to_addr'], '*167*7#') self.assertEqual(msg['from_addr'], '27761234567') self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['transport_metadata'], { 'safaricom': { 'session_id': 'session-id', }, }) yield self.tx_helper.make_dispatch_reply( msg, "hello world", continue_session=False) response = yield deferred self.assertEqual(response, 'END hello world') @inlineCallbacks def test_inbound_resume_with_failed_to_addr_lookup(self): deferred = self.mk_full_request(ORIG='123456', USSD_PARAMS='7*a', SESSION_ID='session-id') response = yield deferred self.assertEqual(json.loads(response), { 'missing_parameter': ['DEST'], }) @inlineCallbacks def test_to_addr_handling(self): defaults = { 'DEST': '167', 'ORIG': '12345', 'SESSION_ID': 'session-id', } d1 = self.mk_full_request(USSD_PARAMS='7*1', **defaults) [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167*7*1#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, "hello world") yield d1 # follow up with the user submitting 'a' d2 = self.mk_full_request(USSD_PARAMS='7*1*a', **defaults) [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167*7*1#') self.assertEqual(msg2['content'], 'a') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply( msg2, "hello world", continue_session=False) yield d2 @inlineCallbacks def test_hitting_url_twice_without_content(self): d1 = self.mk_request(USSD_PARAMS='7*3') [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167*7*3#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, "Hello") yield d1 # make the exact same request again d2 = self.mk_request(USSD_PARAMS='7*3') [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167*7*3#') self.assertEqual(msg2['content'], '') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg2, "Hello") yield d2 @inlineCallbacks def test_submitting_asterisks_as_values(self): yield self.session_manager.create_session('session-id', to_addr='*167*7#', from_addr='27761234567', last_ussd_params='7*a*b') # we're submitting a bunch of *s deferred = self.mk_request(USSD_PARAMS='7*a*b*****') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '****') yield self.tx_helper.make_dispatch_reply(msg, "Hello") yield deferred session = yield self.session_manager.load_session('session-id') self.assertEqual(session['last_ussd_params'], '7*a*b*****') @inlineCallbacks def test_submitting_asterisks_as_values_after_asterisks(self): yield self.session_manager.create_session('session-id', to_addr='*167*7#', from_addr='27761234567', last_ussd_params='7*a*b**') # we're submitting a bunch of *s deferred = self.mk_request(USSD_PARAMS='7*a*b*****') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['content'], '**') yield self.tx_helper.make_dispatch_reply(msg, "Hello") yield deferred session = yield self.session_manager.load_session('session-id') self.assertEqual(session['last_ussd_params'], '7*a*b*****') @inlineCallbacks def test_submitting_with_base_code_empty_ussd_params(self): d1 = self.mk_request() [msg1] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg1['to_addr'], '*167#') self.assertEqual(msg1['content'], '') self.assertEqual(msg1['session_event'], TransportUserMessage.SESSION_NEW) yield self.tx_helper.make_dispatch_reply(msg1, "Hello") yield d1 # ask for first menu d2 = self.mk_request(USSD_PARAMS='1') [msg1, msg2] = yield self.tx_helper.wait_for_dispatched_inbound(2) self.assertEqual(msg2['to_addr'], '*167#') self.assertEqual(msg2['content'], '1') self.assertEqual(msg2['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg2, "Hello") yield d2 # ask for second menu d3 = self.mk_request(USSD_PARAMS='1*1') [m1, m2, msg3] = yield self.tx_helper.wait_for_dispatched_inbound(3) self.assertEqual(msg3['to_addr'], '*167#') self.assertEqual(msg3['content'], '1') self.assertEqual(msg3['session_event'], TransportUserMessage.SESSION_RESUME) yield self.tx_helper.make_dispatch_reply(msg3, "Hello") yield d3 @inlineCallbacks def test_nack(self): msg = yield self.tx_helper.make_dispatch_outbound("outbound") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Missing fields: in_reply_to') PK=JGA@g**"vumi/transports/infobip/infobip.py# -*- test-case-name: vumi.transports.infobip.tests.test_infobip -*- # -*- coding: utf-8 -*- """Infobip USSD transport.""" import json from twisted.internet.defer import inlineCallbacks, returnValue from vumi import log from vumi.errors import VumiError from vumi.message import TransportUserMessage from vumi.transports.httprpc import HttpRpcTransport from vumi.components.session import SessionManager class InfobipError(VumiError): """Used to log errors specific to the Infobip transport.""" class InfobipTransport(HttpRpcTransport): """Infobip transport. Currently only supports the Infobip USSD interface. Configuration parameters: :type ussd_session_timeout: int :param ussd_session_timeout: Number of seconds before USSD session information stored in Redis expires. Default is 600s. Excerpt from :title-reference:`INFOBIP USSD Gateway to Third-party Application HTTP/REST/JSON Web Service API`: Third party application provides four methods for session management. Their parameters are as follows: * sessionActive (type Boolean) – true if the session is active, false otherwise. The parameter is mandatory. * sessionId (type String) – is generated for each started session and. The parameter is mandatory. exitCode (type Integer) – defined the status of the session that is complete. All the exit codes can be found in Table 1. The parameter is mandatory. * reason (type String) – in case a third-party applications releases the session before its completion it will contain the reason for the release. The parameter is used for logging purposes and is mandatory. msisdn (type String) – of the user that sent the response to the menu request. The parameter is mandatory. * imsi (type String) – of the user that sent the response to the menu request. The parameter is optional. * text (type String) – text the user entered in the response. The parameter is mandatory. shortCode – Short code entered in the mobile initiated session or by calling start method. The parameter is optional. * language (type String)– defines which language will be used for message text. Used in applications that support internationalization. The parameter should be set to null if not used. The parameter is optional. * optional (type String)– left for future usage scenarios. The parameter is optional. ussdGwId (type String)– id of the USSD GW calling the third-party application. This parameter is optional. Responses to requests sent from the third-party-applications have the following parameters: * ussdMenu (type String)– menu to send as text to the user. The parameter is mandatory. * shouldClose (type boolean)– set to true if this is the last message in this session sent to the user, false if there will be more. The parameter is mandatory. Please note that the first message in the session will always be sent as a menu (i.e. shouldClose will be set to false). * thirdPartyId (type String)– Id of the third-party application or server. This parameter is optional. * responseExitCode (type Integer) – request processing exit code. Parameter is mandatory. """ # method_name -> session_event, handler_name, sends_json METHOD_TO_HANDLER = { "status": (TransportUserMessage.SESSION_NONE, "handle_infobip_status", False), "start": (TransportUserMessage.SESSION_NEW, "handle_infobip_start", True), "response": (TransportUserMessage.SESSION_RESUME, "handle_infobip_response", True), "end": (TransportUserMessage.SESSION_CLOSE, "handle_infobip_end", True), } def validate_config(self): super(InfobipTransport, self).validate_config() self.r_config = self.config.get('redis_manager', {}) @inlineCallbacks def setup_transport(self): yield super(InfobipTransport, self).setup_transport() r_prefix = "infobip:%s" % (self.transport_name,) session_timeout = int(self.config.get("ussd_session_timeout", 600)) self.session_manager = yield SessionManager.from_redis_config( self.r_config, r_prefix, session_timeout) @inlineCallbacks def teardown_transport(self): yield self.session_manager.stop() yield super(InfobipTransport, self).teardown_transport() def save_ussd_params(self, session_id, params): return self.session_manager.create_session(session_id, **params) def get_ussd_params(self, session_id): return self.session_manager.load_session(session_id) def clear_ussd_params(self, session_id): return self.session_manager.clear_session(session_id) def send_error(self, msgid, reason, code=400): response_data = { "responseExitCode": code, "responseMessage": reason, } self.finish_request(msgid, json.dumps(response_data)) @inlineCallbacks def handle_infobip_status(self, msgid, session_id, eq_data): params = yield self.get_ussd_params(session_id) response_data = { "sessionActive": bool(params), "responseExitCode": 200, "responseMessage": "", } yield self.finish_request(msgid, json.dumps(response_data)) @inlineCallbacks def handle_infobip_start(self, msgid, session_id, req_data): message_dict = yield self.get_ussd_params(session_id) if message_dict: self.send_error( msgid, "USSD session %r already started" % (session_id,)) return try: from_addr = req_data["msisdn"] content = req_data["text"] except KeyError, e: self.send_error(msgid, "Missing required JSON field: %r" % (e,)) return message_dict = { "from_addr": from_addr, # unfortunately shortCode is not as mandatory as the # Infobip documentation claims "to_addr": req_data.get("shortCode") or "", # ussdGwId isn't documented but it does get sent and # contains values like "live2". "provider": req_data.get("ussdGwId", ""), } yield self.save_ussd_params(session_id, message_dict) message_dict["content"] = content returnValue(message_dict) @inlineCallbacks def handle_infobip_response(self, msgid, session_id, req_data): message_dict = yield self.get_ussd_params(session_id) if not message_dict: self.send_error(msgid, "Invalid USSD session %r" % (session_id,)) return try: content = req_data["text"] except KeyError, e: self.send_error(msgid, "Missing required JSON field: %r" % (e,)) return message_dict["content"] = content returnValue(message_dict) @inlineCallbacks def handle_infobip_end(self, msgid, session_id, req_data): message_dict = yield self.get_ussd_params(session_id) if not message_dict: self.send_error(msgid, "Invalid USSD session %r" % (session_id,)) return yield self.clear_ussd_params(session_id) response_data = {"responseExitCode": 200, "responseMessage": ""} self.finish_request(msgid, json.dumps(response_data)) message_dict["content"] = None returnValue(message_dict) def handle_infobip_error(self, msgid, session_id, req_data): self.send_error(msgid, req_data.get("error", "Invalid request")) @inlineCallbacks def handle_raw_inbound_message(self, msgid, request): parts = request.path.split('/') session_id = parts[-2] session_method = parts[-1] session_event, session_handler_name, sends_json = ( self.METHOD_TO_HANDLER.get(session_method, (TransportUserMessage.SESSION_NONE, "handle_infobip_error", False))) session_handler = getattr(self, session_handler_name) req_content = request.content.read() log.msg("Incoming message: %r" % (req_content,)) if sends_json: try: req_data = json.loads(req_content) except ValueError: # send bad JSON to error handler session_handler = self.handle_infobip_error req_data = {"error": "Invalid JSON"} else: req_data = {} message_dict = yield session_handler(msgid, session_id, req_data) if message_dict is not None: transport_metadata = {'session_id': session_id} message_dict.setdefault("message_id", msgid) message_dict.setdefault("session_event", session_event) message_dict.setdefault("content", None) message_dict["transport_name"] = self.transport_name message_dict["transport_type"] = self.config.get('transport_type', 'ussd') message_dict["transport_metadata"] = transport_metadata self.publish_message(**message_dict) def handle_outbound_message(self, message): if message.payload.get('in_reply_to'): should_close = (message['session_event'] == TransportUserMessage.SESSION_CLOSE) response_data = { "shouldClose": should_close, "ussdMenu": message.get('content'), "responseExitCode": 200, "responseMessage": "", } response_id = self.finish_request(message['in_reply_to'], json.dumps(response_data)) if response_id is None: err_msg = ("Infobip transport could not find original request" " when attempting to reply.") log.error(InfobipError(err_msg)) return self.publish_nack(user_message_id=message['message_id'], reason=err_msg) else: return self.publish_ack(message['message_id'], sent_message_id=response_id) else: err_msg = ("Infobip transport cannot process outbound message that" " is not a reply: %s" % (message['message_id'],)) log.error(InfobipError(err_msg)) return self.publish_nack(user_message_id=message['message_id'], reason=err_msg) PK=JGX#vumi/transports/infobip/__init__.py"""Infobip transport.""" from vumi.transports.infobip.infobip import InfobipTransport, InfobipError __all__ = ['InfobipTransport', 'InfobipError'] PK=JG)vumi/transports/infobip/tests/__init__.pyPK=JGY+Y+-vumi/transports/infobip/tests/test_infobip.py"""Test for vumi.transport.infobip.infobip.""" import json from twisted.internet.defer import inlineCallbacks, returnValue from vumi.tests.helpers import VumiTestCase from vumi.utils import http_request from vumi.transports.infobip.infobip import InfobipTransport, InfobipError from vumi.message import TransportUserMessage from vumi.tests.utils import LogCatcher from vumi.transports.tests.helpers import TransportHelper class TestInfobipUssdTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper(TransportHelper(InfobipTransport)) self.transport = yield self.tx_helper.get_transport({ 'transport_type': 'ussd', 'web_path': "/session/", 'web_port': 0, }) self.transport_url = self.transport.get_transport_url() yield self.transport.session_manager.redis._purge_all() # just in case DEFAULT_START_DATA = { "msisdn": "385955363443", "imsi": "429011234567890", "shortCode": "*123#1#", "optional": "o=1", "ussdGwId": "11", "language": None, } DEFAULT_SESSION_DATA = { "start": DEFAULT_START_DATA, "response": DEFAULT_START_DATA, "end": {"reason": "ok", "exitCode": 0}, "status": None, } SESSION_HTTP_METHOD = { "end": "PUT", } @inlineCallbacks def make_request(self, session_type, session_id, reply=None, continue_session=True, expect_msg=True, defer_response=False, **kw): url_suffix = "session/%s/%s" % (session_id, session_type) method = self.SESSION_HTTP_METHOD.get(session_type, "POST") request_data = self.DEFAULT_SESSION_DATA[session_type].copy() request_data.update(kw) deferred_req = http_request(self.transport_url + url_suffix, json.dumps(request_data), method=method) if not expect_msg: msg = None else: [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.clear_all_dispatched() if reply is not None: yield self.tx_helper.make_dispatch_reply( msg, reply, continue_session=continue_session) if defer_response: response = deferred_req # We need to make sure we wait for the response so we don't leave # the reactor dirty if the test runner wins the race with the HTTP # client. self.add_cleanup(lambda: deferred_req) else: response = yield deferred_req returnValue((msg, response)) @inlineCallbacks def test_start(self): msg, response = yield self.make_request("start", 1, text="hello there", reply="hello yourself") self.assertEqual(msg['content'], 'hello there') self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) correct_response = { "shouldClose": False, "responseExitCode": 200, "ussdMenu": "hello yourself", "responseMessage": "", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_start_twice(self): msg, response = yield self.make_request("start", 1, text="hello there", reply="hello yourself") msg, response = yield self.make_request("start", 1, text="hello again", expect_msg=False) correct_response = { 'responseExitCode': 400, 'responseMessage': "USSD session '1' already started", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_response_with_close(self): msg, response = yield self.make_request("start", 1, text="Hi", reply="Hi!") msg, response = yield self.make_request("response", 1, text="More?", reply="No thanks.", continue_session=False) self.assertEqual(msg['content'], 'More?') self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) correct_response = { "shouldClose": True, "responseExitCode": 200, "ussdMenu": "No thanks.", "responseMessage": "", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_response_for_invalid_session(self): msg, response = yield self.make_request("response", 1, expect_msg=False) correct_response = { 'responseExitCode': 400, 'responseMessage': "Invalid USSD session '1'", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_end(self): msg, response = yield self.make_request("start", 1, text='Bye!', reply="Barp") msg, response = yield self.make_request("end", 1) self.assertEqual(msg['content'], None) self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) correct_response = { "responseExitCode": 200, "responseMessage": "", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_end_for_invalid_session(self): msg, response = yield self.make_request("end", 1, expect_msg=False) correct_response = { 'responseExitCode': 400, 'responseMessage': "Invalid USSD session '1'", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_status_for_active_session(self): msg, response = yield self.make_request("start", 1, text="Hi", reply="Boop") response = yield http_request( self.transport_url + "session/1/status", "", method="GET") correct_response = { 'responseExitCode': 200, 'responseMessage': '', 'sessionActive': True, } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_status_for_inactive_session(self): response = yield http_request( self.transport_url + "session/1/status", "", method="GET") correct_response = { 'responseExitCode': 200, 'responseMessage': '', 'sessionActive': False, } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_non_json_content(self): response = yield http_request(self.transport_url + "session/1/start", "not json at all", method="POST") correct_response = { 'responseExitCode': 400, 'responseMessage': 'Invalid JSON', } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_start_without_text(self): msg, response = yield self.make_request("start", 1, expect_msg=False) correct_response = { 'responseExitCode': 400, 'responseMessage': "Missing required JSON field:" " KeyError('text',)", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_response_without_text(self): msg, response = yield self.make_request("start", 1, text="Hi!", reply="Moo") msg, response = yield self.make_request("response", 1, expect_msg=False) correct_response = { 'responseExitCode': 400, 'responseMessage': "Missing required JSON field:" " KeyError('text',)", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_start_without_msisdn(self): json_dict = { 'text': 'Oops. No msisdn.', } response = yield http_request(self.transport_url + "session/1/start", json.dumps(json_dict), method='POST') correct_response = { 'responseExitCode': 400, 'responseMessage': "Missing required JSON field:" " KeyError('msisdn',)", } self.assertEqual(json.loads(response), correct_response) @inlineCallbacks def test_outbound_non_reply_logs_error(self): with LogCatcher() as logger: msg = yield self.tx_helper.make_dispatch_outbound("hi") [error] = logger.errors expected_error = ("Infobip transport cannot process outbound message" " that is not a reply: %s" % (msg['message_id'],)) self.assertEqual(str(error['failure'].value), expected_error) [f] = self.flushLoggedErrors(InfobipError) self.assertEqual(f, error['failure']) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], expected_error) @inlineCallbacks def test_ack(self): msg, response = yield self.make_request( "start", 1, text="Hi!", reply="Moo") [event] = yield self.tx_helper.wait_for_dispatched_events(1) [reply] = yield self.tx_helper.wait_for_dispatched_outbound(1) self.assertEqual(event["event_type"], "ack") self.assertEqual(event["user_message_id"], reply["message_id"]) @inlineCallbacks def test_reply_failure(self): msg, deferred_req = yield self.make_request("start", 1, text="Hi!", defer_response=True) # finish message so reply will fail self.transport.finish_request(msg['message_id'], "Done") with LogCatcher() as logger: reply = yield self.tx_helper.make_dispatch_reply(msg, "Ping") [error] = logger.errors expected_error = ("Infobip transport could not find original request" " when attempting to reply.") self.assertEqual(str(error['failure'].value), expected_error) [f] = self.flushLoggedErrors(InfobipError) self.assertEqual(f, error['failure']) [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], reply['message_id']) self.assertEqual(nack['nack_reason'], expected_error) PK=JG'vumi/transports/trueafrican/__init__.pyPK=Hc-c-(vumi/transports/trueafrican/transport.py# -*- test-case-name: vumi.transports.trueafrican.tests.test_transport -*- """ USSD Transport for TrueAfrican (Uganda) """ import collections from twisted.internet.defer import Deferred, returnValue, inlineCallbacks, fail from twisted.internet.task import LoopingCall from twisted.internet import reactor from twisted.python.failure import Failure from twisted.web import xmlrpc, server from vumi import log from vumi.errors import VumiError from vumi.message import TransportUserMessage from vumi.transports.base import Transport from vumi.components.session import SessionManager from vumi.config import ConfigText, ConfigInt, ConfigDict class TrueAfricanError(VumiError): """Raised by errors in the TrueAfrican transport.""" class TrueAfricanUssdTransportConfig(Transport.CONFIG_CLASS): """TrueAfrican USSD transport configuration.""" port = ConfigInt( "Bind to this port", required=True, static=True) interface = ConfigText( "Bind to this interface", default='', static=True) redis_manager = ConfigDict( "Parameters to connect to Redis with", default={}, static=True) session_timeout = ConfigInt( "Number of seconds before USSD session information stored in" " Redis expires.", default=600, static=True) request_timeout = ConfigInt( "How long should we wait for the remote side generating the response" " for this synchronous operation to come back. Any connection that has" " waited longer than `request_timeout` seconds will manually be" " closed.", default=(4 * 60), static=True) class TrueAfricanUssdTransport(Transport): CONFIG_CLASS = TrueAfricanUssdTransportConfig TRANSPORT_TYPE = 'ussd' SESSION_STATE_MAP = { TransportUserMessage.SESSION_NONE: 'cont', TransportUserMessage.SESSION_RESUME: 'cont', TransportUserMessage.SESSION_CLOSE: 'end', } TIMEOUT_TASK_INTERVAL = 10 @inlineCallbacks def setup_transport(self): super(TrueAfricanUssdTransport, self).setup_transport() config = self.get_static_config() # Session handling key_prefix = "trueafrican:%s" % self.transport_name self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, key_prefix, config.session_timeout ) # XMLRPC Resource self.web_resource = reactor.listenTCP( config.port, server.Site(XmlRpcResource(self)), interface=config.interface ) # request tracking self.clock = self.get_clock() self._requests = {} self.request_timeout = config.request_timeout self.timeout_task = LoopingCall(self.request_timeout_cb) self.timeout_task.clock = self.clock self.timeout_task_d = self.timeout_task.start( self.TIMEOUT_TASK_INTERVAL, now=False ) self.timeout_task_d.addErrback( log.err, "Request timeout handler failed" ) @inlineCallbacks def teardown_transport(self): yield self.web_resource.loseConnection() if self.timeout_task.running: self.timeout_task.stop() yield self.timeout_task_d yield self.session_manager.stop() yield super(TrueAfricanUssdTransport, self).teardown_transport() def get_clock(self): """ For easier stubbing in tests """ return reactor def request_timeout_cb(self): for request_id, request in self._requests.items(): timestamp = request.timestamp if timestamp < self.clock.seconds() - self.request_timeout: self.finish_expired_request(request_id, request) def track_request(self, request_id, http_request, session): d = Deferred() self._requests[request_id] = Request(d, http_request, session, self.clock.seconds()) return d def _send_inbound(self, session_id, session, session_event, content): transport_metadata = {'session_id': session_id} request_id = self.generate_message_id() self.publish_message( message_id=request_id, content=content, to_addr=session['to_addr'], from_addr=session['from_addr'], session_event=session_event, transport_name=self.transport_name, transport_type=self.TRANSPORT_TYPE, transport_metadata=transport_metadata, ) return request_id @inlineCallbacks def handle_session_new(self, request, session_id, msisdn, to_addr): session = yield self.session_manager.create_session( session_id, from_addr=msisdn, to_addr=to_addr ) session_event = TransportUserMessage.SESSION_NEW request_id = self._send_inbound( session_id, session, session_event, None) r = yield self.track_request(request_id, request, session) returnValue(r) @inlineCallbacks def handle_session_resume(self, request, session_id, content): # This is an existing session. session = yield self.session_manager.load_session(session_id) if not session: returnValue(self.response_for_error()) session_event = TransportUserMessage.SESSION_RESUME request_id = self._send_inbound( session_id, session, session_event, content) r = yield self.track_request(request_id, request, session) returnValue(r) @inlineCallbacks def handle_session_end(self, request, session_id): session = yield self.session_manager.load_session(session_id) if not session: returnValue(self.response_for_error()) session_event = TransportUserMessage.SESSION_CLOSE # send a response immediately, and don't (n)ack # since this is not application-initiated self._send_inbound(session_id, session, session_event, None) response = {} returnValue(response) def handle_outbound_message(self, message): in_reply_to = message['in_reply_to'] session_id = message['transport_metadata'].get('session_id') content = message['content'] if not (in_reply_to and session_id and content): return self.publish_nack( user_message_id=message['message_id'], sent_message_id=message['message_id'], reason="Missing in_reply_to, content or session_id fields" ) response = { 'session': session_id, 'type': self.SESSION_STATE_MAP[message['session_event']], 'message': content } log.msg("Sending outbound message %s: %s" % ( message['message_id'], response) ) self.finish_request(in_reply_to, message['message_id'], response) def response_for_error(self): """ Generic response for abnormal server side errors. """ response = { 'message': 'We encountered an error while processing your message', 'type': 'end' } return response def finish_request(self, request_id, message_id, response): request = self._requests.get(request_id) if request is None: # send a nack back, indicating that the original request had # timed out before the outbound message reached us. self.publish_nack( user_message_id=message_id, sent_message_id=message_id, reason='Exceeded request timeout' ) else: del self._requests[request_id] # (n)ack publishing. # # Add a callback and errback, either of which will be invoked # depending on whether the response was written to the client # successfully or not if request.http_request.content.closed: request_done = fail(Failure(TrueAfricanError( "HTTP client closed connection"))) else: request_done = request.http_request.notifyFinish() request_done.addCallbacks( lambda _: self._finish_success_cb(message_id), lambda f: self._finish_failure_cb(f, message_id) ) request.deferred.callback(response) def finish_expired_request(self, request_id, request): """ Called on requests that timed out. """ del self._requests[request_id] log.msg('Timing out on response for %s' % request.session['from_addr']) request.deferred.callback(self.response_for_error()) def _finish_success_cb(self, message_id): self.publish_ack(message_id, message_id) def _finish_failure_cb(self, failure, message_id): self.publish_nack( user_message_id=message_id, sent_message_id=message_id, reason=str(failure) ) # The transport keeps tracks of requests which are still waiting on a response # from an application worker. These requests are stored in a dict and are keyed # by the message_id of the transport message dispatched by the # transport in response to the request. # # For each logical request, we keep track of the following: # # deferred: When the response is available, we fire this deferred # http_request: The Twisted HTTP request associated with the XML RPC request # session: Reference to session data # timestamp: The time the request was received. Used for timeouts. Request = collections.namedtuple('Request', ['deferred', 'http_request', 'session', 'timestamp']) class XmlRpcResource(xmlrpc.XMLRPC): def __init__(self, transport): xmlrpc.XMLRPC.__init__(self, allowNone=True, useDateTime=False) self.putSubHandler("USSD", USSDXmlRpcResource(transport)) class USSDXmlRpcResource(xmlrpc.XMLRPC): def __init__(self, transport): xmlrpc.XMLRPC.__init__(self, allowNone=True, useDateTime=False) self.transport = transport @xmlrpc.withRequest def xmlrpc_INIT(self, request, session_data): """ handler for USSD.INIT """ msisdn = session_data['msisdn'] to_addr = session_data['shortcode'] session_id = session_data['session'] return self.transport.handle_session_new(request, session_id, msisdn, to_addr) @xmlrpc.withRequest def xmlrpc_CONT(self, request, session_data): """ handler for USSD.CONT """ session_id = session_data['session'] content = session_data['response'] return self.transport.handle_session_resume(request, session_id, content) @xmlrpc.withRequest def xmlrpc_END(self, request, session_data): """ handler for USSD.END """ session_id = session_data['session'] return self.transport.handle_session_end(request, session_id) PK=JG-vumi/transports/trueafrican/tests/__init__.pyPK=HX.r9+9+3vumi/transports/trueafrican/tests/test_transport.pyfrom twisted.internet.defer import inlineCallbacks from twisted.internet.task import Clock from twisted.internet.error import ConnectionLost from twisted.web.xmlrpc import Proxy from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.tests.utils import LogCatcher from vumi.transports.tests.helpers import TransportHelper from vumi.transports.trueafrican.transport import TrueAfricanUssdTransport class TestTrueAfricanUssdTransport(VumiTestCase): SESSION_INIT_BODY = { 'session': '1', 'msisdn': '+27724385170', 'shortcode': '*23#' } @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper( TransportHelper(TrueAfricanUssdTransport)) self.clock = Clock() self.patch(TrueAfricanUssdTransport, 'get_clock', lambda _: self.clock) self.transport = yield self.tx_helper.get_transport({ 'interface': '127.0.0.1', 'port': 0, 'request_timeout': 10, }) self.service_url = self.get_service_url(self.transport) def get_service_url(self, transport): """ Get the URL for the HTTP resource. Requires the worker to be started. """ addr = transport.web_resource.getHost() return "http://%s:%s/" % (addr.host, addr.port) def web_client(self): return Proxy(self.service_url) @inlineCallbacks def test_session_new(self): client = self.web_client() resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "Oh Hai!") # verify the transport -> application message self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_NEW) self.assertEqual(msg['from_addr'], '+27724385170') self.assertEqual(msg['to_addr'], '*23#') self.assertEqual(msg['content'], None) resp = yield resp_d self.assertEqual( resp, { 'message': 'Oh Hai!', 'session': '1', 'type': 'cont' } ) @inlineCallbacks def test_session_resume(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "pong") yield resp_d yield self.tx_helper.clear_dispatched_inbound() # resume session resp_d = client.callRemote( 'USSD.CONT', {'session': '1', 'response': 'pong'} ) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "ping") # verify the dispatched inbound message self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_RESUME) self.assertEqual(msg['from_addr'], '+27724385170') self.assertEqual(msg['to_addr'], '*23#') self.assertEqual(msg['content'], 'pong') resp = yield resp_d self.assertEqual( resp, { 'message': 'ping', 'session': '1', 'type': 'cont' } ) @inlineCallbacks def test_session_end_user_initiated(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "ping") yield resp_d yield self.tx_helper.clear_dispatched_inbound() # user initiated session termination resp_d = client.callRemote( 'USSD.END', {'session': '1'} ) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['transport_type'], "ussd") self.assertEqual(msg['session_event'], TransportUserMessage.SESSION_CLOSE) self.assertEqual(msg['from_addr'], '+27724385170') self.assertEqual(msg['to_addr'], '*23#') self.assertEqual(msg['content'], None) resp = yield resp_d self.assertEqual(resp, {}) @inlineCallbacks def test_session_end_application_initiated(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "ping") yield resp_d yield self.tx_helper.clear_dispatched_inbound() # end session resp_d = client.callRemote( 'USSD.CONT', {'session': '1', 'response': 'o rly?'} ) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply( msg, "kthxbye", continue_session=False) resp = yield resp_d self.assertEqual( resp, { 'message': 'kthxbye', 'session': '1', 'type': 'end' } ) @inlineCallbacks def test_ack_for_outbound_message(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) # send response [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) rep = yield self.tx_helper.make_dispatch_reply(msg, "ping") yield resp_d [ack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(ack['event_type'], 'ack') self.assertEqual(ack['user_message_id'], rep['message_id']) self.assertEqual(ack['sent_message_id'], rep['message_id']) @inlineCallbacks def test_nack_for_outbound_message(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) # cancel the request and mute the resulting error. request = self.transport._requests[msg['message_id']] request.http_request.connectionLost(ConnectionLost()) resp_d.cancel() resp_d.addErrback(lambda f: None) # send response rep = yield self.tx_helper.make_dispatch_reply(msg, "ping") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], rep['message_id']) self.assertEqual(nack['sent_message_id'], rep['message_id']) self.assertTrue('HTTP client closed connection' in nack['nack_reason']) @inlineCallbacks def test_nack_for_request_timeout(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.clock.advance(10.1) # .1 second after timeout rep = yield self.tx_helper.make_dispatch_reply(msg, "ping") yield resp_d [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['event_type'], 'nack') self.assertEqual(nack['user_message_id'], rep['message_id']) self.assertEqual(nack['sent_message_id'], rep['message_id']) self.assertEqual(nack['nack_reason'], 'Exceeded request timeout') @inlineCallbacks def test_nack_for_invalid_outbound_message(self): msg = yield self.tx_helper.make_dispatch_outbound("outbound") [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Missing in_reply_to, content or session_id fields') @inlineCallbacks def test_timeout(self): client = self.web_client() # initiate session resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) with LogCatcher(message='Timing out') as lc: self.assertTrue(msg['message_id'] in self.transport._requests) self.clock.advance(10.1) # .1 second after timeout self.assertFalse(msg['message_id'] in self.transport._requests) [warning] = lc.messages() self.assertEqual(warning, 'Timing out on response for +27724385170') resp = yield resp_d self.assertEqual( resp, { 'message': ('We encountered an error while processing' ' your message'), 'type': 'end' } ) @inlineCallbacks def test_request_tracking(self): """ Verify that the transport cleans up after finishing a request """ client = self.web_client() resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "pong") self.assertTrue(msg['message_id'] in self.transport._requests) yield resp_d self.assertFalse(msg['message_id'] in self.transport._requests) @inlineCallbacks def test_missing_session(self): """ Verify that the transport handles missing session data in a graceful manner """ client = self.web_client() resp_d = client.callRemote('USSD.INIT', self.SESSION_INIT_BODY) [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.tx_helper.make_dispatch_reply(msg, "pong") yield resp_d yield self.tx_helper.clear_dispatched_inbound() # simulate Redis falling over yield self.transport.session_manager.redis._purge_all() # resume resp_d = client.callRemote( 'USSD.CONT', {'session': '1', 'response': 'o rly?'} ) resp = yield resp_d self.assertEqual( resp, { 'message': ('We encountered an error while processing' ' your message'), 'type': 'end' } ) PK=JGHrr(vumi/transports/mediaedgegsm/__init__.pyfrom vumi.transports.mediaedgegsm.mediaedgegsm import MediaEdgeGSMTransport __all__ = ['MediaEdgeGSMTransport'] PKqGRڵBB,vumi/transports/mediaedgegsm/mediaedgegsm.py# -*- test-case-name: vumi.transports.mediaedgegsm.tests.test_mediaedgegsm -*- import json from urllib import urlencode from twisted.python import log from twisted.web import http from twisted.internet.defer import inlineCallbacks from vumi.transports.httprpc import HttpRpcTransport from vumi.utils import http_request_full, get_operator_name class MediaEdgeGSMTransport(HttpRpcTransport): """ HTTP transport for MediaEdgeGSM in Ghana. :param str web_path: The HTTP path to listen on. :param int web_port: The HTTP port :param str transport_name: The name this transport instance will use to create its queues :param str username: MediaEdgeGSM account username. :param str password: MediaEdgeGSM account password. :param str outbound_url: The URL to hit for outbound messages that aren't replies. :param str outbound_username: The username for outbound non-reply messages. :param str outbound_password: The username for outbound non-reply messages. :param dict operator_mappings: A nested dictionary mapping MSISDN prefixes to operator names """ transport_type = 'sms' content_type = 'text/plain; charset=utf-8' agent_factory = None # For swapping out the Agent we use in tests. ENCODING = 'utf-8' EXPECTED_FIELDS = set(['USN', 'PWD', 'PhoneNumber', 'ServiceNumber', 'Operator', 'SMSBODY']) def setup_transport(self): self._username = self.config.get('username') self._password = self.config.get('password') self._outbound_url = self.config.get('outbound_url') self._outbound_url_username = self.config.get('outbound_username', '') self._outbound_url_password = self.config.get('outbound_password', '') self._operator_mappings = self.config.get('operator_mappings', {}) return super(MediaEdgeGSMTransport, self).setup_transport() @inlineCallbacks def handle_outbound_message(self, message): if message.payload.get('in_reply_to') and 'content' in message.payload: super(MediaEdgeGSMTransport, self).handle_outbound_message(message) else: msisdn = message['to_addr'].lstrip('+') params = { "USN": self._outbound_url_username, "PWD": self._outbound_url_password, "SmsID": message['message_id'], "PhoneNumber": msisdn, "Operator": get_operator_name(msisdn, self._operator_mappings), "SmsBody": message['content'], } url = '%s?%s' % (self._outbound_url, urlencode(params)) response = yield http_request_full( url, '', method='GET', agent_class=self.agent_factory) log.msg("Response: (%s) %r" % ( response.code, response.delivered_body)) if response.code == http.OK: yield self.publish_ack( user_message_id=message['message_id'], sent_message_id=message['message_id']) else: yield self.publish_nack( user_message_id=message['message_id'], sent_message_id=message['message_id'], reason='Unexpected response code: %s' % (response.code,)) @inlineCallbacks def handle_raw_inbound_message(self, message_id, request): values, errors = self.get_field_values(request, self.EXPECTED_FIELDS) if self._username and (values.get('USN') != self._username): errors['credentials'] = 'invalid' if self._password and (values.get('PWD') != self._password): errors['credentials'] = 'invalid' if errors: log.msg('Unhappy incoming message: %s' % (errors,)) yield self.finish_request(message_id, json.dumps(errors), code=400) return log.msg(('MediaEdgeGSMTransport sending from %(PhoneNumber)s to ' '%(ServiceNumber)s on %(Operator)s message ' '"%(SMSBODY)s"') % values) yield self.publish_message( message_id=message_id, content=values['SMSBODY'], to_addr=values['ServiceNumber'], from_addr=values['PhoneNumber'], provider=values['Operator'], transport_type=self.transport_type, ) PK=JG.vumi/transports/mediaedgegsm/tests/__init__.pyPKqG"7vumi/transports/mediaedgegsm/tests/test_mediaedgegsm.py# -*- encoding: utf-8 -*- import json from urllib import urlencode from twisted.internet.defer import inlineCallbacks, DeferredQueue from twisted.web import http from vumi.utils import http_request, http_request_full from vumi.tests.fake_connection import FakeHttpServer from vumi.tests.helpers import VumiTestCase from vumi.transports.mediaedgegsm import MediaEdgeGSMTransport from vumi.transports.tests.helpers import TransportHelper class TestMediaEdgeGSMTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.mediaedgegsm_calls = DeferredQueue() self.fake_http = FakeHttpServer(self.handle_request) self.base_url = "http://mediaedgegsm.example.com/" self.config = { 'web_path': "foo", 'web_port': 0, 'username': 'user', 'password': 'pass', 'outbound_url': self.base_url, 'outbound_username': 'username', 'outbound_password': 'password', 'operator_mappings': { '417': { '417912': 'VODA', '417913': 'TIGO', '417914': 'UNKNOWN', } } } self.tx_helper = self.add_helper( TransportHelper(MediaEdgeGSMTransport)) self.transport = yield self.tx_helper.get_transport(self.config) self.transport.agent_factory = self.fake_http.get_agent self.transport_url = self.transport.get_transport_url() self.mediaedgegsm_response = '' self.mediaedgegsm_response_code = http.OK def handle_request(self, request): self.mediaedgegsm_calls.put(request) request.setResponseCode(self.mediaedgegsm_response_code) return self.mediaedgegsm_response def mkurl(self, content, from_addr="2371234567", **kw): params = { 'ServiceNumber': '12345', 'PhoneNumber': from_addr, 'SMSBODY': content, 'USN': 'user', 'PWD': 'pass', 'Operator': 'foo', } params.update(kw) return self.mkurl_raw(**params) def mkurl_raw(self, **params): return '%s%s?%s' % ( self.transport_url, self.config['web_path'], urlencode(params) ) @inlineCallbacks def test_health(self): result = yield http_request( self.transport_url + "health", "", method='GET') self.assertEqual(json.loads(result), {'pending_requests': 0}) @inlineCallbacks def test_inbound(self): url = self.mkurl('hello') deferred = http_request(url, '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], "hello") yield self.tx_helper.make_dispatch_reply(msg, 'message received') response = yield deferred self.assertEqual(response, 'message received') @inlineCallbacks def test_outbound(self): msisdns = ['+41791200000', '+41791300000', '+41791400000'] operators = ['VODA', 'TIGO', 'UNKNOWN'] sent_messages = [] for msisdn in msisdns: msg = yield self.tx_helper.make_dispatch_outbound( "outbound", to_addr=msisdn) sent_messages.append(msg) req1 = yield self.mediaedgegsm_calls.get() req2 = yield self.mediaedgegsm_calls.get() req3 = yield self.mediaedgegsm_calls.get() requests = [req1, req2, req3] for req in requests: self.assertEqual(req.path, self.base_url) self.assertEqual(req.method, 'GET') collections = zip(msisdns, operators, sent_messages, requests) for msisdn, operator, msg, req in collections: self.assertEqual({ 'USN': ['username'], 'PWD': ['password'], 'SmsID': [msg['message_id']], 'PhoneNumber': [msisdn.lstrip('+')], 'Operator': [operator], 'SmsBody': [msg['content']], }, req.args) @inlineCallbacks def test_nack(self): self.mediaedgegsm_response_code = http.NOT_FOUND self.mediaedgegsm_response = 'Not Found' msg = yield self.tx_helper.make_dispatch_outbound( "outbound", to_addr='+41791200000') yield self.mediaedgegsm_calls.get() [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(nack['user_message_id'], msg['message_id']) self.assertEqual(nack['sent_message_id'], msg['message_id']) self.assertEqual(nack['nack_reason'], 'Unexpected response code: 404') @inlineCallbacks def test_bad_parameter(self): url = self.mkurl('hello', foo='bar') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'unexpected_parameter': ['foo']}) @inlineCallbacks def test_missing_parameters(self): url = self.mkurl_raw(ServiceNumber='12345', SMSBODY='hello', USN='user', PWD='pass', Operator='foo') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'missing_parameter': ['PhoneNumber']}) @inlineCallbacks def test_invalid_credentials(self): url = self.mkurl_raw( ServiceNumber='12345', SMSBODY='hello', USN='something', PWD='wrong', Operator='foo', PhoneNumber='1234') response = yield http_request_full(url, '', method='GET') self.assertEqual(400, response.code) self.assertEqual(json.loads(response.delivered_body), {'credentials': 'invalid'}) @inlineCallbacks def test_handle_non_ascii_input(self): url = self.mkurl(u"öæł".encode("utf-8")) deferred = http_request_full(url, '', method='GET') [msg] = yield self.tx_helper.wait_for_dispatched_inbound(1) self.assertEqual(msg['transport_name'], self.tx_helper.transport_name) self.assertEqual(msg['to_addr'], "12345") self.assertEqual(msg['from_addr'], "2371234567") self.assertEqual(msg['content'], u"öæł") yield self.tx_helper.make_dispatch_reply(msg, u'Zoë says hi') response = yield deferred self.assertEqual(response.headers.getRawHeaders('Content-Type'), ['text/plain; charset=utf-8']) self.assertEqual(response.delivered_body, u'Zoë says hi'.encode('utf-8')) PK=JG5Hvumi/transports/opera/utils.pyfrom collections import namedtuple import xml.etree.ElementTree as ET OPERA_TIMESTAMP_FORMAT = "%Y%m%dT%H:%M:%S" def parse_receipts_xml(receipt_xml_data): tree = ET.fromstring(receipt_xml_data) return map(receipt_to_namedtuple, tree.findall('receipt')) def receipt_element_to_dict(element): """ Turn an ElementTree element '1' into {el: 1}. Not recursive! >>> data = ET.fromstring("1") >>> receipt_element_to_dict(data) {'el': '1'} >>> """ return dict([(child.tag, child.text) for child in element.getchildren()]) def receipt_to_namedtuple(element): """ Turn an ElementTree element into an object with named params. Not recursive! >>> data = ET.fromstring("1") >>> receipt_to_namedtuple(data) data(el='1') """ d = receipt_element_to_dict(element) klass = namedtuple(element.tag, d.keys()) return klass._make(d.values()) def parse_post_event_xml(post_event_xml_data): tree = ET.fromstring(post_event_xml_data) fields = tree.findall('field') return dict([(field.attrib['name'], field.text) for field in fields]) PK=JG 127 for c in content): content = xmlrpc.Binary(content.encode('utf-8')) xmlrpc_payload['Numbers'] = message['to_addr'] xmlrpc_payload['SMSText'] = content xmlrpc_payload['Delivery'] = delivery xmlrpc_payload['Expiry'] = expiry xmlrpc_payload['Priority'] = priority xmlrpc_payload['Receipt'] = receipt xmlrpc_payload['MaxSegments'] = self.max_segments log.msg("Sending SMS via Opera: %s" % xmlrpc_payload) d = self.proxy.callRemote('EAPIGateway.SendSMS', xmlrpc_payload) d.addErrback(self.handle_outbound_message_failure, message) proxy_response = yield d log.msg("Proxy response: %s" % proxy_response) transport_message_id = proxy_response['Identifier'] yield self.set_message_id_for_identifier( transport_message_id, message['message_id']) yield self.publish_ack( user_message_id=message['message_id'], sent_message_id=transport_message_id) @inlineCallbacks def handle_outbound_message_failure(self, failure, message): """ Decide what to do on certain failure cases. """ if failure.check(xmlrpc.Fault): # If the XML-RPC service isn't behaving properly raise TemporaryFailure(failure) elif failure.check(ValueError): # If the HTTP protocol returns something other than 200 yield self.publish_nack(message['message_id'], str(failure.value)) raise PermanentFailure(failure) else: # Unspecified yield self.publish_nack(message['message_id'], str(failure.value)) raise failure @inlineCallbacks def teardown_transport(self): log.msg("Stopping the OperaOutboundTransport: %s" % self.transport_name) yield self.web_resource.loseConnection() yield self.session_manager.stop() PK=JG;ؤKK)vumi/transports/opera/tests/test_opera.py# -*- coding: utf-8 -*- from datetime import datetime, timedelta from urlparse import parse_qs from twisted.internet import defer from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.web import xmlrpc from vumi.utils import http_request, http_request_full from vumi.transports.failures import PermanentFailure, TemporaryFailure from vumi.transports.opera import OperaTransport from vumi.transports.tests.helpers import TransportHelper from vumi.tests.helpers import VumiTestCase class FakeXMLRPCService(object): def __init__(self, callback): self.callback = callback def callRemote(self, *args, **kwargs): return maybeDeferred(self.callback, *args, **kwargs) class TestOperaTransport(VumiTestCase): @inlineCallbacks def setUp(self): self.tx_helper = self.add_helper( TransportHelper(OperaTransport, mobile_addr='27761234567')) self.transport = yield self.tx_helper.get_transport({ 'url': 'http://testing.domain', 'channel': 'channel', 'service': 'service', 'password': 'password', 'web_receipt_path': '/receipt.xml', 'web_receive_path': '/receive.xml', 'web_port': 0, }) @inlineCallbacks def test_receipt_processing(self): """it should be able to process an incoming XML receipt via HTTP""" identifier = '001efc31' message_id = '123456' # prime redis to match the incoming identifier to an # internal message id yield self.transport.set_message_id_for_identifier( identifier, message_id) xml_data = """ 26567958 %s +27123456789 D 20080831T15:59:24 NO """.strip() % identifier yield http_request( self.transport.get_transport_url('receipt.xml'), xml_data) self.assertEqual([], self.tx_helper.get_dispatched_failures()) self.assertEqual([], self.tx_helper.get_dispatched_inbound()) [event] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(event['delivery_status'], 'delivered') self.assertEqual(event['message_type'], 'event') self.assertEqual(event['event_type'], 'delivery_report') self.assertEqual(event['user_message_id'], message_id) @inlineCallbacks def test_incoming_sms_processing_urlencoded(self): """ it should be able to process in incoming sms as XML delivered via HTTP """ xml_data = ( 'XmlMsg=%3C%3Fxml%20version%3D%221.0%22%3F%3E%0A%3C!DOCTYPE%20bspo' 'stevent%3E%0A%3Cbspostevent%3E%0A%20%20%3Cfield%20name%3D%22MORef' 'erence%22%20type%20%3D%20%22string%22%3E478535078%3C/field%3E%0A%' '20%20%3Cfield%20name%3D%22RemoteNetwork%22%20type%20%3D%20%22stri' 'ng%22%3Emtn-za%3C/field%3E%0A%20%20%3Cfield%20name%3D%22BSDate-to' 'morrow%22%20type%20%3D%20%22string%22%3E20120317%3C/field%3E%0A%2' '0%20%3Cfield%20name%3D%22BSDate-today%22%20type%20%3D%20%22string' '%22%3E20120316%3C/field%3E%0A%20%20%3Cfield%20name%3D%22ReceiveDa' 'te%22%20type%20%3D%20%22date%22%3E2012-03-16%2011:50:04%20%2B0000' '%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Local%22%20type%20%3D%2' '0%22string%22%3E*32323%3C/field%3E%0A%20%20%3Cfield%20name%3D%22C' 'lientID%22%20type%20%3D%20%22string%22%3E4%3C/field%3E%0A%20%20%3' 'Cfield%20name%3D%22ChannelID%22%20type%20%3D%20%22string%22%3E176' '%3C/field%3E%0A%20%20%3Cfield%20name%3D%22MessageID%22%20type%20%' '3D%20%22string%22%3E1487577162%3C/field%3E%0A%20%20%3Cfield%20nam' 'e%3D%22Prefix%22%20type%20%3D%20%22string%22%3E%3C/field%3E%0A%20' '%20%3Cfield%20name%3D%22ClientName%22%20type%20%3D%20%22string%22' '%3EPraekelt%3C/field%3E%0A%20%20%3Cfield%20name%3D%22MobileDevice' '%22%20type%20%3D%20%22string%22%3E%3C/field%3E%0A%20%20%3Cfield%2' '0name%3D%22BSDate-yesterday%22%20type%20%3D%20%22string%22%3E2012' '0315%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Remote%22%20type%20' '%3D%20%22string%22%3E%2B27831234567%3C/field%3E%0A%20%20%3Cfield%' '20name%3D%22MobileNetwork%22%20type%20%3D%20%22string%22%3Emtn-za' '%3C/field%3E%0A%20%20%3Cfield%20name%3D%22State%22%20type%20%3D%2' '0%22string%22%3E9%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Mobile' 'Number%22%20type%20%3D%20%22string%22%3E%2B27831234567%3C/field%3' 'E%0A%20%20%3Cfield%20name%3D%22Text%22%20type%20%3D%20%22string%2' '2%3EHerb01%20spice01%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Ser' 'viceID%22%20type%20%3D%20%22string%22%3E30756%3C/field%3E%0A%20%2' '0%3Cfield%20name%3D%22RegType%22%20type%20%3D%20%22string%22%3ESM' 'S%3C/field%3E%0A%20%20%3Cfield%20name%3D%22NewSubscriber%22%20typ' 'e%20%3D%20%22string%22%3ENO%3C/field%3E%0A%20%20%3Cfield%20name%3' 'D%22Subscriber%22%20type%20%3D%20%22string%22%3E%2B27831234567%3C' '/field%3E%0A%20%20%3Cfield%20name%3D%22id%22%20type%20%3D%20%22st' 'ring%22%3E3361920%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Parsed' '%22%20type%20%3D%20%22string%22%3E%3C/field%3E%0A%20%20%3Cfield%2' '0name%3D%22ServiceName%22%20type%20%3D%20%22string%22%3ERobertson' '%26%238217%3Bs%20Herb%20%26amp%3B%20Spices%20Promo%3C/field%3E%0A' '%20%20%3Cfield%20name%3D%22BSDate-thisweek%22%20type%20%3D%20%22s' 'tring%22%3E20120312%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Serv' 'iceEndDate%22%20type%20%3D%20%22string%22%3E2012-12-31%2003:06:00' '%20%2B0200%3C/field%3E%0A%20%20%3Cfield%20name%3D%22Now%22%20type' '%20%3D%20%22date%22%3E2012-03-16%2011:50:05%20%2B0000%3C/field%3E' '%0A%3C/bspostevent%3E%0A') resp = yield http_request( self.transport.get_transport_url('receive.xml'), xml_data) self.assertEqual([], self.tx_helper.get_dispatched_failures()) self.assertEqual([], self.tx_helper.get_dispatched_events()) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['message_id'], '1487577162') self.assertEqual(msg['to_addr'], '32323') self.assertEqual(msg['from_addr'], '+27831234567') self.assertEqual(msg['content'], 'Herb01 spice01') self.assertEqual(msg['transport_metadata'], { 'provider': 'mtn-za' }) self.assertEqual(resp, parse_qs(xml_data)['XmlMsg'][0]) @inlineCallbacks def test_incoming_sms_processing(self): """ it should be able to process in incoming sms as XML delivered via HTTP """ xml_data = """ 282341913 NO mtn-za 20100605 20100604 2010-06-04 15:51:25 +0000 *32323 4 111 373736741 Praekelt 20100603 +27831234567 5 mtn-za +27831234567 Hello World 20222 1 NO +27831234567 Prktl Vumi 20100531 2010-06-30 07:47:00 +0200 2010-06-04 15:51:27 +0000 """.strip() resp = yield http_request( self.transport.get_transport_url('receive.xml'), xml_data) self.assertEqual([], self.tx_helper.get_dispatched_failures()) self.assertEqual([], self.tx_helper.get_dispatched_events()) [msg] = self.tx_helper.get_dispatched_inbound() self.assertEqual(msg['message_id'], '373736741') self.assertEqual(msg['to_addr'], '32323') self.assertEqual(msg['from_addr'], '+27831234567') self.assertEqual(msg['content'], 'Hello World') self.assertEqual(msg['transport_metadata'], { 'provider': 'mtn-za' }) self.assertEqual(resp, xml_data) @inlineCallbacks def test_incoming_sms_no_data(self): resp = yield http_request_full( self.transport.get_transport_url('receive.xml'), None) self.assertEqual([], self.tx_helper.get_dispatched_failures()) self.assertEqual([], self.tx_helper.get_dispatched_events()) self.assertEqual([], self.tx_helper.get_dispatched_inbound()) self.assertEqual(resp.code, 400) self.assertEqual(resp.delivered_body, "XmlMsg missing.") @inlineCallbacks def test_incoming_sms_partial_data(self): xml_data = """ 282341913 NO mtn-za 20100605 20100604 2010-06-04 15:51:25 +0000 4 111 373736741 Praekelt 20100603 +27831234567 5 mtn-za +27831234567 Hello World 20222 1 NO +27831234567 Prktl Vumi 20100531 2010-06-30 07:47:00 +0200 2010-06-04 15:51:27 +0000 """.strip() resp = yield http_request_full( self.transport.get_transport_url('receive.xml'), xml_data) self.assertEqual([], self.tx_helper.get_dispatched_failures()) self.assertEqual([], self.tx_helper.get_dispatched_events()) self.assertEqual([], self.tx_helper.get_dispatched_inbound()) self.assertEqual(resp.code, 400) self.assertEqual(resp.delivered_body, "Missing field: Local") @inlineCallbacks def test_outbound_ok(self): """ Outbound message we send should hit the XML-RPC service with the correct parameters """ def _cb(method_called, xmlrpc_payload): self.assertEqual(method_called, 'EAPIGateway.SendSMS') self.assertEqual(xmlrpc_payload['Priority'], 'standard') self.assertEqual(xmlrpc_payload['SMSText'], 'hello world') self.assertEqual(xmlrpc_payload['Service'], 'service') self.assertEqual(xmlrpc_payload['Receipt'], 'Y') self.assertEqual(xmlrpc_payload['MaxSegments'], 9) self.assertEqual(xmlrpc_payload['Numbers'], '27761234567') self.assertEqual(xmlrpc_payload['Password'], 'password') self.assertEqual(xmlrpc_payload['Channel'], 'channel') now = datetime.utcnow() tomorrow = now + timedelta(days=1) self.assertEqual(xmlrpc_payload['Expiry'].hour, tomorrow.hour) self.assertEqual(xmlrpc_payload['Expiry'].minute, tomorrow.minute) self.assertEqual(xmlrpc_payload['Expiry'].date(), tomorrow.date()) self.assertEqual(xmlrpc_payload['Delivery'].hour, now.hour) self.assertEqual(xmlrpc_payload['Delivery'].minute, now.minute) self.assertEqual(xmlrpc_payload['Delivery'].date(), now.date()) return { 'Identifier': 'abc123' } self.transport.proxy = FakeXMLRPCService(_cb) msg = yield self.tx_helper.make_dispatch_outbound('hello world') self.assertEqual(self.tx_helper.get_dispatched_failures(), []) self.assertEqual(self.tx_helper.get_dispatched_inbound(), []) [event_msg] = self.tx_helper.get_dispatched_events() self.assertEqual(event_msg['message_type'], 'event') self.assertEqual(event_msg['event_type'], 'ack') self.assertEqual(event_msg['sent_message_id'], 'abc123') # test that we've properly linked the identifier to our # internal id of the given message self.assertEqual( (yield self.transport.get_message_id_for_identifier('abc123')), msg['message_id']) @inlineCallbacks def test_outbound_ok_with_metadata(self): """ Outbound message we send should hit the XML-RPC service with the correct parameters """ fixed_date = datetime(2011, 1, 1, 0, 0, 0) def _cb(method_called, xmlrpc_payload): self.assertEqual(xmlrpc_payload['Delivery'], fixed_date) self.assertEqual(xmlrpc_payload['Expiry'], fixed_date + timedelta(hours=1)) self.assertEqual(xmlrpc_payload['Priority'], 'high') self.assertEqual(xmlrpc_payload['Receipt'], 'N') return { 'Identifier': 'abc123' } self.transport.proxy = FakeXMLRPCService(_cb) yield self.tx_helper.make_dispatch_outbound("hi", transport_metadata={ 'deliver_at': fixed_date, 'expire_at': fixed_date + timedelta(hours=1), 'priority': 'high', 'receipt': 'N', }) @inlineCallbacks def test_outbound_temporary_failure(self): """ if for some reason the delivery of the SMS to opera crashes it shouldn't ACK the message over AMQ but leave it for a retry later """ def _cb(*args, **kwargs): """ Callback handler that raises an error when called """ return defer.fail(xmlrpc.Fault(503, 'oh noes!')) # monkey patch so we can mock errors happening remotely self.transport.proxy = FakeXMLRPCService(_cb) # send a message to the transport which'll hit the FakeXMLRPCService # and as a result raise an error yield self.tx_helper.make_dispatch_outbound("hello world") [twisted_failure] = self.flushLoggedErrors(TemporaryFailure) logged_failure = twisted_failure.value self.assertEqual(logged_failure.failure_code, 'temporary') self.assertEqual(self.tx_helper.get_dispatched_events(), []) self.assertEqual(self.tx_helper.get_dispatched_inbound(), []) [failure] = self.tx_helper.get_dispatched_failures() self.assertEqual(failure['failure_code'], 'temporary') original_msg = failure['message'] self.assertEqual(original_msg['to_addr'], '27761234567') self.assertEqual(original_msg['from_addr'], '9292') self.assertEqual(original_msg['content'], 'hello world') @inlineCallbacks def test_outbound_permanent_failure(self): """ if for some reason the Opera XML-RPC service gives us something other than a 200 response it should consider it a permanent failure """ def _cb(*args, **kwargs): """ Callback handler that raises an error when called """ return defer.fail(ValueError(402, 'Payment Required')) # monkey patch so we can mock errors happening remotely self.transport.proxy = FakeXMLRPCService(_cb) # send a message to the transport which'll hit the FakeXMLRPCService # and as a result raise an error msg = yield self.tx_helper.make_dispatch_outbound("hi") [twisted_failure] = self.flushLoggedErrors(PermanentFailure) logged_failure = twisted_failure.value self.assertEqual(logged_failure.failure_code, 'permanent') [failure] = self.tx_helper.get_dispatched_failures() [nack] = yield self.tx_helper.wait_for_dispatched_events(1) self.assertEqual(failure['failure_code'], 'permanent') self.assertEqual(nack['user_message_id'], msg['message_id']) @inlineCallbacks def test_outbound_unicode_encoding(self): """ Opera supports unicode encoded SMS messages as long as they encoded as xmlrpc.Binary, test that. """ content = u'üïéßø' def _cb(method_called, xmlrpc_payload): self.assertEqual(xmlrpc_payload['SMSText'], xmlrpc.Binary(content.encode('utf-8'))) return {'Identifier': '1'} self.transport.proxy = FakeXMLRPCService(_cb) yield self.tx_helper.make_dispatch_outbound(content) PK=JG'vumi/transports/opera/tests/__init__.pyPK=JGqD::::%vumi/blinkenlights/metrics_workers.py# -*- test-case-name: vumi.blinkenlights.tests.test_metrics_workers -*- import time import random import hashlib from datetime import datetime from twisted.python import log from twisted.internet.defer import inlineCallbacks, Deferred from twisted.internet import reactor from twisted.internet.task import LoopingCall from twisted.internet.protocol import DatagramProtocol from vumi.service import Consumer, Publisher, Worker from vumi.blinkenlights.metrics import (MetricsConsumer, MetricManager, Count, Metric, Timer, Aggregator) from vumi.blinkenlights.message20110818 import MetricMessage class AggregatedMetricConsumer(Consumer): """Consumer for aggregate metrics. Parameters ---------- callback : function (metric_name, values) Called for each metric datapoint as it arrives. The parameters are metric_name (str) and values (a list of timestamp and value pairs). """ exchange_name = "vumi.metrics.aggregates" exchange_type = "direct" durable = True routing_key = "vumi.metrics.aggregates" def __init__(self, channel, callback): self.queue_name = self.routing_key super(AggregatedMetricConsumer, self).__init__(channel) self.callback = callback def consume_message(self, vumi_message): msg = MetricMessage.from_dict(vumi_message.payload) for metric_name, _aggregators, values in msg.datapoints(): self.callback(metric_name, values) class AggregatedMetricPublisher(Publisher): """Publishes aggregated metrics. """ exchange_name = "vumi.metrics.aggregates" exchange_type = "direct" durable = True routing_key = "vumi.metrics.aggregates" def publish_aggregate(self, metric_name, timestamp, value): # TODO: perhaps change interface to publish multiple metrics? msg = MetricMessage() msg.append((metric_name, (), [(timestamp, value)])) self.publish_message(msg) class TimeBucketConsumer(Consumer): """Consume time bucketed metric messages. Parameters ---------- bucket : int Bucket to consume time buckets from. callback : function, f(metric_name, aggregators, values) Called for each metric datapoint as it arrives. The parameters are metric_name (str), aggregator (list of aggregator names) and values (a list of timestamp and value pairs). """ exchange_name = "vumi.metrics.buckets" exchange_type = "direct" durable = True ROUTING_KEY_TEMPLATE = "bucket.%d" def __init__(self, channel, bucket, callback): self.queue_name = self.ROUTING_KEY_TEMPLATE % bucket self.routing_key = self.queue_name super(TimeBucketConsumer, self).__init__(channel) self.callback = callback def consume_message(self, vumi_message): msg = MetricMessage.from_dict(vumi_message.payload) for metric_name, aggregators, values in msg.datapoints(): self.callback(metric_name, aggregators, values) class TimeBucketPublisher(Publisher): """Publish time bucketed metric messages. Parameters ---------- buckets : int Total number of buckets messages are being distributed to. bucket_size : int, in seconds Size of each time bucket in seconds. """ exchange_name = "vumi.metrics.buckets" exchange_type = "direct" durable = True ROUTING_KEY_TEMPLATE = "bucket.%d" def __init__(self, buckets, bucket_size): self.buckets = buckets self.bucket_size = bucket_size def find_bucket(self, metric_name, ts_key): md5 = hashlib.md5("%s:%d" % (metric_name, ts_key)) return int(md5.hexdigest(), 16) % self.buckets def publish_metric(self, metric_name, aggregates, values): timestamp_buckets = {} for timestamp, value in values: ts_key = int(timestamp) / self.bucket_size ts_bucket = timestamp_buckets.get(ts_key) if ts_bucket is None: ts_bucket = timestamp_buckets[ts_key] = [] ts_bucket.append((timestamp, value)) for ts_key, ts_bucket in timestamp_buckets.iteritems(): bucket = self.find_bucket(metric_name, ts_key) routing_key = self.ROUTING_KEY_TEMPLATE % bucket msg = MetricMessage() msg.append((metric_name, aggregates, ts_bucket)) self.publish_message(msg, routing_key=routing_key) class MetricTimeBucket(Worker): """Gathers metrics messages and redistributes them to aggregators. :class:`MetricTimeBuckets` take metrics from the vumi.metrics exchange and redistribute them to one of N :class:`MetricAggregator` workers. There can be any number of :class:`MetricTimeBucket` workers. Configuration Values -------------------- buckets : int (N) The total number of aggregator workers. :class:`MetricAggregator` workers must be started with bucket numbers 0 to N-1 otherwise metric data will go missing (or at best be stuck in a queue somewhere). bucket_size : int, in seconds The amount of time each time bucket represents. """ @inlineCallbacks def startWorker(self): log.msg("Starting a MetricTimeBucket with config: %s" % self.config) buckets = int(self.config.get("buckets")) log.msg("Total number of buckets %d" % buckets) bucket_size = int(self.config.get("bucket_size")) log.msg("Bucket size is %d seconds" % bucket_size) self.publisher = yield self.start_publisher(TimeBucketPublisher, buckets, bucket_size) self.consumer = yield self.start_consumer(MetricsConsumer, self.publisher.publish_metric) class DiscardedMetricError(Exception): pass class MetricAggregator(Worker): """Gathers a subset of metrics and aggregates them. :class:`MetricAggregators` work in sets of N. Configuration Values -------------------- bucket : int, 0 to N-1 An aggregator needs to know which number out of N it is. This is its bucket number. bucket_size : int, in seconds The amount of time each time bucket represents. lag : int, seconds, optional The number of seconds after a bucket's time ends to wait before processing the bucket. Default is 5s. """ _time = time.time # hook for faking time in tests def _ts_key(self, time): return int(time) / self.bucket_size @inlineCallbacks def startWorker(self): log.msg("Starting a MetricAggregator with config: %s" % self.config) bucket = int(self.config.get("bucket")) log.msg("MetricAggregator bucket %d" % bucket) self.bucket_size = int(self.config.get("bucket_size")) log.msg("Bucket size is %d seconds" % self.bucket_size) self.lag = float(self.config.get("lag", 5.0)) # ts_key -> { metric_name -> (aggregate_set, values) } # values is a list of (timestamp, value) pairs self.buckets = {} # initialize last processed bucket self._last_ts_key = self._ts_key(self._time() - self.lag) - 2 self.publisher = yield self.start_publisher(AggregatedMetricPublisher) self.consumer = yield self.start_consumer(TimeBucketConsumer, bucket, self.consume_metric) self._task = LoopingCall(self.check_buckets) done = self._task.start(self.bucket_size, False) done.addErrback(lambda failure: log.err(failure, "MetricAggregator bucket checking task died")) def check_buckets(self): """Periodically clean out old buckets and calculate aggregates.""" # key for previous bucket current_ts_key = self._ts_key(self._time() - self.lag) - 1 for ts_key in self.buckets.keys(): if ts_key <= self._last_ts_key: log.err(DiscardedMetricError("Throwing way old metric data: %r" % self.buckets[ts_key])) del self.buckets[ts_key] elif ts_key <= current_ts_key: aggregates = [] ts = ts_key * self.bucket_size items = self.buckets[ts_key].iteritems() for metric_name, (agg_set, values) in items: values = [v for t, v in sorted(values)] for agg_name in agg_set: agg_metric = "%s.%s" % (metric_name, agg_name) agg_func = Aggregator.from_name(agg_name) agg_value = agg_func(values) aggregates.append((agg_metric, agg_value)) for agg_metric, agg_value in aggregates: self.publisher.publish_aggregate(agg_metric, ts, agg_value) del self.buckets[ts_key] self._last_ts_key = current_ts_key def consume_metric(self, metric_name, aggregates, values): if not values: return ts_key = self._ts_key(values[0][0]) metrics = self.buckets.get(ts_key, None) if metrics is None: metrics = self.buckets[ts_key] = {} metric = metrics.get(metric_name) if metric is None: metric = metrics[metric_name] = (set(), []) existing_aggregates, existing_values = metric existing_aggregates.update(aggregates) existing_values.extend(values) def stopWorker(self): self._task.stop() self.check_buckets() class MetricsCollectorWorker(Worker): @inlineCallbacks def startWorker(self): log.msg("Starting %s with config: %s" % ( type(self).__name__, self.config)) yield self.setup_worker() self.consumer = yield self.start_consumer( AggregatedMetricConsumer, self.consume_metrics) def stopWorker(self): log.msg("Stopping %s" % (type(self).__name__,)) return self.teardown_worker() def setup_worker(self): pass def teardown_worker(self): pass def consume_metrics(self, metric_name, values): raise NotImplementedError() class GraphitePublisher(Publisher): """Publisher for sending messages to Graphite.""" exchange_name = "graphite" exchange_type = "topic" durable = True auto_delete = False delivery_mode = 2 def publish_metric(self, metric, value, timestamp): self.publish_raw("%f %d" % (value, timestamp), routing_key=metric) class GraphiteMetricsCollector(MetricsCollectorWorker): """Worker that collects Vumi metrics and publishes them to Graphite.""" @inlineCallbacks def setup_worker(self): self.graphite_publisher = yield self.start_publisher(GraphitePublisher) def consume_metrics(self, metric_name, values): for timestamp, value in values: self.graphite_publisher.publish_metric( metric_name, value, timestamp) class UDPMetricsProtocol(DatagramProtocol): def __init__(self, ip, port): # NOTE: `host` must be an IP, not a hostname. self._ip = ip self._port = port def startProtocol(self): self.transport.connect(self._ip, self._port) def send_metric(self, metric_string): return self.transport.write(metric_string) class UDPMetricsCollector(MetricsCollectorWorker): """Worker that collects Vumi metrics and publishes them over UDP.""" DEFAULT_FORMAT_STRING = '%(timestamp)s %(metric_name)s %(value)s\n' DEFAULT_TIMESTAMP_FORMAT = '%Y-%m-%d %H:%M:%S%z' @inlineCallbacks def setup_worker(self): self.format_string = self.config.get( 'format_string', self.DEFAULT_FORMAT_STRING) self.timestamp_format = self.config.get( 'timestamp_format', self.DEFAULT_TIMESTAMP_FORMAT) self.metrics_ip = yield reactor.resolve(self.config['metrics_host']) self.metrics_port = int(self.config['metrics_port']) self.metrics_protocol = UDPMetricsProtocol( self.metrics_ip, self.metrics_port) self.listener = yield reactor.listenUDP(0, self.metrics_protocol) def teardown_worker(self): return self.listener.stopListening() def consume_metrics(self, metric_name, values): for timestamp, value in values: timestamp = datetime.utcfromtimestamp(timestamp) metric_string = self.format_string % { 'timestamp': timestamp.strftime(self.timestamp_format), 'metric_name': metric_name, 'value': value, } self.metrics_protocol.send_metric(metric_string) class RandomMetricsGenerator(Worker): """Worker that publishes a set of random metrics. Useful for tests and demonstrations. Configuration Values -------------------- manager_period : float in seconds, optional How often to have the internal metric manager send metrics messages. Default is 5s. generator_period: float in seconds, optional How often the random metric loop should send values to the metric manager. Default is 1s. """ # callback for tests, f(worker) # (or anyone else that wants to be notified when metrics are generated) on_run = None @inlineCallbacks def startWorker(self): log.msg("Starting the MetricsGenerator with config: %s" % self.config) manager_period = float(self.config.get("manager_period", 5.0)) log.msg("MetricManager will sent metrics every %s seconds" % manager_period) generator_period = float(self.config.get("generator_period", 1.0)) log.msg("Random metrics values will be generated every %s seconds" % generator_period) self.mm = yield self.start_publisher(MetricManager, "vumi.random.", manager_period) self.counter = self.mm.register(Count("count")) self.value = self.mm.register(Metric("value")) self.timer = self.mm.register(Timer("timer")) self.next = Deferred() self.task = LoopingCall(self.run) self.task.start(generator_period) @inlineCallbacks def run(self): if random.choice([True, False]): self.counter.inc() self.value.set(random.normalvariate(2.0, 0.1)) with self.timer.timeit(): d = Deferred() wait = random.uniform(0.0, 0.1) reactor.callLater(wait, lambda: d.callback(None)) yield d if self.on_run is not None: self.on_run(self) def stopWorker(self): self.mm.stop() self.task.stop() log.msg("Stopping the MetricsGenerator") PKqGE 66vumi/blinkenlights/metrics.py# -*- test-case-name: vumi.blinkenlights.tests.test_metrics -*- """Basic set of functionality for working with blinkenlights metrics. Includes a publisher, a consumer and a set of simple metrics. """ import time import warnings from twisted.internet.task import LoopingCall from twisted.python import log from zope.interface import Interface, implementer from vumi.service import Publisher, Consumer from vumi.blinkenlights.message20110818 import MetricMessage class IMetricPublisher(Interface): def publish_message(msg): """ Publish a :class:`MetricMessage`. """ @implementer(IMetricPublisher) class MetricPublisher(Publisher): """ Publisher for metrics messages. """ exchange_name = "vumi.metrics" exchange_type = "direct" routing_key = "vumi.metrics" durable = True auto_delete = False delivery_mode = 2 class MetricManager(object): """Utility for creating and monitoring a set of metrics. :type prefix: str :param prefix: Prefix for the name of all metrics registered with this manager. :type publish_interval: int in seconds :param publish_interval: How often to publish the set of metrics. :type on_publish: f(metric_manager) :param on_publish: Function to call immediately after metrics after published. """ def __init__(self, prefix, publish_interval=5, on_publish=None, publisher=None): self.prefix = prefix self._metrics = [] # list of metrics to poll self._oneshot_msgs = [] # list of oneshot messages since last publish self._metrics_lookup = {} # metric name -> metric self._publish_interval = publish_interval self._task = None # created in .start() self._on_publish = on_publish self._publisher = publisher def start_polling(self): """ Start the metric polling and publishing task. """ self._task = LoopingCall(self.publish_metrics) done = self._task.start(self._publish_interval, now=False) done.addErrback(lambda failure: log.err(failure, "MetricManager polling task died")) def stop_polling(self): """ Stop the metric polling and publishing task. """ if self._task: if self._task.running: self._task.stop() self._task = None def publish_metrics(self): """ Publish all waiting metrics. """ msg = MetricMessage() self._collect_oneshot_metrics(msg) self._collect_polled_metrics(msg) self.publish_message(msg) if self._on_publish is not None: self._on_publish(self) def publish_message(self, msg): if self._publisher is None: raise ValueError("No publisher available.") IMetricPublisher(self._publisher).publish_message(msg) def _collect_oneshot_metrics(self, msg): oneshots, self._oneshot_msgs = self._oneshot_msgs, [] for metric, values in oneshots: msg.append((self.prefix + metric.name, metric.aggs, values)) def _collect_polled_metrics(self, msg): for metric in self._metrics: msg.append((self.prefix + metric.name, metric.aggs, metric.poll())) def oneshot(self, metric, value): """Publish a single value for the given metric. :type metric: :class:`Metric` :param metric: Metric object to register. Will have the manager's prefix added to its name. :type value: float :param value: The value to publish for the metric. """ self._oneshot_msgs.append( (metric, [(int(time.time()), value)])) def register(self, metric): """Register a new metric object to be managed by this metric set. A metric can be registered with only one metric set. :type metric: :class:`Metric` :param metric: Metric object to register. The metric will have its `.manage()` method called with this manager as the manager. :rtype: For convenience, returns the metric passed in. """ metric.manage(self) self._metrics.append(metric) if metric.name in self._metrics_lookup: raise MetricRegistrationError("Duplicate metric name %s" % metric.name) self._metrics_lookup[metric.name] = metric return metric def __getitem__(self, suffix): return self._metrics_lookup[suffix] def __contains__(self, suffix): return suffix in self._metrics_lookup # Everything from this point onward is to allow MetricManager to pretend to # be a publisher and avoid breaking existing code that treats it as one. exchange_name = MetricPublisher.exchange_name exchange_type = MetricPublisher.exchange_type durable = MetricPublisher.durable _publish_metrics = publish_metrics # For old tests that poke this. def start(self, channel): """Start publishing metrics in a loop.""" if self._publisher is not None: raise RuntimeError("Publisher already present.") self._publisher = MetricPublisher() self._publisher.start(channel) self.start_polling() def stop(self): """Stop publishing metrics.""" self.stop_polling() class AggregatorAlreadyDefinedError(Exception): pass class Aggregator(object): """Registry of aggregate functions for metrics. :type name: str :param name: Short name for the aggregator. :type func: f(list of values) -> float :param func: The aggregation function. Should return a default value if the list of values is empty (usually this default is 0.0). """ REGISTRY = {} def __init__(self, name, func): if name in self.REGISTRY: raise AggregatorAlreadyDefinedError(name) self.name = name self.func = func self.REGISTRY[name] = self @classmethod def from_name(cls, name): return cls.REGISTRY[name] def __call__(self, values): return self.func(values) SUM = Aggregator("sum", sum) AVG = Aggregator("avg", lambda values: sum(values) / len(values) if values else 0.0) MAX = Aggregator("max", lambda values: max(values) if values else 0.0) MIN = Aggregator("min", lambda values: min(values) if values else 0.0) LAST = Aggregator("last", lambda values: values[-1] if values else 0.0) class MetricRegistrationError(Exception): pass class Metric(object): """Simple metric. Values set are collected and polled periodically by the metric manager. :type name: str :param name: Name of this metric. Will be appened to the :class:`MetricManager` prefix when this metric is published. :type aggregators: list of aggregators, optional :param aggregators: List of aggregation functions to request eventually be applied to this metric. The default is to average the value. Examples: >>> mm = MetricManager('vumi.worker0.') >>> my_val = mm.register(Metric('my.value')) >>> my_val.set(1.5) >>> my_val.name 'my.value' """ #: Default aggregators are [:data:`AVG`] DEFAULT_AGGREGATORS = [AVG] def __init__(self, name, aggregators=None): if aggregators is None: aggregators = self.DEFAULT_AGGREGATORS self.name = name self.aggs = tuple(sorted(agg.name for agg in aggregators)) self._manager = None self._values = [] # list of unpolled values @property def managed(self): return self._manager is not None def manage(self, manager): """Called by :class:`MetricManager` when this metric is registered.""" if self._manager is not None: raise MetricRegistrationError( "Metric %s already registered with MetricManager with" " prefix %s." % (self.name, self._manager.prefix)) self._manager = manager def set(self, value): """Append a value for later polling.""" self._values.append((int(time.time()), value)) def poll(self): """Called periodically by the :class:`MetricManager`.""" values, self._values = self._values, [] return values class Count(Metric): """A simple counter. Examples: >>> mm = MetricManager('vumi.worker0.') >>> my_count = mm.register(Count('my.count')) >>> my_count.inc() """ #: Default aggregators are [:data:`SUM`] DEFAULT_AGGREGATORS = [SUM] def inc(self): """Increment the count by 1.""" self.set(1.0) class TimerError(Exception): """Raised when an error occurs in a call to an EventTimer method.""" class TimerAlreadyStartedError(TimerError): """Raised when attempting to start an EventTimer that is already started. """ class TimerNotStartedError(TimerError): """Raised when attempting to stop an EventTimer that was not started. """ class TimerAlreadyStoppedError(TimerError): """Raised when attempting to stop an EventTimer that is already stopped. """ class EventTimer(object): def __init__(self, timer, start=False): self._timer = timer self._start_time = None self._stop_time = None if start: self.start() def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() return False def start(self): if self._start_time is not None: raise TimerAlreadyStartedError("Attempt to start timer %r that" " was already started" % (self._timer.name,)) self._start_time = time.time() def stop(self): if self._start_time is None: raise TimerNotStartedError("Attempt to stop timer %r that" " has not been started" % (self._timer.name,)) if self._stop_time is not None: raise TimerAlreadyStoppedError("Attempt to stop timer %r that" " has already been stopped" % (self._timer.name,)) self._stop_time = time.time() self._timer.set(self._stop_time - self._start_time) class Timer(Metric): """A metric that records time spent on operations. Examples: >>> mm = MetricManager('vumi.worker0.') >>> my_timer = mm.register(Timer('hard.work')) Using the timer as a context manager: >>> with my_timer.timeit(): >>> process_data() Using the timer without a context manager: >>> event_timer = my_timer.timeit() >>> event_timer.start() >>> d = process_other_data() >>> d.addCallback(lambda r: event_timer.stop()) Note that timers returned by `timeit` may only have `start` and `stop` called on them once (and only in that order). .. note:: Using ``.start()`` or ``.stop()`` directly or via using the :class:`Timer` instance itself as a context manager is deprecated because they are not re-entrant and it's easy to accidentally overlap multiple calls to ``.start()`` and ``.stop()`` on the same :class:`Timer` instance (e.g. by letting the reactor run in between). All applications should be updated to use ``.timeit()``. Deprecated use of ``.start()`` and ``.stop()``: >>> my_timer.start() >>> try: >>> process_other_data() >>> finally: >>> my_timer.stop() Deprecated use of ``.start()`` and ``.stop()`` via using the :class:`Timer` itself as a context manager: >>> with my_timer: >>> process_more_data() """ #: Default aggregators are [:data:`AVG`] DEFAULT_AGGREGATORS = [AVG] def __init__(self, *args, **kws): super(Timer, self).__init__(*args, **kws) self._event_timer = EventTimer(self) def __enter__(self): warnings.warn( "Use of Timer directly as a context manager is deprecated." " Please use Timer.timeit() instead.", DeprecationWarning) return self._event_timer.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): result = self._event_timer.__exit__(exc_type, exc_val, exc_tb) self._event_timer = EventTimer(self) return result def timeit(self, start=False): return EventTimer(self, start=start) def start(self): warnings.warn( "Use of Timer.start() is deprecated." " Please use Timer.timeit() instead.", DeprecationWarning) return self._event_timer.start() def stop(self): result = self._event_timer.stop() self._event_timer = EventTimer(self) return result class MetricsConsumer(Consumer): """Utility for consuming metrics published by :class:`MetricManager`s. :type callback: f(metric_name, aggregators, values) :param callback: Called for each metric datapoint as it arrives. The parameters are metric_name (str), aggregator (list of aggregator names) and values (a list of timestamp and value paits). """ exchange_name = "vumi.metrics" exchange_type = "direct" routing_key = "vumi.metrics" durable = True def __init__(self, channel, callback): self.queue_name = self.routing_key super(MetricsConsumer, self).__init__(channel) self.callback = callback def consume_message(self, vumi_message): msg = MetricMessage.from_dict(vumi_message.payload) for metric_name, aggregators, values in msg.datapoints(): self.callback(metric_name, aggregators, values) PK=JG'LLvumi/blinkenlights/__init__.py"""Vumi monitoring and control framework.""" from vumi.blinkenlights.metrics_workers import (MetricTimeBucket, MetricAggregator, GraphiteMetricsCollector) __all__ = ["MetricTimeBucket", "MetricAggregator", "GraphiteMetricsCollector"] PK=JG5  %vumi/blinkenlights/message20110707.py# -*- test-case-name: vumi.blinkenlights.tests.test_message20110707 -*- from datetime import datetime from vumi import message as vumi_message class Message(object): """ Blinkenlights message object. This sits inside a Vumi message, and works with decoded JSON data. """ VERSION = "20110707" MESSAGE_TYPE = None REQUIRED_FIELDS = ( # Excludes message_version, which is handled differently. "message_type", "source_name", "source_id", "payload", "timestamp", ) def __init__(self, message_type, source_name, source_id, payload, timestamp=None): self.source_name = source_name self.source_id = source_id self.message_type = message_type self.payload = payload if timestamp is None: timestamp = datetime.utcnow() if not isinstance(timestamp, datetime): # Assume it's a list or tuple here timestamp = datetime(*timestamp) self.timestamp = timestamp if self.MESSAGE_TYPE and self.MESSAGE_TYPE != self.message_type: raise ValueError("Incorrect message type. Expected '%s', got" " '%s'." % (self.MESSAGE_TYPE, self.message_type)) self.process_payload() def process_payload(self): pass def to_dict(self): message = {'message_version': self.VERSION} message.update(dict((field, getattr(self, field)) for field in self.REQUIRED_FIELDS)) # Massage the timestamp into the serialised list we use message['timestamp'] = list(self.timestamp.timetuple()[:6]) return message def to_vumi_message(self): return vumi_message.Message(**self.to_dict()) @classmethod def from_dict(cls, message): message = message.copy() # So we can modify it safely version = message.pop('message_version') if version != cls.VERSION: raise ValueError("Incorrect message version. Expected '%s', got" " '%s'." % (cls.VERSION, version)) for field in cls.REQUIRED_FIELDS: if field not in message: raise ValueError("Missing mandatory field '%s'." % (field,)) for field in message: if field not in cls.REQUIRED_FIELDS: raise ValueError("Found unexpected field '%s'." % (field,)) if not message['timestamp']: raise ValueError("Missing timestamp in field 'timestamp'.") return cls(**message) def __str__(self): return u"" % ( self.VERSION, self.message_type, self.timestamp, self.source_name, self.source_id, repr(self.payload)) def __eq__(self, other): if self.VERSION != other.VERSION: return False if self.REQUIRED_FIELDS != other.REQUIRED_FIELDS: return False for field in self.REQUIRED_FIELDS: if getattr(self, field) != getattr(other, field): return False return True class MetricsMessage(Message): MESSAGE_TYPE = "metrics" def process_payload(self): self.metrics = {} for metric in self.payload: name = metric['name'] count = metric['count'] time = metric.get('time', None) tags = dict(i for i in metric.items() if i[0] not in ('name', 'count', 'time')) self.metrics.setdefault(name, []).append((count, time, tags)) PK=JGb@@%vumi/blinkenlights/message20110818.py# -*- test-case-name: vumi.blinkenlights.tests.test_message20110818 -*- from vumi.message import Message class MetricMessage(Message): """Class representing Vumi metrics messages. A metrics message is a list of (metric_name, timestamp, float value) data points and a small amount of metadata: * `metric_name` is a dotted byte string, e.g. 'vumi.w1.my_metric'. * `timestamp` is a float giving seconds since the POSIX Epoch, e.g. time.time(). * `value` is any float. """ def __init__(self): self._datapoints = [] super(MetricMessage, self).__init__(datapoints=self._datapoints) def append(self, datapoint): self._datapoints.append(datapoint) def extend(self, datapoints): self._datapoints.extend(datapoints) def datapoints(self): return self._datapoints def to_dict(self): return { 'datapoints': self._datapoints, } @classmethod def from_dict(cls, msgdict): msg = cls() msg.extend(msgdict['datapoints']) return msg PK=JG0%%'vumi/blinkenlights/heartbeat/monitor.py# -*- test-case-name: vumi.blinkenlights.heartbeat.tests.test_monitor -*- import time import collections import json from twisted.internet.defer import inlineCallbacks from twisted.internet.task import LoopingCall from vumi.worker import BaseWorker from vumi.config import ConfigDict, ConfigInt from vumi.blinkenlights.heartbeat.publisher import HeartBeatMessage from vumi.blinkenlights.heartbeat.storage import Storage from vumi.persist.txredis_manager import TxRedisManager from vumi.utils import generate_worker_id from vumi.errors import ConfigError from vumi import log WorkerIssue = collections.namedtuple('WorkerIssue', ['issue_type', 'start_time', 'procs_count']) def assert_field(cfg, key): """ helper to check whether a config key is defined. Only used for verifying dict fields in the new-style configs """ if key not in cfg: raise ConfigError("Expected '%s' field in config" % key) class WorkerInstance(object): """Represents a worker instance. A hostname, process id pair uniquely identify a worker instance. """ def __init__(self, hostname, pid): self.hostname = hostname self.pid = pid def __eq__(self, obj): if not isinstance(obj, WorkerInstance): return NotImplemented return (self.hostname == obj.hostname and self.pid == obj.pid) def __hash__(self): return hash((self.hostname, self.pid)) class Worker(object): def __init__(self, system_id, worker_name, min_procs): self.system_id = system_id self.name = worker_name self.min_procs = min_procs self.worker_id = generate_worker_id(system_id, worker_name) self._instances = set() self._instances_active = set() self.procs_count = 0 def to_dict(self): """Serializes information into basic dicts""" counts = self._compute_host_info(self._instances) hosts = [] for host, count in counts.iteritems(): hosts.append({ 'host': host, 'proc_count': count, }) obj = { 'id': self.worker_id, 'name': self.name, 'system_id': self.system_id, 'min_procs': self.min_procs, 'hosts': hosts, } return obj def _compute_host_info(self, instances): """Compute the number of worker instances running on each host.""" counts = {} # initialize per-host counters for ins in instances: counts[ins.hostname] = 0 # update counters for each instance for ins in instances: counts[ins.hostname] = counts[ins.hostname] + 1 return counts @inlineCallbacks def audit(self, storage): """ Verify whether enough workers checked in. Make sure to call snapshot() before running this method """ count = len(self._instances) # if there was previously a min-procs-fail, but now enough # instances checked in, then clear the worker issue if (count >= self.min_procs) and (self.procs_count < self.min_procs): yield storage.delete_worker_issue(self.worker_id) if count < self.min_procs: issue = WorkerIssue("min-procs-fail", time.time(), count) yield storage.open_or_update_issue(self.worker_id, issue) self.procs_count = count def snapshot(self): """ This method must be run before any diagnostic audit and analyses What it does is clear the instances_active set in preparation for all the instances which will check-in in the next interval. All diagnostics are based on the _instances_active set, which holds all the instances which checked-in the previous interval. """ self._instances = self._instances_active self._instances_active = set() def record(self, hostname, pid): """Record that process (hostname,pid) checked in.""" self._instances_active.add(WorkerInstance(hostname, pid)) class System(object): def __init__(self, system_name, system_id, workers): self.name = system_name self.system_id = system_id self.workers = workers def to_dict(self): """Serialize information to basic dicts""" obj = { 'name': self.name, 'id': self.system_id, 'timestamp': int(time.time()), 'workers': [wkr.to_dict() for wkr in self.workers], } return obj def dumps(self): """Dump to a JSON string""" return json.dumps(self.to_dict()) class HeartBeatMonitor(BaseWorker): class CONFIG_CLASS(BaseWorker.CONFIG_CLASS): deadline = ConfigInt( "Check-in deadline for participating workers", required=True, static=True) redis_manager = ConfigDict( "Redis client configuration.", required=True, static=True) monitored_systems = ConfigDict( "Tree of systems and workers.", required=True, static=True) _task = None @inlineCallbacks def startWorker(self): log.msg("Heartbeat monitor initializing") config = self.get_static_config() self.deadline = config.deadline redis_config = config.redis_manager self._redis = yield TxRedisManager.from_config(redis_config) self._storage = Storage(self._redis) self._systems, self._workers = self.parse_config( config.monitored_systems) # Start consuming heartbeats yield self.consume("heartbeat.inbound", self._consume_message, exchange_name='vumi.health', message_class=HeartBeatMessage) self._start_task() @inlineCallbacks def stopWorker(self): log.msg("HeartBeat: stopping worker") if self._task: self._task.stop() self._task = None yield self._task_done self._redis.close_manager() def parse_config(self, config): """ Parse configuration and populate in-memory state """ systems = [] workers = {} # loop over each defined system for sys in config.values(): assert_field(sys, 'workers') assert_field(sys, 'system_id') system_id = sys['system_id'] system_workers = [] # loop over each defined worker in the system for wkr_entry in sys['workers'].values(): assert_field(wkr_entry, 'name') assert_field(wkr_entry, 'min_procs') worker_name = wkr_entry['name'] min_procs = wkr_entry['min_procs'] wkr = Worker(system_id, worker_name, min_procs) workers[wkr.worker_id] = wkr system_workers.append(wkr) systems.append(System(system_id, system_id, system_workers)) return systems, workers def update(self, msg): """ Process a heartbeat message. """ worker_id = msg['worker_id'] timestamp = msg['timestamp'] hostname = msg['hostname'] pid = msg['pid'] # A bunch of discard rules: # 1. Unknown worker (Monitored workers need to be in the config) # 2. Message which are too old. wkr = self._workers.get(worker_id, None) if wkr is None: log.msg("Discarding message. worker '%s' is unknown" % worker_id) return if timestamp < (time.time() - self.deadline): log.msg("Discarding heartbeat from '%s'. Too old" % worker_id) return wkr.record(hostname, pid) @inlineCallbacks def _sync_to_storage(self): """ Write systems data to storage """ # write system ids system_ids = [sys.system_id for sys in self._systems] yield self._storage.add_system_ids(system_ids) # dump each system for sys in self._systems: yield self._storage.write_system(sys) @inlineCallbacks def _periodic_task(self): """ Iterate over worker instance sets and check to see whether any have not checked-in on time. We call snapshot() first, since the execution of tasks here is interleaved with the processing of worker heartbeat messages. """ # snapshot the the set of checked-in instances for wkr in self._workers.values(): wkr.snapshot() # run diagnostic audits on all workers for wkr in self._workers.values(): yield wkr.audit(self._storage) # write everything to redis yield self._sync_to_storage() def _start_task(self): """Create a timer task to check for missing worker""" self._task = LoopingCall(self._periodic_task) self._task_done = self._task.start(self.deadline, now=False) errfn = lambda failure: log.err(failure, "Heartbeat verify: timer task died") self._task_done.addErrback(errfn) def _consume_message(self, msg): log.msg("Received message: %s" % msg) self.update(msg.payload) PK=JGUsII'vumi/blinkenlights/heartbeat/storage.py# -*- test-case-name: vumi.blinkenlights.heartbeat.tests.test_storage -*- """ Storage Schema: Timestamp (UNIX timestamp): key = timestamp List of systems (JSON list): key = systems System state (JSON dict): key = system:$SYSTEM_ID Worker issue (JSON dict): key = worker:$WORKER_ID:issue """ import json from vumi.persist.redis_base import Manager TIMESTAMP_KEY = "timestamp" SYSTEMS_KEY = "systems" def issue_key(worker_id): return "worker:%s:issue" % worker_id def system_key(system_id): return "system:%s" % system_id class Storage(object): """ TxRedis interface for the heartbeat monitor. Basically only supports mutating operations since the monitor does not do any reads """ def __init__(self, redis): self._redis = redis self.manager = redis @Manager.calls_manager def add_system_ids(self, system_ids): yield self._redis.sadd(SYSTEMS_KEY, *system_ids) @Manager.calls_manager def write_system(self, sys): key = system_key(sys.system_id) yield self._redis.set(key, sys.dumps()) def _issue_to_dict(self, issue): return { 'issue_type': issue.issue_type, 'start_time': issue.start_time, 'procs_count': issue.procs_count, } @Manager.calls_manager def delete_worker_issue(self, worker_id): key = issue_key(worker_id) yield self._redis.delete(key) @Manager.calls_manager def open_or_update_issue(self, worker_id, issue): key = issue_key(worker_id) issue_raw = yield self._redis.get(key) if issue_raw is None: issue_data = self._issue_to_dict(issue) else: issue_data = json.loads(issue_raw) issue_data['procs_count'] = issue.procs_count yield self._redis.set(key, json.dumps(issue_data)) PK=JG߂)vumi/blinkenlights/heartbeat/publisher.py# -*- test-case-name: vumi.blinkenlights.heartbeat.tests.test_publisher -*- from twisted.internet.task import LoopingCall from vumi.service import Publisher from vumi.message import Message from vumi import log class HeartBeatMessage(Message): """ Basically just a wrapper around a dict for now, with some minimal validation and version identification """ VERSION_20130319 = "20130319" def __init__(self, **kw): super(HeartBeatMessage, self).__init__(**kw) def validate_fields(self): # these basic fields must be present, irrespective of version self.assert_field_present( 'version', 'system_id', 'worker_id', 'worker_name', 'hostname', 'pid', ) class HeartBeatPublisher(Publisher): """ A publisher which send periodic heartbeat messages to the AMQP heartbeat.inbound queue """ HEARTBEAT_PERIOD_SECS = 10 def __init__(self, gen_attrs_func): self.routing_key = "heartbeat.inbound" self.exchange_name = "vumi.health" self.durable = True self._task = None self._gen_attrs_func = gen_attrs_func def _beat(self): """ Read various host and worker attributes and wrap them in a message """ attrs = self._gen_attrs_func() msg = HeartBeatMessage(**attrs) self.publish_message(msg) def start(self, channel): super(HeartBeatPublisher, self).start(channel) self._start_looping_task() def _start_looping_task(self): self._task = LoopingCall(self._beat) done = self._task.start(HeartBeatPublisher.HEARTBEAT_PERIOD_SECS, now=False) done.addErrback( lambda failure: log.err(failure, "HeartBeatPublisher task died")) def stop(self): """Stop publishing metrics.""" if self._task: self._task.stop() self._task = None PK=JGȈ(vumi/blinkenlights/heartbeat/__init__.py"""Vumi worker heartbeating.""" from vumi.blinkenlights.heartbeat.publisher import (HeartBeatMessage, HeartBeatPublisher) __all__ = ["HeartBeatMessage", "HeartBeatPublisher"] PK=JG6 ϥ!!2vumi/blinkenlights/heartbeat/tests/test_monitor.py# -*- encoding: utf-8 -*- """Tests for vumi.blinkenlights.heartbeat.monitor""" import time import json from twisted.internet.defer import inlineCallbacks from vumi.blinkenlights.heartbeat import publisher from vumi.blinkenlights.heartbeat import monitor from vumi.blinkenlights.heartbeat.storage import issue_key from vumi.utils import generate_worker_id from vumi.tests.helpers import VumiTestCase, WorkerHelper, PersistenceHelper def expected_wkr_dict(): wkr = { 'id': 'system-1:foo', 'name': 'foo', 'system_id': 'system-1', 'min_procs': 1, 'hosts': [{'host': 'host-1', 'proc_count': 1}], } return wkr def expected_sys_dict(): sys = { 'name': 'system-1', 'id': 'system-1', 'timestamp': int(435), 'workers': [expected_wkr_dict()], } return sys class TestWorkerInstance(VumiTestCase): def test_create(self): worker = monitor.WorkerInstance('foo', 34) self.assertEqual(worker.hostname, 'foo') self.assertEqual(worker.pid, 34) def test_equiv(self): self.assertEqual(monitor.WorkerInstance('foo', 34), monitor.WorkerInstance('foo', 34)) self.failIfEqual(monitor.WorkerInstance('foo', 4), monitor.WorkerInstance('foo', 34)) self.failIfEqual(monitor.WorkerInstance('fo', 34), monitor.WorkerInstance('foo', 34)) def test_hash(self): worker1 = monitor.WorkerInstance('foo', 34) worker2 = monitor.WorkerInstance('foo', 34) worker3 = monitor.WorkerInstance('foo', 35) worker4 = monitor.WorkerInstance('bar', 34) self.assertEqual(hash(worker1), hash(worker2)) self.assertNotEqual(hash(worker1), hash(worker3)) self.assertNotEqual(hash(worker1), hash(worker4)) class TestWorker(VumiTestCase): def test_to_dict(self): wkr = monitor.Worker('system-1', 'foo', 1) wkr.record('host-1', 34) wkr.snapshot() obj = wkr.to_dict() self.assertEqual(obj, expected_wkr_dict()) def test_compute_host_info(self): wkr = monitor.Worker('system-1', 'foo', 1) wkr.record('host-1', 34) wkr.record('host-1', 546) wkr.snapshot() counts = wkr._compute_host_info(wkr._instances) self.assertEqual(counts['host-1'], 2) def test_snapshot(self): wkr = monitor.Worker('system-1', 'foo', 1) wkr.record('host-1', 34) wkr.record('host-1', 546) self.assertEqual(len(wkr._instances_active), 2) self.assertEqual(len(wkr._instances), 0) wkr.snapshot() self.assertEqual(len(wkr._instances_active), 0) self.assertEqual(len(wkr._instances), 2) class TestSystem(VumiTestCase): def test_to_dict(self): wkr = monitor.Worker('system-1', 'foo', 1) sys = monitor.System('system-1', 'system-1', [wkr]) wkr.record('host-1', 34) wkr.snapshot() obj = sys.to_dict() obj['timestamp'] = 435 self.assertEqual(obj, expected_sys_dict()) def test_dumps(self): wkr = monitor.Worker('system-1', 'foo', 1) sys = monitor.System('system-1', 'system-1', [wkr]) wkr.record('host-1', 34) wkr.snapshot() obj_json = sys.dumps() obj = json.loads(obj_json) obj['timestamp'] = 435 self.assertEqual(obj, expected_sys_dict()) class TestHeartBeatMonitor(VumiTestCase): @inlineCallbacks def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) self.worker_helper = self.add_helper(WorkerHelper()) config = { 'deadline': 30, 'redis_manager': { 'key_prefix': 'heartbeats', 'db': 5, 'FAKE_REDIS': True, }, 'monitored_systems': { 'system-1': { 'system_name': 'system-1', 'system_id': 'system-1', 'workers': { 'twitter_transport': { 'name': 'twitter_transport', 'min_procs': 2, } } } } } self.worker = yield self.worker_helper.get_worker( monitor.HeartBeatMonitor, config, start=False) def gen_fake_attrs(self, timestamp): sys_id = 'system-1' wkr_name = 'twitter_transport' wkr_id = generate_worker_id(sys_id, wkr_name) attrs = { 'version': publisher.HeartBeatMessage.VERSION_20130319, 'system_id': sys_id, 'worker_id': wkr_id, 'worker_name': wkr_name, 'hostname': "test-host-1", 'timestamp': timestamp, 'pid': 345, } return attrs @inlineCallbacks def test_update(self): # Test the processing of a message. yield self.worker.startWorker() attrs1 = self.gen_fake_attrs(time.time()) attrs2 = self.gen_fake_attrs(time.time()) # process the fake message (and process it twice to verify idempotency) self.worker.update(attrs1) self.worker.update(attrs1) # retrieve the instance set corresponding to the worker_id in the # fake message wkr = self.worker._workers[attrs1['worker_id']] self.assertEqual(len(wkr._instances_active), 1) inst = wkr._instances_active.pop() wkr._instances_active.add(inst) self.assertEqual(inst.hostname, "test-host-1") self.assertEqual(inst.pid, 345) # now process a message from another instance of the worker # and verify that there are two recorded instances attrs2['hostname'] = 'test-host-2' self.worker.update(attrs2) self.assertEqual(len(wkr._instances_active), 2) @inlineCallbacks def test_audit_fail(self): # here we test the verification of a worker who # who had less than min_procs check in yield self.worker.startWorker() fkredis = self.worker._redis attrs = self.gen_fake_attrs(time.time()) wkr_id = attrs['worker_id'] # process the fake message () yield self.worker.update(attrs) wkr = self.worker._workers[attrs['worker_id']] wkr.snapshot() yield wkr.audit(self.worker._storage) # test that an issue was opened self.assertEqual(wkr.procs_count, 1) key = issue_key(wkr_id) issue = json.loads((yield fkredis.get(key))) self.assertEqual(issue['issue_type'], 'min-procs-fail') @inlineCallbacks def test_audit_pass(self): # here we test the verification of a worker who # who had more than min_procs check in yield self.worker.startWorker() fkredis = self.worker._redis attrs = self.gen_fake_attrs(time.time()) wkr_id = attrs['worker_id'] # process the fake message () yield self.worker.update(attrs) attrs['pid'] = 2342 yield self.worker.update(attrs) wkr = self.worker._workers[attrs['worker_id']] wkr.snapshot() yield wkr.audit(self.worker._storage) # verify that no issue has been opened self.assertEqual(wkr.procs_count, 2) key = issue_key(wkr_id) issue = yield fkredis.get(key) self.assertEqual(issue, None) @inlineCallbacks def test_serialize_to_redis(self): # This covers a lot of the serialization methods # as well as the _sync_to_storage() function. yield self.worker.startWorker() fkredis = self.worker._redis attrs = self.gen_fake_attrs(time.time()) # process the fake message self.worker.update(attrs) yield self.worker._periodic_task() # this blob is what should be persisted into redis (as JSON) expected = { u'name': u'system-1', u'id': u'system-1', u'timestamp': 2, u'workers': [{ u'id': u'system-1:twitter_transport', u'name': u'twitter_transport', u'system_id': u'system-1', u'min_procs': 2, u'hosts': [{u'host': u'test-host-1', u'proc_count': 1}] }], } # verify that the system data was persisted correctly system = json.loads((yield fkredis.get('system:system-1'))) system['timestamp'] = 2 self.assertEqual(system, expected) PK=JG.vumi/blinkenlights/heartbeat/tests/__init__.pyPKqG,-4vumi/blinkenlights/heartbeat/tests/test_publisher.py# -*- encoding: utf-8 -*- """Tests for vumi.blinkenlights.heartbeat.publisher""" import json from twisted.internet.defer import inlineCallbacks from vumi.tests.fake_amqp import FakeAMQPBroker from vumi.blinkenlights.heartbeat import publisher from vumi.errors import MissingMessageField from vumi.tests.helpers import VumiTestCase, WorkerHelper class MockHeartBeatPublisher(publisher.HeartBeatPublisher): # stub out the LoopingCall task def _start_looping_task(self): self._task = None class TestHeartBeatPublisher(VumiTestCase): def gen_fake_attrs(self): attrs = { 'version': publisher.HeartBeatMessage.VERSION_20130319, 'system_id': "system-1", 'worker_id': "worker-1", 'worker_name': "worker-1", 'hostname': "test-host-1", 'timestamp': 100, 'pid': 43, } return attrs @inlineCallbacks def test_publish_heartbeat(self): broker = FakeAMQPBroker() client = WorkerHelper.get_fake_amqp_client(broker) channel = yield client.get_channel() pub = MockHeartBeatPublisher(self.gen_fake_attrs) pub.start(channel) pub._beat() [msg] = broker.get_dispatched("vumi.health", "heartbeat.inbound") self.assertEqual(json.loads(msg.body), self.gen_fake_attrs()) def test_message_validation(self): attrs = self.gen_fake_attrs() attrs.pop("version") self.assertRaises(MissingMessageField, publisher.HeartBeatMessage, **attrs) attrs = self.gen_fake_attrs() attrs.pop("system_id") self.assertRaises(MissingMessageField, publisher.HeartBeatMessage, **attrs) attrs = self.gen_fake_attrs() attrs.pop("worker_id") self.assertRaises(MissingMessageField, publisher.HeartBeatMessage, **attrs) PK=JG%k  2vumi/blinkenlights/heartbeat/tests/test_storage.py# -*- encoding: utf-8 -*- """Tests for vumi.blinkenlights.heartbeat.monitor""" import json from twisted.internet.defer import inlineCallbacks from vumi.persist.txredis_manager import TxRedisManager from vumi.blinkenlights.heartbeat import storage from vumi.blinkenlights.heartbeat import monitor from vumi.tests.helpers import VumiTestCase class DummySystem: def __init__(self): self.system_id = 'haha' def dumps(self): return "Ha!" class TestStorage(VumiTestCase): @inlineCallbacks def setUp(self): config = { 'key_prefix': 'heartbeats', 'db': 5, 'FAKE_REDIS': True, } self.redis = yield TxRedisManager.from_config(config) self.add_cleanup(self.cleanup_redis) self.stg = storage.Storage(self.redis) @inlineCallbacks def cleanup_redis(self): yield self.redis._purge_all() yield self.redis.close_manager() @inlineCallbacks def test_add_system_ids(self): yield self.stg.add_system_ids(['foo', 'bar']) yield self.stg.add_system_ids(['bar']) res = yield self.redis.smembers(storage.SYSTEMS_KEY) self.assertEqual(sorted(res), sorted(['foo', 'bar'])) @inlineCallbacks def test_write_system(self): yield self.stg.write_system(DummySystem()) res = yield self.redis.get(storage.system_key('haha')) self.assertEqual(res, 'Ha!') @inlineCallbacks def test_delete_issue(self): iss = monitor.WorkerIssue('min-procs-fail', 5, 78) yield self.stg.open_or_update_issue('worker-1', iss) res = yield self.redis.get(storage.issue_key('worker-1')) self.assertEqual(type(res), str) yield self.stg.delete_worker_issue('worker-1') res = yield self.redis.get(storage.issue_key('worker-1')) self.assertEqual(res, None) @inlineCallbacks def test_open_or_update_issue(self): obj = { 'issue_type': 'min-procs-fail', 'start_time': 5, 'procs_count': 78, } iss = monitor.WorkerIssue('min-procs-fail', 5, 78) yield self.stg.open_or_update_issue('foo', iss) res = yield self.redis.get(storage.issue_key('foo')) self.assertEqual(res, json.dumps(obj)) # now update the issue iss = monitor.WorkerIssue('min-procs-fail', 5, 77) obj['procs_count'] = 77 yield self.stg.open_or_update_issue('foo', iss) res = yield self.redis.get(storage.issue_key('foo')) self.assertEqual(res, json.dumps(obj)) PK=JG0vumi/blinkenlights/tests/test_message20110818.pyimport time import vumi.blinkenlights.message20110818 as message from vumi.tests.helpers import VumiTestCase class TestMessage(VumiTestCase): def test_to_dict(self): now = time.time() datapoint = ("vumi.w1.a_metric", now, 1.5, ("sum",)) msg = message.MetricMessage() msg.append(datapoint) self.assertEqual(msg.to_dict(), { 'datapoints': [datapoint], }) def test_from_dict(self): now = time.time() datapoint = ("vumi.w1.a_metric", now, 1.5, ("avg",)) msgdict = {"datapoints": [datapoint]} msg = message.MetricMessage.from_dict(msgdict) self.assertEqual(msg.datapoints(), [datapoint]) def test_extend(self): now = time.time() datapoint = ("vumi.w1.a_metric", now, 1.5, ("min", "max")) msg = message.MetricMessage() msg.extend([datapoint, datapoint, datapoint]) self.assertEqual(msg.datapoints(), [ datapoint, datapoint, datapoint]) PK=JGR~0vumi/blinkenlights/tests/test_message20110707.pyfrom datetime import datetime from vumi.blinkenlights import message20110707 as message from vumi.tests.helpers import VumiTestCase TIMEOBJ = datetime(2011, 07, 07, 12, 00, 00) TIMELIST = [2011, 07, 07, 12, 00, 00] def mkmsg(message_version, message_type, source_name, source_id, payload, timestamp): return { "message_version": message_version, "message_type": message_type, "source_name": source_name, "source_id": source_id, "payload": payload, "timestamp": timestamp, } def mkmsgobj(message_type, source_name, source_id, payload, timestamp): return message.Message(message_type, source_name, source_id, payload, timestamp) class TestMessage(VumiTestCase): def test_decode_valid_message(self): """A valid message dict should decode into an appropriate Message object. """ msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], TIMELIST) msg = message.Message.from_dict(msg_data) self.assertEquals("custom", msg.message_type) self.assertEquals("myworker", msg.source_name) self.assertEquals("abc123", msg.source_id) self.assertEquals(["foo"], msg.payload) self.assertEquals(TIMEOBJ, msg.timestamp) def test_encode_valid_message(self): """A Message object should encode into an appropriate message dict. """ msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], TIMELIST) msg = message.Message("custom", "myworker", "abc123", ["foo"], TIMEOBJ) self.assertEquals(msg_data, msg.to_dict()) def test_decode_invalid_messages(self): """Various kinds of invalid messages should fail to decode. """ msg_data = mkmsg("19800902", "custom", "myworker", "abc123", ["foo"], TIMELIST) self.assertRaises(ValueError, message.Message.from_dict, msg_data) msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], None) self.assertRaises(ValueError, message.Message.from_dict, msg_data) msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], TIMELIST) msg_data.pop('payload') self.assertRaises(ValueError, message.Message.from_dict, msg_data) msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], TIMELIST) msg_data['foo'] = 'bar' self.assertRaises(ValueError, message.Message.from_dict, msg_data) def test_message_equality(self): """Identical messages should compare equal. Different messages should not. """ msg_data = mkmsg("20110707", "custom", "myworker", "abc123", ["foo"], TIMELIST) msg1 = mkmsgobj("custom", "myworker", "abc123", ["foo"], TIMEOBJ) msg2 = message.Message.from_dict(msg_data) msg3 = message.Message.from_dict(msg1.to_dict()) diff_msgs = [ mkmsgobj("custom1", "myworker", "abc123", ["foo"], TIMEOBJ), mkmsgobj("custom", "myworker1", "abc123", ["foo"], TIMEOBJ), mkmsgobj("custom", "myworker", "abc1231", ["foo"], TIMEOBJ), mkmsgobj("custom", "myworker", "abc123", ["foo1"], TIMEOBJ), ] self.assertEquals(msg1, msg1) self.assertEquals(msg1, msg2) self.assertEquals(msg1, msg3) for msg in diff_msgs: self.assertNotEquals(msg1, msg) def test_timestamp_injection(self): """A message created without a timestamp should get one. """ start = datetime.utcnow() msg = mkmsgobj("custom", "myworker", "abc123", ["foo"], None) self.assertTrue(start <= msg.timestamp <= datetime.utcnow(), "Expected a time near %s, got %s" % (start, msg.timestamp)) def mkmetricsmsg(metrics): payload = metrics return mkmsg("20110707", "metrics", "myworker", "abc123", payload, TIMELIST) class TestMetricsMessage(VumiTestCase): def test_parse_empty_metrics(self): msg_data = mkmetricsmsg([]) msg = message.MetricsMessage.from_dict(msg_data) self.assertEquals({}, msg.metrics) def test_parse_metrics(self): msg_data = mkmetricsmsg([ {'name': 'vumi.metrics.test.foo', 'method': 'do_stuff', 'count': 5}, {'name': 'vumi.metrics.test.foo', 'method': 'do_more_stuff', 'count': 6}, {'name': 'vumi.metrics.test.foo', 'method': 'do_stuff', 'count': 7}, {'name': 'vumi.metrics.test.bar', 'method': 'do_stuff', 'count': 3, 'time': 120}, ]) msg = message.MetricsMessage.from_dict(msg_data) expected = { 'vumi.metrics.test.foo': [ (5, None, {'method': 'do_stuff'}), (6, None, {'method': 'do_more_stuff'}), (7, None, {'method': 'do_stuff'}), ], 'vumi.metrics.test.bar': [ (3, 120, {'method': 'do_stuff'}), ], } self.assertEquals(expected, msg.metrics) PKqGrEEE(vumi/blinkenlights/tests/test_metrics.pyimport time from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, Deferred from vumi.blinkenlights import metrics from vumi.message import Message from vumi.service import Worker from vumi.tests.helpers import VumiTestCase, WorkerHelper class TestMetricPublisher(VumiTestCase): def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) @inlineCallbacks def start_publisher(self, publisher): client = WorkerHelper.get_fake_amqp_client(self.worker_helper.broker) channel = yield client.get_channel() publisher.start(channel) def _sleep(self, delay): d = Deferred() reactor.callLater(delay, lambda: d.callback(None)) return d @inlineCallbacks def _check_msg(self, prefix, metric, values): msgs = yield self.worker_helper.wait_for_dispatched_metrics() if values is None: self.assertEqual(msgs, []) return [datapoint] = msgs[-1] self.assertEqual(datapoint[0], prefix + metric.name) self.assertEqual(datapoint[1], list(metric.aggs)) # check datapoints within 2s of now -- the truncating of # time.time() to an int for timestamps can cause a 1s # difference by itself now = time.time() self.assertTrue(all(abs(p[0] - now) < 2.0 for p in datapoint[2]), "Not all datapoints near now (%f): %r" % (now, datapoint)) self.assertEqual([dp[1] for dp in datapoint[2]], values) @inlineCallbacks def test_publish_single_metric(self): publisher = metrics.MetricPublisher() yield self.start_publisher(publisher) msg = metrics.MetricMessage() cnt = metrics.Count("my.count") msg.append( ("vumi.test.%s" % (cnt.name,), cnt.aggs, [(time.time(), 1)])) publisher.publish_message(msg) self._check_msg("vumi.test.", cnt, [1]) def test_publisher_provides_interface(self): publisher = metrics.MetricPublisher() self.assertTrue(metrics.IMetricPublisher.providedBy(publisher)) class TestMetricManager(VumiTestCase): def setUp(self): self._next_publish = Deferred() self.add_cleanup(lambda: self._next_publish.callback(None)) self.worker_helper = self.add_helper(WorkerHelper()) def on_publish(self, mm): d, self._next_publish = self._next_publish, Deferred() d.callback(mm) def wait_publish(self): return self._next_publish @inlineCallbacks def start_manager_as_publisher(self, manager): client = WorkerHelper.get_fake_amqp_client(self.worker_helper.broker) channel = yield client.get_channel() manager.start(channel) self.add_cleanup(manager.stop) def _sleep(self, delay): d = Deferred() reactor.callLater(delay, lambda: d.callback(None)) return d @inlineCallbacks def _check_msg(self, manager, metric, values): msgs = yield self.worker_helper.wait_for_dispatched_metrics() if values is None: self.assertEqual(msgs, []) return [datapoint] = msgs[-1] self.assertEqual(datapoint[0], manager.prefix + metric.name) self.assertEqual(datapoint[1], list(metric.aggs)) # check datapoints within 2s of now -- the truncating of # time.time() to an int for timestamps can cause a 1s # difference by itself now = time.time() self.assertTrue(all(abs(p[0] - now) < 2.0 for p in datapoint[2]), "Not all datapoints near now (%f): %r" % (now, datapoint)) self.assertEqual([dp[1] for dp in datapoint[2]], values) @inlineCallbacks def test_start_manager_no_publisher(self): mm = metrics.MetricManager("vumi.test.") self.assertEqual(mm._publisher, None) self.assertEqual(mm._task, None) yield self.start_manager_as_publisher(mm) self.assertIsInstance(mm._publisher, metrics.MetricPublisher) self.assertNotEqual(mm._task, None) @inlineCallbacks def test_start_manager_publisher_and_channel(self): publisher = metrics.MetricPublisher() mm = metrics.MetricManager("vumi.test.", publisher=publisher) self.assertEqual(mm._publisher, publisher) self.assertEqual(mm._task, None) yield self.assertFailure( self.start_manager_as_publisher(mm), RuntimeError) def test_start_polling_no_publisher(self): mm = metrics.MetricManager("vumi.test.") self.assertEqual(mm._publisher, None) self.assertEqual(mm._task, None) mm.start_polling() self.add_cleanup(mm.stop_polling) self.assertEqual(mm._publisher, None) self.assertNotEqual(mm._task, None) def test_start_polling_with_publisher(self): publisher = metrics.MetricPublisher() mm = metrics.MetricManager("vumi.test.", publisher=publisher) self.assertEqual(mm._publisher, publisher) self.assertEqual(mm._task, None) mm.start_polling() self.add_cleanup(mm.stop_polling) self.assertEqual(mm._publisher, publisher) self.assertNotEqual(mm._task, None) def test_oneshot(self): self.patch(time, "time", lambda: 12345) mm = metrics.MetricManager("vumi.test.") cnt = metrics.Count("my.count") mm.oneshot(cnt, 3) self.assertEqual(cnt.name, "my.count") self.assertEqual(mm._oneshot_msgs, [ (cnt, [(12345, 3)]), ]) def test_register(self): mm = metrics.MetricManager("vumi.test.") cnt = mm.register(metrics.Count("my.count")) self.assertEqual(cnt.name, "my.count") self.assertEqual(mm._metrics, [cnt]) def test_double_register(self): mm = metrics.MetricManager("vumi.test.") mm.register(metrics.Count("my.count")) self.assertRaises(metrics.MetricRegistrationError, mm.register, metrics.Count("my.count")) def test_lookup(self): mm = metrics.MetricManager("vumi.test.") cnt = mm.register(metrics.Count("my.count")) self.assertTrue("my.count" in mm) self.assertTrue(mm["my.count"] is cnt) self.assertEqual(mm["my.count"].name, "my.count") @inlineCallbacks def test_publish_metrics_poll(self): mm = metrics.MetricManager("vumi.test.", 0.1, self.on_publish) cnt = mm.register(metrics.Count("my.count")) yield self.start_manager_as_publisher(mm) cnt.inc() mm.publish_metrics() self._check_msg(mm, cnt, [1]) @inlineCallbacks def test_publish_metrics_oneshot(self): mm = metrics.MetricManager("vumi.test.", 0.1, self.on_publish) cnt = metrics.Count("my.count") yield self.start_manager_as_publisher(mm) mm.oneshot(cnt, 1) mm.publish_metrics() self._check_msg(mm, cnt, [1]) @inlineCallbacks def test_start(self): mm = metrics.MetricManager("vumi.test.", 0.1, self.on_publish) cnt = mm.register(metrics.Count("my.count")) yield self.start_manager_as_publisher(mm) self.assertTrue(mm._task is not None) self._check_msg(mm, cnt, None) cnt.inc() yield self.wait_publish() self._check_msg(mm, cnt, [1]) cnt.inc() cnt.inc() yield self.wait_publish() self._check_msg(mm, cnt, [1, 1]) @inlineCallbacks def test_publish_metrics(self): mm = metrics.MetricManager("vumi.test.", 0.1, self.on_publish) cnt = metrics.Count("my.count") yield self.start_manager_as_publisher(mm) mm.oneshot(cnt, 1) self.assertEqual(len(mm._oneshot_msgs), 1) mm.publish_metrics() self.assertEqual(mm._oneshot_msgs, []) self._check_msg(mm, cnt, [1]) def test_publish_metrics_not_started_no_publisher(self): mm = metrics.MetricManager("vumi.test.") self.assertEqual(mm._publisher, None) mm.oneshot(metrics.Count("my.count"), 1) self.assertRaises(ValueError, mm.publish_metrics) def test_stop_unstarted(self): mm = metrics.MetricManager("vumi.test.", 0.1, self.on_publish) mm.stop() mm.stop() # Check that .stop() is idempotent. @inlineCallbacks def test_in_worker(self): worker = yield self.worker_helper.get_worker(Worker, {}, start=False) mm = yield worker.start_publisher(metrics.MetricManager, "vumi.test.", 0.1, self.on_publish) acc = mm.register(metrics.Metric("my.acc")) try: self.assertTrue(mm._task is not None) self._check_msg(mm, acc, None) acc.set(1.5) acc.set(1.0) yield self.wait_publish() self._check_msg(mm, acc, [1.5, 1.0]) finally: mm.stop() @inlineCallbacks def test_task_failure(self): mm = metrics.MetricManager("vumi.test.", 0.1) wait_error = Deferred() class BadMetricError(Exception): pass class BadMetric(metrics.Metric): def poll(self): wait_error.callback(None) raise BadMetricError("bad metric") mm.register(BadMetric("bad")) yield self.start_manager_as_publisher(mm) yield wait_error yield self._sleep(0) # allow log message to be processed error, = self.flushLoggedErrors(BadMetricError) self.assertTrue(error.type is BadMetricError) class TestAggregators(VumiTestCase): def test_sum(self): self.assertEqual(metrics.SUM([]), 0.0) self.assertEqual(metrics.SUM([1.0, 2.0]), 3.0) self.assertEqual(metrics.SUM([2.0, 1.0]), 3.0) self.assertEqual(metrics.SUM.name, "sum") self.assertEqual(metrics.Aggregator.from_name("sum"), metrics.SUM) def test_avg(self): self.assertEqual(metrics.AVG([]), 0.0) self.assertEqual(metrics.AVG([1.0, 2.0]), 1.5) self.assertEqual(metrics.AVG([2.0, 1.0]), 1.5) self.assertEqual(metrics.AVG.name, "avg") self.assertEqual(metrics.Aggregator.from_name("avg"), metrics.AVG) def test_min(self): self.assertEqual(metrics.MIN([]), 0.0) self.assertEqual(metrics.MIN([1.0, 2.0]), 1.0) self.assertEqual(metrics.MIN([2.0, 1.0]), 1.0) self.assertEqual(metrics.MIN.name, "min") self.assertEqual(metrics.Aggregator.from_name("min"), metrics.MIN) def test_max(self): self.assertEqual(metrics.MAX([]), 0.0) self.assertEqual(metrics.MAX([1.0, 2.0]), 2.0) self.assertEqual(metrics.MAX([2.0, 1.0]), 2.0) self.assertEqual(metrics.MAX.name, "max") self.assertEqual(metrics.Aggregator.from_name("max"), metrics.MAX) def test_last(self): self.assertEqual(metrics.LAST([]), 0.0) self.assertEqual(metrics.LAST([1.0, 2.0]), 2.0) self.assertEqual(metrics.LAST([2.0, 1.0]), 1.0) self.assertEqual(metrics.LAST.name, "last") self.assertEqual(metrics.Aggregator.from_name("last"), metrics.LAST) def test_already_registered(self): self.assertRaises(metrics.AggregatorAlreadyDefinedError, metrics.Aggregator, "sum", sum) class CheckValuesMixin(object): def _check_poll_base(self, metric, n): datapoints = metric.poll() # check datapoints within 2s of now -- the truncating of # time.time() to an int for timestamps can cause a 1s # difference by itself now = time.time() self.assertTrue(all(abs(d[0] - now) <= 2.0 for d in datapoints), "Not all datapoints near now (%f): %r" % (now, datapoints)) self.assertTrue(all(isinstance(d[0], (int, long)) for d in datapoints)) actual_values = [dp[1] for dp in datapoints] return actual_values def check_poll_func(self, metric, n, test): actual_values = self._check_poll_base(metric, n) self.assertEqual([test(v) for v in actual_values], [True] * n) def check_poll(self, metric, expected_values): n = len(expected_values) actual_values = self._check_poll_base(metric, n) self.assertEqual(actual_values, expected_values) class TestMetric(VumiTestCase, CheckValuesMixin): def test_manage(self): mm = metrics.MetricManager("vumi.test.") metric = metrics.Metric("foo") metric.manage(mm) self.assertEqual(metric.name, "foo") mm2 = metrics.MetricManager("vumi.othertest.") self.assertRaises(metrics.MetricRegistrationError, metric.manage, mm2) def test_managed(self): metric = metrics.Metric("foo") self.assertFalse(metric.managed) mm = metrics.MetricManager("vumi.test.") metric.manage(mm) self.assertTrue(metric.managed) def test_poll(self): metric = metrics.Metric("foo") self.check_poll(metric, []) metric.set(1.0) metric.set(2.0) self.check_poll(metric, [1.0, 2.0]) class TestCount(VumiTestCase, CheckValuesMixin): def test_inc_and_poll(self): metric = metrics.Count("foo") self.check_poll(metric, []) metric.inc() self.check_poll(metric, [1.0]) self.check_poll(metric, []) metric.inc() metric.inc() self.check_poll(metric, [1.0, 1.0]) class TestTimer(VumiTestCase, CheckValuesMixin): def patch_time(self, starting_value): def fake_time(): return self._fake_time self.patch(time, 'time', fake_time) self._fake_time = starting_value def incr_fake_time(self, value): self._fake_time += value def test_start_and_stop(self): timer = metrics.Timer("foo") self.patch_time(12345.0) timer.start() self.incr_fake_time(0.1) timer.stop() self.check_poll_func(timer, 1, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_already_started(self): timer = metrics.Timer("foo") timer.start() self.assertRaises(metrics.TimerAlreadyStartedError, timer.start) def test_not_started(self): timer = metrics.Timer("foo") self.assertRaises(metrics.TimerNotStartedError, timer.stop) def test_stop_and_stop(self): timer = metrics.Timer("foo") timer.start() timer.stop() self.assertRaises(metrics.TimerNotStartedError, timer.stop) def test_double_start_and_stop(self): timer = metrics.Timer("foo") self.patch_time(12345.0) timer.start() self.incr_fake_time(0.1) timer.stop() timer.start() self.incr_fake_time(0.1) timer.stop() self.check_poll_func(timer, 2, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_context_manager(self): timer = metrics.Timer("foo") self.patch_time(12345.0) with timer: self.incr_fake_time(0.1) # feign sleep self.check_poll_func(timer, 1, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_accumulate_times(self): timer = metrics.Timer("foo") self.patch_time(12345.0) with timer: self.incr_fake_time(0.1) # feign sleep with timer: self.incr_fake_time(0.1) # feign sleep self.check_poll_func(timer, 2, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_timeit(self): timer = metrics.Timer("foo") self.patch_time(12345.0) with timer.timeit(): self.incr_fake_time(0.1) self.check_poll_func(timer, 1, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_timeit_start_and_stop(self): timer = metrics.Timer("foo") self.patch_time(12345.0) event_timer = timer.timeit() event_timer.start() self.incr_fake_time(0.1) event_timer.stop() self.check_poll_func(timer, 1, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) def test_timeit_start_and_start(self): event_timer = metrics.Timer("foo").timeit() event_timer.start() self.assertRaises(metrics.TimerAlreadyStartedError, event_timer.start) def test_timeit_stop_without_start(self): event_timer = metrics.Timer("foo").timeit() self.assertRaises(metrics.TimerNotStartedError, event_timer.stop) def test_timeit_stop_and_stop(self): event_timer = metrics.Timer("foo").timeit() event_timer.start() event_timer.stop() self.assertRaises(metrics.TimerAlreadyStoppedError, event_timer.stop) def test_timeit_autostart(self): timer = metrics.Timer("foo") self.patch_time(12345.0) event_timer = timer.timeit(start=True) self.incr_fake_time(0.1) event_timer.stop() self.check_poll_func(timer, 1, lambda x: 0.09 < x < 0.11) self.check_poll(timer, []) class TestMetricsConsumer(VumiTestCase): def test_consume_message(self): expected_datapoints = [ ("vumi.test.v1", 1234, 1.0), ("vumi.test.v2", 3456, 2.0), ] datapoints = [] callback = lambda *v: datapoints.append(v) consumer = metrics.MetricsConsumer(None, callback) msg = metrics.MetricMessage() msg.extend(expected_datapoints) vumi_msg = Message.from_json(msg.to_json()) consumer.consume_message(vumi_msg) self.assertEqual(datapoints, expected_datapoints) PKqGlq5q50vumi/blinkenlights/tests/test_metrics_workers.pyfrom twisted.internet.defer import inlineCallbacks, Deferred, DeferredQueue from twisted.internet.protocol import DatagramProtocol from twisted.internet import reactor from vumi.blinkenlights import metrics_workers from vumi.blinkenlights.message20110818 import MetricMessage from vumi.tests.helpers import VumiTestCase, WorkerHelper class BrokerWrapper(object): """Wrap utility methods around a FakeAMQPBroker.""" def __init__(self, broker): self._broker = broker def __getattr__(self, name): return getattr(self._broker, name) def send_datapoints(self, exchange, queue, datapoints): """Publish datapoints to a broker.""" msg = MetricMessage() msg.extend(datapoints) self._broker.publish_message(exchange, queue, msg) def recv_datapoints(self, exchange, queue): """Retrieve datapoints from a broker.""" vumi_msgs = self._broker.get_messages(exchange, queue) msgs = [MetricMessage.from_dict(vm.payload) for vm in vumi_msgs] return [msg.datapoints() for msg in msgs] class TestMetricTimeBucket(VumiTestCase): def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) @inlineCallbacks def test_bucketing(self): config = {'buckets': 4, 'bucket_size': 5} worker = yield self.worker_helper.get_worker( metrics_workers.MetricTimeBucket, config=config) broker = BrokerWrapper(self.worker_helper.broker) datapoints = [ ("vumi.test.foo", ("agg",), [(1230, 1.5), (1235, 2.0)]), ("vumi.test.bar", ("sum",), [(1240, 1.0)]), ] broker.send_datapoints("vumi.metrics", "vumi.metrics", datapoints) yield broker.kick_delivery() buckets = [broker.recv_datapoints("vumi.metrics.buckets", "bucket.%d" % i) for i in range(4)] expected_buckets = [ [], [[[u'vumi.test.bar', ['sum'], [[1240, 1.0]]]]], [[[u'vumi.test.foo', ['agg'], [[1230, 1.5]]]], [[u'vumi.test.foo', ['agg'], [[1235, 2.0]]]]], [], ] self.assertEqual(buckets, expected_buckets) yield worker.stopWorker() class TestMetricAggregator(VumiTestCase): def setUp(self): self.now = 0 self.worker_helper = self.add_helper(WorkerHelper()) self.broker = BrokerWrapper(self.worker_helper.broker) def fake_time(self): return self.now @inlineCallbacks def test_aggregating(self): config = {'bucket': 3, 'bucket_size': 5} worker = yield self.worker_helper.get_worker( metrics_workers.MetricAggregator, config, start=False) worker._time = self.fake_time yield worker.startWorker() datapoints = [ ("vumi.test.foo", ("avg",), [(1235, 1.5), (1236, 2.0)]), ("vumi.test.foo", ("sum",), [(1240, 1.0)]), ] self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.2", datapoints) yield self.broker.kick_delivery() def recv(): return self.broker.recv_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates") expected = [] self.now = 1241 worker.check_buckets() self.assertEqual(recv(), expected) expected.append([["vumi.test.foo.avg", [], [[1235, 1.75]]]]) self.now = 1246 worker.check_buckets() self.assertEqual(recv(), expected) # skip a few checks expected.append([["vumi.test.foo.sum", [], [[1240, 2.0]]]]) self.now = 1261 worker.check_buckets() self.assertEqual(recv(), expected) @inlineCallbacks def test_aggregating_last(self): config = {'bucket': 3, 'bucket_size': 5} worker = yield self.worker_helper.get_worker( metrics_workers.MetricAggregator, config, start=False) worker._time = self.fake_time yield worker.startWorker() datapoints = [ ("vumi.test.foo", ("last",), [(1235, 1.5), (1236, 2.0)]), ("vumi.test.bar", ("last",), [(1241, 1.0), (1240, 2.0)]), ] self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.2", datapoints) yield self.broker.kick_delivery() def recv(): return self.broker.recv_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates") expected = [] self.now = 1241 worker.check_buckets() self.assertEqual(recv(), expected) expected.append([["vumi.test.foo.last", [], [[1235, 2.0]]]]) self.now = 1246 worker.check_buckets() self.assertEqual(recv(), expected) # skip a few checks expected.append([["vumi.test.bar.last", [], [[1240, 1.0]]]]) self.now = 1261 worker.check_buckets() self.assertEqual(recv(), expected) @inlineCallbacks def test_aggregating_lag(self): config = {'bucket': 3, 'bucket_size': 5, 'lag': 1} worker = yield self.worker_helper.get_worker( metrics_workers.MetricAggregator, config, start=False) worker._time = self.fake_time yield worker.startWorker() datapoints = [ ("vumi.test.foo", ("avg",), [(1235, 1.5), (1236, 2.0)]), ("vumi.test.foo", ("sum",), [(1240, 1.0)]), ] self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.3", datapoints) self.broker.send_datapoints( "vumi.metrics.buckets", "bucket.2", datapoints) yield self.broker.kick_delivery() def recv(): return self.broker.recv_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates") expected = [] self.now = 1237 worker.check_buckets() self.assertEqual(recv(), expected) expected.append([["vumi.test.foo.avg", [], [[1235, 1.75]]]]) self.now = 1242 worker.check_buckets() self.assertEqual(recv(), expected) # skip a few checks expected.append([["vumi.test.foo.sum", [], [[1240, 2.0]]]]) self.now = 1257 worker.check_buckets() self.assertEqual(recv(), expected) class TestAggregationSystem(VumiTestCase): """Tests tying MetricTimeBucket and MetricAggregator together.""" def setUp(self): self.aggregator_workers = [] self.now = 0 self.worker_helper = self.add_helper(WorkerHelper()) self.broker = BrokerWrapper(self.worker_helper.broker) def fake_time(self): return self.now def send(self, datapoints): self.broker.send_datapoints("vumi.metrics", "vumi.metrics", datapoints) def recv(self): return self.broker.recv_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates") @inlineCallbacks def _setup_workers(self, bucketters, aggregators, bucket_size): bucket_config = { 'buckets': aggregators, 'bucket_size': bucket_size, } for _i in range(bucketters): worker = yield self.worker_helper.get_worker( metrics_workers.MetricTimeBucket, bucket_config) aggregator_config = { 'bucket_size': bucket_size, } for i in range(aggregators): config = aggregator_config.copy() config['bucket'] = i worker = yield self.worker_helper.get_worker( metrics_workers.MetricAggregator, config=config, start=False) worker._time = self.fake_time yield worker.startWorker() self.aggregator_workers.append(worker) # TODO: use parameteric test cases to test many combinations of workers @inlineCallbacks def test_aggregating_one_metric(self): yield self._setup_workers(1, 1, 5) datapoints = [("vumi.test.foo", ["sum"], [(12345, 1.0), (12346, 2.0)])] self.send(datapoints) self.send(datapoints) yield self.broker.kick_delivery() # deliver to bucketters yield self.broker.kick_delivery() # deliver to aggregators self.now = 12355 for worker in self.aggregator_workers: worker.check_buckets() datapoints, = self.recv() self.assertEqual(datapoints, [ ["vumi.test.foo.sum", [], [[12345, 6.0]]] ]) class TestGraphitePublisher(VumiTestCase): def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) def _check_msg(self, channel, metric, value, timestamp): [msg] = self.worker_helper.broker.get_dispatched("graphite", metric) self.assertEqual(msg.properties, {"delivery mode": 2}) self.assertEqual(msg.body, "%f %d" % (value, timestamp)) @inlineCallbacks def test_publish_metric(self): datapoint = ("vumi.test.v1", 1.0, 1234) client = WorkerHelper.get_fake_amqp_client(self.worker_helper.broker) channel = yield client.get_channel() pub = metrics_workers.GraphitePublisher() pub.start(channel) pub.publish_metric(*datapoint) self._check_msg(channel, *datapoint) class TestGraphiteMetricsCollector(VumiTestCase): def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) self.broker = BrokerWrapper(self.worker_helper.broker) @inlineCallbacks def test_single_message(self): yield self.worker_helper.get_worker( metrics_workers.GraphiteMetricsCollector, {}) datapoints = [("vumi.test.foo", "", [(1234, 1.5)])] self.broker.send_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates", datapoints) yield self.broker.kick_delivery() content, = self.broker.get_dispatched("graphite", "vumi.test.foo") parts = content.body.split() value, ts = float(parts[0]), int(parts[1]) self.assertEqual(value, 1.5) self.assertEqual(ts, 1234) class UDPMetricsCatcher(DatagramProtocol): def __init__(self): self.queue = DeferredQueue() def datagramReceived(self, datagram, addr): self.queue.put(datagram) class TestUDPMetricsCollector(VumiTestCase): @inlineCallbacks def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) self.broker = BrokerWrapper(self.worker_helper.broker) self.udp_protocol = UDPMetricsCatcher() self.udp_server = yield reactor.listenUDP(0, self.udp_protocol) self.add_cleanup(self.udp_server.stopListening) self.worker = yield self.worker_helper.get_worker( metrics_workers.UDPMetricsCollector, { 'metrics_host': '127.0.0.1', 'metrics_port': self.udp_server.getHost().port, }) def send_metrics(self, *metrics): datapoints = [("vumi.test.foo", "", list(metrics))] self.broker.send_datapoints("vumi.metrics.aggregates", "vumi.metrics.aggregates", datapoints) return self.broker.kick_delivery() @inlineCallbacks def test_single_message(self): yield self.send_metrics((1234, 1.5)) received = yield self.udp_protocol.queue.get() self.assertEqual('1970-01-01 00:20:34 vumi.test.foo 1.5\n', received) @inlineCallbacks def test_multiple_messages(self): yield self.send_metrics((1234, 1.5), (1235, 2.5)) received = yield self.udp_protocol.queue.get() self.assertEqual('1970-01-01 00:20:34 vumi.test.foo 1.5\n', received) received = yield self.udp_protocol.queue.get() self.assertEqual('1970-01-01 00:20:35 vumi.test.foo 2.5\n', received) class TestRandomMetricsGenerator(VumiTestCase): def setUp(self): self.worker_helper = self.add_helper(WorkerHelper()) self.broker = BrokerWrapper(self.worker_helper.broker) self._on_run = Deferred() self.add_cleanup(lambda: self._on_run.callback(None)) def on_run(self, worker): d, self._on_run = self._on_run, Deferred() d.callback(None) def wake_after_run(self): return self._on_run @inlineCallbacks def test_one_run(self): worker = yield self.worker_helper.get_worker( metrics_workers.RandomMetricsGenerator, { "manager_period": "0.1", "generator_period": "0.1", }, start=False) worker.on_run = self.on_run yield worker.startWorker() yield self.wake_after_run() yield self.wake_after_run() datasets = self.broker.recv_datapoints('vumi.metrics', 'vumi.metrics') # there should be a least one but there may be more # than one if the tests are running slowly datapoints = datasets[0] self.assertEqual(sorted(d[0] for d in datapoints), ["vumi.random.count", "vumi.random.timer", "vumi.random.value"]) PK=JG$vumi/blinkenlights/tests/__init__.pyPK=JGBKKvumi/persist/model.py# -*- test-case-name: vumi.persist.tests.test_model -*- """Base classes for Vumi persistence models.""" from functools import wraps import urllib from vumi.errors import VumiError from vumi.persist.fields import Field, FieldDescriptor, ValidationError class ModelMigrationError(VumiError): pass class VumiRiakError(VumiError): pass class ModelMetaClass(type): def __new__(mcs, name, bases, dict): # set default bucket suffix if "bucket" not in dict: dict["bucket"] = name.lower() # locate Field instances fields, descriptors = {}, {} class_dicts = [dict] + [base.__dict__ for base in reversed(bases)] for cls_dict in class_dicts: for key, possible_field in cls_dict.items(): if key in fields: continue if isinstance(possible_field, FieldDescriptor): possible_field = possible_field.field # copy descriptors if isinstance(possible_field, Field): descriptors[key] = possible_field.get_descriptor(key) dict[key] = descriptors[key] fields[key] = possible_field dict["field_descriptors"] = descriptors # add backlinks object dict["backlinks"] = BackLinks() cls = type.__new__(mcs, name, bases, dict) # inform field instances which classes they belong to for field_descriptor in descriptors.itervalues(): field_descriptor.setup(cls) return cls class BackLinks(object): """Object for holding reverse-key look-up functions for a Model class.""" def __init__(self): self.functions = {} def declare_backlink(self, name, function): if name in self.functions: raise RuntimeError("Backlink %r already registered" % (name,)) self.functions[name] = function def __get__(self, instance, owner): if instance is None: return self return BackLinkProxy(self, instance) class BackLinkProxy(object): def __init__(self, backlinks, modelobj): self._backlinks = backlinks self._modelobj = modelobj def __getattr__(self, key): if key not in self._backlinks.functions: raise AttributeError( "No backlink function registered for %r" % (key,)) def wrapped_backlink(*args, **kwargs): return self._backlinks.functions[key](self._modelobj, *args, **kwargs) return wrapped_backlink class ModelMigrator(object): """ Migration handler for old Model versions. Subclasses of this should implement ``migrate_from_()`` methods for each previous version of the model being migrated. This method will be called with a :class:`MigrationData` instance and must return a :class:`MigrationData` instance. (This will likely be the same instance, but may be different.) The ``migrate_from_()`` is allowed to do whatever other operations may be required (for example, modifying related objects). However, care should be taken to avoid lenthly delays, race conditions, etc. There is a special-case ``migrate_from_unversioned()`` method that is called for objects that do not contain a model version. In order to facilitate different processes using different model versions, reverse migrations are also supported. These are similar to forward migrations, except they are applied at save time (rather than load time) and methods are named ``reverse_from_()``. """ def __init__(self, model_class, manager, data_version, reverse=False): self.model_class = model_class self.manager = manager self.data_version = data_version self.reverse = reverse prefix = "reverse" if reverse else "migrate" if data_version is not None: migration_method_name = '%s_from_%s' % (prefix, str(data_version)) else: migration_method_name = '%s_from_unversioned' % (prefix,) self.migration_method = getattr(self, migration_method_name, None) def __call__(self, riak_object): if self.migration_method is None: prefix = "reverse " if self.reverse else "" raise ModelMigrationError( 'No %smigrators defined for %s version %s' % ( prefix, self.model_class.__name__, self.data_version)) return self.migration_method(MigrationData(riak_object)) class MigrationData(object): def __init__(self, riak_object): self.riak_object = riak_object self.old_data = riak_object.get_data() self.new_data = {} self.old_index = {} self.new_index = {} for name, value in riak_object.get_indexes(): self.old_index.setdefault(name, []).append(value) def get_riak_object(self): self.riak_object.set_data(self.new_data) # We need to explicitly remove old indexes before adding new ones. for field in self.old_index: self.riak_object.remove_index(field) for field, values in self.new_index.iteritems(): for value in values: self.riak_object.add_index(field, value) return self.riak_object def copy_values(self, *fields): """Copy field values from old data to new data.""" for field in fields: self.new_data[field] = self.old_data[field] def copy_indexes(self, *indexes): """Copy indexes from old data to new data.""" for index in indexes: self.new_index[index] = self.old_index.get(index, [])[:] def copy_dynamic_values(self, *dynamic_prefixes): """Copy dynamic field values from old data to new data.""" for prefix in dynamic_prefixes: for key in self.old_data: if key.startswith(prefix): self.new_data[key] = self.old_data[key] def add_index(self, index, value): """Add a new index value to new data.""" if index is None: index = '' else: index = str(index) if isinstance(value, unicode): value = value.encode('utf-8') self.new_index.setdefault(index, []).append(value) def clear_index(self, index): """Remove all values for a given index from new data.""" del self.new_index[index] def set_value(self, field, value, index=None, index_value=None): """Set the value (and optionally the index) for a field. Indexes are usually set by :class:`FieldDescriptor` objects. Since we don't have those here, we need to explicitly set the index values for fields that are indexed. """ self.new_data[field] = value if index is not None: if index_value is None: index_value = value if index_value is not None: self.add_index(index, index_value) class Model(object): """A model is a description of an entity persisted in a data store.""" __metaclass__ = ModelMetaClass VERSION = None MIGRATOR = ModelMigrator bucket = None # TODO: maybe replace .backlinks with a class-level .query # or .by_ method def __init__(self, manager, key, _riak_object=None, **field_values): self._fields_changed = [] self.manager = manager self.key = key if _riak_object is not None: self._riak_object = _riak_object else: self._riak_object = manager.riak_object(type(self), key) for field_name, descriptor in self.field_descriptors.iteritems(): field = descriptor.field if not field.initializable: continue field_value = field_values.pop(field_name, field.default) if callable(field_value): field_value = field_value() descriptor.initialize(self, field_value) if field_values: raise ValidationError("Unexpected extra initial fields %r passed" " to model %s" % (field_values.keys(), self.__class__)) self.clean() self.was_migrated = False def __repr__(self): str_items = ["%s=%r" % item for item in sorted(self.get_data().items())] return "<%s %s>" % (self.__class__.__name__, " ".join(str_items)) def clean(self): for field_name, descriptor in self.field_descriptors.iteritems(): descriptor.clean(self) def _field_changed(self, changed_field_name): """ Called when a field value changes. """ already_notifying = bool(self._fields_changed) if changed_field_name not in self._fields_changed: self._fields_changed.append(changed_field_name) if not already_notifying: self._notify_fields_changed() def _notify_fields_changed(self): while self._fields_changed: # We only update self._fields_changed after processing, because # we're also using it to track whether we're currently processing. self._notify_field_changed(self._fields_changed[0]) self._fields_changed[:1] = [] def _notify_field_changed(self, changed_field_name): for field_name, descriptor in self.field_descriptors.iteritems(): if field_name != changed_field_name: descriptor.model_field_changed(self, changed_field_name) def get_data(self): """ Returns a dictionary with for all known field names & values. Useful for when needing to represent a model instance as a dictionary. :returns: A dict of all values, including the key. """ data = self._riak_object.get_data() data.update({ 'key': self.key, }) return data def save(self): """Save the object to Riak. :returns: A deferred that fires once the data is saved (or None if using a synchronous manager). """ return self.manager.store(self) def delete(self): """Delete the object from Riak. :returns: A deferred that fires once the data is deleted (or None if using a synchronous manager). """ return self.manager.delete(self) @classmethod def load(cls, manager, key, result=None): """Load an object from Riak. :returns: A deferred that fires with the new model object. """ return manager.load(cls, key, result=result) @classmethod def load_all_bunches(cls, manager, keys): """Load batches of objects for the given list of keys. :returns: An iterator over (possibly deferred) lists of model instances. """ return manager.load_all_bunches(cls, keys) @classmethod def all_keys(cls, manager): """Return all keys in this model's bucket. Uses Riak's special `$bucket` index. Beware of tombstones (i.e. the keys returned might have been deleted from Riak in the near past). :returns: List of keys from this model's bucket. """ return manager.index_keys( cls, '$bucket', manager.bucket_name(cls), None) @classmethod def index_keys(cls, manager, field_name, value, end_value=None, return_terms=None): """Find object keys by index. :param manager: A :class:`Manager` object. :param str field_name: The name of the field to get the index from. The index type (integer or binary) is determined by the field and this may affect the behaviour of range queries. :param value: The index value to look up. This is processed by the field in question to get the actual value to send to Riak. If ``end_value`` is provided, ``value`` is used as the start of a range query, otherwise an exact match is performed. :param end_value: The index value to use as the end of a range query. This is processed by the field in question to get the actual value to send to Riak. If provided, a range query is performed. :param bool return_terms: If ``True``, the raw index values will be returned along with the object keys in a ``(term, key)`` tuple. These raw values are not processed by the field and may therefore be different from the expected field values. :returns: List of keys matching the index param. If ``return_terms`` is ``True``, a list of ``(term, key)`` tuples will be returned instead. """ index_name, start_value, end_value = index_vals_for_field( cls, field_name, value, end_value) return manager.index_keys( cls, index_name, start_value, end_value, return_terms=return_terms) @classmethod def all_keys_page(cls, manager, max_results=None, continuation=None): """Return all keys in this model's bucket. Uses Riak's special `$bucket` index. Beware of tombstones (i.e. the keys returned might have been deleted from Riak in the near past). :param int max_results: The maximum number of results to return per page. If ``None``, pagination will disables and a single page containing all results will be returned. :param continuation: An opaque continuation token indicating which page of results to fetch. The index page object returned from this method has a ``continuation`` attribute that contains this value. If ``None``, the first page of results will be returned. :returns: :class:`VumiIndexPage` or :class:`VumiTxIndexPage` object containing all keys from this model's bucket. """ return manager.index_keys_page( cls, '$bucket', manager.bucket_name(cls), None, max_results=max_results, continuation=continuation) @classmethod def index_keys_page(cls, manager, field_name, value, end_value=None, return_terms=None, max_results=None, continuation=None): """Find object keys by index, using pagination. :param manager: A :class:`Manager` object. :param str field_name: The name of the field to get the index from. The index type (integer or binary) is determined by the field and this may affect the behaviour of range queries. :param value: The index value to look up. This is processed by the field in question to get the actual value to send to Riak. If ``end_value`` is provided, ``value`` is used as the start of a range query, otherwise an exact match is performed. :param end_value: The index value to use as the end of a range query. This is processed by the field in question to get the actual value to send to Riak. If provided, a range query is performed. :param bool return_terms: If ``True``, the raw index values will be returned along with the object keys in a ``(term, key)`` tuple. These raw values are not processed by the field and may therefore be different from the expected field values. :param int max_results: The maximum number of results to return per page. If ``None``, pagination will disables and a single page containing all results will be returned. :param continuation: An opaque continuation token indicating which page of results to fetch. The index page object returned from this method has a ``continuation`` attribute that contains this value. If ``None``, the first page of results will be returned. :returns: :class:`VumiIndexPage` or :class:`VumiTxIndexPage` object containing results. If ``return_terms`` is ``True``, the object returned will contain ``(term, key)`` tuples instead of keys. """ index_name, start_value, end_value = index_vals_for_field( cls, field_name, value, end_value) return manager.index_keys_page( cls, index_name, start_value, end_value, return_terms=return_terms, max_results=max_results, continuation=continuation) @classmethod def index_lookup(cls, manager, field_name, value): """Find objects by index. :returns: :class:`VumiMapReduce` instance based on the index param. """ return manager.mr_from_field(cls, field_name, value) @classmethod def index_match(cls, manager, query, field_name, value): """ Finds objects in the index that match the regex patterns in query :param list query: A list of dictionaries with query information. Each dictionary should have the follow structure: { "key": "the key to use to lookup the value in the JSON doc", "pattern": "the regex to match the value with", "flags": "the flags to set on the RegExp object", } :returns: class:`VumiMapReduce` instance based on the index param with and a map phase for matching against the query. """ return manager.mr_from_field_match(cls, query, field_name, value) @classmethod def search(cls, manager, **kw): """Search for instances of this model matching keys/values. :returns: :class:`VumiMapReduce` instance based on the search params. """ # TODO: build the queries more intelligently for k, value in kw.iteritems(): value = unicode(value) value = value.replace('\\', '\\\\') value = value.replace("'", "\\'") kw[k] = value query = " AND ".join("%s:'%s'" % (k, v) for k, v in kw.iteritems()) return cls.raw_search(manager, query) @classmethod def raw_search(cls, manager, query): """ Performs a raw riak search, does no inspection on the given query. :returns: :class:`VumiMapReduce` instance based on the search params. """ return manager.mr_from_search(cls, query) @classmethod def real_search(cls, manager, query, rows=None, start=None): """ Performs a real riak search, does no inspection on the given query. :returns: list of keys. """ return manager.real_search(cls, query, rows=rows, start=start) @classmethod def enable_search(cls, manager): """Enable solr indexing over for this model and manager.""" return manager.riak_enable_search(cls) def index_vals_for_field(model, field_name, start_value, end_value): descriptor = model.field_descriptors[field_name] if descriptor.index_name is None: raise ValueError("%s.%s is not indexed" % ( model.__name__, field_name)) # The Riak client library does silly things under the hood. start_value = descriptor.field.to_riak(start_value) if start_value is None: # FIXME: We should be raising an exception here, but we still rely on # this having the value "None" in places. :-( start_value = 'None' else: start_value = str(start_value) if end_value is not None: end_value = str(descriptor.field.to_riak(end_value)) return descriptor.index_name, start_value, end_value class VumiMapReduceError(Exception): pass class VumiMapReduce(object): def __init__(self, mgr, riak_mapreduce_obj): self._has_run = False self._manager = mgr self._riak_mapreduce_obj = riak_mapreduce_obj @classmethod def from_field(cls, mgr, model, field_name, start_value, end_value=None): index_name, sv, ev = index_vals_for_field( model, field_name, start_value, end_value) return cls.from_index(mgr, model, index_name, sv, ev) @classmethod def from_index(cls, mgr, model, index_name, start_value, end_value=None): return cls(mgr, mgr.riak_map_reduce().index( mgr.bucket_name(model), index_name, start_value, end_value)) @classmethod def from_search(cls, mgr, model, query): return cls( mgr, mgr.riak_map_reduce().search(mgr.bucket_name(model), query)) @classmethod def from_field_match(cls, mgr, model, query, field_name, start_value, end_value=None): index_name, sv, ev = index_vals_for_field( model, field_name, start_value, end_value) return cls.from_index_match(mgr, model, query, index_name, sv, ev) @classmethod def from_index_match(cls, mgr, model, query, index_name, start_value, end_value=None): """ Do a regex OR search across the keys found in a secondary index. :param Manager mgr: The manager to use. :param Model model: The model to use. :param query: A list of dictionaries to use to search with. The dictionary is in the following format: { "key": "the key to lookup value for in the JSON dictionary", "pattern": "the regex pattern the value of `key` should match", "flags": "the modifier flags to give to the RegExp object", } :param str index_name: The name of the index :param str start_value: The start value to search the 2i on :param str end_value: The end value to search on. Defaults to `None`. """ mr = mgr.riak_map_reduce().index( mgr.bucket_name(model), index_name, start_value, end_value).map( """ function(value, keyData, arg) { /* skip deleted values, might show up during a test */ var values = value.values.filter(function(val) { return !val.metadata['X-Riak-Deleted']; }); if(values.length) { var data = JSON.parse(values[0].data); for (j in arg) { var query = arg[j]; var content = data[query.key]; var regex = RegExp(query.pattern, query.flags) if(content && regex.test(content)) { return [value.key]; } } } return []; } """, { 'arg': query, # Client lib turns this to JSON for us. }) return cls(mgr, mr) @classmethod def from_keys(cls, mgr, model, keys): bucket_name = mgr.bucket_name(model) mr = mgr.riak_map_reduce() for key in keys: mr.add_bucket_key_data(bucket_name, key, None) return cls(mgr, mr) def _assert_not_run(self): if self._has_run: raise VumiMapReduceError("This mapreduce has already run.") self._has_run = True def filter_not_found(self): self._riak_mapreduce_obj.map(function=""" function(v) { values = v.values.filter(function(val) { return !val.metadata['X-Riak-Deleted'] }) if (values) { return [v.key]; } else { return []; } }""") self._riak_mapreduce_obj.filter_not_found() def get_count(self): self._assert_not_run() self._riak_mapreduce_obj.reduce( function=["riak_kv_mapreduce", "reduce_count_inputs"]) return self._manager.run_map_reduce( self._riak_mapreduce_obj, reducer_func=lambda mgr, obj: obj[0]) def _results_to_keys(self, mgr, obj): if isinstance(obj, basestring): # Assume strings are keys. return obj else: # If we haven't been given a string, we probably have a riak link. _bucket, key, _tag = obj return key def get_keys(self): self._assert_not_run() return self._manager.run_map_reduce( self._riak_mapreduce_obj, self._results_to_keys) class Manager(object): """A wrapper around a Riak client.""" DEFAULT_LOAD_BUNCH_SIZE = 100 DEFAULT_MAPREDUCE_TIMEOUT = 4 * 60 * 1000 # in milliseconds # This is a temporary measure to give us an easy way to switch back to the # old mechanism if the new one causes problems. USE_MAPREDUCE_BUNCH_LOADING = False def __init__(self, client, bucket_prefix, load_bunch_size=None, mapreduce_timeout=None, store_versions=None, parent=None): self.client = client self.bucket_prefix = bucket_prefix self.load_bunch_size = load_bunch_size or self.DEFAULT_LOAD_BUNCH_SIZE self.mapreduce_timeout = (mapreduce_timeout or self.DEFAULT_MAPREDUCE_TIMEOUT) self._bucket_cache = {} self.store_versions = store_versions or {} self._parent = parent def proxy(self, modelcls): return ModelProxy(self, modelcls) def sub_manager(self, sub_prefix): return self.__class__( self.client, self.bucket_prefix + sub_prefix, parent=self) def bucket_name(self, modelcls_or_obj): return self.bucket_prefix + modelcls_or_obj.bucket def bucket_for_modelcls(self, modelcls): modelcls_id = id(modelcls) bucket = self._bucket_cache.get(modelcls_id) if bucket is None: bucket_name = self.bucket_name(modelcls) bucket = self.riak_bucket(bucket_name) self._bucket_cache[modelcls_id] = bucket return bucket @staticmethod def calls_manager(manager_attr): """Decorate a method that calls a manager. This redecorates with the `call_decorator` attribute on the Manager subclass used, which should be either @inlineCallbacks or @flatten_generator. """ if callable(manager_attr): # If we don't get a manager attribute name, default to 'manager'. return Manager.calls_manager('manager')(manager_attr) def redecorate(func): @wraps(func) def wrapper(self, *args, **kw): manager = getattr(self, manager_attr) return manager.call_decorator(func)(self, *args, **kw) return wrapper return redecorate @classmethod def from_config(cls, config): """Construct a manager from a dictionary of options. :param dict config: Dictionary of options for the manager. """ raise NotImplementedError("Sub-classes of Manager should implement" " .from_config(...)") def close_manager(self): """Close the client underlying this manager instance, if necessary. """ raise NotImplementedError("Sub-classes of Manager should implement" " .close_manager(...)") def riak_object(self, cls, key): """Construct an empty RiakObject for the given model class and key.""" raise NotImplementedError("Sub-classes of Manager should implement" " .riak_object(...)") def store(self, modelobj): """Store the modelobj in Riak.""" raise NotImplementedError("Sub-classes of Manager should implement" " .store(...)") def delete(self, modelobj): """Delete the modelobj from Riak.""" raise NotImplementedError("Sub-classes of Manager should implement" " .delete(...)") def load(self, cls, key, result=None): """Load a model instance for the key from Riak. If the key doesn't exist, this method should return None instead of an instance of cls. """ raise NotImplementedError("Sub-classes of Manager should implement" " .load(...)") def _migrate_riak_object(self, modelcls, key, riak_object): """ Migrate a loaded riak_object to the latest schema version. NOTE: This should only be called by subclasses. """ was_migrated = False # Run migrators until we have the correct version of the data. while riak_object.get_data() is not None: data_version = riak_object.get_data().get('$VERSION', None) if data_version == modelcls.VERSION: obj = modelcls(self, key, _riak_object=riak_object) obj.was_migrated = was_migrated return obj migrator = modelcls.MIGRATOR(modelcls, self, data_version) riak_object = migrator(riak_object).get_riak_object() was_migrated = True return None def _reverse_migrate_riak_object(self, modelobj): """ Migrate a riak_object to the required schema version before storing. NOTE: This should only be called by subclasses. """ riak_object = modelobj._riak_object modelcls = type(modelobj) model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__) store_version = self.store_versions.get(model_name, modelcls.VERSION) # Run reverse migrators until we have the correct version of the data. data_version = riak_object.get_data().get('$VERSION', None) while data_version != store_version: migrator = modelcls.MIGRATOR( modelcls, self, data_version, reverse=True) riak_object = migrator(riak_object).get_riak_object() data_version = riak_object.get_data().get('$VERSION', None) return riak_object def _load_multiple(self, cls, keys): """Load the model instances for a batch of keys from Riak. If a key doesn't exist, no object will be returned for it. """ raise NotImplementedError("Sub-classes of Manager should implement" " ._load_multiple(...)") def _load_bunch_mapreduce(self, model, keys): """Load the model instances for a batch of keys from Riak. If a key doesn't exist, no object will be returned for it. """ mr = self.mr_from_keys(model, keys) mr._riak_mapreduce_obj.map(function=""" function (v) { values = v.values.filter(function(val) { return !val.metadata['X-Riak-Deleted']; }) if (!values.length) { return []; } return [[v.key, values[0]]] } """).filter_not_found() return self.run_map_reduce( mr._riak_mapreduce_obj, lambda mgr, obj: model.load(mgr, *obj)) def _load_bunch(self, model, keys): """Load the model instances for a batch of keys from Riak. If a key doesn't exist, no object will be returned for it. """ assert len(keys) <= self.load_bunch_size if not keys: return [] if self.USE_MAPREDUCE_BUNCH_LOADING: return self._load_bunch_mapreduce(model, keys) else: return self._load_multiple(model, keys) def load_all_bunches(self, model, keys): """Load batches of model instances for a list of keys from Riak. :returns: An iterator over (possibly deferred) lists of model instances. """ while keys: batch_keys = keys[:self.load_bunch_size] keys = keys[self.load_bunch_size:] yield self._load_bunch(model, batch_keys) def riak_map_reduce(self): """Construct a RiakMapReduce object for this client.""" raise NotImplementedError("Sub-classes of Manager should implement" " .riak_map_reduce(...)") def run_map_reduce(self, mapreduce, mapper_func=None, reducer_func=None): """Run a map reduce instance and return the results mapped to objects by the map_function.""" raise NotImplementedError("Sub-classes of Manager should implement" " .run_map_reduce(...)") def should_quote_index_values(self): raise NotImplementedError("Sub-classes of Manager should implement" " .should_quote_index_values()") def index_keys(self, model, index_name, start_value, end_value=None, return_terms=None): bucket = self.bucket_for_modelcls(model) if self.should_quote_index_values(): if start_value is not None: start_value = urllib.quote(start_value) if end_value is not None: end_value = urllib.quote(end_value) return bucket.get_index( index_name, start_value, end_value, return_terms=return_terms) def index_keys_page(self, model, index_name, start_value, end_value=None, return_terms=None, max_results=None, continuation=None): bucket = self.bucket_for_modelcls(model) if self.should_quote_index_values(): if start_value is not None: start_value = urllib.quote(start_value) if end_value is not None: end_value = urllib.quote(end_value) return bucket.get_index_page( index_name, start_value, end_value, return_terms=return_terms, max_results=max_results, continuation=continuation) def mr_from_field(self, model, field_name, start_value, end_value=None): return VumiMapReduce.from_field( self, model, field_name, start_value, end_value) def mr_from_index(self, model, index_name, start_value, end_value=None): return VumiMapReduce.from_index( self, model, index_name, start_value, end_value) def mr_from_search(self, model, query): return VumiMapReduce.from_search(self, model, query) def mr_from_index_match(self, model, query, index_name, start_value, end_value=None): return VumiMapReduce.from_index_match(self, model, query, index_name, start_value, end_value) def mr_from_field_match(self, model, query, field_name, start_value, end_value=None): return VumiMapReduce.from_field_match(self, model, query, field_name, start_value, end_value) def mr_from_keys(self, model, keys): return VumiMapReduce.from_keys(self, model, keys) def real_search(self, model, query, rows=None, start=None): raise NotImplementedError() def riak_enable_search(self, model): """Enable search indexing for the model's bucket.""" raise NotImplementedError("Sub-classes of Manager should implement" " .riak_enable_search(...)") def purge_all(self): """Delete *ALL* keys in buckets whose names start buckets with this manager's bucket prefix. Use only in tests. """ raise NotImplementedError("Sub-classes of Manager should implement" " .purge_all()") class ModelProxy(object): def __init__(self, manager, modelcls): self._manager = manager self._modelcls = modelcls self.bucket = modelcls.bucket def __call__(self, key, **data): return self._modelcls(self._manager, key, **data) def load(self, key): return self._modelcls.load(self._manager, key) def load_all_bunches(self, *args, **kw): return self._modelcls.load_all_bunches(self._manager, *args, **kw) def all_keys(self): return self._modelcls.all_keys(self._manager) def index_keys(self, field_name, value, end_value=None, return_terms=None): return self._modelcls.index_keys( self._manager, field_name, value, end_value, return_terms=return_terms) def all_keys_page(self, max_results=None, continuation=None): return self._modelcls.all_keys_page( self._manager, max_results=max_results, continuation=continuation) def index_keys_page(self, field_name, value, end_value=None, return_terms=None, max_results=None, continuation=None): return self._modelcls.index_keys_page( self._manager, field_name, value, end_value, return_terms=return_terms, max_results=max_results, continuation=continuation) def index_lookup(self, field_name, value): return self._modelcls.index_lookup(self._manager, field_name, value) def index_match(self, query, field_name, value): return self._modelcls.index_match(self._manager, query, field_name, value) def search(self, **kw): return self._modelcls.search(self._manager, **kw) def raw_search(self, query): return self._modelcls.raw_search(self._manager, query) def real_search(self, query, rows=None, start=None): return self._modelcls.real_search( self._manager, query, rows=rows, start=start) def enable_search(self): return self._modelcls.enable_search(self._manager) PK=JG2vumi/persist/riak_base.py"""Basic tools for building a Riak manager.""" import json from riak import RiakClient from vumi.persist.model import VumiRiakError def _to_unicode(text, encoding='utf-8'): # If we already have unicode or `None`, there's nothing to do. if isinstance(text, (unicode, type(None))): return text # If we have a tuple, we need to do our thing with every element in it. if isinstance(text, tuple): return tuple(_to_unicode(item, encoding) for item in text) # If we get here, then we should have a bytestring. return text.decode(encoding) class VumiRiakClientBase(object): """ Wrapper around a RiakClient to manage resources better. """ def __init__(self, **client_args): self._closed = False self._raw_client = RiakClient(**client_args) # Some versions of the riak client library use simplejson by # preference, which breaks some of our unicode assumptions. This makes # sure we're using stdlib json which doesn't sometimes return # bytestrings instead of unicode. self._client.set_encoder('application/json', json.dumps) self._client.set_encoder('text/json', json.dumps) self._client.set_decoder('application/json', json.loads) self._client.set_decoder('text/json', json.loads) @property def protocol(self): return self._raw_client.protocol @property def _client(self): """ Raise an exception if closed, otherwise return underlying client. """ if self._closed: raise VumiRiakError("Can't use closed Riak client.") return self._raw_client def close(self): self._closed = True return self._raw_client.close() def bucket(self, bucket_name): return self._client.bucket(bucket_name) def put(self, *args, **kw): return self._client.put(*args, **kw) def get(self, *args, **kw): return self._client.get(*args, **kw) def delete(self, *args, **kw): return self._client.delete(*args, **kw) def mapred(self, *args, **kw): return self._client.mapred(*args, **kw) def _purge_all(self, bucket_prefix): """ Purge all objects and buckets properties belonging to buckets with the given prefix. NOTE: This operation should *ONLY* be used in tests. """ # We need to use a potentially closed client here, so we bypass the # check and reclose afterwards if necessary. buckets = self._raw_client.get_buckets() for bucket in buckets: if bucket.name.startswith(bucket_prefix): for key in bucket.get_keys(): obj = bucket.get(key) obj.delete() bucket.clear_properties() if self._closed: self.close() class VumiIndexPageBase(object): """ Wrapper around a page of index query results. Iterating over this object will return the results for the current page. """ def __init__(self, index_page): self._index_page = index_page def __iter__(self): if self._index_page.stream: raise NotImplementedError("Streaming is not currently supported.") return (_to_unicode(item) for item in self._index_page) def __len__(self): return len(self._index_page) def __eq__(self, other): return self._index_page.__eq__(other) def has_next_page(self): """ Indicate whether there are more results to follow. :returns: ``True`` if there are more results, ``False`` if this is the last page. """ return self._index_page.has_next_page() @property def continuation(self): return _to_unicode(self._index_page.continuation) # Methods that touch the network. def next_page(self): raise NotImplementedError("Subclasses must implement this.") class VumiRiakBucketBase(object): """ Wrapper around a RiakBucket to manage network access better. """ def __init__(self, riak_bucket): self._riak_bucket = riak_bucket def get_name(self): return self._riak_bucket.name # Methods that touch the network. def get_index(self, index_name, start_value, end_value=None, return_terms=None): raise NotImplementedError("Subclasses must implement this.") def get_index_page(self, index_name, start_value, end_value=None, return_terms=None, max_results=None, continuation=None): raise NotImplementedError("Subclasses must implement this.") class VumiRiakObjectBase(object): """ Wrapper around a RiakObject to manage network access better. """ def __init__(self, riak_obj): self._riak_obj = riak_obj @property def key(self): return self._riak_obj.key def get_key(self): return self.key def get_content_type(self): return self._riak_obj.content_type def set_content_type(self, content_type): self._riak_obj.content_type = content_type def get_data(self): return self._riak_obj.data def set_data(self, data): self._riak_obj.data = data def set_encoded_data(self, encoded_data): self._riak_obj.encoded_data = encoded_data def set_data_field(self, key, value): self._riak_obj.data[key] = value def delete_data_field(self, key): del self._riak_obj.data[key] def get_indexes(self): return self._riak_obj.indexes def set_indexes(self, indexes): self._riak_obj.indexes = indexes def add_index(self, index_name, index_value): self._riak_obj.add_index(index_name, index_value) def remove_index(self, index_name, index_value=None): self._riak_obj.remove_index(index_name, index_value) def get_user_metadata(self): return self._riak_obj.usermeta def set_user_metadata(self, usermeta): self._riak_obj.usermeta = usermeta def get_bucket(self): raise NotImplementedError("Subclasses must implement this.") # Methods that touch the network. def _call_and_wrap(self, func): """ Call a function that touches the network and wrap the result in this class. """ raise NotImplementedError("Subclasses must implement this.") def store(self): return self._call_and_wrap(self._riak_obj.store) def reload(self): return self._call_and_wrap(self._riak_obj.reload) def delete(self): return self._call_and_wrap(self._riak_obj.delete) PK[Hz++vumi/persist/txredis_manager.py# -*- test-case-name: vumi.persist.tests.test_txredis_manager -*- # txredis is made of silliness. # There are two variants, both of which call themselves version 2.2. One has # everything in txredis.protocol, the other has the client stuff in # txredis.client. try: import txredis.client as txrc txr = txrc except ImportError: import txredis.protocol as txrp txr = txrp import txredis.exceptions from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, succeed, Deferred from vumi.persist.redis_base import Manager from vumi.persist.fake_redis import ( FakeRedis, ResponseError as FakeResponseError) class VumiRedis(txr.Redis): """Wrapper around txredis to make it more suitable for our needs. Aside from the various API operations we need to implement to match the other redis client, we add a deferred that fires when we've finished connecting to the redis server. This avoids problems with trying to use a client that hasn't completely connected yet. TODO: We need to find a way to test this stuff """ def __init__(self, *args, **kw): super(VumiRedis, self).__init__(*args, **kw) self.connected_d = Deferred() self._disconnected_d = Deferred() self._client_shutdown_called = False def connectionMade(self): d = super(VumiRedis, self).connectionMade() d.addCallback(lambda _: self) return d.chainDeferred(self.connected_d) def connectionLost(self, reason): super(VumiRedis, self).connectionLost(reason) self._disconnected_d.callback(None) def _client_shutdown(self): """ Issue a ``QUIT`` command and wait for the connection to close. A single client may be used by multiple manager instances, so we only issue the ``QUIT`` once. This still leaves us with a potential race condition if the connection is being used elsewhere, but we can't do anything useful about that here. """ self.factory.stopTrying() d = succeed(None) if not self._client_shutdown_called: self._client_shutdown_called = True d.addCallback(lambda _: self.quit()) return d.addCallback(lambda _: self._disconnected_d) def _ok_to_true(self, r): """ Some commands return 'OK', but we expect True. """ return True if r == 'OK' else r def hget(self, key, field): d = super(VumiRedis, self).hget(key, field) d.addCallback(lambda r: r.get(field) if r else None) return d def lrem(self, key, value, num=0): return super(VumiRedis, self).lrem(key, value, count=num) def ltrim(self, key, start, end): d = super(VumiRedis, self).ltrim(key, start, end) d.addCallback(self._ok_to_true) return d # lpop() and rpop() are implemented in txredis 2.2.1 (which is in Ubuntu), # but not 2.2 (which is in pypi). Annoyingly, pop() in 2.2.1 calls lpop() # and rpop(), so we can't just delegate to that as we did before. def rpop(self, key): self._send('RPOP', key) return self.getResponse() def lpop(self, key): self._send('LPOP', key) return self.getResponse() def set(self, key, value, *args, **kw): d = super(VumiRedis, self).set(key, value, *args, **kw) d.addCallback(self._ok_to_true) return d def setex(self, key, seconds, value): return self.set(key, value, expire=seconds) def rename(self, key, newkey): d = super(VumiRedis, self).rename(key, newkey) d.addCallback(self._ok_to_true) return d # setnx() is implemented in txredis 2.2.1 (which is in Ubuntu), but not 2.2 # (which is in pypi). Annoyingly, set() in 2.2.1 calls setnx(), so we can't # just delegate to that as we did before. def setnx(self, key, value): self._send('SETNX', key, value) return self.getResponse() def zadd(self, key, *args, **kwargs): if args: if len(args) % 2 != 0: raise ValueError("ZADD requires an equal number of " "values and scores") pieces = zip(args[::2], args[1::2]) pieces.extend(kwargs.iteritems()) orig_zadd = super(VumiRedis, self).zadd d = succeed(0) def do_zadd(s, key, member, score): d = orig_zadd(key, member, score) d.addCallback(lambda r: r + s) return d for member, score in pieces: d.addCallback(do_zadd, key, member, score) return d def zrange(self, key, start, end, desc=False, withscores=False): return super(VumiRedis, self).zrange(key, start, end, withscores=withscores, reverse=desc) def zrangebyscore(self, key, min, max, start=None, num=None, withscores=False, score_cast_func=float): d = super(VumiRedis, self).zrangebyscore( key, min, max, offset=start, count=num, withscores=withscores) if withscores: d.addCallback(lambda r: [(v, score_cast_func(s)) for v, s in r]) return d def scan(self, cursor, match=None, count=None): """ Scan through all the keys in the database returning those that match the pattern ``match``. The ``cursor`` specifies where to start a scan and ``count`` determines how much work to do looking for keys on each scan. ``cursor`` may be ``None`` or ``'0'`` to indicate a new scan. Any other value should be treated as an opaque string. .. note:: Requires redis server 2.8 or later. """ args = [] if cursor is None: cursor = '0' if match is not None: args.extend(("MATCH", match)) if count is not None: args.extend(("COUNT", count)) self._send("SCAN", cursor, *args) d = self.getResponse() d.addCallback( lambda r: [(None if r[0] == '0' or r[0] == 0 else r[0]), r[1]]) return d def ttl(self, key): # Synchronous redis returns None if -1 or -2 is returned but # txredis doesn't. Older sync redis' return -2 if the key does not # exist so we require redis >= 2.7.1 in setup.py (WAT). d = super(VumiRedis, self).ttl(key) d.addCallback(lambda r: (None if r < 0 else r)) return d # txredis doesn't implement this. def persist(self, key): """ Remove the expiration from a key, causing it to persist indefinitely. """ self._send('PERSIST', key) return self.getResponse() def type(self, key): d = self.get_type(key) # txredis turns 'none' into None, so we reverse that for consistency. d.addCallback(lambda r: r if r is not None else 'none') return d # txredis doesn't implement this. def pfadd(self, key, *values): """ Add the values to the HyperLogLog data structure at the given key. .. note:: Requires redis server 2.8.9 or later. """ self._send('PFADD', key, *values) return self.getResponse() # txredis doesn't implement this. def pfcount(self, key): """ Return the approximate cardinality of the HyperLogLog at the given key. .. note:: Requires redis server 2.8.9 or later. """ self._send('PFCOUNT', key) return self.getResponse() class VumiRedisClientFactory(txr.RedisClientFactory): protocol = VumiRedis # Faster reconnecting. maxDelay = 5.0 initialDelay = 0.01 def buildProtocol(self, addr): self.client = self.protocol(*self._args, **self._kwargs) self.client.factory = self self.resetDelay() prev_d, self.deferred = self.deferred, Deferred() prev_d.callback(self.client) return self.client class TxRedisManager(Manager): call_decorator = staticmethod(inlineCallbacks) RESPONSE_ERROR = txredis.exceptions.ResponseError def __init__(self, *args, **kwargs): super(TxRedisManager, self).__init__(*args, **kwargs) @classmethod def _fake_manager(cls, fake_redis, manager_config): if fake_redis is None: fake_redis = FakeRedis(async=True) manager_config['config']['FAKE_REDIS'] = fake_redis manager = cls(fake_redis, **manager_config) # Because ._close() assumes a real connection. manager._close = fake_redis.teardown manager.RESPONSE_ERROR = FakeResponseError return succeed(manager) @classmethod def _manager_from_config(cls, client_config, manager_config): """Construct a manager from a dictionary of options. :param dict config: Dictionary of options for the manager. :param str key_prefix: Key prefix for namespacing. """ host = client_config.pop('host', '127.0.0.1') port = client_config.pop('port', 6379) factory = VumiRedisClientFactory(**client_config) reactor.connectTCP(host, port, factory) d = factory.deferred.addCallback(lambda client: client.connected_d) d.addCallback(cls._make_manager, manager_config) return d @classmethod def _make_manager(cls, client, manager_config): manager = cls(client, **manager_config) cls._attach_reconnector(manager) return manager def set_client(self, client): self._client_proxy.client = client return client @staticmethod def _attach_reconnector(manager): def set_client(client): return manager.set_client(client) def reconnect(client): client.factory.deferred.addCallback(reconnect) return client.connected_d.addCallback(set_client) manager._client.factory.deferred.addCallback(reconnect) return manager def _close(self): """ Close redis connection. """ return self._client._client_shutdown() @inlineCallbacks def _purge_all(self): """Delete *ALL* keys whose names start with this manager's key prefix. Use only in tests. """ # Given the races around connection closing, the easiest thing to do # here is to create a new manager with the same config for cleanup # operations. new_manager = yield self.from_config(self._config) # If we're a submanager we might have a different key prefix. new_manager._key_prefix = self._key_prefix yield new_manager._do_purge() yield new_manager._close() @inlineCallbacks def _do_purge(self): for key in (yield self.keys()): yield self.delete(key) def _make_redis_call(self, call, *args, **kw): """Make a redis API call using the underlying client library. """ return getattr(self._client, call)(*args, **kw) def _filter_redis_results(self, func, results): """Filter results of a redis call. """ return results.addCallback(func) PK*gcH vumi/persist/redis_manager.py# -*- test-case-name: vumi.persist.tests.test_redis_manager -*- import redis import redis.exceptions from vumi.persist.redis_base import Manager from vumi.persist.fake_redis import ( FakeRedis, ResponseError as FakeResponseError) from vumi.utils import flatten_generator class VumiRedis(redis.Redis): """ Custom Vumi redis client implementation. """ def setex(self, key, seconds, value): """ The underlying .setex() signature doesn't match our implementation in the txredis manager. This wrapper swaps the last two parameters, seconds and value, so that they do. """ return super(VumiRedis, self).setex(key, value, seconds) def scan(self, cursor, match=None, count=None): """ Scan through all the keys in the database returning those that match the pattern ``match``. The ``cursor`` specifies where to start a scan and ``count`` determines how much work to do looking for keys on each scan. ``cursor`` may be ``None`` or ``'0'`` to indicate a new scan. Any other value should be treated as an opaque string. .. note:: Requires redis server 2.8 or later. """ args = [] if cursor is None: cursor = '0' if match is not None: args.extend(("MATCH", match)) if count is not None: args.extend(("COUNT", count)) cursor, keys = self.execute_command("SCAN", cursor, *args) if cursor == '0' or cursor == 0: cursor = None return (cursor, keys) class RedisManager(Manager): RESPONSE_ERROR = redis.exceptions.ResponseError call_decorator = staticmethod(flatten_generator) @classmethod def _fake_manager(cls, fake_redis, manager_config): if fake_redis is None: fake_redis = FakeRedis(async=False) manager_config['config']['FAKE_REDIS'] = fake_redis manager = cls(fake_redis, **manager_config) # Because ._close() assumes a real connection. manager._close = fake_redis.teardown manager.RESPONSE_ERROR = FakeResponseError return manager @classmethod def _manager_from_config(cls, config, manager_config): """Construct a manager from a dictionary of options. :param dict config: Dictionary of options for the manager. :param str key_prefix: Key prefix for namespacing. """ return cls(VumiRedis(**config), **manager_config) def _close(self): """Close redis connection.""" # Close all the connections this client may have open. self._client.connection_pool.disconnect() def _purge_all(self): """Delete *ALL* keys whose names start with this manager's key prefix. Use only in tests. """ for key in self.keys(): self.delete(key) def _make_redis_call(self, call, *args, **kw): """Make a redis API call using the underlying client library. """ return getattr(self._client, call)(*args, **kw) def _filter_redis_results(self, func, results): """Filter results of a redis call. """ return func(results) PK=JG?##vumi/persist/txriak_manager.py# -*- test-case-name: vumi.persist.tests.test_txriak_manager -*- """An async manager implementation on top of the riak Python package.""" from riak import RiakObject, RiakMapReduce, RiakError from twisted.internet.threads import deferToThread from twisted.internet.defer import ( inlineCallbacks, returnValue, gatherResults, maybeDeferred, succeed) from vumi.persist.model import Manager, VumiRiakError from vumi.persist.riak_base import ( VumiRiakClientBase, VumiIndexPageBase, VumiRiakBucketBase, VumiRiakObjectBase) def riakErrorHandler(failure): e = failure.trap(RiakError) raise VumiRiakError(e) class VumiTxRiakClient(VumiRiakClientBase): """ Wrapper around a RiakClient to manage resources better. """ class VumiTxIndexPage(VumiIndexPageBase): """ Wrapper around a page of index query results. Iterating over this object will return the results for the current page. """ # Methods that touch the network. def next_page(self): """ Fetch the next page of results. :returns: A new :class:`VumiTxIndexPage` object containing the next page of results. """ if not self.has_next_page(): return succeed(None) d = deferToThread(self._index_page.next_page) d.addCallback(type(self)) d.addErrback(riakErrorHandler) return d class VumiTxRiakBucket(VumiRiakBucketBase): """ Wrapper around a RiakBucket to manage network access better. """ # Methods that touch the network. def get_index(self, index_name, start_value, end_value=None, return_terms=None): d = self.get_index_page( index_name, start_value, end_value, return_terms=return_terms) d.addCallback(list) return d def get_index_page(self, index_name, start_value, end_value=None, return_terms=None, max_results=None, continuation=None): d = deferToThread( self._riak_bucket.get_index, index_name, start_value, end_value, return_terms=return_terms, max_results=max_results, continuation=continuation) d.addCallback(VumiTxIndexPage) d.addErrback(riakErrorHandler) return d class VumiTxRiakObject(VumiRiakObjectBase): """ Wrapper around a RiakObject to manage network access better. """ def get_bucket(self): return VumiTxRiakBucket(self._riak_obj.bucket) # Methods that touch the network. def _call_and_wrap(self, func): """ Call a function that touches the network and wrap the result in this class. """ d = deferToThread(func) d.addCallback(type(self)) return d class TxRiakManager(Manager): """An async persistence manager for the riak Python package.""" call_decorator = staticmethod(inlineCallbacks) @classmethod def from_config(cls, config): config = config.copy() bucket_prefix = config.pop('bucket_prefix') load_bunch_size = config.pop( 'load_bunch_size', cls.DEFAULT_LOAD_BUNCH_SIZE) mapreduce_timeout = config.pop( 'mapreduce_timeout', cls.DEFAULT_MAPREDUCE_TIMEOUT) transport_type = config.pop('transport_type', 'http') store_versions = config.pop('store_versions', None) host = config.get('host', '127.0.0.1') port = config.get('port') prefix = config.get('prefix', 'riak') mapred_prefix = config.get('mapred_prefix', 'mapred') client_id = config.get('client_id') transport_options = config.get('transport_options', {}) client_args = dict( host=host, prefix=prefix, mapred_prefix=mapred_prefix, protocol=transport_type, client_id=client_id, transport_options=transport_options) if port is not None: client_args['port'] = port client = VumiTxRiakClient(**client_args) return cls( client, bucket_prefix, load_bunch_size=load_bunch_size, mapreduce_timeout=mapreduce_timeout, store_versions=store_versions) def close_manager(self): if self._parent is None: # Only top-level managers may close the client. return deferToThread(self.client.close) return succeed(None) def _is_unclosed(self): # This returns `True` if the manager needs to be explicitly closed and # hasn't been closed yet. It should only be used in tests that ensure # client objects aren't leaked. if self._parent is not None: return False return not self.client._closed def riak_bucket(self, bucket_name): bucket = self.client.bucket(bucket_name) if bucket is not None: bucket = VumiTxRiakBucket(bucket) return bucket def riak_object(self, modelcls, key, result=None): bucket = self.bucket_for_modelcls(modelcls)._riak_bucket riak_object = VumiTxRiakObject(RiakObject(self.client, bucket, key)) if result: metadata = result['metadata'] indexes = metadata['index'] if hasattr(indexes, 'items'): # TODO: I think this is a Riak bug. In some cases # (maybe when there are no indexes?) the index # comes back as a list, in others (maybe when # there are indexes?) it comes back as a dict. indexes = indexes.items() data = result['data'] riak_object.set_content_type(metadata['content-type']) riak_object.set_indexes(indexes) riak_object.set_encoded_data(data) else: riak_object.set_content_type("application/json") riak_object.set_data({'$VERSION': modelcls.VERSION}) return riak_object def store(self, modelobj): riak_object = self._reverse_migrate_riak_object(modelobj) d = riak_object.store() d.addCallback(lambda _: modelobj) return d def delete(self, modelobj): d = modelobj._riak_object.delete() d.addCallback(lambda _: None) return d @inlineCallbacks def load(self, modelcls, key, result=None): riak_object = self.riak_object(modelcls, key, result) if not result: yield riak_object.reload() returnValue(self._migrate_riak_object(modelcls, key, riak_object)) def _load_multiple(self, modelcls, keys): d = gatherResults([self.load(modelcls, key) for key in keys]) d.addCallback(lambda objs: [obj for obj in objs if obj is not None]) return d def riak_map_reduce(self): mapreduce = RiakMapReduce(self.client) # Hack: We replace the two methods that hit the network with # deferToThread wrappers to prevent accidental sync calls in # other code. run = mapreduce.run stream = mapreduce.stream mapreduce.run = lambda *a, **kw: deferToThread(run, *a, **kw) mapreduce.stream = lambda *a, **kw: deferToThread(stream, *a, **kw) return mapreduce def run_map_reduce(self, mapreduce, mapper_func=None, reducer_func=None): def map_results(raw_results): deferreds = [] for row in raw_results: deferreds.append(maybeDeferred(mapper_func, self, row)) return gatherResults(deferreds) mapreduce_done = mapreduce.run(timeout=self.mapreduce_timeout) if mapper_func is not None: mapreduce_done.addCallback(map_results) if reducer_func is not None: mapreduce_done.addCallback(lambda r: reducer_func(self, r)) return mapreduce_done def _search_iteration(self, bucket, query, rows, start): d = deferToThread(bucket.search, query, rows=rows, start=start) d.addCallback(lambda r: [doc["id"] for doc in r["docs"]]) return d @inlineCallbacks def real_search(self, modelcls, query, rows=None, start=None): rows = 1000 if rows is None else rows bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) if start is not None: keys = yield self._search_iteration(bucket, query, rows, start) returnValue(keys) keys = [] new_keys = yield self._search_iteration(bucket, query, rows, 0) while new_keys: keys.extend(new_keys) new_keys = yield self._search_iteration( bucket, query, rows, len(keys)) returnValue(keys) def riak_enable_search(self, modelcls): bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) return deferToThread(bucket.enable_search) def riak_search_enabled(self, modelcls): bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) return deferToThread(bucket.search_enabled) def should_quote_index_values(self): return False def purge_all(self): return deferToThread(self.client._purge_all, self.bucket_prefix) PK=JGܵvumi/persist/__init__.py"""The vumi.persist API.""" PK=JGrTÎÎvumi/persist/fields.py# -*- test-case-name: vumi.persist.tests.test_fields -*- """Field types for Vumi's persistence models.""" import iso8601 from datetime import datetime from vumi.message import format_vumi_date, parse_vumi_date from vumi.utils import to_kwargs # Index values in Riak have to be non-empty, so a zero-length string # counts as "no value". Since we still have legacy data that was # inadvertantly indexed with "None" because of the str() call in the # library and we still have legacy code that relies on an index search # for a value of "None", fixing this properly here will break existing # functionality. Once we have rewritten the offending code to not use # "None" in the index, we can remove the hack below and be happier. STORE_NONE_FOR_EMPTY_INDEX = False class ValidationError(Exception): """Raised when a value assigned to a field is invalid.""" class FieldDescriptor(object): """Property for getting and setting fields.""" def __init__(self, key, field): self.key = key self.field = field if self.field.index: if self.field.index_name is None: self.index_name = "%s_bin" % self.key else: self.index_name = field.index_name else: self.index_name = None def setup(self, model_cls): self.model_cls = model_cls def validate(self, value): self.field.validate(value) def initialize(self, modelobj, value): self.validate(value) self.set_value(modelobj, value) def _add_index(self, modelobj, value): # XXX: The underlying libraries call str() on whatever index values we # provide, so we do this explicitly here and special-case None. if value is None: if STORE_NONE_FOR_EMPTY_INDEX: # Hackery for things that need "None" index values. modelobj._riak_object.add_index(self.index_name, "None") return modelobj._riak_object.add_index(self.index_name, str(value)) def get_riak_data(self, modelobj, default=None, key=None): if key is None: key = self.key return modelobj._riak_object.get_data().get(key, default) def set_riak_data(self, modelobj, raw_value, key=None): if key is None: key = self.key old_raw_value = modelobj._riak_object.get_data().get(key) # We set this even if it's "unmodified" because we want explicit `None` # values rather than missing fields. modelobj._riak_object.set_data_field(key, raw_value) if old_raw_value != raw_value: modelobj._field_changed(self.key) def delete_riak_data(self, modelobj, key=None): if key is None: key = self.key old_raw_value = modelobj._riak_object.get_data().get(key) modelobj._riak_object.delete_data_field(key) if old_raw_value is not None: modelobj._field_changed(self.key) def set_value(self, modelobj, value): """Set the value associated with this descriptor.""" raw_value = self.field.to_riak(value) self.set_riak_data(modelobj, raw_value) if self.index_name is not None: modelobj._riak_object.remove_index(self.index_name) self._add_index(modelobj, raw_value) def get_value(self, modelobj): """Get the value associated with this descriptor.""" return self.field.from_riak(self.get_riak_data(modelobj)) def clean(self, modelobj): """Do any cleanup of the model data for this descriptor after loading the data from Riak.""" pass def model_field_changed(self, modelobj, changed_field_name): """ Do any necessary computation when a field changes. """ pass def __repr__(self): return "<%s key=%s field=%r>" % (self.__class__.__name__, self.key, self.field) def __get__(self, instance, owner): if instance is None: return self.field return self.get_value(instance) def __set__(self, instance, value): # instance can never be None here self.validate(value) self.set_value(instance, value) class Field(object): """Base class for model attributes / fields. :param object default: Default value for the field. The default default is None. :param boolean null: Whether None is allowed as a value. Default is False (which means the field must either be specified explicitly or by a non-None default). :param boolen index: Whether the field should also be indexed. Default is False. :param string index_name: The name to use for the index. The default is the field name followed by _bin. """ descriptor_class = FieldDescriptor # whether an attempt should be made to initialize the field on # model instance creation initializable = True def __init__(self, default=None, null=False, index=False, index_name=None): self.default = default self.null = null self.index = index self.index_name = index_name def get_descriptor(self, key): return self.descriptor_class(key, self) def validate(self, value): """Validate a value. Checks null values and calls .validate() for non-null values. Raises ValidationError if a value is invalid. """ if not self.null and value is None: raise ValidationError("None is not allowed as a value for non-null" " fields.") if value is not None: self.custom_validate(value) def custom_validate(self, value): """Check whether a non-null value is valid for this field.""" pass def to_riak(self, value): return self.custom_to_riak(value) if value is not None else None def custom_to_riak(self, value): """Convert a non-None value to something storable by Riak.""" return value def from_riak(self, raw_value): return (self.custom_from_riak(raw_value) if raw_value is not None else None) def custom_from_riak(self, raw_value): """Convert a non-None value stored by Riak to Python.""" return raw_value class Integer(Field): """Field that accepts integers. Additional parameters: :param integer min: Minimum allowed value (default is `None` which indicates no minimum). :param integer max: Maximum allowed value (default is `None` which indicates no maximum). """ def __init__(self, min=None, max=None, **kw): super(Integer, self).__init__(**kw) self.min = min self.max = max def custom_validate(self, value): if not isinstance(value, (int, long)): raise ValidationError("Value %r is not an integer." % (value,)) if self.min is not None and value < self.min: raise ValidationError("Value %r too low (minimum value is %d)." % (value, self.min)) if self.max is not None and value > self.max: raise ValidationError("Value %r too high (maximum value is %d)." % (value, self.max)) class Boolean(Field): """Field that is either True or False. """ def custom_validate(self, value): if not isinstance(value, bool): raise ValidationError('Value %r is not a boolean.' % (value,)) class Unicode(Field): """Field that accepts unicode strings. Additional parameters: :param integer max_length: Maximum allowed length (default is `None` which indicates no maximum). """ def __init__(self, max_length=None, **kw): super(Unicode, self).__init__(**kw) self.max_length = max_length def custom_validate(self, value): if not isinstance(value, unicode): raise ValidationError("Value %r is not a unicode string." % (value,)) if self.max_length is not None and len(value) > self.max_length: raise ValidationError("Value %r too long (maximum length is %d)." % (value, self.max_length)) class Tag(Field): """Field that represents a Vumi tag.""" def custom_validate(self, value): if not isinstance(value, tuple) or len(value) != 2: raise ValidationError("Tags %r should be a (pool, tag_name)" " tuple" % (value,)) def custom_to_riak(self, value): return list(value) def custom_from_riak(self, value): return tuple(value) class TimestampDescriptor(FieldDescriptor): """A field descriptor for timestamp fields.""" def set_value(self, modelobj, value): if value is not None and not isinstance(value, datetime): # we can be sure that this is a iso8601 parseable string, since it # passed validation value = iso8601.parse_date(value) super(TimestampDescriptor, self).set_value(modelobj, value) class Timestamp(Field): """Field that stores a datetime.""" descriptor_class = TimestampDescriptor def custom_validate(self, value): if isinstance(value, datetime): return try: iso8601.parse_date(value) return except iso8601.ParseError: pass raise ValidationError("Timestamp field expects a datetime or an " "iso8601 formatted string.") def custom_to_riak(self, value): return format_vumi_date(value) def custom_from_riak(self, value): return parse_vumi_date(value) class Json(Field): """Field that stores an object that can be serialized to/from JSON.""" pass class VumiMessageDescriptor(FieldDescriptor): """Property for getting and setting fields.""" def setup(self, model_cls): super(VumiMessageDescriptor, self).setup(model_cls) self.message_class = self.field.message_class if self.field.prefix is None: self.prefix = "%s." % self.key else: self.prefix = self.field.prefix def _clear_keys(self, modelobj): for key in modelobj._riak_object.get_data().keys(): if key.startswith(self.prefix): self.delete_riak_data(modelobj, key) def _timestamp_to_json(self, dt): return format_vumi_date(dt) def _timestamp_from_json(self, value): return parse_vumi_date(value) def set_value(self, modelobj, msg): """Set the value associated with this descriptor.""" self._clear_keys(modelobj) if msg is None: return for key, value in msg.payload.iteritems(): if key == self.message_class._CACHE_ATTRIBUTE: continue # TODO: timestamp as datetime in payload must die. if key == "timestamp": value = self._timestamp_to_json(value) full_key = "%s%s" % (self.prefix, key) self.set_riak_data(modelobj, value, full_key) def get_value(self, modelobj): """Get the value associated with this descriptor.""" payload = {} for key, value in modelobj._riak_object.get_data().iteritems(): if key.startswith(self.prefix): key = key[len(self.prefix):] # TODO: timestamp as datetime in payload must die. if key == "timestamp": value = self._timestamp_from_json(value) payload[key] = value if not payload: return None return self.field.message_class(**to_kwargs(payload)) class VumiMessage(Field): """Field that represents a Vumi message. Additional parameters: :param class message_class: The class of the message objects being stored. Usually one of Message, TransportUserMessage or TransportEvent. :param string prefix: The prefix to use when storing message payload keys in Riak. Default is the name of the field followed by a dot ('.'). Note:: The special message attribute ``__cache__`` is not stored by this field. """ descriptor_class = VumiMessageDescriptor def __init__(self, message_class, prefix=None, **kw): super(VumiMessage, self).__init__(**kw) self.message_class = message_class self.prefix = prefix def custom_validate(self, value): if not isinstance(value, self.message_class): raise ValidationError("Message %r should be an instance of %r" % (value, self.message_class)) class FieldWithSubtype(Field): """Base class for a field that is a collection of other fields of a single type. :param Field field_type: The field specification for the dynamic values. Default is Unicode(). """ def __init__(self, field_type=None, **kw): super(FieldWithSubtype, self).__init__(**kw) if field_type is None: field_type = Unicode() if field_type.descriptor_class is not FieldDescriptor: raise RuntimeError("Dynamic fields only supports fields that" " that use the basic FieldDescriptor class") self.field_type = field_type def validate_subfield(self, value): self.field_type.validate(value) def subfield_to_riak(self, value): return self.field_type.to_riak(value) def subfield_from_riak(self, value): return self.field_type.from_riak(value) class DynamicDescriptor(FieldDescriptor): """A field descriptor for dynamic fields.""" def setup(self, model_cls): super(DynamicDescriptor, self).setup(model_cls) if self.field.prefix is None: self.prefix = "%s." % self.key else: self.prefix = self.field.prefix def initialize(self, modelobj, valuedict): if valuedict is not None: self.update(modelobj, valuedict) def get_value(self, modelobj): return DynamicProxy(self, modelobj) def set_value(self, modelobj, valuedict): self.clear(modelobj) self.update(modelobj, valuedict) def clear(self, modelobj): keys = list(self.iterkeys(modelobj)) for key in keys: self.delete_dynamic_value(modelobj, key) def iterkeys(self, modelobj): prefix_len = len(self.prefix) data = modelobj._riak_object.get_data() return (key[prefix_len:] for key in data.iterkeys() if key.startswith(self.prefix)) def iteritems(self, modelobj): prefix_len = len(self.prefix) from_riak = self.field.subfield_from_riak data = modelobj._riak_object.get_data() return ((key[prefix_len:], from_riak(value)) for key, value in data.iteritems() if key.startswith(self.prefix)) def update(self, modelobj, otherdict): # this is a separate method so it can succeed or fail # somewhat atomically in the case where otherdict contains # bad keys or values items = [(self.prefix + key, self.field.subfield_to_riak(value)) for key, value in otherdict.iteritems()] for key, value in items: self.set_riak_data(modelobj, value, key=key) def get_dynamic_value(self, modelobj, dynamic_key): return self.field.subfield_from_riak( self.get_riak_data(modelobj, key=(self.prefix + dynamic_key))) def set_dynamic_value(self, modelobj, dynamic_key, value): self.field.validate_subfield(value) value = self.field.subfield_to_riak(value) self.set_riak_data(modelobj, value, self.prefix + dynamic_key) def delete_dynamic_value(self, modelobj, dynamic_key): self.delete_riak_data(modelobj, self.prefix + dynamic_key) def has_dynamic_key(self, modelobj, dynamic_key): key = self.prefix + dynamic_key return key in modelobj._riak_object.get_data() class DynamicProxy(object): def __init__(self, descriptor, modelobj): self._descriptor = descriptor self._modelobj = modelobj def iterkeys(self): return self._descriptor.iterkeys(self._modelobj) def keys(self): return list(self.iterkeys()) def iteritems(self): return self._descriptor.iteritems(self._modelobj) def items(self): return list(self.iteritems()) def itervalues(self): return (value for _key, value in self.iteritems()) def values(self): return list(self.itervalues()) def update(self, otherdict): return self._descriptor.update(self._modelobj, otherdict) def clear(self): self._descriptor.clear(self._modelobj) def copy(self): return dict(self.iteritems()) def __getitem__(self, key): return self._descriptor.get_dynamic_value(self._modelobj, key) def __setitem__(self, key, value): self._descriptor.set_dynamic_value(self._modelobj, key, value) def __delitem__(self, key): self._descriptor.delete_dynamic_value(self._modelobj, key) def __contains__(self, key): return self._descriptor.has_dynamic_key(self._modelobj, key) class Dynamic(FieldWithSubtype): """A field that allows sub-fields to be added dynamically. :param Field field_type: The field specification for the dynamic values. Default is Unicode(). :param string prefix: The prefix to use when storing these values in Riak. Default is the name of the field followed by a dot ('.'). """ descriptor_class = DynamicDescriptor def __init__(self, field_type=None, prefix=None): super(Dynamic, self).__init__(field_type=field_type) self.prefix = prefix def custom_validate(self, valuedict): if not isinstance(valuedict, dict): raise ValidationError( "Value %r should be a dict of subfield name-value pairs" % valuedict) for key, value in valuedict.iteritems(): self.validate_subfield(value) if not isinstance(key, unicode): raise ValidationError("Dynamic field needs unicode keys.") class ListOfDescriptor(FieldDescriptor): """A field descriptor for ListOf fields.""" def get_value(self, modelobj): return ListProxy(self, modelobj) def get_list_item(self, modelobj, list_idx): raw_item = self.get_riak_data(modelobj, [])[list_idx] return self.field.subfield_from_riak(raw_item) def _set_model_data(self, modelobj, raw_values): self.set_riak_data(modelobj, raw_values) if self.index_name is not None: modelobj._riak_object.remove_index(self.index_name) for value in raw_values: self._add_index(modelobj, value) def set_value(self, modelobj, values): map(self.field.validate_subfield, values) raw_values = [self.field.subfield_to_riak(value) for value in values] self._set_model_data(modelobj, raw_values) def set_list_item(self, modelobj, list_idx, value): self.field.validate_subfield(value) raw_value = self.field.subfield_to_riak(value) field_list = self.get_riak_data(modelobj, []) field_list[list_idx] = raw_value self._set_model_data(modelobj, field_list) def del_list_item(self, modelobj, list_idx): field_list = self.get_riak_data(modelobj, []) del field_list[list_idx] self._set_model_data(modelobj, field_list) def append_list_item(self, modelobj, value): self.field.validate_subfield(value) raw_value = self.field.subfield_to_riak(value) field_list = self.get_riak_data(modelobj, []) field_list.append(raw_value) self._set_model_data(modelobj, field_list) def remove_list_item(self, modelobj, value): self.field.validate_subfield(value) raw_value = self.field.subfield_to_riak(value) field_list = self.get_riak_data(modelobj, []) field_list.remove(raw_value) self._set_model_data(modelobj, field_list) def extend_list(self, modelobj, values): map(self.field.validate_subfield, values) raw_values = [self.field.subfield_to_riak(value) for value in values] field_list = self.get_riak_data(modelobj, []) field_list.extend(raw_values) self._set_model_data(modelobj, field_list) def iter_list(self, modelobj): raw_list = self.get_riak_data(modelobj, []) for raw_value in raw_list: yield self.field.subfield_from_riak(raw_value) class ListProxy(object): def __init__(self, descriptor, modelobj): self._descriptor = descriptor self._modelobj = modelobj def __getitem__(self, idx): return self._descriptor.get_list_item(self._modelobj, idx) def __setitem__(self, idx, value): self._descriptor.set_list_item(self._modelobj, idx, value) def __delitem__(self, idx): self._descriptor.del_list_item(self._modelobj, idx) def remove(self, value): self._descriptor.remove_list_item(self._modelobj, value) def append(self, value): self._descriptor.append_list_item(self._modelobj, value) def extend(self, values): self._descriptor.extend_list(self._modelobj, values) def __iter__(self): return self._descriptor.iter_list(self._modelobj) class ListOf(FieldWithSubtype): """A field that contains a list of values of some other type. :param Field field_type: The field specification for the dynamic values. Default is Unicode(). """ descriptor_class = ListOfDescriptor def __init__(self, field_type=None, **kw): super(ListOf, self).__init__(field_type=field_type, default=list, **kw) def custom_validate(self, valuelist): if not isinstance(valuelist, list): raise ValidationError( "Value %r should be a list of values" % valuelist) map(self.validate_subfield, valuelist) class SetOfDescriptor(FieldDescriptor): """ A field descriptor for SetOf fields. """ def get_value(self, modelobj): return SetProxy(self, modelobj) def _get_model_data(self, modelobj): return set(self.get_riak_data(modelobj, [])) def _set_model_data(self, modelobj, raw_values): raw_values = sorted(set(raw_values)) self.set_riak_data(modelobj, raw_values) if self.index_name is not None: modelobj._riak_object.remove_index(self.index_name) for value in raw_values: self._add_index(modelobj, value) def set_contains_item(self, modelobj, value): field_set = self._get_model_data(modelobj) return value in field_set def set_value(self, modelobj, values): map(self.field.validate_subfield, values) raw_values = [self.field.subfield_to_riak(value) for value in values] self._set_model_data(modelobj, raw_values) def add_set_item(self, modelobj, value): self.field.validate_subfield(value) field_set = self._get_model_data(modelobj) field_set.add(self.field.subfield_to_riak(value)) self._set_model_data(modelobj, field_set) def remove_set_item(self, modelobj, value): self.field.validate_subfield(value) field_set = self._get_model_data(modelobj) field_set.remove(value) self._set_model_data(modelobj, field_set) def discard_set_item(self, modelobj, value): self.field.validate_subfield(value) field_set = self._get_model_data(modelobj) field_set.discard(value) self._set_model_data(modelobj, field_set) def update_set(self, modelobj, values): map(self.field.validate_subfield, values) raw_values = [self.field.subfield_to_riak(value) for value in values] field_set = self._get_model_data(modelobj) field_set.update(raw_values) self._set_model_data(modelobj, field_set) def iter_set(self, modelobj): field_set = self._get_model_data(modelobj) for raw_value in field_set: yield self.field.subfield_from_riak(raw_value) class SetProxy(object): def __init__(self, descriptor, modelobj): self._descriptor = descriptor self._modelobj = modelobj def __contains__(self, value): return self._descriptor.set_contains_item(self._modelobj, value) def add(self, value): self._descriptor.add_set_item(self._modelobj, value) def remove(self, value): self._descriptor.remove_set_item(self._modelobj, value) def discard(self, value): self._descriptor.discard_set_item(self._modelobj, value) def update(self, values): self._descriptor.update_set(self._modelobj, values) def __iter__(self): return self._descriptor.iter_set(self._modelobj) class SetOf(FieldWithSubtype): """ A field that contains a set of values of some other type. :param Field field_type: The field specification for the dynamic values. Default is Unicode(). """ descriptor_class = SetOfDescriptor def __init__(self, field_type=None, **kw): super(SetOf, self).__init__(field_type=field_type, default=set, **kw) def custom_validate(self, valueset): if not isinstance(valueset, set): raise ValidationError( "Value %r should be a set of values" % valueset) map(self.validate_subfield, valueset) def custom_to_riak(self, value): return sorted(value) def custom_from_riak(self, raw_value): return set(raw_value) class ForeignKeyDescriptor(FieldDescriptor): def setup(self, model_cls): super(ForeignKeyDescriptor, self).setup(model_cls) self.other_model = self.field.other_model if self.field.index is None: self.index_name = "%s_bin" % self.key else: self.index_name = self.field.index backlink_name = self.field.backlink if backlink_name is None: backlink_name = model_cls.__name__.lower() + "s" self.other_model.backlinks.declare_backlink( backlink_name, self.reverse_lookup_keys) backlink_keys_name = backlink_name + "_keys" if backlink_keys_name.endswith("s_keys"): backlink_keys_name = backlink_name[:-1] + "_keys" self.other_model.backlinks.declare_backlink( backlink_keys_name, self.reverse_lookup_keys_paginated) def reverse_lookup_keys(self, modelobj, manager=None): if manager is None: manager = modelobj.manager return manager.index_keys( self.model_cls, self.index_name, modelobj.key) def reverse_lookup_keys_paginated(self, modelobj, manager=None, max_results=None, continuation=None): """ Perform a paginated index query for backlinked objects. """ if manager is None: manager = modelobj.manager return manager.index_keys_page( self.model_cls, self.index_name, modelobj.key, max_results=max_results, continuation=continuation) def clean(self, modelobj): if self.key not in modelobj._riak_object.get_data(): # We might have an old-style index-only version of the data. indexes = [ value for name, value in modelobj._riak_object.get_indexes() if name == self.index_name] self.set_riak_data(modelobj, (indexes or [None])[0]) def get_value(self, modelobj): return ForeignKeyProxy(self, modelobj) def get_foreign_key(self, modelobj): return self.get_riak_data(modelobj) def set_foreign_key(self, modelobj, foreign_key): self.set_riak_data(modelobj, foreign_key) modelobj._riak_object.remove_index(self.index_name) if foreign_key is not None: self._add_index(modelobj, foreign_key) def get_foreign_object(self, modelobj, manager=None): key = self.get_foreign_key(modelobj) if key is None: return None if manager is None: manager = modelobj.manager return self.other_model.load(manager, key) def initialize(self, modelobj, value): if isinstance(value, basestring): self.set_foreign_key(modelobj, value) else: self.set_foreign_object(modelobj, value) def set_value(self, modelobj, value): raise RuntimeError("ForeignKeyDescriptors should never be assigned" " to.") def set_foreign_object(self, modelobj, otherobj): self.validate(otherobj) foreign_key = otherobj.key if otherobj is not None else None self.set_foreign_key(modelobj, foreign_key) class ForeignKeyProxy(object): def __init__(self, descriptor, modelobj): self._descriptor = descriptor self._modelobj = modelobj def _get_key(self): return self._descriptor.get_foreign_key(self._modelobj) def _set_key(self, foreign_key): return self._descriptor.set_foreign_key(self._modelobj, foreign_key) key = property(fget=_get_key, fset=_set_key) def get(self, manager=None): return self._descriptor.get_foreign_object(self._modelobj, manager) def set(self, otherobj): self._descriptor.set_foreign_object(self._modelobj, otherobj) class ForeignKey(Field): """A field that links to another class. Additional parameters: :param Model other_model: The type of model linked to. :param string index: The name to use for the index. The default is the field name followed by _bin. :param string backlink: The name to use for the backlink on :attr:`other_model.backlinks`. The default is the name of the class the field is on converted to lowercase and with 's' appended (e.g. 'FooModel' would result in :attr:`other_model.backlinks.foomodels`). This is also used (with `_keys` appended and a trailing `s` omitted if one is present) for the paginated keys backlink function. """ descriptor_class = ForeignKeyDescriptor def __init__(self, other_model, index=None, backlink=None, **kw): super(ForeignKey, self).__init__(**kw) self.other_model = other_model self.index = index self.backlink = backlink def custom_validate(self, value): if not isinstance(value, self.other_model): raise ValidationError("ForeignKey requires a %r instance" % (self.other_model,)) class ManyToManyDescriptor(ForeignKeyDescriptor): def get_value(self, modelobj): return ManyToManyProxy(self, modelobj) def set_value(self, modelobj, value): raise RuntimeError("ManyToManyDescriptors should never be assigned" " to.") def clean(self, modelobj): if self.key not in modelobj._riak_object.get_data(): # We might have an old-style index-only version of the data. indexes = [ value for name, value in modelobj._riak_object.get_indexes() if name == self.index_name] self.set_riak_data(modelobj, indexes[:]) def get_foreign_keys(self, modelobj): return self.get_riak_data(modelobj, [])[:] def add_foreign_key(self, modelobj, foreign_key): if foreign_key not in self.get_foreign_keys(modelobj): field_list = self.get_riak_data(modelobj, []) field_list.append(foreign_key) self.set_riak_data(modelobj, field_list) self._add_index(modelobj, foreign_key) def remove_foreign_key(self, modelobj, foreign_key): if foreign_key in self.get_foreign_keys(modelobj): field_list = self.get_riak_data(modelobj, []) field_list.remove(foreign_key) self.set_riak_data(modelobj, field_list) modelobj._riak_object.remove_index(self.index_name, foreign_key) def load_foreign_objects(self, modelobj, manager=None): keys = self.get_foreign_keys(modelobj) if manager is None: manager = modelobj.manager return manager.load_all_bunches(self.other_model, keys) def add_foreign_object(self, modelobj, otherobj): self.validate(otherobj) self.add_foreign_key(modelobj, otherobj.key) def remove_foreign_object(self, modelobj, otherobj): self.validate(otherobj) self.remove_foreign_key(modelobj, otherobj.key) def clear_keys(self, modelobj): self.set_riak_data(modelobj, []) modelobj._riak_object.remove_index(self.index_name) class ManyToManyProxy(object): def __init__(self, descriptor, modelobj): self._descriptor = descriptor self._modelobj = modelobj def keys(self): return self._descriptor.get_foreign_keys(self._modelobj) def add_key(self, foreign_key): self._descriptor.add_foreign_key(self._modelobj, foreign_key) def remove_key(self, foreign_key): self._descriptor.remove_foreign_key(self._modelobj, foreign_key) def load_all_bunches(self, manager=None): return self._descriptor.load_foreign_objects(self._modelobj, manager) def add(self, otherobj): self._descriptor.add_foreign_object(self._modelobj, otherobj) def remove(self, otherobj): self._descriptor.remove_foreign_object(self._modelobj, otherobj) def clear(self): self._descriptor.clear_keys(self._modelobj) class ManyToMany(ForeignKey): """A field that links to multiple instances of another class. :param Model other_model: The type of model linked to. :param string index: The name to use for the index. The default is the field name followed by _bin. :param string backlink: The name to use for the backlink on :attr:`other_model.backlinks`. The default is the name of the class the field is on converted to lowercase and with 's' appended (e.g. 'FooModel' would result in :attr:`other_model.backlinks.foomodels`). This is also used (with `_keys` appended and a trailing `s` omitted if one is present) for the paginated keys backlink function. """ descriptor_class = ManyToManyDescriptor initializable = False def __init__(self, other_model, index=None, backlink=None): super(ManyToMany, self).__init__(other_model, index, backlink) class ComputedValueDescriptor(FieldDescriptor): """A field descriptor for computed value fields.""" def __init__(self, key, field): super(ComputedValueDescriptor, self).__init__(key, field) self.subfield_descriptor = field.field_type.get_descriptor(key) def model_field_changed(self, modelobj, changed_field_name): self.set_value(modelobj, self.field.value_func(modelobj)) def set_value(self, modelobj, value): return self.subfield_descriptor.set_value(modelobj, value) def get_value(self, modelobj): return self.subfield_descriptor.get_value(modelobj) def __set__(self, instance, value): raise RuntimeError("Can't set value of computed field.") class ComputedValue(Field): """ Field that stores a computed value. :param value_func: A function that takes a model instance as its only parameter and returns a value for the field. This is called whenever the value of another field changes to compute the value of this field. :param Field field_type: The field specification for the computed value. Default is Unicode(). """ descriptor_class = ComputedValueDescriptor initializable = False def __init__(self, value_func, field_type=None): super(ComputedValue, self).__init__() if field_type is None: field_type = Unicode() if not isinstance(field_type, Field): raise TypeError("field_type must be a Field object.") self.field_type = field_type self.value_func = value_func def validate(self, value): return self.field_type.validate(value) def to_riak(self, value): return self.field_type.to_riak(value) def from_riak(self, value): return self.field_type.from_riak(value) PK=JG36rvumi/persist/riak_manager.py# -*- test-case-name: vumi.persist.tests.test_riak_manager -*- """A manager implementation on top of the riak Python package.""" from riak import RiakObject, RiakMapReduce, RiakError from vumi.persist.model import Manager, VumiRiakError from vumi.persist.riak_base import ( VumiRiakClientBase, VumiIndexPageBase, VumiRiakBucketBase, VumiRiakObjectBase) from vumi.utils import flatten_generator class VumiRiakClient(VumiRiakClientBase): """ Wrapper around a RiakClient to manage resources better. """ class VumiIndexPage(VumiIndexPageBase): """ Wrapper around a page of index query results. Iterating over this object will return the results for the current page. """ # Methods that touch the network. def next_page(self): """ Fetch the next page of results. :returns: A new :class:`VumiIndexPage` object containing the next page of results. """ if not self.has_next_page(): return None try: result = self._index_page.next_page() except RiakError as e: raise VumiRiakError(e) return type(self)(result) class VumiRiakBucket(VumiRiakBucketBase): """ Wrapper around a RiakBucket to manage network access better. """ # Methods that touch the network. def get_index(self, index_name, start_value, end_value=None, return_terms=None): keys = self.get_index_page( index_name, start_value, end_value, return_terms=return_terms) return list(keys) def get_index_page(self, index_name, start_value, end_value=None, return_terms=None, max_results=None, continuation=None): try: result = self._riak_bucket.get_index( index_name, start_value, end_value, return_terms=return_terms, max_results=max_results, continuation=continuation) except RiakError as e: raise VumiRiakError(e) return VumiIndexPage(result) class VumiRiakObject(VumiRiakObjectBase): """ Wrapper around a RiakObject to manage network access better. """ def get_bucket(self): return VumiRiakBucket(self._riak_obj.bucket) # Methods that touch the network. def _call_and_wrap(self, func): """ Call a function that touches the network and wrap the result in this class. """ return type(self)(func()) class RiakManager(Manager): """A persistence manager for the riak Python package.""" call_decorator = staticmethod(flatten_generator) @classmethod def from_config(cls, config): config = config.copy() bucket_prefix = config.pop('bucket_prefix') load_bunch_size = config.pop( 'load_bunch_size', cls.DEFAULT_LOAD_BUNCH_SIZE) mapreduce_timeout = config.pop( 'mapreduce_timeout', cls.DEFAULT_MAPREDUCE_TIMEOUT) transport_type = config.pop('transport_type', 'http') store_versions = config.pop('store_versions', None) host = config.get('host', '127.0.0.1') port = config.get('port') prefix = config.get('prefix', 'riak') mapred_prefix = config.get('mapred_prefix', 'mapred') client_id = config.get('client_id') transport_options = config.get('transport_options', {}) client_args = dict( host=host, prefix=prefix, mapred_prefix=mapred_prefix, protocol=transport_type, client_id=client_id, transport_options=transport_options) if port is not None: client_args['port'] = port client = VumiRiakClient(**client_args) return cls( client, bucket_prefix, load_bunch_size=load_bunch_size, mapreduce_timeout=mapreduce_timeout, store_versions=store_versions) def close_manager(self): if self._parent is None: # Only top-level managers may close the client. self.client.close() def _is_unclosed(self): # This returns `True` if the manager needs to be explicitly closed and # hasn't been closed yet. It should only be used in tests that ensure # client objects aren't leaked. if self._parent is not None: return False return not self.client._closed def riak_bucket(self, bucket_name): bucket = self.client.bucket(bucket_name) if bucket is not None: bucket = VumiRiakBucket(bucket) return bucket def riak_object(self, modelcls, key, result=None): bucket = self.bucket_for_modelcls(modelcls)._riak_bucket riak_object = VumiRiakObject(RiakObject(self.client, bucket, key)) if result: metadata = result['metadata'] indexes = metadata['index'] if hasattr(indexes, 'items'): # TODO: I think this is a Riak bug. In some cases # (maybe when there are no indexes?) the index # comes back as a list, in others (maybe when # there are indexes?) it comes back as a dict. indexes = indexes.items() data = result['data'] riak_object.set_content_type(metadata['content-type']) riak_object.set_indexes(indexes) riak_object.set_encoded_data(data) else: riak_object.set_content_type("application/json") riak_object.set_data({'$VERSION': modelcls.VERSION}) return riak_object def store(self, modelobj): riak_object = self._reverse_migrate_riak_object(modelobj) riak_object.store() return modelobj def delete(self, modelobj): modelobj._riak_object.delete() def load(self, modelcls, key, result=None): riak_object = self.riak_object(modelcls, key, result) if not result: riak_object.reload() return self._migrate_riak_object(modelcls, key, riak_object) def _load_multiple(self, modelcls, keys): objs = (self.load(modelcls, key) for key in keys) return [obj for obj in objs if obj is not None] def riak_map_reduce(self): return RiakMapReduce(self.client) def run_map_reduce(self, mapreduce, mapper_func=None, reducer_func=None): results = mapreduce.run(timeout=self.mapreduce_timeout) if mapper_func is not None: results = [mapper_func(self, row) for row in results] if reducer_func is not None: results = reducer_func(self, results) return results def _search_iteration(self, bucket, query, rows, start): results = bucket.search(query, rows=rows, start=start) return [doc["id"] for doc in results["docs"]] def real_search(self, modelcls, query, rows=None, start=None): rows = 1000 if rows is None else rows bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) if start is not None: return self._search_iteration(bucket, query, rows, start) keys = [] new_keys = self._search_iteration(bucket, query, rows, 0) while new_keys: keys.extend(new_keys) new_keys = self._search_iteration(bucket, query, rows, len(keys)) return keys def riak_enable_search(self, modelcls): bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) return bucket.enable_search() def riak_search_enabled(self, modelcls): bucket_name = self.bucket_name(modelcls) bucket = self.client.bucket(bucket_name) return bucket.search_enabled() def should_quote_index_values(self): return False def purge_all(self): self.client._purge_all(self.bucket_prefix) PK[Hk7)7)vumi/persist/redis_base.py# -*- test-case-name: vumi.persist.tests.test_redis_base -*- import os from functools import wraps from vumi.persist.ast_magic import make_function from vumi.persist.fake_redis import FakeRedis def make_callfunc(name, redis_call): def func(self, *a, **kw): def _f(k, v): if k in redis_call.key_args: return self._key(v) return v arg_names = list(redis_call.args) + [redis_call.vararg] * len(a) aa = [_f(k, v) for k, v in zip(arg_names, a)] kk = dict((k, _f(k, v)) for k, v in kw.items()) result = self._make_redis_call(name, *aa, **kk) f_func = redis_call.filter_func if f_func: if isinstance(f_func, basestring): f_func = getattr(self, f_func) result = self._filter_redis_results(f_func, result) return result fargs = ['self'] + list(redis_call.args) return make_function(name, func, fargs, redis_call.vararg, redis_call.kwarg, redis_call.defaults) class RedisCall(object): def __init__(self, args, vararg=None, kwarg=None, defaults=(), filter_func=None, key_args=('key',)): self.args = args self.vararg = vararg self.kwarg = kwarg self.defaults = defaults self.filter_func = filter_func self.key_args = key_args class CallMakerMetaclass(type): def __new__(meta, classname, bases, class_dict): new_class_dict = {} for name, attr in class_dict.items(): if isinstance(attr, RedisCall): attr = make_callfunc(name, attr) new_class_dict[name] = attr return type.__new__(meta, classname, bases, new_class_dict) class ClientProxy(object): def __init__(self, client): self.client = client class Manager(object): __metaclass__ = CallMakerMetaclass def __init__( self, client, config, key_prefix, key_separator=None, client_proxy=None): assert \ (client is None and client_proxy is not None) or \ (client is not None and client_proxy is None), \ 'Only one of client or client_proxy may be specified' if key_separator is None: key_separator = ':' if client_proxy is None: self._client_proxy = ClientProxy(client) else: self._client_proxy = client_proxy self._config = config self._key_prefix = key_prefix self._key_separator = key_separator def __deepcopy__(self, memo): "This is to let managers pass through config deepcopies in tests." return self @property def _client(self): return self._client_proxy.client def get_key_prefix(self): """This is only intended for use in testing, not production.""" return self._key_prefix def sub_manager(self, sub_prefix): key_prefix = self._key(sub_prefix) sub_man = self.__class__( None, self._config, key_prefix, client_proxy=self._client_proxy) if isinstance(self._client, FakeRedis): sub_man._close = self._client.teardown return sub_man @staticmethod def calls_manager(manager_attr): """Decorate a method that calls a manager. This redecorates with the `call_decorator` attribute on the Manager subclass used, which should be either @inlineCallbacks or @flatten_generator. """ if callable(manager_attr): # If we don't get a manager attribute name, default to 'manager'. return Manager.calls_manager('manager')(manager_attr) def redecorate(func): @wraps(func) def wrapper(self, *args, **kw): manager = getattr(self, manager_attr) return manager.call_decorator(func)(self, *args, **kw) return wrapper return redecorate @classmethod def from_config(cls, config): """Construct a manager from a dictionary of options. :param dict config: Dictionary of options for the manager. """ # So we can mangle it client_config = config.copy() manager_config = { 'config': config.copy(), 'key_prefix': client_config.pop('key_prefix', None), 'key_separator': client_config.pop('key_separator', ':'), } fake_redis = client_config.pop('FAKE_REDIS', None) if 'VUMITEST_REDIS_DB' in os.environ: fake_redis = None client_config['db'] = int(os.environ['VUMITEST_REDIS_DB']) if fake_redis is not None: if isinstance(fake_redis, cls): # We want to unwrap the existing fake_redis to rewrap it. fake_redis = fake_redis._client if isinstance(fake_redis, FakeRedis): # We want to wrap the existing fake_redis. pass else: # We want a new fake redis. fake_redis = None return cls._fake_manager(fake_redis, manager_config) return cls._manager_from_config(client_config, manager_config) @classmethod def _fake_manager(cls, fake_redis, manager_config): raise NotImplementedError("Sub-classes of Manager should implement" " ._fake_manager(...)") @classmethod def _manager_from_config(cls, client_config, manager_config): """Construct a client from a dictionary of options. :param dict config: Dictionary of options for the manager. :param str key_prefix: Key prefix for namespacing. """ raise NotImplementedError("Sub-classes of Manager should implement" " ._manager_from_config(...)") def close_manager(self): return self._close() def _close(self): """Close redis connection.""" raise NotImplementedError("Sub-classes of Manager should implement" " ._close()") def _purge_all(self): """Delete *ALL* keys whose names start with this manager's key prefix. Use only in tests. """ raise NotImplementedError("Sub-classes of Manager should implement" " ._purge_all()") def _make_redis_call(self, call, *args, **kw): """Make a redis API call using the underlying client library. """ raise NotImplementedError("Sub-classes of Manager should implement" " ._make_redis_call()") def _filter_redis_results(self, func, results): """Filter results of a redis call. """ raise NotImplementedError("Sub-classes of Manager should implement" " ._filter_redis_results()") def _key(self, key): """ Generate a key using this manager's key prefix """ if self._key_prefix is None: return key return "%s%s%s" % (self._key_prefix, self._key_separator, key) def _unkey(self, key): """ Strip off manager's key prefix from a key """ prefix = "%s%s" % (self._key_prefix, self._key_separator) if key.startswith(prefix): return key[len(prefix):] return key def _unkeys(self, keys): return [self._unkey(k) for k in keys] def _unkeys_scan(self, scan_results): return [scan_results[0], self._unkeys(scan_results[1])] # Global operations type = RedisCall(['key']) exists = RedisCall(['key']) keys = RedisCall(['pattern'], defaults=['*'], key_args=['pattern'], filter_func='_unkeys') scan = RedisCall(['cursor', 'match', 'count'], defaults=['*', None], key_args=['match'], filter_func='_unkeys_scan') # String operations get = RedisCall(['key']) set = RedisCall(['key', 'value']) setnx = RedisCall(['key', 'value']) delete = RedisCall(['key']) setex = RedisCall(['key', 'seconds', 'value']) rename = RedisCall(['key', 'newkey'], key_args=('key', 'newkey')) # Integer operations incr = RedisCall(['key', 'amount'], defaults=[1]) incrby = RedisCall(['key', 'amount']) decr = RedisCall(['key', 'amount'], defaults=[1]) decrby = RedisCall(['key', 'amount']) # Hash operations hset = RedisCall(['key', 'field', 'value']) hsetnx = RedisCall(['key', 'field', 'value']) hget = RedisCall(['key', 'field']) hdel = RedisCall(['key'], vararg='fields') hmset = RedisCall(['key', 'mapping']) hgetall = RedisCall(['key']) hlen = RedisCall(['key']) hvals = RedisCall(['key']) hincrby = RedisCall(['key', 'field', 'amount'], defaults=[1]) hexists = RedisCall(['key', 'field']) # Set operations sadd = RedisCall(['key'], vararg='values') smembers = RedisCall(['key']) spop = RedisCall(['key']) srem = RedisCall(['key', 'value']) scard = RedisCall(['key']) smove = RedisCall(['src', 'dst', 'value'], key_args=['src', 'dst']) sunion = RedisCall(['key'], vararg='args', key_args=['key', 'args']) sismember = RedisCall(['key', 'value']) # Sorted set operations zadd = RedisCall(['key'], kwarg='valscores') zrem = RedisCall(['key', 'value']) zcard = RedisCall(['key']) zrange = RedisCall(['key', 'start', 'stop', 'desc', 'withscores'], defaults=[False, False]) zrangebyscore = RedisCall( ['key', 'min', 'max', 'start', 'num', 'withscores'], defaults=['-inf', '+inf', None, None, False]) zscore = RedisCall(['key', 'value']) zcount = RedisCall(['key', 'min', 'max']) zremrangebyrank = RedisCall(['key', 'start', 'stop']) # List operations llen = RedisCall(['key']) lpop = RedisCall(['key']) rpop = RedisCall(['key']) lpush = RedisCall(['key', 'obj']) rpush = RedisCall(['key', 'obj']) lrange = RedisCall(['key', 'start', 'end']) lrem = RedisCall(['key', 'value', 'num'], defaults=[0]) rpoplpush = RedisCall( ['source'], vararg='destination', key_args=['source', 'destination']) ltrim = RedisCall(['key', 'start', 'stop']) # Expiry operations expire = RedisCall(['key', 'seconds']) persist = RedisCall(['key']) ttl = RedisCall(['key']) # HyperLogLog operations pfadd = RedisCall(['key'], vararg='values') pfcount = RedisCall(['key']) PK=JG.00vumi/persist/ast_magic.pyimport ast from functools import partial def _mknode(cls, **kw): "Make an AST node with the relevant bits attached." node = cls() node.lineno = 0 node.col_offset = 0 for k, v in kw.items(): setattr(node, k, v) return node # Some conveniences for building the AST. arguments = partial(_mknode, ast.arguments) Call = partial(_mknode, ast.Call) Attribute = partial(_mknode, ast.Attribute) FunctionDef = partial(_mknode, ast.FunctionDef) Return = partial(_mknode, ast.Return) Module = partial(_mknode, ast.Module) _param = lambda name: _mknode(ast.Name, id=name, ctx=ast.Param()) _load = lambda name: _mknode(ast.Name, id=name, ctx=ast.Load()) _kw = lambda name: _mknode(ast.keyword, arg=name, value=_load(name)) def make_function(name, func, args, vararg=None, kwarg=None, defaults=()): "Create a function that has a nice signature and calls out to ``func``." # Give our default arguments names so we can shove them in globals. dflts = [("default_%s" % i, d) for i, d in enumerate(defaults)] # Build args and default lists for our function def. a_args = [_param(a) for a in args] a_defaults = [_load(k) for k, v in dflts] # Build args and keywords lists for our function call. c_args = [_load(a) for a in args[:len(args) - len(defaults)]] c_keywords = [_kw(a) for a in args[len(args) - len(defaults):]] # Construct the call to our external function. call = Call(func=_load('func'), args=c_args, keywords=c_keywords, starargs=(vararg and _load(vararg)), kwargs=(kwarg and _load(kwarg))) # Construct the function definition we're actually making. func_def = FunctionDef( name=name, args=arguments( args=a_args, vararg=vararg, kwarg=kwarg, defaults=a_defaults), body=[Return(value=call)], decorator_list=[]) # Build up locals and globals, then compile and extract our function. locs = {} globs = dict(globals(), func=func, **dict(dflts)) eval(compile(Module(body=[func_def]), '', 'exec'), globs, locs) return locs[name] PKfcHWh_IIvumi/persist/fake_redis.py# -*- test-case-name: vumi.persist.tests.test_fake_redis -*- import fnmatch from functools import wraps from itertools import takewhile, dropwhile import os from zlib import crc32 from hyperloglog import HyperLogLog from twisted.internet import reactor from twisted.internet.defer import Deferred, execute from twisted.internet.task import Clock FAKE_REDIS_WAIT = float(os.environ.get('VUMI_FAKE_REDIS_WAIT', '0.002')) def maybe_async(func): @wraps(func) def wrapper(self, *args, **kw): return self._delay_operation(func, args, kw) wrapper.sync = func return wrapper def call_to_deferred(deferred, func, *args, **kw): execute(func, *args, **kw).chainDeferred(deferred) class ResponseError(Exception): """ Exception class for things we throw to match the real Redis client libraries. """ class FakeRedis(object): """In process and memory implementation of redis-like data store. It's intended to match the Python redis module API closely so that it can be used in place of the redis module when testing. Known limitations: * Exceptions raised are not guaranteed to match the exception types raised by the real Python redis module. """ def __init__(self, charset='utf-8', errors='strict', async=False): self._data = {} self._known_key_existence = {} self._expiries = {} self._is_async = async self.clock = Clock() self._charset = charset self._charset_errors = errors self._delayed_calls = [] def teardown(self): self._clean_up_expires() self._clean_up_delayed_calls() def _encode(self, value): # Replicated from # redis-py's redis/connection.py if isinstance(value, str): return value if not isinstance(value, unicode): value = str(value) if isinstance(value, unicode): value = value.encode(self._charset, self._charset_errors) return value def _clean_up_expires(self): for key in self._expiries.keys(): delayed = self._expiries.pop(key) if not (delayed.cancelled or delayed.called): delayed.cancel() def _clean_up_delayed_calls(self): for delayed in self._delayed_calls: if not (delayed.cancelled or delayed.called): delayed.cancel() def _delay_operation(self, func, args, kw): """ Return the result with some fake delay. If we're in async mode, add some real delay to catch code that doesn't properly wait for the deferred to fire. """ self.clock.advance(0.1) if self._is_async: # Add some latency to catch things that don't wait on deferreds. We # can't use deferLater() here because we want to keep track of the # delayed call object. d = Deferred() delayed = reactor.callLater( FAKE_REDIS_WAIT, call_to_deferred, d, func, self, *args, **kw) self._delayed_calls.append(delayed) return d else: return func(self, *args, **kw) def _set_key(self, key, value): self._known_key_existence[key] = True self._data[key] = value def _setdefault_key(self, key, default): self._known_key_existence[key] = True return self._data.setdefault(key, default) def _sort_keys_by_hash(self, keys): """ Sort keys in a consistent but non-obvious way. We sort by the crc32 of the key, that being cheap and good enough for our purposes here. """ return sorted(keys, key=crc32) # Global operations @maybe_async def type(self, key): value = self._data.get(key) if value is None: return 'none' if isinstance(value, basestring): return 'string' if isinstance(value, list): return 'list' if isinstance(value, set): return 'set' if isinstance(value, Zset): return 'zset' if isinstance(value, dict): return 'hash' @maybe_async def exists(self, key): return key in self._data @maybe_async def keys(self, pattern='*'): return fnmatch.filter(self._data.keys(), pattern) @maybe_async def scan(self, cursor, match=None, count=None): if cursor is None: start = 0 else: start = int(cursor) if match is None: match = '*' if count is None: count = 10 output = [] # Start with all the keys we've ever seen, ordered in a consistent but # non-obvious way. keys = self._sort_keys_by_hash(self._known_key_existence.keys()) # Then throw away the number of keys our cursor has already walked. # This means we may miss new keys that have been added since we started # iterating and/or return duplicates, but that's what Redis does. i = None for i, key in enumerate(keys[start:]): if not self._known_key_existence[key]: # This key has been deleted. continue output.append(key) if len(output) >= count: break # Update the cursor to reflect the new position in the key list. if i is None or start + i + 1 >= len(keys): cursor = None else: cursor = str(start + i + 1) return [cursor, fnmatch.filter(output, match)] @maybe_async def flushdb(self): self._data = {} self._known_key_existence = {} # String operations @maybe_async def get(self, key): return self._data.get(key) @maybe_async def set(self, key, value): value = self._encode(value) # set() sets string value self._set_key(key, value) return True @maybe_async def setex(self, key, time, value): self.set.sync(self, key, value) self.expire.sync(self, key, time) return True @maybe_async def setnx(self, key, value): value = self._encode(value) # set() sets string value if key not in self._data: self._set_key(key, value) return 1 return 0 @maybe_async def delete(self, key): existed = (key in self._data) self._data.pop(key, None) if existed: self._known_key_existence[key] = False return existed @maybe_async def rename(self, key, newkey): if key == newkey: raise ResponseError("source and destination objects are the same") if key not in self._data: raise ResponseError("no such key") data = self._data.pop(key) self._set_key(newkey, data) return True # Integer operations # The python redis lib combines incr & incrby into incr(key, amount=1) @maybe_async def incr(self, key, amount=1): old_value = self._data.get(key) if old_value is None: old_value = 0 new_value = int(old_value) + amount self.set.sync(self, key, new_value) return new_value @maybe_async def decr(self, key, amount=1): old_value = self._data.get(key) if old_value is None: old_value = 0 new_value = int(old_value) - amount self.set.sync(self, key, new_value) return new_value # Hash operations @maybe_async def hset(self, key, field, value): mapping = self._setdefault_key(key, {}) new_field = field not in mapping mapping[field] = value return int(new_field) @maybe_async def hsetnx(self, key, field, value): if self.hexists.sync(self, key, field): return 0 return self.hset.sync(self, key, field, value) @maybe_async def hget(self, key, field): value = self._data.get(key, {}).get(field) if value is not None: return self._encode(value) @maybe_async def hdel(self, key, *fields): mapping = self._data.get(key) if mapping is None: return 0 deleted = 0 for field in fields: if field in mapping: del mapping[field] deleted += 1 return deleted @maybe_async def hmset(self, key, mapping): hval = self._setdefault_key(key, {}) hval.update(dict([(k, v) for k, v in mapping.items()])) @maybe_async def hgetall(self, key): return dict((self._encode(k), self._encode(v)) for k, v in self._data.get(key, {}).items()) @maybe_async def hlen(self, key): return len(self._data.get(key, {})) @maybe_async def hvals(self, key): return map(self._encode, self._data.get(key, {}).values()) @maybe_async def hincrby(self, key, field, amount=1): try: value = self._data.get(key, {}).get(field, "0") except AttributeError: raise ResponseError("WRONGTYPE Operation against a key holding" " the wrong kind of value") # the int(str(..)) coerces amount to an int but rejects floats try: value = int(value) + int(str(amount)) except (TypeError, ValueError): raise ResponseError("value is not an integer or out of range") self._setdefault_key(key, {})[field] = str(value) return value @maybe_async def hexists(self, key, field): return int(field in self._data.get(key, {})) # Set operations @maybe_async def sadd(self, key, *values): sval = self._setdefault_key(key, set()) old_len = len(sval) sval.update(map(self._encode, values)) return len(sval) - old_len @maybe_async def smembers(self, key): return self._data.get(key, set()) @maybe_async def spop(self, key): sval = self._data.get(key, set()) if not sval: return None return sval.pop() @maybe_async def srem(self, key, value): sval = self._data.get(key, set()) if value in sval: sval.remove(value) return 1 return 0 @maybe_async def scard(self, key): return len(self._data.get(key, set())) @maybe_async def smove(self, src, dst, value): result = self.srem.sync(self, src, value) if result: self.sadd.sync(self, dst, value) return result @maybe_async def sunion(self, key, *args): union = set() for rkey in (key,) + args: union.update(self._data.get(rkey, set())) return union @maybe_async def sismember(self, key, value): sval = self._data.get(key, set()) return value in sval # Sorted set operations @maybe_async def zadd(self, key, **valscores): zval = self._setdefault_key(key, Zset()) return zval.zadd(**valscores) @maybe_async def zrem(self, key, value): zval = self._setdefault_key(key, Zset()) return zval.zrem(value) @maybe_async def zcard(self, key): zval = self._data.get(key, Zset()) return zval.zcard() @maybe_async def zrange(self, key, start, stop, desc=False, withscores=False, score_cast_func=float): zval = self._data.get(key, Zset()) results = zval.zrange(start, stop, desc=desc, score_cast_func=score_cast_func) if withscores: return results else: return [v for v, k in results] @maybe_async def zrangebyscore(self, key, min='-inf', max='+inf', start=0, num=None, withscores=False, score_cast_func=float): zval = self._data.get(key, Zset()) results = zval.zrangebyscore( min, max, start, num, score_cast_func=score_cast_func) if withscores: return results else: return [v for v, k in results] @maybe_async def zcount(self, key, min, max): return len(self.zrangebyscore.sync(self, key, min, max)) @maybe_async def zscore(self, key, value): zval = self._data.get(key, Zset()) return zval.zscore(value) @maybe_async def zremrangebyrank(self, key, start, stop): zval = self._setdefault_key(key, Zset()) return zval.zremrangebyrank(start, stop) # List operations @maybe_async def llen(self, key): return len(self._data.get(key, [])) @maybe_async def lpop(self, key): if self.llen.sync(self, key): return self._data[key].pop(0) @maybe_async def rpop(self, key): if self.llen.sync(self, key): return self._data[key].pop(-1) @maybe_async def lpush(self, key, obj): self._setdefault_key(key, []).insert(0, self._encode(obj)) return self.llen.sync(self, key) @maybe_async def rpush(self, key, obj): self._setdefault_key(key, []).append(self._encode(obj)) return self.llen.sync(self, key) @maybe_async def lrange(self, key, start, end): lval = self._data.get(key, []) if end >= 0 or end < -1: end += 1 else: end = None return lval[start:end] @maybe_async def lrem(self, key, value, num=0): removed = [0] value = self._encode(value) def keep(v): if v == value and (num == 0 or removed[0] < abs(num)): removed[0] += 1 return False return True lval = self._data.get(key, []) if num >= 0: lval = [v for v in lval if keep(v)] else: lval.reverse() lval = [v for v in lval if keep(v)] lval.reverse() self._set_key(key, lval) return removed[0] @maybe_async def rpoplpush(self, source, destination): value = self.rpop.sync(self, source) if value: self.lpush.sync(self, destination, value) return value @maybe_async def ltrim(self, key, start, stop): lval = self._data.get(key, []) if stop != -1: # -1 means "end of list", so we skip the deletion. Otherwise we # increment the "stop" value to avoid deleting the last value we # want to keep. del lval[stop + 1:] del lval[:start] return True # Expiry operations @maybe_async def expire(self, key, seconds): if key not in self._data: return 0 self.persist.sync(self, key) delayed = self.clock.callLater(seconds, self.delete.sync, self, key) self._expiries[key] = delayed return 1 @maybe_async def ttl(self, key): delayed = self._expiries.get(key) if delayed is not None and delayed.active(): return round(delayed.getTime() - self.clock.seconds()) return None @maybe_async def persist(self, key): delayed = self._expiries.get(key) if delayed is not None and delayed.active(): delayed.cancel() return 1 return 0 # HyperLogLog operations @maybe_async def pfadd(self, key, *values): hll = self._setdefault_key(key, HyperLogLog(0.01)) old_card = hll.card() for value in values: hll.add(value) return hll.card() != old_card @maybe_async def pfcount(self, key): hll = self._data.get(key, HyperLogLog(0.01)) return len(hll) class Zset(object): """A Redis-like ordered set implementation.""" def __init__(self): self._zval = [] def _redis_range_to_py_range(self, start, end): end += 1 # redis start/end are element indexes if end == 0: end = None return start, end def _to_float(self, value): try: return float(value) except (ValueError, TypeError): raise ResponseError("value is not a valid float") def zadd(self, **valscores): new_zval = [val for val in self._zval if val[1] not in valscores] new_zval.extend((self._to_float(score), value) for value, score in valscores.items()) new_zval.sort() added = len(new_zval) - len(self._zval) self._zval = new_zval return added def zrem(self, value): new_zval = [val for val in self._zval if val[1] != value] existed = len(new_zval) != len(self._zval) self._zval = new_zval return existed def zcard(self): return len(self._zval) def zrange(self, start, stop, desc=False, score_cast_func=float): start, stop = self._redis_range_to_py_range(start, stop) # copy before changing in place zval = self._zval[:] zval.sort(reverse=desc) return [(v, score_cast_func(k)) for k, v in zval[start:stop]] def zrangebyscore(self, min='-inf', max='+inf', start=0, num=None, score_cast_func=float): results = self.zrange(0, -1, score_cast_func=score_cast_func) results.sort(key=lambda val: val[1]) def mkcheck(spec, is_upper_bound): spec = str(spec) # Handling infinities are easy, so get them out the way first. if spec.endswith('-inf'): return lambda val: False if spec.endswith('+inf'): return lambda val: True is_exclusive = False if spec.startswith('('): is_exclusive = True spec = spec[1:] spec = score_cast_func(spec) # For the lower bound, exclusive means drop less than or equal to. # For the upper bound, exclusive means take less than. if is_exclusive == is_upper_bound: return lambda val: val[1] < spec return lambda val: val[1] <= spec results = dropwhile(mkcheck(min, False), results) results = takewhile(mkcheck(max, True), results) results = list(results)[start:] if num is not None: results = results[:num] return list(results) def zscore(self, val): for score, value in self._zval: if value == val: return score def zremrangebyrank(self, start, stop): start, stop = self._redis_range_to_py_range(start, stop) deleted_keys = self._zval[start:stop] del self._zval[start:stop] return len(deleted_keys) PK=JGJI;DD*vumi/persist/tests/test_txredis_manager.py"""Tests for vumi.persist.txredis_manager.""" import os from functools import wraps from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue, Deferred from twisted.trial.unittest import SkipTest from vumi.persist.txredis_manager import TxRedisManager from vumi.tests.helpers import VumiTestCase def wait(secs): d = Deferred() reactor.callLater(secs, d.callback, None) return d def skip_fake_redis(func): @wraps(func) def wrapper(*args, **kw): if 'VUMITEST_REDIS_DB' not in os.environ: # We're using a fake redis, so skip this test. raise SkipTest( "This test requires a real Redis server. Set VUMITEST_REDIS_DB" " to run it.") return func(*args, **kw) return wrapper class TestTxRedisManager(VumiTestCase): @inlineCallbacks def get_manager(self): manager = yield TxRedisManager.from_config({ 'FAKE_REDIS': 'yes', 'key_prefix': 'redistest', }) self.add_cleanup(self.cleanup_manager, manager) yield manager._purge_all() returnValue(manager) @inlineCallbacks def cleanup_manager(self, manager): yield manager._purge_all() yield manager._close() @inlineCallbacks def test_key_unkey(self): manager = yield self.get_manager() self.assertEqual('redistest:foo', manager._key('foo')) self.assertEqual('foo', manager._unkey('redistest:foo')) self.assertEqual('redistest:redistest:foo', manager._key('redistest:foo')) self.assertEqual('redistest:foo', manager._unkey('redistest:redistest:foo')) @inlineCallbacks def test_set_get_keys(self): manager = yield self.get_manager() self.assertEqual([], (yield manager.keys())) self.assertEqual(None, (yield manager.get('foo'))) yield manager.set('foo', 'bar') self.assertEqual(['foo'], (yield manager.keys())) self.assertEqual('bar', (yield manager.get('foo'))) yield manager.set('foo', 'baz') self.assertEqual(['foo'], (yield manager.keys())) self.assertEqual('baz', (yield manager.get('foo'))) @inlineCallbacks def test_disconnect_twice(self): manager = yield self.get_manager() yield manager._close() yield manager._close() @inlineCallbacks def test_scan(self): manager = yield self.get_manager() self.assertEqual([], (yield manager.keys())) for i in range(10): yield manager.set('key%d' % i, 'value%d' % i) all_keys = set() cursor = None for i in range(20): # loop enough times to have gone through all the keys in our test # redis instance but not forever so we can assert on the value of # cursor if we get stuck. cursor, keys = yield manager.scan(cursor) all_keys.update(keys) if cursor is None: break self.assertEqual(cursor, None) self.assertEqual(all_keys, set( 'key%d' % i for i in range(10))) @inlineCallbacks def test_ttl(self): manager = yield self.get_manager() missing_ttl = yield manager.ttl("missing_key") self.assertEqual(missing_ttl, None) yield manager.set("key-no-ttl", "value") no_ttl = yield manager.ttl("key-no-ttl") self.assertEqual(no_ttl, None) yield manager.setex("key-ttl", 30, "value") ttl = yield manager.ttl("key-ttl") self.assertTrue(10 <= ttl <= 30) @skip_fake_redis @inlineCallbacks def test_reconnect_sub_managers(self): manager = yield self.get_manager() sub_manager = manager.sub_manager('subredis') sub_sub_manager = sub_manager.sub_manager('subsubredis') yield manager.set("foo", "1") yield sub_manager.set("foo", "2") yield sub_sub_manager.set("foo", "3") # Our three managers are all connected properly. f1 = yield manager.get("foo") f2 = yield sub_manager.get("foo") f3 = yield sub_sub_manager.get("foo") self.assertEqual([f1, f2, f3], ["1", "2", "3"]) # Kill the connection and wait a few moments for the reconnect. yield manager._client.quit() yield wait(manager._client.factory.initialDelay + 0.05) # Our three managers are all reconnected properly. f1 = yield manager.get("foo") f2 = yield sub_manager.get("foo") f3 = yield sub_sub_manager.get("foo") self.assertEqual([f1, f2, f3], ["1", "2", "3"]) PK*gcHe%vumi/persist/tests/test_fake_redis.py# -*- coding: utf-8 -*- import os from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue, Deferred from vumi.persist.fake_redis import FakeRedis, ResponseError from vumi.tests.helpers import VumiTestCase class FakeRedisTestMixin(object): """ Test methods (and some unimplemented stubs) for FakeRedis. """ def get_redis(self, **kwargs): """ Return a Redis object (or a wrapper around one). """ raise NotImplementedError(".get_redis() method not implemented.") def assert_redis_op(self, redis, expected, op, *args, **kw): """ Assert that a redis operation returns the expected result. """ raise NotImplementedError(".assert_redis_op() method not implemented.") def assert_redis_error(self, redis, op, *args, **kw): """ Assert that a redis operation raises an exception. """ raise NotImplementedError( ".assert_redis_error() method not implemented.") def wait(self, delay): """ Wait some number of seconds, either for real or by advancing a clock. """ raise NotImplementedError(".wait() method not implemented.") @inlineCallbacks def test_rename(self): redis = yield self.get_redis() yield redis.set("old_me", "oldval") yield self.assert_redis_op(redis, True, 'rename', "old_me", "new_me") self.assertEqual((yield redis.exists("old_me")), False) self.assertEqual((yield redis.get("new_me")), "oldval") yield self.assert_redis_error(redis, 'rename', "old_me", "old_me") yield self.assert_redis_error(redis, 'rename', "other", "new_me") @inlineCallbacks def test_delete(self): redis = yield self.get_redis() yield redis.set("delete_me", 1) yield self.assert_redis_op(redis, True, 'delete', "delete_me") yield self.assert_redis_op(redis, False, 'delete', "delete_me") @inlineCallbacks def test_incr(self): redis = yield self.get_redis() yield redis.set("inc", 1) yield self.assert_redis_op(redis, '1', 'get', "inc") yield self.assert_redis_op(redis, 2, 'incr', "inc") yield self.assert_redis_op(redis, 3, 'incr', "inc") yield self.assert_redis_op(redis, '3', 'get', "inc") @inlineCallbacks def test_incrby(self): redis = yield self.get_redis() yield redis.set("inc", 1) yield self.assert_redis_op(redis, '1', 'get', "inc") yield self.assert_redis_op(redis, 3, 'incr', "inc", 2) yield self.assert_redis_op(redis, '3', 'get', "inc") @inlineCallbacks def test_decr(self): redis = yield self.get_redis() yield redis.set("dec", 4) yield self.assert_redis_op(redis, '4', 'get', "dec") yield self.assert_redis_op(redis, 3, 'decr', "dec") yield self.assert_redis_op(redis, 2, 'decr', "dec") yield self.assert_redis_op(redis, '2', 'get', "dec") @inlineCallbacks def test_decrby(self): redis = yield self.get_redis() yield redis.set("dec", 4) yield self.assert_redis_op(redis, '4', 'get', "dec") yield self.assert_redis_op(redis, 2, 'decr', "dec", 2) yield self.assert_redis_op(redis, '2', 'get', "dec") @inlineCallbacks def test_setnx(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, False, 'exists', "mykey") yield self.assert_redis_op(redis, True, 'setnx', "mykey", "value") yield self.assert_redis_op(redis, "value", 'get', "mykey") yield self.assert_redis_op(redis, False, 'setnx', "mykey", "other") yield self.assert_redis_op(redis, "value", 'get', "mykey") @inlineCallbacks def test_setex(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, False, 'exists', "mykey") yield self.assert_redis_op(redis, True, 'setex', "mykey", 10, "value") yield self.assert_redis_op(redis, "value", 'get', "mykey") yield self.assert_redis_op(redis, 10, 'ttl', "mykey") @inlineCallbacks def test_incr_with_by_param(self): redis = yield self.get_redis() yield redis.set("inc", 1) yield self.assert_redis_op(redis, '1', 'get', "inc") yield self.assert_redis_op(redis, 2, 'incr', "inc", 1) yield self.assert_redis_op(redis, 4, 'incr', "inc", 2) yield self.assert_redis_op(redis, 7, 'incr', "inc", 3) yield self.assert_redis_op(redis, 11, 'incr', "inc", 4) yield self.assert_redis_op(redis, 111, 'incr', "inc", 100) yield self.assert_redis_op(redis, '111', 'get', "inc") @inlineCallbacks def test_zadd(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'zadd', 'set', one=1.0) yield self.assert_redis_op(redis, 0, 'zadd', 'set', one=2.0) yield self.assert_redis_op( redis, [('one', 2.0)], 'zrange', 'set', 0, -1, withscores=True) yield self.assert_redis_error(redis, "zadd", "set", one='foo') yield self.assert_redis_error(redis, "zadd", "set", one=None) @inlineCallbacks def test_zrange(self): redis = yield self.get_redis() yield redis.zadd('set', one=0.1, two=0.2, three=0.3) yield self.assert_redis_op(redis, ['one'], 'zrange', 'set', 0, 0) yield self.assert_redis_op( redis, ['one', 'two'], 'zrange', 'set', 0, 1) yield self.assert_redis_op( redis, ['one', 'two', 'three'], 'zrange', 'set', 0, 2) yield self.assert_redis_op( redis, ['one', 'two', 'three'], 'zrange', 'set', 0, 3) yield self.assert_redis_op( redis, ['one', 'two', 'three'], 'zrange', 'set', 0, -1) yield self.assert_redis_op( redis, [('one', 0.1), ('two', 0.2), ('three', 0.3)], 'zrange', 'set', 0, -1, withscores=True) yield self.assert_redis_op( redis, ['three', 'two', 'one'], 'zrange', 'set', 0, -1, desc=True) yield self.assert_redis_op( redis, [('three', 0.3), ('two', 0.2), ('one', 0.1)], 'zrange', 'set', 0, -1, withscores=True, desc=True) yield self.assert_redis_op( redis, [('three', 0.3)], 'zrange', 'set', 0, 0, withscores=True, desc=True) @inlineCallbacks def test_zrangebyscore(self): redis = yield self.get_redis() yield redis.zadd( 'set', one=0.1, two=0.2, three=0.3, four=0.4, five=0.5) yield self.assert_redis_op( redis, ['two', 'three', 'four'], 'zrangebyscore', 'set', 0.2, 0.4) yield self.assert_redis_op( redis, ['two', 'three'], 'zrangebyscore', 'set', 0.2, 0.4, 0, 2) yield self.assert_redis_op( redis, ['three'], 'zrangebyscore', 'set', '(0.2', '(0.4') yield self.assert_redis_op( redis, ['two', 'three', 'four', 'five'], 'zrangebyscore', 'set', '0.2', '+inf') yield self.assert_redis_op( redis, ['one', 'two'], 'zrangebyscore', 'set', '-inf', '0.2') @inlineCallbacks def test_zcount(self): redis = yield self.get_redis() yield redis.zadd( 'set', one=0.1, two=0.2, three=0.3, four=0.4, five=0.5) yield self.assert_redis_op(redis, 3, 'zcount', 'set', 0.2, 0.4) @inlineCallbacks def test_zrangebyscore_with_scores(self): redis = yield self.get_redis() yield redis.zadd( 'set', one=0.1, two=0.2, three=0.3, four=0.4, five=0.5) yield self.assert_redis_op( redis, [('two', 0.2), ('three', 0.3), ('four', 0.4)], 'zrangebyscore', 'set', 0.2, 0.4, withscores=True) @inlineCallbacks def test_zcard(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 0, 'zcard', 'set') yield redis.zadd('set', one=0.1, two=0.2) yield self.assert_redis_op(redis, 2, 'zcard', 'set') yield redis.zadd('set', three=0.3) yield self.assert_redis_op(redis, 3, 'zcard', 'set') @inlineCallbacks def test_zrem(self): redis = yield self.get_redis() yield redis.zadd('set', one=0.1, two=0.2) yield self.assert_redis_op(redis, True, 'zrem', 'set', 'one') yield self.assert_redis_op(redis, False, 'zrem', 'set', 'one') yield self.assert_redis_op( redis, [('two', 0.2)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 2, 'zremrangebyrank', 'set', 0, 1) yield self.assert_redis_op( redis, [('three', 3)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_empty_range(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 0, 'zremrangebyrank', 'set', 10, 11) yield self.assert_redis_op( redis, [('one', 1), ('two', 2), ('three', 3)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_negative_empty_range(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op( redis, 0, 'zremrangebyrank', 'set', -11, -10) yield self.assert_redis_op( redis, [('one', 1), ('two', 2), ('three', 3)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_negative_start(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 2, 'zremrangebyrank', 'set', -2, 2) yield self.assert_redis_op( redis, [('one', 1)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_negative_start_empty_range(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 0, 'zremrangebyrank', 'set', -1, 1) yield self.assert_redis_op( redis, [('one', 1), ('two', 2), ('three', 3)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_negative_stop(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 2, 'zremrangebyrank', 'set', 1, -1) yield self.assert_redis_op( redis, [('one', 1)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zremrangebyrank_negative_stop_empty_range(self): redis = yield self.get_redis() yield redis.zadd('set', one=1, two=2, three=3) yield self.assert_redis_op(redis, 0, 'zremrangebyrank', 'set', 0, -5) yield self.assert_redis_op( redis, [('one', 1), ('two', 2), ('three', 3)], 'zrange', 'set', 0, -1, withscores=True) @inlineCallbacks def test_zscore(self): redis = yield self.get_redis() yield redis.zadd('set', one=0.1, two=0.2) yield self.assert_redis_op(redis, 0.1, 'zscore', 'set', 'one') yield self.assert_redis_op(redis, 0.2, 'zscore', 'set', 'two') @inlineCallbacks def test_hgetall_returns_copy(self): redis = yield self.get_redis() yield redis.hset("hash", "foo", "1") data = yield redis.hgetall("hash") data["foo"] = "2" yield self.assert_redis_op(redis, {"foo": "1"}, 'hgetall', "hash") @inlineCallbacks def test_hincrby(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'hincrby', "inc", "field1") yield self.assert_redis_op(redis, 2, 'hincrby', "inc", "field1") yield self.assert_redis_op(redis, 5, 'hincrby', "inc", "field1", 3) yield self.assert_redis_op(redis, 7, 'hincrby', "inc", "field1", "2") yield self.assert_redis_error(redis, "hincrby", "inc", "field1", "1.5") yield redis.hset("inc", "field2", "a") yield self.assert_redis_error(redis, "hincrby", "inc", "field2") yield redis.set("key", "string") yield self.assert_redis_error(redis, "hincrby", "key", "field1") @inlineCallbacks def test_hexists(self): redis = yield self.get_redis() yield redis.hset('key', 'field', 1) yield self.assert_redis_op(redis, True, 'hexists', 'key', 'field') yield redis.hdel('key', 'field') yield self.assert_redis_op(redis, False, 'hexists', 'key', 'field') @inlineCallbacks def test_hsetnx(self): redis = yield self.get_redis() yield redis.hset('key', 'field', 1) self.assert_redis_op(redis, 0, 'hsetnx', 'key', 'field', 2) self.assertEqual((yield redis.hget('key', 'field')), '1') self.assert_redis_op(redis, 1, 'hsetnx', 'key', 'other-field', 2) self.assertEqual((yield redis.hget('key', 'other-field')), '2') @inlineCallbacks def test_sadd(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'sadd', 'set', 1) yield self.assert_redis_op(redis, 3, 'sadd', 'set', 2, 3, 4) yield self.assert_redis_op( redis, set(['1', '2', '3', '4']), 'smembers', 'set') @inlineCallbacks def test_smove(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'sadd', 'set1', 1) yield self.assert_redis_op(redis, 1, 'sadd', 'set2', 2) yield self.assert_redis_op(redis, True, 'smove', 'set1', 'set2', '1') yield self.assert_redis_op(redis, set(), 'smembers', 'set1') yield self.assert_redis_op(redis, set(['1', '2']), 'smembers', 'set2') yield self.assert_redis_op(redis, False, 'smove', 'set1', 'set2', '1') yield self.assert_redis_op(redis, True, 'smove', 'set2', 'set3', '1') yield self.assert_redis_op(redis, set(['2']), 'smembers', 'set2') yield self.assert_redis_op(redis, set(['1']), 'smembers', 'set3') @inlineCallbacks def test_sunion(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'sadd', 'set1', 1) yield self.assert_redis_op(redis, 1, 'sadd', 'set2', 2) yield self.assert_redis_op(redis, set(['1']), 'sunion', 'set1') yield self.assert_redis_op( redis, set(['1', '2']), 'sunion', 'set1', 'set2') yield self.assert_redis_op(redis, set(), 'sunion', 'other') @inlineCallbacks def test_lpush(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'lpush', 'list', 1) yield self.assert_redis_op(redis, ['1'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, 2, 'lpush', 'list', 'a') yield self.assert_redis_op(redis, ['a', '1'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, 3, 'lpush', 'list', '7') yield self.assert_redis_op( redis, ['7', 'a', '1'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_rpush(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'rpush', 'list', 1) yield self.assert_redis_op(redis, ['1'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, 2, 'rpush', 'list', 'a') yield self.assert_redis_op(redis, ['1', 'a'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, 3, 'rpush', 'list', '7') yield self.assert_redis_op( redis, ['1', 'a', '7'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_rpop(self): redis = yield self.get_redis() yield redis.lpush('key', 1) yield redis.lpush('key', 'a') yield redis.lpush('key', '3') yield self.assert_redis_op(redis, '1', 'rpop', 'key') yield self.assert_redis_op(redis, 'a', 'rpop', 'key') yield self.assert_redis_op(redis, '3', 'rpop', 'key') yield self.assert_redis_op(redis, None, 'rpop', 'key') @inlineCallbacks def test_rpoplpush(self): redis = yield self.get_redis() yield redis.lpush('source', 1) yield redis.lpush('source', 'a') yield redis.lpush('source', '3') yield self.assert_redis_op( redis, '1', 'rpoplpush', 'source', 'destination') yield self.assert_redis_op( redis, 'a', 'rpoplpush', 'source', 'destination') yield self.assert_redis_op( redis, '3', 'rpoplpush', 'source', 'destination') yield self.assert_redis_op(redis, None, 'rpop', 'source') yield self.assert_redis_op(redis, '1', 'rpop', 'destination') yield self.assert_redis_op(redis, 'a', 'rpop', 'destination') yield self.assert_redis_op(redis, '3', 'rpop', 'destination') yield self.assert_redis_op(redis, None, 'rpop', 'destination') @inlineCallbacks def test_lrem(self): redis = yield self.get_redis() for i in range(5): yield self.assert_redis_op( redis, 2 * i + 1, 'rpush', 'list', 'v%d' % i) yield self.assert_redis_op(redis, 2 * i + 2, 'rpush', 'list', 1) yield self.assert_redis_op(redis, 5, 'lrem', 'list', 1) yield self.assert_redis_op( redis, ['v0', 'v1', 'v2', 'v3', 'v4'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_lrem_positive_num(self): redis = yield self.get_redis() for i in range(5): yield self.assert_redis_op( redis, 2 * i + 1, 'rpush', 'list', 'v%d' % i) yield self.assert_redis_op(redis, 2 * i + 2, 'rpush', 'list', 1) yield self.assert_redis_op(redis, 2, 'lrem', 'list', 1, 2) yield self.assert_redis_op( redis, ['v0', 'v1', 'v2', '1', 'v3', '1', 'v4', '1'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_lrem_negative_num(self): redis = yield self.get_redis() for i in range(5): yield self.assert_redis_op( redis, 2 * i + 1, 'rpush', 'list', 'v%d' % i) yield self.assert_redis_op(redis, 2 * i + 2, 'rpush', 'list', 1) yield self.assert_redis_op(redis, 2, 'lrem', 'list', 1, -2) yield self.assert_redis_op( redis, ['v0', '1', 'v1', '1', 'v2', '1', 'v3', 'v4'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_ltrim(self): redis = yield self.get_redis() for i in range(1, 5): yield self.assert_redis_op(redis, i, 'rpush', 'list', str(i)) yield self.assert_redis_op( redis, ['1', '2', '3', '4'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, True, 'ltrim', 'list', 1, 2) yield self.assert_redis_op(redis, ['2', '3'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_ltrim_mid_range(self): redis = yield self.get_redis() for i in range(1, 6): yield self.assert_redis_op(redis, i, 'rpush', 'list', str(i)) yield self.assert_redis_op( redis, ['1', '2', '3', '4', '5'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, True, 'ltrim', 'list', 2, 3) yield self.assert_redis_op(redis, ['3', '4'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_ltrim_keep_all(self): redis = yield self.get_redis() for i in range(1, 4): yield self.assert_redis_op(redis, i, 'rpush', 'list', str(i)) yield self.assert_redis_op( redis, ['1', '2', '3'], 'lrange', 'list', 0, -1) yield self.assert_redis_op(redis, True, 'ltrim', 'list', 0, -1) yield self.assert_redis_op( redis, ['1', '2', '3'], 'lrange', 'list', 0, -1) @inlineCallbacks def test_expire_persist_ttl(self): redis = yield self.get_redis() # Missing key. yield self.assert_redis_op(redis, None, 'ttl', "tempval") yield self.assert_redis_op(redis, 0, 'expire', "tempval", 10) yield self.assert_redis_op(redis, 0, 'persist', "tempval") # Persistent key. yield redis.set("tempval", 1) yield self.assert_redis_op(redis, None, 'ttl', "tempval") yield self.assert_redis_op(redis, 0, 'persist', "tempval") yield self.assert_redis_op(redis, 1, 'expire', "tempval", 10) # Temporary key. yield self.assert_redis_op(redis, 10, 'ttl', "tempval") yield self.wait(redis, 0.6) # Wait a bit for the TTL to change. yield self.assert_redis_op(redis, 9, 'ttl', "tempval") yield self.assert_redis_op(redis, 1, 'expire', "tempval", 5) yield self.assert_redis_op(redis, 5, 'ttl', "tempval") yield self.assert_redis_op(redis, 1, 'persist', "tempval") # Persistent key again. yield redis.set("tempval", 1) yield self.assert_redis_op(redis, None, 'ttl', "tempval") yield self.assert_redis_op(redis, 0, 'persist', "tempval") yield self.assert_redis_op(redis, 1, 'expire', "tempval", 10) @inlineCallbacks def test_type(self): redis = yield self.get_redis() yield self.assert_redis_op(redis, 'none', 'type', 'unknown_key') yield redis.set("string_key", "a") yield self.assert_redis_op(redis, 'string', 'type', 'string_key') yield redis.lpush("list_key", "a") yield self.assert_redis_op(redis, 'list', 'type', 'list_key') yield redis.sadd("set_key", "a") yield self.assert_redis_op(redis, 'set', 'type', 'set_key') yield redis.zadd("zset_key", a=1.0) yield self.assert_redis_op(redis, 'zset', 'type', 'zset_key') yield redis.hset("hash_key", "a", 1.0) yield self.assert_redis_op(redis, 'hash', 'type', 'hash_key') @inlineCallbacks def test_charset_encoding_default(self): # Redis client assumes utf-8 redis = yield self.get_redis() yield redis.set('name', u'Zoë Destroyer of Ascii') yield self.assert_redis_op( redis, 'Zo\xc3\xab Destroyer of Ascii', 'get', 'name') @inlineCallbacks def test_charset_encoding_custom_replace(self): redis = yield self.get_redis(charset='ascii', errors='replace') yield redis.set('name', u'Zoë Destroyer of Ascii') yield self.assert_redis_op( redis, 'Zo? Destroyer of Ascii', 'get', 'name') @inlineCallbacks def test_charset_encoding_custom_ignore(self): redis = yield self.get_redis(charset='ascii', errors='ignore') yield redis.set('name', u'Zoë Destroyer of Ascii') yield self.assert_redis_op( redis, 'Zo Destroyer of Ascii', 'get', 'name') @inlineCallbacks def test_scan_no_keys(self): """ Real and fake Redis implementation all return the same (empty) response when we scan for keys that don't exist. Other scanning methods are in FakeRedisUnverifiedTestMixin because we can't fake the same arbitrary order in which keys are returned from real Redis. """ redis = yield self.get_redis() yield self.assert_redis_op(redis, [None, []], 'scan', None) @inlineCallbacks def test_pfadd_and_pfcount(self): """ We can't test these two things separately, so test them together. """ redis = yield self.get_redis() yield self.assert_redis_op(redis, 1, 'pfadd', 'hll1', 'a') yield self.assert_redis_op(redis, 1, 'pfcount', 'hll1') yield self.assert_redis_op(redis, 0, 'pfadd', 'hll1', 'a') yield self.assert_redis_op(redis, 1, 'pfcount', 'hll1') yield self.assert_redis_op(redis, 1, 'pfadd', 'hll2', 'a', 'b') yield self.assert_redis_op(redis, 2, 'pfcount', 'hll2') yield self.assert_redis_op(redis, 0, 'pfadd', 'hll2', 'a', 'b') yield self.assert_redis_op(redis, 2, 'pfcount', 'hll2') class FakeRedisUnverifiedTestMixin(object): """ This mixin adds some extra tests that are not verified against real Redis. Each test in here should explain why verification isn't possible. """ @inlineCallbacks def test_scan_simple(self): """ Scanning returns keys in an order that depends on arbitrary state in the Redis server, so we can't fake it in a way that's identical to real Redis. """ redis = yield self.get_redis() for i in range(20): yield redis.set("key%02d" % i, str(i)) # Ordered the way FakeRedis.scan() returns them. result_keys = redis._sort_keys_by_hash( ["key%02d" % i for i in range(20)]) self.assert_redis_op(redis, ['10', result_keys[:10]], 'scan', None) self.assert_redis_op( redis, ['5', result_keys[:5]], 'scan', None, count=5) self.assert_redis_op( redis, ['10', result_keys[5:10]], 'scan', '5', count=5) self.assert_redis_op( redis, [None, result_keys[15:]], 'scan', '15', count=5) self.assert_redis_op( redis, [None, result_keys], 'scan', None, count=20) @inlineCallbacks def test_scan_interleaved_key_changes(self): """ Scanning returns keys in an order that depends on arbitrary state in the Redis server, so we can't fake it in a way that's identical to real Redis. """ redis = yield self.get_redis() for i in range(20): yield redis.set("key%02d" % i, str(i)) # Ordered the way FakeRedis.scan() returns them. result_keys = redis._sort_keys_by_hash( ["key%02d" % i for i in range(20)]) self.assert_redis_op(redis, ['10', result_keys[:10]], 'scan', None) # Set and delete a bunch of keys to change some internal state. The # next call to scan() will return duplicates. for i in range(20): yield redis.set("transient%02d" % i, str(i)) yield redis.delete("transient%02d" % i) self.assert_redis_op(redis, ['31', result_keys[5:15]], 'scan', '10') self.assert_redis_op(redis, [None, result_keys[15:]], 'scan', '31') @inlineCallbacks def test_pfadd_and_pfcount_large(self): """ for large sets, we get approximate counts. Redis and hyperloglog use different hash functions, so we get different approximations out of them and can't verify the results. """ redis = yield self.get_redis() values = ['v%s' % i for i in xrange(1000)] yield self.assert_redis_op(redis, 1, 'pfadd', 'hll1', *values) yield self.assert_redis_op(redis, 998, 'pfcount', 'hll1') yield self.assert_redis_op(redis, 0, 'pfadd', 'hll1', *values) yield self.assert_redis_op(redis, 998, 'pfcount', 'hll1') class TestFakeRedis(FakeRedisUnverifiedTestMixin, FakeRedisTestMixin, VumiTestCase): def get_redis(self, **kwargs): redis = FakeRedis(**kwargs) self.add_cleanup(redis.teardown) return redis def assert_redis_op(self, redis, expected, op, *args, **kw): self.assertEqual(expected, getattr(redis, op)(*args, **kw)) def assert_redis_error(self, redis, op, *args, **kw): self.assertRaises( ResponseError, getattr(redis, op), *args, **kw) def wait(self, redis, delay): redis.clock.advance(delay) class TestFakeRedisAsync(FakeRedisUnverifiedTestMixin, FakeRedisTestMixin, VumiTestCase): def get_redis(self, **kwargs): redis = FakeRedis(async=True, **kwargs) self.add_cleanup(redis.teardown) return redis def assert_redis_op(self, redis, expected, op, *args, **kw): d = getattr(redis, op)(*args, **kw) return d.addCallback(lambda r: self.assertEqual(expected, r)) def assert_redis_error(self, redis, op, *args, **kw): d = getattr(redis, op)(*args, **kw) return self.assertFailure(d, ResponseError) def wait(self, redis, delay): redis.clock.advance(delay) class RedisPairWrapper(object): def __init__(self, test_case, fake_redis, real_redis): self._test_case = test_case self._fake_redis = fake_redis self._real_redis = real_redis self._perform_operation = self._real_redis.call_decorator( self._perform_operation_gen) def _perform_operation_gen(self, op, *args, **kw): """ Perform an operation on both the fake and real Redises and assert that the responses and errors are the same. NOTE: This method is a generator and is not used directly. It's wrapped with an appropriate sync/async wrapper in __init__() above. """ results = [] errors = [] for redis in [self._fake_redis, self._real_redis]: try: result = yield getattr(redis, op)(*args, **kw) except Exception as e: errors.append(e) if results != []: self._test_case.fail( "Fake redis returned %r but real redis raised %r" % ( results[0], errors[0])) else: results.append(result) if errors != []: self._test_case.fail( "Real redis returned %r but fake redis raised %r" % ( results[0], errors[0])) # First, handle errors. if errors: # We convert ResponseErrors from the real redis to the fake redis # ResponseError type. We ignore the error message in the check, but # we display it to aid debugging. fake_type, real_type = type(errors[0]), type(errors[1]) if real_type is self._real_redis.RESPONSE_ERROR: real_type = self._fake_redis.RESPONSE_ERROR self._test_case.assertEqual( fake_type, real_type, ("Fake redis (a) and real redis (b) errors different:" "\n a = %r\n b = %r") % tuple(errors)) raise errors[0] # Now handle results. self._test_case.assertEqual( results[0], results[1], "Fake redis (a) and real redis (b) responses different:" "\n a = %r\n b = %r" % tuple(results)) returnValue(results[0]) def __getattr__(self, name): return lambda *args, **kw: self._perform_operation(name, *args, **kw) class TestFakeRedisVerify(FakeRedisTestMixin, VumiTestCase): if 'VUMITEST_REDIS_DB' not in os.environ: skip = ("This test requires a real Redis server. Set VUMITEST_REDIS_DB" " to run it.") def get_redis(self, **kwargs): from vumi.persist.redis_manager import RedisManager # Fake redis fake_redis = RedisManager._fake_manager(FakeRedis(**kwargs), { "config": {}, "key_prefix": 'redistest', }) self.add_cleanup(fake_redis._close) # Real redis config = { 'FAKE_REDIS': 'yes', 'key_prefix': 'redistest', } config.update(kwargs) real_redis = RedisManager.from_config(config) self.add_cleanup(self.cleanup_manager, real_redis) real_redis._purge_all() # Both redises return RedisPairWrapper(self, fake_redis, real_redis) def cleanup_manager(self, manager): manager._purge_all() manager._close() def assert_redis_op(self, redis, expected, op, *args, **kw): self.assertEqual(expected, getattr(redis, op)(*args, **kw)) def assert_redis_error(self, redis, op, *args, **kw): self.assertRaises(ResponseError, getattr(redis, op), *args, **kw) def wait(self, redis, delay): redis._fake_redis._client.clock.advance(delay) d = Deferred() reactor.callLater(delay, d.callback, None) return d class TestFakeRedisVerifyAsync(FakeRedisTestMixin, VumiTestCase): if 'VUMITEST_REDIS_DB' not in os.environ: skip = ("This test requires a real Redis server. Set VUMITEST_REDIS_DB" " to run it.") @inlineCallbacks def get_redis(self, **kwargs): from vumi.persist.txredis_manager import TxRedisManager # Fake redis fake_redis = yield TxRedisManager._fake_manager( FakeRedis(async=True, **kwargs), { "config": {}, "key_prefix": 'redistest', }) self.add_cleanup(fake_redis._close) # Real redis config = { 'FAKE_REDIS': 'yes', 'key_prefix': 'redistest', } config.update(kwargs) real_redis = yield TxRedisManager.from_config(config) self.add_cleanup(self.cleanup_manager, real_redis) # Both redises yield real_redis._purge_all() returnValue(RedisPairWrapper(self, fake_redis, real_redis)) @inlineCallbacks def cleanup_manager(self, manager): yield manager._purge_all() yield manager._close() def assert_redis_op(self, redis, expected, op, *args, **kw): d = getattr(redis, op)(*args, **kw) return d.addCallback(lambda r: self.assertEqual(expected, r)) def assert_redis_error(self, redis, op, *args, **kw): d = getattr(redis, op)(*args, **kw) return self.assertFailure(d, ResponseError) def wait(self, redis, delay): redis._fake_redis._client.clock.advance(delay) d = Deferred() reactor.callLater(delay, d.callback, None) return d PK=JG vumi/persist/tests/test_model.py# -*- coding: utf-8 -*- """Tests for vumi.persist.model.""" from twisted.internet.defer import inlineCallbacks, returnValue from vumi.persist.model import ( Model, Manager, ModelMigrator, ModelMigrationError, VumiRiakError) from vumi.persist import fields from vumi.persist.fields import ( ValidationError, Integer, Unicode, Dynamic, Field, FieldDescriptor) from vumi.tests.helpers import VumiTestCase, import_skip class SimpleModel(Model): a = Integer() b = Unicode() class IndexedModel(Model): a = Integer(index=True) b = Unicode(index=True, null=True) class InheritedModel(SimpleModel): c = Integer() class OverriddenModel(InheritedModel): c = Integer(min=0, max=5) class VersionedModelMigrator(ModelMigrator): def migrate_from_unversioned(self, migration_data): # Migrator assertions assert self.data_version is None assert self.model_class is VersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set(['$VERSION', 'a']) assert migration_data.old_data['$VERSION'] is None assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', 1) migration_data.set_value('b', migration_data.old_data['a']) return migration_data def reverse_from_1(self, migration_data): # Migrator assertions assert self.data_version == 1 assert self.model_class is VersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set(['$VERSION', 'b']) assert migration_data.old_data['$VERSION'] == 1 assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', None) migration_data.set_value('a', migration_data.old_data['b']) return migration_data def migrate_from_1(self, migration_data): # Migrator assertions assert self.data_version == 1 assert self.model_class is VersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set(['$VERSION', 'b']) assert migration_data.old_data['$VERSION'] == 1 assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', 2) migration_data.set_value('c', migration_data.old_data['b']) migration_data.set_value('text', 'hello') return migration_data def reverse_from_2(self, migration_data): # Migrator assertions assert self.data_version == 2 assert self.model_class is VersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set( ['$VERSION', 'c', 'text']) assert migration_data.old_data['$VERSION'] == 2 assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', 1) migration_data.set_value('b', migration_data.old_data['c']) # Drop the text field. return migration_data def migrate_from_2(self, migration_data): # Migrator assertions assert self.data_version == 2 assert self.model_class is IndexedVersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set( ['$VERSION', 'c', 'text']) assert migration_data.old_data['$VERSION'] == 2 assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', 3) migration_data.copy_values('c') migration_data.set_value( 'text', migration_data.old_data['text'], index='text_bin') return migration_data def reverse_from_3(self, migration_data): # Migrator assertions assert self.data_version == 3 assert self.model_class is IndexedVersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set( ['$VERSION', 'c', 'text']) assert migration_data.old_data['$VERSION'] == 3 assert migration_data.old_index == {"text_bin": ["hi"]} # Actual migration migration_data.set_value('$VERSION', 2) migration_data.copy_values('c') migration_data.set_value('text', migration_data.old_data['text']) return migration_data def migrate_from_3(self, migration_data): # Migrator assertions assert self.data_version == 3 assert self.model_class is IndexRemovedVersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set( ['$VERSION', 'c', 'text']) assert migration_data.old_data['$VERSION'] == 3 assert migration_data.old_index == {"text_bin": ["hi"]} # Actual migration migration_data.set_value('$VERSION', 4) migration_data.copy_values('c') migration_data.set_value('text', migration_data.old_data['text']) return migration_data def reverse_from_4(self, migration_data): # Migrator assertions assert self.data_version == 4 assert self.model_class is IndexRemovedVersionedModel assert isinstance(self.manager, Manager) # Data assertions assert set(migration_data.old_data.keys()) == set( ['$VERSION', 'c', 'text']) assert migration_data.old_data['$VERSION'] == 4 assert migration_data.old_index == {} # Actual migration migration_data.set_value('$VERSION', 3) migration_data.copy_values('c') migration_data.set_value( 'text', migration_data.old_data['text'], index='text_bin') return migration_data class UnversionedModel(Model): bucket = 'versionedmodel' a = Integer() class OldVersionedModel(Model): VERSION = 1 bucket = 'versionedmodel' b = Integer() class VersionedModel(Model): VERSION = 2 MIGRATOR = VersionedModelMigrator c = Integer() text = Unicode(null=True) class IndexedVersionedModel(Model): VERSION = 3 MIGRATOR = VersionedModelMigrator bucket = 'versionedmodel' c = Integer() text = Unicode(null=True, index=True) class IndexRemovedVersionedModel(Model): VERSION = 4 MIGRATOR = VersionedModelMigrator bucket = 'versionedmodel' c = Integer() text = Unicode(null=True) class UnknownVersionedModel(Model): VERSION = 5 bucket = 'versionedmodel' d = Integer() class VersionedDynamicModelMigrator(ModelMigrator): def migrate_from_unversioned(self, migration_data): migration_data.copy_dynamic_values('keep-') migration_data.set_value('$VERSION', 1) return migration_data class UnversionedDynamicModel(Model): bucket = 'versioneddynamicmodel' drop = Dynamic(prefix='drop-') keep = Dynamic(prefix='keep-') class VersionedDynamicModel(Model): bucket = 'versioneddynamicmodel' VERSION = 1 MIGRATOR = VersionedDynamicModelMigrator drop = Dynamic(prefix='drop-') keep = Dynamic(prefix='keep-') class ModelTestMixin(object): @Manager.calls_manager def filter_tombstones(self, model_cls, keys): live_keys = [] for key in keys: model = yield model_cls.load(key) if model is not None: live_keys.append(key) returnValue(live_keys) def get_model_indexes(self, model): indexes = {} for name, value in model._riak_object.get_indexes(): indexes.setdefault(name, []).append(value) return indexes def test_simple_class(self): field_names = SimpleModel.field_descriptors.keys() self.assertEqual(sorted(field_names), ['a', 'b']) self.assertTrue(isinstance(SimpleModel.a, Integer)) self.assertTrue(isinstance(SimpleModel.b, Unicode)) def test_repr(self): simple_model = self.manager.proxy(SimpleModel) s = simple_model("foo", a=1, b=u"bar") self.assertEqual( repr(s), "") def test_get_data(self): simple_model = self.manager.proxy(SimpleModel) s = simple_model("foo", a=1, b=u"bar") self.assertEqual(s.get_data(), { '$VERSION': None, 'key': 'foo', 'a': 1, 'b': 'bar', }) def test_declare_backlinks(self): class TestModel(Model): pass TestModel.backlinks.declare_backlink("foo", lambda m, o: None) self.assertRaises(RuntimeError, TestModel.backlinks.declare_backlink, "foo", lambda m, o: None) t = TestModel(self.manager, "key") self.assertTrue(callable(t.backlinks.foo)) self.assertRaises(AttributeError, getattr, t.backlinks, 'bar') @inlineCallbacks def assert_mapreduce_results(self, expected_keys, mr_func, *args, **kw): keys = yield mr_func(*args, **kw).get_keys() count = yield mr_func(*args, **kw).get_count() self.assertEqual(expected_keys, sorted(keys)) self.assertEqual(len(expected_keys), count) @inlineCallbacks def assert_search_results(self, expected_keys, func, *args, **kw): keys = yield func(*args, **kw) self.assertEqual(expected_keys, sorted(keys)) @Manager.calls_manager def test_simple_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() search = simple_model.search yield self.assert_mapreduce_results(["one"], search, a=1) yield self.assert_mapreduce_results(["two"], search, a=2, b='def') yield self.assert_mapreduce_results(["three", "two"], search, a=2) @Manager.calls_manager def test_simple_search_escaping(self): simple_model = self.manager.proxy(SimpleModel) search = simple_model.search yield simple_model.enable_search() yield simple_model("one", a=1, b=u'a\'bc').save() search = simple_model.search yield self.assert_mapreduce_results([], search, b=" OR a:1") yield self.assert_mapreduce_results([], search, b="b' OR a:1 '") yield self.assert_mapreduce_results(["one"], search, b="a\'bc") @Manager.calls_manager def test_simple_raw_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() search = simple_model.raw_search yield self.assert_mapreduce_results(["one"], search, 'a:1') yield self.assert_mapreduce_results(["two"], search, 'a:2 AND b:def') yield self.assert_mapreduce_results( ["one", "two"], search, 'b:abc OR b:def') yield self.assert_mapreduce_results(["three", "two"], search, 'a:2') @Manager.calls_manager def test_simple_real_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() search = simple_model.real_search yield self.assert_search_results(["one"], search, 'a:1') yield self.assert_search_results(["two"], search, 'a:2 AND b:def') yield self.assert_search_results( ["one", "two"], search, 'b:abc OR b:def') yield self.assert_search_results(["three", "two"], search, 'a:2') @Manager.calls_manager def test_big_real_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() keys = [] for i in range(100): key = "xx%06d" % (i + 1) keys.append(key) yield simple_model(key, a=99, b=u'abc').save() yield simple_model("yy000001", a=98, b=u'def').save() yield simple_model("yy000002", a=98, b=u'ghi').save() search = lambda q: simple_model.real_search(q, rows=11) yield self.assert_search_results(keys, search, 'a:99') @Manager.calls_manager def test_empty_real_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() search = simple_model.real_search yield self.assert_search_results([], search, 'a:7') @Manager.calls_manager def test_limited_results_real_search(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model.enable_search() yield simple_model("1one", a=1, b=u'abc').save() yield simple_model("2two", a=2, b=u'def').save() yield simple_model("3three", a=2, b=u'ghi').save() yield simple_model("4four", a=2, b=u'jkl').save() @inlineCallbacks def search(q): results = yield simple_model.real_search(q, rows=2, start=0) self.assertEqual(len(results), 2) results_new = yield simple_model.real_search(q, rows=2, start=2) self.assertEqual(len(results_new), 1) returnValue(results + results_new) yield self.assert_search_results( [u'2two', u'3three', u'4four'], search, 'a:2') @Manager.calls_manager def test_load_all_bunches(self): self.assertFalse(self.manager.USE_MAPREDUCE_BUNCH_LOADING) simple_model = self.manager.proxy(SimpleModel) yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() objs_iter = simple_model.load_all_bunches(['one', 'two', 'bad']) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) self.assertEqual(["one", "two"], sorted(obj.key for obj in objs)) @Manager.calls_manager def test_load_all_bunches_skips_tombstones(self): self.assertFalse(self.manager.USE_MAPREDUCE_BUNCH_LOADING) simple_model = self.manager.proxy(SimpleModel) yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() tombstone = yield simple_model("tombstone", a=2, b=u'ghi').save() yield tombstone.delete() objs_iter = simple_model.load_all_bunches(['one', 'two', 'tombstone']) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) self.assertEqual(["one", "two"], sorted(obj.key for obj in objs)) @Manager.calls_manager def test_load_all_bunches_mapreduce(self): self.manager.USE_MAPREDUCE_BUNCH_LOADING = True simple_model = self.manager.proxy(SimpleModel) yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() yield simple_model("three", a=2, b=u'ghi').save() objs_iter = simple_model.load_all_bunches(['one', 'two', 'bad']) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) self.assertEqual(["one", "two"], sorted(obj.key for obj in objs)) @Manager.calls_manager def test_load_all_bunches_mapreduce_skips_tombstones(self): self.manager.USE_MAPREDUCE_BUNCH_LOADING = True simple_model = self.manager.proxy(SimpleModel) yield simple_model("one", a=1, b=u'abc').save() yield simple_model("two", a=2, b=u'def').save() tombstone = yield simple_model("tombstone", a=2, b=u'ghi').save() yield tombstone.delete() objs_iter = simple_model.load_all_bunches(['one', 'two', 'tombstone']) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) self.assertEqual(["one", "two"], sorted(obj.key for obj in objs)) @Manager.calls_manager def test_load_all_bunches_performance(self): """ A performance test that is handy to occasionally but shouldn't happen on every test run. This should go away once we're happy with the non-mapreduce bunch loading. """ import time start_setup = time.time() simple_model = self.manager.proxy(SimpleModel) keys = [] for i in xrange(2000): obj = yield simple_model("item%s" % i, a=i, b=u'abc').save() keys.append(obj.key) end_setup = time.time() print "\n\nSetup time: %s" % (end_setup - start_setup,) start_mr = time.time() self.manager.USE_MAPREDUCE_BUNCH_LOADING = True objs_iter = simple_model.load_all_bunches(keys) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) end_mr = time.time() print "Mapreduce time: %s" % (end_mr - start_mr,) start_mult = time.time() self.manager.USE_MAPREDUCE_BUNCH_LOADING = False objs_iter = simple_model.load_all_bunches(keys) objs = [] for obj_bunch in objs_iter: objs.extend((yield obj_bunch)) end_mult = time.time() print "Multiple time: %s\n" % (end_mult - start_mult,) self.assertEqual(sorted(keys), sorted(obj.key for obj in objs)) test_load_all_bunches_performance.skip = ( "This takes a long time to run. Enable it if you need it.") @Manager.calls_manager def test_simple_instance(self): simple_model = self.manager.proxy(SimpleModel) s1 = simple_model("foo", a=5, b=u'3') yield s1.save() s2 = yield simple_model.load("foo") self.assertEqual(s2.a, 5) self.assertEqual(s2.b, u'3') self.assertEqual(s2.was_migrated, False) @Manager.calls_manager def test_simple_instance_delete(self): simple_model = self.manager.proxy(SimpleModel) s1 = simple_model("foo", a=5, b=u'3') yield s1.save() s2 = yield simple_model.load("foo") yield s2.delete() s3 = yield simple_model.load("foo") self.assertEqual(s3, None) @Manager.calls_manager def test_nonexist_keys_return_none(self): simple_model = self.manager.proxy(SimpleModel) s = yield simple_model.load("foo") self.assertEqual(s, None) @Manager.calls_manager def test_all_keys(self): simple_model = self.manager.proxy(SimpleModel) keys = yield self.filter_tombstones( simple_model, (yield simple_model.all_keys())) self.assertEqual(keys, []) yield simple_model("foo-1", a=5, b=u'1').save() yield simple_model("foo-2", a=5, b=u'2').save() keys = yield self.filter_tombstones( simple_model, (yield simple_model.all_keys())) self.assertEqual(sorted(keys), [u"foo-1", u"foo-2"]) @Manager.calls_manager def test_index_keys(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('a', 1) self.assertEqual(keys, ["foo1"]) # We should get a list object, not an IndexPage wrapper. self.assertTrue(isinstance(keys, list)) keys = yield indexed_model.index_keys('b', u"one") self.assertEqual(sorted(keys), ["foo1", "foo2"]) keys = yield indexed_model.index_keys('b', None) self.assertEqual(keys, []) @Manager.calls_manager def test_index_keys_store_none_for_empty(self): self.patch(fields, "STORE_NONE_FOR_EMPTY_INDEX", True) indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('a', 1) self.assertEqual(keys, ["foo1"]) # We should get a list object, not an IndexPage wrapper. self.assertTrue(isinstance(keys, list)) keys = yield indexed_model.index_keys('b', u"one") self.assertEqual(sorted(keys), ["foo1", "foo2"]) keys = yield indexed_model.index_keys('b', None) self.assertEqual(keys, ["foo3"]) @Manager.calls_manager def test_index_keys_return_terms(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('a', 1, return_terms=True) self.assertEqual(keys, [("1", "foo1")]) keys = yield indexed_model.index_keys('b', u"one", return_terms=True) self.assertEqual(sorted(keys), [(u"one", "foo1"), (u"one", "foo2")]) keys = yield indexed_model.index_keys('b', None, return_terms=True) self.assertEqual(list(keys), []) @Manager.calls_manager def test_index_keys_return_terms_store_none_for_empty(self): self.patch(fields, "STORE_NONE_FOR_EMPTY_INDEX", True) indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('a', 1, return_terms=True) self.assertEqual(keys, [("1", "foo1")]) keys = yield indexed_model.index_keys('b', u"one", return_terms=True) self.assertEqual(sorted(keys), [(u"one", "foo1"), (u"one", "foo2")]) keys = yield indexed_model.index_keys('b', None, return_terms=True) self.assertEqual(list(keys), [(u"None", "foo3")]) @Manager.calls_manager def test_index_keys_range(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=3, b=None).save() keys = yield indexed_model.index_keys('a', 1, 2) self.assertEqual(sorted(keys), ["foo1", "foo2"]) keys = yield indexed_model.index_keys('a', 2, 3) self.assertEqual(sorted(keys), ["foo2", "foo3"]) @Manager.calls_manager def test_index_keys_range_return_terms(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=3, b=None).save() keys = yield indexed_model.index_keys('a', 1, 2, return_terms=True) self.assertEqual(sorted(keys), [("1", "foo1"), ("2", "foo2")]) keys = yield indexed_model.index_keys('a', 2, 3, return_terms=True) self.assertEqual(sorted(keys), [("2", "foo2"), ("3", "foo3")]) @Manager.calls_manager def test_all_keys_page(self): simple_model = self.manager.proxy(SimpleModel) keys_page = yield simple_model.all_keys_page() keys = yield self.filter_tombstones(simple_model, list(keys_page)) self.assertEqual(keys, []) yield simple_model("foo-1", a=5, b=u'1').save() yield simple_model("foo-2", a=5, b=u'2').save() keys_page = yield simple_model.all_keys_page() keys = yield self.filter_tombstones(simple_model, list(keys_page)) self.assertEqual(sorted(keys), [u"foo-1", u"foo-2"]) @Manager.calls_manager def test_all_keys_page_multiple_pages(self): simple_model = self.manager.proxy(SimpleModel) yield simple_model("foo-1", a=5, b=u'1').save() yield simple_model("foo-2", a=5, b=u'2').save() keys = [] # We get results in arbitrary order and we may have tombstones left # over from prior tests. Therefore, we iterate through all index pages # and assert that we have exactly one result in each page except the # last. keys_page = yield simple_model.all_keys_page(max_results=1) while keys_page is not None: keys.extend(list(keys_page)) if keys_page.has_next_page(): self.assertEqual(len(keys_page), 1) keys_page = yield keys_page.next_page() else: keys_page = None keys = yield self.filter_tombstones(simple_model, keys) self.assertEqual(sorted(keys), [u"foo-1", u"foo-2"]) @Manager.calls_manager def test_index_keys_page(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=1, b=u"one").save() yield indexed_model("foo3", a=1, b=None).save() yield indexed_model("foo4", a=1, b=None).save() keys1 = yield indexed_model.index_keys_page('a', 1, max_results=2) self.assertEqual(sorted(keys1), ["foo1", "foo2"]) self.assertEqual(keys1.has_next_page(), True) keys2 = yield keys1.next_page() self.assertEqual(sorted(keys2), ["foo3", "foo4"]) self.assertEqual(keys2.has_next_page(), True) keys3 = yield keys2.next_page() self.assertEqual(sorted(keys3), []) self.assertEqual(keys3.has_next_page(), False) no_keys = yield keys3.next_page() self.assertEqual(no_keys, None) @Manager.calls_manager def test_index_keys_page_explicit_continuation(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=1, b=u"one").save() yield indexed_model("foo3", a=1, b=None).save() yield indexed_model("foo4", a=1, b=None).save() keys1 = yield indexed_model.index_keys_page('a', 1, max_results=1) self.assertEqual(sorted(keys1), ["foo1"]) self.assertEqual(keys1.has_next_page(), True) self.assertTrue(isinstance(keys1.continuation, unicode)) keys2 = yield indexed_model.index_keys_page( 'a', 1, max_results=2, continuation=keys1.continuation) self.assertEqual(sorted(keys2), ["foo2", "foo3"]) self.assertEqual(keys2.has_next_page(), True) keys3 = yield keys2.next_page() self.assertEqual(sorted(keys3), ["foo4"]) self.assertEqual(keys3.has_next_page(), False) @Manager.calls_manager def test_index_keys_page_none_continuation(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u'one').save() keys1 = yield indexed_model.index_keys_page('a', 1, max_results=2) self.assertEqual(sorted(keys1), ['foo1']) self.assertEqual(keys1.has_next_page(), False) self.assertEqual(keys1.continuation, None) @Manager.calls_manager def test_index_keys_page_bad_continutation(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u'one').save() try: yield indexed_model.index_keys_page( 'a', 1, max_results=1, continuation='bad-id') self.fail('Expected VumiRiakError.') except VumiRiakError: pass @Manager.calls_manager def test_index_keys_page_length(self): """ The length function for the page returns the correct length for two pages of different length. """ indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=1, b=u"one").save() yield indexed_model("foo3", a=1, b=None).save() keys1 = yield indexed_model.index_keys_page('a', 1, max_results=2) self.assertEqual(len(keys1), 2) keys2 = yield keys1.next_page() self.assertEqual(len(keys2), 1) @Manager.calls_manager def test_index_keys_empty_page_length(self): """ The length function for the page returns a length of 0 for an empty page. """ indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=2, b=u"one").save() keys1 = yield indexed_model.index_keys_page('a', 1) self.assertEqual(len(keys1), 0) @Manager.calls_manager def test_index_keys_quoting(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"+one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('b', u"+one") self.assertEqual(sorted(keys), ["foo1"]) keys = yield indexed_model.index_keys('b', u"one") self.assertEqual(sorted(keys), ["foo2"]) keys = yield indexed_model.index_keys('b', None) self.assertEqual(keys, []) @Manager.calls_manager def test_index_keys_quoting_store_none_for_empty(self): self.patch(fields, "STORE_NONE_FOR_EMPTY_INDEX", True) indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"+one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() keys = yield indexed_model.index_keys('b', u"+one") self.assertEqual(sorted(keys), ["foo1"]) keys = yield indexed_model.index_keys('b', u"one") self.assertEqual(sorted(keys), ["foo2"]) keys = yield indexed_model.index_keys('b', None) self.assertEqual(keys, ["foo3"]) @Manager.calls_manager def test_index_lookup(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() lookup = indexed_model.index_lookup yield self.assert_mapreduce_results(["foo1"], lookup, 'a', 1) yield self.assert_mapreduce_results( ["foo1", "foo2"], lookup, 'b', u"one") yield self.assert_mapreduce_results([], lookup, 'b', None) @Manager.calls_manager def test_index_lookup_store_none_for_empty(self): self.patch(fields, "STORE_NONE_FOR_EMPTY_INDEX", True) indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() lookup = indexed_model.index_lookup yield self.assert_mapreduce_results(["foo1"], lookup, 'a', 1) yield self.assert_mapreduce_results( ["foo1", "foo2"], lookup, 'b', u"one") yield self.assert_mapreduce_results(["foo3"], lookup, 'b', None) @Manager.calls_manager def test_index_match(self): indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() match = indexed_model.index_match yield self.assert_mapreduce_results( ["foo1"], match, [{'key': 'b', 'pattern': 'one', 'flags': 'i'}], 'a', 1) yield self.assert_mapreduce_results( ["foo1", "foo2"], match, [{'key': 'b', 'pattern': 'one', 'flags': 'i'}], 'b', u"one") yield self.assert_mapreduce_results( [], match, [{'key': 'a', 'pattern': '2', 'flags': 'i'}], 'b', None) # test with non-existent key yield self.assert_mapreduce_results( [], match, [{'key': 'foo', 'pattern': 'one', 'flags': 'i'}], 'a', 1) # test case sensitivity yield self.assert_mapreduce_results( ['foo1'], match, [{'key': 'b', 'pattern': 'ONE', 'flags': 'i'}], 'a', 1) yield self.assert_mapreduce_results( [], match, [{'key': 'b', 'pattern': 'ONE', 'flags': ''}], 'a', 1) @Manager.calls_manager def test_index_match_store_none_for_empty(self): self.patch(fields, "STORE_NONE_FOR_EMPTY_INDEX", True) indexed_model = self.manager.proxy(IndexedModel) yield indexed_model("foo1", a=1, b=u"one").save() yield indexed_model("foo2", a=2, b=u"one").save() yield indexed_model("foo3", a=2, b=None).save() match = indexed_model.index_match yield self.assert_mapreduce_results( ["foo1"], match, [{'key': 'b', 'pattern': 'one', 'flags': 'i'}], 'a', 1) yield self.assert_mapreduce_results( ["foo1", "foo2"], match, [{'key': 'b', 'pattern': 'one', 'flags': 'i'}], 'b', u"one") yield self.assert_mapreduce_results( ["foo3"], match, [{'key': 'a', 'pattern': '2', 'flags': 'i'}], 'b', None) # test with non-existent key yield self.assert_mapreduce_results( [], match, [{'key': 'foo', 'pattern': 'one', 'flags': 'i'}], 'a', 1) # test case sensitivity yield self.assert_mapreduce_results( ['foo1'], match, [{'key': 'b', 'pattern': 'ONE', 'flags': 'i'}], 'a', 1) yield self.assert_mapreduce_results( [], match, [{'key': 'b', 'pattern': 'ONE', 'flags': ''}], 'a', 1) @Manager.calls_manager def test_inherited_model(self): field_names = InheritedModel.field_descriptors.keys() self.assertEqual(sorted(field_names), ["a", "b", "c"]) inherited_model = self.manager.proxy(InheritedModel) im1 = inherited_model("foo", a=1, b=u"2", c=3) yield im1.save() im2 = yield inherited_model.load("foo") self.assertEqual(im2.a, 1) self.assertEqual(im2.b, u'2') self.assertEqual(im2.c, 3) def test_overriden_model(self): int_field = OverriddenModel.field_descriptors['c'].field self.assertEqual(int_field.max, 5) self.assertEqual(int_field.min, 0) overridden_model = self.manager.proxy(OverriddenModel) overridden_model("foo", a=1, b=u"2", c=3) self.assertRaises(ValidationError, overridden_model, "foo", a=1, b=u"2", c=-1) @Manager.calls_manager def test_unversioned_migration(self): old_model = self.manager.proxy(UnversionedModel) new_model = self.manager.proxy(VersionedModel) foo_old = old_model("foo", a=1) yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.was_migrated, True) @Manager.calls_manager def test_unversioned_reverse_migration(self): old_model = self.manager.proxy(UnversionedModel) new_model = self.manager.proxy(VersionedModel) foo_new = new_model("foo", c=1) model_name = "%s.%s" % ( VersionedModel.__module__, VersionedModel.__name__) self.manager.store_versions[model_name] = None yield foo_new.save() foo_old = yield old_model.load("foo") self.assertEqual(foo_old.a, 1) self.assertEqual(foo_old.was_migrated, False) @Manager.calls_manager def test_version_migration(self): old_model = self.manager.proxy(OldVersionedModel) new_model = self.manager.proxy(VersionedModel) foo_old = old_model("foo", b=1) yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.text, "hello") self.assertEqual(foo_new.was_migrated, True) @Manager.calls_manager def test_version_reverse_migration(self): old_model = self.manager.proxy(OldVersionedModel) new_model = self.manager.proxy(VersionedModel) foo_new = new_model("foo", c=1) model_name = "%s.%s" % ( VersionedModel.__module__, VersionedModel.__name__) self.manager.store_versions[model_name] = OldVersionedModel.VERSION yield foo_new.save() foo_old = yield old_model.load("foo") self.assertEqual(foo_old.b, 1) self.assertEqual(foo_old.was_migrated, False) @Manager.calls_manager def test_version_migration_new_index(self): old_model = self.manager.proxy(VersionedModel) new_model = self.manager.proxy(IndexedVersionedModel) foo_old = old_model("foo", c=1, text=u"hi") yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.text, "hi") self.assertEqual(self.get_model_indexes(foo_new), {"text_bin": ["hi"]}) self.assertEqual(foo_new.was_migrated, True) @Manager.calls_manager def test_version_reverse_migration_new_index(self): old_model = self.manager.proxy(VersionedModel) new_model = self.manager.proxy(IndexedVersionedModel) foo_new = new_model("foo", c=1, text=u"hi") model_name = "%s.%s" % ( VersionedModel.__module__, IndexedVersionedModel.__name__) self.manager.store_versions[model_name] = VersionedModel.VERSION yield foo_new.save() foo_old = yield old_model.load("foo") self.assertEqual(foo_old.c, 1) self.assertEqual(foo_old.text, "hi") # Old indexes are no longer kept across migrations. self.assertEqual(self.get_model_indexes(foo_old), {}) self.assertEqual(foo_old.was_migrated, False) @Manager.calls_manager def test_version_migration_new_index_with_unicode(self): old_model = self.manager.proxy(VersionedModel) new_model = self.manager.proxy(IndexedVersionedModel) foo_old = old_model("foo", c=1, text=u"hi Zoë") yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.text, u"hi Zoë") self.assertEqual( self.get_model_indexes(foo_new), {"text_bin": ["hi Zo\xc3\xab"]}) self.assertEqual(foo_new.was_migrated, True) @Manager.calls_manager def test_version_migration_new_index_None(self): old_model = self.manager.proxy(VersionedModel) new_model = self.manager.proxy(IndexedVersionedModel) foo_old = old_model("foo", c=1, text=None) yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.text, None) self.assertEqual(self.get_model_indexes(foo_new), {}) @Manager.calls_manager def test_version_migration_remove_index(self): """ An index can be removed in a migration. """ old_model = self.manager.proxy(IndexedVersionedModel) new_model = self.manager.proxy(IndexRemovedVersionedModel) foo_old = old_model("foo", c=1, text=u"hi") yield foo_old.save() foo_new = yield new_model.load("foo") self.assertEqual(foo_new.c, 1) self.assertEqual(foo_new.text, "hi") self.assertEqual(self.get_model_indexes(foo_new), {}) self.assertEqual(foo_new.was_migrated, True) @Manager.calls_manager def test_version_reverse_migration_remove_index(self): """ A removed index can be restored in a reverse migration. """ old_model = self.manager.proxy(IndexedVersionedModel) new_model = self.manager.proxy(IndexRemovedVersionedModel) foo_new = new_model("foo", c=1, text=u"hi") model_name = "%s.%s" % ( VersionedModel.__module__, IndexRemovedVersionedModel.__name__) self.manager.store_versions[model_name] = IndexedVersionedModel.VERSION yield foo_new.save() foo_old = yield old_model.load("foo") self.assertEqual(foo_old.c, 1) self.assertEqual(foo_old.text, "hi") self.assertEqual(self.get_model_indexes(foo_new), {"text_bin": ["hi"]}) @Manager.calls_manager def test_version_migration_failure(self): odd_model = self.manager.proxy(UnknownVersionedModel) new_model = self.manager.proxy(VersionedModel) foo_odd = odd_model("foo", d=1) yield foo_odd.save() try: yield new_model.load("foo") self.fail('Expected ModelMigrationError.') except ModelMigrationError, e: self.assertEqual( e.args[0], 'No migrators defined for VersionedModel version 5') @Manager.calls_manager def test_dynamic_field_migration(self): old_model = self.manager.proxy(UnversionedDynamicModel) new_model = self.manager.proxy(VersionedDynamicModel) old = old_model("foo") old.keep['bar'] = u"bar-val" old.keep['baz'] = u"baz-val" old.drop['bar'] = u"drop" yield old.save() new = yield new_model.load("foo") self.assertEqual(new.keep['bar'], u"bar-val") self.assertEqual(new.keep['baz'], u"baz-val") self.assertFalse("bar" in new.drop) def test_update_if_field_changed(self): """ If a field value changes, model_field_changed() is called on other field descriptors. """ mfc_called_fields = [] class DetectChangedFieldDescripor(FieldDescriptor): def model_field_changed(self, modelobj, changed_field_name): mfc_called_fields.append((self.key, changed_field_name)) class DetectChangedField(Field): descriptor_class = DetectChangedFieldDescripor class DetectChangedFieldModel(Model): a = DetectChangedField(null=True) b = DetectChangedField(null=True) dcf_model = self.manager.proxy(DetectChangedFieldModel) dcf = dcf_model("foo") self.assertEqual(mfc_called_fields, []) # Change .a to a new value and assert that .b was notified. dcf.a = "aval" self.assertEqual(mfc_called_fields, [("b", "a")]) # Change .b to a new value and assert that .a was notified. dcf.b = "bval" self.assertEqual(mfc_called_fields, [("b", "a"), ("a", "b")]) def test_no_update_if_field_unchanged(self): """ If a field value is set to its previous value, model_field_changed() is not called on other field descriptors. """ mfc_called_fields = [] class DetectChangedFieldDescripor(FieldDescriptor): def model_field_changed(self, modelobj, changed_field_name): mfc_called_fields.append((self.key, changed_field_name)) class DetectChangedField(Field): descriptor_class = DetectChangedFieldDescripor class DetectChangedFieldModel(Model): a = DetectChangedField(null=True) b = DetectChangedField(null=True) dcf_model = self.manager.proxy(DetectChangedFieldModel) dcf = dcf_model("foo") self.assertEqual(mfc_called_fields, []) # Change .a to a new value and assert that .b was notified. dcf.a = "aval" self.assertEqual(mfc_called_fields, [("b", "a")]) # Change .a to its existing value and assert that .b was not notified. dcf.a = "aval" self.assertEqual(mfc_called_fields, [("b", "a")]) # Change .a to another new value and assert that .b was notified. dcf.a = "aval2" self.assertEqual(mfc_called_fields, [("b", "a"), ("b", "a")]) class TestModelOnTxRiak(VumiTestCase, ModelTestMixin): @inlineCallbacks def setUp(self): try: from vumi.persist.txriak_manager import TxRiakManager except ImportError, e: import_skip(e, 'riak') self.manager = TxRiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(self.cleanup_manager) yield self.manager.purge_all() @inlineCallbacks def cleanup_manager(self): yield self.manager.purge_all() yield self.manager.close_manager() class TestModelOnRiak(VumiTestCase, ModelTestMixin): def setUp(self): try: from vumi.persist.riak_manager import RiakManager except ImportError, e: import_skip(e, 'riak') self.manager = RiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(self.cleanup_manager) self.manager.purge_all() def cleanup_manager(self): self.manager.purge_all() self.manager.close_manager() PK=JGvumi/persist/tests/__init__.pyPK[H)%vumi/persist/tests/test_redis_base.py"""Tests for vumi.persist.redis_base.""" from vumi.persist.redis_base import Manager from vumi.tests.helpers import VumiTestCase class TestBaseRedisManager(VumiTestCase): def mk_manager(self, key_prefix='test', client=None, config=None): if client is None: client = object() return Manager(client, config=config, key_prefix=key_prefix) def test_key_prefix(self): manager = self.mk_manager() self.assertEqual('test:None', manager._key(None)) self.assertEqual('test:foo', manager._key('foo')) def test_no_key_prefix(self): manager = self.mk_manager(None) self.assertEqual(None, manager._key(None)) self.assertEqual('foo', manager._key('foo')) def test_sub_manager(self): manager = self.mk_manager() sub_manager = manager.sub_manager("foo") self.assertEqual(sub_manager._key_prefix, "test:foo") self.assertEqual(sub_manager._client, manager._client) self.assertEqual(sub_manager._key_separator, manager._key_separator) def test_no_key_prefix_sub_manager(self): manager = self.mk_manager(None) sub_manager = manager.sub_manager("foo") self.assertEqual(sub_manager._key_prefix, "foo") self.assertEqual(sub_manager._client, manager._client) self.assertEqual(sub_manager._key_separator, manager._key_separator) def test_client_and_client_proxy_disallowed(self): '''If both the client and the client proxy are specified when creating a manager, then an exception should be raised.''' e = self.assertRaises( AssertionError, Manager, object(), None, None, client_proxy=object()) self.assertEqual( str(e), 'Only one of client or client_proxy may be specified') PK=JG!vumi/persist/tests/test_fields.py# -*- coding: utf-8 -*- """Tests for vumi.persist.fields.""" from datetime import datetime from functools import wraps from twisted.internet.defer import inlineCallbacks, returnValue from vumi.message import Message, TransportUserMessage from vumi.persist.fields import ( ValidationError, Field, Integer, Unicode, Tag, Timestamp, Json, ListOf, SetOf, Dynamic, FieldWithSubtype, Boolean, VumiMessage, ForeignKey, ManyToMany, ComputedValue) from vumi.persist.model import Manager, Model from vumi.tests.helpers import VumiTestCase, MessageHelper, import_skip def needs_riak(method): """ Mark a test method as needing Riak setup. """ method.needs_riak = True return method class ModelFieldTestsDecorator(object): """ Class decorator for replacing `@needs_riak`-marked test methods with two wrapped versions, one for each Riak manager. This is used here instead of the more usual mechanism of a mixin and two subclasses because we have a lot of small test classes which have several test methods that don't need Riak. """ def __call__(deco, cls): """ Find all methods on the the given class that are marked with `@needs_riak` and replace them with wrapped versions for both RiakManager and TxRiakManager. """ # We can't use `inspect.getmembers()` because of a bug in Python 2.6 # around empty slots: http://bugs.python.org/issue1162154 needs_riak_methods = [] for member_name in dir(cls): # If the class has an empty slot (`__provides__` from # zope.interface, in this case) we get a name for a member that # does not exist. member = getattr(cls, member_name, None) if getattr(member, "needs_riak", False): needs_riak_methods.append((member_name, member)) for name, meth in needs_riak_methods: delattr(cls, name) setattr(cls, name + "__on_riak", deco.wrap_riak_setup(meth)) setattr(cls, name + "__on_txriak", deco.wrap_txriak_setup(meth)) return cls def wrap_riak_setup(deco, meth): """ Return a wrapper around `meth` that sets up a RiakManager. """ @wraps(meth) def wrapper(self): deco.setup_riak(self) return meth(self) return wrapper def wrap_txriak_setup(deco, meth): """ Return a wrapper around `meth` that sets up a TxRiakManager. """ @wraps(meth) def wrapper(self): d = deco.setup_txriak(self) return d.addCallback(lambda _: meth(self)) return wrapper def setup_riak(deco, self): """ Set up a RiakManager on the given test class. """ try: from vumi.persist.riak_manager import RiakManager except ImportError, e: import_skip(e, 'riak') self.manager = RiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(deco.cleanup_manager, self) self.manager.purge_all() @inlineCallbacks def setup_txriak(deco, self): """ Set up a TxRiakManager on the given test class. """ try: from vumi.persist.txriak_manager import TxRiakManager except ImportError, e: import_skip(e, 'riak') self.manager = TxRiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(deco.cleanup_manager, self) yield self.manager.purge_all() @inlineCallbacks def cleanup_manager(deco, self): """ Clean up the Riak manager on the given test class. """ yield self.manager.purge_all() yield self.manager.close_manager() model_field_tests = ModelFieldTestsDecorator() def watch_model_changes(modelobj): changed_field_names = [] old_field_changed = modelobj._field_changed def new_field_changed(changed_field_name): changed_field_names.append(changed_field_name) return old_field_changed modelobj._field_changed = new_field_changed return changed_field_names @model_field_tests class TestBaseField(VumiTestCase): def test_validate(self): f = Field() f.validate("foo") f.validate(object()) self.assertRaises(ValidationError, f.validate, None) def test_validate_null(self): f = Field(null=True) f.validate("foo") f.validate(None) def test_to_riak(self): f = Field() obj = object() self.assertEqual(f.to_riak(obj), obj) def test_from_riak(self): f = Field() obj = object() self.assertEqual(f.from_riak(obj), obj) def test_get_descriptor(self): f = Field() descriptor = f.get_descriptor("foo") self.assertEqual(descriptor.key, "foo") self.assertEqual(descriptor.field, f) self.assertTrue("Field object" in repr(descriptor)) class BaseFieldModel(Model): """ Toy model for Field tests. """ f = Field() @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ base_model = self.manager.proxy(self.BaseFieldModel) yield base_model("m_str", f="string").save() yield base_model("m_int", f=1).save() yield base_model("m_list", f=["string", 1]).save() yield base_model("m_dict", f={"key": "val"}).save() m_str = yield base_model.load("m_str") self.assertEqual(m_str.f, "string") m_int = yield base_model.load("m_int") self.assertEqual(m_int.f, 1) m_list = yield base_model.load("m_list") self.assertEqual(m_list.f, ["string", 1]) m_dict = yield base_model.load("m_dict") self.assertEqual(m_dict.f, {"key": "val"}) @needs_riak @Manager.calls_manager def test_field_changed_notification(self): """ The model object is notified about changes to the field value. """ base_model = self.manager.proxy(self.BaseFieldModel) m_str = base_model("m_str", f="string") field_changes = watch_model_changes(m_str) self.assertEqual(field_changes, []) # Field set to previous value, no notification. m_str.f = "string" self.assertEqual(field_changes, []) # Field set to new value, notification sent. m_str.f = "string2" self.assertEqual(field_changes, ["f"]) # Model saved, no notification. yield m_str.save() self.assertEqual(field_changes, ["f"]) @model_field_tests class TestInteger(VumiTestCase): def test_validate_unbounded(self): i = Integer() i.validate(5) i.validate(-3) self.assertRaises(ValidationError, i.validate, 5.0) self.assertRaises(ValidationError, i.validate, "5") def test_validate_minimum(self): i = Integer(min=3) i.validate(3) i.validate(4) self.assertRaises(ValidationError, i.validate, 2) def test_validate_maximum(self): i = Integer(max=5) i.validate(5) i.validate(4) self.assertRaises(ValidationError, i.validate, 6) class IntegerModel(Model): """ Toy model for Integer field tests. """ i = Integer() @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ int_model = self.manager.proxy(self.IntegerModel) yield int_model("m_1", i=1).save() yield int_model("m_leet", i=1337).save() m_1 = yield int_model.load("m_1") self.assertEqual(m_1.i, 1) m_leet = yield int_model.load("m_leet") self.assertEqual(m_leet.i, 1337) @model_field_tests class TestBoolean(VumiTestCase): def test_validate(self): b = Boolean() b.validate(True) b.validate(False) self.assertRaises(ValidationError, b.validate, 'True') self.assertRaises(ValidationError, b.validate, 'False') self.assertRaises(ValidationError, b.validate, 1) self.assertRaises(ValidationError, b.validate, 0) class BooleanModel(Model): """ Toy model for Boolean field tests. """ b = Boolean() @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ bool_model = self.manager.proxy(self.BooleanModel) yield bool_model("m_t", b=True).save() yield bool_model("m_f", b=False).save() m_t = yield bool_model.load("m_t") self.assertEqual(m_t.b, True) m_f = yield bool_model.load("m_f") self.assertEqual(m_f.b, False) @model_field_tests class TestUnicode(VumiTestCase): def test_validate(self): u = Unicode() u.validate(u"") u.validate(u"a") u.validate(u"æ") u.validate(u"foé") self.assertRaises(ValidationError, u.validate, "") self.assertRaises(ValidationError, u.validate, "foo") self.assertRaises(ValidationError, u.validate, 3) def test_validate_max_length(self): u = Unicode(max_length=5) u.validate(u"12345") u.validate(u"1234") self.assertRaises(ValidationError, u.validate, u"123456") class UnicodeModel(Model): """ Toy model for Unicode field tests. """ u = Unicode() @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ unicode_model = self.manager.proxy(self.UnicodeModel) yield unicode_model("m_empty", u=u"").save() yield unicode_model("m_full", u=u"You must be an optimist").save() yield unicode_model("m_unicode", u=u"foé").save() m_empty = yield unicode_model.load("m_empty") self.assertEqual(m_empty.u, u"") m_full = yield unicode_model.load("m_full") self.assertEqual(m_full.u, u"You must be an optimist") m_unicode = yield unicode_model.load("m_unicode") self.assertEqual(m_unicode.u, u"foé") @model_field_tests class TestTag(VumiTestCase): def test_validate(self): t = Tag() t.validate(("pool", "tagname")) self.assertRaises(ValidationError, t.validate, ["pool", "tagname"]) self.assertRaises(ValidationError, t.validate, ("pool",)) def test_to_riak(self): t = Tag() self.assertEqual(t.to_riak(("pool", "tagname")), ["pool", "tagname"]) def test_from_riak(self): t = Tag() self.assertEqual(t.from_riak(["pool", "tagname"]), ("pool", "tagname")) @model_field_tests class TestTimestamp(VumiTestCase): def test_validate(self): t = Timestamp() t.validate(datetime.now()) t.validate("2007-01-25T12:00:00Z") t.validate(u"2007-01-25T12:00:00Z") self.assertRaises(ValidationError, t.validate, "foo") def test_to_riak(self): t = Timestamp() dt = datetime(2100, 10, 5, 11, 10, 9) self.assertEqual(t.to_riak(dt), "2100-10-05 11:10:09.000000") def test_from_riak(self): t = Timestamp() dt = datetime(2100, 10, 5, 11, 10, 9) self.assertEqual(t.from_riak("2100-10-05 11:10:09.000000"), dt) class TimestampModel(Model): """ Toy model for Timestamp tests. """ time = Timestamp(null=True) @needs_riak def test_set_field(self): """ A timestamp field can be set to a datetime or string value through its descriptor. """ timestamp_model = self.manager.proxy(self.TimestampModel) t = timestamp_model("foo") now = datetime.now() t.time = now self.assertEqual(t.time, now) t.time = u"2007-01-25T12:00:00Z" self.assertEqual(t.time, datetime(2007, 01, 25, 12, 0)) @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ timestamp_model = self.manager.proxy(self.TimestampModel) now = datetime.now() yield timestamp_model("m_now", time=now).save() yield timestamp_model("m_string", time=u"2007-01-25T12:00:00Z").save() m_now = yield timestamp_model.load("m_now") self.assertEqual(m_now.time, now) m_string = yield timestamp_model.load("m_string") self.assertEqual(m_string.time, datetime(2007, 01, 25, 12, 0)) @model_field_tests class TestJson(VumiTestCase): def test_validate(self): j = Json() j.validate({"foo": None}) self.assertRaises(ValidationError, j.validate, None) def test_to_riak(self): j = Json() d = {"foo": 5} self.assertEqual(j.to_riak(d), d) def test_from_riak(self): j = Json() d = {"foo": [1, 2, 3]} self.assertEqual(j.from_riak(d), d) class JsonModel(Model): """ Toy model for Json tests. """ j = Json() @needs_riak @Manager.calls_manager def test_assorted_values(self): """ Values are preserved when the field is stored and later loaded. """ json_model = self.manager.proxy(self.JsonModel) yield json_model("m_str", j="string").save() yield json_model("m_int", j=1).save() yield json_model("m_list", j=["string", 1]).save() yield json_model("m_dict", j={"key": "val"}).save() m_str = yield json_model.load("m_str") self.assertEqual(m_str.j, "string") m_int = yield json_model.load("m_int") self.assertEqual(m_int.j, 1) m_list = yield json_model.load("m_list") self.assertEqual(m_list.j, ["string", 1]) m_dict = yield json_model.load("m_dict") self.assertEqual(m_dict.j, {"key": "val"}) class TestFieldWithSubtype(VumiTestCase): def test_fails_on_fancy_subtype(self): self.assertRaises(RuntimeError, FieldWithSubtype, Dynamic()) @model_field_tests class TestDynamic(VumiTestCase): def test_validate(self): dynamic = Dynamic() dynamic.validate({u'a': u'foo', u'b': u'bar'}) self.assertRaises(ValidationError, dynamic.validate, {u'a': 'foo', u'b': u'bar'}) self.assertRaises(ValidationError, dynamic.validate, u'this is not a dict') self.assertRaises(ValidationError, dynamic.validate, {u'a': 'foo', u'b': 2}) class DynamicModel(Model): """ Toy model for Dynamic tests. """ a = Unicode() contact_info = Dynamic() @needs_riak def test_get_data_with_dynamic_proxy(self): """ A Dynamic field creates more than one field in the Riak object. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = dynamic_model("foo", a=u"ab") m.contact_info['foo'] = u'bar' m.contact_info['zip'] = u'zap' self.assertEqual(m.get_data(), { 'key': 'foo', '$VERSION': None, 'a': 'ab', 'contact_info.foo': 'bar', 'contact_info.zip': 'zap', }) def _create_dynamic_instance(self, dynamic_model, key): m = dynamic_model(key, a=u"ab") m.contact_info['cellphone'] = u"+27123" m.contact_info['telephone'] = u"+2755" m.contact_info['honorific'] = u"BDFL" return m @needs_riak @Manager.calls_manager def test_dynamic_value(self): """ Values are preserved when the field is stored and later loaded. """ dynamic_model = self.manager.proxy(self.DynamicModel) yield self._create_dynamic_instance(dynamic_model, "foo").save() m = yield dynamic_model.load("foo") self.assertEqual(m.a, u"ab") self.assertEqual(m.contact_info['cellphone'], u"+27123") self.assertEqual(m.contact_info['telephone'], u"+2755") self.assertEqual(m.contact_info['honorific'], u"BDFL") @needs_riak def test_dynamic_field_init(self): """ Dynamic fields can be initialised with dicts. """ dynamic_model = self.manager.proxy(self.DynamicModel) contact_info = {'cellphone': u'+27123', 'telephone': u'+2755'} m = dynamic_model("foo", a=u"ab", contact_info=contact_info) self.assertEqual(m.contact_info.copy(), contact_info) @needs_riak def test_dynamic_field_keys_and_values(self): """ Dynamic field keys and values are available as lists or and iterators, similar to a dict. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") keys = m.contact_info.keys() iterkeys = m.contact_info.iterkeys() self.assertTrue(isinstance(keys, list)) self.assertTrue(hasattr(iterkeys, 'next')) self.assertEqual(sorted(keys), ['cellphone', 'honorific', 'telephone']) self.assertEqual(sorted(iterkeys), sorted(keys)) values = m.contact_info.values() itervalues = m.contact_info.itervalues() self.assertTrue(isinstance(values, list)) self.assertTrue(hasattr(itervalues, 'next')) self.assertEqual(sorted(values), ["+27123", "+2755", "BDFL"]) self.assertEqual(sorted(itervalues), sorted(values)) items = m.contact_info.items() iteritems = m.contact_info.iteritems() self.assertTrue(isinstance(items, list)) self.assertTrue(hasattr(iteritems, 'next')) self.assertEqual(sorted(items), [('cellphone', "+27123"), ('honorific', "BDFL"), ('telephone', "+2755")]) self.assertEqual(sorted(iteritems), sorted(items)) @needs_riak def test_dynamic_field_clear(self): """ Dynamic fields can be cleared. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") m.contact_info.clear() self.assertEqual(m.contact_info.items(), []) @needs_riak def test_dynamic_field_update(self): """ Dynamic fields can be bulk-updated. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") m.contact_info.update({"cellphone": "123", "name": "foo"}) self.assertEqual(sorted(m.contact_info.items()), [ ('cellphone', "123"), ('honorific', "BDFL"), ('name', "foo"), ('telephone', "+2755")]) @needs_riak def test_dynamic_field_contains(self): """ Dynamic fields support `in`. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") self.assertTrue("cellphone" in m.contact_info) self.assertFalse("landline" in m.contact_info) @needs_riak def test_dynamic_field_del(self): """ Values can be removed from dynamic fields. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") del m.contact_info["telephone"] self.assertEqual(sorted(m.contact_info.keys()), ['cellphone', 'honorific']) @needs_riak def test_dynamic_field_setting(self): """ Setting a dynamic field to a dict replaces its contents. """ dynamic_model = self.manager.proxy(self.DynamicModel) m = self._create_dynamic_instance(dynamic_model, "foo") m.contact_info = {u'cellphone': u'789', u'name': u'foo'} self.assertEqual(sorted(m.contact_info.items()), [ (u'cellphone', u'789'), (u'name', u'foo'), ]) @model_field_tests class TestListOf(VumiTestCase): def test_validate(self): """ By default, a ListOf field is a list of Unicode fields. """ listof = ListOf() listof.validate([u'foo', u'bar']) self.assertRaises(ValidationError, listof.validate, u'this is not a list') self.assertRaises(ValidationError, listof.validate, ['a', 2]) self.assertRaises(ValidationError, listof.validate, [1, 2]) def test_validate_with_subtype(self): """ If an explicit subtype is provided, its validation is used. """ listof_unicode = ListOf(Unicode()) listof_unicode.validate([u"a", u"b"]) self.assertRaises(ValidationError, listof_unicode.validate, [1, 2]) listof_int = ListOf(Integer()) listof_int.validate([1, 2]) self.assertRaises(ValidationError, listof_int.validate, [u"a", u"b"]) listof_smallint = ListOf(Integer(max=10)) listof_smallint.validate([1, 2]) self.assertRaises( ValidationError, listof_smallint.validate, [1, 100]) class ListOfModel(Model): """ Toy model for ListOf tests. """ items = ListOf(Integer()) texts = ListOf(Unicode()) class IndexedListOfModel(Model): """ Toy model for ListOf index tests. """ items = ListOf(Integer(), index=True) @needs_riak def test_get_data_with_list_proxy(self): """ A ListOf field creates a list field in the Riak object. """ list_model = self.manager.proxy(self.ListOfModel) m = list_model("foo") m.items.append(1) m.items.append(42) m.texts.append(u"Thing 1.") m.texts.append(u"Thing 42.") self.assertEqual(m.get_data(), { 'key': 'foo', '$VERSION': None, 'items': [1, 42], 'texts': [u"Thing 1.", u"Thing 42."], }) @needs_riak @Manager.calls_manager def test_listof_fields(self): """ A ListOf field can be manipulated as if it were a list. """ list_model = self.manager.proxy(self.ListOfModel) l1 = list_model("foo") l1.items.append(1) l1.items.append(2) yield l1.save() l2 = yield list_model.load("foo") self.assertEqual(l2.items[0], 1) self.assertEqual(l2.items[1], 2) self.assertEqual(list(l2.items), [1, 2]) l2.items[0] = 5 self.assertEqual(l2.items[0], 5) del l2.items[0] self.assertEqual(list(l2.items), [2]) l2.items.append(5) self.assertEqual(list(l2.items), [2, 5]) l2.items.remove(5) self.assertEqual(list(l2.items), [2]) l2.items.extend([3, 4, 5]) self.assertEqual(list(l2.items), [2, 3, 4, 5]) l2.items = [1] self.assertEqual(list(l2.items), [1]) @needs_riak @Manager.calls_manager def test_listof_fields_indexes(self): """ An indexed ListOf field has an index value for each item in the list. """ list_model = self.manager.proxy(self.IndexedListOfModel) l1 = list_model("foo") l1.items.append(1) l1.items.append(2) yield l1.save() assert_indexes = lambda mdl, values: self.assertEqual( mdl._riak_object.get_indexes(), set(('items_bin', str(v)) for v in values)) l2 = yield list_model.load("foo") self.assertEqual(l2.items[0], 1) self.assertEqual(l2.items[1], 2) self.assertEqual(list(l2.items), [1, 2]) assert_indexes(l2, [1, 2]) l2.items[0] = 5 self.assertEqual(l2.items[0], 5) assert_indexes(l2, [2, 5]) del l2.items[0] self.assertEqual(list(l2.items), [2]) assert_indexes(l2, [2]) l2.items.append(5) self.assertEqual(list(l2.items), [2, 5]) assert_indexes(l2, [2, 5]) l2.items.remove(5) self.assertEqual(list(l2.items), [2]) assert_indexes(l2, [2]) l2.items.extend([3, 4, 5]) self.assertEqual(list(l2.items), [2, 3, 4, 5]) assert_indexes(l2, [2, 3, 4, 5]) l2.items = [1] self.assertEqual(list(l2.items), [1]) assert_indexes(l2, [1]) @model_field_tests class TestSetOf(VumiTestCase): def test_validate(self): """ By default, a SetOf field is a set of Unicode fields. """ f = SetOf() f.validate(set([u'foo', u'bar'])) self.assertRaises(ValidationError, f.validate, u'this is not a set') self.assertRaises(ValidationError, f.validate, set(['a', 2])) self.assertRaises(ValidationError, f.validate, [u'a', u'b']) def test_validate_with_subtype(self): """ If an explicit subtype is provided, its validation is used. """ setof_unicode = SetOf(Unicode()) setof_unicode.validate(set([u"a", u"b"])) self.assertRaises(ValidationError, setof_unicode.validate, set([1, 2])) setof_int = SetOf(Integer()) setof_int.validate(set([1, 2])) self.assertRaises( ValidationError, setof_int.validate, set([u"a", u"b"])) setof_smallint = SetOf(Integer(max=10)) setof_smallint.validate(set([1, 2])) self.assertRaises( ValidationError, setof_smallint.validate, set([1, 100])) def test_to_riak(self): """ The JSON representation of a SetOf field is a sorted list. """ f = SetOf() self.assertEqual(f.to_riak(set([1, 2, 3])), [1, 2, 3]) def test_from_riak(self): """ The JSON list is turned into a set when read. """ f = SetOf() self.assertEqual(f.from_riak([1, 2, 3]), set([1, 2, 3])) class SetOfModel(Model): """ Toy model for SetOf tests. """ items = SetOf(Integer()) texts = SetOf(Unicode()) class IndexedSetOfModel(Model): """ Toy model for SetOf index tests. """ items = SetOf(Integer(), index=True) @needs_riak def test_setof_fields_validation(self): set_model = self.manager.proxy(self.SetOfModel) m1 = set_model("foo") self.assertRaises(ValidationError, m1.items.add, "foo") self.assertRaises(ValidationError, m1.items.remove, "foo") self.assertRaises(ValidationError, m1.items.discard, "foo") self.assertRaises(ValidationError, m1.items.update, set(["foo"])) @needs_riak @Manager.calls_manager def test_setof_fields(self): """ A SetOf field can be manipulated as if it were a set. """ set_model = self.manager.proxy(self.SetOfModel) m1 = set_model("foo") m1.items.add(1) m1.items.add(2) yield m1.save() m2 = yield set_model.load("foo") self.assertTrue(1 in m2.items) self.assertTrue(2 in m2.items) self.assertEqual(set(m2.items), set([1, 2])) m2.items.add(5) self.assertTrue(5 in m2.items) m2.items.remove(1) self.assertTrue(1 not in m2.items) self.assertRaises(KeyError, m2.items.remove, 1) m2.items.add(1) m2.items.discard(1) self.assertTrue(1 not in m2.items) m2.items.discard(1) self.assertTrue(1 not in m2.items) m2.items.update([3, 4, 5]) self.assertEqual(set(m2.items), set([2, 3, 4, 5])) m2.items = set([7, 8]) self.assertEqual(set(m2.items), set([7, 8])) @needs_riak @Manager.calls_manager def test_setof_fields_indexes(self): """ An indexed SetOf field has an index value for each item in the set. """ set_model = self.manager.proxy(self.IndexedSetOfModel) m1 = set_model("foo") m1.items.add(1) m1.items.add(2) yield m1.save() assert_indexes = lambda mdl, values: self.assertEqual( mdl._riak_object.get_indexes(), set(('items_bin', str(v)) for v in values)) m2 = yield set_model.load("foo") self.assertTrue(1 in m2.items) self.assertTrue(2 in m2.items) self.assertEqual(set(m2.items), set([1, 2])) assert_indexes(m2, [1, 2]) m2.items.add(5) self.assertTrue(5 in m2.items) assert_indexes(m2, [1, 2, 5]) m2.items.remove(1) self.assertTrue(1 not in m2.items) assert_indexes(m2, [2, 5]) m2.items.add(1) m2.items.discard(1) self.assertTrue(1 not in m2.items) assert_indexes(m2, [2, 5]) m2.items.discard(1) self.assertTrue(1 not in m2.items) assert_indexes(m2, [2, 5]) m2.items.update([3, 4, 5]) self.assertEqual(set(m2.items), set([2, 3, 4, 5])) assert_indexes(m2, [2, 3, 4, 5]) m2.items = set([7, 8]) self.assertEqual(set(m2.items), set([7, 8])) assert_indexes(m2, [7, 8]) @model_field_tests class TestVumiMessage(VumiTestCase): def test_validate(self): f = VumiMessage(Message) msg = Message() f.validate(msg) self.assertRaises( ValidationError, f.validate, u'this is not a vumi message') self.assertRaises( ValidationError, f.validate, None) class VumiMessageModel(Model): """ Toy model for VumiMessage tests. """ msg = VumiMessage(TransportUserMessage) @needs_riak @Manager.calls_manager def test_vumimessage_field(self): msg_helper = self.add_helper(MessageHelper()) msg_model = self.manager.proxy(self.VumiMessageModel) msg = msg_helper.make_inbound("foo", extra="bar") m1 = msg_model("foo", msg=msg) yield m1.save() m2 = yield msg_model.load("foo") self.assertEqual(m1.msg, m2.msg) self.assertEqual(m2.msg, msg) self.assertRaises(ValidationError, setattr, m1, "msg", "foo") # test extra keys are removed msg2 = msg_helper.make_inbound("foo") m1.msg = msg2 self.assertTrue("extra" not in m1.msg) @needs_riak @Manager.calls_manager def test_vumimessage_field_excludes_cache(self): msg_helper = self.add_helper(MessageHelper()) msg_model = self.manager.proxy(self.VumiMessageModel) cache_attr = TransportUserMessage._CACHE_ATTRIBUTE msg = msg_helper.make_inbound("foo", extra="bar") msg.cache["cache"] = "me" self.assertEqual(msg[cache_attr], {"cache": "me"}) m1 = msg_model("foo", msg=msg) self.assertTrue(cache_attr not in m1.msg) yield m1.save() m2 = yield msg_model.load("foo") self.assertTrue(cache_attr not in m2.msg) self.assertEqual(m2.msg, m1.msg) class ReferencedModel(Model): """ Toy model for testing fields that reference other models. """ a = Integer() b = Unicode() @model_field_tests class TestForeignKey(VumiTestCase): class ForeignKeyModel(Model): """ Toy model for ForeignKey tests. """ referenced = ForeignKey(ReferencedModel, null=True) @needs_riak def test_get_data_with_foreign_key_proxy(self): """ A ForeignKey field stores the referenced key in the Riak object. """ referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') fk_model = self.manager.proxy(self.ForeignKeyModel) f1 = fk_model("bar1") f1.referenced.set(s1) self.assertEqual(f1.get_data(), { 'key': 'bar1', '$VERSION': None, 'referenced': 'foo' }) @needs_riak @Manager.calls_manager def test_foreignkey_fields(self): """ ForeignKey fields can operate on both keys and model instances. """ fk_model = self.manager.proxy(self.ForeignKeyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') f1 = fk_model("bar") f1.referenced.set(s1) yield s1.save() yield f1.save() self.assertEqual(f1._riak_object.get_data()['referenced'], s1.key) f2 = yield fk_model.load("bar") s2 = yield f2.referenced.get() self.assertEqual(f2.referenced.key, "foo") self.assertEqual(s2.a, 5) self.assertEqual(s2.b, u"3") f2.referenced.set(None) s3 = yield f2.referenced.get() self.assertEqual(s3, None) f2.referenced.key = "foo" s4 = yield f2.referenced.get() self.assertEqual(s4.key, "foo") f2.referenced.key = None s5 = yield f2.referenced.get() self.assertEqual(s5, None) self.assertRaises(ValidationError, f2.referenced.set, object()) @needs_riak @Manager.calls_manager def test_old_foreignkey_fields(self): """ Old versions of the ForeignKey field relied entirely on indexes and didn't store the referenced key in the model data. """ fk_model = self.manager.proxy(self.ForeignKeyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') f1 = fk_model("bar") # Create index directly and remove data field to simulate old-style # index-only implementation f1._riak_object.add_index('referenced_bin', s1.key) data = f1._riak_object.get_data() data.pop('referenced') f1._riak_object.set_data(data) yield s1.save() yield f1.save() f2 = yield fk_model.load("bar") s2 = yield f2.referenced.get() self.assertEqual(f2.referenced.key, "foo") self.assertEqual(s2.a, 5) self.assertEqual(s2.b, u"3") f2.referenced.set(None) s3 = yield f2.referenced.get() self.assertEqual(s3, None) f2.referenced.key = "foo" s4 = yield f2.referenced.get() self.assertEqual(s4.key, "foo") f2.referenced.key = None s5 = yield f2.referenced.get() self.assertEqual(s5, None) self.assertRaises(ValidationError, f2.referenced.set, object()) @needs_riak @Manager.calls_manager def test_reverse_foreignkey_fields(self): """ When we declare a ForeignKey field, we add both a paginated index lookup method and a legacy non-paginated index lookup method to the foreign model's backlinks attribute. """ fk_model = self.manager.proxy(self.ForeignKeyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') f1 = fk_model("bar1") f1.referenced.set(s1) f2 = fk_model("bar2") f2.referenced.set(s1) yield s1.save() yield f1.save() yield f2.save() s2 = yield referenced_model.load("foo") results = yield s2.backlinks.foreignkeymodels() self.assertEqual(sorted(results), ["bar1", "bar2"]) results_p1 = yield s2.backlinks.foreignkeymodel_keys() self.assertEqual(sorted(results_p1), ["bar1", "bar2"]) self.assertEqual(results_p1.has_next_page(), False) @model_field_tests class TestManyToMany(VumiTestCase): @Manager.calls_manager def load_all_bunches_flat(self, m2m_field): results = [] for result_bunch in m2m_field.load_all_bunches(): results.extend((yield result_bunch)) returnValue(results) class ManyToManyModel(Model): references = ManyToMany(ReferencedModel) @needs_riak def test_get_data_with_many_to_many_proxy(self): """ A ManyToMany field stores the referenced keys in the Riak object. """ referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') mm_model = self.manager.proxy(self.ManyToManyModel) m1 = mm_model("bar") m1.references.add(s1) m1.save() self.assertEqual(m1.get_data(), { 'key': 'bar', '$VERSION': None, 'references': ['foo'], }) @needs_riak @Manager.calls_manager def test_manytomany_field(self): """ ManyToMany fields can operate on both keys and model instances. """ mm_model = self.manager.proxy(self.ManyToManyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') m1 = mm_model("bar") m1.references.add(s1) yield s1.save() yield m1.save() self.assertEqual(m1._riak_object.get_data()['references'], [s1.key]) m2 = yield mm_model.load("bar") [s2] = yield self.load_all_bunches_flat(m2.references) self.assertEqual(m2.references.keys(), ["foo"]) self.assertEqual(s2.a, 5) self.assertEqual(s2.b, u"3") m2.references.remove(s2) references = yield self.load_all_bunches_flat(m2.references) self.assertEqual(references, []) m2.references.add_key("foo") [s4] = yield self.load_all_bunches_flat(m2.references) self.assertEqual(s4.key, "foo") m2.references.remove_key("foo") references = yield self.load_all_bunches_flat(m2.references) self.assertEqual(references, []) self.assertRaises(ValidationError, m2.references.add, object()) self.assertRaises(ValidationError, m2.references.remove, object()) t1 = referenced_model("bar1", a=3, b=u'4') t2 = referenced_model("bar2", a=4, b=u'4') m2.references.add(t1) m2.references.add(t2) yield t1.save() yield t2.save() references = yield self.load_all_bunches_flat(m2.references) references.sort(key=lambda s: s.key) self.assertEqual([s.key for s in references], ["bar1", "bar2"]) self.assertEqual(references[0].a, 3) self.assertEqual(references[1].a, 4) m2.references.clear() m2.references.add_key("unknown") self.assertEqual([], (yield self.load_all_bunches_flat(m2.references))) @needs_riak @Manager.calls_manager def test_old_manytomany_field(self): """ Old versions of the ManyToMany field relied entirely on indexes and didn't store the referenced keys in the model data. """ mm_model = self.manager.proxy(self.ManyToManyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo", a=5, b=u'3') m1 = mm_model("bar") # Create index directly to simulate old-style index-only implementation m1._riak_object.add_index('references_bin', s1.key) # Manually remove the entry from the data dict to allow it to be # set from the index value in descriptor.clean() data = m1._riak_object.get_data() data.pop('references') m1._riak_object.set_data(data) yield s1.save() yield m1.save() m2 = yield mm_model.load("bar") [s2] = yield self.load_all_bunches_flat(m2.references) self.assertEqual(m2.references.keys(), ["foo"]) self.assertEqual(s2.a, 5) self.assertEqual(s2.b, u"3") m2.references.remove(s2) references = yield self.load_all_bunches_flat(m2.references) self.assertEqual(references, []) m2.references.add_key("foo") [s4] = yield self.load_all_bunches_flat(m2.references) self.assertEqual(s4.key, "foo") m2.references.remove_key("foo") references = yield self.load_all_bunches_flat(m2.references) self.assertEqual(references, []) self.assertRaises(ValidationError, m2.references.add, object()) self.assertRaises(ValidationError, m2.references.remove, object()) t1 = referenced_model("bar1", a=3, b=u'4') t2 = referenced_model("bar2", a=4, b=u'4') m2.references.add(t1) m2.references.add(t2) yield t1.save() yield t2.save() references = yield self.load_all_bunches_flat(m2.references) references.sort(key=lambda s: s.key) self.assertEqual([s.key for s in references], ["bar1", "bar2"]) self.assertEqual(references[0].a, 3) self.assertEqual(references[1].a, 4) m2.references.clear() m2.references.add_key("unknown") self.assertEqual([], (yield self.load_all_bunches_flat(m2.references))) @needs_riak @Manager.calls_manager def test_reverse_manytomany_fields(self): """ When we declare a ManyToMany field, we add both a paginated index lookup method and a legacy non-paginated index lookup method to the foreign model's backlinks attribute. """ mm_model = self.manager.proxy(self.ManyToManyModel) referenced_model = self.manager.proxy(ReferencedModel) s1 = referenced_model("foo1", a=5, b=u'3') s2 = referenced_model("foo2", a=4, b=u'4') m1 = mm_model("bar1") m1.references.add(s1) m1.references.add(s2) m2 = mm_model("bar2") m2.references.add(s1) yield s1.save() yield s2.save() yield m1.save() yield m2.save() s1 = yield referenced_model.load("foo1") results = yield s1.backlinks.manytomanymodels() self.assertEqual(sorted(results), ["bar1", "bar2"]) results_p1 = yield s1.backlinks.manytomanymodel_keys() self.assertEqual(sorted(results_p1), ["bar1", "bar2"]) self.assertEqual(results_p1.has_next_page(), False) s2 = yield referenced_model.load("foo2") results = yield s2.backlinks.manytomanymodels() self.assertEqual(sorted(results), ["bar1"]) results_p1 = yield s2.backlinks.manytomanymodel_keys() self.assertEqual(sorted(results_p1), ["bar1"]) self.assertEqual(results_p1.has_next_page(), False) @model_field_tests class TestComputedValue(VumiTestCase): def test_validate_default_unicode(self): """ By default, a ComputedValue field has a Unicode value. """ comp = ComputedValue(lambda m: NotImplemented) comp.validate(u"") comp.validate(u"a") comp.validate(u"æ") comp.validate(u"foé") self.assertRaises(ValidationError, comp.validate, "") self.assertRaises(ValidationError, comp.validate, "foo") self.assertRaises(ValidationError, comp.validate, 3) def test_validate_listof(self): """ If an explicit subtype is provided, its validation is used. """ comp = ComputedValue(lambda m: NotImplemented, ListOf()) comp.validate([u'foo', u'bar']) self.assertRaises(ValidationError, comp.validate, u'this is not a list') self.assertRaises(ValidationError, comp.validate, ['a', 2]) self.assertRaises(ValidationError, comp.validate, [1, 2]) class ComputedValueModel(Model): """ Toy model for ComputedValue tests. """ a = Integer() b = Unicode() c = ListOf(Unicode()) a_with_b = ComputedValue( lambda m: u"%s::%s" % (m.a, m.b), Unicode(index=True)) b_with_a = ComputedValue(lambda m: u"%s::%s" % (m.b, m.a), Unicode()) a_with_c = ComputedValue( lambda m: [u"%s::%s" % (m.a, c) for c in m.c], ListOf(Unicode(), index=True)) def assert_indexes(self, mdl, indexes): self.assertEqual( mdl._riak_object.get_indexes(), set((k, v) for k, vs in indexes.items() for v in vs)) @needs_riak @Manager.calls_manager def test_computedvalue_field(self): """ A `ComputedValue` field gets its value from the function it's given. """ ci_model = self.manager.proxy(self.ComputedValueModel) m1 = ci_model("foo", a=7, b=u"bar", c=[u"thing1", u"thing2"]) # Value and index are computed at creation time. self.assertEqual(m1.a_with_b, u"7::bar") self.assertEqual(m1.b_with_a, u"bar::7") self.assertEqual(list(m1.a_with_c), [u"7::thing1", u"7::thing2"]) self.assert_indexes(m1, { "a_with_b_bin": ["7::bar"], "a_with_c_bin": ["7::thing1", "7::thing2"], }) # Value and index are correct after save. yield m1.save() self.assertEqual(m1.a_with_b, u"7::bar") self.assertEqual(m1.b_with_a, u"bar::7") self.assertEqual(list(m1.a_with_c), [u"7::thing1", u"7::thing2"]) self.assert_indexes(m1, { "a_with_b_bin": ["7::bar"], "a_with_c_bin": ["7::thing1", "7::thing2"], }) # Value and index are correct after load. m2 = yield ci_model.load("foo") self.assertEqual(m1.a, m2.a) self.assertEqual(m1.b, m2.b) self.assertEqual(m1.a_with_b, "7::bar") self.assertEqual(list(m1.a_with_c), ["7::thing1", "7::thing2"]) self.assert_indexes(m1, { "a_with_b_bin": ["7::bar"], "a_with_c_bin": ["7::thing1", "7::thing2"], }) @needs_riak def test_computedvalue_field_update(self): """ A `ComputedValue` field's value (and index) are updated when a field value changes. This test needs a Riak manager, but never actually calls Riak. """ ci_model = self.manager.proxy(self.ComputedValueModel) m1 = ci_model("foo", a=7, b=u"bar", c=[u"thing1", u"thing2"]) # Value and index are computed at creation time. self.assertEqual(m1.a_with_b, u"7::bar") self.assertEqual(m1.b_with_a, u"bar::7") self.assertEqual(list(m1.a_with_c), [u"7::thing1", u"7::thing2"]) self.assert_indexes(m1, { "a_with_b_bin": ["7::bar"], "a_with_c_bin": ["7::thing1", "7::thing2"], }) # Value and index are recomputed after a field changes. m1.a = 8 self.assertEqual(m1.a_with_b, u"8::bar") self.assertEqual(m1.b_with_a, u"bar::8") self.assertEqual(list(m1.a_with_c), [u"8::thing1", u"8::thing2"]) self.assert_indexes(m1, { "a_with_b_bin": ["8::bar"], "a_with_c_bin": ["8::thing1", "8::thing2"], }) @needs_riak def test_computedindex_field_cannot_set_value(self): """ A `ComputedValue` field cannot have its value set manually. This test needs a Riak manager, but never actually calls Riak. """ ci_model = self.manager.proxy(self.ComputedValueModel) m1 = ci_model("foo", a=7, b=u"bar") # Manually setting the value is impossible. def set_value(): m1.a_with_b = u"a thing" self.assertRaises(RuntimeError, set_value) PK=JG6bb'vumi/persist/tests/test_riak_manager.py"""Tests for vumi.persist.riak_manager.""" from itertools import count from twisted.internet.defer import returnValue from vumi.persist.tests.test_txriak_manager import ( CommonRiakManagerTests, DummyModel) from vumi.persist.model import Manager from vumi.tests.helpers import VumiTestCase, import_skip class TestRiakManager(CommonRiakManagerTests, VumiTestCase): """Most tests are inherited from the CommonRiakManagerTests mixin.""" def setUp(self): try: from vumi.persist.riak_manager import ( RiakManager, flatten_generator) except ImportError, e: import_skip(e, 'riak') self.call_decorator = flatten_generator self.manager = RiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(self.manager.purge_all) self.manager.purge_all() def test_call_decorator(self): self.assertEqual(type(self.manager).call_decorator, self.call_decorator) def test_flatten_generator(self): results = [] counter = count() @self.call_decorator def f(): for i in range(3): a = yield counter.next() results.append(a) ret = f() self.assertEqual(ret, None) self.assertEqual(results, list(range(3))) def test_flatter_generator_with_return_value(self): @self.call_decorator def f(): yield None returnValue("foo") ret = f() self.assertEqual(ret, "foo") @Manager.calls_manager def test_run_riak_map_reduce_and_fetch_results(self): dummies = [self.mkdummy(str(i), {"a": i}) for i in range(4)] for dummy in dummies: dummy.add_index('test_index_bin', 'test_key') yield self.manager.store(dummy) mr = self.manager.riak_map_reduce() mr.index('test.dummy_model', 'test_index_bin', 'test_key') mr.map(function='function(v) { return [[v.key, v.values[0]]] }') mr_results = [] def mapper(manager, key_and_result_tuple): self.assertEqual(manager, self.manager) key, result = key_and_result_tuple model_instance = manager.load(DummyModel, key, result) mr_results.append(model_instance) return model_instance results = yield self.manager.run_map_reduce(mr, mapper) results.sort(key=lambda d: d.key) expected_keys = [str(i) for i in range(4)] expected_data = [{"a": i} for i in range(4)] self.assertEqual([d.key for d in results], expected_keys) mr_results.sort(key=lambda model_instance: model_instance.key) self.assertEqual([model.key for model in mr_results], expected_keys) self.assertEqual( [model.get_data() for model in mr_results], expected_data) def test_transport_class_protocol_buffer(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'transport_type': 'pbc', 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'pbc') def test_transport_class_http(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'transport_type': 'http', 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'http') def test_transport_class_default(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'http') PK=JG&7&7)vumi/persist/tests/test_txriak_manager.py"""Tests for vumi.persist.txriak_manager.""" from twisted.internet.defer import inlineCallbacks from vumi.persist.model import Manager, VumiRiakError from vumi.tests.helpers import VumiTestCase, import_skip class DummyModel(object): bucket = "dummy_model" VERSION = None MIGRATORS = None def __init__(self, manager, key, _riak_object=None): self.manager = manager self.key = key self._riak_object = _riak_object @classmethod def load(cls, manager, key, result=None): return manager.load(cls, key, result=result) def set_riak(self, riak_object): self._riak_object = riak_object def get_data(self): return self._riak_object.get_data() def set_data(self, data): self._riak_object.set_data(data) def add_index(self, index_name, key): self._riak_object.add_index(index_name, key) def get_link_key(link): return link[1] class CommonRiakManagerTests(object): """Common tests for Riak managers. Tests assume self.manager is set to a suitable Riak manager. """ def mkdummy(self, key, data=None, dummy_class=DummyModel): dummy = dummy_class(self.manager, key) dummy.set_riak(self.manager.riak_object(dummy, key)) if data is not None: dummy.set_data(data) return dummy def test_from_config(self): manager_cls = self.manager.__class__ manager = manager_cls.from_config({'bucket_prefix': 'test.'}) self.assertEqual(manager.__class__, manager_cls) self.assertEqual(manager.load_bunch_size, manager.DEFAULT_LOAD_BUNCH_SIZE) self.assertEqual(manager.mapreduce_timeout, manager.DEFAULT_MAPREDUCE_TIMEOUT) def test_from_config_with_bunch_size(self): manager_cls = self.manager.__class__ manager = manager_cls.from_config({'bucket_prefix': 'test.', 'load_bunch_size': 10, }) self.assertEqual(manager.load_bunch_size, 10) def test_from_config_with_mapreduce_timeout(self): manager_cls = self.manager.__class__ manager = manager_cls.from_config({'bucket_prefix': 'test.', 'mapreduce_timeout': 1000, }) self.assertEqual(manager.mapreduce_timeout, 1000) def test_from_config_with_store_versions(self): manager_cls = self.manager.__class__ manager = manager_cls.from_config({ 'bucket_prefix': 'test.', 'store_versions': { 'foo.Foo': 3, 'bar.Bar': None, }, }) self.assertEqual(manager.store_versions, { 'foo.Foo': 3, 'bar.Bar': None, }) def test_sub_manager(self): """ A sub-manager shares its parent's client object, but has an additional suffix on its bucket_prefix. """ sub_manager = self.manager.sub_manager("foo.") self.assertEqual(sub_manager.client, self.manager.client) self.assertEqual(sub_manager._parent, self.manager) self.assertEqual(sub_manager.bucket_prefix, 'test.foo.') def test_sub_manager_unclosed(self): """ A sub-manager is never "unclosed", because the parent is responsible for managing the client object. """ sub_manager = self.manager.sub_manager("foo.") self.assertEqual(sub_manager.client, self.manager.client) self.assertEqual(sub_manager.client._closed, False) self.assertEqual(sub_manager._is_unclosed(), False) @Manager.calls_manager def test_sub_manager_close(self): """ A sub-manager does not close its client object, because the parent is responsible for managing the client object. """ sub_manager = self.manager.sub_manager("foo.") self.assertEqual(sub_manager.client, self.manager.client) self.assertEqual(sub_manager.client._closed, False) yield sub_manager.close_manager() self.assertEqual(sub_manager.client._closed, False) def test_bucket_name_on_modelcls(self): dummy = self.mkdummy("bar") bucket_name = self.manager.bucket_name(type(dummy)) self.assertEqual(bucket_name, "test.dummy_model") def test_bucket_name_on_instance(self): dummy = self.mkdummy("bar") bucket_name = self.manager.bucket_name(dummy) self.assertEqual(bucket_name, "test.dummy_model") def test_bucket_for_modelcls(self): dummy_cls = type(self.mkdummy("foo")) bucket1 = self.manager.bucket_for_modelcls(dummy_cls) bucket2 = self.manager.bucket_for_modelcls(dummy_cls) self.assertEqual(id(bucket1), id(bucket2)) self.assertEqual(bucket1.get_name(), "test.dummy_model") def test_riak_object(self): dummy = DummyModel(self.manager, "foo") riak_object = self.manager.riak_object(dummy, "foo") self.assertEqual(riak_object.get_data(), {'$VERSION': None}) self.assertEqual(riak_object.get_content_type(), "application/json") self.assertEqual( riak_object.get_bucket().get_name(), "test.dummy_model") self.assertEqual(riak_object.key, "foo") @Manager.calls_manager def test_store_and_load(self): dummy1 = self.mkdummy("foo", {"a": 1}) result1 = yield self.manager.store(dummy1) self.assertEqual(dummy1, result1) dummy2 = yield self.manager.load(DummyModel, "foo") self.assertEqual(dummy2.get_data(), {"a": 1}) @Manager.calls_manager def test_delete(self): dummy1 = self.mkdummy("foo", {"a": 1}) yield self.manager.store(dummy1) dummy2 = yield self.manager.load(DummyModel, "foo") yield self.manager.delete(dummy2) dummy3 = yield self.manager.load(DummyModel, "foo") self.assertEqual(dummy3, None) @Manager.calls_manager def test_load_missing(self): dummy = self.mkdummy("unknown") result = yield self.manager.load(DummyModel, dummy.key) self.assertEqual(result, None) @Manager.calls_manager def test_load_all_bunches(self): yield self.manager.store(self.mkdummy("foo", {"a": 0})) yield self.manager.store(self.mkdummy("bar", {"a": 1})) yield self.manager.store(self.mkdummy("baz", {"a": 2})) self.manager.load_bunch_size = load_bunch_size = 2 keys = ["foo", "unknown", "bar", "baz"] result_data = [] for result_bunch in self.manager.load_all_bunches(DummyModel, keys): bunch = yield result_bunch self.assertTrue(len(bunch) <= load_bunch_size) result_data.extend(result.get_data() for result in bunch) result_data.sort(key=lambda d: d["a"]) self.assertEqual(result_data, [{"a": 0}, {"a": 1}, {"a": 2}]) @Manager.calls_manager def test_run_riak_map_reduce(self): dummies = [self.mkdummy(str(i), {"a": i}) for i in range(4)] for dummy in dummies: dummy.add_index('test_index_bin', 'test_key') yield self.manager.store(dummy) mr = self.manager.riak_map_reduce() mr.index('test.dummy_model', 'test_index_bin', 'test_key') mr_results = [] def mapper(manager, link): self.assertEqual(manager, self.manager) mr_results.append(link) dummy = self.mkdummy(get_link_key(link)) return manager.load(DummyModel, dummy.key) results = yield self.manager.run_map_reduce(mr, mapper) results.sort(key=lambda d: d.key) expected_keys = [str(i) for i in range(4)] self.assertEqual([d.key for d in results], expected_keys) mr_results.sort(key=get_link_key) self.assertEqual([get_link_key(l) for l in mr_results], expected_keys) @Manager.calls_manager def test_run_riak_map_reduce_with_timeout(self): dummies = [self.mkdummy(str(i), {"a": i}) for i in range(4)] for dummy in dummies: dummy.add_index('test_index_bin', 'test_key') yield self.manager.store(dummy) # override mapreduce_timeout for testing self.manager.mapreduce_timeout = 10 # milliseconds mr = self.manager.riak_map_reduce() mr.index('test.dummy_model', 'test_index_bin', 'test_key') mr.map( """ function(value, keyData) { var date = new Date(); var curDate = null; // Wait 11ms so we run past the 10ms timeout. do { curDate = new Date(); } while(curDate-date < 11); return value; } """) try: yield self.manager.run_map_reduce(mr, lambda m, l: None) except Exception, err: msg = str(err)[1:-1].decode("string-escape") if not all([ msg.startswith("Error running MapReduce operation."), msg.endswith("Body: '{\"error\":\"timeout\"}'")]): # This doesn't look like a timeout error, reraise it. raise else: self.fail("Map reduce operation did not timeout") @Manager.calls_manager def test_purge_all(self): dummy = self.mkdummy("foo", {"baz": 0}) yield self.manager.store(dummy) yield self.manager.purge_all() result = yield self.manager.load(DummyModel, dummy.key) self.assertEqual(result, None) @Manager.calls_manager def test_purge_all_clears_bucket_properties(self): search_enabled = yield self.manager.riak_search_enabled(DummyModel) self.assertEqual(search_enabled, False) yield self.manager.riak_enable_search(DummyModel) search_enabled = yield self.manager.riak_search_enabled(DummyModel) self.assertEqual(search_enabled, True) # We need at least one key in here so the bucket can be found and # purged. dummy = self.mkdummy("foo", {"baz": 0}) yield self.manager.store(dummy) yield self.manager.purge_all() search_enabled = yield self.manager.riak_search_enabled(DummyModel) self.assertEqual(search_enabled, False) @Manager.calls_manager def test_json_decoding(self): # Some versions of the riak client library use simplejson by # preference, which breaks some of our unicode assumptions. This test # only fails when such a version is being used and our workaround # fails. If we're using a good version of the client library, the test # will pass even if the workaround fails. dummy1 = self.mkdummy("foo", {"a": "b"}) result1 = yield self.manager.store(dummy1) self.assertTrue(isinstance(result1.get_data()["a"], unicode)) dummy2 = yield self.manager.load(DummyModel, "foo") self.assertEqual(dummy2.get_data(), {"a": "b"}) self.assertTrue(isinstance(dummy2.get_data()["a"], unicode)) @Manager.calls_manager def test_json_decoding_index_keys(self): # Some versions of the riak client library use simplejson by # preference, which breaks some of our unicode assumptions. This test # only fails when such a version is being used and our workaround # fails. If we're using a good version of the client library, the test # will pass even if the workaround fails. class MyDummy(DummyModel): # Use a fresh bucket name here so we don't get leftover keys. bucket = 'decoding_index_dummy' dummy1 = self.mkdummy("foo", {"a": "b"}, dummy_class=MyDummy) yield self.manager.store(dummy1) [key] = yield self.manager.index_keys( MyDummy, '$bucket', self.manager.bucket_name(MyDummy), None) self.assertEqual(key, u"foo") self.assertTrue(isinstance(key, unicode)) @Manager.calls_manager def test_error_when_closed(self): """ We get an exception if we try to use a closed manager. """ # Load a missing object while open, no exception. dummy = self.mkdummy("unknown") result = yield self.manager.load(DummyModel, dummy.key) self.assertEqual(result, None) # Load a missing object while closed. yield self.manager.close_manager() try: yield self.manager.load(DummyModel, dummy.key) except VumiRiakError, err: self.assertEqual(err.args[0], "Can't use closed Riak client.") else: self.fail( "Expected VumiRiakError using closed manager, nothing raised.") class TestTxRiakManager(CommonRiakManagerTests, VumiTestCase): @inlineCallbacks def setUp(self): try: from vumi.persist.txriak_manager import TxRiakManager except ImportError, e: import_skip(e, 'riak', 'riak') self.manager = TxRiakManager.from_config({'bucket_prefix': 'test.'}) self.add_cleanup(self.manager.purge_all) yield self.manager.purge_all() def test_call_decorator(self): self.assertEqual(type(self.manager).call_decorator, inlineCallbacks) def test_transport_class_protocol_buffer(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'transport_type': 'pbc', 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'pbc') def test_transport_class_http(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'transport_type': 'http', 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'http') def test_transport_class_default(self): manager_class = type(self.manager) manager = manager_class.from_config({ 'bucket_prefix': 'test.', }) self.assertEqual(manager.client.protocol, 'http') PK,\cH)i (vumi/persist/tests/test_redis_manager.py"""Tests for vumi.persist.redis_manager.""" from vumi.tests.helpers import VumiTestCase, import_skip class TestRedisManager(VumiTestCase): def setUp(self): try: from vumi.persist.redis_manager import RedisManager except ImportError, e: import_skip(e, 'redis') self.manager = RedisManager.from_config( {'FAKE_REDIS': 'yes', 'key_prefix': 'redistest'}) self.add_cleanup(self.cleanup_manager) self.manager._purge_all() def cleanup_manager(self): self.manager._purge_all() self.manager._close() def test_key_unkey(self): self.assertEqual('redistest:foo', self.manager._key('foo')) self.assertEqual('foo', self.manager._unkey('redistest:foo')) self.assertEqual('redistest:redistest:foo', self.manager._key('redistest:foo')) self.assertEqual('redistest:foo', self.manager._unkey('redistest:redistest:foo')) def test_set_get_keys(self): self.assertEqual([], self.manager.keys()) self.assertEqual(None, self.manager.get('foo')) self.manager.set('foo', 'bar') self.assertEqual(['foo'], self.manager.keys()) self.assertEqual('bar', self.manager.get('foo')) self.manager.set('foo', 'baz') self.assertEqual(['foo'], self.manager.keys()) self.assertEqual('baz', self.manager.get('foo')) def test_disconnect_twice(self): self.manager._close() self.manager._close() def test_scan(self): self.assertEqual([], self.manager.keys()) for i in range(10): self.manager.set('key%d' % i, 'value%d' % i) all_keys = set() cursor = None for i in range(20): # loop enough times to have gone through all the keys in our test # redis instance but not forever so we can assert on the value of # cursor if we get stuck. Also prevents hanging the whole test # suite (since this test doesn't yield to the reactor). cursor, keys = self.manager.scan(cursor) all_keys.update(keys) if cursor is None: break self.assertEqual(cursor, None) self.assertEqual(all_keys, set( 'key%d' % i for i in range(10))) def test_ttl(self): missing_ttl = self.manager.ttl("missing_key") self.assertEqual(missing_ttl, None) self.manager.set("key-no-ttl", "value") no_ttl = self.manager.ttl("key-no-ttl") self.assertEqual(no_ttl, None) self.manager.setex("key-ttl", 30, "value") ttl = self.manager.ttl("key-ttl") self.assertTrue(10 <= ttl <= 30) PK=JG`99 vumi/resources/amqp-spec-0-8.xml AMQ Protocol 0.80 Indicates that the method completed successfully. This reply code is reserved for future use - the current protocol design does not use positive confirmation and reply codes are sent only in case of an error. The client asked for a specific message that is no longer available. The message was delivered to another client, or was purged from the queue for some other reason. The client attempted to transfer content larger than the server could accept at the present time. The client may retry at a later time. An operator intervened to close the connection for some reason. The client may retry at some later date. The client tried to work with an unknown virtual host or cluster. The client attempted to work with a server entity to which it has no due to security settings. The client attempted to work with a server entity that does not exist. The client attempted to work with a server entity to which it has no access because another client is working with it. The client sent a malformed frame that the server could not decode. This strongly implies a programming error in the client. The client sent a frame that contained illegal values for one or more fields. This strongly implies a programming error in the client. The client sent an invalid sequence of frames, attempting to perform an operation that was considered invalid by the server. This usually implies a programming error in the client. The client attempted to work with a channel that had not been correctly opened. This most likely indicates a fault in the client layer. The server could not complete the method because it lacked sufficient resources. This may be due to the client creating too many of some type of entity. The client tried to work with some entity in a manner that is prohibited by the server, due to security settings or by some other criteria. The client tried to use functionality that is not implemented in the server. The server could not complete the method because of an internal error. The server may require intervention by an operator in order to resume normal operations. access ticket granted by server An access ticket granted by the server for a certain set of access rights within a specific realm. Access tickets are valid within the channel where they were created, and expire when the channel closes. consumer tag Identifier for the consumer, valid within the current connection. The consumer tag is valid only within the channel from which the consumer was created. I.e. a client MUST NOT create a consumer in one channel and then use it in another. server-assigned delivery tag The server-assigned and channel-specific delivery tag The delivery tag is valid only within the channel from which the message was received. I.e. a client MUST NOT receive a message on one channel and then acknowledge it on another. The server MUST NOT use a zero value for delivery tags. Zero is reserved for client use, meaning "all messages so far received". exchange name The exchange name is a client-selected string that identifies the exchange for publish methods. Exchange names may consist of any mixture of digits, letters, and underscores. Exchange names are scoped by the virtual host. list of known hosts Specifies the list of equivalent or alternative hosts that the server knows about, which will normally include the current server itself. Clients can cache this information and use it when reconnecting to a server after a failure. The server MAY leave this field empty if it knows of no other hosts than itself. no acknowledgement needed If this field is set the server does not expect acknowledgments for messages. That is, when a message is delivered to the client the server automatically and silently acknowledges it on behalf of the client. This functionality increases performance but at the cost of reliability. Messages can get lost if a client dies before it can deliver them to the application. do not deliver own messages If the no-local field is set the server will not send messages to the client that published them. Must start with a slash "/" and continue with path names separated by slashes. A path name consists of any combination of at least one of [A-Za-z0-9] plus zero or more of [.-_+!=:]. This string provides a set of peer properties, used for identification, debugging, and general information. The properties SHOULD contain these fields: "product", giving the name of the peer product, "version", giving the name of the peer version, "platform", giving the name of the operating system, "copyright", if appropriate, and "information", giving other general information. queue name The queue name identifies the queue within the vhost. Queue names may consist of any mixture of digits, letters, and underscores. message is being redelivered This indicates that the message has been previously delivered to this or another client. The server SHOULD try to signal redelivered messages when it can. When redelivering a message that was not successfully acknowledged, the server SHOULD deliver it to the original client if possible. The client MUST NOT rely on the redelivered field but MUST take it as a hint that the message may already have been processed. A fully robust client must be able to track duplicate received messages on non-transacted, and locally-transacted channels. reply code from server The reply code. The AMQ reply codes are defined in AMQ RFC 011. localised reply text The localised reply text. This text can be logged as an aid to resolving issues. work with socket connections The connection class provides methods for a client to establish a network connection to a server, and for both peers to operate the connection thereafter. connection = open-connection *use-connection close-connection open-connection = C:protocol-header S:START C:START-OK *challenge S:TUNE C:TUNE-OK C:OPEN S:OPEN-OK | S:REDIRECT challenge = S:SECURE C:SECURE-OK use-connection = *channel close-connection = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK start connection negotiation This method starts the connection negotiation process by telling the client the protocol version that the server proposes, along with a list of security mechanisms which the client can use for authentication. If the client cannot handle the protocol version suggested by the server it MUST close the socket connection. The server MUST provide a protocol version that is lower than or equal to that requested by the client in the protocol header. If the server cannot support the specified protocol it MUST NOT send this method, but MUST close the socket connection. protocol major version The protocol major version that the server agrees to use, which cannot be higher than the client's major version. protocol major version The protocol minor version that the server agrees to use, which cannot be higher than the client's minor version. server properties available security mechanisms A list of the security mechanisms that the server supports, delimited by spaces. Currently ASL supports these mechanisms: PLAIN. available message locales A list of the message locales that the server supports, delimited by spaces. The locale defines the language in which the server will send reply texts. All servers MUST support at least the en_US locale. select security mechanism and locale This method selects a SASL security mechanism. ASL uses SASL (RFC2222) to negotiate authentication and encryption. client properties selected security mechanism A single security mechanisms selected by the client, which must be one of those specified by the server. The client SHOULD authenticate using the highest-level security profile it can handle from the list provided by the server. The mechanism field MUST contain one of the security mechanisms proposed by the server in the Start method. If it doesn't, the server MUST close the socket. security response data A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. For the PLAIN security mechanism this is defined as a field table holding two fields, LOGIN and PASSWORD. selected message locale A single message local selected by the client, which must be one of those specified by the server. security mechanism challenge The SASL protocol works by exchanging challenges and responses until both peers have received sufficient information to authenticate each other. This method challenges the client to provide more information. security challenge data Challenge information, a block of opaque binary data passed to the security mechanism. security mechanism response This method attempts to authenticate, passing a block of SASL data for the security mechanism at the server side. security response data A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. propose connection tuning parameters This method proposes a set of connection configuration values to the client. The client can accept and/or adjust these. proposed maximum channels The maximum total number of channels that the server allows per connection. Zero means that the server does not impose a fixed limit, but the number of allowed channels may be limited by available server resources. proposed maximum frame size The largest frame size that the server proposes for the connection. The client can negotiate a lower value. Zero means that the server does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Until the frame-max has been negotiated, both peers MUST accept frames of up to 4096 octets large. The minimum non-zero value for the frame-max field is 4096. desired heartbeat delay The delay, in seconds, of the connection heartbeat that the server wants. Zero means the server does not want a heartbeat. negotiate connection tuning parameters This method sends the client's connection tuning parameters to the server. Certain fields are negotiated, others provide capability information. negotiated maximum channels The maximum total number of channels that the client will use per connection. May not be higher than the value specified by the server. The server MAY ignore the channel-max value or MAY use it for tuning its resource allocation. negotiated maximum frame size The largest frame size that the client and server will use for the connection. Zero means that the client does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Note that the frame-max limit applies principally to content frames, where large contents can be broken into frames of arbitrary size. Until the frame-max has been negotiated, both peers must accept frames of up to 4096 octets large. The minimum non-zero value for the frame-max field is 4096. desired heartbeat delay The delay, in seconds, of the connection heartbeat that the client wants. Zero means the client does not want a heartbeat. open connection to virtual host This method opens a connection to a virtual host, which is a collection of resources, and acts to separate multiple application domains within a server. The client MUST open the context before doing any work on the connection. virtual host name The name of the virtual host to work with. If the server supports multiple virtual hosts, it MUST enforce a full separation of exchanges, queues, and all associated entities per virtual host. An application, connected to a specific virtual host, MUST NOT be able to access resources of another virtual host. The server SHOULD verify that the client has permission to access the specified virtual host. The server MAY configure arbitrary limits per virtual host, such as the number of each type of entity that may be used, per connection and/or in total. required capabilities The client may specify a number of capability names, delimited by spaces. The server can use this string to how to process the client's connection request. insist on connecting to server In a configuration with multiple load-sharing servers, the server may respond to a Connection.Open method with a Connection.Redirect. The insist option tells the server that the client is insisting on a connection to the specified server. When the client uses the insist option, the server SHOULD accept the client connection unless it is technically unable to do so. signal that the connection is ready This method signals to the client that the connection is ready for use. asks the client to use a different server This method redirects the client to another server, based on the requested virtual host and/or capabilities. When getting the Connection.Redirect method, the client SHOULD reconnect to the host specified, and if that host is not present, to any of the hosts specified in the known-hosts list. server to connect to Specifies the server to connect to. This is an IP address or a DNS name, optionally followed by a colon and a port number. If no port number is specified, the client should use the default port number for the protocol. request a connection close This method indicates that the sender wants to close the connection. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method any received method except the Close-OK method MUST be discarded. The peer sending this method MAY use a counter or timeout to detect failure of the other peer to respond correctly with the Close-OK method. When a server receives the Close method from a client it MUST delete all server-side resources associated with the client's context. A client CANNOT reconnect to a context after sending or receiving a Close method. failing method class When the close is provoked by a method exception, this is the class of the method. failing method ID When the close is provoked by a method exception, this is the ID of the method. confirm a connection close This method confirms a Connection.Close method and tells the recipient that it is safe to release resources for the connection and close the socket. A peer that detects a socket closure without having received a Close-Ok handshake method SHOULD log the error. work with channels The channel class provides methods for a client to establish a virtual connection - a channel - to a server and for both peers to operate the virtual connection thereafter. channel = open-channel *use-channel close-channel open-channel = C:OPEN S:OPEN-OK use-channel = C:FLOW S:FLOW-OK / S:FLOW C:FLOW-OK / S:ALERT / functional-class close-channel = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK open a channel for use This method opens a virtual connection (a channel). This method MUST NOT be called when the channel is already open. out-of-band settings Configures out-of-band transfers on this channel. The syntax and meaning of this field will be formally defined at a later date. signal that the channel is ready This method signals to the client that the channel is ready for use. enable/disable flow from peer This method asks the peer to pause or restart the flow of content data. This is a simple flow-control mechanism that a peer can use to avoid oveflowing its queues or otherwise finding itself receiving more messages than it can process. Note that this method is not intended for window control. The peer that receives a request to stop sending content should finish sending the current content, if any, and then wait until it receives a Flow restart method. When a new channel is opened, it is active. Some applications assume that channels are inactive until started. To emulate this behaviour a client MAY open the channel, then pause it. When sending content data in multiple frames, a peer SHOULD monitor the channel for incoming methods and respond to a Channel.Flow as rapidly as possible. A peer MAY use the Channel.Flow method to throttle incoming content data for internal reasons, for example, when exchangeing data over a slower connection. The peer that requests a Channel.Flow method MAY disconnect and/or ban a peer that does not respect the request. start/stop content frames If 1, the peer starts sending content frames. If 0, the peer stops sending content frames. confirm a flow method Confirms to the peer that a flow command was received and processed. current flow setting Confirms the setting of the processed flow method: 1 means the peer will start sending or continue to send content frames; 0 means it will not. send a non-fatal warning message This method allows the server to send a non-fatal warning to the client. This is used for methods that are normally asynchronous and thus do not have confirmations, and for which the server may detect errors that need to be reported. Fatal errors are handled as channel or connection exceptions; non-fatal errors are sent through this method. detailed information for warning A set of fields that provide more information about the problem. The meaning of these fields are defined on a per-reply-code basis (TO BE DEFINED). request a channel close This method indicates that the sender wants to close the channel. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method any received method except Channel.Close-OK MUST be discarded. The peer sending this method MAY use a counter or timeout to detect failure of the other peer to respond correctly with Channel.Close-OK.. failing method class When the close is provoked by a method exception, this is the class of the method. failing method ID When the close is provoked by a method exception, this is the ID of the method. confirm a channel close This method confirms a Channel.Close method and tells the recipient that it is safe to release resources for the channel and close the socket. A peer that detects a socket closure without having received a Channel.Close-Ok handshake method SHOULD log the error. work with access tickets The protocol control access to server resources using access tickets. A client must explicitly request access tickets before doing work. An access ticket grants a client the right to use a specific set of resources - called a "realm" - in specific ways. access = C:REQUEST S:REQUEST-OK request an access ticket This method requests an access ticket for an access realm. The server responds by granting the access ticket. If the client does not have access rights to the requested realm this causes a connection exception. Access tickets are a per-channel resource. The realm name MUST start with either "/data" (for application resources) or "/admin" (for server administration resources). If the realm starts with any other path, the server MUST raise a connection exception with reply code 403 (access refused). The server MUST implement the /data realm and MAY implement the /admin realm. The mapping of resources to realms is not defined in the protocol - this is a server-side configuration issue. name of requested realm If the specified realm is not known to the server, the server must raise a channel exception with reply code 402 (invalid path). request exclusive access Request exclusive access to the realm. If the server cannot grant this - because there are other active tickets for the realm - it raises a channel exception. request passive access Request message passive access to the specified access realm. Passive access lets a client get information about resources in the realm but not to make any changes to them. request active access Request message active access to the specified access realm. Acvtive access lets a client get create and delete resources in the realm. request write access Request write access to the specified access realm. Write access lets a client publish messages to all exchanges in the realm. request read access Request read access to the specified access realm. Read access lets a client consume messages from queues in the realm. grant access to server resources This method provides the client with an access ticket. The access ticket is valid within the current channel and for the lifespan of the channel. The client MUST NOT use access tickets except within the same channel as originally granted. The server MUST isolate access tickets per channel and treat an attempt by a client to mix these as a connection exception. work with exchanges Exchanges match and distribute messages across queues. Exchanges can be configured in the server or created at runtime. exchange = C:DECLARE S:DECLARE-OK / C:DELETE S:DELETE-OK amq_exchange_19 The server MUST implement the direct and fanout exchange types, and predeclare the corresponding exchanges named amq.direct and amq.fanout in each virtual host. The server MUST also predeclare a direct exchange to act as the default exchange for content Publish methods and for default queue bindings. amq_exchange_20 The server SHOULD implement the topic exchange type, and predeclare the corresponding exchange named amq.topic in each virtual host. amq_exchange_21 The server MAY implement the system exchange type, and predeclare the corresponding exchanges named amq.system in each virtual host. If the client attempts to bind a queue to the system exchange, the server MUST raise a connection exception with reply code 507 (not allowed). amq_exchange_22 The default exchange MUST be defined as internal, and be inaccessible to the client except by specifying an empty exchange name in a content Publish method. That is, the server MUST NOT let clients make explicit bindings to this exchange. declare exchange, create if needed This method creates an exchange if it does not already exist, and if the exchange exists, verifies that it is of the correct and expected class. amq_exchange_23 The server SHOULD support a minimum of 16 exchanges per virtual host and ideally, impose no limit except as defined by available resources. When a client defines a new exchange, this belongs to the access realm of the ticket used. All further work done with that exchange must be done with an access ticket for the same realm. The client MUST provide a valid access ticket giving "active" access to the realm in which the exchange exists or will be created, or "passive" access if the if-exists flag is set. amq_exchange_15 Exchange names starting with "amq." are reserved for predeclared and standardised exchanges. If the client attempts to create an exchange starting with "amq.", the server MUST raise a channel exception with reply code 403 (access refused). exchange type Each exchange belongs to one of a set of exchange types implemented by the server. The exchange types define the functionality of the exchange - i.e. how messages are routed through it. It is not valid or meaningful to attempt to change the type of an existing exchange. amq_exchange_16 If the exchange already exists with a different type, the server MUST raise a connection exception with a reply code 507 (not allowed). amq_exchange_18 If the server does not support the requested exchange type it MUST raise a connection exception with a reply code 503 (command invalid). do not create exchange If set, the server will not create the exchange. The client can use this to check whether an exchange exists without modifying the server state. amq_exchange_05 If set, and the exchange does not already exist, the server MUST raise a channel exception with reply code 404 (not found). request a durable exchange If set when creating a new exchange, the exchange will be marked as durable. Durable exchanges remain active when a server restarts. Non-durable exchanges (transient exchanges) are purged if/when a server restarts. amq_exchange_24 The server MUST support both durable and transient exchanges. The server MUST ignore the durable field if the exchange already exists. auto-delete when unused If set, the exchange is deleted when all queues have finished using it. amq_exchange_02 The server SHOULD allow for a reasonable delay between the point when it determines that an exchange is not being used (or no longer used), and the point when it deletes the exchange. At the least it must allow a client to create an exchange and then bind a queue to it, with a small but non-zero delay between these two actions. amq_exchange_25 The server MUST ignore the auto-delete field if the exchange already exists. create internal exchange If set, the exchange may not be used directly by publishers, but only when bound to other exchanges. Internal exchanges are used to construct wiring that is not visible to applications. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. arguments for declaration A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This field is ignored if passive is 1. confirms an exchange declaration This method confirms a Declare method and confirms the name of the exchange, essential for automatically-named exchanges. delete an exchange This method deletes an exchange. When an exchange is deleted all queue bindings on the exchange are cancelled. The client MUST provide a valid access ticket giving "active" access rights to the exchange's access realm. amq_exchange_11 The exchange MUST exist. Attempting to delete a non-existing exchange causes a channel exception. delete only if unused If set, the server will only delete the exchange if it has no queue bindings. If the exchange has queue bindings the server does not delete it but raises a channel exception instead. amq_exchange_12 If set, the server SHOULD delete the exchange but only if it has no queue bindings. amq_exchange_13 If set, the server SHOULD raise a channel exception if the exchange is in use. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm deletion of an exchange This method confirms the deletion of an exchange. work with queues Queues store and forward messages. Queues can be configured in the server or created at runtime. Queues must be attached to at least one exchange in order to receive messages from publishers. queue = C:DECLARE S:DECLARE-OK / C:BIND S:BIND-OK / C:PURGE S:PURGE-OK / C:DELETE S:DELETE-OK amq_queue_33 A server MUST allow any content class to be sent to any queue, in any mix, and queue and delivery these content classes independently. Note that all methods that fetch content off queues are specific to a given content class. declare queue, create if needed This method creates or checks a queue. When creating a new queue the client can specify various properties that control the durability of the queue and its contents, and the level of sharing for the queue. amq_queue_34 The server MUST create a default binding for a newly-created queue to the default exchange, which is an exchange of type 'direct'. amq_queue_35 The server SHOULD support a minimum of 256 queues per virtual host and ideally, impose no limit except as defined by available resources. When a client defines a new queue, this belongs to the access realm of the ticket used. All further work done with that queue must be done with an access ticket for the same realm. The client provides a valid access ticket giving "active" access to the realm in which the queue exists or will be created, or "passive" access if the if-exists flag is set. amq_queue_10 The queue name MAY be empty, in which case the server MUST create a new queue with a unique generated name and return this to the client in the Declare-Ok method. amq_queue_32 Queue names starting with "amq." are reserved for predeclared and standardised server queues. If the queue name starts with "amq." and the passive option is zero, the server MUST raise a connection exception with reply code 403 (access refused). do not create queue If set, the server will not create the queue. The client can use this to check whether a queue exists without modifying the server state. amq_queue_05 If set, and the queue does not already exist, the server MUST respond with a reply code 404 (not found) and raise a channel exception. request a durable queue If set when creating a new queue, the queue will be marked as durable. Durable queues remain active when a server restarts. Non-durable queues (transient queues) are purged if/when a server restarts. Note that durable queues do not necessarily hold persistent messages, although it does not make sense to send persistent messages to a transient queue. amq_queue_03 The server MUST recreate the durable queue after a restart. amq_queue_36 The server MUST support both durable and transient queues. amq_queue_37 The server MUST ignore the durable field if the queue already exists. request an exclusive queue Exclusive queues may only be consumed from by the current connection. Setting the 'exclusive' flag always implies 'auto-delete'. amq_queue_38 The server MUST support both exclusive (private) and non-exclusive (shared) queues. amq_queue_04 The server MUST raise a channel exception if 'exclusive' is specified and the queue already exists and is owned by a different connection. auto-delete queue when unused If set, the queue is deleted when all consumers have finished using it. Last consumer can be cancelled either explicitly or because its channel is closed. If there was no consumer ever on the queue, it won't be deleted. amq_queue_02 The server SHOULD allow for a reasonable delay between the point when it determines that a queue is not being used (or no longer used), and the point when it deletes the queue. At the least it must allow a client to create a queue and then create a consumer to read from it, with a small but non-zero delay between these two actions. The server should equally allow for clients that may be disconnected prematurely, and wish to re-consume from the same queue without losing messages. We would recommend a configurable timeout, with a suitable default value being one minute. amq_queue_31 The server MUST ignore the auto-delete field if the queue already exists. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. arguments for declaration A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This field is ignored if passive is 1. confirms a queue definition This method confirms a Declare method and confirms the name of the queue, essential for automatically-named queues. Reports the name of the queue. If the server generated a queue name, this field contains that name. number of messages in queue Reports the number of messages in the queue, which will be zero for newly-created queues. number of consumers Reports the number of active consumers for the queue. Note that consumers can suspend activity (Channel.Flow) in which case they do not appear in this count. bind queue to an exchange This method binds a queue to an exchange. Until a queue is bound it will not receive any messages. In a classic messaging model, store-and-forward queues are bound to a dest exchange and subscription queues are bound to a dest_wild exchange. amq_queue_25 A server MUST allow ignore duplicate bindings - that is, two or more bind methods for a specific queue, with identical arguments - without treating these as an error. amq_queue_39 If a bind fails, the server MUST raise a connection exception. amq_queue_12 The server MUST NOT allow a durable queue to bind to a transient exchange. If the client attempts this the server MUST raise a channel exception. amq_queue_13 Bindings for durable queues are automatically durable and the server SHOULD restore such bindings after a server restart. amq_queue_17 If the client attempts to an exchange that was declared as internal, the server MUST raise a connection exception with reply code 530 (not allowed). amq_queue_40 The server SHOULD support at least 4 bindings per queue, and ideally, impose no limit except as defined by available resources. The client provides a valid access ticket giving "active" access rights to the queue's access realm. Specifies the name of the queue to bind. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). If the queue does not exist the server MUST raise a channel exception with reply code 404 (not found). The name of the exchange to bind to. amq_queue_14 If the exchange does not exist the server MUST raise a channel exception with reply code 404 (not found). message routing key Specifies the routing key for the binding. The routing key is used for routing messages depending on the exchange configuration. Not all exchanges use a routing key - refer to the specific exchange documentation. If the routing key is empty and the queue name is empty, the routing key will be the current queue for the channel, which is the last declared queue. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. arguments for binding A set of arguments for the binding. The syntax and semantics of these arguments depends on the exchange class. confirm bind successful This method confirms that the bind was successful. purge a queue This method removes all messages from a queue. It does not cancel consumers. Purged messages are deleted without any formal "undo" mechanism. amq_queue_15 A call to purge MUST result in an empty queue. amq_queue_41 On transacted channels the server MUST not purge messages that have already been sent to a client but not yet acknowledged. amq_queue_42 The server MAY implement a purge queue or log that allows system administrators to recover accidentally-purged messages. The server SHOULD NOT keep purged messages in the same storage spaces as the live messages since the volumes of purged messages may get very large. The access ticket must be for the access realm that holds the queue. The client MUST provide a valid access ticket giving "read" access rights to the queue's access realm. Note that purging a queue is equivalent to reading all messages and discarding them. Specifies the name of the queue to purge. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). The queue must exist. Attempting to purge a non-existing queue causes a channel exception. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirms a queue purge This method confirms the purge of a queue. number of messages purged Reports the number of messages purged. delete a queue This method deletes a queue. When a queue is deleted any pending messages are sent to a dead-letter queue if this is defined in the server configuration, and all consumers on the queue are cancelled. amq_queue_43 The server SHOULD use a dead-letter queue to hold messages that were pending on a deleted queue, and MAY provide facilities for a system administrator to move these messages back to an active queue. The client provides a valid access ticket giving "active" access rights to the queue's access realm. Specifies the name of the queue to delete. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). The queue must exist. Attempting to delete a non-existing queue causes a channel exception. delete only if unused If set, the server will only delete the queue if it has no consumers. If the queue has consumers the server does does not delete it but raises a channel exception instead. amq_queue_29 amq_queue_30 The server MUST respect the if-unused flag when deleting a queue. delete only if empty amq_queue_27 If set, the server will only delete the queue if it has no messages. If the queue is not empty the server raises a channel exception. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm deletion of a queue This method confirms the deletion of a queue. number of messages purged Reports the number of messages purged. work with basic content The Basic class provides methods that support an industry-standard messaging model. basic = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:PUBLISH content / S:RETURN content / S:DELIVER content / C:GET ( S:GET-OK content / S:GET-EMPTY ) / C:ACK / C:REJECT The server SHOULD respect the persistent property of basic messages and SHOULD make a best-effort to hold persistent basic messages on a reliable storage mechanism. The server MUST NOT discard a persistent basic message in case of a queue overflow. The server MAY use the Channel.Flow method to slow or stop a basic message publisher when necessary. The server MAY overflow non-persistent basic messages to persistent storage and MAY discard or dead-letter non-persistent basic messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for basic messages, where priorities 0-4 and 5-9 are treated as two distinct levels. The server MAY implement up to 10 priority levels. The server MUST deliver messages of the same priority in order irrespective of their individual persistence. The server MUST support both automatic and explicit acknowledgements on Basic content. MIME content type MIME content encoding Message header field table Non-persistent (1) or persistent (2) The message priority, 0 to 9 The application correlation identifier The destination to reply to Message expiration specification The application message identifier The message timestamp The message type name The creating user id The creating application id Intra-cluster routing identifier specify quality of service This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. prefetch window in octets The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. The server will send a message in advance if it is equal to or smaller in size than the available prefetch size (and also falls into other prefetch limits). May be set to zero, meaning "no specific limit", although other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. The server MUST ignore this setting when the client is not processing any messages - i.e. the prefetch size does not limit the transfer of single messages to a client, only the sending in advance of more messages while the client still has one or more unacknowledged messages. prefetch window in messages Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server MAY send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. apply to entire connection By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. confirm the requested qos This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. start a queue consumer This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, unless the queue was declared as private, and ideally, impose no limit except as defined by available resources. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The tag MUST NOT refer to an existing consumer. If the client attempts to create two consumers with the same non-empty tag the server MUST raise a connection exception with reply code 530 (not allowed). request exclusive access Request exclusive consumer access, meaning only this consumer can access the queue. If the server cannot grant exclusive access to the queue when asked, - because there are other consumers active - it MUST raise a channel exception with return code 403 (access refused). do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a new consumer The server provides the client with a consumer tag, which is used by the client for methods called on the consumer at a later stage. Holds the consumer tag specified by the client or provided by the server. end a queue consumer This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. The client may receive an abitrary number of messages in between sending the cancel method and receiving the cancel-ok reply. If the queue no longer exists when the client sends a cancel command, or the consumer has been cancelled for other reasons, this command has no effect. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a cancelled consumer This method confirms that the cancellation was completed. publish a message This method publishes a message to a specific exchange. The message will be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. If the exchange was declared as an internal exchange, the server MUST raise a channel exception with a reply code 403 (access refused). The exchange MAY refuse basic content in which case it MUST raise a channel exception with reply code 540 (not implemented). Message routing key Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. indicate mandatory routing This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. request immediate delivery This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. return a failed message This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Message routing key Specifies the routing key name specified when the message was published. notify the client of a consumer message This method delivers a message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. The server SHOULD track the number of times a message has been delivered to clients and when a message is redelivered a certain number of times - e.g. 5 times - without being acknowledged, the server SHOULD consider the message to be unprocessable (possibly causing client applications to abort), and move the message to a dead letter queue. Specifies the name of the exchange that the message was originally published to. Message routing key Specifies the routing key name specified when the message was published. direct access to a queue This method provides a direct access to the messages in a queue using a synchronous dialogue that is designed for specific types of application where synchronous functionality is more important than performance. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). provide client with a message This method delivers a message to the client following a get method. A message delivered by 'get-ok' must be acknowledged unless the no-ack option was set in the get method. Specifies the name of the exchange that the message was originally published to. If empty, the message was published to the default exchange. Message routing key Specifies the routing key name specified when the message was published. number of messages pending This field reports the number of messages pending on the queue, excluding the message being delivered. Note that this figure is indicative, not reliable, and can change arbitrarily as messages are added to the queue and removed by other clients. indicate no messages available This method tells the client that the queue has no messages available for the client. Cluster id For use by cluster applications, should not be used by client applications. acknowledge one or more messages This method acknowledges one or more messages delivered via the Deliver or Get-Ok methods. The client can ask to confirm a single message or a set of messages up to and including a specific message. acknowledge multiple messages If set to 1, the delivery tag is treated as "up to and including", so that the client can acknowledge multiple messages with a single method. If set to zero, the delivery tag refers to a single message. If the multiple field is 1, and the delivery tag is zero, tells the server to acknowledge all outstanding mesages. The server MUST validate that a non-zero delivery-tag refers to an delivered message, and raise a channel exception if this is not the case. reject an incoming message This method allows a client to reject a message. It can be used to interrupt and cancel large incoming messages, or return untreatable messages to their original queue. The server SHOULD be capable of accepting and process the Reject method while sending message content with a Deliver or Get-Ok method. I.e. the server should read and process incoming methods while sending output frames. To cancel a partially-send content, the server sends a content body frame of size 1 (i.e. with no data except the frame-end octet). The server SHOULD interpret this method as meaning that the client is unable to process the message at this time. A client MUST NOT use this method as a means of selecting messages to process. A rejected message MAY be discarded or dead-lettered, not necessarily passed to another client. requeue the message If this field is zero, the message will be discarded. If this bit is 1, the server will attempt to requeue the message. The server MUST NOT deliver the message to the same client within the context of the current channel. The recommended strategy is to attempt to deliver the message to an alternative consumer, and if that is not possible, to move the message to a dead-letter queue. The server MAY use more sophisticated tracking to hold the message on the queue and redeliver it to the same client at a later stage. redeliver unacknowledged messages. This method is only allowed on non-transacted channels. This method asks the broker to redeliver all unacknowledged messages on a specifieid channel. Zero or more messages may be redelivered. requeue the message If this field is zero, the message will be redelivered to the original recipient. If this bit is 1, the server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. The server MUST set the redelivered flag on all messages that are resent. The server MUST raise a channel exception if this is called on a transacted channel. work with file content The file class provides methods that support reliable file transfer. File messages have a specific set of properties that are required for interoperability with file transfer applications. File messages and acknowledgements are subject to channel transactions. Note that the file class does not provide message browsing methods; these are not compatible with the staging model. Applications that need browsable file transfer should use Basic content and the Basic class. file = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:OPEN S:OPEN-OK C:STAGE content / S:OPEN C:OPEN-OK S:STAGE content / C:PUBLISH / S:DELIVER / S:RETURN / C:ACK / C:REJECT The server MUST make a best-effort to hold file messages on a reliable storage mechanism. The server MUST NOT discard a file message in case of a queue overflow. The server MUST use the Channel.Flow method to slow or stop a file message publisher when necessary. The server MUST implement at least 2 priority levels for file messages, where priorities 0-4 and 5-9 are treated as two distinct levels. The server MAY implement up to 10 priority levels. The server MUST support both automatic and explicit acknowledgements on file content. MIME content type MIME content encoding Message header field table The message priority, 0 to 9 The destination to reply to The application message identifier The message filename The message timestamp Intra-cluster routing identifier specify quality of service This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. prefetch window in octets The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. May be set to zero, meaning "no specific limit". Note that other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. prefetch window in messages Specifies a prefetch window in terms of whole messages. This is compatible with some file API implementations. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server MAY send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. apply to entire connection By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. confirm the requested qos This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. start a queue consumer This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, unless the queue was declared as private, and ideally, impose no limit except as defined by available resources. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The tag MUST NOT refer to an existing consumer. If the client attempts to create two consumers with the same non-empty tag the server MUST raise a connection exception with reply code 530 (not allowed). request exclusive access Request exclusive consumer access, meaning only this consumer can access the queue. If the server cannot grant exclusive access to the queue when asked, - because there are other consumers active - it MUST raise a channel exception with return code 405 (resource locked). do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a new consumer This method provides the client with a consumer tag which it MUST use in methods that work with the consumer. Holds the consumer tag specified by the client or provided by the server. end a queue consumer This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a cancelled consumer This method confirms that the cancellation was completed. request to start staging This method requests permission to start staging a message. Staging means sending the message into a temporary area at the recipient end and then delivering the message by referring to this temporary area. Staging is how the protocol handles partial file transfers - if a message is partially staged and the connection breaks, the next time the sender starts to stage it, it can restart from where it left off. staging identifier This is the staging identifier. This is an arbitrary string chosen by the sender. For staging to work correctly the sender must use the same staging identifier when staging the same message a second time after recovery from a failure. A good choice for the staging identifier would be the SHA1 hash of the message properties data (including the original filename, revised time, etc.). message content size The size of the content in octets. The recipient may use this information to allocate or check available space in advance, to avoid "disk full" errors during staging of very large messages. The sender MUST accurately fill the content-size field. Zero-length content is permitted. confirm staging ready This method confirms that the recipient is ready to accept staged data. If the message was already partially-staged at a previous time the recipient will report the number of octets already staged. already staged amount The amount of previously-staged content in octets. For a new message this will be zero. The sender MUST start sending data from this octet offset in the message, counting from zero. The recipient MAY decide how long to hold partially-staged content and MAY implement staging by always discarding partially-staged content. However if it uses the file content type it MUST support the staging methods. stage message content This method stages the message, sending the message content to the recipient from the octet offset specified in the Open-Ok method. publish a message This method publishes a staged file message to a specific exchange. The file message will be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. If the exchange was declared as an internal exchange, the server MUST respond with a reply code 403 (access refused) and raise a channel exception. The exchange MAY refuse file content in which case it MUST respond with a reply code 540 (not implemented) and raise a channel exception. Message routing key Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. indicate mandatory routing This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. request immediate delivery This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. staging identifier This is the staging identifier of the message to publish. The message must have been staged. Note that a client can send the Publish method asynchronously without waiting for staging to finish. return a failed message This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Message routing key Specifies the routing key name specified when the message was published. notify the client of a consumer message This method delivers a staged file message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. The server SHOULD track the number of times a message has been delivered to clients and when a message is redelivered a certain number of times - e.g. 5 times - without being acknowledged, the server SHOULD consider the message to be unprocessable (possibly causing client applications to abort), and move the message to a dead letter queue. Specifies the name of the exchange that the message was originally published to. Message routing key Specifies the routing key name specified when the message was published. staging identifier This is the staging identifier of the message to deliver. The message must have been staged. Note that a server can send the Deliver method asynchronously without waiting for staging to finish. acknowledge one or more messages This method acknowledges one or more messages delivered via the Deliver method. The client can ask to confirm a single message or a set of messages up to and including a specific message. acknowledge multiple messages If set to 1, the delivery tag is treated as "up to and including", so that the client can acknowledge multiple messages with a single method. If set to zero, the delivery tag refers to a single message. If the multiple field is 1, and the delivery tag is zero, tells the server to acknowledge all outstanding mesages. The server MUST validate that a non-zero delivery-tag refers to an delivered message, and raise a channel exception if this is not the case. reject an incoming message This method allows a client to reject a message. It can be used to return untreatable messages to their original queue. Note that file content is staged before delivery, so the client will not use this method to interrupt delivery of a large message. The server SHOULD interpret this method as meaning that the client is unable to process the message at this time. A client MUST NOT use this method as a means of selecting messages to process. A rejected message MAY be discarded or dead-lettered, not necessarily passed to another client. requeue the message If this field is zero, the message will be discarded. If this bit is 1, the server will attempt to requeue the message. The server MUST NOT deliver the message to the same client within the context of the current channel. The recommended strategy is to attempt to deliver the message to an alternative consumer, and if that is not possible, to move the message to a dead-letter queue. The server MAY use more sophisticated tracking to hold the message on the queue and redeliver it to the same client at a later stage. work with streaming content The stream class provides methods that support multimedia streaming. The stream class uses the following semantics: one message is one packet of data; delivery is unacknowleged and unreliable; the consumer can specify quality of service parameters that the server can try to adhere to; lower-priority messages may be discarded in favour of high priority messages. stream = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:PUBLISH content / S:RETURN / S:DELIVER content The server SHOULD discard stream messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for stream messages, where priorities 0-4 and 5-9 are treated as two distinct levels. The server MAY implement up to 10 priority levels. The server MUST implement automatic acknowledgements on stream content. That is, as soon as a message is delivered to a client via a Deliver method, the server must remove it from the queue. MIME content type MIME content encoding Message header field table The message priority, 0 to 9 The message timestamp specify quality of service This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. prefetch window in octets The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. May be set to zero, meaning "no specific limit". Note that other prefetch limits may still apply. prefetch window in messages Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. transfer rate in octets/second Specifies a desired transfer rate in octets per second. This is usually determined by the application that uses the streaming data. A value of zero means "no limit", i.e. as rapidly as possible. The server MAY ignore the prefetch values and consume rates, depending on the type of stream and the ability of the server to queue and/or reply it. The server MAY drop low-priority messages in favour of high-priority messages. apply to entire connection By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. confirm the requested qos This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. start a queue consumer This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, unless the queue was declared as private, and ideally, impose no limit except as defined by available resources. Streaming applications SHOULD use different channels to select different streaming resolutions. AMQP makes no provision for filtering and/or transforming streams except on the basis of priority-based selective delivery of individual messages. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The tag MUST NOT refer to an existing consumer. If the client attempts to create two consumers with the same non-empty tag the server MUST raise a connection exception with reply code 530 (not allowed). request exclusive access Request exclusive consumer access, meaning only this consumer can access the queue. If the server cannot grant exclusive access to the queue when asked, - because there are other consumers active - it MUST raise a channel exception with return code 405 (resource locked). do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a new consumer This method provides the client with a consumer tag which it may use in methods that work with the consumer. Holds the consumer tag specified by the client or provided by the server. end a queue consumer This method cancels a consumer. Since message delivery is asynchronous the client may continue to receive messages for a short while after canceling a consumer. It may process or discard these as appropriate. do not send a reply method If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. confirm a cancelled consumer This method confirms that the cancellation was completed. publish a message This method publishes a message to a specific exchange. The message will be routed to queues as defined by the exchange configuration and distributed to any active consumers as appropriate. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. If the exchange was declared as an internal exchange, the server MUST respond with a reply code 403 (access refused) and raise a channel exception. The exchange MAY refuse stream content in which case it MUST respond with a reply code 540 (not implemented) and raise a channel exception. Message routing key Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. indicate mandatory routing This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. request immediate delivery This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. return a failed message This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Message routing key Specifies the routing key name specified when the message was published. notify the client of a consumer message This method delivers a message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. Specifies the name of the exchange that the message was originally published to. Specifies the name of the queue that the message came from. Note that a single channel can start many consumers on different queues. work with standard transactions Standard transactions provide so-called "1.5 phase commit". We can ensure that work is never lost, but there is a chance of confirmations being lost, so that messages may be resent. Applications that use standard transactions must be able to detect and ignore duplicate messages. An client using standard transactions SHOULD be able to track all messages received within a reasonable period, and thus detect and reject duplicates of the same message. It SHOULD NOT pass these to the application layer. tx = C:SELECT S:SELECT-OK / C:COMMIT S:COMMIT-OK / C:ROLLBACK S:ROLLBACK-OK select standard transaction mode This method sets the channel to use standard transactions. The client must use this method at least once on a channel before using the Commit or Rollback methods. confirm transaction mode This method confirms to the client that the channel was successfully set to use standard transactions. commit the current transaction This method commits all messages published and acknowledged in the current transaction. A new transaction starts immediately after a commit. confirm a successful commit This method confirms to the client that the commit succeeded. Note that if a commit fails, the server raises a channel exception. abandon the current transaction This method abandons all messages published and acknowledged in the current transaction. A new transaction starts immediately after a rollback. confirm a successful rollback This method confirms to the client that the rollback succeeded. Note that if an rollback fails, the server raises a channel exception. work with distributed transactions Distributed transactions provide so-called "2-phase commit". The AMQP distributed transaction model supports the X-Open XA architecture and other distributed transaction implementations. The Dtx class assumes that the server has a private communications channel (not AMQP) to a distributed transaction coordinator. dtx = C:SELECT S:SELECT-OK C:START S:START-OK select standard transaction mode This method sets the channel to use distributed transactions. The client must use this method at least once on a channel before using the Start method. confirm transaction mode This method confirms to the client that the channel was successfully set to use distributed transactions. start a new distributed transaction This method starts a new distributed transaction. This must be the first method on a new channel that uses the distributed transaction mode, before any methods that publish or consume messages. transaction identifier The distributed transaction key. This identifies the transaction so that the AMQP server can coordinate with the distributed transaction coordinator. confirm the start of a new distributed transaction This method confirms to the client that the transaction started. Note that if a start fails, the server raises a channel exception. methods for protocol tunneling. The tunnel methods are used to send blocks of binary data - which can be serialised AMQP methods or other protocol frames - between AMQP peers. tunnel = C:REQUEST / S:REQUEST Message header field table The identity of the tunnelling proxy The name or type of the message being tunnelled The message durability indicator The message broadcast mode sends a tunnelled method This method tunnels a block of binary data, which can be an encoded AMQP method or other data. The binary data is sent as the content for the Tunnel.Request method. meta data for the tunnelled block This field table holds arbitrary meta-data that the sender needs to pass to the recipient. test functional primitives of the implementation The test class provides methods for a peer to test the basic operational correctness of another peer. The test methods are intended to ensure that all peers respect at least the basic elements of the protocol, such as frame and content organisation and field types. We assume that a specially-designed peer, a "monitor client" would perform such tests. test = C:INTEGER S:INTEGER-OK / S:INTEGER C:INTEGER-OK / C:STRING S:STRING-OK / S:STRING C:STRING-OK / C:TABLE S:TABLE-OK / S:TABLE C:TABLE-OK / C:CONTENT S:CONTENT-OK / S:CONTENT C:CONTENT-OK test integer handling This method tests the peer's capability to correctly marshal integer data. octet test value An octet integer test value. short test value A short integer test value. long test value A long integer test value. long-long test value A long long integer test value. operation to test The client must execute this operation on the provided integer test fields and return the result. return sum of test values return lowest of test values return highest of test values report integer test result This method reports the result of an Integer method. result value The result of the tested operation. test string handling This method tests the peer's capability to correctly marshal string data. short string test value An short string test value. long string test value A long string test value. operation to test The client must execute this operation on the provided string test fields and return the result. return concatentation of test strings return shortest of test strings return longest of test strings report string test result This method reports the result of a String method. result value The result of the tested operation. test field table handling This method tests the peer's capability to correctly marshal field table data. field table of test values A field table of test values. operation to test on integers The client must execute this operation on the provided field table integer values and return the result. return sum of numeric field values return min of numeric field values return max of numeric field values operation to test on strings The client must execute this operation on the provided field table string values and return the result. return concatenation of string field values return shortest of string field values return longest of string field values report table test result This method reports the result of a Table method. integer result value The result of the tested integer operation. string result value The result of the tested string operation. test content handling This method tests the peer's capability to correctly marshal content. report content test result This method reports the result of a Content method. It contains the content checksum and echoes the original content as provided. content hash The 32-bit checksum of the content, calculated by adding the content into a 32-bit accumulator. PK=JG66vumi/resources/__init__.py"""Package for holding miscellaneous package data.""" PK=JGM.O.O vumi/resources/amqp-spec-0-9.xml Indicates that the method completed successfully. This reply code is reserved for future use - the current protocol design does not use positive confirmation and reply codes are sent only in case of an error. The client asked for a specific message that is no longer available. The message was delivered to another client, or was purged from the queue for some other reason. The client attempted to transfer content larger than the server could accept at the present time. The client may retry at a later time. When the exchange cannot route the result of a .Publish, most likely due to an invalid routing key. Only when the mandatory flag is set. When the exchange cannot deliver to a consumer when the immediate flag is set. As a result of pending data on the queue or the absence of any consumers of the queue. An operator intervened to close the connection for some reason. The client may retry at some later date. The client tried to work with an unknown virtual host. The client attempted to work with a server entity to which it has no access due to security settings. The client attempted to work with a server entity that does not exist. The client attempted to work with a server entity to which it has no access because another client is working with it. The client requested a method that was not allowed because some precondition failed. The client sent a malformed frame that the server could not decode. This strongly implies a programming error in the client. The client sent a frame that contained illegal values for one or more fields. This strongly implies a programming error in the client. The client sent an invalid sequence of frames, attempting to perform an operation that was considered invalid by the server. This usually implies a programming error in the client. The client attempted to work with a channel that had not been correctly opened. This most likely indicates a fault in the client layer. The server could not complete the method because it lacked sufficient resources. This may be due to the client creating too many of some type of entity. The client tried to work with some entity in a manner that is prohibited by the server, due to security settings or by some other criteria. The client tried to use functionality that is not implemented in the server. The server could not complete the method because of an internal error. The server may require intervention by an operator in order to resume normal operations. An access ticket granted by the server for a certain set of access rights within a specific realm. Access tickets are valid within the channel where they were created, and expire when the channel closes. Identifier for the consumer, valid within the current connection. The server-assigned and channel-specific delivery tag The delivery tag is valid only within the channel from which the message was received. I.e. a client MUST NOT receive a message on one channel and then acknowledge it on another. The server MUST NOT use a zero value for delivery tags. Zero is reserved for client use, meaning "all messages so far received". The exchange name is a client-selected string that identifies the exchange for publish methods. Exchange names may consist of any mixture of digits, letters, and underscores. Exchange names are scoped by the virtual host. Specifies the list of equivalent or alternative hosts that the server knows about, which will normally include the current server itself. Clients can cache this information and use it when reconnecting to a server after a failure. This field may be empty. If this field is set the server does not expect acknowledgements for messages. That is, when a message is delivered to the client the server automatically and silently acknowledges it on behalf of the client. This functionality increases performance but at the cost of reliability. Messages can get lost if a client dies before it can deliver them to the application. If the no-local field is set the server will not send messages to the connection that published them. Must start with a slash "/" and continue with path names separated by slashes. A path name consists of any combination of at least one of [A-Za-z0-9] plus zero or more of [.-_+!=:]. This string provides a set of peer properties, used for identification, debugging, and general information. The queue name identifies the queue within the vhost. Queue names may consist of any mixture of digits, letters, and underscores. This indicates that the message has been previously delivered to this or another client. The server SHOULD try to signal redelivered messages when it can. When redelivering a message that was not successfully acknowledged, the server SHOULD deliver it to the original client if possible. Create a shared queue and publish a message to the queue. Consume the message using explicit acknowledgements, but do not acknowledge the message. Close the connection, reconnect, and consume from the queue again. The message should arrive with the redelivered flag set. The client MUST NOT rely on the redelivered field but should take it as a hint that the message may already have been processed. A fully robust client must be able to track duplicate received messages on non-transacted, and locally-transacted channels. The reply code. The AMQ reply codes are defined as constants at the start of this formal specification. The localised reply text. This text can be logged as an aid to resolving issues. Specifies the destination to which the message is to be transferred. The destination can be empty, meaning the default exchange or consumer. The reject code must be one of 0 (generic) or 1 (immediate delivery was attempted but failed). Used for authentication, replay prevention, and encrypted bodies. The connection class provides methods for a client to establish a network connection to a server, and for both peers to operate the connection thereafter. connection = open-connection *use-connection close-connection open-connection = C:protocol-header S:START C:START-OK *challenge S:TUNE C:TUNE-OK C:OPEN S:OPEN-OK | S:REDIRECT challenge = S:SECURE C:SECURE-OK use-connection = *channel close-connection = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK This method starts the connection negotiation process by telling the client the protocol version that the server proposes, along with a list of security mechanisms which the client can use for authentication. If the server cannot support the protocol specified in the protocol header, it MUST close the socket connection without sending any response method. The client sends a protocol header containing an invalid protocol name. The server must respond by closing the connection. The server MUST provide a protocol version that is lower than or equal to that requested by the client in the protocol header. The client requests a protocol version that is higher than any valid implementation, e.g. 9.0. The server must respond with a current protocol version, e.g. 1.0. If the client cannot handle the protocol version suggested by the server it MUST close the socket connection. The server sends a protocol version that is lower than any valid implementation, e.g. 0.1. The client must respond by closing the connection. The protocol version, major component, as transmitted in the AMQP protocol header. This, combined with the protocol minor component fully describe the protocol version, which is written in the format major-minor. Hence, with major=1, minor=3, the protocol version would be "1-3". The protocol version, minor component, as transmitted in the AMQP protocol header. This, combined with the protocol major component fully describe the protocol version, which is written in the format major-minor. Hence, with major=1, minor=3, the protocol version would be "1-3". The properties SHOULD contain at least these fields: "host", specifying the server host name or address, "product", giving the name of the server product, "version", giving the name of the server version, "platform", giving the name of the operating system, "copyright", if appropriate, and "information", giving other general information. Client connects to server and inspects the server properties. It checks for the presence of the required fields. A list of the security mechanisms that the server supports, delimited by spaces. A list of the message locales that the server supports, delimited by spaces. The locale defines the language in which the server will send reply texts. The server MUST support at least the en_US locale. Client connects to server and inspects the locales field. It checks for the presence of the required locale(s). This method selects a SASL security mechanism. The properties SHOULD contain at least these fields: "product", giving the name of the client product, "version", giving the name of the client version, "platform", giving the name of the operating system, "copyright", if appropriate, and "information", giving other general information. A single security mechanisms selected by the client, which must be one of those specified by the server. The client SHOULD authenticate using the highest-level security profile it can handle from the list provided by the server. If the mechanism field does not contain one of the security mechanisms proposed by the server in the Start method, the server MUST close the connection without sending any further data. Client connects to server and sends an invalid security mechanism. The server must respond by closing the connection (a socket close, with no connection close negotiation). A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. A single message locale selected by the client, which must be one of those specified by the server. The SASL protocol works by exchanging challenges and responses until both peers have received sufficient information to authenticate each other. This method challenges the client to provide more information. Challenge information, a block of opaque binary data passed to the security mechanism. This method attempts to authenticate, passing a block of SASL data for the security mechanism at the server side. A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. This method proposes a set of connection configuration values to the client. The client can accept and/or adjust these. The maximum total number of channels that the server allows per connection. Zero means that the server does not impose a fixed limit, but the number of allowed channels may be limited by available server resources. The largest frame size that the server proposes for the connection. The client can negotiate a lower value. Zero means that the server does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Until the frame-max has been negotiated, both peers MUST accept frames of up to frame-min-size octets large, and the minimum negotiated value for frame-max is also frame-min-size. Client connects to server and sends a large properties field, creating a frame of frame-min-size octets. The server must accept this frame. The delay, in seconds, of the connection heartbeat that the server wants. Zero means the server does not want a heartbeat. This method sends the client's connection tuning parameters to the server. Certain fields are negotiated, others provide capability information. The maximum total number of channels that the client will use per connection. If the client specifies a channel max that is higher than the value provided by the server, the server MUST close the connection without attempting a negotiated close. The server may report the error in some fashion to assist implementors. The largest frame size that the client and server will use for the connection. Zero means that the client does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Note that the frame-max limit applies principally to content frames, where large contents can be broken into frames of arbitrary size. Until the frame-max has been negotiated, both peers MUST accept frames of up to frame-min-size octets large, and the minimum negotiated value for frame-max is also frame-min-size. If the client specifies a frame max that is higher than the value provided by the server, the server MUST close the connection without attempting a negotiated close. The server may report the error in some fashion to assist implementors. The delay, in seconds, of the connection heartbeat that the client wants. Zero means the client does not want a heartbeat. This method opens a connection to a virtual host, which is a collection of resources, and acts to separate multiple application domains within a server. The server may apply arbitrary limits per virtual host, such as the number of each type of entity that may be used, per connection and/or in total. The name of the virtual host to work with. If the server supports multiple virtual hosts, it MUST enforce a full separation of exchanges, queues, and all associated entities per virtual host. An application, connected to a specific virtual host, MUST NOT be able to access resources of another virtual host. The server SHOULD verify that the client has permission to access the specified virtual host. The client can specify zero or more capability names, delimited by spaces. The server can use this string to how to process the client's connection request. In a configuration with multiple collaborating servers, the server may respond to a Connection.Open method with a Connection.Redirect. The insist option tells the server that the client is insisting on a connection to the specified server. When the client uses the insist option, the server MUST NOT respond with a Connection.Redirect method. If it cannot accept the client's connection request it should respond by closing the connection with a suitable reply code. This method signals to the client that the connection is ready for use. This method redirects the client to another server, based on the requested virtual host and/or capabilities. When getting the Connection.Redirect method, the client SHOULD reconnect to the host specified, and if that host is not present, to any of the hosts specified in the known-hosts list. Specifies the server to connect to. This is an IP address or a DNS name, optionally followed by a colon and a port number. If no port number is specified, the client should use the default port number for the protocol. This method indicates that the sender wants to close the connection. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method any received method except the Close-OK method MUST be discarded. When the close is provoked by a method exception, this is the class of the method. When the close is provoked by a method exception, this is the ID of the method. This method confirms a Connection.Close method and tells the recipient that it is safe to release resources for the connection and close the socket. A peer that detects a socket closure without having received a Close-Ok handshake method SHOULD log the error. The channel class provides methods for a client to establish a channel to a server and for both peers to operate the channel thereafter. channel = open-channel *use-channel close-channel open-channel = C:OPEN S:OPEN-OK / C:RESUME S:OK use-channel = C:FLOW S:FLOW-OK / S:FLOW C:FLOW-OK / S:PING C:OK / C:PONG S:OK / C:PING S:OK / S:PONG C:OK / functional-class close-channel = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK This method opens a channel to the server. The client MUST NOT use this method on an already-opened channel. Client opens a channel and then reopens the same channel. Configures out-of-band transfers on this channel. The syntax and meaning of this field will be formally defined at a later date. This method signals to the client that the channel is ready for use. This method asks the peer to pause or restart the flow of content data. This is a simple flow-control mechanism that a peer can use to avoid overflowing its queues or otherwise finding itself receiving more messages than it can process. Note that this method is not intended for window control. The peer that receives a disable flow method should finish sending the current content frame, if any, then pause. When a new channel is opened, it is active (flow is active). Some applications assume that channels are inactive until started. To emulate this behaviour a client MAY open the channel, then pause it. When sending content frames, a peer SHOULD monitor the channel for incoming methods and respond to a Channel.Flow as rapidly as possible. A peer MAY use the Channel.Flow method to throttle incoming content data for internal reasons, for example, when exchanging data over a slower connection. The peer that requests a Channel.Flow method MAY disconnect and/or ban a peer that does not respect the request. This is to prevent badly-behaved clients from overwhelming a broker. If 1, the peer starts sending content frames. If 0, the peer stops sending content frames. Confirms to the peer that a flow command was received and processed. Confirms the setting of the processed flow method: 1 means the peer will start sending or continue to send content frames; 0 means it will not. This method indicates that the sender wants to close the channel. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method any received method except the Close-OK method MUST be discarded. When the close is provoked by a method exception, this is the class of the method. When the close is provoked by a method exception, this is the ID of the method. This method confirms a Channel.Close method and tells the recipient that it is safe to release resources for the channel. A peer that detects a socket closure without having received a Channel.Close-Ok handshake method SHOULD log the error. This method resume a previously interrupted channel. [WORK IN PROGRESS] Request that the recipient issue a pong request. [WORK IN PROGRESS] Issued after a ping request is received. Note that this is a request issued after receiving a ping, not a response to receiving a ping. [WORK IN PROGRESS] Signals normal completion of a method. The protocol control access to server resources using access tickets. A client must explicitly request access tickets before doing work. An access ticket grants a client the right to use a specific set of resources - called a "realm" - in specific ways. access = C:REQUEST S:REQUEST-OK This method requests an access ticket for an access realm. The server responds by granting the access ticket. If the client does not have access rights to the requested realm this causes a connection exception. Access tickets are a per-channel resource. Specifies the name of the realm to which the client is requesting access. The realm is a configured server-side object that collects a set of resources (exchanges, queues, etc.). If the channel has already requested an access ticket onto this realm, the previous ticket is destroyed and a new ticket is created with the requested access rights, if allowed. The client MUST specify a realm that is known to the server. The server makes an identical response for undefined realms as it does for realms that are defined but inaccessible to this client. Client specifies an undefined realm. Request exclusive access to the realm, meaning that this will be the only channel that uses the realm's resources. The client MAY NOT request exclusive access to a realm that has active access tickets, unless the same channel already had the only access ticket onto that realm. Client opens two channels and requests exclusive access to the same realm. Request message passive access to the specified access realm. Passive access lets a client get information about resources in the realm but not to make any changes to them. Request message active access to the specified access realm. Active access lets a client get create and delete resources in the realm. Request write access to the specified access realm. Write access lets a client publish messages to all exchanges in the realm. Request read access to the specified access realm. Read access lets a client consume messages from queues in the realm. This method provides the client with an access ticket. The access ticket is valid within the current channel and for the lifespan of the channel. The client MUST NOT use access tickets except within the same channel as originally granted. Client opens two channels, requests a ticket on one channel, and then tries to use that ticket in a second channel. Exchanges match and distribute messages across queues. Exchanges can be configured in the server or created at runtime. exchange = C:DECLARE S:DECLARE-OK / C:DELETE S:DELETE-OK The server MUST implement these standard exchange types: fanout, direct. Client attempts to declare an exchange with each of these standard types. The server SHOULD implement these standard exchange types: topic, headers. Client attempts to declare an exchange with each of these standard types. The server MUST, in each virtual host, pre-declare an exchange instance for each standard exchange type that it implements, where the name of the exchange instance, if defined, is "amq." followed by the exchange type name. The server MUST, in each virtual host, pre-declare at least two direct exchange instances: one named "amq.direct", the other with no public name that serves as a default exchange for Publish methods. Client creates a temporary queue and attempts to bind to each required exchange instance ("amq.fanout", "amq.direct", "amq.topic", and "amq.headers" if those types are defined). The server MUST pre-declare a direct exchange with no public name to act as the default exchange for content Publish methods and for default queue bindings. Client checks that the default exchange is active by specifying a queue binding with no exchange name, and publishing a message with a suitable routing key but without specifying the exchange name, then ensuring that the message arrives in the queue correctly. The server MUST NOT allow clients to access the default exchange except by specifying an empty exchange name in the Queue.Bind and content Publish methods. The server MAY implement other exchange types as wanted. This method creates an exchange if it does not already exist, and if the exchange exists, verifies that it is of the correct and expected class. The server SHOULD support a minimum of 16 exchanges per virtual host and ideally, impose no limit except as defined by available resources. The client creates as many exchanges as it can until the server reports an error; the number of exchanges successfully created must be at least sixteen. When a client defines a new exchange, this belongs to the access realm of the ticket used. All further work done with that exchange must be done with an access ticket for the same realm. The client MUST provide a valid access ticket giving "active" access to the realm in which the exchange exists or will be created, or "passive" access if the if-exists flag is set. Client creates access ticket with wrong access rights and attempts to use in this method. Exchange names starting with "amq." are reserved for pre-declared and standardised exchanges. The client MUST NOT attempt to create an exchange starting with "amq.". TODO. Each exchange belongs to one of a set of exchange types implemented by the server. The exchange types define the functionality of the exchange - i.e. how messages are routed through it. It is not valid or meaningful to attempt to change the type of an existing exchange. Exchanges cannot be redeclared with different types. The client MUST not attempt to redeclare an existing exchange with a different type than used in the original Exchange.Declare method. TODO. The client MUST NOT attempt to create an exchange with a type that the server does not support. TODO. If set, the server will not create the exchange. The client can use this to check whether an exchange exists without modifying the server state. If set, and the exchange does not already exist, the server MUST raise a channel exception with reply code 404 (not found). TODO. If set when creating a new exchange, the exchange will be marked as durable. Durable exchanges remain active when a server restarts. Non-durable exchanges (transient exchanges) are purged if/when a server restarts. The server MUST support both durable and transient exchanges. TODO. The server MUST ignore the durable field if the exchange already exists. TODO. If set, the exchange is deleted when all queues have finished using it. The server MUST ignore the auto-delete field if the exchange already exists. TODO. If set, the exchange may not be used directly by publishers, but only when bound to other exchanges. Internal exchanges are used to construct wiring that is not visible to applications. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This field is ignored if passive is 1. This method confirms a Declare method and confirms the name of the exchange, essential for automatically-named exchanges. This method deletes an exchange. When an exchange is deleted all queue bindings on the exchange are cancelled. The client MUST provide a valid access ticket giving "active" access rights to the exchange's access realm. Client creates access ticket with wrong access rights and attempts to use in this method. The client MUST NOT attempt to delete an exchange that does not exist. If set, the server will only delete the exchange if it has no queue bindings. If the exchange has queue bindings the server does not delete it but raises a channel exception instead. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms the deletion of an exchange. Queues store and forward messages. Queues can be configured in the server or created at runtime. Queues must be attached to at least one exchange in order to receive messages from publishers. queue = C:DECLARE S:DECLARE-OK / C:BIND S:BIND-OK / C:PURGE S:PURGE-OK / C:DELETE S:DELETE-OK A server MUST allow any content class to be sent to any queue, in any mix, and queue and deliver these content classes independently. Note that all methods that fetch content off queues are specific to a given content class. Client creates an exchange of each standard type and several queues that it binds to each exchange. It must then successfully send each of the standard content types to each of the available queues. This method creates or checks a queue. When creating a new queue the client can specify various properties that control the durability of the queue and its contents, and the level of sharing for the queue. The server MUST create a default binding for a newly-created queue to the default exchange, which is an exchange of type 'direct' and use the queue name as the routing key. Client creates a new queue, and then without explicitly binding it to an exchange, attempts to send a message through the default exchange binding, i.e. publish a message to the empty exchange, with the queue name as routing key. The server SHOULD support a minimum of 256 queues per virtual host and ideally, impose no limit except as defined by available resources. Client attempts to create as many queues as it can until the server reports an error. The resulting count must at least be 256. When a client defines a new queue, this belongs to the access realm of the ticket used. All further work done with that queue must be done with an access ticket for the same realm. The client MUST provide a valid access ticket giving "active" access to the realm in which the queue exists or will be created. Client creates access ticket with wrong access rights and attempts to use in this method. The queue name MAY be empty, in which case the server MUST create a new queue with a unique generated name and return this to the client in the Declare-Ok method. Client attempts to create several queues with an empty name. The client then verifies that the server-assigned names are unique and different. Queue names starting with "amq." are reserved for pre-declared and standardised server queues. A client MAY NOT attempt to declare a queue with a name that starts with "amq." and the passive option set to zero. A client attempts to create a queue with a name starting with "amq." and with the passive option set to zero. If set, the server will not create the queue. This field allows the client to assert the presence of a queue without modifying the server state. The client MAY ask the server to assert that a queue exists without creating the queue if not. If the queue does not exist, the server treats this as a failure. Client declares an existing queue with the passive option and expects the server to respond with a declare-ok. Client then attempts to declare a non-existent queue with the passive option, and the server must close the channel with the correct reply-code. If set when creating a new queue, the queue will be marked as durable. Durable queues remain active when a server restarts. Non-durable queues (transient queues) are purged if/when a server restarts. Note that durable queues do not necessarily hold persistent messages, although it does not make sense to send persistent messages to a transient queue. The server MUST recreate the durable queue after a restart. A client creates a durable queue. The server is then restarted. The client then attempts to send a message to the queue. The message should be successfully delivered. The server MUST support both durable and transient queues. A client creates two named queues, one durable and one transient. The server MUST ignore the durable field if the queue already exists. A client creates two named queues, one durable and one transient. The client then attempts to declare the two queues using the same names again, but reversing the value of the durable flag in each case. Verify that the queues still exist with the original durable flag values. Exclusive queues may only be consumed from by the current connection. Setting the 'exclusive' flag always implies 'auto-delete'. The server MUST support both exclusive (private) and non-exclusive (shared) queues. A client creates two named queues, one exclusive and one non-exclusive. The client MAY NOT attempt to declare any existing and exclusive queue on multiple connections. A client declares an exclusive named queue. A second client on a different connection attempts to declare a queue of the same name. If set, the queue is deleted when all consumers have finished using it. Last consumer can be cancelled either explicitly or because its channel is closed. If there was no consumer ever on the queue, it won't be deleted. The server MUST ignore the auto-delete field if the queue already exists. A client creates two named queues, one as auto-delete and one explicit-delete. The client then attempts to declare the two queues using the same names again, but reversing the value of the auto-delete field in each case. Verify that the queues still exist with the original auto-delete flag values. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This field is ignored if passive is 1. This method confirms a Declare method and confirms the name of the queue, essential for automatically-named queues. Reports the name of the queue. If the server generated a queue name, this field contains that name. Reports the number of messages in the queue, which will be zero for newly-created queues. Reports the number of active consumers for the queue. Note that consumers can suspend activity (Channel.Flow) in which case they do not appear in this count. This method binds a queue to an exchange. Until a queue is bound it will not receive any messages. In a classic messaging model, store-and-forward queues are bound to a direct exchange and subscription queues are bound to a topic exchange. A server MUST allow ignore duplicate bindings - that is, two or more bind methods for a specific queue, with identical arguments - without treating these as an error. A client binds a named queue to an exchange. The client then repeats the bind (with identical arguments). If a bind fails, the server MUST raise a connection exception. TODO The server MUST NOT allow a durable queue to bind to a transient exchange. A client creates a transient exchange. The client then declares a named durable queue and then attempts to bind the transient exchange to the durable queue. Bindings for durable queues are automatically durable and the server SHOULD restore such bindings after a server restart. A server creates a named durable queue and binds it to a durable exchange. The server is restarted. The client then attempts to use the queue/exchange combination. If the client attempts to bind to an exchange that was declared as internal, the server MUST raise a connection exception with reply code 530 (not allowed). A client attempts to bind a named queue to an internal exchange. The server SHOULD support at least 4 bindings per queue, and ideally, impose no limit except as defined by available resources. A client creates a named queue and attempts to bind it to 4 different non-internal exchanges. The client provides a valid access ticket giving "active" access rights to the queue's access realm. Specifies the name of the queue to bind. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. A client MUST NOT be allowed to bind a non-existent and unnamed queue (i.e. empty queue name) to an exchange. A client attempts to bind with an unnamed (empty) queue name to an exchange. A client MUST NOT be allowed to bind a non-existent queue (i.e. not previously declared) to an exchange. A client attempts to bind an undeclared queue name to an exchange. A client MUST NOT be allowed to bind a queue to a non-existent exchange. A client attempts to bind an named queue to a undeclared exchange. Specifies the routing key for the binding. The routing key is used for routing messages depending on the exchange configuration. Not all exchanges use a routing key - refer to the specific exchange documentation. If the queue name is empty, the server uses the last queue declared on the channel. If the routing key is also empty, the server uses this queue name for the routing key as well. If the queue name is provided but the routing key is empty, the server does the binding with that empty routing key. The meaning of empty routing keys depends on the exchange implementation. If a message queue binds to a direct exchange using routing key K and a publisher sends the exchange a message with routing key R, then the message MUST be passed to the message queue if K = R. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of arguments for the binding. The syntax and semantics of these arguments depends on the exchange class. This method confirms that the bind was successful. This method unbinds a queue from an exchange. If a unbind fails, the server MUST raise a connection exception. The client provides a valid access ticket giving "active" access rights to the queue's access realm. Specifies the name of the queue to unbind. If the queue does not exist the server MUST raise a channel exception with reply code 404 (not found). The name of the exchange to unbind from. If the exchange does not exist the server MUST raise a channel exception with reply code 404 (not found). Specifies the routing key of the binding to unbind. Specifies the arguments of the binding to unbind. This method confirms that the unbind was successful. This method removes all messages from a queue. It does not cancel consumers. Purged messages are deleted without any formal "undo" mechanism. A call to purge MUST result in an empty queue. On transacted channels the server MUST not purge messages that have already been sent to a client but not yet acknowledged. The server MAY implement a purge queue or log that allows system administrators to recover accidentally-purged messages. The server SHOULD NOT keep purged messages in the same storage spaces as the live messages since the volumes of purged messages may get very large. The access ticket must be for the access realm that holds the queue. The client MUST provide a valid access ticket giving "read" access rights to the queue's access realm. Note that purging a queue is equivalent to reading all messages and discarding them. Specifies the name of the queue to purge. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). The queue MUST exist. Attempting to purge a non-existing queue MUST cause a channel exception. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms the purge of a queue. Reports the number of messages purged. This method deletes a queue. When a queue is deleted any pending messages are sent to a dead-letter queue if this is defined in the server configuration, and all consumers on the queue are cancelled. The server SHOULD use a dead-letter queue to hold messages that were pending on a deleted queue, and MAY provide facilities for a system administrator to move these messages back to an active queue. The client provides a valid access ticket giving "active" access rights to the queue's access realm. Specifies the name of the queue to delete. If the queue name is empty, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). The queue must exist. If the client attempts to delete a non-existing queue the server MUST raise a channel exception with reply code 404 (not found). If set, the server will only delete the queue if it has no consumers. If the queue has consumers the server does does not delete it but raises a channel exception instead. The server MUST respect the if-unused flag when deleting a queue. If set, the server will only delete the queue if it has no messages. If the queue is not empty the server MUST raise a channel exception with reply code 406 (precondition failed). If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms the deletion of a queue. Reports the number of messages purged. The Basic class provides methods that support an industry-standard messaging model. basic = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:PUBLISH content / S:RETURN content / S:DELIVER content / C:GET ( S:GET-OK content / S:GET-EMPTY ) / C:ACK / C:REJECT The server SHOULD respect the persistent property of basic messages and SHOULD make a best-effort to hold persistent basic messages on a reliable storage mechanism. Send a persistent message to queue, stop server, restart server and then verify whether message is still present. Assumes that queues are durable. Persistence without durable queues makes no sense. The server MUST NOT discard a persistent basic message in case of a queue overflow. Create a queue overflow situation with persistent messages and verify that messages do not get lost (presumably the server will write them to disk). The server MAY use the Channel.Flow method to slow or stop a basic message publisher when necessary. Create a queue overflow situation with non-persistent messages and verify whether the server responds with Channel.Flow or not. Repeat with persistent messages. The server MAY overflow non-persistent basic messages to persistent storage. The server MAY discard or dead-letter non-persistent basic messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for basic messages, where priorities 0-4 and 5-9 are treated as two distinct levels. Send a number of priority 0 messages to a queue. Send one priority 9 message. Consume messages from the queue and verify that the first message received was priority 9. The server MAY implement up to 10 priority levels. Send a number of messages with mixed priorities to a queue, so that all priority values from 0 to 9 are exercised. A good scenario would be ten messages in low-to-high priority. Consume from queue and verify how many priority levels emerge. The server MUST deliver messages of the same priority in order irrespective of their individual persistence. Send a set of messages with the same priority but different persistence settings to a queue. Consume and verify that messages arrive in same order as originally published. The server MUST support automatic acknowledgements on Basic content, i.e. consumers with the no-ack field set to FALSE. Create a queue and a consumer using automatic acknowledgements. Publish a set of messages to the queue. Consume the messages and verify that all messages are received. The server MUST support explicit acknowledgements on Basic content, i.e. consumers with the no-ack field set to TRUE. Create a queue and a consumer using explicit acknowledgements. Publish a set of messages to the queue. Consume the messages but acknowledge only half of them. Disconnect and reconnect, and consume from the queue. Verify that the remaining messages are received. This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. The server will send a message in advance if it is equal to or smaller in size than the available prefetch size (and also falls into other prefetch limits). May be set to zero, meaning "no specific limit", although other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. The server MUST ignore this setting when the client is not processing any messages - i.e. the prefetch size does not limit the transfer of single messages to a client, only the sending in advance of more messages while the client still has one or more unacknowledged messages. Define a QoS prefetch-size limit and send a single message that exceeds that limit. Verify that the message arrives correctly. Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server may send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. Define a QoS prefetch-size limit and a prefetch-count limit greater than one. Send multiple messages that exceed the prefetch size. Verify that no more than one message arrives at once. By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, and ideally, impose no limit except as defined by available resources. Create a queue and create consumers on that queue until the server closes the connection. Verify that the number of consumers created was at least sixteen and report the total number. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Attempt to create a consumer with an invalid (non-zero) access ticket. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the queue name is empty the client MUST have previously declared a queue using this channel. Attempt to create a consumer with an empty queue name and no previously declared queue on the channel. Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The client MUST NOT specify a tag that refers to an existing consumer. Attempt to create two consumers with the same non-empty tag. The consumer tag is valid only within the channel from which the consumer was created. I.e. a client MUST NOT create a consumer in one channel and then use it in another. Attempt to create a consumer in one channel, then use in another channel, in which consumers have also been created (to test that the server uses unique consumer tags). Request exclusive consumer access, meaning only this consumer can access the queue. The client MAY NOT gain exclusive access to a queue that already has active consumers. Open two connections to a server, and in one connection create a shared (non-exclusive) queue and then consume from the queue. In the second connection attempt to consume from the same queue using the exclusive option. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of filters for the consume. The syntax and semantics of these filters depends on the providers implementation. The server provides the client with a consumer tag, which is used by the client for methods called on the consumer at a later stage. Holds the consumer tag specified by the client or provided by the server. This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. The client may receive an arbitrary number of messages in between sending the cancel method and receiving the cancel-ok reply. If the queue does not exist the server MUST ignore the cancel method, so long as the consumer tag is valid for that channel. TODO. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms that the cancellation was completed. This method publishes a message to a specific exchange. The message will be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. TODO. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. TODO. If the exchange was declared as an internal exchange, the server MUST raise a channel exception with a reply code 403 (access refused). TODO. The exchange MAY refuse basic content in which case it MUST raise a channel exception with reply code 540 (not implemented). TODO. Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. TODO. This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. TODO. This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Specifies the routing key name specified when the message was published. This method delivers a message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. The server SHOULD track the number of times a message has been delivered to clients and when a message is redelivered a certain number of times - e.g. 5 times - without being acknowledged, the server SHOULD consider the message to be unprocessable (possibly causing client applications to abort), and move the message to a dead letter queue. TODO. Specifies the name of the exchange that the message was originally published to. Specifies the routing key name specified when the message was published. This method provides a direct access to the messages in a queue using a synchronous dialogue that is designed for specific types of application where synchronous functionality is more important than performance. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. TODO. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). TODO. This method delivers a message to the client following a get method. A message delivered by 'get-ok' must be acknowledged unless the no-ack option was set in the get method. Specifies the name of the exchange that the message was originally published to. If empty, the message was published to the default exchange. Specifies the routing key name specified when the message was published. This field reports the number of messages pending on the queue, excluding the message being delivered. Note that this figure is indicative, not reliable, and can change arbitrarily as messages are added to the queue and removed by other clients. This method tells the client that the queue has no messages available for the client. For use by cluster applications, should not be used by client applications. This method acknowledges one or more messages delivered via the Deliver or Get-Ok methods. The client can ask to confirm a single message or a set of messages up to and including a specific message. If set to 1, the delivery tag is treated as "up to and including", so that the client can acknowledge multiple messages with a single method. If set to zero, the delivery tag refers to a single message. If the multiple field is 1, and the delivery tag is zero, tells the server to acknowledge all outstanding messages. The server MUST validate that a non-zero delivery-tag refers to an delivered message, and raise a channel exception if this is not the case. TODO. This method allows a client to reject a message. It can be used to interrupt and cancel large incoming messages, or return untreatable messages to their original queue. The server SHOULD be capable of accepting and process the Reject method while sending message content with a Deliver or Get-Ok method. I.e. the server should read and process incoming methods while sending output frames. To cancel a partially-send content, the server sends a content body frame of size 1 (i.e. with no data except the frame-end octet). The server SHOULD interpret this method as meaning that the client is unable to process the message at this time. TODO. A client MUST NOT use this method as a means of selecting messages to process. A rejected message MAY be discarded or dead-lettered, not necessarily passed to another client. TODO. If this field is zero, the message will be discarded. If this bit is 1, the server will attempt to requeue the message. The server MUST NOT deliver the message to the same client within the context of the current channel. The recommended strategy is to attempt to deliver the message to an alternative consumer, and if that is not possible, to move the message to a dead-letter queue. The server MAY use more sophisticated tracking to hold the message on the queue and redeliver it to the same client at a later stage. TODO. This method asks the broker to redeliver all unacknowledged messages on a specified channel. Zero or more messages may be redelivered. This method is only allowed on non-transacted channels. The server MUST set the redelivered flag on all messages that are resent. TODO. The server MUST raise a channel exception if this is called on a transacted channel. TODO. If this field is zero, the message will be redelivered to the original recipient. If this bit is 1, the server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. The file class provides methods that support reliable file transfer. File messages have a specific set of properties that are required for interoperability with file transfer applications. File messages and acknowledgements are subject to channel transactions. Note that the file class does not provide message browsing methods; these are not compatible with the staging model. Applications that need browsable file transfer should use Basic content and the Basic class. file = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:OPEN S:OPEN-OK C:STAGE content / S:OPEN C:OPEN-OK S:STAGE content / C:PUBLISH / S:DELIVER / S:RETURN / C:ACK / C:REJECT The server MUST make a best-effort to hold file messages on a reliable storage mechanism. The server MUST NOT discard a file message in case of a queue overflow. The server MUST use the Channel.Flow method to slow or stop a file message publisher when necessary. The server MUST implement at least 2 priority levels for file messages, where priorities 0-4 and 5-9 are treated as two distinct levels. The server MAY implement up to 10 priority levels. The server MUST support both automatic and explicit acknowledgements on file content. This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. May be set to zero, meaning "no specific limit". Note that other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. Specifies a prefetch window in terms of whole messages. This is compatible with some file API implementations. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server MAY send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, unless the queue was declared as private, and ideally, impose no limit except as defined by available resources. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The tag MUST NOT refer to an existing consumer. If the client attempts to create two consumers with the same non-empty tag the server MUST raise a connection exception with reply code 530 (not allowed). Request exclusive consumer access, meaning only this consumer can access the queue. If the server cannot grant exclusive access to the queue when asked, - because there are other consumers active - it MUST raise a channel exception with return code 405 (resource locked). If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of filters for the consume. The syntax and semantics of these filters depends on the providers implementation. This method provides the client with a consumer tag which it MUST use in methods that work with the consumer. Holds the consumer tag specified by the client or provided by the server. This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms that the cancellation was completed. This method requests permission to start staging a message. Staging means sending the message into a temporary area at the recipient end and then delivering the message by referring to this temporary area. Staging is how the protocol handles partial file transfers - if a message is partially staged and the connection breaks, the next time the sender starts to stage it, it can restart from where it left off. This is the staging identifier. This is an arbitrary string chosen by the sender. For staging to work correctly the sender must use the same staging identifier when staging the same message a second time after recovery from a failure. A good choice for the staging identifier would be the SHA1 hash of the message properties data (including the original filename, revised time, etc.). The size of the content in octets. The recipient may use this information to allocate or check available space in advance, to avoid "disk full" errors during staging of very large messages. The sender MUST accurately fill the content-size field. Zero-length content is permitted. This method confirms that the recipient is ready to accept staged data. If the message was already partially-staged at a previous time the recipient will report the number of octets already staged. The amount of previously-staged content in octets. For a new message this will be zero. The sender MUST start sending data from this octet offset in the message, counting from zero. The recipient MAY decide how long to hold partially-staged content and MAY implement staging by always discarding partially-staged content. However if it uses the file content type it MUST support the staging methods. This method stages the message, sending the message content to the recipient from the octet offset specified in the Open-Ok method. This method publishes a staged file message to a specific exchange. The file message will be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. If the exchange was declared as an internal exchange, the server MUST respond with a reply code 403 (access refused) and raise a channel exception. The exchange MAY refuse file content in which case it MUST respond with a reply code 540 (not implemented) and raise a channel exception. Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. This is the staging identifier of the message to publish. The message must have been staged. Note that a client can send the Publish method asynchronously without waiting for staging to finish. This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Specifies the routing key name specified when the message was published. This method delivers a staged file message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. The server SHOULD track the number of times a message has been delivered to clients and when a message is redelivered a certain number of times - e.g. 5 times - without being acknowledged, the server SHOULD consider the message to be unprocessable (possibly causing client applications to abort), and move the message to a dead letter queue. Specifies the name of the exchange that the message was originally published to. Specifies the routing key name specified when the message was published. This is the staging identifier of the message to deliver. The message must have been staged. Note that a server can send the Deliver method asynchronously without waiting for staging to finish. This method acknowledges one or more messages delivered via the Deliver method. The client can ask to confirm a single message or a set of messages up to and including a specific message. If set to 1, the delivery tag is treated as "up to and including", so that the client can acknowledge multiple messages with a single method. If set to zero, the delivery tag refers to a single message. If the multiple field is 1, and the delivery tag is zero, tells the server to acknowledge all outstanding messages. The server MUST validate that a non-zero delivery-tag refers to an delivered message, and raise a channel exception if this is not the case. This method allows a client to reject a message. It can be used to return untreatable messages to their original queue. Note that file content is staged before delivery, so the client will not use this method to interrupt delivery of a large message. The server SHOULD interpret this method as meaning that the client is unable to process the message at this time. A client MUST NOT use this method as a means of selecting messages to process. A rejected message MAY be discarded or dead-lettered, not necessarily passed to another client. If this field is zero, the message will be discarded. If this bit is 1, the server will attempt to requeue the message. The server MUST NOT deliver the message to the same client within the context of the current channel. The recommended strategy is to attempt to deliver the message to an alternative consumer, and if that is not possible, to move the message to a dead-letter queue. The server MAY use more sophisticated tracking to hold the message on the queue and redeliver it to the same client at a later stage. The stream class provides methods that support multimedia streaming. The stream class uses the following semantics: one message is one packet of data; delivery is unacknowledged and unreliable; the consumer can specify quality of service parameters that the server can try to adhere to; lower-priority messages may be discarded in favour of high priority messages. stream = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:PUBLISH content / S:RETURN / S:DELIVER content The server SHOULD discard stream messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for stream messages, where priorities 0-4 and 5-9 are treated as two distinct levels. The server MAY implement up to 10 priority levels. The server MUST implement automatic acknowledgements on stream content. That is, as soon as a message is delivered to a client via a Deliver method, the server must remove it from the queue. This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. May be set to zero, meaning "no specific limit". Note that other prefetch limits may still apply. Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. Specifies a desired transfer rate in octets per second. This is usually determined by the application that uses the streaming data. A value of zero means "no limit", i.e. as rapidly as possible. The server MAY ignore the prefetch values and consume rates, depending on the type of stream and the ability of the server to queue and/or reply it. The server MAY drop low-priority messages in favour of high-priority messages. By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, unless the queue was declared as private, and ideally, impose no limit except as defined by available resources. Streaming applications SHOULD use different channels to select different streaming resolutions. AMQP makes no provision for filtering and/or transforming streams except on the basis of priority-based selective delivery of individual messages. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). Specifies the identifier for the consumer. The consumer tag is local to a connection, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The tag MUST NOT refer to an existing consumer. If the client attempts to create two consumers with the same non-empty tag the server MUST raise a connection exception with reply code 530 (not allowed). Request exclusive consumer access, meaning only this consumer can access the queue. If the server cannot grant exclusive access to the queue when asked, - because there are other consumers active - it MUST raise a channel exception with return code 405 (resource locked). If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. A set of filters for the consume. The syntax and semantics of these filters depends on the providers implementation. This method provides the client with a consumer tag which it may use in methods that work with the consumer. Holds the consumer tag specified by the client or provided by the server. This method cancels a consumer. Since message delivery is asynchronous the client may continue to receive messages for a short while after cancelling a consumer. It may process or discard these as appropriate. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. This method confirms that the cancellation was completed. This method publishes a message to a specific exchange. The message will be routed to queues as defined by the exchange configuration and distributed to any active consumers as appropriate. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The server MUST accept a blank exchange name to mean the default exchange. If the exchange was declared as an internal exchange, the server MUST respond with a reply code 403 (access refused) and raise a channel exception. The exchange MAY refuse stream content in which case it MUST respond with a reply code 540 (not implemented) and raise a channel exception. Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. Specifies the routing key name specified when the message was published. This method delivers a message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. Specifies the name of the exchange that the message was originally published to. Specifies the name of the queue that the message came from. Note that a single channel can start many consumers on different queues. Standard transactions provide so-called "1.5 phase commit". We can ensure that work is never lost, but there is a chance of confirmations being lost, so that messages may be resent. Applications that use standard transactions must be able to detect and ignore duplicate messages. An client using standard transactions SHOULD be able to track all messages received within a reasonable period, and thus detect and reject duplicates of the same message. It SHOULD NOT pass these to the application layer. tx = C:SELECT S:SELECT-OK / C:COMMIT S:COMMIT-OK / C:ROLLBACK S:ROLLBACK-OK This method sets the channel to use standard transactions. The client must use this method at least once on a channel before using the Commit or Rollback methods. This method confirms to the client that the channel was successfully set to use standard transactions. This method commits all messages published and acknowledged in the current transaction. A new transaction starts immediately after a commit. This method confirms to the client that the commit succeeded. Note that if a commit fails, the server raises a channel exception. This method abandons all messages published and acknowledged in the current transaction. A new transaction starts immediately after a rollback. This method confirms to the client that the rollback succeeded. Note that if an rollback fails, the server raises a channel exception. Distributed transactions provide so-called "2-phase commit". The AMQP distributed transaction model supports the X-Open XA architecture and other distributed transaction implementations. The Dtx class assumes that the server has a private communications channel (not AMQP) to a distributed transaction coordinator. dtx = C:SELECT S:SELECT-OK C:START S:START-OK This method sets the channel to use distributed transactions. The client must use this method at least once on a channel before using the Start method. This method confirms to the client that the channel was successfully set to use distributed transactions. This method starts a new distributed transaction. This must be the first method on a new channel that uses the distributed transaction mode, before any methods that publish or consume messages. The distributed transaction key. This identifies the transaction so that the AMQP server can coordinate with the distributed transaction coordinator. This method confirms to the client that the transaction started. Note that if a start fails, the server raises a channel exception. The tunnel methods are used to send blocks of binary data - which can be serialised AMQP methods or other protocol frames - between AMQP peers. tunnel = C:REQUEST / S:REQUEST This method tunnels a block of binary data, which can be an encoded AMQP method or other data. The binary data is sent as the content for the Tunnel.Request method. This field table holds arbitrary meta-data that the sender needs to pass to the recipient. [WORK IN PROGRESS] The message class provides methods that support an industry-standard messaging model. message = C:QOS S:OK / C:CONSUME S:OK / C:CANCEL S:OK / C:TRANSFER ( S:OK / S:REJECT ) / S:TRANSFER ( C:OK / C:REJECT ) / C:GET ( S:OK / S:EMPTY ) / C:RECOVER S:OK / C:OPEN S:OK / S:OPEN C:OK / C:APPEND S:OK / S:APPEND C:OK / C:CLOSE S:OK / S:CLOSE C:OK / C:CHECKPOINT S:OK / S:CHECKPOINT C:OK / C:RESUME S:OFFSET / S:RESUME C:OFFSET The server SHOULD respect the persistent property of messages and SHOULD make a best-effort to hold persistent mess ages on a reliable storage mechanism. Send a persistent message to queue, stop server, restart server and then verify whether message is still present. Assumes that queues are durable. Persistence without durable queues makes no sense. The server MUST NOT discard a persistent message in case of a queue overflow. Create a queue overflow situation with persistent messages and verify that messages do not get lost (presumably the server will write them to disk). The server MAY use the Channel.Flow method to slow or stop a message publisher when necessary. Create a queue overflow situation with non-persistent messages and verify whether the server responds with Channel.Flow or not. Repeat with persistent messages. The server MAY overflow non-persistent messages to persistent storage. The server MAY discard or dead-letter non-persistent messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for messages, where priorities 0-4 and 5-9 are treated as two distinct levels. Send a number of priority 0 messages to a queue. Send one priority 9 message. Consume messages from the queue and verify that the first message received was priority 9. The server MAY implement up to 10 priority levels. Send a number of messages with mixed priorities to a queue, so that all priority values from 0 to 9 are exercised. A good scenario would be ten messages in low-to-high priority. Consume from queue and verify how many priority levels emerge. The server MUST deliver messages of the same priority in order irrespective of their individual persistence. Send a set of messages with the same priority but different persistence settings to a queue. Consume and verify that messages arrive in same order as originally published. The server MUST support automatic acknowledgements on messages, i.e. consumers with the no-ack field set to FALSE. Create a queue and a consumer using automatic acknowledgements. Publish a set of messages to the queue. Consume the messages and verify that all messages are received. The server MUST support explicit acknowledgements on messages, i.e. consumers with the no-ack field set to TRUE. Create a queue and a consumer using explicit acknowledgements. Publish a set of messages to the queue. Consume the messages but acknowledge only half of them. Disconnect and reconnect, and consume from the queue. Verify that the remaining messages are received. [WORK IN PROGRESS] This method transfers a message between two peers. When a client uses this method to publish a message to a broker, the destination identifies a specific exchange. The message will then be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. In the asynchronous message delivery model, the client starts a consumer using the Consume method and passing in a destination, then the broker responds with transfer methods to the specified destination as and when messages arrive for that consumer. If synchronous message delivery is required, the client may issue a get request which on success causes a single message to be transferred to the specified destination. Message acknowledgement is signalled by the return result of this method. The recipient MUST NOT return ok before the message has been processed as defined by the QoS settings. The client MUST provide a valid access ticket giving "write" access rights to the access realm for the exchange. Specifies the destination to which the message is to be transferred. The destination can be empty, meaning the default exchange or consumer. If the destination is specified, and that exchange or consumer does not exist, the peer must raise a channel exception. The server MUST accept a blank destination to mean the default exchange. If the destination refers to an internal exchange, the server MUST raise a channel exception with a reply code 403 (access refused). A destination MAY refuse message content in which case it MUST raise a channel exception with reply code 540 (not implemented). This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will reject the message. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. If this is set to a non zero value then a message expiration time will be computed based on the current time plus this value. Messages that live longer than their expiration time will be discarded (or dead lettered). If a message is transfered between brokers before delivery to a final consumer the ttl should be decremented before peer to peer transfer and both timestamp and expiration should be cleared. Set on arrival by the broker. The expiration header assigned by the broker. After receiving the message the broker sets expiration to the sum of the ttl specified in the publish method and the current time. (ttl = expiration - timestamp) [WORK IN PROGRESS] This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were created on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, and ideally, impose no limit except as defined by available resources. Create a queue and create consumers on that queue until the server closes the connection. Verify that the number of consumers created was at least sixteen and report the total number. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Attempt to create a consumer with an invalid (non-zero) access ticket. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the queue name is empty the client MUST have previously declared a queue using this channel. Attempt to create a consumer with an empty queue name and no previously declared queue on the channel. Specifies the destination for the consumer. The destination is local to a connection, so two clients can use the same destination. The client MUST NOT specify a destination that refers to an existing consumer. Attempt to create two consumers with the same non-empty destination. The destination is valid only within the channel from which the consumer was created. I.e. a client MUST NOT create a consumer in one channel and then use it in another. Attempt to create a consumer in one channel, then use in another channel, in which consumers have also been created (to test that the server uses unique destinations). Request exclusive consumer access, meaning only this consumer can access the queue. The client MAY NOT gain exclusive access to a queue that already has active consumers. Open two connections to a server, and in one connection create a shared (non-exclusive) queue and then consume from the queue. In the second connection attempt to consume from the same queue using the exclusive option. A set of filters for the consume. The syntax and semantics of these filters depends on the providers implementation. [WORK IN PROGRESS] This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. The client may receive an arbitrary number of messages in between sending the cancel method and receiving the cancel-ok reply. If the queue does not exist the server MUST ignore the cancel method, so long as the consumer tag is valid for that channel. [WORK IN PROGRESS] This method provides a direct access to the messages in a queue using a synchronous dialogue that is designed for specific types of application where synchronous functionality is more important than performance. The client MUST provide a valid access ticket giving "read" access rights to the realm for the queue. Specifies the name of the queue to consume from. If the queue name is null, refers to the current queue for the channel, which is the last declared queue. If the client did not previously declare a queue, and the queue name in this method is empty, the server MUST raise a connection exception with reply code 530 (not allowed). On normal completion of the get request (i.e. a response of ok). A message will be transferred to the supplied destination. [WORK IN PROGRESS] This method asks the broker to redeliver all unacknowledged messages on a specified channel. Zero or more messages may be redelivered. This method is only allowed on non-transacted channels. The server MUST set the redelivered flag on all messages that are resent. The server MUST raise a channel exception if this is called on a transacted channel. If this field is zero, the message will be redelivered to the original recipient. If this bit is 1, the server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. [WORK IN PROGRESS] This method creates a reference. A references provides a means to send a message body into a temporary area at the recipient end and then deliver the message by referring to this temporary area. This is how the protocol handles large message transfers. The scope of a ref is defined to be between calls to open (or resume) and close. Between these points it is valid for a ref to be used from any content data type, and so the receiver must hold onto its contents. Should the channel be closed when a ref is still in scope, the receiver may discard its contents (unless it is checkpointed). A ref that is in scope is considered open. The recipient MUST generate an error if the reference is currently open (in scope). [WORK IN PROGRESS] This method signals the recipient that no more data will be appended to the reference. A recipient CANNOT acknowledge a message until its reference is closed (not in scope). The recipient MUST generate an error if the reference was not previously open (in scope). [WORK IN PROGRESS] This method appends data to a reference. The recipient MUST generate an error if the reference is not open (not in scope). [WORK IN PROGRESS] This method provides a means to checkpoint large message transfer. The sender may ask the recipient to checkpoint the contents of a reference using the supplied identifier. The sender may then resume the transfer at a later point. It is at the discretion of the recipient how much data to save with the checkpoint, and the sender MUST honour the offset returned by the resume method. The recipient MUST generate an error if the reference is not open (not in scope). This is the checkpoint identifier. This is an arbitrary string chosen by the sender. For checkpointing to work correctly the sender must use the same checkpoint identifier when resuming the message. A good choice for the checkpoint identifier would be the SHA1 hash of the message properties data (including the original filename, revised time, etc.). [WORK IN PROGRESS] This method resumes a reference from the last checkpoint. A reference is considered to be open (in scope) after a resume even though it will not have been opened via the open method during this session. The recipient MUST generate an error if the reference is currently open (in scope). [WORK IN PROGRESS] This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. The server will send a message in advance if it is equal to or smaller in size than the available prefetch size (and also falls into other prefetch limits). May be set to zero, meaning "no specific limit", although other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. The server MUST ignore this setting when the client is not processing any messages - i.e. the prefetch size does not limit the transfer of single messages to a client, only the sending in advance of more messages while the client still has one or more unacknowledged messages. Define a QoS prefetch-size limit and send a single message that exceeds that limit. Verify that the message arrives correctly. Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server may send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. Define a QoS prefetch-size limit and a prefetch-count limit greater than one. Send multiple messages that exceed the prefetch size. Verify that no more than one message arrives at once. By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. [WORK IN PROGRESS] Signals the normal completion of a method. [WORK IN PROGRESS] Signals that a queue does not contain any messages. [WORK IN PROGRESS] This response rejects a message. A message may be rejected for a number of reasons. [WORK IN PROGRESS] Returns the data offset into a reference body. PK=JGݙb  "vumi/resources/amqp-spec-0-9-1.xml Indicates that the method completed successfully. This reply code is reserved for future use - the current protocol design does not use positive confirmation and reply codes are sent only in case of an error. The client attempted to transfer content larger than the server could accept at the present time. The client may retry at a later time. When the exchange cannot deliver to a consumer when the immediate flag is set. As a result of pending data on the queue or the absence of any consumers of the queue. An operator intervened to close the connection for some reason. The client may retry at some later date. The client tried to work with an unknown virtual host. The client attempted to work with a server entity to which it has no access due to security settings. The client attempted to work with a server entity that does not exist. The client attempted to work with a server entity to which it has no access because another client is working with it. The client requested a method that was not allowed because some precondition failed. The sender sent a malformed frame that the recipient could not decode. This strongly implies a programming error in the sending peer. The sender sent a frame that contained illegal values for one or more fields. This strongly implies a programming error in the sending peer. The client sent an invalid sequence of frames, attempting to perform an operation that was considered invalid by the server. This usually implies a programming error in the client. The client attempted to work with a channel that had not been correctly opened. This most likely indicates a fault in the client layer. The peer sent a frame that was not expected, usually in the context of a content header and body. This strongly indicates a fault in the peer's content processing. The server could not complete the method because it lacked sufficient resources. This may be due to the client creating too many of some type of entity. The client tried to work with some entity in a manner that is prohibited by the server, due to security settings or by some other criteria. The client tried to use functionality that is not implemented in the server. The server could not complete the method because of an internal error. The server may require intervention by an operator in order to resume normal operations. Identifier for the consumer, valid within the current channel. The server-assigned and channel-specific delivery tag The delivery tag is valid only within the channel from which the message was received. I.e. a client MUST NOT receive a message on one channel and then acknowledge it on another. The server MUST NOT use a zero value for delivery tags. Zero is reserved for client use, meaning "all messages so far received". The exchange name is a client-selected string that identifies the exchange for publish methods. If this field is set the server does not expect acknowledgements for messages. That is, when a message is delivered to the client the server assumes the delivery will succeed and immediately dequeues it. This functionality may increase performance but at the cost of reliability. Messages can get lost if a client dies before they are delivered to the application. If the no-local field is set the server will not send messages to the connection that published them. If set, the server will not respond to the method. The client should not wait for a reply method. If the server could not complete the method it will raise a channel or connection exception. Unconstrained. This table provides a set of peer properties, used for identification, debugging, and general information. The queue name identifies the queue within the vhost. In methods where the queue name may be blank, and that has no specific significance, this refers to the 'current' queue for the channel, meaning the last queue that the client declared on the channel. If the client did not declare a queue, and the method needs a queue name, this will result in a 502 (syntax error) channel exception. This indicates that the message has been previously delivered to this or another client. The server SHOULD try to signal redelivered messages when it can. When redelivering a message that was not successfully acknowledged, the server SHOULD deliver it to the original client if possible. Declare a shared queue and publish a message to the queue. Consume the message using explicit acknowledgements, but do not acknowledge the message. Close the connection, reconnect, and consume from the queue again. The message should arrive with the redelivered flag set. The client MUST NOT rely on the redelivered field but should take it as a hint that the message may already have been processed. A fully robust client must be able to track duplicate received messages on non-transacted, and locally-transacted channels. The number of messages in the queue, which will be zero for newly-declared queues. This is the number of messages present in the queue, and committed if the channel on which they were published is transacted, that are not waiting acknowledgement. The reply code. The AMQ reply codes are defined as constants at the start of this formal specification. The localised reply text. This text can be logged as an aid to resolving issues. The connection class provides methods for a client to establish a network connection to a server, and for both peers to operate the connection thereafter. connection = open-connection *use-connection close-connection open-connection = C:protocol-header S:START C:START-OK *challenge S:TUNE C:TUNE-OK C:OPEN S:OPEN-OK challenge = S:SECURE C:SECURE-OK use-connection = *channel close-connection = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK This method starts the connection negotiation process by telling the client the protocol version that the server proposes, along with a list of security mechanisms which the client can use for authentication. If the server cannot support the protocol specified in the protocol header, it MUST respond with a valid protocol header and then close the socket connection. The client sends a protocol header containing an invalid protocol name. The server MUST respond by sending a valid protocol header and then closing the connection. The server MUST provide a protocol version that is lower than or equal to that requested by the client in the protocol header. The client requests a protocol version that is higher than any valid implementation, e.g. 2.0. The server must respond with a protocol header indicating its supported protocol version, e.g. 1.0. If the client cannot handle the protocol version suggested by the server it MUST close the socket connection without sending any further data. The server sends a protocol version that is lower than any valid implementation, e.g. 0.1. The client must respond by closing the connection without sending any further data. The major version number can take any value from 0 to 99 as defined in the AMQP specification. The minor version number can take any value from 0 to 99 as defined in the AMQP specification. The properties SHOULD contain at least these fields: "host", specifying the server host name or address, "product", giving the name of the server product, "version", giving the name of the server version, "platform", giving the name of the operating system, "copyright", if appropriate, and "information", giving other general information. Client connects to server and inspects the server properties. It checks for the presence of the required fields. A list of the security mechanisms that the server supports, delimited by spaces. A list of the message locales that the server supports, delimited by spaces. The locale defines the language in which the server will send reply texts. The server MUST support at least the en_US locale. Client connects to server and inspects the locales field. It checks for the presence of the required locale(s). This method selects a SASL security mechanism. The properties SHOULD contain at least these fields: "product", giving the name of the client product, "version", giving the name of the client version, "platform", giving the name of the operating system, "copyright", if appropriate, and "information", giving other general information. A single security mechanisms selected by the client, which must be one of those specified by the server. The client SHOULD authenticate using the highest-level security profile it can handle from the list provided by the server. If the mechanism field does not contain one of the security mechanisms proposed by the server in the Start method, the server MUST close the connection without sending any further data. Client connects to server and sends an invalid security mechanism. The server must respond by closing the connection (a socket close, with no connection close negotiation). A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. A single message locale selected by the client, which must be one of those specified by the server. The SASL protocol works by exchanging challenges and responses until both peers have received sufficient information to authenticate each other. This method challenges the client to provide more information. Challenge information, a block of opaque binary data passed to the security mechanism. This method attempts to authenticate, passing a block of SASL data for the security mechanism at the server side. A block of opaque data passed to the security mechanism. The contents of this data are defined by the SASL security mechanism. This method proposes a set of connection configuration values to the client. The client can accept and/or adjust these. Specifies highest channel number that the server permits. Usable channel numbers are in the range 1..channel-max. Zero indicates no specified limit. The largest frame size that the server proposes for the connection, including frame header and end-byte. The client can negotiate a lower value. Zero means that the server does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Until the frame-max has been negotiated, both peers MUST accept frames of up to frame-min-size octets large, and the minimum negotiated value for frame-max is also frame-min-size. Client connects to server and sends a large properties field, creating a frame of frame-min-size octets. The server must accept this frame. The delay, in seconds, of the connection heartbeat that the server wants. Zero means the server does not want a heartbeat. This method sends the client's connection tuning parameters to the server. Certain fields are negotiated, others provide capability information. The maximum total number of channels that the client will use per connection. If the client specifies a channel max that is higher than the value provided by the server, the server MUST close the connection without attempting a negotiated close. The server may report the error in some fashion to assist implementors. The largest frame size that the client and server will use for the connection. Zero means that the client does not impose any specific limit but may reject very large frames if it cannot allocate resources for them. Note that the frame-max limit applies principally to content frames, where large contents can be broken into frames of arbitrary size. Until the frame-max has been negotiated, both peers MUST accept frames of up to frame-min-size octets large, and the minimum negotiated value for frame-max is also frame-min-size. If the client specifies a frame max that is higher than the value provided by the server, the server MUST close the connection without attempting a negotiated close. The server may report the error in some fashion to assist implementors. The delay, in seconds, of the connection heartbeat that the client wants. Zero means the client does not want a heartbeat. This method opens a connection to a virtual host, which is a collection of resources, and acts to separate multiple application domains within a server. The server may apply arbitrary limits per virtual host, such as the number of each type of entity that may be used, per connection and/or in total. The name of the virtual host to work with. If the server supports multiple virtual hosts, it MUST enforce a full separation of exchanges, queues, and all associated entities per virtual host. An application, connected to a specific virtual host, MUST NOT be able to access resources of another virtual host. The server SHOULD verify that the client has permission to access the specified virtual host. This method signals to the client that the connection is ready for use. This method indicates that the sender wants to close the connection. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method, any received methods except Close and Close-OK MUST be discarded. The response to receiving a Close after sending Close must be to send Close-Ok. When the close is provoked by a method exception, this is the class of the method. When the close is provoked by a method exception, this is the ID of the method. This method confirms a Connection.Close method and tells the recipient that it is safe to release resources for the connection and close the socket. A peer that detects a socket closure without having received a Close-Ok handshake method SHOULD log the error. The channel class provides methods for a client to establish a channel to a server and for both peers to operate the channel thereafter. channel = open-channel *use-channel close-channel open-channel = C:OPEN S:OPEN-OK use-channel = C:FLOW S:FLOW-OK / S:FLOW C:FLOW-OK / functional-class close-channel = C:CLOSE S:CLOSE-OK / S:CLOSE C:CLOSE-OK This method opens a channel to the server. The client MUST NOT use this method on an already-opened channel. Client opens a channel and then reopens the same channel. This method signals to the client that the channel is ready for use. This method asks the peer to pause or restart the flow of content data sent by a consumer. This is a simple flow-control mechanism that a peer can use to avoid overflowing its queues or otherwise finding itself receiving more messages than it can process. Note that this method is not intended for window control. It does not affect contents returned by Basic.Get-Ok methods. When a new channel is opened, it is active (flow is active). Some applications assume that channels are inactive until started. To emulate this behaviour a client MAY open the channel, then pause it. When sending content frames, a peer SHOULD monitor the channel for incoming methods and respond to a Channel.Flow as rapidly as possible. A peer MAY use the Channel.Flow method to throttle incoming content data for internal reasons, for example, when exchanging data over a slower connection. The peer that requests a Channel.Flow method MAY disconnect and/or ban a peer that does not respect the request. This is to prevent badly-behaved clients from overwhelming a server. If 1, the peer starts sending content frames. If 0, the peer stops sending content frames. Confirms to the peer that a flow command was received and processed. Confirms the setting of the processed flow method: 1 means the peer will start sending or continue to send content frames; 0 means it will not. This method indicates that the sender wants to close the channel. This may be due to internal conditions (e.g. a forced shut-down) or due to an error handling a specific method, i.e. an exception. When a close is due to an exception, the sender provides the class and method id of the method which caused the exception. After sending this method, any received methods except Close and Close-OK MUST be discarded. The response to receiving a Close after sending Close must be to send Close-Ok. When the close is provoked by a method exception, this is the class of the method. When the close is provoked by a method exception, this is the ID of the method. This method confirms a Channel.Close method and tells the recipient that it is safe to release resources for the channel. A peer that detects a socket closure without having received a Channel.Close-Ok handshake method SHOULD log the error. Exchanges match and distribute messages across queues. Exchanges can be configured in the server or declared at runtime. exchange = C:DECLARE S:DECLARE-OK / C:DELETE S:DELETE-OK The server MUST implement these standard exchange types: fanout, direct. Client attempts to declare an exchange with each of these standard types. The server SHOULD implement these standard exchange types: topic, headers. Client attempts to declare an exchange with each of these standard types. The server MUST, in each virtual host, pre-declare an exchange instance for each standard exchange type that it implements, where the name of the exchange instance, if defined, is "amq." followed by the exchange type name. The server MUST, in each virtual host, pre-declare at least two direct exchange instances: one named "amq.direct", the other with no public name that serves as a default exchange for Publish methods. Client declares a temporary queue and attempts to bind to each required exchange instance ("amq.fanout", "amq.direct", "amq.topic", and "amq.headers" if those types are defined). The server MUST pre-declare a direct exchange with no public name to act as the default exchange for content Publish methods and for default queue bindings. Client checks that the default exchange is active by specifying a queue binding with no exchange name, and publishing a message with a suitable routing key but without specifying the exchange name, then ensuring that the message arrives in the queue correctly. The server MUST NOT allow clients to access the default exchange except by specifying an empty exchange name in the Queue.Bind and content Publish methods. The server MAY implement other exchange types as wanted. This method creates an exchange if it does not already exist, and if the exchange exists, verifies that it is of the correct and expected class. The server SHOULD support a minimum of 16 exchanges per virtual host and ideally, impose no limit except as defined by available resources. The client declares as many exchanges as it can until the server reports an error; the number of exchanges successfully declared must be at least sixteen. Exchange names starting with "amq." are reserved for pre-declared and standardised exchanges. The client MAY declare an exchange starting with "amq." if the passive option is set, or the exchange already exists. The client attempts to declare a non-existing exchange starting with "amq." and with the passive option set to zero. The exchange name consists of a non-empty sequence of these characters: letters, digits, hyphen, underscore, period, or colon. The client attempts to declare an exchange with an illegal name. Each exchange belongs to one of a set of exchange types implemented by the server. The exchange types define the functionality of the exchange - i.e. how messages are routed through it. It is not valid or meaningful to attempt to change the type of an existing exchange. Exchanges cannot be redeclared with different types. The client MUST not attempt to redeclare an existing exchange with a different type than used in the original Exchange.Declare method. TODO. The client MUST NOT attempt to declare an exchange with a type that the server does not support. TODO. If set, the server will reply with Declare-Ok if the exchange already exists with the same name, and raise an error if not. The client can use this to check whether an exchange exists without modifying the server state. When set, all other method fields except name and no-wait are ignored. A declare with both passive and no-wait has no effect. Arguments are compared for semantic equivalence. If set, and the exchange does not already exist, the server MUST raise a channel exception with reply code 404 (not found). TODO. If not set and the exchange exists, the server MUST check that the existing exchange has the same values for type, durable, and arguments fields. The server MUST respond with Declare-Ok if the requested exchange matches these fields, and MUST raise a channel exception if not. TODO. If set when creating a new exchange, the exchange will be marked as durable. Durable exchanges remain active when a server restarts. Non-durable exchanges (transient exchanges) are purged if/when a server restarts. The server MUST support both durable and transient exchanges. TODO. A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This method confirms a Declare method and confirms the name of the exchange, essential for automatically-named exchanges. This method deletes an exchange. When an exchange is deleted all queue bindings on the exchange are cancelled. The client MUST NOT attempt to delete an exchange that does not exist. If set, the server will only delete the exchange if it has no queue bindings. If the exchange has queue bindings the server does not delete it but raises a channel exception instead. The server MUST NOT delete an exchange that has bindings on it, if the if-unused field is true. The client declares an exchange, binds a queue to it, then tries to delete it setting if-unused to true. This method confirms the deletion of an exchange. Queues store and forward messages. Queues can be configured in the server or created at runtime. Queues must be attached to at least one exchange in order to receive messages from publishers. queue = C:DECLARE S:DECLARE-OK / C:BIND S:BIND-OK / C:UNBIND S:UNBIND-OK / C:PURGE S:PURGE-OK / C:DELETE S:DELETE-OK This method creates or checks a queue. When creating a new queue the client can specify various properties that control the durability of the queue and its contents, and the level of sharing for the queue. The server MUST create a default binding for a newly-declared queue to the default exchange, which is an exchange of type 'direct' and use the queue name as the routing key. Client declares a new queue, and then without explicitly binding it to an exchange, attempts to send a message through the default exchange binding, i.e. publish a message to the empty exchange, with the queue name as routing key. The server SHOULD support a minimum of 256 queues per virtual host and ideally, impose no limit except as defined by available resources. Client attempts to declare as many queues as it can until the server reports an error. The resulting count must at least be 256. The queue name MAY be empty, in which case the server MUST create a new queue with a unique generated name and return this to the client in the Declare-Ok method. Client attempts to declare several queues with an empty name. The client then verifies that the server-assigned names are unique and different. Queue names starting with "amq." are reserved for pre-declared and standardised queues. The client MAY declare a queue starting with "amq." if the passive option is set, or the queue already exists. The client attempts to declare a non-existing queue starting with "amq." and with the passive option set to zero. The queue name can be empty, or a sequence of these characters: letters, digits, hyphen, underscore, period, or colon. The client attempts to declare a queue with an illegal name. If set, the server will reply with Declare-Ok if the queue already exists with the same name, and raise an error if not. The client can use this to check whether a queue exists without modifying the server state. When set, all other method fields except name and no-wait are ignored. A declare with both passive and no-wait has no effect. Arguments are compared for semantic equivalence. The client MAY ask the server to assert that a queue exists without creating the queue if not. If the queue does not exist, the server treats this as a failure. Client declares an existing queue with the passive option and expects the server to respond with a declare-ok. Client then attempts to declare a non-existent queue with the passive option, and the server must close the channel with the correct reply-code. If not set and the queue exists, the server MUST check that the existing queue has the same values for durable, exclusive, auto-delete, and arguments fields. The server MUST respond with Declare-Ok if the requested queue matches these fields, and MUST raise a channel exception if not. TODO. If set when creating a new queue, the queue will be marked as durable. Durable queues remain active when a server restarts. Non-durable queues (transient queues) are purged if/when a server restarts. Note that durable queues do not necessarily hold persistent messages, although it does not make sense to send persistent messages to a transient queue. The server MUST recreate the durable queue after a restart. Client declares a durable queue. The server is then restarted. The client then attempts to send a message to the queue. The message should be successfully delivered. The server MUST support both durable and transient queues. A client declares two named queues, one durable and one transient. Exclusive queues may only be accessed by the current connection, and are deleted when that connection closes. Passive declaration of an exclusive queue by other connections are not allowed. The server MUST support both exclusive (private) and non-exclusive (shared) queues. A client declares two named queues, one exclusive and one non-exclusive. The client MAY NOT attempt to use a queue that was declared as exclusive by another still-open connection. One client declares an exclusive queue. A second client on a different connection attempts to declare, bind, consume, purge, delete, or declare a queue of the same name. If set, the queue is deleted when all consumers have finished using it. The last consumer can be cancelled either explicitly or because its channel is closed. If there was no consumer ever on the queue, it won't be deleted. Applications can explicitly delete auto-delete queues using the Delete method as normal. The server MUST ignore the auto-delete field if the queue already exists. Client declares two named queues, one as auto-delete and one explicit-delete. Client then attempts to declare the two queues using the same names again, but reversing the value of the auto-delete field in each case. Verify that the queues still exist with the original auto-delete flag values. A set of arguments for the declaration. The syntax and semantics of these arguments depends on the server implementation. This method confirms a Declare method and confirms the name of the queue, essential for automatically-named queues. Reports the name of the queue. If the server generated a queue name, this field contains that name. Reports the number of active consumers for the queue. Note that consumers can suspend activity (Channel.Flow) in which case they do not appear in this count. This method binds a queue to an exchange. Until a queue is bound it will not receive any messages. In a classic messaging model, store-and-forward queues are bound to a direct exchange and subscription queues are bound to a topic exchange. A server MUST allow ignore duplicate bindings - that is, two or more bind methods for a specific queue, with identical arguments - without treating these as an error. A client binds a named queue to an exchange. The client then repeats the bind (with identical arguments). A server MUST not deliver the same message more than once to a queue, even if the queue has multiple bindings that match the message. A client declares a named queue and binds it using multiple bindings to the amq.topic exchange. The client then publishes a message that matches all its bindings. The server MUST allow a durable queue to bind to a transient exchange. A client declares a transient exchange. The client then declares a named durable queue and then attempts to bind the transient exchange to the durable queue. Bindings of durable queues to durable exchanges are automatically durable and the server MUST restore such bindings after a server restart. A server declares a named durable queue and binds it to a durable exchange. The server is restarted. The client then attempts to use the queue/exchange combination. The server SHOULD support at least 4 bindings per queue, and ideally, impose no limit except as defined by available resources. A client declares a named queue and attempts to bind it to 4 different exchanges. Specifies the name of the queue to bind. The client MUST either specify a queue name or have previously declared a queue on the same channel The client opens a channel and attempts to bind an unnamed queue. The client MUST NOT attempt to bind a queue that does not exist. The client attempts to bind a non-existent queue. A client MUST NOT be allowed to bind a queue to a non-existent exchange. A client attempts to bind an named queue to a undeclared exchange. The server MUST accept a blank exchange name to mean the default exchange. The client declares a queue and binds it to a blank exchange name. Specifies the routing key for the binding. The routing key is used for routing messages depending on the exchange configuration. Not all exchanges use a routing key - refer to the specific exchange documentation. If the queue name is empty, the server uses the last queue declared on the channel. If the routing key is also empty, the server uses this queue name for the routing key as well. If the queue name is provided but the routing key is empty, the server does the binding with that empty routing key. The meaning of empty routing keys depends on the exchange implementation. If a message queue binds to a direct exchange using routing key K and a publisher sends the exchange a message with routing key R, then the message MUST be passed to the message queue if K = R. A set of arguments for the binding. The syntax and semantics of these arguments depends on the exchange class. This method confirms that the bind was successful. This method unbinds a queue from an exchange. If a unbind fails, the server MUST raise a connection exception. Specifies the name of the queue to unbind. The client MUST either specify a queue name or have previously declared a queue on the same channel The client opens a channel and attempts to unbind an unnamed queue. The client MUST NOT attempt to unbind a queue that does not exist. The client attempts to unbind a non-existent queue. The name of the exchange to unbind from. The client MUST NOT attempt to unbind a queue from an exchange that does not exist. The client attempts to unbind a queue from a non-existent exchange. The server MUST accept a blank exchange name to mean the default exchange. The client declares a queue and binds it to a blank exchange name. Specifies the routing key of the binding to unbind. Specifies the arguments of the binding to unbind. This method confirms that the unbind was successful. This method removes all messages from a queue which are not awaiting acknowledgment. The server MUST NOT purge messages that have already been sent to a client but not yet acknowledged. The server MAY implement a purge queue or log that allows system administrators to recover accidentally-purged messages. The server SHOULD NOT keep purged messages in the same storage spaces as the live messages since the volumes of purged messages may get very large. Specifies the name of the queue to purge. The client MUST either specify a queue name or have previously declared a queue on the same channel The client opens a channel and attempts to purge an unnamed queue. The client MUST NOT attempt to purge a queue that does not exist. The client attempts to purge a non-existent queue. This method confirms the purge of a queue. Reports the number of messages purged. This method deletes a queue. When a queue is deleted any pending messages are sent to a dead-letter queue if this is defined in the server configuration, and all consumers on the queue are cancelled. The server SHOULD use a dead-letter queue to hold messages that were pending on a deleted queue, and MAY provide facilities for a system administrator to move these messages back to an active queue. Specifies the name of the queue to delete. The client MUST either specify a queue name or have previously declared a queue on the same channel The client opens a channel and attempts to delete an unnamed queue. The client MUST NOT attempt to delete a queue that does not exist. The client attempts to delete a non-existent queue. If set, the server will only delete the queue if it has no consumers. If the queue has consumers the server does does not delete it but raises a channel exception instead. The server MUST NOT delete a queue that has consumers on it, if the if-unused field is true. The client declares a queue, and consumes from it, then tries to delete it setting if-unused to true. If set, the server will only delete the queue if it has no messages. The server MUST NOT delete a queue that has messages on it, if the if-empty field is true. The client declares a queue, binds it and publishes some messages into it, then tries to delete it setting if-empty to true. This method confirms the deletion of a queue. Reports the number of messages deleted. The Basic class provides methods that support an industry-standard messaging model. basic = C:QOS S:QOS-OK / C:CONSUME S:CONSUME-OK / C:CANCEL S:CANCEL-OK / C:PUBLISH content / S:RETURN content / S:DELIVER content / C:GET ( S:GET-OK content / S:GET-EMPTY ) / C:ACK / C:REJECT / C:RECOVER-ASYNC / C:RECOVER S:RECOVER-OK The server SHOULD respect the persistent property of basic messages and SHOULD make a best-effort to hold persistent basic messages on a reliable storage mechanism. Send a persistent message to queue, stop server, restart server and then verify whether message is still present. Assumes that queues are durable. Persistence without durable queues makes no sense. The server MUST NOT discard a persistent basic message in case of a queue overflow. Declare a queue overflow situation with persistent messages and verify that messages do not get lost (presumably the server will write them to disk). The server MAY use the Channel.Flow method to slow or stop a basic message publisher when necessary. Declare a queue overflow situation with non-persistent messages and verify whether the server responds with Channel.Flow or not. Repeat with persistent messages. The server MAY overflow non-persistent basic messages to persistent storage. The server MAY discard or dead-letter non-persistent basic messages on a priority basis if the queue size exceeds some configured limit. The server MUST implement at least 2 priority levels for basic messages, where priorities 0-4 and 5-9 are treated as two distinct levels. Send a number of priority 0 messages to a queue. Send one priority 9 message. Consume messages from the queue and verify that the first message received was priority 9. The server MAY implement up to 10 priority levels. Send a number of messages with mixed priorities to a queue, so that all priority values from 0 to 9 are exercised. A good scenario would be ten messages in low-to-high priority. Consume from queue and verify how many priority levels emerge. The server MUST deliver messages of the same priority in order irrespective of their individual persistence. Send a set of messages with the same priority but different persistence settings to a queue. Consume and verify that messages arrive in same order as originally published. The server MUST support un-acknowledged delivery of Basic content, i.e. consumers with the no-ack field set to TRUE. The server MUST support explicitly acknowledged delivery of Basic content, i.e. consumers with the no-ack field set to FALSE. Declare a queue and a consumer using explicit acknowledgements. Publish a set of messages to the queue. Consume the messages but acknowledge only half of them. Disconnect and reconnect, and consume from the queue. Verify that the remaining messages are received. This method requests a specific quality of service. The QoS can be specified for the current channel or for all channels on the connection. The particular properties and semantics of a qos method always depend on the content class semantics. Though the qos method could in principle apply to both peers, it is currently meaningful only for the server. The client can request that messages be sent in advance so that when the client finishes processing a message, the following message is already held locally, rather than needing to be sent down the channel. Prefetching gives a performance improvement. This field specifies the prefetch window size in octets. The server will send a message in advance if it is equal to or smaller in size than the available prefetch size (and also falls into other prefetch limits). May be set to zero, meaning "no specific limit", although other prefetch limits may still apply. The prefetch-size is ignored if the no-ack option is set. The server MUST ignore this setting when the client is not processing any messages - i.e. the prefetch size does not limit the transfer of single messages to a client, only the sending in advance of more messages while the client still has one or more unacknowledged messages. Define a QoS prefetch-size limit and send a single message that exceeds that limit. Verify that the message arrives correctly. Specifies a prefetch window in terms of whole messages. This field may be used in combination with the prefetch-size field; a message will only be sent in advance if both prefetch windows (and those at the channel and connection level) allow it. The prefetch-count is ignored if the no-ack option is set. The server may send less data in advance than allowed by the client's specified prefetch windows but it MUST NOT send more. Define a QoS prefetch-size limit and a prefetch-count limit greater than one. Send multiple messages that exceed the prefetch size. Verify that no more than one message arrives at once. By default the QoS settings apply to the current channel only. If this field is set, they are applied to the entire connection. This method tells the client that the requested QoS levels could be handled by the server. The requested QoS applies to all active consumers until a new QoS is defined. This method asks the server to start a "consumer", which is a transient request for messages from a specific queue. Consumers last as long as the channel they were declared on, or until the client cancels them. The server SHOULD support at least 16 consumers per queue, and ideally, impose no limit except as defined by available resources. Declare a queue and create consumers on that queue until the server closes the connection. Verify that the number of consumers created was at least sixteen and report the total number. Specifies the name of the queue to consume from. Specifies the identifier for the consumer. The consumer tag is local to a channel, so two clients can use the same consumer tags. If this field is empty the server will generate a unique tag. The client MUST NOT specify a tag that refers to an existing consumer. Attempt to create two consumers with the same non-empty tag, on the same channel. The consumer tag is valid only within the channel from which the consumer was created. I.e. a client MUST NOT create a consumer in one channel and then use it in another. Attempt to create a consumer in one channel, then use in another channel, in which consumers have also been created (to test that the server uses unique consumer tags). Request exclusive consumer access, meaning only this consumer can access the queue. The client MAY NOT gain exclusive access to a queue that already has active consumers. Open two connections to a server, and in one connection declare a shared (non-exclusive) queue and then consume from the queue. In the second connection attempt to consume from the same queue using the exclusive option. A set of arguments for the consume. The syntax and semantics of these arguments depends on the server implementation. The server provides the client with a consumer tag, which is used by the client for methods called on the consumer at a later stage. Holds the consumer tag specified by the client or provided by the server. This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. The client may receive an arbitrary number of messages in between sending the cancel method and receiving the cancel-ok reply. If the queue does not exist the server MUST ignore the cancel method, so long as the consumer tag is valid for that channel. TODO. This method confirms that the cancellation was completed. This method publishes a message to a specific exchange. The message will be routed to queues as defined by the exchange configuration and distributed to any active consumers when the transaction, if any, is committed. Specifies the name of the exchange to publish to. The exchange name can be empty, meaning the default exchange. If the exchange name is specified, and that exchange does not exist, the server will raise a channel exception. The client MUST NOT attempt to publish a content to an exchange that does not exist. The client attempts to publish a content to a non-existent exchange. The server MUST accept a blank exchange name to mean the default exchange. The client declares a queue and binds it to a blank exchange name. If the exchange was declared as an internal exchange, the server MUST raise a channel exception with a reply code 403 (access refused). TODO. The exchange MAY refuse basic content in which case it MUST raise a channel exception with reply code 540 (not implemented). TODO. Specifies the routing key for the message. The routing key is used for routing messages depending on the exchange configuration. This flag tells the server how to react if the message cannot be routed to a queue. If this flag is set, the server will return an unroutable message with a Return method. If this flag is zero, the server silently drops the message. The server SHOULD implement the mandatory flag. TODO. This flag tells the server how to react if the message cannot be routed to a queue consumer immediately. If this flag is set, the server will return an undeliverable message with a Return method. If this flag is zero, the server will queue the message, but with no guarantee that it will ever be consumed. The server SHOULD implement the immediate flag. TODO. This method returns an undeliverable message that was published with the "immediate" flag set, or an unroutable message published with the "mandatory" flag set. The reply code and text provide information about the reason that the message was undeliverable. Specifies the name of the exchange that the message was originally published to. May be empty, meaning the default exchange. Specifies the routing key name specified when the message was published. This method delivers a message to the client, via a consumer. In the asynchronous message delivery model, the client starts a consumer using the Consume method, then the server responds with Deliver methods as and when messages arrive for that consumer. The server SHOULD track the number of times a message has been delivered to clients and when a message is redelivered a certain number of times - e.g. 5 times - without being acknowledged, the server SHOULD consider the message to be unprocessable (possibly causing client applications to abort), and move the message to a dead letter queue. TODO. Specifies the name of the exchange that the message was originally published to. May be empty, indicating the default exchange. Specifies the routing key name specified when the message was published. This method provides a direct access to the messages in a queue using a synchronous dialogue that is designed for specific types of application where synchronous functionality is more important than performance. Specifies the name of the queue to get a message from. This method delivers a message to the client following a get method. A message delivered by 'get-ok' must be acknowledged unless the no-ack option was set in the get method. Specifies the name of the exchange that the message was originally published to. If empty, the message was published to the default exchange. Specifies the routing key name specified when the message was published. This method tells the client that the queue has no messages available for the client. This method acknowledges one or more messages delivered via the Deliver or Get-Ok methods. The client can ask to confirm a single message or a set of messages up to and including a specific message. If set to 1, the delivery tag is treated as "up to and including", so that the client can acknowledge multiple messages with a single method. If set to zero, the delivery tag refers to a single message. If the multiple field is 1, and the delivery tag is zero, tells the server to acknowledge all outstanding messages. The server MUST validate that a non-zero delivery-tag refers to a delivered message, and raise a channel exception if this is not the case. On a transacted channel, this check MUST be done immediately and not delayed until a Tx.Commit. Specifically, a client MUST not acknowledge the same message more than once. TODO. This method allows a client to reject a message. It can be used to interrupt and cancel large incoming messages, or return untreatable messages to their original queue. The server SHOULD be capable of accepting and process the Reject method while sending message content with a Deliver or Get-Ok method. I.e. the server should read and process incoming methods while sending output frames. To cancel a partially-send content, the server sends a content body frame of size 1 (i.e. with no data except the frame-end octet). The server SHOULD interpret this method as meaning that the client is unable to process the message at this time. TODO. The client MUST NOT use this method as a means of selecting messages to process. TODO. If requeue is true, the server will attempt to requeue the message. If requeue is false or the requeue attempt fails the messages are discarded or dead-lettered. The server MUST NOT deliver the message to the same client within the context of the current channel. The recommended strategy is to attempt to deliver the message to an alternative consumer, and if that is not possible, to move the message to a dead-letter queue. The server MAY use more sophisticated tracking to hold the message on the queue and redeliver it to the same client at a later stage. TODO. This method asks the server to redeliver all unacknowledged messages on a specified channel. Zero or more messages may be redelivered. This method is deprecated in favour of the synchronous Recover/Recover-Ok. The server MUST set the redelivered flag on all messages that are resent. TODO. If this field is zero, the message will be redelivered to the original recipient. If this bit is 1, the server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. This method asks the server to redeliver all unacknowledged messages on a specified channel. Zero or more messages may be redelivered. This method replaces the asynchronous Recover. The server MUST set the redelivered flag on all messages that are resent. TODO. If this field is zero, the message will be redelivered to the original recipient. If this bit is 1, the server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. This method acknowledges a Basic.Recover method. The Tx class allows publish and ack operations to be batched into atomic units of work. The intention is that all publish and ack requests issued within a transaction will complete successfully or none of them will. Servers SHOULD implement atomic transactions at least where all publish or ack requests affect a single queue. Transactions that cover multiple queues may be non-atomic, given that queues can be created and destroyed asynchronously, and such events do not form part of any transaction. Further, the behaviour of transactions with respect to the immediate and mandatory flags on Basic.Publish methods is not defined. Applications MUST NOT rely on the atomicity of transactions that affect more than one queue. Applications MUST NOT rely on the behaviour of transactions that include messages published with the immediate option. Applications MUST NOT rely on the behaviour of transactions that include messages published with the mandatory option. tx = C:SELECT S:SELECT-OK / C:COMMIT S:COMMIT-OK / C:ROLLBACK S:ROLLBACK-OK This method sets the channel to use standard transactions. The client must use this method at least once on a channel before using the Commit or Rollback methods. This method confirms to the client that the channel was successfully set to use standard transactions. This method commits all message publications and acknowledgments performed in the current transaction. A new transaction starts immediately after a commit. The client MUST NOT use the Commit method on non-transacted channels. The client opens a channel and then uses Tx.Commit. This method confirms to the client that the commit succeeded. Note that if a commit fails, the server raises a channel exception. This method abandons all message publications and acknowledgments performed in the current transaction. A new transaction starts immediately after a rollback. Note that unacked messages will not be automatically redelivered by rollback; if that is required an explicit recover call should be issued. The client MUST NOT use the Rollback method on non-transacted channels. The client opens a channel and then uses Tx.Rollback. This method confirms to the client that the rollback succeeded. Note that if an rollback fails, the server raises a channel exception. PK=JGm6!!vumi/middleware/tagger.py# -*- test-case-name: vumi.middleware.tests.test_tagger -*- from confmodel import Config from confmodel.fields import ConfigDict import re from vumi.middleware.base import TransportMiddleware, BaseMiddlewareConfig class TaggingMiddlewareConfig(BaseMiddlewareConfig): class ConfigIncoming(ConfigDict): def clean(self, value): if 'addr_pattern' not in value: self.raise_config_error( "does not contain the `addr_pattern` key.") if not isinstance(value['addr_pattern'], basestring): self.raise_config_error( "does not have an `addr_pattern` key with type `string`.") if 'tagpool_template' not in value: self.raise_config_error( "does not contain the `tagpool_template` key.") if not isinstance(value['tagpool_template'], basestring): self.raise_config_error( "does not have an `tagpool_template` key with type " "`string`.") if 'tagname_template' not in value: self.raise_config_error( "does not contain the `tagname_template` key.") if not isinstance(value['tagname_template'], basestring): self.raise_config_error( "does not have an `tagname_template` key with type " "`string`.") return super(self.__class__, self).clean(value) class ConfigOutgoing(ConfigDict): def clean(self, value): if 'tagname_pattern' not in value: self.raise_config_error( "does not contain the `tagname_pattern` key.") if not isinstance(value['tagname_pattern'], basestring): self.raise_config_error( "does not have an `tagname_pattern` key with type " "`string`.") if 'msg_template' not in value: self.raise_config_error( "does not contain the `msg_template` key.") if not isinstance(value['msg_template'], dict): self.raise_config_error( "does not have an `msg_template` key with type `dict`.") return super(self.__class__, self).clean(value) incoming = ConfigIncoming( "Dict containing " """* **addr_pattern** (*string*): Regular expression matching the to_addr of incoming messages. Incoming messages with to_addr values that don't match the pattern are not modified. * **tagpool_template** (*string*): Template for producing tag pool from successful matches of `addr_pattern`. The string is expanded using `match.expand(tagpool_template)`. * **tagname_template** (*string*): Template for producing tag name from successful matches of `addr_pattern`. The string is expanded using `match.expand(tagname_template)`.""", required=True, static=True) outgoing = ConfigOutgoing( "Dict containing " """* **tagname_pattern** (*string*): Regular expression matching the tag name of outgoing messages. Outgoing messages with tag names that don't match the pattern are not modified. Note: The tag pool the tag belongs to is not examined. * **msg_template** (*dict*): A dictionary of additional key-value pairs to add to the outgoing message payloads whose tag matches `tag_pattern`. Values which are strings are expanded using `match.expand(value)`. Values which are dicts are recursed into. Values which are neither are left as is.""", required=True, static=True) class TaggingMiddleware(TransportMiddleware): """ Transport middleware for adding tag names to inbound messages and for adding additional parameters to outbound messages based on their tag. Transports that wish to eventually have incoming messages associated with an existing message batch by :class:`vumi.application.MessageStore` or via :class:`vumi.middleware.StoringMiddleware` need to ensure that incoming messages are provided with a tag by this or some other middleware. Configuration options: :param dict incoming: * **addr_pattern** (*string*): Regular expression matching the to_addr of incoming messages. Incoming messages with to_addr values that don't match the pattern are not modified. * **tagpool_template** (*string*): Template for producing tag pool from successful matches of `addr_pattern`. The string is expanded using `match.expand(tagpool_template)`. * **tagname_template** (*string*): Template for producing tag name from successful matches of `addr_pattern`. The string is expanded using `match.expand(tagname_template)`. :param dict outgoing: * **tagname_pattern** (*string*): Regular expression matching the tag name of outgoing messages. Outgoing messages with tag names that don't match the pattern are not modified. Note: The tag pool the tag belongs to is not examined. * **msg_template** (*dict*): A dictionary of additional key-value pairs to add to the outgoing message payloads whose tag matches `tag_pattern`. Values which are strings are expanded using `match.expand(value)`. Values which are dicts are recursed into. Values which are neither are left as is. """ CONFIG_CLASS = TaggingMiddlewareConfig def setup_middleware(self): config_incoming = self.config.incoming config_outgoing = self.config.outgoing self.to_addr_re = re.compile(config_incoming['addr_pattern']) self.tagpool_template = config_incoming['tagpool_template'] self.tagname_template = config_incoming['tagname_template'] self.tag_re = re.compile(config_outgoing['tagname_pattern']) self.msg_template = config_outgoing['msg_template'] def handle_inbound(self, message, connector_name): to_addr = message.get('to_addr') if to_addr is not None: match = self.to_addr_re.match(to_addr) else: match = None if match is not None: tag = (match.expand(self.tagpool_template), match.expand(self.tagname_template)) self.add_tag_to_msg(message, tag) return message def handle_outbound(self, message, connector_name): tag = self.map_msg_to_tag(message) if tag is not None: match = self.tag_re.match(tag[1]) else: match = None if match is not None: self._deepupdate(match, message.payload, self.msg_template) return message @staticmethod def _deepupdate(match, origdict, newdict): # set of ids of processed dicts (to avoid recursion) seen = set([id(newdict)]) stack = [(origdict, newdict)] while stack: current_dict, current_new_dict = stack.pop() for key, value in current_new_dict.iteritems(): if isinstance(value, dict): if id(value) in seen: continue next_dict = current_dict.setdefault(key, {}) seen.add(id(value)) stack.append((next_dict, value)) elif isinstance(value, basestring): current_dict[key] = match.expand(value) else: current_dict[key] = value @staticmethod def add_tag_to_msg(msg, tag): """Convenience method for adding a tag to a message.""" tag_metadata = msg['helper_metadata'].setdefault('tag', {}) # convert tag to list so that msg == json.loads(json.dumps(msg)) tag_metadata['tag'] = list(tag) @staticmethod def add_tag_to_payload(payload, tag): """Convenience method for adding a tag to a message payload.""" helper_metadata = payload.setdefault('helper_metadata', {}) tag_metadata = helper_metadata.setdefault('tag', {}) tag_metadata['tag'] = list(tag) @staticmethod def map_msg_to_tag(msg): """Convenience method for retrieving a tag that was added to a message by this middleware. """ tag = msg.get('helper_metadata', {}).get('tag', {}).get('tag') if tag is not None: # convert JSON list to a proper tag tuple return tuple(tag) return None PK=JGkF F vumi/middleware/manhole.py# -*- test-case-name: vumi.middleware.tests.test_manhole -*- from confmodel.fields import ConfigList from vumi.middleware import BaseMiddleware from vumi.middleware.base import BaseMiddlewareConfig from vumi.config import ConfigServerEndpoint from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks from twisted.cred import portal from twisted.conch import manhole_ssh, manhole_tap from twisted.conch.checkers import SSHPublicKeyDatabase from twisted.python.filepath import FilePath from twisted.internet.endpoints import serverFromString class ManholeMiddlewareConfig(BaseMiddlewareConfig): twisted_endpoint = ConfigServerEndpoint( "Twisted endpoint to listen on", default="tcp:0", static=True) autorized_keys = ConfigList( "List of absolute paths to `authorized_keys` files containing SSH " "public keys that are allowed access.", default=None, static=True) class SSHPubKeyDatabase(SSHPublicKeyDatabase): """ Checker for authorizing people against a list of `authorized_keys` files. If nothing is specified then it defaults to `authorized_keys` and `authorized_keys2` for the logged in user. """ def __init__(self, authorized_keys): self.authorized_keys = authorized_keys def getAuthorizedKeysFiles(self, credentials): if self.authorized_keys is not None: return [FilePath(ak) for ak in self.authorized_keys] return SSHPublicKeyDatabase.getAuthorizedKeysFiles(self, credentials) class ManholeMiddleware(BaseMiddleware): """ Middleware providing SSH access into the worker this middleware is attached to. Requires the following packages to be installed: * pyasn1 * pycrypto :param str twisted_endpoint: The Twisted endpoint to listen on. Defaults to `tcp:0` which has the reactor select any available port. :param list authorized_keys: List of absolute paths to `authorized_keys` files containing SSH public keys that are allowed access. """ CONFIG_CLASS = ManholeMiddlewareConfig def validate_config(self): self.twisted_endpoint = self.config.twisted_endpoint self.authorized_keys = self.config.authorized_keys @inlineCallbacks def setup_middleware(self): self.validate_config() checker = SSHPubKeyDatabase(self.authorized_keys) ssh_realm = manhole_ssh.TerminalRealm() ssh_realm.chainedProtocolFactory = manhole_tap.chainedProtocolFactory({ 'worker': self.worker, }) ssh_portal = portal.Portal(ssh_realm, [checker]) factory = manhole_ssh.ConchFactory(ssh_portal) endpoint = serverFromString(reactor, self.twisted_endpoint) self.socket = yield endpoint.listen(factory) def teardown_middleware(self): return self.socket.stopListening() PK=JGjDŽ''vumi/middleware/base.py# -*- test-case-name: vumi.middleware.tests.test_base -*- from confmodel import Config from twisted.internet.defer import inlineCallbacks, returnValue from vumi.utils import load_class_by_string from vumi.errors import ConfigError, VumiError class MiddlewareError(VumiError): pass class BaseMiddlewareConfig(Config): """ Config class for the base middleware. """ class BaseMiddleware(object): """Common middleware base class. This is a convenient definition of and set of common functionality for middleware classes. You need not subclass this and should not instantiate this directly. The :meth:`__init__` method should take exactly the following options so that your class can be instantiated from configuration in a standard way: :param string name: Name of the middleware. :param dict config: Dictionary of configuraiton items. :type worker: vumi.service.Worker :param worker: Reference to the transport or application being wrapped by this middleware. If you are subclassing this class, you should not override :meth:`__init__`. Custom setup should be done in :meth:`setup_middleware` instead. The config class can be overidden by replacing the ``config_class`` class variable. """ CONFIG_CLASS = BaseMiddlewareConfig def __init__(self, name, config, worker): self.name = name self.config = self.CONFIG_CLASS(config, static=True) self.consume_priority = config.get('consume_priority') self.publish_priority = config.get('publish_priority') self.worker = worker def setup_middleware(self): """Any custom setup may be done here. :rtype: Deferred or None :returns: May return a deferred that is called when setup is complete. """ pass def teardown_middleware(self): """"Any custom teardown may be done here :rtype: Deferred or None :returns: May return a Deferred that is called when teardown is complete """ pass def handle_consume_inbound(self, message, connector_name): """Called when an inbound transport user message is consumed. The other methods listed below all function in the same way. Only the kind and direction of the message being processed differs. * :meth:`handle_publish_inbound` * :meth:`handle_consume_outbound` * :meth:`handle_publish_outbound` * :meth:`handle_consume_event` * :meth:`handle_publish_event` * :meth:`handle_failure` By default, the ``handle_consume_*`` and ``handle_publish_*`` methods call their ``handle_*`` equivalents. :param vumi.message.TransportUserMessage message: Inbound message to process. :param string connector_name: The name of the connector the message is being received on or sent to. :rtype: vumi.message.TransportUserMessage :returns: The processed message. """ return self.handle_inbound(message, connector_name) def handle_publish_inbound(self, message, connector_name): """Called when an inbound transport user message is published. See :meth:`handle_consume_inbound`. """ return self.handle_inbound(message, connector_name) def handle_inbound(self, message, connector_name): """Default handler for published and consumed inbound messages. See :meth:`handle_consume_inbound`. """ return message def handle_consume_outbound(self, message, connector_name): """Called when an outbound transport user message is consumed. See :meth:`handle_consume_inbound`. """ return self.handle_outbound(message, connector_name) def handle_publish_outbound(self, message, connector_name): """Called when an outbound transport user message is published. See :meth:`handle_consume_inbound`. """ return self.handle_outbound(message, connector_name) def handle_outbound(self, message, connector_name): """Default handler for published and consumed outbound messages. See :meth:`handle_consume_inbound`. """ return message def handle_consume_event(self, event, connector_name): """Called when a transport event is consumed. See :meth:`handle_consume_inbound`. """ return self.handle_event(event, connector_name) def handle_publish_event(self, event, connector_name): """Called when a transport event is published. See :meth:`handle_consume_inbound`. """ return self.handle_event(event, connector_name) def handle_event(self, event, connector_name): """Default handler for published and consumed events. See :meth:`handle_consume_inbound`. """ return event def handle_consume_failure(self, failure, connector_name): """Called when a failure message is consumed. See :meth:`handle_consume_inbound`. """ return self.handle_failure(failure, connector_name) def handle_publish_failure(self, failure, connector_name): """Called when a failure message is published. See :meth:`handle_consume_inbound`. """ return self.handle_failure(failure, connector_name) def handle_failure(self, failure, connector_name): """Called to process a failure message ( :class:`vumi.transports.failures.FailureMessage`). See :meth:`handle_consume_inbound`. """ return failure class TransportMiddleware(BaseMiddleware): """Message processor middleware for Transports. """ class ApplicationMiddleware(BaseMiddleware): """Message processor middleware for Applications. """ class MiddlewareStack(object): """Ordered list of middlewares to pass a Message through. """ def __init__(self, middlewares): self.consume_middlewares = self._sort_by_priority( middlewares, 'consume_priority') self.publish_middlewares = self._sort_by_priority( reversed(middlewares), 'publish_priority') @staticmethod def _sort_by_priority(middlewares, priority_key): # We rely on Python's sorting algorithm being stable to preserve # order within priority levels. return sorted(middlewares, key=lambda mw: getattr(mw, priority_key)) @inlineCallbacks def _handle(self, middlewares, handler_name, message, connector_name): method_name = 'handle_%s' % (handler_name,) for middleware in middlewares: handler = getattr(middleware, method_name) message = yield handler(message, connector_name) if message is None: raise MiddlewareError( 'Returned value of %s.%s should never be None' % ( middleware, method_name,)) returnValue(message) def apply_consume(self, handler_name, message, connector_name): handler_name = 'consume_%s' % (handler_name,) return self._handle( self.consume_middlewares, handler_name, message, connector_name) def apply_publish(self, handler_name, message, connector_name): handler_name = 'publish_%s' % (handler_name,) return self._handle( self.publish_middlewares, handler_name, message, connector_name) @inlineCallbacks def teardown(self): for mw in self.publish_middlewares: yield mw.teardown_middleware() def create_middlewares_from_config(worker, config): """Return a list of middleware objects created from a worker configuration. """ middlewares = [] for item in config.get("middleware", []): keys = item.keys() if len(keys) != 1: raise ConfigError( "Middleware items contain only a single key-value pair. The" " key should be a name for the middleware. The value should be" " the full dotted name of the class implementing the" " middleware, or a mapping containing the keys 'class' with a" " value of the full dotted class name, 'consume_priority' with" " the priority level for consuming, and 'publish_priority'" " with the priority level for publishing, both integers.") middleware_name = keys[0] middleware_config = config.get(middleware_name, {}) if isinstance(item[middleware_name], basestring): cls_name = item[middleware_name] middleware_config['consume_priority'] = 0 middleware_config['publish_priority'] = 0 elif isinstance(item[middleware_name], dict): conf = item[middleware_name] cls_name = conf.get('class') try: middleware_config['consume_priority'] = int(conf.get( 'consume_priority', 0)) middleware_config['publish_priority'] = int(conf.get( 'publish_priority', 0)) except ValueError: raise ConfigError( "Middleware priority level must be an integer") else: raise ConfigError( "Middleware item values must either be a string with the", " full dotted name of the class implementing the middleware," " or a dictionary with 'class', 'consume_priority', and" " 'publish_priority' keys.") cls = load_class_by_string(cls_name) middleware = cls(middleware_name, middleware_config, worker) middlewares.append(middleware) return middlewares @inlineCallbacks def setup_middlewares_from_config(worker, config): """Create a list of middleware objects, call .setup_middleware() on then and then return the list. """ middlewares = create_middlewares_from_config(worker, config) for mw in middlewares: yield mw.setup_middleware() returnValue(middlewares) PK=JG%vumi/middleware/address_translator.py# -*- test-case-name: vumi.middleware.tests.test_address_translator -*- from confmodel.fields import ConfigDict from vumi.middleware import BaseMiddleware from vumi.middleware.base import BaseMiddlewareConfig class AddressTranslatorMiddlewareConfig(BaseMiddlewareConfig): """ Configuration class for the address translator middleware. """ outbound_map = ConfigDict( "Mapping of old ``to_addr`` values to new ``to_addr`` values", required=True, static=True) class AddressTranslationMiddleware(BaseMiddleware): """Address translation middleware. Used for mapping a set of `to_addr` values in outbound messages to new values. Inbound messages have the inverse mapping applied to their `from_addr` values.. This is useful during debugging, testing and development. For example, you might want to make your Gmail address look like an MSISDN to an application to test SMS address handling, for instance. Or you might want to have an outgoing SMS end up at your Gmail account. Configuration options: :param dict outbound_map: Mapping of old `to_addr` values to new `to_addr` values for outbound messages. Inbound messages have the inverse mapping applied to `from_addr` values. Addresses not in this dictionary are not affected. """ CONFIG_CLASS = AddressTranslatorMiddlewareConfig def setup_middleware(self): self.outbound_map = self.config.outbound_map self.inbound_map = dict((v, k) for k, v in self.outbound_map.items()) def handle_outbound(self, message, connector_name): fake_addr = message['to_addr'] real_addr = self.outbound_map.get(fake_addr) if real_addr is not None: message['to_addr'] = real_addr return message def handle_inbound(self, message, connector_name): real_addr = message['from_addr'] fake_addr = self.inbound_map.get(real_addr) if fake_addr is not None: message['from_addr'] = fake_addr return message PK=JGiV 99 vumi/middleware/manhole_utils.pyimport os from twisted.conch.ssh import transport, userauth, connection, channel from twisted.internet import defer from twisted.conch.manhole_ssh import ConchFactory # these are shipped along with Twisted private_key = ConchFactory.privateKeys['ssh-rsa'] public_key = ConchFactory.publicKeys['ssh-rsa'] class ClientTransport(transport.SSHClientTransport): def verifyHostKey(self, pub_key, fingerprint): return defer.succeed(1) def connectionSecure(self): return self.requestService(ClientUserAuth( os.getlogin(), ClientConnection(self.factory.channelConnected))) class ClientUserAuth(userauth.SSHUserAuthClient): def getPassword(self, prompt=None): # Not doing password based auth return def getPublicKey(self): return public_key.blob() def getPrivateKey(self): return defer.succeed(private_key.keyObject) class ClientConnection(connection.SSHConnection): def __init__(self, channel_connected): connection.SSHConnection.__init__(self) self._channel_connected = channel_connected def serviceStarted(self): channel = ClientChannel(self._channel_connected, conn=self) self.openChannel(channel) class ClientChannel(channel.SSHChannel): name = 'session' def __init__(self, channel_connected, *args, **kwargs): channel.SSHChannel.__init__(self, *args, **kwargs) self._channel_connected = channel_connected self.buffer = u'' self.queue = defer.DeferredQueue() def channelOpen(self, data): self._channel_connected.callback(self) def dataReceived(self, data): self.buffer += data lines = self.buffer.split('\r\n') for line in lines[:-1]: self.queue.put(line) self.buffer = lines[-1] PK=JGvumi/middleware/logging.py# -*- test-case-name: vumi.middleware.tests.test_logging -*- from confmodel.fields import ConfigText from vumi.middleware import BaseMiddleware from vumi.middleware.base import BaseMiddlewareConfig from vumi import log class LoggingMiddlewareConfig(BaseMiddlewareConfig): """ Configuration class for the logging middleware. """ log_level = ConfigText( "Log level from :mod:`vumi.log` to log inbound and outbound messages " "and events at", default='info', static=True) failure_log_level = ConfigText( "Log level from :mod:`vumi.log` to log failure messages at", default='error', static=True) class LoggingMiddleware(BaseMiddleware): """Middleware for logging messages published and consumed by transports and applications. Optional configuration: :param string log_level: Log level from :mod:`vumi.log` to log inbound and outbound messages and events at. Default is `info`. :param string failure_log_level: Log level from :mod:`vumi.log` to log failure messages at. Default is `error`. """ CONFIG_CLASS = LoggingMiddlewareConfig def setup_middleware(self): log_level = self.config.log_level self.message_logger = getattr(log, log_level) failure_log_level = self.config.failure_log_level self.failure_logger = getattr(log, failure_log_level) def _log(self, direction, logger, msg, connector_name): logger("Processed %s message for %s: %s" % ( direction, connector_name, msg.to_json())) return msg def handle_inbound(self, message, connector_name): return self._log( "inbound", self.message_logger, message, connector_name) def handle_outbound(self, message, connector_name): return self._log( "outbound", self.message_logger, message, connector_name) def handle_event(self, event, connector_name): return self._log("event", self.message_logger, event, connector_name) def handle_failure(self, failure, connector_name): return self._log( "failure", self.failure_logger, failure, connector_name) PK=JG."vumi/middleware/message_storing.py# -*- test-case-name: vumi.middleware.tests.test_message_storing -*- from confmodel.fields import ConfigBool, ConfigDict, ConfigText from twisted.internet.defer import inlineCallbacks, returnValue from vumi.middleware.base import BaseMiddleware, BaseMiddlewareConfig from vumi.middleware.tagger import TaggingMiddleware from vumi.components.message_store import MessageStore from vumi.config import ConfigRiak from vumi.persist.txriak_manager import TxRiakManager from vumi.persist.txredis_manager import TxRedisManager class StoringMiddlewareConfig(BaseMiddlewareConfig): """ Config class for the storing middleware. """ store_prefix = ConfigText( "Prefix for message store keys in key-value store.", default='message_store', static=True) redis_manager = ConfigDict( "Redis configuration parameters", default={}, static=True) riak_manager = ConfigRiak( "Riak configuration parameters. Must contain at least a bucket_prefix" " key", required=True, static=True) store_on_consume = ConfigBool( "``True`` to store consumed messages as well as published ones, " "``False`` to store only published messages.", default=True, static=True) class StoringMiddleware(BaseMiddleware): """Middleware for storing inbound and outbound messages and events. Failures are not stored currently because these are typically stored by :class:`vumi.transports.FailureWorker` instances. Messages are always stored. However, in order for messages to be associated with a particular batch_id ( see :class:`vumi.application.MessageStore`) a batch needs to be created in the message store (typically by an application worker that initiates sending outbound messages) and messages need to be tagged with a tag associated with the batch (typically by an application worker or middleware such as :class:`vumi.middleware.TaggingMiddleware`). Configuration options: :param string store_prefix: Prefix for message store keys in key-value store. Default is 'message_store'. :param dict redis_manager: Redis configuration parameters. :param dict riak_manager: Riak configuration parameters. Must contain at least a bucket_prefix key. :param bool store_on_consume: ``True`` to store consumed messages as well as published ones, ``False`` to store only published messages. Default is ``True``. """ CONFIG_CLASS = StoringMiddlewareConfig @inlineCallbacks def setup_middleware(self): store_prefix = self.config.store_prefix r_config = self.config.redis_manager self.redis = yield TxRedisManager.from_config(r_config) self.manager = TxRiakManager.from_config(self.config.riak_manager) self.store = MessageStore( self.manager, self.redis.sub_manager(store_prefix)) self.store_on_consume = self.config.store_on_consume @inlineCallbacks def teardown_middleware(self): yield self.redis.close_manager() yield self.manager.close_manager() def handle_consume_inbound(self, message, connector_name): if not self.store_on_consume: return message return self.handle_inbound(message, connector_name) @inlineCallbacks def handle_inbound(self, message, connector_name): tag = TaggingMiddleware.map_msg_to_tag(message) yield self.store.add_inbound_message(message, tag=tag) returnValue(message) def handle_consume_outbound(self, message, connector_name): if not self.store_on_consume: return message return self.handle_outbound(message, connector_name) @inlineCallbacks def handle_outbound(self, message, connector_name): tag = TaggingMiddleware.map_msg_to_tag(message) yield self.store.add_outbound_message(message, tag=tag) returnValue(message) def handle_consume_event(self, event, connector_name): if not self.store_on_consume: return event return self.handle_event(event, connector_name) @inlineCallbacks def handle_event(self, event, connector_name): transport_metadata = event.get('transport_metadata', {}) # FIXME: The SMPP transport writes a 'datetime' object # in the 'date' of the transport_metadata. # json.dumps() that RiakObject uses is unhappy with that. if 'date' in transport_metadata: date = transport_metadata['date'] if not isinstance(date, basestring): transport_metadata['date'] = date.isoformat() yield self.store.add_event(event) returnValue(event) PK=JGSAvumi/middleware/__init__.py"""Middleware classes to process messages on their way in and out of workers. """ from vumi.middleware.base import ( BaseMiddleware, TransportMiddleware, ApplicationMiddleware, MiddlewareStack, create_middlewares_from_config, setup_middlewares_from_config) __all__ = [ 'BaseMiddleware', 'TransportMiddleware', 'ApplicationMiddleware', 'MiddlewareStack', 'create_middlewares_from_config', 'setup_middlewares_from_config'] PK=JGTe!vumi/middleware/session_length.py# -*- test-case-name: vumi.middleware.tests.test_session_length -*- from confmodel.fields import ConfigDict, ConfigInt, ConfigText from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet import reactor from vumi import log from vumi.errors import VumiError from vumi.message import TransportUserMessage from vumi.middleware.base import BaseMiddleware, BaseMiddlewareConfig from vumi.middleware.tagger import TaggingMiddleware from vumi.persist.txredis_manager import TxRedisManager class SessionLengthMiddlewareError(VumiError): """ Raised when the SessionLengthMiddleware encounters unexcepted messages. """ class SessionLengthMiddlewareConfig(BaseMiddlewareConfig): """ Configuration class for the session length middleware. """ redis_manager = ConfigDict("Redis config", default={}, static=True) timeout = ConfigInt( "Redis key timeout (secs)", default=600, static=True) namespace_type = ConfigText( "Namespace to use to lookup and set the (address, timestamp) " "key-value pairs in redis. Possible types: " " - transport_name: the message's `transport_name` field is used." " - tag: the tag associated to the message is used. *Note*: this " " requires the `TaggingMiddleware` to be earlier in the " " middleware chain.", default='transport_name', static=True) field_name = ConfigText( "Field name in message helper_metadata", default="session", static=True) class SessionLengthMiddleware(BaseMiddleware): """ Middleware for storing the session length in the message. Stores the start timestamp of a session in the message payload in message['helper_metadata'][field_name]['session_start'] for all messages in the session, as well as the end timestamp if the session in message['helper_metadata'][field_name]['session_end'] if the message marks the end of the session. Configuration options: :param dict redis: Redis configuration parameters. :param int timeout: Redis key timeout (in seconds). Defaults to 600. :param str field_name: The field name to use when storing the timestamps in the message helper_metadata. Defaults to 'session'. """ CONFIG_CLASS = SessionLengthMiddlewareConfig SESSION_NEW, SESSION_CLOSE = ( TransportUserMessage.SESSION_NEW, TransportUserMessage.SESSION_CLOSE) SESSION_START = 'session_start' SESSION_END = 'session_end' DIRECTION_INBOUND = 'inbound' DIRECTION_OUTBOUND = 'outbound' NAMESPACE_HANDLERS = { 'tag': 'get_tag', 'transport_name': 'get_transport_name' } @inlineCallbacks def setup_middleware(self): self.redis = yield TxRedisManager.from_config( self.config.redis_manager) self.timeout = self.config.timeout self.field_name = self.config.field_name self.clock = reactor namespace_type = self.config.namespace_type self.namespace_handler = getattr( self, self.NAMESPACE_HANDLERS[namespace_type]) @inlineCallbacks def teardown_middleware(self): yield self.redis.close_manager() def get_transport_name(self, message): return message['transport_name'] def get_tag(self, message): tag = TaggingMiddleware.map_msg_to_tag(message) if tag is None: return None else: return ":".join(tag) def _time(self): return self.clock.seconds() def _key_address(self, message, direction): if direction == self.DIRECTION_INBOUND: return message['from_addr'] elif direction == self.DIRECTION_OUTBOUND: return message['to_addr'] def _key(self, namespace, key_addr): return ':'.join((namespace, key_addr, 'session_created')) def _set_metadata(self, message, key, value): metadata = message['helper_metadata'].setdefault(self.field_name, {}) metadata[key] = float(value) def _has_metadata(self, message, key): metadata = message['helper_metadata'].get(self.field_name, {}) return key in metadata @inlineCallbacks def _set_new_start_time(self, message, redis_key, time): yield self.redis.setex(redis_key, self.timeout, str(time)) self._set_metadata(message, self.SESSION_START, time) @inlineCallbacks def _set_current_start_time(self, message, redis_key, clear): created_time = yield self.redis.get(redis_key) if created_time is not None: self._set_metadata(message, self.SESSION_START, created_time) if clear: yield self.redis.delete(redis_key) def _set_end_time(self, message, time): self._set_metadata(message, self.SESSION_END, time) @inlineCallbacks def _handle_start(self, message, redis_key): if not self._has_metadata(message, self.SESSION_START): time = self._time() yield self._set_new_start_time(message, redis_key, time) @inlineCallbacks def _handle_end(self, message, redis_key): if not self._has_metadata(message, self.SESSION_END): self._set_end_time(message, self._time()) if not self._has_metadata(message, self.SESSION_START): yield self._set_current_start_time(message, redis_key, clear=True) @inlineCallbacks def _handle_default(self, message, redis_key): if not self._has_metadata(message, self.SESSION_START): yield self._set_current_start_time(message, redis_key, clear=False) @inlineCallbacks def _process_message(self, message, direction): namespace = self.namespace_handler(message) if namespace is None: log.error(SessionLengthMiddlewareError( "Session length redis namespace cannot be None, " "skipping message")) returnValue(message) key_addr = self._key_address(message, direction) if key_addr is None: log.error(SessionLengthMiddlewareError( "Session length key address cannot be None, " "skipping message")) returnValue(message) redis_key = self._key(namespace, key_addr) if message.get('session_event') == self.SESSION_NEW: yield self._handle_start(message, redis_key) elif message.get('session_event') == self.SESSION_CLOSE: yield self._handle_end(message, redis_key) else: yield self._handle_default(message, redis_key) returnValue(message) def handle_inbound(self, message, connector_name): return self._process_message(message, self.DIRECTION_INBOUND) def handle_outbound(self, message, connector_name): return self._process_message(message, self.DIRECTION_OUTBOUND) PK=JGEi"vumi/middleware/provider_setter.py# -*- test-case-name: vumi.middleware.tests.test_provider_setter -*- from confmodel.fields import ConfigDict, ConfigText from vumi import log from vumi.errors import VumiError from vumi.middleware.base import TransportMiddleware, BaseMiddlewareConfig from vumi.utils import normalize_msisdn class ProviderSettingMiddlewareError(VumiError): """ Raised when provider setting middleware encounters an error. """ class StaticProviderSetterMiddlewareConfig(BaseMiddlewareConfig): """ Config class for the static provider setting middleware. """ provider = ConfigText( "Value to set the ``provider`` field to", required=True, static=True) class StaticProviderSettingMiddleware(TransportMiddleware): """ Transport middleware that sets a static ``provider`` on each inbound message and outbound message. Configuration options: :param str provider: Value to set the ``provider`` field to. .. note:: If you rely on the provider value in other middleware, please order your middleware carefully. If another middleware requires the provider for both inbound and outbound messages, you might need two copies of this middleware (one on either side of the other middleware). """ CONFIG_CLASS = StaticProviderSetterMiddlewareConfig def setup_middleware(self): self.provider_value = self.config.provider def handle_inbound(self, message, connector_name): message["provider"] = self.provider_value return message def handle_outbound(self, message, connector_name): message["provider"] = self.provider_value return message class AddressPrefixProviderSettingMiddlewareConfig(BaseMiddlewareConfig): """ Config for the address prefix provider setting middleware. """ class ConfigNormalizeMsisdn(ConfigDict): def clean(self, value): if 'country_code' not in value: self.raise_config_error( "does not contain the `country_code` key.") if 'strip_plus' not in value: value['strip_plus'] = False return super(self.__class__, self).clean(value) provider_prefixes = ConfigDict( "Mapping from address prefix to provider value. Longer prefixes are " "checked first to avoid ambiguity. If no prefix matches, the provider" " value will be set to ``None``.", required=True, static=True) normalize_msisdn = ConfigNormalizeMsisdn( "Optional MSISDN normalization config. If present, this dict should " "contain a (mandatory) ``country_code`` field and an optional boolean " "``strip_plus`` field (default ``False``). If absent, the " "``from_addr`` field will not be normalized prior to the prefix check." " (This normalization is only used for the prefix check. The " "``from_addr`` field on the message is not modified.)", static=True) class AddressPrefixProviderSettingMiddleware(TransportMiddleware): """ Transport middleware that sets a ``provider`` on each message based on configured address prefixes. Inbound messages have their provider set based on their ``from_addr``. Outbound messages have their provider set based on their ``to_addr``. Configuration options: :param dict provider_prefixes: Mapping from address prefix to provider value. Longer prefixes are checked first to avoid ambiguity. If no prefix matches, the provider value will be set to ``None``. :param dict normalize_msisdn: Optional MSISDN normalization config. If present, this dict should contain a (mandatory) ``country_code`` field and an optional boolean ``strip_plus`` field (default ``False``). If absent, the ``from_addr`` field will not be normalized prior to the prefix check. (This normalization is only used for the prefix check. The ``from_addr`` field on the message is not modified.) .. note:: If you rely on the provider value in other middleware, please order your middleware carefully. If another middleware requires the provider for both inbound and outbound messages, you might need two copies of this middleware (one on either side of the other middleware). """ CONFIG_CLASS = AddressPrefixProviderSettingMiddlewareConfig def setup_middleware(self): prefixes = self.config.provider_prefixes.items() self.provider_prefixes = sorted( prefixes, key=lambda item: -len(item[0])) self.normalize_config = self.config.normalize_msisdn def normalize_addr(self, addr): if self.normalize_config: addr = normalize_msisdn( addr, country_code=self.normalize_config["country_code"]) if self.normalize_config.get("strip_plus"): addr = addr.lstrip("+") return addr def get_provider(self, addr): if addr is None: log.error(ProviderSettingMiddlewareError( "Address for determining message provider cannot be None," " skipping message")) return None addr = self.normalize_addr(addr) for prefix, provider in self.provider_prefixes: if addr.startswith(prefix): return provider return None def handle_inbound(self, message, connector_name): if message.get("provider") is None: message["provider"] = self.get_provider(message["from_addr"]) return message def handle_outbound(self, message, connector_name): if message.get("provider") is None: message["provider"] = self.get_provider(message["to_addr"]) return message PK=JG@1$vumi/middleware/tests/test_tagger.py"""Tests for vumi.middleware.tagger.""" import re from vumi.middleware.tagger import TaggingMiddleware from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase class TestTaggingMiddleware(VumiTestCase): DEFAULT_CONFIG = { 'incoming': { 'addr_pattern': r'^\d+(\d{3})$', 'tagpool_template': r'pool1', 'tagname_template': r'mytag-\1', }, 'outgoing': { 'tagname_pattern': r'mytag-(\d{3})$', 'msg_template': { 'from_addr': r'1234*\1', }, }, } def mk_tagger(self, config=None): dummy_worker = object() if config is None: config = self.DEFAULT_CONFIG self.mw = TaggingMiddleware("dummy_tagger", config, dummy_worker) self.mw.setup_middleware() def mk_msg(self, to_addr, tag=None, from_addr="12345"): msg = TransportUserMessage(to_addr=to_addr, from_addr=from_addr, transport_name="dummy_connector", transport_type="dummy_transport_type") if tag is not None: TaggingMiddleware.add_tag_to_msg(msg, tag) return msg def get_tag(self, to_addr): msg = self.mk_msg(to_addr) msg = self.mw.handle_inbound(msg, "dummy_connector") return TaggingMiddleware.map_msg_to_tag(msg) def get_from_addr(self, to_addr, tag): msg = self.mk_msg(to_addr, tag, from_addr=None) msg = self.mw.handle_outbound(msg, "dummy_connector") return msg['from_addr'] def test_inbound_matching_to_addr(self): self.mk_tagger() self.assertEqual(self.get_tag("123456"), ("pool1", "mytag-456")) self.assertEqual(self.get_tag("1234"), ("pool1", "mytag-234")) def test_inbound_nonmatching_to_addr(self): self.mk_tagger() self.assertEqual(self.get_tag("a1234"), None) def test_inbound_nonmatching_to_addr_leaves_msg_unmodified(self): self.mk_tagger() tag = ("dont", "modify") orig_msg = self.mk_msg("a1234", tag=tag) msg = orig_msg.from_json(orig_msg.to_json()) msg = self.mw.handle_inbound(msg, "dummy_connector") self.assertEqual(msg, orig_msg) def test_inbound_none_to_addr(self): self.mk_tagger() self.assertEqual(self.get_tag(None), None) def test_outbound_matching_tag(self): self.mk_tagger() self.assertEqual(self.get_from_addr("111", ("pool1", "mytag-456")), "1234*456") self.assertEqual(self.get_from_addr("111", ("pool1", "mytag-789")), "1234*789") def test_outbound_nonmatching_tag(self): self.mk_tagger() self.assertEqual(self.get_from_addr("111", ("pool1", "othertag-456")), None) def test_outbound_nonmatching_tag_leaves_msg_unmodified(self): self.mk_tagger() orig_msg = self.mk_msg("a1234", tag=("pool1", "othertag-456")) msg = orig_msg.from_json(orig_msg.to_json()) msg = self.mw.handle_outbound(msg, "dummy_connector") for key in msg.payload.keys(): self.assertEqual(msg[key], orig_msg[key], "Key %r not equal" % key) self.assertEqual(msg, orig_msg) def test_outbound_no_tag(self): self.mk_tagger() self.assertEqual(self.get_from_addr("111", None), None) def test_deepupdate(self): orig = {'a': {'b': "foo"}, 'c': "bar"} TaggingMiddleware._deepupdate(re.match(".*", "foo"), orig, {'a': {'b': "baz"}, 'd': r'\g<0>!', 'e': 1}) self.assertEqual(orig, {'a': {'b': "baz"}, 'c': "bar", 'd': "foo!", 'e': 1}) def test_deepupdate_with_recursion(self): self.mk_tagger() orig = {'a': {'b': "foo"}, 'c': "bar"} new = {'a': {'b': "baz"}} new['a']['d'] = new TaggingMiddleware._deepupdate(re.match(".*", "foo"), orig, new) self.assertEqual(orig, {'a': {'b': "baz"}, 'c': "bar"}) def test_map_msg_to_tag(self): msg = self.mk_msg("123456") self.assertEqual(TaggingMiddleware.map_msg_to_tag(msg), None) msg['helper_metadata']['tag'] = {'tag': ['pool', 'mytag']} self.assertEqual(TaggingMiddleware.map_msg_to_tag(msg), ("pool", "mytag")) def test_add_tag_to_msg(self): msg = self.mk_msg("123456") TaggingMiddleware.add_tag_to_msg(msg, ('pool', 'mytag')) self.assertEqual(msg['helper_metadata']['tag'], { 'tag': ['pool', 'mytag'], }) def test_add_tag_to_payload(self): payload = {} TaggingMiddleware.add_tag_to_payload(payload, ('pool', 'mytag')) self.assertEqual(payload, { 'helper_metadata': { 'tag': { 'tag': ['pool', 'mytag'], }, }, }) PK=JG3YY,vumi/middleware/tests/test_session_length.py"""Tests for vumi.middleware.session_length.""" from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.task import Clock from vumi.message import TransportUserMessage from vumi.middleware.session_length import ( SessionLengthMiddleware, SessionLengthMiddlewareError) from vumi.middleware.tagger import TaggingMiddleware from vumi.tests.utils import LogCatcher from vumi.tests.helpers import VumiTestCase, PersistenceHelper SESSION_NEW, SESSION_CLOSE, SESSION_NONE = ( TransportUserMessage.SESSION_NEW, TransportUserMessage.SESSION_CLOSE, TransportUserMessage.SESSION_NONE) class TestSessionLengthMiddleware(VumiTestCase): def setUp(self): self.persistence_helper = self.add_helper(PersistenceHelper()) self.clock = Clock() @inlineCallbacks def mk_middleware(self, **kw): dummy_worker = object() config = self.persistence_helper.mk_config(kw) mw = SessionLengthMiddleware("session_length", config, dummy_worker) yield mw.setup_middleware() self.patch(mw, 'clock', self.clock) self.redis = mw.redis self.add_cleanup(mw.teardown_middleware) returnValue(mw) def mk_msg(self, to_addr, from_addr, session_event=SESSION_NEW, session_start=None, session_end=None, tag=None, transport_name='dummy_transport'): msg = TransportUserMessage( to_addr=to_addr, from_addr=from_addr, transport_name=transport_name, transport_type="dummy_transport_type", session_event=session_event) if tag is not None: TaggingMiddleware.add_tag_to_msg(msg, tag) if session_start is not None: self._set_metadata(msg, 'session_start', session_start) if session_end is not None: self._set_metadata(msg, 'session_end', session_end) return msg def _set_metadata(self, msg, name, value, metadata_field_name='session'): metadata = msg['helper_metadata'].setdefault(metadata_field_name, {}) metadata[name] = value def assert_middleware_error(self, msg): [err] = self.flushLoggedErrors(SessionLengthMiddlewareError) self.assertEqual(str(err.value), msg) @inlineCallbacks def test_incoming_message_session_start(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', '+54321') msg = yield mw.handle_inbound(msg_start, "dummy_connector") value = yield self.redis.get('dummy_transport:+54321:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) @inlineCallbacks def test_incoming_message_session_end(self): mw = yield self.mk_middleware() msg_end = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) msg = yield mw.handle_inbound(msg_end, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_end'], 0.0) @inlineCallbacks def test_incoming_message_session_start_end(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', '+54321') msg_end = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) msg = yield mw.handle_inbound(msg_start, "dummy_connector") value = yield self.redis.get('dummy_transport:+54321:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) self.clock.advance(1.0) msg = yield mw.handle_inbound(msg_end, "dummy_connector") value = yield self.redis.get('dummy_transport:+54321:session_created') self.assertTrue(value is None) self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 1.0) @inlineCallbacks def test_incoming_message_session_start_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23, session_event=SESSION_NEW) yield self.redis.set('dummy_transport:+12345:session_created', '23.0') msg = yield mw.handle_inbound(msg, "dummy_connector") value = yield self.redis.get('dummy_transport:+12345:session_created') self.assertEqual(value, '23.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) @inlineCallbacks def test_incoming_message_session_end_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23.0, session_end=32.0, session_event=SESSION_CLOSE) msg = yield mw.handle_inbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 32.0) @inlineCallbacks def test_incoming_message_session_none_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23.0, session_event=SESSION_NONE) msg = yield mw.handle_inbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) @inlineCallbacks def test_incoming_message_session_start_transport_namespace_type(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg('+12345', '+54321', transport_name='foo') msg = yield mw.handle_inbound(msg, "dummy_connector") value = yield self.redis.get('foo:+54321:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) @inlineCallbacks def test_incoming_message_session_end_transport_namespace_type(self): mw = yield self.mk_middleware(namespace_type='transport_name') yield self.redis.set('foo:+54321:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, transport_name='foo') msg = yield mw.handle_inbound(msg, "dummy_connector") value = yield self.redis.get('foo:+54321:session_created') self.assertEqual(value, None) self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 0.0) @inlineCallbacks def test_incoming_message_session_none_transport_namespace_type(self): mw = yield self.mk_middleware(namespace_type='transport_name') yield self.redis.set('foo:+54321:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, transport_name='foo') msg = yield mw.handle_inbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) @inlineCallbacks def test_incoming_message_session_start_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg('+12345', '+54321', transport_name=None) result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_incoming_message_session_end_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, transport_name=None) result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_incoming_message_session_none_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_NONE, transport_name=None) result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_incoming_message_session_start_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') msg = self.mk_msg('+12345', '+54321', tag=('pool1', 'tag1')) msg = yield mw.handle_inbound(msg, "dummy_connector") value = yield self.redis.get('pool1:tag1:+54321:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) @inlineCallbacks def test_incoming_message_session_end_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+54321:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, tag=('pool1', 'tag1')) msg = yield mw.handle_inbound(msg, "dummy_connector") value = yield self.redis.get('pool1:tag1:+54321:session_created') self.assertEqual(value, None) self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 0.0) @inlineCallbacks def test_incoming_message_session_none_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+54321:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, tag=('pool1', 'tag1')) msg = yield mw.handle_inbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) @inlineCallbacks def test_incoming_message_session_start_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') # create message with no tag msg = self.mk_msg('+12345', '+54321') result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_incoming_message_session_end_tag_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') # create message with no tag msg = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_incoming_message_session_none_tag_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+54321:session_created', '23.0') # create message with no tag msg = self.mk_msg('+12345', '+54321', session_event=SESSION_NONE) result_msg = yield mw.handle_inbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_start(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', '+54321') msg = yield mw.handle_outbound(msg_start, "dummy_connector") value = yield self.redis.get('dummy_transport:+12345:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) @inlineCallbacks def test_outgoing_message_session_end(self): mw = yield self.mk_middleware() msg_end = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) msg = yield mw.handle_outbound(msg_end, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_end'], 0.0) @inlineCallbacks def test_outgoing_message_session_start_end(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', '+54321') msg_end = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) msg = yield mw.handle_outbound(msg_start, "dummy_connector") value = yield self.redis.get('dummy_transport:+12345:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) self.clock.advance(1.0) msg = yield mw.handle_outbound(msg_end, "dummy_connector") value = yield self.redis.get('dummy_transport:+54321:session_created') self.assertTrue(value is None) self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 1.0) @inlineCallbacks def test_outgoing_message_session_start_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23, session_event=SESSION_NEW) yield self.redis.set('dummy_transport:+12345:session_created', '23') msg = yield mw.handle_outbound(msg, "dummy_connector") value = yield self.redis.get('dummy_transport:+12345:session_created') self.assertEqual(value, '23') self.assertEqual( msg['helper_metadata']['session']['session_start'], 23) @inlineCallbacks def test_outgoing_message_session_end_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23, session_end=32, session_event=SESSION_CLOSE) msg = yield mw.handle_outbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23) self.assertEqual( msg['helper_metadata']['session']['session_end'], 32) @inlineCallbacks def test_outgoing_message_session_none_no_overwrite(self): mw = yield self.mk_middleware() msg = self.mk_msg( '+12345', '+54321', session_start=23, session_event=SESSION_NONE) msg = yield mw.handle_outbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23) @inlineCallbacks def test_outgoing_message_session_start_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg('+12345', '+54321', transport_name=None) result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_end_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, transport_name=None) result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_none_no_transport_name(self): mw = yield self.mk_middleware(namespace_type='transport_name') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_NONE, transport_name=None) result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_start_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') msg = self.mk_msg('+12345', '+54321', tag=('pool1', 'tag1')) msg = yield mw.handle_outbound(msg, "dummy_connector") value = yield self.redis.get('pool1:tag1:+12345:session_created') self.assertEqual(value, '0.0') self.assertEqual( msg['helper_metadata']['session']['session_start'], 0.0) @inlineCallbacks def test_outgoing_message_session_end_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+12345:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_CLOSE, tag=('pool1', 'tag1')) msg = yield mw.handle_outbound(msg, "dummy_connector") value = yield self.redis.get('tag1:+12345:session_created') self.assertEqual(value, None) self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) self.assertEqual( msg['helper_metadata']['session']['session_end'], 0.0) @inlineCallbacks def test_outgoing_message_session_none_tag_namespace_type(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+12345:session_created', '23.0') msg = self.mk_msg( '+12345', '+54321', session_event=SESSION_NONE, tag=('pool1', 'tag1')) msg = yield mw.handle_outbound(msg, "dummy_connector") self.assertEqual( msg['helper_metadata']['session']['session_start'], 23.0) @inlineCallbacks def test_outgoing_message_session_start_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') # create message with no tag msg = self.mk_msg('+12345', '+54321') result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_end_tag_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') # create message with no tag msg = self.mk_msg('+12345', '+54321', session_event=SESSION_CLOSE) result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_outgoing_message_session_none_tag_no_tag(self): mw = yield self.mk_middleware(namespace_type='tag') yield self.redis.set('pool1:tag1:+54321:session_created', '23.0') # create message with no tag msg = self.mk_msg('+12345', '+54321', session_event=SESSION_NONE) result_msg = yield mw.handle_outbound(msg, "dummy_connector") self.assert_middleware_error( "Session length redis namespace cannot be None, skipping message") self.assertEqual(msg, result_msg) self.assertTrue('session' not in msg['helper_metadata']) @inlineCallbacks def test_redis_key_timeout(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', '+54321') yield mw.handle_inbound(msg_start, "dummy_connector") ttl = yield self.redis.ttl('dummy_transport::+54321:session_created') self.assertTrue(ttl <= 120) @inlineCallbacks def test_redis_key_custom_timeout(self): mw = yield self.mk_middleware(timeout=20) msg_start = self.mk_msg('+12345', '+54321') yield mw.handle_inbound(msg_start, "dummy_connector") ttl = yield self.redis.ttl('dummy_transport:+54321:session_created') self.assertTrue(ttl <= 20) @inlineCallbacks def test_custom_message_field_name(self): mw = yield self.mk_middleware(field_name='foobar') msg_start = self.mk_msg('+12345', '+54321') msg = yield mw.handle_inbound(msg_start, "dummy_connector") self.assertEqual( msg['helper_metadata']['foobar']['session_start'], 0.0) @inlineCallbacks def test_incoming_message_session_no_from_addr(self): mw = yield self.mk_middleware() msg_start = self.mk_msg('+12345', None) msg = yield mw.handle_inbound(msg_start, "dummy_connector") self.assertEqual(msg, msg_start) self.assert_middleware_error( "Session length key address cannot be None, skipping message") @inlineCallbacks def test_outgoing_message_session_no_to_addr(self): mw = yield self.mk_middleware() msg_start = self.mk_msg(None, '+54321') msg = yield mw.handle_outbound(msg_start, "dummy_connector") self.assertEqual(msg, msg_start) self.assert_middleware_error( "Session length key address cannot be None, skipping message") PK=JG6gD %vumi/middleware/tests/test_logging.py"""Tests from vumi.middleware.logging.""" from vumi.middleware.logging import LoggingMiddleware from vumi.tests.utils import LogCatcher from vumi.tests.helpers import VumiTestCase class DummyMessage(object): def __init__(self, json): self._json = json def to_json(self): return self._json class TestLoggingMiddleware(VumiTestCase): def mklogger(self, config): worker = object() mw = LoggingMiddleware("test_logger", config, worker) mw.setup_middleware() return mw def test_default_config(self): mw = self.mklogger({}) with LogCatcher() as lc: for handler, rkey in [ (mw.handle_inbound, "inbound"), (mw.handle_outbound, "outbound"), (mw.handle_event, "event"), (mw.handle_failure, "failure")]: msg = DummyMessage(rkey) result = handler(msg, "dummy_connector") self.assertEqual(result, msg) logs = lc.logs self.assertEqual([log['logLevel'] for log in logs], [20, 20, 20, 40]) self.assertEqual([log['message'][0] for log in logs], [ "Processed inbound message for dummy_connector: inbound", "Processed outbound message for dummy_connector: outbound", "Processed event message for dummy_connector: event", "'Processed failure message for dummy_connector: failure'", ]) def test_custom_log_level(self): mw = self.mklogger({'log_level': 'warning'}) with LogCatcher() as lc: msg = DummyMessage("inbound") result = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(result, msg) logs = lc.logs self.assertEqual([log['logLevel'] for log in logs], [30]) self.assertEqual([log['message'][0] for log in logs], [ "Processed inbound message for dummy_connector: inbound", ]) def test_custom_failure_log_level(self): mw = self.mklogger({'failure_log_level': 'info'}) with LogCatcher() as lc: msg = DummyMessage("failure") result = mw.handle_failure(msg, "dummy_connector") self.assertEqual(result, msg) logs = lc.logs self.assertEqual([log['logLevel'] for log in logs], [20]) self.assertEqual([log['message'][0] for log in logs], [ "Processed failure message for dummy_connector: failure", ]) PK=JG1"bvumi/middleware/tests/utils.py"""Helpers for testing middleware.""" from vumi.middleware.base import BaseMiddleware class RecordingMiddleware(BaseMiddleware): """Marks the calling of middleware in a custom attribute on the processed message. Useful for testing middleware is being called correctly. """ def _handle(self, direction, message, connector_name): record = message.payload.setdefault('record', []) record.append((self.name, direction, connector_name)) return message def handle_inbound(self, message, connector_name): return self._handle('inbound', message, connector_name) def handle_outbound(self, message, connector_name): return self._handle('outbound', message, connector_name) def handle_event(self, message, connector_name): return self._handle('event', message, connector_name) def handle_failure(self, message, connector_name): return self._handle('failure', message, connector_name) PK=JG0W//"vumi/middleware/tests/test_base.pyimport yaml import itertools from confmodel.fields import ConfigInt from twisted.internet.defer import inlineCallbacks, returnValue from vumi.middleware.base import ( BaseMiddleware, MiddlewareStack, create_middlewares_from_config, setup_middlewares_from_config, BaseMiddlewareConfig) from vumi.tests.helpers import VumiTestCase class ToyMiddlewareConfig(BaseMiddlewareConfig): """ Config for the toy middleware. """ param_foo = ConfigInt("Foo parameter", static=True) param_bar = ConfigInt("Bar parameter", static=True) class ToyMiddleware(BaseMiddleware): CONFIG_CLASS = ToyMiddlewareConfig # simple attribute to check that setup_middleware is called _setup_done = False _teardown_count = itertools.count(1) def _handle(self, direction, message, connector_name): message = '%s.%s' % (message, self.name) self.worker.processed(self.name, direction, message, connector_name) return message def setup_middleware(self): self._setup_done = True self._teardown_done = False def teardown_middleware(self): self._teardown_done = next(self._teardown_count) def handle_inbound(self, message, connector_name): return self._handle('inbound', message, connector_name) def handle_outbound(self, message, connector_name): return self._handle('outbound', message, connector_name) def handle_event(self, message, connector_name): return self._handle('event', message, connector_name) def handle_failure(self, message, connector_name): return self._handle('failure', message, connector_name) class ToyAsymmetricMiddleware(ToyMiddleware): def _handle(self, direction, message, connector_name): message = '%s.%s' % (message, self.name) self.worker.processed(self.name, direction, message, connector_name) return message def handle_consume_inbound(self, message, connector_name): return self._handle('consume_inbound', message, connector_name) def handle_publish_inbound(self, message, connector_name): return self._handle('publish_inbound', message, connector_name) def handle_consume_outbound(self, message, connector_name): return self._handle('consume_outbound', message, connector_name) def handle_publish_outbound(self, message, connector_name): return self._handle('publish_outbound', message, connector_name) def handle_consume_event(self, message, connector_name): return self._handle('consume_event', message, connector_name) def handle_publish_event(self, message, connector_name): return self._handle('publish_event', message, connector_name) def handle_consume_failure(self, message, connector_name): return self._handle('consume_failure', message, connector_name) def handle_publish_failure(self, message, connector_name): return self._handle('publish_failure', message, connector_name) class TestMiddlewareStack(VumiTestCase): @inlineCallbacks def setUp(self): self.stack = MiddlewareStack([ (yield self.mkmiddleware('mw1', ToyMiddleware)), (yield self.mkmiddleware('mw2', ToyAsymmetricMiddleware)), (yield self.mkmiddleware('mw3', ToyMiddleware)), ]) self.processed_messages = [] @inlineCallbacks def mkmiddleware(self, name, mw_class): mw = mw_class( name, {'consume_priority': 0, 'publish_priority': 0}, self) yield mw.setup_middleware() returnValue(mw) @inlineCallbacks def mk_priority_middleware(self, name, mw_class, consume_pri, publish_pri): mw = mw_class( name, { 'consume_priority': consume_pri, 'publish_priority': publish_pri, }, self) yield mw.setup_middleware() returnValue(mw) def processed(self, name, direction, message, connector_name): self.processed_messages.append( (name, direction, message, connector_name)) def assert_processed(self, expected): self.assertEqual(expected, self.processed_messages) @inlineCallbacks def test_apply_consume_inbound(self): self.assert_processed([]) yield self.stack.apply_consume('inbound', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw1', 'inbound', 'dummy_msg.mw1', 'end_foo'), ('mw2', 'consume_inbound', 'dummy_msg.mw1.mw2', 'end_foo'), ('mw3', 'inbound', 'dummy_msg.mw1.mw2.mw3', 'end_foo'), ]) @inlineCallbacks def test_apply_publish_inbound(self): self.assert_processed([]) yield self.stack.apply_publish('inbound', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw3', 'inbound', 'dummy_msg.mw3', 'end_foo'), ('mw2', 'publish_inbound', 'dummy_msg.mw3.mw2', 'end_foo'), ('mw1', 'inbound', 'dummy_msg.mw3.mw2.mw1', 'end_foo'), ]) @inlineCallbacks def test_apply_consume_outbound(self): self.assert_processed([]) yield self.stack.apply_consume('outbound', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw1', 'outbound', 'dummy_msg.mw1', 'end_foo'), ('mw2', 'consume_outbound', 'dummy_msg.mw1.mw2', 'end_foo'), ('mw3', 'outbound', 'dummy_msg.mw1.mw2.mw3', 'end_foo'), ]) @inlineCallbacks def test_apply_publish_outbound(self): self.assert_processed([]) yield self.stack.apply_publish('outbound', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw3', 'outbound', 'dummy_msg.mw3', 'end_foo'), ('mw2', 'publish_outbound', 'dummy_msg.mw3.mw2', 'end_foo'), ('mw1', 'outbound', 'dummy_msg.mw3.mw2.mw1', 'end_foo'), ]) @inlineCallbacks def test_apply_consume_event(self): self.assert_processed([]) yield self.stack.apply_consume('event', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw1', 'event', 'dummy_msg.mw1', 'end_foo'), ('mw2', 'consume_event', 'dummy_msg.mw1.mw2', 'end_foo'), ('mw3', 'event', 'dummy_msg.mw1.mw2.mw3', 'end_foo'), ]) @inlineCallbacks def test_apply_publish_event(self): self.assert_processed([]) yield self.stack.apply_publish('event', 'dummy_msg', 'end_foo') self.assert_processed([ ('mw3', 'event', 'dummy_msg.mw3', 'end_foo'), ('mw2', 'publish_event', 'dummy_msg.mw3.mw2', 'end_foo'), ('mw1', 'event', 'dummy_msg.mw3.mw2.mw1', 'end_foo'), ]) @inlineCallbacks def test_teardown_in_reverse_order(self): def get_teardown_timestamps(): return [mw._teardown_done for mw in self.stack.consume_middlewares] self.assertFalse(any(get_teardown_timestamps())) yield self.stack.teardown() self.assertTrue(all(get_teardown_timestamps())) teardown_order = sorted(self.stack.consume_middlewares, key=lambda mw: mw._teardown_done) self.assertEqual([mw.name for mw in teardown_order], ['mw3', 'mw2', 'mw1']) @inlineCallbacks def test_middleware_priority_ordering(self): self.stack = MiddlewareStack([ (yield self.mk_priority_middleware('p2', ToyMiddleware, 3, 3)), (yield self.mkmiddleware('pn', ToyMiddleware)), (yield self.mk_priority_middleware('p1_1', ToyMiddleware, 2, 2)), (yield self.mk_priority_middleware('p1_2', ToyMiddleware, 2, 2)), (yield self.mk_priority_middleware('pasym', ToyMiddleware, 1, 4)), ]) # test consume self.assert_processed([]) yield self.stack.apply_consume('event', 'dummy_msg', 'end_foo') self.assert_processed([ ('pn', 'event', 'dummy_msg.pn', 'end_foo'), ('pasym', 'event', 'dummy_msg.pn.pasym', 'end_foo'), ('p1_1', 'event', 'dummy_msg.pn.pasym.p1_1', 'end_foo'), ('p1_2', 'event', 'dummy_msg.pn.pasym.p1_1.p1_2', 'end_foo'), ('p2', 'event', 'dummy_msg.pn.pasym.p1_1.p1_2.p2', 'end_foo'), ]) # test publish self.processed_messages = [] yield self.stack.apply_publish('event', 'dummy_msg', 'end_foo') self.assert_processed([ ('pn', 'event', 'dummy_msg.pn', 'end_foo'), ('p1_2', 'event', 'dummy_msg.pn.p1_2', 'end_foo'), ('p1_1', 'event', 'dummy_msg.pn.p1_2.p1_1', 'end_foo'), ('p2', 'event', 'dummy_msg.pn.p1_2.p1_1.p2', 'end_foo'), ('pasym', 'event', 'dummy_msg.pn.p1_2.p1_1.p2.pasym', 'end_foo'), ]) class TestUtilityFunctions(VumiTestCase): TEST_CONFIG_1 = { "middleware": [ {"mw1": "vumi.middleware.tests.test_base.ToyMiddleware"}, {"mw2": { 'class': "vumi.middleware.tests.test_base.ToyMiddleware", 'consume_priority': 1, 'publish_priority': -1}}, ], "mw1": { "param_foo": 1, "param_bar": 2, } } TEST_YAML = """ middleware: - mw1: vumi.middleware.tests.test_base.ToyMiddleware - mw2: vumi.middleware.tests.test_base.ToyMiddleware """ def test_create_middleware_from_config(self): worker = object() middlewares = create_middlewares_from_config(worker, self.TEST_CONFIG_1) self.assertEqual([type(mw) for mw in middlewares], [ToyMiddleware, ToyMiddleware]) self.assertEqual([mw._setup_done for mw in middlewares], [False, False]) self.assertEqual(middlewares[0].config.param_foo, 1) self.assertEqual(middlewares[0].config.param_bar, 2) self.assertEqual(middlewares[0].consume_priority, 0) self.assertEqual(middlewares[0].publish_priority, 0) self.assertEqual(middlewares[1].consume_priority, 1) self.assertEqual(middlewares[1].publish_priority, -1) @inlineCallbacks def test_setup_middleware_from_config(self): worker = object() middlewares = yield setup_middlewares_from_config(worker, self.TEST_CONFIG_1) self.assertEqual([type(mw) for mw in middlewares], [ToyMiddleware, ToyMiddleware]) self.assertEqual([mw._setup_done for mw in middlewares], [True, True]) self.assertEqual(middlewares[0].config.param_foo, 1) self.assertEqual(middlewares[0].config.param_bar, 2) self.assertEqual(middlewares[0].consume_priority, 0) self.assertEqual(middlewares[0].publish_priority, 0) self.assertEqual(middlewares[1].consume_priority, 1) self.assertEqual(middlewares[1].publish_priority, -1) def test_parse_yaml(self): # this test is here to ensure the YAML one has to # type looks nice worker = object() config = yaml.safe_load(self.TEST_YAML) middlewares = create_middlewares_from_config(worker, config) self.assertEqual([type(mw) for mw in middlewares], [ToyMiddleware, ToyMiddleware]) self.assertEqual([mw._setup_done for mw in middlewares], [False, False]) @inlineCallbacks def test_sort_by_priority(self): priority2 = ToyMiddleware('priority2', {}, self) priority2.priority = 2 priority1_1 = ToyMiddleware('priority1_1', {}, self) priority1_1.priority = 1 priority1_2 = ToyMiddleware('priority1_2', {}, self) priority1_2.priority = 1 middlewares = [priority2, priority1_1, priority1_2] for mw in middlewares: yield mw.setup_middleware() mw_sorted = MiddlewareStack._sort_by_priority(middlewares, 'priority') self.assertEqual( mw_sorted, [priority1_1, priority1_2, priority2]) PK=JG!vumi/middleware/tests/__init__.pyPK=JGG$G$-vumi/middleware/tests/test_provider_setter.py"""Tests for vumi.middleware.provider_setter.""" from vumi.middleware.provider_setter import ( StaticProviderSettingMiddleware, AddressPrefixProviderSettingMiddleware, ProviderSettingMiddlewareError) from vumi.tests.helpers import VumiTestCase, MessageHelper class TestStaticProviderSettingMiddleware(VumiTestCase): def setUp(self): self.msg_helper = self.add_helper(MessageHelper()) def mk_middleware(self, config): dummy_worker = object() mw = StaticProviderSettingMiddleware( "static_provider_setter", config, dummy_worker) mw.setup_middleware() return mw def test_set_provider_on_inbound_if_unset(self): """ The statically configured provider value is set on inbound messages that have no provider. """ mw = self.mk_middleware({"provider": "MY-MNO"}) msg = self.msg_helper.make_inbound(None) self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_replace_provider_on_inbound_if_set(self): """ The statically configured provider value replaces any existing provider a message may already have set. """ mw = self.mk_middleware({"provider": "MY-MNO"}) msg = self.msg_helper.make_inbound(None, provider="YOUR-MNO") self.assertEqual(msg.get("provider"), "YOUR-MNO") processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_set_provider_on_outbound_if_unset(self): """ Outbound messages are left as they are. """ mw = self.mk_middleware({"provider": "MY-MNO"}) msg = self.msg_helper.make_outbound(None) self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_outbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") class TestAddressPrefixProviderSettingMiddleware(VumiTestCase): def setUp(self): self.msg_helper = self.add_helper(MessageHelper()) def mk_middleware(self, config): dummy_worker = object() mw = AddressPrefixProviderSettingMiddleware( "address_prefix_provider_setter", config, dummy_worker) mw.setup_middleware() return mw def assert_middleware_error(self, msg): [err] = self.flushLoggedErrors(ProviderSettingMiddlewareError) self.assertEqual(str(err.value), msg) def test_set_provider_unique_matching_prefix(self): """ If exactly one prefix matches the address, its corresponding provider value is set on the inbound message. """ mw = self.mk_middleware({"provider_prefixes": { "+123": "MY-MNO", "+124": "YOUR-MNO", }}) msg = self.msg_helper.make_inbound(None, from_addr="+12345") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_set_provider_longest_matching_prefix(self): """ If more than one prefix matches the address, the provider value for the longest matching prefix is set on the inbound message. """ mw = self.mk_middleware({"provider_prefixes": { "+12": "YOUR-MNO", "+123": "YOUR-MNO", "+1234": "YOUR-MNO", "+12345": "MY-MNO", "+123456": "YOUR-MNO", }}) msg = self.msg_helper.make_inbound(None, from_addr="+12345") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_no_provider_for_no_matching_prefix(self): """ If no prefix matches the address, the provider value will be set to ``None`` on the inbound message. """ mw = self.mk_middleware({"provider_prefixes": { "+124": "YOUR-MNO", "+125": "YOUR-MNO", }}) msg = self.msg_helper.make_inbound(None, from_addr="+12345") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), None) def test_set_provider_no_normalize_msisdn(self): """ If exactly one prefix matches the address, its corresponding provider value is set on the inbound message. """ mw = self.mk_middleware({ "provider_prefixes": { "083": "MY-MNO", "+2783": "YOUR-MNO", }, }) msg = self.msg_helper.make_inbound(None, from_addr="0831234567") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_set_provider_normalize_msisdn(self): """ If exactly one prefix matches the address, its corresponding provider value is set on the inbound message. """ mw = self.mk_middleware({ "normalize_msisdn": {"country_code": "27"}, "provider_prefixes": { "083": "YOUR-MNO", "+2783": "MY-MNO", }, }) msg = self.msg_helper.make_inbound(None, from_addr="0831234567") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_set_provider_normalize_msisdn_strip_plus(self): """ If exactly one prefix matches the address, its corresponding provider value is set on the inbound message. """ mw = self.mk_middleware({ "normalize_msisdn": {"country_code": "27", "strip_plus": True}, "provider_prefixes": { "083": "YOUR-MNO", "+2783": "YOUR-MNO", "2783": "MY-MNO", }, }) msg = self.msg_helper.make_inbound(None, from_addr="0831234567") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_set_provider_on_outbound(self): """ Outbound messages are left as they are. """ mw = self.mk_middleware({"provider_prefixes": {"+123": "MY-MNO"}}) msg = self.msg_helper.make_outbound( None, to_addr="+1234567", from_addr="+12345") self.assertEqual(msg.get("provider"), None) processed_msg = mw.handle_outbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "MY-MNO") def test_provider_not_overwritten_for_inbound(self): """ If a provider already exists for an inbound message, it isn't overwritten. """ mw = self.mk_middleware({"provider_prefixes": {"+123": "MY-MNO"}}) msg = self.msg_helper.make_inbound( None, to_addr="+345", from_addr="+12345", provider="OTHER-MNO") processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "OTHER-MNO") def test_provider_not_overwritten_for_outbound(self): """ If a provider already exists for an outbound message, it isn't overwritten. """ mw = self.mk_middleware({"provider_prefixes": {"+123": "MY-MNO"}}) msg = self.msg_helper.make_outbound( None, to_addr="+1234567", from_addr="+345", provider="OTHER-MNO") processed_msg = mw.handle_outbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), "OTHER-MNO") def test_provider_logs_no_address_error_for_inbound(self): """ If the from_addr of an inbound message is None, an error should be logged and the message returned. """ mw = self.mk_middleware({"provider_prefixes": {"+123": "MY-MNO"}}) msg = self.msg_helper.make_inbound( None, to_addr="+1234567", from_addr=None) processed_msg = mw.handle_inbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), None) self.assert_middleware_error( "Address for determining message provider cannot be None," " skipping message") def test_provider_logs_no_address_error_for_outbound(self): """ If the to_addr of an outbound message is None, an error should be logged and the message returned. """ mw = self.mk_middleware({"provider_prefixes": {"+123": "MY-MNO"}}) msg = self.msg_helper.make_outbound( None, to_addr=None, from_addr="+345") processed_msg = mw.handle_outbound(msg, "dummy_connector") self.assertEqual(processed_msg.get("provider"), None) self.assert_middleware_error( "Address for determining message provider cannot be None," " skipping message") PK=JG $W W %vumi/middleware/tests/test_manhole.pyfrom twisted.trial.unittest import SkipTest from twisted.internet import defer, protocol, reactor from twisted.internet.defer import inlineCallbacks from twisted.internet.endpoints import TCP4ClientEndpoint from vumi.tests.helpers import VumiTestCase try: from twisted.conch.manhole_ssh import ConchFactory from twisted.conch.ssh import session from vumi.middleware.manhole import ManholeMiddleware from vumi.middleware.manhole_utils import ClientTransport # these are shipped along with Twisted private_key = ConchFactory.privateKeys['ssh-rsa'] public_key = ConchFactory.publicKeys['ssh-rsa'] except ImportError: ssh = False else: ssh = True class DummyWorker(object): pass class TestManholeMiddleware(VumiTestCase): def setUp(self): if not ssh: raise SkipTest('Crypto requirements missing. Skipping Test.') self.pub_key_file_name = self.mktemp() self.pub_key_file = open(self.pub_key_file_name, 'w') self.pub_key_file.write(public_key.toString('OPENSSH')) self.pub_key_file.flush() self._middlewares = [] self._client_sockets = [] self.mw = self.get_middleware({ 'authorized_keys': [self.pub_key_file.name] }) @inlineCallbacks def open_shell(self, middleware): host = middleware.socket.getHost() factory = protocol.ClientFactory() factory.protocol = ClientTransport factory.channelConnected = defer.Deferred() endpoint = TCP4ClientEndpoint(reactor, host.host, host.port) proto = yield endpoint.connect(factory) channel = yield factory.channelConnected conn = channel.conn term = session.packRequest_pty_req("vt100", (0, 0, 0, 0), '') yield conn.sendRequest(channel, 'pty-req', term, wantReply=1) yield conn.sendRequest(channel, 'shell', '', wantReply=1) self._client_sockets.append(proto) self.add_cleanup(proto.loseConnection) defer.returnValue(channel) def get_middleware(self, config={}): config = dict({ 'port': '0', }, **config) worker = DummyWorker() worker.transport_name = 'foo' mw = ManholeMiddleware("test_manhole_mw", config, worker) mw.setup_middleware() self._middlewares.append(mw) self.add_cleanup(mw.teardown_middleware) return mw @inlineCallbacks def test_mw(self): shell = yield self.open_shell(self.mw) shell.write('print worker.transport_name\n') # read the echoed line we sent first, this is hard to test because # I'm not seeing how I can tell Twisted not to use color in the # returned response. yield shell.queue.get() # next is the server response received_line = yield shell.queue.get() self.assertEqual(received_line, 'foo') PK=JGC--vumi/middleware/tests/test_message_storing.py"""Tests for vumi.middleware.message_storing.""" from twisted.internet.defer import inlineCallbacks, returnValue from vumi.middleware.tagger import TaggingMiddleware from vumi.message import TransportUserMessage, TransportEvent from vumi.tests.helpers import VumiTestCase, PersistenceHelper class TestStoringMiddleware(VumiTestCase): def setUp(self): self.persistence_helper = self.add_helper( PersistenceHelper(use_riak=True)) # Create and stash a riak manager to clean up afterwards, because we # don't get access to the one inside the middleware. manager = self.persistence_helper.get_riak_manager() self.add_cleanup(manager.close_manager) @inlineCallbacks def setup_middleware(self, config={}): # We've already skipped the test by now if we don't have riakasaurus, # so it's safe to import stuff that pulls it in without guards. from vumi.middleware.message_storing import StoringMiddleware config = self.persistence_helper.mk_config(config) dummy_worker = object() mw = StoringMiddleware("dummy_storer", config, dummy_worker) self.add_cleanup(mw.teardown_middleware) yield mw.setup_middleware() self.store = mw.store self.add_cleanup(self.store.manager.purge_all) yield self.store.manager.purge_all() yield self.store.cache.redis._purge_all() # just in case returnValue(mw) def mk_msg(self): msg = TransportUserMessage(to_addr="45678", from_addr="12345", transport_name="dummy_connector", transport_type="dummy_transport_type") return msg def mk_ack(self, user_message_id="1"): ack = TransportEvent(event_type="ack", user_message_id=user_message_id, sent_message_id="1") return ack @inlineCallbacks def assert_batch_keys(self, batch_id, outbound=[], inbound=[]): outbound_keys = yield self.store.batch_outbound_keys(batch_id) self.assertEqual(sorted(outbound_keys), sorted(outbound)) inbound_keys = yield self.store.batch_inbound_keys(batch_id) self.assertEqual(sorted(inbound_keys), sorted(inbound)) @inlineCallbacks def assert_outbound_stored(self, msg, batch_id=None, events=[]): msg_id = msg['message_id'] stored_msg = yield self.store.get_outbound_message(msg_id) self.assertEqual(stored_msg, msg) event_keys = yield self.store.message_event_keys(msg_id) self.assertEqual(sorted(event_keys), sorted(events)) if batch_id is not None: yield self.assert_batch_keys(batch_id, outbound=[msg_id]) @inlineCallbacks def assert_outbound_not_stored(self, msg): msg_id = msg['message_id'] stored_msg = yield self.store.get_outbound_message(msg_id) self.assertEqual(stored_msg, None) @inlineCallbacks def assert_inbound_stored(self, msg, batch_id=None): msg_id = msg['message_id'] stored_msg = yield self.store.get_inbound_message(msg_id) self.assertEqual(stored_msg, msg) if batch_id is not None: yield self.assert_batch_keys(batch_id, inbound=[msg_id]) @inlineCallbacks def assert_inbound_not_stored(self, msg): msg_id = msg['message_id'] stored_msg = yield self.store.get_inbound_message(msg_id) self.assertEqual(stored_msg, None) @inlineCallbacks def test_handle_outbound(self): mw = yield self.setup_middleware() msg1 = self.mk_msg() resp1 = yield mw.handle_consume_outbound(msg1, "dummy_connector") self.assertEqual(msg1, resp1) yield self.assert_outbound_stored(msg1) msg2 = self.mk_msg() resp2 = yield mw.handle_publish_outbound(msg2, "dummy_connector") self.assertEqual(msg2, resp2) yield self.assert_outbound_stored(msg2) self.assertNotEqual(msg1['message_id'], msg2['message_id']) @inlineCallbacks def test_handle_outbound_no_consume_store(self): mw = yield self.setup_middleware({'store_on_consume': False}) msg1 = self.mk_msg() resp1 = yield mw.handle_consume_outbound(msg1, "dummy_connector") self.assertEqual(msg1, resp1) yield self.assert_outbound_not_stored(msg1) msg2 = self.mk_msg() resp2 = yield mw.handle_publish_outbound(msg2, "dummy_connector") self.assertEqual(msg2, resp2) yield self.assert_outbound_stored(msg2) self.assertNotEqual(msg1['message_id'], msg2['message_id']) @inlineCallbacks def test_handle_outbound_with_tag(self): mw = yield self.setup_middleware() batch_id = yield self.store.batch_start([("pool", "tag")]) msg = self.mk_msg() TaggingMiddleware.add_tag_to_msg(msg, ["pool", "tag"]) response = yield mw.handle_outbound(msg, "dummy_connector") self.assertEqual(response, msg) yield self.assert_outbound_stored(msg, batch_id) @inlineCallbacks def test_handle_inbound(self): mw = yield self.setup_middleware() msg1 = self.mk_msg() resp1 = yield mw.handle_consume_inbound(msg1, "dummy_connector") self.assertEqual(resp1, msg1) yield self.assert_inbound_stored(msg1) msg2 = self.mk_msg() resp2 = yield mw.handle_publish_inbound(msg2, "dummy_connector") self.assertEqual(resp2, msg2) yield self.assert_inbound_stored(msg2) self.assertNotEqual(msg1['message_id'], msg2['message_id']) @inlineCallbacks def test_handle_inbound_no_consume_store(self): mw = yield self.setup_middleware({'store_on_consume': False}) msg1 = self.mk_msg() resp1 = yield mw.handle_consume_inbound(msg1, "dummy_connector") self.assertEqual(resp1, msg1) yield self.assert_inbound_not_stored(msg1) msg2 = self.mk_msg() resp2 = yield mw.handle_publish_inbound(msg2, "dummy_connector") self.assertEqual(resp2, msg2) yield self.assert_inbound_stored(msg2) self.assertNotEqual(msg1['message_id'], msg2['message_id']) @inlineCallbacks def test_handle_inbound_with_tag(self): mw = yield self.setup_middleware() batch_id = yield self.store.batch_start([("pool", "tag")]) msg = self.mk_msg() TaggingMiddleware.add_tag_to_msg(msg, ["pool", "tag"]) response = yield mw.handle_inbound(msg, "dummy_connector") self.assertEqual(response, msg) yield self.assert_inbound_stored(msg, batch_id) @inlineCallbacks def test_handle_event(self): mw = yield self.setup_middleware() msg = self.mk_msg() msg_id = msg["message_id"] yield self.store.add_outbound_message(msg) ack1 = self.mk_ack(user_message_id=msg_id) event_id1 = ack1['event_id'] resp1 = yield mw.handle_consume_event(ack1, "dummy_connector") self.assertEqual(ack1, resp1) yield self.assert_outbound_stored(msg, events=[event_id1]) ack2 = self.mk_ack(user_message_id=msg_id) event_id2 = ack2['event_id'] resp2 = yield mw.handle_publish_event(ack2, "dummy_connector") self.assertEqual(ack2, resp2) yield self.assert_outbound_stored(msg, events=[event_id1, event_id2]) @inlineCallbacks def test_handle_event_no_consume_store(self): mw = yield self.setup_middleware({'store_on_consume': False}) msg = self.mk_msg() msg_id = msg["message_id"] yield self.store.add_outbound_message(msg) ack1 = self.mk_ack(user_message_id=msg_id) resp1 = yield mw.handle_consume_event(ack1, "dummy_connector") self.assertEqual(resp1, ack1) yield self.assert_outbound_stored(msg, events=[]) ack2 = self.mk_ack(user_message_id=msg_id) event_id2 = ack2['event_id'] resp2 = yield mw.handle_publish_event(ack2, "dummy_connector") self.assertEqual(resp2, ack2) yield self.assert_outbound_stored(msg, events=[event_id2]) PK=JGNj**0vumi/middleware/tests/test_address_translator.py"""Tests from vumi.middleware.address_translator.""" from vumi.middleware.address_translator import AddressTranslationMiddleware from vumi.tests.helpers import VumiTestCase class TestAddressTranslationMiddleware(VumiTestCase): def mk_addr_trans(self, outbound_map): worker = object() config = {'outbound_map': outbound_map} mw = AddressTranslationMiddleware("test_addr_trans", config, worker) mw.setup_middleware() return mw def mk_msg(self, to_addr='unknown', from_addr='unknown'): return { 'to_addr': to_addr, 'from_addr': from_addr, } def test_handle_outbound(self): mw = self.mk_addr_trans({'555OUT': '555IN'}) msg = mw.handle_outbound(self.mk_msg(to_addr="555OUT"), "outbound") self.assertEqual(msg['to_addr'], "555IN") msg = mw.handle_outbound(self.mk_msg(to_addr="555UNK"), "outbound") self.assertEqual(msg['to_addr'], "555UNK") def test_handle_inbound(self): mw = self.mk_addr_trans({'555OUT': '555IN'}) msg = mw.handle_inbound(self.mk_msg(from_addr="555IN"), "inbound") self.assertEqual(msg['from_addr'], "555OUT") msg = mw.handle_inbound(self.mk_msg(from_addr="555UNK"), "inbound") self.assertEqual(msg['from_addr'], "555UNK") PK=JGSVvumi/demos/tictactoe.py# -*- test-case-name: vumi.demos.tests.test_tictactoe -*- """Simple Tic Tac Toe game.""" from twisted.internet.defer import inlineCallbacks from twisted.python import log from vumi.application import ApplicationWorker class TicTacToeGame(object): def __init__(self, player_X): self.board = [ [' ', ' ', ' '], [' ', ' ', ' '], [' ', ' ', ' '], ] self.player_X = player_X self.player_O = None def set_player_O(self, player_O): self.player_O = player_O def draw_line(self): return "+---+---+---+" def draw_row(self, row): return '| %s | %s | %s |' % tuple(row) def draw_board(self): return '\n'.join(['\n'.join([self.draw_line(), self.draw_row(r)]) for r in self.board] + [self.draw_line()]) def _move(self, val, x, y): if self.board[y][x] != ' ': return False self.board[y][x] = val return True def move(self, sid, x, y): if sid == self.player_X: return self._move('X', x, y), self.player_O if sid == self.player_O: return self._move('O', x, y), self.player_X def check_line(self, v1, v2, v3): if (v1 != ' ') and (v1 == v2) and (v1 == v3): return v1 return False def check_win(self): for l in [((0, 0), (0, 1), (0, 2)), ((1, 0), (1, 1), (1, 2)), ((2, 0), (2, 1), (2, 2)), ((0, 0), (1, 0), (2, 0)), ((0, 1), (1, 1), (2, 1)), ((0, 2), (1, 2), (2, 2)), ((0, 0), (1, 1), (2, 2)), ((0, 2), (1, 1), (2, 0))]: ll = [self.board[y][x] for x, y in l] result = self.check_line(*ll) if result: return result return False def check_draw(self): for x in [0, 1, 2]: for y in [0, 1, 2]: if self.board[y][x] == ' ': return False return True class TicTacToeWorker(ApplicationWorker): @inlineCallbacks def startWorker(self): """docstring for startWorker""" yield super(TicTacToeWorker, self).startWorker() self.games = {} self.open_game = None self.messages = {} def reply(self, player, content, continue_session=True): orig = self.messages.pop(player, None) if orig is None: log.msg("Can't reply to %r, no stored message." % player) return return self.reply_to(orig, content, continue_session=continue_session) def end(self, player, content): return self.reply(player, content, continue_session=False) def new_session(self, msg): log.msg("New session:", msg) log.msg("Open game:", self.open_game) log.msg("Games:", self.games) user_id = msg.user() self.messages[user_id] = msg if self.open_game: game = self.open_game game.set_player_O(user_id) self.open_game = None self.reply(game.player_X, game.draw_board()) else: game = TicTacToeGame(user_id) self.open_game = game self.games[user_id] = game def close_session(self, msg): log.msg("Close session:", msg) user_id = msg.user() game = self.games.get(user_id) if game: if self.open_game == game: self.open_game = None for uid in (game.player_X, game.player_O): if uid is not None: self.games.pop(uid, None) self.end(uid, "Other side timed out.") def consume_user_message(self, msg): log.msg("Resume session:", msg) user_id = msg.user() self.messages[user_id] = msg if user_id not in self.games: return game = self.games[user_id] move = self.parse_move(msg['content']) if move is None: self.end(game.player_X, "Cheerio.") self.end(game.player_O, "Cheerio.") return log.msg("Move:", move) resp, other_uid = game.move(user_id, *move) if game.check_win(): self.end(user_id, "You won!") self.end(other_uid, "You lost!") return if game.check_draw(): self.end(user_id, "Draw. :-(") self.end(other_uid, "Draw. :-(") return self.reply(other_uid, game.draw_board()) def parse_move(self, move): moves = { '1': (0, 0), '2': (1, 0), '3': (2, 0), '4': (0, 1), '5': (1, 1), '6': (2, 1), '7': (0, 2), '8': (1, 2), '9': (2, 2), } if move[0] in moves: return moves[move[0]] return None PK=JGRw vumi/demos/calculator.py# -*- test-case-name: vumi.demos.tests.test_calculator -*- import operator from vumi.application import ApplicationWorker def mk_menu(preamble, options): return '\n'.join( [preamble] + ['%s. %s' % (idx, action) for (idx, (action, op)) in enumerate(options, 1)]) class CalculatorApp(ApplicationWorker): def setup_application(self): self._sessions = {} self.actions = [ ('Add', operator.add), ('Subtract', operator.sub), ('Multiply', operator.mul), ] def teardown_application(self): pass def save_session(self, user_id, data): self._sessions[user_id] = data return data def get_session(self, user_id): return self._sessions.get(user_id, {}) def clear_session(self, user_id): return self.save_session(user_id, {}) def new_session(self, message): self.clear_session(message.user()) return self.reply_to(message, mk_menu( 'What would you like to do?', self.actions)) def close_session(self, message): self.clear_session(message.user()) def consume_user_message(self, message): user_id = message.user() session = self.get_session(user_id) try: numeric_input = int(message['content']) except (ValueError, TypeError): self.clear_session(message.user()) return self.reply_to( message, 'Sorry invalid input!', continue_session=False) if 'action' not in session: action_index = numeric_input - 1 try: action, op = self.actions[action_index] except IndexError: return self.new_session(message) session['action'] = action_index d = self.reply_to(message, 'What is the first number?') d.addCallback(lambda *a: self.save_session(user_id, session)) return d if 'first_number' not in session: session['first_number'] = numeric_input d = self.reply_to(message, 'What is the second number?') d.addCallback(lambda *a: self.save_session(user_id, session)) return d if 'second_number' not in session: session['second_number'] = numeric_input result = self.calculate(session) d = self.reply_to(message, result, continue_session=False) d.addCallback(lambda *a: self.save_session(user_id, session)) return d def calculate(self, session): action, op = self.actions[session['action']] result = op(session['first_number'], session['second_number']) return 'The result is: %s.' % (result,) PK=JGvumi/demos/__init__.pyPK=JGUvumi/demos/hangman.py# -*- test-case-name: vumi.demos.tests.test_hangman -*- import string from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python import log from vumi.application import ApplicationWorker from vumi.utils import http_request from vumi.components.session import SessionManager from vumi.config import ConfigText, ConfigDict class HangmanGame(object): """Represents a game of Hangman. Parameters ---------- word : str Word to guess. guesses : set, optional Characters guessed so far. If None, defaults to the empty set. msg : str, optional Message set in reply to last user action. Defaults to 'New game!'. """ UI_TEMPLATE = ( u"%(msg)s\n" u"Word: %(word)s\n" u"Letters guessed so far: %(guesses)s\n" u"%(prompt)s (0 to quit):\n") # exit codes NOT_DONE, DONE, DONE_WANTS_NEW = range(3) def __init__(self, word, guesses=None, msg="New game!"): self.word = word self.guesses = guesses if guesses is not None else set() self.msg = msg self.exit_code = self.NOT_DONE def state(self): """Return the game state as a dict.""" return { 'guesses': u"".join(sorted(self.guesses)), 'word': self.word, 'msg': self.msg, } @classmethod def from_state(cls, state): return cls(word=state['word'], guesses=set(state['guesses']), msg=state['msg']) def event(self, message): """Handle an user input string. Parameters ---------- message : unicode Message received from user. """ message = message.lower() if not message: self.msg = u"Some input required please." elif len(message) > 1: self.msg = u"Single characters only please." elif message == '0': self.exit_code = self.DONE self.msg = u"Game ended." elif self.won(): self.exit_code = self.DONE_WANTS_NEW elif message not in string.lowercase: self.msg = u"Letters of the alphabet only please." elif message in self.guesses: self.msg = u"You've already guessed '%s'." % (message,) else: assert len(message) == 1 self.guesses.add(message) log.msg("Message: %r, word: %r" % (message, self.word)) if message in self.word: self.msg = u"Word contains at least one '%s'! :D" % (message,) else: self.msg = u"Word contains no '%s'. :(" % (message,) if self.won(): self.msg = self.victory_message() def victory_message(self): uniques = len(set(self.word)) guesses = len(self.guesses) for factor, msg in [ (1, u"Flawless victory!"), (1.5, u"Epic victory!"), (2, u"Standard victory!"), (3, u"Sub-par victory!"), (4, u"Random victory!")]: if guesses <= uniques * factor: return msg return u"Button mashing!" def won(self): return all(x in self.guesses for x in self.word) def draw_board(self): """Return a text-based UI.""" if self.exit_code != self.NOT_DONE: return u"Adieu!" word = u"".join((x if x in self.guesses else '_') for x in self.word) guesses = "".join(sorted(self.guesses)) if self.won(): prompt = u"Enter anything to start a new game" else: prompt = u"Enter next guess" return self.UI_TEMPLATE % {'word': word, 'guesses': guesses, 'msg': self.msg, 'prompt': prompt, } class HangmanConfig(ApplicationWorker.CONFIG_CLASS): "Hangman worker config." worker_name = ConfigText( "Name of this hangman worker.", required=True, static=True) redis_manager = ConfigDict( "Redis client configuration.", default={}, static=True) random_word_url = ConfigText( "URL to GET a random word from.", required=True) class HangmanWorker(ApplicationWorker): """Worker that plays Hangman. """ CONFIG_CLASS = HangmanConfig @inlineCallbacks def setup_application(self): """Start the worker""" config = self.get_static_config() # Connect to Redis r_prefix = "hangman_game:%s:%s" % ( config.transport_name, config.worker_name) self.session_manager = yield SessionManager.from_redis_config( config.redis_manager, r_prefix) @inlineCallbacks def teardown_application(self): yield self.session_manager.stop() def random_word(self, random_word_url): log.msg('Fetching random word from %s' % (random_word_url,)) d = http_request(random_word_url, None, method='GET') def _decode(word): # result from http_request should always be bytes # convert to unicode, strip BOMs and whitespace word = word.decode("utf-8", "ignore") word = word.lstrip(u'\ufeff\ufffe') word = word.strip() return word return d.addCallback(_decode) def game_key(self, user_id): """Key for looking up a users game in data store.""" return user_id.lstrip('+') @inlineCallbacks def load_game(self, msisdn): """Fetch a game for the given user ID. """ game_key = self.game_key(msisdn) state = yield self.session_manager.load_session(game_key) if state: game = HangmanGame.from_state(state) else: game = None returnValue(game) @inlineCallbacks def new_game(self, msisdn, random_word_url): """Create a new game for the given user ID. """ word = yield self.random_word(random_word_url) word = word.strip().lower() game = HangmanGame(word) game_key = self.game_key(msisdn) yield self.session_manager.create_session(game_key, **game.state()) returnValue(game) def save_game(self, msisdn, game): """Save the game state for the given game.""" game_key = self.game_key(msisdn) state = game.state() return self.session_manager.save_session(game_key, state) def delete_game(self, msisdn): """Delete the users saved game.""" game_key = self.game_key(msisdn) return self.session_manager.clear_session(game_key) @inlineCallbacks def consume_user_message(self, msg): """Find or create a hangman game for this player. Then process the user's message. """ log.msg("User message: %s" % msg['content']) user_id = msg.user() config = yield self.get_config(msg) game = yield self.load_game(user_id) if game is None: game = yield self.new_game(user_id, config.random_word_url) if msg['content'] is None: # probably new session -- just send the user the board self.reply_to(msg, game.draw_board(), True) return message = msg['content'].strip() game.event(message) continue_session = True if game.exit_code == game.DONE: yield self.delete_game(user_id) continue_session = False elif game.exit_code == game.DONE_WANTS_NEW: game = yield self.new_game(user_id, config.random_word_url) else: yield self.save_game(user_id, game) self.reply_to(msg, game.draw_board(), continue_session) def close_session(self, msg): """We ignore session closing and wait for the user to return.""" pass PK=JG\p p vumi/demos/ircbot.py# -*- test-case-name: vumi.demos.tests.test_ircbot -*- """Demo workers for constructing a simple IRC bot.""" import re import json from twisted.internet.defer import inlineCallbacks, returnValue from vumi import log from vumi.application import ApplicationWorker from vumi.persist.txredis_manager import TxRedisManager from vumi.config import ConfigText, ConfigDict class MemoConfig(ApplicationWorker.CONFIG_CLASS): "Memo worker config." worker_name = ConfigText("Name of worker.", required=True, static=True) redis_manager = ConfigDict( "Redis client configuration.", default={}, static=True) class MemoWorker(ApplicationWorker): """Watches for memos to users and notifies users of memos when users appear. """ CONFIG_CLASS = MemoConfig MEMO_RE = re.compile(r'^\S+ tell (\S+) (.*)$') @inlineCallbacks def setup_application(self): config = self.get_static_config() r_prefix = "ircbot:memos:%s" % (config.worker_name,) redis = yield TxRedisManager.from_config(config.redis_manager) self.redis = redis.sub_manager(r_prefix) def teardown_application(self): return self.redis._close() def rkey_memo(self, channel, recipient): return "%s:%s" % (channel, recipient) def store_memo(self, channel, recipient, sender, text): memo_key = self.rkey_memo(channel, recipient) value = json.dumps([sender, text]) return self.redis.rpush(memo_key, value) @inlineCallbacks def retrieve_memos(self, channel, recipient, delete=False): memo_key = self.rkey_memo(channel, recipient) memos = yield self.redis.lrange(memo_key, 0, -1) if delete: yield self.redis.delete(memo_key) returnValue([json.loads(value) for value in memos]) @inlineCallbacks def consume_user_message(self, msg): """Log message from a user.""" nickname = msg.user() irc_metadata = msg['helper_metadata'].get('irc', {}) channel = irc_metadata.get('irc_channel', 'unknown') addressed_to = irc_metadata.get('addressed_to_transport', True) if addressed_to: yield self.process_potential_memo(channel, nickname, msg) memos = yield self.retrieve_memos(channel, nickname, delete=True) if memos: log.msg("Time to deliver some memos:", memos) for memo_sender, memo_text in memos: self.reply_to(msg, "%s, %s asked me tell you: %s" % (nickname, memo_sender, memo_text)) @inlineCallbacks def process_potential_memo(self, channel, sender, msg): match = self.MEMO_RE.match(msg['content']) if match: recipient = match.group(1).lower() memo_text = match.group(2) yield self.store_memo(channel, recipient, sender, memo_text) self.reply_to(msg, "%s: Sure thing, boss." % (sender,)) PK=JGG*vumi/demos/words.py# -*- test-case-name: vumi.demos.tests.test_words -*- """Demo ApplicationWorkers that perform simple text manipulations.""" from twisted.python import log from vumi.application import ApplicationWorker class SimpleAppWorker(ApplicationWorker): """Base class for very simple application workers. Configuration ------------- transport_name : str Name of the transport. """ def consume_user_message(self, msg): """Find or create a hangman game for this player. Then process the user's message. """ content = msg['content'].encode('utf-8') if msg['content'] else None log.msg("User message: %s" % content) text = msg['content'] if text is None: reply = self.get_help() else: reply = self.process_message(text) return self.reply_to(msg, reply) def process_message(self, text): raise NotImplementedError("Sub-classes should implement" " process_message.") def get_help(self): return "Enter text:" class EchoWorker(SimpleAppWorker): """Echos text back to the sender.""" def process_message(self, data): return data def get_help(self): return "Enter text to echo:" class ReverseWorker(SimpleAppWorker): """Replies with reversed text.""" def process_message(self, data): return data[::-1] def get_help(self): return "Enter text to reverse:" class WordCountWorker(SimpleAppWorker): """Returns word and letter counts to the sender.""" def process_message(self, data): response = [] words = len(data.split()) response.append("%s word%s" % (words, "s" * (words != 1))) chars = len(data) response.append("%s char%s" % (chars, "s" * (chars != 1))) return ', '.join(response) def get_help(self): return "Enter text to return word and character counts for:" PK=JGρFk==vumi/demos/rps.py# -*- test-case-name: vumi.demos.tests.test_rps -*- from twisted.internet.defer import inlineCallbacks from twisted.python import log from vumi.application import ApplicationWorker class MultiPlayerGameWorker(ApplicationWorker): @inlineCallbacks def startWorker(self): """docstring for startWorker""" self.games = {} self.open_game = None self.game_setup() self.messages = {} yield super(MultiPlayerGameWorker, self).startWorker() def new_session(self, msg): log.msg("New session:", msg) log.msg("Open game:", self.open_game) log.msg("Games:", self.games) user_id = msg.user() self.messages[user_id] = msg if self.open_game: game = self.open_game if not self.add_player_to_game(game, user_id): self.open_game = None else: game = self.create_new_game(user_id) self.open_game = game self.games[user_id] = game def close_session(self, msg): log.msg("Close session:", msg) user_id = msg.user() game = self.games.get(user_id) if game: if self.open_game == game: self.open_game = None self.clean_up_game(game) for uid, sgame in self.games.items(): if game == sgame: self.games.pop(uid, None) msg = "Game terminated due to remote player disconnect." self.end(uid, msg) self.messages.pop(user_id, None) def consume_user_message(self, msg): log.msg("Resume session:", msg) user_id = msg.user() self.messages[user_id] = msg if user_id not in self.games: return game = self.games[user_id] self.continue_game(game, user_id, msg['content']) def game_setup(self): pass def create_new_game(self, session_id): pass def add_player_to_game(self, game, session_id): pass def clean_up_game(self, game): pass def continue_game(self, game, session_id, message): pass def reply(self, player, content, continue_session=True): orig = self.messages.pop(player, None) if orig is None: log.msg("Can't reply to %r, no stored message." % player) return return self.reply_to(orig, content, continue_session=continue_session) def end(self, player, content): return self.reply(player, content, continue_session=False) class RockPaperScissorsGame(object): def __init__(self, best_of, player_1): self.best_of = best_of self.player_1 = player_1 self.player_2 = None self.current_move = None self.scores = (0, 0) self.last_result = None def set_player_2(self, player_2): self.player_2 = player_2 def get_other_player(self, sid): if sid == self.player_1: return self.player_2 return self.player_1 def draw_board(self, sid): scores = self.scores if sid == self.player_2: scores = tuple(reversed(scores)) result = [] if self.last_result == (0, 0): result.append("Draw.") elif self.last_result == (1, 0): if sid == self.player_1: result.append("You won!") else: result.append("You lost!") elif self.last_result == (0, 1): if sid == self.player_1: result.append("You lost!") else: result.append("You won!") result.extend([ 'You: %s, opponent: %s' % scores, '1. rock', '2. paper', '3. scissors', ]) return '\n'.join(result) def move(self, sid, choice): if not self.current_move: self.current_move = choice return False if sid == self.player_1: player_1, player_2 = choice, self.current_move else: player_1, player_2 = self.current_move, choice self.current_move = None result = self.decide(player_1, player_2) self.last_result = result self.scores = ( self.scores[0] + result[0], self.scores[1] + result[1], ) return True def decide(self, player_1, player_2): return { (1, 1): (0, 0), (1, 2): (1, 0), (1, 3): (0, 1), (2, 1): (0, 1), (2, 2): (0, 0), (2, 3): (1, 0), (3, 1): (1, 0), (3, 2): (0, 1), (3, 3): (0, 0), }[(player_1, player_2)] def check_win(self): return sum(self.scores) >= self.best_of class RockPaperScissorsWorker(MultiPlayerGameWorker): def create_new_game(self, session_id): return RockPaperScissorsGame(5, session_id) def add_player_to_game(self, game, session_id): game.set_player_2(session_id) self.turn_reply(game) return False def clean_up_game(self, game): pass def continue_game(self, game, session_id, message): move = self.parse_move(message) if move is None: self.end(session_id, 'You disconnected.') self.end(game.get_other_player(session_id), 'Your opponent disconnected.') return if game.move(session_id, move): self.turn_reply(game) def parse_move(self, message): char = (message + ' ')[0] if char in '123': return int(char) return None @inlineCallbacks def turn_reply(self, game): if game.check_win(): if game.scores[0] > game.scores[1]: self.end(game.player_1, "You won! :-)") self.end(game.player_2, "You lost. :-(") else: self.end(game.player_1, "You lost. :-(") self.end(game.player_2, "You won! :-)") return yield self.reply(game.player_1, game.draw_board(game.player_1)) yield self.reply(game.player_2, game.draw_board(game.player_2)) PK=JG]LJvumi/demos/static_reply.py# -*- test-case-name: vumi.demos.tests.test_static_reply -*- from datetime import date from twisted.internet.defer import succeed, inlineCallbacks from vumi.application import ApplicationWorker from vumi.config import ConfigText class StaticReplyConfig(ApplicationWorker.CONFIG_CLASS): reply_text = ConfigText( "Reply text to send in response to inbound messages.", static=False, default="Hello {user} at {now}.") class StaticReplyApplication(ApplicationWorker): """ Application that replies to incoming messages with a configured response. """ CONFIG_CLASS = StaticReplyConfig @inlineCallbacks def consume_user_message(self, message): config = yield self.get_config(message) yield self.reply_to( message, config.reply_text.format( user=message.user(), now=date.today()), continue_session=False) PK=JGYTV %vumi/demos/tests/wikipedia_sample.xml africa
Africa Africa is the world's second largest and second most populous continent, after Asia. http://en.wikipedia.org/wiki/Africa Race and ethnicity in the United States Census Race and ethnicity in the United States Census, as defined by the Federal Office of Management and Budget (OMB) and the United States Census Bureau, are self-identification data items in which residents choose the race or races with which they most closely identify, and indicate whether or not they are of Hispanic or Latino origin (ethnicity). http://en.wikipedia.org/wiki/Race_and_ethnicity_in_the_United_States_Census African American African Americans (also referred to as Black Americans or Afro-Americans, and formerly as American Negroes) are citizens or residents of the United States who have origins in any of the black populations of Africa. http://en.wikipedia.org/wiki/African_American African people African people refers to natives, inhabitants, or citizen of Africa and to people of African descent. http://en.wikipedia.org/wiki/African_people
PK=JGҿIIvumi/demos/tests/test_words.py# -*- coding: utf-8 -*- """Tests for vumi.demos.words.""" from twisted.internet.defer import inlineCallbacks from vumi.demos.words import (SimpleAppWorker, EchoWorker, ReverseWorker, WordCountWorker) from vumi.message import TransportUserMessage from vumi.tests.helpers import VumiTestCase from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.utils import LogCatcher class EchoTestApp(SimpleAppWorker): """Test worker that echos calls to process_message.""" def process_message(self, data): return 'echo:%s' % data class TestSimpleAppWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(None)) self.worker = yield self.app_helper.get_application({}, EchoTestApp) @inlineCallbacks def test_help(self): yield self.app_helper.make_dispatch_inbound( None, session_event=TransportUserMessage.SESSION_NEW) [reply] = self.app_helper.get_dispatched_outbound() self.assertEqual(reply['session_event'], None) self.assertEqual(reply['content'], 'Enter text:') @inlineCallbacks def test_content_text(self): yield self.app_helper.make_dispatch_inbound( "test", session_event=TransportUserMessage.SESSION_NEW) [reply] = self.app_helper.get_dispatched_outbound() self.assertEqual(reply['session_event'], None) self.assertEqual(reply['content'], 'echo:test') @inlineCallbacks def test_base_process_message(self): worker = yield self.app_helper.get_application({}, SimpleAppWorker) self.assertRaises(NotImplementedError, worker.process_message, 'foo') class TestEchoWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(None)) self.worker = yield self.app_helper.get_application({}, EchoWorker) def test_process_message(self): self.assertEqual(self.worker.process_message("foo"), "foo") def test_help(self): self.assertEqual(self.worker.get_help(), "Enter text to echo:") @inlineCallbacks def test_echo_non_ascii(self): content = u'Zoë destroyer of Ascii' with LogCatcher() as log: yield self.app_helper.make_dispatch_inbound(content) [reply] = self.app_helper.get_dispatched_outbound() self.assertEqual( log.messages(), ['User message: Zo\xc3\xab destroyer of Ascii']) class TestReverseWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(None)) self.worker = yield self.app_helper.get_application({}, ReverseWorker) def test_process_message(self): self.assertEqual(self.worker.process_message("foo"), "oof") def test_help(self): self.assertEqual(self.worker.get_help(), "Enter text to reverse:") class TestWordCountWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(None)) self.worker = yield self.app_helper.get_application( {}, WordCountWorker) def test_process_message(self): self.assertEqual(self.worker.process_message("foo bar"), "2 words, 7 chars") def test_singular(self): self.assertEqual(self.worker.process_message("f"), "1 word, 1 char") def test_help(self): self.assertEqual(self.worker.get_help(), "Enter text to return word" " and character counts for:") PK=JGX## vumi/demos/tests/test_hangman.py# -*- encoding: utf-8 -*- """Tests for vumi.demos.hangman.""" import string from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet import reactor from twisted.web.server import Site from twisted.web.resource import Resource from twisted.web.static import Data from vumi.demos.hangman import HangmanGame, HangmanWorker from vumi.message import TransportUserMessage from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.helpers import VumiTestCase def mkstate(word, guesses, msg): return {'word': word, 'guesses': guesses, 'msg': msg} class TestHangmanGame(VumiTestCase): def test_easy_game(self): game = HangmanGame(word='moo') game.event('m') game.event('o') self.assertTrue(game.won()) self.assertEqual( game.state(), mkstate('moo', 'mo', 'Flawless victory!')) def test_incorrect_guesses(self): game = HangmanGame(word='moo') game.event('f') game.event('g') self.assertFalse(game.won()) self.assertEqual( game.state(), mkstate('moo', 'fg', "Word contains no 'g'. :(")) def test_repeated_guesses(self): game = HangmanGame(word='moo') game.event('f') game.event('f') self.assertFalse(game.won()) self.assertEqual( game.state(), mkstate('moo', 'f', "You've already guessed 'f'.")) def test_button_mashing(self): game = HangmanGame(word='moo') for event in string.lowercase.replace('o', ''): game.event(event) game.event('o') self.assertTrue(game.won()) self.assertEqual( game.state(), mkstate('moo', string.lowercase, "Button mashing!")) def test_new_game(self): game = HangmanGame(word='moo') for event in ('m', 'o', '-'): game.event(event) self.assertEqual( game.state(), mkstate('moo', 'mo', 'Flawless victory!')) self.assertEqual(game.exit_code, game.DONE_WANTS_NEW) def test_from_state(self): game = HangmanGame.from_state(mkstate("bar", "xyz", "Eep?")) self.assertEqual(game.word, "bar") self.assertEqual(game.guesses, set("xyz")) self.assertEqual(game.msg, "Eep?") self.assertEqual(game.exit_code, game.NOT_DONE) def test_from_state_non_ascii(self): game = HangmanGame.from_state( mkstate("b\xc3\xa4r".decode("utf-8"), "xyz", "Eep?")) self.assertEqual(game.word, u"b\u00e4r") self.assertEqual(game.guesses, set("xyz")) self.assertEqual(game.msg, "Eep?") self.assertEqual(game.exit_code, game.NOT_DONE) def test_exit(self): game = HangmanGame('elephant') game.event('0') self.assertEqual(game.exit_code, game.DONE) self.assertEqual(game.draw_board(), "Adieu!") def test_draw_board(self): game = HangmanGame('word') board = game.draw_board() msg, word, guesses, prompt, end = board.split("\n") self.assertEqual(msg, "New game!") self.assertEqual(word, "Word: ____") self.assertEqual(guesses, "Letters guessed so far: ") self.assertEqual(prompt, "Enter next guess (0 to quit):") def test_draw_board_at_end_of_game(self): game = HangmanGame('m') game.event('m') board = game.draw_board() msg, word, guesses, prompt, end = board.split("\n") self.assertEqual(msg, "Flawless victory!") self.assertEqual(word, "Word: m") self.assertEqual(guesses, "Letters guessed so far: m") self.assertEqual(prompt, "Enter anything to start a new game" " (0 to quit):") def test_displaying_word(self): game = HangmanGame('word') game.event('w') game.event('r') board = game.draw_board() _msg, word, _guesses, _prompt, _end = board.split("\n") self.assertEqual(word, "Word: w_r_") def test_displaying_guesses(self): game = HangmanGame('word') game.event('w') board = game.draw_board() msg, _word, _guesses, _prompt, _end = board.split("\n") self.assertEqual(msg, "Word contains at least one 'w'! :D") game.event('w') board = game.draw_board() msg, _word, _guesses, _prompt, _end = board.split("\n") self.assertEqual(msg, "You've already guessed 'w'.") game.event('x') board = game.draw_board() msg, _word, _guesses, _prompt, _end = board.split("\n") self.assertEqual(msg, "Word contains no 'x'. :(") def test_garbage_input(self): game = HangmanGame(word="zoo") for garbage in [":", "!", "\x00", "+", "abc", ""]: game.event(garbage) self.assertEqual(game.guesses, set()) game.event('z') game.event('o') self.assertTrue(game.won()) class TestHangmanWorker(VumiTestCase): @inlineCallbacks def setUp(self): root = Resource() # data is elephant with a UTF-8 encoded BOM # it is a sad elephant (as seen in the wild) root.putChild("word", Data('\xef\xbb\xbfelephant\r\n', 'text/html')) site_factory = Site(root) self.webserver = yield reactor.listenTCP( 0, site_factory, interface='127.0.0.1') self.add_cleanup(self.webserver.loseConnection) addr = self.webserver.getHost() random_word_url = "http://%s:%s/word" % (addr.host, addr.port) self.app_helper = self.add_helper(ApplicationHelper(HangmanWorker)) self.worker = yield self.app_helper.get_application({ 'worker_name': 'test_hangman', 'random_word_url': random_word_url, }) yield self.worker.session_manager.redis._purge_all() # just in case def send(self, content, session_event=None): return self.app_helper.make_dispatch_inbound( content, session_event=session_event) @inlineCallbacks def recv(self, n=0): msgs = yield self.app_helper.wait_for_dispatched_outbound(n) def reply_code(msg): if msg['session_event'] == TransportUserMessage.SESSION_CLOSE: return 'end' return 'reply' returnValue([(reply_code(msg), msg['content']) for msg in msgs]) @inlineCallbacks def test_new_session(self): yield self.send(None, TransportUserMessage.SESSION_NEW) replies = yield self.recv(1) self.assertEqual(len(replies), 1) reply = replies[0] self.assertEqual(reply[0], 'reply') self.assertEqual(reply[1], "New game!\n" "Word: ________\n" "Letters guessed so far: \n" "Enter next guess (0 to quit):\n") @inlineCallbacks def test_random_word(self): word = yield self.worker.random_word( self.worker.config['random_word_url']) self.assertEqual(word, 'elephant') @inlineCallbacks def test_full_session(self): yield self.send(None, TransportUserMessage.SESSION_NEW) for event in ('e', 'l', 'p', 'h', 'a', 'n', 'o', 't'): yield self.send(event, TransportUserMessage.SESSION_RESUME) replies = yield self.recv(9) self.assertEqual(len(replies), 9) last_reply = replies[-1] self.assertEqual(last_reply[0], 'reply') self.assertEqual(last_reply[1], "Epic victory!\n" "Word: elephant\n" "Letters guessed so far: aehlnopt\n" "Enter anything to start a new game (0 to quit):\n") yield self.send('1') replies = yield self.recv(10) last_reply = replies[-1] self.assertEqual(last_reply[0], 'reply') self.assertEqual(last_reply[1], "New game!\n" "Word: ________\n" "Letters guessed so far: \n" "Enter next guess (0 to quit):\n") yield self.send('0') replies = yield self.recv(11) last_reply = replies[-1] self.assertEqual(last_reply[0], 'end') self.assertEqual(last_reply[1], "Adieu!") @inlineCallbacks def test_close_session(self): yield self.send(None, TransportUserMessage.SESSION_CLOSE) replies = yield self.recv() self.assertEqual(replies, []) @inlineCallbacks def test_non_ascii_input(self): yield self.send(None, TransportUserMessage.SESSION_NEW) for event in (u'ü', u'æ'): yield self.send(event, TransportUserMessage.SESSION_RESUME) replies = yield self.recv(3) self.assertEqual(len(replies), 3) for reply in replies[1:]: self.assertEqual(reply[0], 'reply') self.assertTrue(reply[1].startswith( 'Letters of the alphabet only please.')) PK=JG%vumi/demos/tests/test_static_reply.pyfrom datetime import date from twisted.internet.defer import inlineCallbacks from vumi.application.tests.helpers import ApplicationHelper from vumi.demos.static_reply import StaticReplyApplication from vumi.tests.helpers import VumiTestCase class TestStaticReplyApplication(VumiTestCase): def setUp(self): self.app_helper = self.add_helper( ApplicationHelper(StaticReplyApplication)) @inlineCallbacks def test_receive_message_custom(self): yield self.app_helper.get_application({ 'reply_text': 'Hi {user}', }) yield self.app_helper.make_dispatch_inbound( "Hello", from_addr='from_addr') [reply] = self.app_helper.get_dispatched_outbound() self.assertEqual('Hi from_addr', reply['content']) self.assertEqual(u'close', reply['session_event']) @inlineCallbacks def test_receive_message(self): yield self.app_helper.get_application({}) yield self.app_helper.make_dispatch_inbound( "Hello", from_addr='from_addr') [reply] = self.app_helper.get_dispatched_outbound() self.assertEqual('Hello from_addr at %s.' % (date.today(),), reply['content']) self.assertEqual(u'close', reply['session_event']) PK=JGiC"vumi/demos/tests/test_tictactoe.pyfrom twisted.internet.defer import inlineCallbacks from vumi.demos.tictactoe import TicTacToeGame, TicTacToeWorker from vumi.message import TransportUserMessage from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.helpers import VumiTestCase class TestTicTacToeGame(VumiTestCase): def get_game(self, moves=()): game = TicTacToeGame('pX') game.set_player_O('pO') for sid, x, y in moves: game.move(sid, x, y) return game def test_game_init(self): game = TicTacToeGame('pX') self.assertEquals('pX', game.player_X) self.assertEquals(None, game.player_O) self.assertEquals([[' '] * 3] * 3, game.board) game.set_player_O('pO') self.assertEquals('pX', game.player_X) self.assertEquals('pO', game.player_O) self.assertEquals([[' '] * 3] * 3, game.board) def test_move(self): game = self.get_game() expected_board = [[' ', ' ', ' '] for _i in range(3)] self.assertEqual((True, 'pO'), game.move('pX', 0, 0)) expected_board[0][0] = 'X' self.assertEqual(expected_board, game.board) self.assertEqual((True, 'pX'), game.move('pO', 1, 0)) expected_board[0][1] = 'O' self.assertEqual(expected_board, game.board) self.assertEqual((False, 'pO'), game.move('pX', 0, 0)) self.assertEqual(expected_board, game.board) def test_draw_board(self): game = self.get_game(moves=[('pX', 0, 0), ('pO', 1, 0), ('pX', 1, 2)]) self.assertEqual("+---+---+---+\n" "| X | O | |\n" "+---+---+---+\n" "| | | |\n" "+---+---+---+\n" "| | X | |\n" "+---+---+---+", game.draw_board()) def test_check_draw(self): game = self.get_game() for y, x in [(y, x) for x in range(3) for y in range(3)]: self.assertEqual(False, game.check_draw()) game.move('pX' if x + y % 2 == 0 else 'pO', x, y) self.assertEqual(True, game.check_draw()) def test_check_win(self): game = self.get_game() for i, (y, x) in enumerate([(0, 0), (1, 0), (1, 2), (1, 1), (0, 2), (0, 1), (2, 2)]): self.assertEqual(False, game.check_win()) game.move('pX' if i % 2 == 0 else 'pO', x, y) self.assertEqual('X', game.check_win()) class TestTicTacToeWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(TicTacToeWorker)) self.worker = yield self.app_helper.get_application({}) def dispatch_start_message(self, from_addr): return self.app_helper.make_dispatch_inbound( None, from_addr=from_addr, session_event=TransportUserMessage.SESSION_NEW) @inlineCallbacks def test_new_sessions(self): self.assertEquals({}, self.worker.games) self.assertEquals(None, self.worker.open_game) user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) self.assertNotEquals(None, self.worker.open_game) game = self.worker.open_game self.assertEquals({user1: game}, self.worker.games) yield self.dispatch_start_message(user2) self.assertEquals(None, self.worker.open_game) self.assertEquals({user1: game, user2: game}, self.worker.games) [msg] = self.app_helper.get_dispatched_outbound() self.assertTrue(msg['content'].startswith('+---+---+---+')) @inlineCallbacks def test_moves(self): user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) game = self.worker.open_game yield self.dispatch_start_message(user2) self.assertEquals(1, len(self.app_helper.get_dispatched_outbound())) yield self.app_helper.make_dispatch_inbound('1', from_addr=user1) self.assertEquals(2, len(self.app_helper.get_dispatched_outbound())) yield self.app_helper.make_dispatch_inbound('2', from_addr=user2) self.assertEquals(3, len(self.app_helper.get_dispatched_outbound())) self.assertEqual('X', game.board[0][0]) self.assertEqual('O', game.board[0][1]) @inlineCallbacks def test_full_game(self): user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) game = self.worker.open_game yield self.dispatch_start_message(user2) for user, content in [ (user1, '1'), (user2, '4'), (user1, '2'), (user2, '5'), (user1, '3')]: yield self.app_helper.make_dispatch_inbound( content, from_addr=user) self.assertEqual('X', game.check_win()) [end1, end2] = self.app_helper.get_dispatched_outbound()[-2:] self.assertEqual(user1, end1["to_addr"]) self.assertEqual("You won!", end1["content"]) self.assertEqual(user2, end2["to_addr"]) self.assertEqual("You lost!", end2["content"]) PK=JGvumi/demos/tests/__init__.pyPK=JGPbb#vumi/demos/tests/test_calculator.pyfrom twisted.internet.defer import inlineCallbacks from vumi.tests.helpers import VumiTestCase from vumi.application.tests.helpers import ApplicationHelper from vumi.demos.calculator import CalculatorApp from vumi.message import TransportUserMessage class TestCalculatorApp(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(CalculatorApp)) self.worker = yield self.app_helper.get_application({}) @inlineCallbacks def test_session_start(self): yield self.app_helper.make_dispatch_inbound( None, session_event=TransportUserMessage.SESSION_NEW) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual( resp['content'], 'What would you like to do?\n' '1. Add\n' '2. Subtract\n' '3. Multiply') @inlineCallbacks def test_first_number(self): yield self.app_helper.make_dispatch_inbound( '1', session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual(resp['content'], 'What is the first number?') @inlineCallbacks def test_second_number(self): self.worker.save_session('+41791234567', { 'action': 1, }) yield self.app_helper.make_dispatch_inbound( '1', session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual(resp['content'], 'What is the second number?') @inlineCallbacks def test_action(self): self.worker.save_session('+41791234567', { 'action': 0, # add 'first_number': 2, }) yield self.app_helper.make_dispatch_inbound( '2', session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual(resp['content'], 'The result is: 4.') self.assertEqual(resp['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_invalid_input(self): self.worker.save_session('+41791234567', { 'action': 0, # add }) yield self.app_helper.make_dispatch_inbound( 'not-an-int', session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual(resp['content'], 'Sorry invalid input!') self.assertEqual(resp['session_event'], TransportUserMessage.SESSION_CLOSE) @inlineCallbacks def test_invalid_action(self): yield self.app_helper.make_dispatch_inbound( 'not-an-option', session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertTrue( resp['content'].startswith('Sorry invalid input!')) @inlineCallbacks def test_user_cancellation(self): self.worker.save_session('+41791234567', {'foo': 'bar'}) yield self.app_helper.make_dispatch_inbound( None, session_event=TransportUserMessage.SESSION_CLOSE) self.assertEqual(self.worker.get_session('+41791234567'), {}) @inlineCallbacks def test_none_input_on_session_resume(self): yield self.app_helper.make_dispatch_inbound( None, session_event=TransportUserMessage.SESSION_RESUME) [resp] = yield self.app_helper.wait_for_dispatched_outbound(1) self.assertEqual(resp['content'], 'Sorry invalid input!') PK=JGO O vumi/demos/tests/test_ircbot.py"""Tests for vumi.demos.ircbot.""" from twisted.internet.defer import inlineCallbacks, returnValue from vumi.demos.ircbot import MemoWorker from vumi.message import TransportUserMessage from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.helpers import VumiTestCase class TestMemoWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(MemoWorker)) self.worker = yield self.app_helper.get_application( {'worker_name': 'testmemo'}) yield self.worker.redis._purge_all() # just in case def send(self, content, from_addr='testnick', channel=None): transport_metadata = {} helper_metadata = {} if channel is not None: transport_metadata['irc_channel'] = channel helper_metadata['irc'] = {'irc_channel': channel} return self.app_helper.make_dispatch_inbound( content, from_addr=from_addr, helper_metadata=helper_metadata, transport_metadata=transport_metadata) @inlineCallbacks def recv(self, n=0): msgs = yield self.app_helper.wait_for_dispatched_outbound(n) def reply_code(msg): if msg['session_event'] == TransportUserMessage.SESSION_CLOSE: return 'end' return 'reply' returnValue([(reply_code(msg), msg['content']) for msg in msgs]) @inlineCallbacks def test_no_memos(self): yield self.send("Message from someone with no messages.") replies = yield self.recv() self.assertEquals([], replies) @inlineCallbacks def test_leave_memo(self): yield self.send('bot: tell memoed hey there', channel='#test') memos = yield self.worker.retrieve_memos('#test', 'memoed') self.assertEquals(memos, [['testnick', 'hey there']]) replies = yield self.recv() self.assertEqual(replies, [ ('reply', 'testnick: Sure thing, boss.'), ]) @inlineCallbacks def test_leave_memo_nick_canonicalization(self): yield self.send('bot: tell MeMoEd hey there', channel='#test') memos = yield self.worker.retrieve_memos('#test', 'memoed') self.assertEquals(memos, [['testnick', 'hey there']]) @inlineCallbacks def test_send_memos(self): yield self.send('bot: tell testmemo this is memo 1', channel='#test') yield self.send('bot: tell testmemo this is memo 2', channel='#test') yield self.send('bot: tell testmemo this is a different channel', channel='#another') # replies to setting memos replies = yield self.recv(3) self.app_helper.clear_dispatched_outbound() yield self.send('ping', channel='#test', from_addr='testmemo') replies = yield self.recv(2) self.assertEqual(replies, [ ('reply', 'testmemo, testnick asked me tell you:' ' this is memo 1'), ('reply', 'testmemo, testnick asked me tell you:' ' this is memo 2'), ]) self.app_helper.clear_dispatched_outbound() yield self.send('ping', channel='#another', from_addr='testmemo') replies = yield self.recv(1) self.assertEqual(replies, [ ('reply', 'testmemo, testnick asked me tell you:' ' this is a different channel'), ]) PK=JG7vumi/demos/tests/test_rps.pyfrom twisted.internet.defer import inlineCallbacks from vumi.message import TransportUserMessage from vumi.demos.rps import RockPaperScissorsGame, RockPaperScissorsWorker from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.helpers import VumiTestCase class TestRockPaperScissorsGame(VumiTestCase): def get_game(self, scores=None): game = RockPaperScissorsGame(5, 'p1') game.set_player_2('p2') if scores is not None: game.scores = scores return game def test_game_init(self): game = RockPaperScissorsGame(5, 'p1') self.assertEquals('p1', game.player_1) self.assertEquals(None, game.player_2) self.assertEquals((0, 0), game.scores) self.assertEquals(None, game.current_move) game.set_player_2('p2') self.assertEquals('p1', game.player_1) self.assertEquals('p2', game.player_2) self.assertEquals((0, 0), game.scores) self.assertEquals(None, game.current_move) def test_game_moves_draw(self): game = self.get_game((1, 1)) game.move('p1', 1) self.assertEquals(1, game.current_move) self.assertEquals((1, 1), game.scores) game.move('p2', 1) self.assertEquals(None, game.current_move) self.assertEquals((1, 1), game.scores) def test_game_moves_win(self): game = self.get_game((1, 1)) game.move('p1', 1) self.assertEquals(1, game.current_move) self.assertEquals((1, 1), game.scores) game.move('p2', 2) self.assertEquals(None, game.current_move) self.assertEquals((2, 1), game.scores) def test_game_moves_lose(self): game = self.get_game((1, 1)) game.move('p1', 1) self.assertEquals(1, game.current_move) self.assertEquals((1, 1), game.scores) game.move('p2', 3) self.assertEquals(None, game.current_move) self.assertEquals((1, 2), game.scores) class TestRockPaperScissorsWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper( ApplicationHelper(RockPaperScissorsWorker)) self.worker = yield self.app_helper.get_application({}) def dispatch_start_message(self, from_addr): return self.app_helper.make_dispatch_inbound( None, from_addr=from_addr, session_event=TransportUserMessage.SESSION_NEW) @inlineCallbacks def test_new_sessions(self): self.assertEquals({}, self.worker.games) self.assertEquals(None, self.worker.open_game) user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) self.assertNotEquals(None, self.worker.open_game) game = self.worker.open_game self.assertEquals({user1: game}, self.worker.games) yield self.dispatch_start_message(user2) self.assertEquals(None, self.worker.open_game) self.assertEquals({user1: game, user2: game}, self.worker.games) self.assertEquals(2, len(self.app_helper.get_dispatched_outbound())) @inlineCallbacks def test_moves(self): user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) game = self.worker.open_game yield self.dispatch_start_message(user2) self.assertEquals(2, len(self.app_helper.get_dispatched_outbound())) yield self.app_helper.make_dispatch_inbound('1', from_addr=user2) self.assertEquals(2, len(self.app_helper.get_dispatched_outbound())) yield self.app_helper.make_dispatch_inbound('2', from_addr=user1) self.assertEquals(4, len(self.app_helper.get_dispatched_outbound())) self.assertEquals((0, 1), game.scores) @inlineCallbacks def test_full_game(self): user1 = '+27831234567' user2 = '+27831234568' yield self.dispatch_start_message(user1) game = self.worker.open_game yield self.dispatch_start_message(user2) for user, content in [(user1, '1'), (user2, '2')] * 5: # best-of 5 yield self.app_helper.make_dispatch_inbound( content, from_addr=user) self.assertEqual((5, 0), game.scores) [end1, end2] = self.app_helper.get_dispatched_outbound()[-2:] self.assertEqual("You won! :-)", end1["content"]) self.assertEqual(user1, end1["to_addr"]) self.assertEqual("You lost. :-(", end2["content"]) self.assertEqual(user2, end2["to_addr"]) PKqGIIvumi/application/http_relay.py# -*- test-case-name: vumi.application.tests.test_http_relay -*- from base64 import b64encode from twisted.python import log from twisted.web import http from twisted.internet.defer import inlineCallbacks from vumi.application.base import ApplicationWorker from vumi.utils import http_request_full from vumi.errors import VumiError from vumi.config import ConfigText, ConfigUrl class HTTPRelayError(VumiError): pass class HTTPRelayConfig(ApplicationWorker.CONFIG_CLASS): # TODO: Make these less static? url = ConfigUrl( "URL to submit incoming message to.", required=True, static=True) event_url = ConfigUrl( "URL to submit incoming events to. (Defaults to the same as 'url').", static=True) http_method = ConfigText( "HTTP method for submitting messages.", default='POST', static=True) auth_method = ConfigText( "HTTP authentication method.", default='basic', static=True) username = ConfigText("Username for HTTP authentication.", default='') password = ConfigText("Password for HTTP authentication.", default='') class HTTPRelayApplication(ApplicationWorker): CONFIG_CLASS = HTTPRelayConfig reply_header = 'X-Vumi-HTTPRelay-Reply' agent_factory = None # For swapping out the Agent we use in tests. def validate_config(self): self.supported_auth_methods = { 'basic': self.generate_basic_auth_headers, } # XXX: Is this the best way to do this? if 'event_url' not in self.config: self.config['event_url'] = self.config['url'] config = self.get_static_config() if config.auth_method not in self.supported_auth_methods: raise HTTPRelayError( 'HTTP Authentication method %s not supported' % ( repr(config.auth_method,))) def generate_basic_auth_headers(self, username, password): credentials = ':'.join([username, password]) auth_string = b64encode(credentials.encode('utf-8')) return { 'Authorization': ['Basic %s' % (auth_string,)] } def get_auth_headers(self, config): if config.username: handler = self.supported_auth_methods.get(config.auth_method) return handler(config.username, config.password) return {} @inlineCallbacks def consume_user_message(self, message): config = yield self.get_config(message) headers = self.get_auth_headers(config) response = yield http_request_full( config.url.geturl(), message.to_json(), headers, config.http_method, agent_class=self.agent_factory) headers = response.headers if response.code == http.OK: if headers.hasHeader(self.reply_header): raw_headers = headers.getRawHeaders(self.reply_header) content = response.delivered_body.strip() if (raw_headers[0].lower() == 'true') and content: self.reply_to(message, content) else: log.err('%s responded with %s' % ( config.url.geturl(), response.code)) @inlineCallbacks def relay_event(self, event): config = yield self.get_config(event) headers = self.get_auth_headers(config) yield http_request_full( config.event_url.geturl(), event.to_json(), headers, config.http_method, agent_class=self.agent_factory) @inlineCallbacks def consume_ack(self, event): yield self.relay_event(event) @inlineCallbacks def consume_delivery_report(self, event): yield self.relay_event(event) PK=JG|eevumi/application/base.py# -*- test-case-name: vumi.application.tests.test_base -*- """Basic tools for building a vumi ApplicationWorker.""" import copy from twisted.internet.defer import maybeDeferred from vumi.config import ConfigText, ConfigDict from vumi.worker import BaseWorker from vumi import log from vumi.message import TransportUserMessage from vumi.errors import InvalidEndpoint SESSION_NEW = TransportUserMessage.SESSION_NEW SESSION_CLOSE = TransportUserMessage.SESSION_CLOSE SESSION_RESUME = TransportUserMessage.SESSION_RESUME class ApplicationConfig(BaseWorker.CONFIG_CLASS): """Base config definition for applications. You should subclass this and add application-specific fields. """ transport_name = ConfigText( "The name this application instance will use to create its queues.", required=True, static=True) send_to = ConfigDict( "'send_to' configuration dict.", default={}, static=True) class ApplicationWorker(BaseWorker): """Base class for an application worker. Handles :class:`vumi.message.TransportUserMessage` and :class:`vumi.message.TransportEvent` messages. Application workers may send outgoing messages using :meth:`reply_to` (for replies to incoming messages) or :meth:`send_to` (for messages that are not replies). :meth:`send_to` can take either an `endpoint` parameter to specify the endpoint to send on (and optionally add additional message data from application configuration). :attr:`ALLOWED_ENDPOINTS` lists the endpoints this application is allowed to send messages to using the :meth:`send_to` method. If it is set to `None`, any endpoint is allowed. Messages sent via :meth:`send_to` pass optional additional data from configuration to the TransportUserMessage constructor, based on the endpoint parameter passed to send_to. This usually contains information useful for routing the message. An example :meth:`send_to` configuration might look like:: - send_to: - default: transport_name: sms_transport NOTE: If you are using non-endpoint routing, 'transport_name' **must** be defined for each send_to section since dispatchers rely on this for routing outbound messages. The available set of endpoints defaults to just the single endpoint named `default`. If applications wish to define their own set of available endpoints they should override :attr:`ALLOWED_ENDPOINTS`. Setting :attr:`ALLOWED_ENDPOINTS` to `None` allows the application to send on arbitrary endpoint names. """ transport_name = None UNPAUSE_CONNECTORS = True CONFIG_CLASS = ApplicationConfig ALLOWED_ENDPOINTS = frozenset(['default']) def _validate_config(self): config = self.get_static_config() self.transport_name = config.transport_name self.validate_config() def setup_connectors(self): d = self.setup_ri_connector(self.transport_name) def cb(connector): connector.set_inbound_handler(self.dispatch_user_message) connector.set_event_handler(self.dispatch_event) return connector return d.addCallback(cb) def setup_worker(self): """ Set up basic application worker stuff. You shouldn't have to override this in subclasses. """ self._event_handlers = { 'ack': self.consume_ack, 'nack': self.consume_nack, 'delivery_report': self.consume_delivery_report, } self._session_handlers = { SESSION_NEW: self.new_session, SESSION_CLOSE: self.close_session, } d = maybeDeferred(self.setup_application) if self.UNPAUSE_CONNECTORS: d.addCallback(lambda r: self.unpause_connectors()) return d def teardown_worker(self): d = self.pause_connectors() d.addCallback(lambda r: self.teardown_application()) return d def setup_application(self): """ All application specific setup should happen in here. Subclasses should override this method to perform extra setup. """ pass def teardown_application(self): """ Clean-up of setup done in setup_application should happen here. """ pass def _dispatch_event_raw(self, event): event_type = event.get('event_type') handler = self._event_handlers.get(event_type, self.consume_unknown_event) return handler(event) def dispatch_event(self, event): """Dispatch to event_type specific handlers.""" return self._dispatch_event_raw(event) def consume_unknown_event(self, event): log.msg("Unknown event type in message %r" % (event,)) def consume_ack(self, event): """Handle an ack message.""" pass def consume_nack(self, event): """Handle a nack message""" pass def consume_delivery_report(self, event): """Handle a delivery report.""" pass def _dispatch_user_message_raw(self, message): session_event = message.get('session_event') handler = self._session_handlers.get(session_event, self.consume_user_message) return handler(message) def dispatch_user_message(self, message): """Dispatch user messages to handler.""" return self._dispatch_user_message_raw(message) def consume_user_message(self, message): """Respond to user message.""" pass def new_session(self, message): """Respond to a new session. Defaults to calling consume_user_message. """ return self.consume_user_message(message) def close_session(self, message): """Close a session. The .reply_to() method should not be called when the session is closed. """ pass def _publish_message(self, message, endpoint_name=None): publisher = self.connectors[self.transport_name] return publisher.publish_outbound(message, endpoint_name=endpoint_name) @staticmethod def check_endpoint(allowed_endpoints, endpoint): """Check that endpoint is in the list of allowed endpoints. :param list allowed_endpoints: List (or set) of allowed endpoints. If ``allowed_endpoints`` is ``None``, all endpoints are allowed. :param str endpoint: Endpoint to check. The special value ``None`` is equivalent to ``default``. """ if allowed_endpoints is None: return if endpoint is None: endpoint = "default" if endpoint not in allowed_endpoints: raise InvalidEndpoint( "Endpoint %r not defined in list of allowed endpoints %r" % (endpoint, allowed_endpoints)) def reply_to(self, original_message, content, continue_session=True, **kws): reply = original_message.reply(content, continue_session, **kws) endpoint_name = original_message.get_routing_endpoint() return self._publish_message(reply, endpoint_name=endpoint_name) def reply_to_group(self, original_message, content, continue_session=True, **kws): reply = original_message.reply_group(content, continue_session, **kws) endpoint_name = original_message.get_routing_endpoint() return self._publish_message(reply, endpoint_name=endpoint_name) def send_to(self, to_addr, content, endpoint=None, **kw): if endpoint is None: endpoint = 'default' self.check_endpoint(self.ALLOWED_ENDPOINTS, endpoint) options = copy.deepcopy( self.get_static_config().send_to.get(endpoint, {})) options.update(kw) msg = TransportUserMessage.send(to_addr, content, **options) return self._publish_message(msg, endpoint_name=endpoint) PK=JG(+vumi/application/sandboxer.jsvar vm = require('vm'); var events = require('events'); var EventEmitter = events.EventEmitter; var SandboxApi = function () { // API for use by applications var self = this; self.id = 0; self.emitter = new EventEmitter(); self.next_id = function () { self.id += 1; return self.id.toString(); }; self.populate_command = function (command, msg) { msg.cmd = command; msg.reply = false; msg.cmd_id = self.next_id(); return msg; }; self.request = function (command, msg, callback) { // callback is optional and is called once a reply to // the request is received. self.populate_command(command, msg); self.emitter.emit('request', { msg: msg, callback: callback }); }; self.log_info = function (msg, callback) { self.request('log.info', {msg: msg}, callback); }; self.done = function () { self.log_info('Done.', function() { self.emitter.emit('done'); }); }; // handlers: // * on_unknown_command is the default message handler // * other handlers are looked up based on the command name self.on_unknown_command = function(command) {}; }; var SandboxRunner = function (api) { // Runner for a sandboxed app var self = this; self.emitter = new EventEmitter(); self.api = api; self.chunk = ""; self.pending_requests = {}; self.loaded = false; self.emitter.on('command', function (command) { var handler_name = "on_" + command.cmd.replace('.', '_').replace('-', '_'); var handler = api[handler_name]; if (!handler) { handler = api.on_unknown_command; } if (handler) { handler.call(self.api, command); } }); self.emitter.on('reply', function (reply) { var handler = self.pending_requests[reply.cmd_id]; if (handler && handler.callback) { handler.callback.call(self.api, reply); } }); self.api.emitter.on('request', function(request) { setImmediate(function() { if (request.callback) { self.pending_requests[request.msg.cmd_id] = { callback: request.callback }; } self.send_command(request.msg); }); }); self.api.emitter.on('done', function() { self.exit(); }); self.exit = function() { process.exit(0); }; self.load_code = function (command) { self.log("Loading sandboxed code ..."); var ctxt; var loaded_module = vm.createScript(command.javascript); if (command.app_context) { // TODO use vm stuff instead of eval eval("ctxt = " + command.app_context + ";"); // jshint ignore:line } else { ctxt = {}; } ctxt.api = self.api; loaded_module.runInNewContext(ctxt); self.loaded = true; }; self.send_command = function (cmd) { process.stdout.write(JSON.stringify(cmd)); process.stdout.write("\n"); }; self.log = function(msg) { var cmd = self.api.populate_command("log.info", {"msg": msg}); self.send_command(cmd); }; self.data_from_stdin = function (data) { var parts = data.split("\n"); parts[0] = self.chunk + parts[0]; for (i = 0; i < parts.length - 1; i++) { if (!parts[i]) { continue; } var msg = JSON.parse(parts[i]); if (!self.loaded) { if (msg.cmd == 'initialize') { self.load_code(msg); } } else if (!msg.reply) { self.emitter.emit('command', msg); } else { self.emitter.emit('reply', msg); } } self.chunk = parts[parts.length - 1]; }; self.run = function () { process.stdin.resume(); process.stdin.setEncoding('ascii'); process.stdin.on('data', function(data) { self.data_from_stdin(data); }); }; }; var api = new SandboxApi(); var runner = new SandboxRunner(api); runner.run(); runner.log("Starting sandbox ..."); PK=JGjX ``vumi/application/sandbox.py# -*- test-case-name: vumi.application.tests.test_sandbox -*- """An application for sandboxing message processing.""" import base64 import resource import os import json import pkg_resources import logging import operator from uuid import uuid4 from StringIO import StringIO import warnings from treq.client import HTTPClient from twisted.internet import reactor from twisted.internet.protocol import ProcessProtocol from twisted.internet.defer import ( Deferred, inlineCallbacks, maybeDeferred, returnValue, DeferredList, succeed) from twisted.internet.error import ProcessDone from twisted.python.failure import Failure from twisted.web.client import WebClientContextFactory, Agent from OpenSSL.SSL import ( VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD) from vumi.config import ConfigText, ConfigInt, ConfigList, ConfigDict from vumi.application.base import ApplicationWorker from vumi.message import Message from vumi.errors import ConfigError from vumi.persist.txredis_manager import TxRedisManager from vumi.utils import load_class_by_string, HttpDataLimitError, to_kwargs from vumi import log from vumi.application.sandbox_rlimiter import SandboxRlimiter warnings.warn( "Use of vumi.application.sandbox is deprecated, the vumi sandbox worker " "and its components have moved to the vxsandbox package:" "pypi.python.org/pypi/vxsandbox", category=DeprecationWarning) class MultiDeferred(object): """A callable that returns new deferreds each time and then fires them all together.""" NOT_FIRED = object() def __init__(self): self._result = self.NOT_FIRED self._deferreds = [] def callback(self, result): self._result = result for d in self._deferreds: d.callback(result) self._deferreds = [] def get(self): d = Deferred() if self.fired(): d.callback(self._result) else: self._deferreds.append(d) return d def fired(self): return self._result is not self.NOT_FIRED class SandboxError(Exception): """An error occurred inside the sandbox.""" class SandboxProtocol(ProcessProtocol): """A protocol for communicating over stdin and stdout with a sandboxed process. The sandbox process is created by calling :meth:`spawn`. This: * Spawns a new Python process that applies the supplied rlimits. * The spawned process then `execs` the supplied executable. Once a spawned process starts, the parent process communicates with it over `stdin`, `stdout` and `stderr` reading and writing a stream of newline separated JSON commands that are parsed and formatted by :class:`SandboxCommand`. Incoming commands are dispatched to :class:`SandboxResource` instances via the supplied :class:`SandboxApi`. """ def __init__(self, sandbox_id, api, executable, spawn_kwargs, rlimits, timeout, recv_limit): self.sandbox_id = sandbox_id self.api = api self.executable = executable self.spawn_kwargs = spawn_kwargs self.rlimits = rlimits self._started = MultiDeferred() self._done = MultiDeferred() self._pending_requests = [] self.exit_reason = None self.timeout_task = reactor.callLater(timeout, self.kill) self.recv_limit = recv_limit self.recv_bytes = 0 self.chunk = '' self.error_chunk = '' self.error_lines = [] api.set_sandbox(self) def spawn(self): SandboxRlimiter.spawn( reactor, self, self.executable, self.rlimits, **self.spawn_kwargs) def done(self): """Returns a deferred that will be called when the process ends.""" return self._done.get() def started(self): """Returns a deferred that will be called once the process starts.""" return self._started.get() def kill(self): """Kills the underlying process.""" if self.transport.pid is not None: self.transport.signalProcess('KILL') def send(self, command): """Writes the command to the processes' stdin.""" self.transport.write(command.to_json()) self.transport.write("\n") def check_recv(self, nbytes): self.recv_bytes += nbytes if self.recv_bytes <= self.recv_limit: return True else: self.kill() self.api.log("Sandbox %r killed for producing too much data on" " stderr and stdout." % (self.sandbox_id), level=logging.ERROR) return False def connectionMade(self): self._started.callback(self) def _process_data(self, chunk, data): if not self.check_recv(len(data)): return [''] # skip the data if it's too big line_parts = data.split("\n") line_parts[0] = chunk + line_parts[0] return line_parts def _parse_command(self, line): try: return SandboxCommand.from_json(line) except Exception, e: return SandboxCommand(cmd="unknown", line=line, exception=e) def outReceived(self, data): lines = self._process_data(self.chunk, data) for i in range(len(lines) - 1): d = self.api.dispatch_request(self._parse_command(lines[i])) self._pending_requests.append(d) self.chunk = lines[-1] def outConnectionLost(self): if self.chunk: line, self.chunk = self.chunk, "" d = self.api.dispatch_request(self._parse_command(line)) self._pending_requests.append(d) def errReceived(self, data): lines = self._process_data(self.error_chunk, data) for i in range(len(lines) - 1): self.error_lines.append(lines[i]) self.error_chunk = lines[-1] def errConnectionLost(self): if self.error_chunk: self.error_lines.append(self.error_chunk) self.error_chunk = "" def _process_request_results(self, results): for success, result in results: if not success: # errors here are bugs in Vumi and thus should always # be logged via Twisted too. log.error(result) # we log them again in a simplified form via the sandbox # api so that the sandbox owner gets to see them too self.api.log(result.getErrorMessage(), logging.ERROR) def processEnded(self, reason): if self.timeout_task.active(): self.timeout_task.cancel() if isinstance(reason.value, ProcessDone): result = reason.value.status else: result = reason if not self._started.fired(): self._started.callback(Failure( SandboxError("Process failed to start."))) if self.error_lines: self.api.log("\n".join(self.error_lines), logging.ERROR) self.error_lines = [] requests_done = DeferredList(self._pending_requests) requests_done.addCallback(self._process_request_results) requests_done.addCallback(lambda _r: self._done.callback(result)) class SandboxResources(object): """Class for holding resources common to a set of sandboxes.""" def __init__(self, app_worker, config): self.app_worker = app_worker self.config = config self.resources = {} def add_resource(self, resource_name, resource): """Add additional resources -- should only be called before calling :meth:`setup_resources`.""" self.resources[resource_name] = resource def validate_config(self): # FIXME: The name of this method is a vicious lie. # It does not validate configs. It constructs resources objects. # Fixing that is beyond the scope of this commit, however. for name, config in self.config.iteritems(): cls = load_class_by_string(config.pop('cls')) self.resources[name] = cls(name, self.app_worker, config) @inlineCallbacks def setup_resources(self): for resource in self.resources.itervalues(): yield resource.setup() @inlineCallbacks def teardown_resources(self): for resource in self.resources.itervalues(): yield resource.teardown() class SandboxResource(object): """Base class for sandbox resources.""" # TODO: SandboxResources should probably have their own config definitions. # Is that overkill? def __init__(self, name, app_worker, config): self.name = name self.app_worker = app_worker self.config = config def setup(self): pass def teardown(self): pass def sandbox_init(self, api): pass def reply(self, command, **kwargs): return SandboxCommand(cmd=command['cmd'], reply=True, cmd_id=command['cmd_id'], **kwargs) def reply_error(self, command, reason): return self.reply(command, success=False, reason=reason) def dispatch_request(self, api, command): handler_name = 'handle_%s' % (command['cmd'],) handler = getattr(self, handler_name, self.unknown_request) return maybeDeferred(handler, api, command) def unknown_request(self, api, command): api.log("Resource %s received unknown command %r from" " sandbox %r. Killing sandbox. [Full command: %r]" % (self.name, command['cmd'], api.sandbox_id, command), logging.ERROR) api.sandbox_kill() # it's a harsh world class RedisResource(SandboxResource): """ Resource that provides access to a simple key-value store. Configuration options: :param dict redis_manager: Redis manager configuration options. :param int keys_per_user_soft: Maximum number of keys each user may make use of in redis before usage warnings are logged. (default: 80% of hard limit). :param int keys_per_user_hard: Maximum number of keys each user may make use of in redis (default: 100). Falls back to keys_per_user. :param int keys_per_user: Synonym for `keys_per_user_hard`. Deprecated. """ # FIXME: # - Currently we allow key expiry to be set. Keys that expire are # not decremented from the sandbox's key limit. This means that # some sandboxes might hit their key limit too soon. This is # better than not allowing expiry of keys and filling up Redis # though. @inlineCallbacks def setup(self): self.r_config = self.config.get('redis_manager', {}) self.keys_per_user_hard = self.config.get( 'keys_per_user_hard', self.config.get('keys_per_user', 100)) self.keys_per_user_soft = self.config.get( 'keys_per_user_soft', int(0.8 * self.keys_per_user_hard)) self.redis = yield TxRedisManager.from_config(self.r_config) def teardown(self): return self.redis.close_manager() def _count_key(self, sandbox_id): return "#".join(["count", sandbox_id]) def _sandboxed_key(self, sandbox_id, key): return "#".join(["sandboxes", sandbox_id, key]) def _too_many_keys(self, command): return self.reply(command, success=False, reason="Too many keys") @inlineCallbacks def check_keys(self, api, key): if (yield self.redis.exists(key)): returnValue(True) count_key = self._count_key(api.sandbox_id) key_count = yield self.redis.incr(count_key, 1) if key_count > self.keys_per_user_soft: if key_count < self.keys_per_user_hard: api.log('Redis soft limit of %s keys reached for sandbox %s. ' 'Once the hard limit of %s is reached no more keys ' 'can be written.' % ( self.keys_per_user_soft, api.sandbox_id, self.keys_per_user_hard), logging.WARNING) else: api.log('Redis hard limit of %s keys reached for sandbox %s. ' 'No more keys can be written.' % ( self.keys_per_user_hard, api.sandbox_id), logging.ERROR) yield self.redis.incr(count_key, -1) returnValue(False) returnValue(True) @inlineCallbacks def handle_set(self, api, command): """ Set the value of a key. Command fields: - ``key``: The key whose value should be set. - ``value``: The value to store. May be any JSON serializable object. - ``seconds``: Lifetime of the key in seconds. The default ``null`` indicates that the key should not expire. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'kv.set', {key: 'foo', value: {x: '42'}}, function(reply) { api.log_info('Value store: ' + reply.success); }); """ key = self._sandboxed_key(api.sandbox_id, command.get('key')) seconds = command.get('seconds') if not (seconds is None or isinstance(seconds, (int, long))): returnValue(self.reply_error( command, "seconds must be a number or null")) if not (yield self.check_keys(api, key)): returnValue(self._too_many_keys(command)) json_value = json.dumps(command.get('value')) if seconds is None: yield self.redis.set(key, json_value) else: yield self.redis.setex(key, seconds, json_value) returnValue(self.reply(command, success=True)) @inlineCallbacks def handle_get(self, api, command): """ Retrieve the value of a key. Command fields: - ``key``: The key whose value should be retrieved. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. - ``value``: The value retrieved. Example: .. code-block:: javascript api.request( 'kv.get', {key: 'foo'}, function(reply) { api.log_info( 'Value retrieved: ' + JSON.stringify(reply.value)); } ); """ key = self._sandboxed_key(api.sandbox_id, command.get('key')) raw_value = yield self.redis.get(key) value = json.loads(raw_value) if raw_value is not None else None returnValue(self.reply(command, success=True, value=value)) @inlineCallbacks def handle_delete(self, api, command): """ Delete a key. Command fields: - ``key``: The key to delete. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'kv.delete', {key: 'foo'}, function(reply) { api.log_info('Value deleted: ' + reply.success); } ); """ key = self._sandboxed_key(api.sandbox_id, command.get('key')) existed = bool((yield self.redis.delete(key))) if existed: count_key = self._count_key(api.sandbox_id) yield self.redis.incr(count_key, -1) returnValue(self.reply(command, success=True, existed=existed)) @inlineCallbacks def handle_incr(self, api, command): """ Atomically increment the value of an integer key. The current value of the key must be an integer. If the key does not exist, it is set to zero. Command fields: - ``key``: The key to delete. - ``amount``: The integer amount to increment the key by. Defaults to 1. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. - ``value``: The new value of the key. Example: .. code-block:: javascript api.request( 'kv.incr', {key: 'foo', amount: 3}, function(reply) { api.log_info('New value: ' + reply.value); } ); """ key = self._sandboxed_key(api.sandbox_id, command.get('key')) if not (yield self.check_keys(api, key)): returnValue(self._too_many_keys(command)) amount = command.get('amount', 1) try: value = yield self.redis.incr(key, amount=amount) except Exception, e: returnValue(self.reply(command, success=False, reason=unicode(e))) returnValue(self.reply(command, value=int(value), success=True)) class OutboundResource(SandboxResource): """ Resource that provides the ability to send outbound messages. Includes support for replying to the sender of the current message, replying to the group the current message was from and sending messages that aren't replies. """ def handle_reply_to(self, api, command): """ Sends a reply to the individual who sent a received message. Command fields: - ``content``: The body of the reply message. - ``in_reply_to``: The ``message id`` of the message being replied to. - ``continue_session``: Whether to continue the session (if any). Defaults to ``true``. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'outbound.reply_to', {content: 'Welcome!', in_reply_to: '06233d4eede945a3803bf9f3b78069ec'}, function(reply) { api.log_info('Reply sent: ' + reply.success); }); """ content = command['content'] continue_session = command.get('continue_session', True) orig_msg = api.get_inbound_message(command['in_reply_to']) d = self.app_worker.reply_to(orig_msg, content, continue_session=continue_session) d.addCallback(lambda r: self.reply(command, success=True)) return d def handle_reply_to_group(self, api, command): """ Sends a reply to the group from which a received message was sent. Command fields: - ``content``: The body of the reply message. - ``in_reply_to``: The ``message id`` of the message being replied to. - ``continue_session``: Whether to continue the session (if any). Defaults to ``true``. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'outbound.reply_to_group', {content: 'Welcome!', in_reply_to: '06233d4eede945a3803bf9f3b78069ec'}, function(reply) { api.log_info('Reply to group sent: ' + reply.success); }); """ content = command['content'] continue_session = command.get('continue_session', True) orig_msg = api.get_inbound_message(command['in_reply_to']) d = self.app_worker.reply_to_group(orig_msg, content, continue_session=continue_session) d.addCallback(lambda r: self.reply(command, success=True)) return d def handle_send_to(self, api, command): """ Sends a message to a specified address. Command fields: - ``content``: The body of the reply message. - ``to_addr``: The address of the recipient (e.g. an MSISDN). - ``endpoint``: The name of the endpoint to send the message via. Optional (default is ``"default"``). Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'outbound.send_to', {content: 'Welcome!', to_addr: '+27831234567', endpoint: 'default'}, function(reply) { api.log_info('Message sent: ' + reply.success); }); """ content = command['content'] to_addr = command['to_addr'] endpoint = command.get('endpoint', 'default') d = self.app_worker.send_to(to_addr, content, endpoint=endpoint) d.addCallback(lambda r: self.reply(command, success=True)) return d class JsSandboxResource(SandboxResource): """ Resource that initializes a Javascript sandbox. Typically used alongside vumi/applicaiton/sandboxer.js which is a simple node.js based Javascript sandbox. Requires the worker to have a `javascript_for_api` method. """ def sandbox_init(self, api): javascript = self.app_worker.javascript_for_api(api) app_context = self.app_worker.app_context_for_api(api) api.sandbox_send(SandboxCommand(cmd="initialize", javascript=javascript, app_context=app_context)) class LoggingResource(SandboxResource): """ Resource that allows a sandbox to log messages via Twisted's logging framework. """ def log(self, api, msg, level): """Logs a message via vumi.log (i.e. Twisted logging). Sub-class should override this if they wish to log messages elsewhere. The `api` parameter is provided for use by such sub-classes. The `log` method should always return a deferred. """ return succeed(log.msg(msg, logLevel=level)) @inlineCallbacks def handle_log(self, api, command, level=None): """ Log a message at the specified severity level. The other log commands are identical except that ``level`` need not be specified. Using the log-level specific commands is preferred. Command fields: - ``level``: The severity level to log at. Must be an integer log level. Default severity is the ``INFO`` log level. - ``msg``: The message to log. Reply fields: - ``success``: ``true`` if the operation was successful, otherwise ``false``. Example: .. code-block:: javascript api.request( 'log.log', {level: 20, msg: 'Abandon ship!'}, function(reply) { api.log_info('New value: ' + reply.value); } ); """ level = command.get('level', level) if level is None: level = logging.INFO msg = command.get('msg') if msg is None: returnValue(self.reply(command, success=False, reason="Value expected for msg")) if not isinstance(msg, basestring): msg = str(msg) elif isinstance(msg, unicode): msg = msg.encode('utf-8') yield self.log(api, msg, level) returnValue(self.reply(command, success=True)) def handle_debug(self, api, command): """ Logs a message at the ``DEBUG`` log level. See :func:`handle_log` for details. """ return self.handle_log(api, command, level=logging.DEBUG) def handle_info(self, api, command): """ Logs a message at the ``INFO`` log level. See :func:`handle_log` for details. """ return self.handle_log(api, command, level=logging.INFO) def handle_warning(self, api, command): """ Logs a message at the ``WARNING`` log level. See :func:`handle_log` for details. """ return self.handle_log(api, command, level=logging.WARNING) def handle_error(self, api, command): """ Logs a message at the ``ERROR`` log level. See :func:`handle_log` for details. """ return self.handle_log(api, command, level=logging.ERROR) def handle_critical(self, api, command): """ Logs a message at the ``CRITICAL`` log level. See :func:`handle_log` for details. """ return self.handle_log(api, command, level=logging.CRITICAL) try: from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.internet.ssl import optionsForClientTLS class HttpClientPolicyForHTTPS(BrowserLikePolicyForHTTPS): """ This client policy is used if we have Twisted 14.0.0 or newer and are not explicitly disabling host verification. """ def __init__(self, ssl_method=None): super(HttpClientPolicyForHTTPS, self).__init__() self.ssl_method = ssl_method def creatorForNetloc(self, hostname, port): options = {} if self.ssl_method is not None: options['method'] = self.ssl_method return optionsForClientTLS( hostname.decode("ascii"), extraCertificateOptions=options) except ImportError: HttpClientPolicyForHTTPS = None class HttpClientContextFactory(object): """ This context factory is used if we have a Twisted version older than 14.0.0 or if we are explicitly disabling host verification. """ def __init__(self, verify_options=None, ssl_method=None): self.verify_options = verify_options self.ssl_method = ssl_method def getContext(self, hostname, port): context = self._get_noverify_context() if self.verify_options in (None, VERIFY_NONE): # We don't want to do anything with verification here. return context if self.verify_options is not None: def verify_callback(conn, cert, errno, errdepth, ok): return ok context.set_verify(self.verify_options, verify_callback) return context def _get_noverify_context(self): """ Use ClientContextFactory directly and set the method if necessary. This will perform no host verification at all. """ from twisted.internet.ssl import ClientContextFactory context_factory = ClientContextFactory() if self.ssl_method is not None: context_factory.method = self.ssl_method return context_factory.getContext() def make_context_factory(ssl_method=None, verify_options=None): if HttpClientPolicyForHTTPS is None or verify_options == VERIFY_NONE: return HttpClientContextFactory( verify_options=verify_options, ssl_method=ssl_method) else: return HttpClientPolicyForHTTPS(ssl_method=ssl_method) class HttpClientResource(SandboxResource): """ Resource that allows making HTTP calls to outside services. All command on this resource share a common set of command and response fields: Command fields: - ``url``: The URL to request - ``verify_options``: A list of options to verify when doing an HTTPS request. Possible string values are ``VERIFY_NONE``, ``VERIFY_PEER``, ``VERIFY_CLIENT_ONCE`` and ``VERIFY_FAIL_IF_NO_PEER_CERT``. Specifying multiple values results in passing along a reduced ``OR`` value (e.g. VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT) - ``headers``: A dictionary of keys for the header name and a list of values to provide as header values. - ``data``: The payload to submit as part of the request. - ``files``: A dictionary, submitted as multipart/form-data in the request: .. code-block:: javascript [{ "field name": { "file_name": "the file name", "content_type": "content-type", "data": "data to submit, encoded as base64", } }, ...] The ``data`` field in the dictionary will be base64 decoded before the HTTP request is made. Success reply fields: - ``success``: Set to ``true`` - ``body``: The response body - ``code``: The HTTP response code Failure reply fields: - ``success``: set to ``false`` - ``reason``: Reason for the failure Example: .. code-block:: javascript api.request( 'http.get', {url: 'http://foo/'}, function(reply) { api.log_info(reply.body); }); """ DEFAULT_TIMEOUT = 30 # seconds DEFAULT_DATA_LIMIT = 128 * 1024 # 128 KB agent_class = Agent http_client_class = HTTPClient def setup(self): self.timeout = self.config.get('timeout', self.DEFAULT_TIMEOUT) self.data_limit = self.config.get('data_limit', self.DEFAULT_DATA_LIMIT) def _make_request_from_command(self, method, command): url = command.get('url', None) if not isinstance(url, basestring): return succeed(self.reply(command, success=False, reason="No URL given")) url = url.encode("utf-8") verify_map = { 'VERIFY_NONE': VERIFY_NONE, 'VERIFY_PEER': VERIFY_PEER, 'VERIFY_CLIENT_ONCE': VERIFY_CLIENT_ONCE, 'VERIFY_FAIL_IF_NO_PEER_CERT': VERIFY_FAIL_IF_NO_PEER_CERT, } method_map = { 'SSLv3': SSLv3_METHOD, 'SSLv23': SSLv23_METHOD, 'TLSv1': TLSv1_METHOD, } if 'verify_options' in command: verify_options = [verify_map[key] for key in command.get('verify_options', [])] verify_options = reduce(operator.or_, verify_options) else: verify_options = None if 'ssl_method' in command: # TODO: Fail better with unknown method. ssl_method = method_map[command['ssl_method']] else: ssl_method = None context_factory = make_context_factory( verify_options=verify_options, ssl_method=ssl_method) headers = command.get('headers', None) data = command.get('data', None) files = command.get('files', None) d = self._make_request(method, url, headers=headers, data=data, files=files, timeout=self.timeout, context_factory=context_factory, data_limit=self.data_limit) d.addCallback(self._make_success_reply, command) d.addErrback(self._make_failure_reply, command) return d def _make_request(self, method, url, headers=None, data=None, files=None, timeout=None, context_factory=None, data_limit=None): context_factory = (context_factory if context_factory is not None else WebClientContextFactory()) if headers is not None: headers = dict((k.encode("utf-8"), [x.encode("utf-8") for x in v]) for k, v in headers.items()) if data is not None: data = data.encode("utf-8") if files is not None: files = dict([ (key, (value['file_name'], value['content_type'], StringIO(base64.b64decode(value['data'])))) for key, value in files.iteritems()]) agent = self.agent_class(reactor, contextFactory=context_factory) http_client = self.http_client_class(agent) d = http_client.request(method, url, headers=headers, data=data, files=files, timeout=timeout) d.addCallback(self._ensure_data_limit, data_limit) return d def _ensure_data_limit(self, response, data_limit): header = response.headers.getRawHeaders('Content-Length') def data_limit_check(response, length): if data_limit is not None and length > data_limit: raise HttpDataLimitError( "Received %d bytes, maximum of %d bytes allowed." % (length, data_limit,)) return response if header is None: d = response.content() d.addCallback(lambda body: data_limit_check(response, len(body))) return d content_length = header[0] return maybeDeferred(data_limit_check, response, int(content_length)) def _make_success_reply(self, response, command): d = response.content() d.addCallback( lambda body: self.reply(command, success=True, body=body, code=response.code)) return d def _make_failure_reply(self, failure, command): return self.reply(command, success=False, reason=failure.getErrorMessage()) def handle_get(self, api, command): """ Make an HTTP GET request. See :class:`HttpResource` for details. """ return self._make_request_from_command('GET', command) def handle_put(self, api, command): """ Make an HTTP PUT request. See :class:`HttpResource` for details. """ return self._make_request_from_command('PUT', command) def handle_delete(self, api, command): """ Make an HTTP DELETE request. See :class:`HttpResource` for details. """ return self._make_request_from_command('DELETE', command) def handle_head(self, api, command): """ Make an HTTP HEAD request. See :class:`HttpResource` for details. """ return self._make_request_from_command('HEAD', command) def handle_post(self, api, command): """ Make an HTTP POST request. See :class:`HttpResource` for details. """ return self._make_request_from_command('POST', command) def handle_patch(self, api, command): """ Make an HTTP PATCH request. See :class:`HttpResource` for details. """ return self._make_request_from_command('PATCH', command) class SandboxApi(object): """A sandbox API instance for a particular sandbox run.""" def __init__(self, resources, config): self._sandbox = None self._inbound_messages = {} self.resources = resources self.fallback_resource = SandboxResource("fallback", None, {}) potential_logger = None if config.logging_resource: potential_logger = self.resources.resources.get( config.logging_resource) if potential_logger is None: log.warning("Failed to find logging resource %r." " Falling back to Twisted logging." % (config.logging_resource,)) elif not hasattr(potential_logger, 'log'): log.warning("Logging resource %r has no attribute 'log'." " Falling abck to Twisted logging." % (config.logging_resource,)) potential_logger = None self.logging_resource = potential_logger self.config = config @property def sandbox_id(self): return self._sandbox.sandbox_id def set_sandbox(self, sandbox): if self._sandbox is not None: raise SandboxError("Sandbox already set (" "existing id: %r, new id: %r)." % (self.sandbox_id, sandbox.sandbox_id)) self._sandbox = sandbox def sandbox_init(self): for resource in self.resources.resources.values(): resource.sandbox_init(self) def sandbox_inbound_message(self, msg): self._inbound_messages[msg['message_id']] = msg self.sandbox_send(SandboxCommand(cmd="inbound-message", msg=msg.payload)) def sandbox_inbound_event(self, event): self.sandbox_send(SandboxCommand(cmd="inbound-event", msg=event.payload)) def sandbox_send(self, msg): self._sandbox.send(msg) def sandbox_kill(self): self._sandbox.kill() def get_inbound_message(self, message_id): return self._inbound_messages.get(message_id) def log(self, msg, level): if self.logging_resource is None: # fallback to vumi.log logging if we don't # have a logging resource. return succeed(log.msg(msg, logLevel=level)) else: return self.logging_resource.log(self, msg, level=level) @inlineCallbacks def dispatch_request(self, command): resource_name, sep, rest = command['cmd'].partition('.') if not sep: resource_name, rest = '', resource_name command['cmd'] = rest resource = self.resources.resources.get(resource_name, self.fallback_resource) try: reply = yield resource.dispatch_request(self, command) except Exception, e: # errors here are bugs in Vumi so we always log them # via Twisted. However, we reply to the sandbox with # a failure and log via the sandbox api so that the # sandbox owner can be notified. log.error() self.log(str(e), level=logging.ERROR) reply = SandboxCommand( reply=True, cmd_id=command['cmd_id'], success=False, reason=unicode(e)) if reply is not None: reply['cmd'] = '%s%s%s' % (resource_name, sep, rest) self.sandbox_send(reply) class SandboxCommand(Message): @staticmethod def generate_id(): return uuid4().get_hex() def process_fields(self, fields): fields = super(SandboxCommand, self).process_fields(fields) fields.setdefault('cmd', 'unknown') fields.setdefault('cmd_id', self.generate_id()) fields.setdefault('reply', False) return fields def validate_fields(self): super(SandboxCommand, self).validate_fields() self.assert_field_present( 'cmd', 'cmd_id', 'reply', ) @classmethod def from_json(cls, json_string): # We override this to avoid the datetime conversions. return cls(_process_fields=False, **to_kwargs(json.loads(json_string))) class SandboxConfig(ApplicationWorker.CONFIG_CLASS): sandbox = ConfigDict( "Dictionary of resources to provide to the sandbox." " Keys are the names of resources (as seen inside the sandbox)." " Values are dictionaries which must contain a `cls` key that" " gives the full name of the class that provides the resource." " Other keys are additional configuration for that resource.", default={}, static=True) executable = ConfigText( "Full path to the executable to run in the sandbox.") args = ConfigList( "List of arguments to pass to the executable (not including" " the path of the executable itself).", default=[]) path = ConfigText("Current working directory to run the executable in.") env = ConfigDict( "Custom environment variables for the sandboxed process.", default={}) timeout = ConfigInt( "Length of time the subprocess is given to process a message.", default=60) recv_limit = ConfigInt( "Maximum number of bytes that will be read from a sandboxed" " process' stdout and stderr combined.", default=1024 * 1024) rlimits = ConfigDict( "Dictionary of resource limits to be applied to sandboxed" " processes. Defaults are fairly restricted. Keys maybe" " names or values of the RLIMIT constants in" " Python `resource` module. Values should be appropriate integers.", default={}) logging_resource = ConfigText( "Name of the logging resource to use to report errors detected" " in sandboxed code (e.g. lines written to stderr, unexpected" " process termination). Set to null to disable and report" " these directly using Twisted logging instead.", default=None) sandbox_id = ConfigText("This is set based on individual messages.") class Sandbox(ApplicationWorker): """Sandbox application worker.""" CONFIG_CLASS = SandboxConfig KB, MB = 1024, 1024 * 1024 DEFAULT_RLIMITS = { resource.RLIMIT_CORE: (1 * MB, 1 * MB), resource.RLIMIT_CPU: (60, 60), resource.RLIMIT_FSIZE: (1 * MB, 1 * MB), resource.RLIMIT_DATA: (64 * MB, 64 * MB), resource.RLIMIT_STACK: (1 * MB, 1 * MB), resource.RLIMIT_RSS: (10 * MB, 10 * MB), resource.RLIMIT_NOFILE: (15, 15), resource.RLIMIT_MEMLOCK: (64 * KB, 64 * KB), resource.RLIMIT_AS: (196 * MB, 196 * MB), } def validate_config(self): config = self.get_static_config() self.resources = self.create_sandbox_resources(config.sandbox) self.resources.validate_config() def get_config(self, msg): config = self.config.copy() config['sandbox_id'] = self.sandbox_id_for_message(msg) return succeed(self.CONFIG_CLASS(config)) def _convert_rlimits(self, rlimits_config): rlimits = dict((getattr(resource, key, key), value) for key, value in rlimits_config.iteritems()) for key in rlimits.iterkeys(): if not isinstance(key, (int, long)): raise ConfigError("Unknown resource limit key %r" % (key,)) return rlimits def setup_application(self): return self.resources.setup_resources() def teardown_application(self): return self.resources.teardown_resources() def setup_connectors(self): # Set the default event handler so we can handle events from any # endpoint. d = super(Sandbox, self).setup_connectors() def cb(connector): connector.set_default_event_handler(self.dispatch_event) return connector return d.addCallback(cb) def create_sandbox_resources(self, config): return SandboxResources(self, config) def get_executable_and_args(self, config): return config.executable, [config.executable] + config.args def get_rlimits(self, config): rlimits = self.DEFAULT_RLIMITS.copy() rlimits.update(self._convert_rlimits(config.rlimits)) return rlimits def create_sandbox_protocol(self, api): executable, args = self.get_executable_and_args(api.config) rlimits = self.get_rlimits(api.config) spawn_kwargs = dict( args=args, env=api.config.env, path=api.config.path) return SandboxProtocol( api.config.sandbox_id, api, executable, spawn_kwargs, rlimits, api.config.timeout, api.config.recv_limit) def create_sandbox_api(self, resources, config): return SandboxApi(resources, config) def sandbox_id_for_message(self, msg_or_event): """Return a sandbox id for a message or event. Sub-classes may override this to retrieve an appropriate id. """ return msg_or_event['sandbox_id'] def sandbox_protocol_for_message(self, msg_or_event, config): """Return a sandbox protocol for a message or event. Sub-classes may override this to retrieve an appropriate protocol. """ api = self.create_sandbox_api(self.resources, config) protocol = self.create_sandbox_protocol(api) return protocol def _process_in_sandbox(self, sandbox_protocol, api_callback): sandbox_protocol.spawn() def on_start(_result): sandbox_protocol.api.sandbox_init() api_callback() d = sandbox_protocol.done() d.addErrback(log.error) return d d = sandbox_protocol.started() d.addCallbacks(on_start, log.error) return d @inlineCallbacks def process_message_in_sandbox(self, msg): config = yield self.get_config(msg) sandbox_protocol = yield self.sandbox_protocol_for_message(msg, config) def sandbox_init(): sandbox_protocol.api.sandbox_inbound_message(msg) status = yield self._process_in_sandbox(sandbox_protocol, sandbox_init) returnValue(status) @inlineCallbacks def process_event_in_sandbox(self, event): config = yield self.get_config(event) sandbox_protocol = yield self.sandbox_protocol_for_message( event, config) def sandbox_init(): sandbox_protocol.api.sandbox_inbound_event(event) status = yield self._process_in_sandbox(sandbox_protocol, sandbox_init) returnValue(status) def consume_user_message(self, msg): return self.process_message_in_sandbox(msg) def close_session(self, msg): return self.process_message_in_sandbox(msg) def consume_ack(self, event): return self.process_event_in_sandbox(event) def consume_nack(self, event): return self.process_event_in_sandbox(event) def consume_delivery_report(self, event): return self.process_event_in_sandbox(event) class JsSandboxConfig(SandboxConfig): "JavaScript sandbox configuration." javascript = ConfigText("JavaScript code to run.", required=True) app_context = ConfigText("Custom context to execute JS with.") logging_resource = ConfigText( "Name of the logging resource to use to report errors detected" " in sandboxed code (e.g. lines written to stderr, unexpected" " process termination). Set to null to disable and report" " these directly using Twisted logging instead.", default='log') class JsSandbox(Sandbox): """ Configuration options: As for :class:`Sandbox` except: * `executable` defaults to searching for a `node.js` binary. * `args` defaults to the JS sandbox script in the `vumi.application` module. * An instance of :class:`JsSandboxResource` is added to the sandbox resources under the name `js` if no `js` resource exists. * An instance of :class:`LoggingResource` is added to the sandbox resources under the name `log` if no `log` resource exists. * `logging_resource` is set to `log` if it is not set. * An extra 'javascript' parameter specifies the javascript to execute. * An extra optional 'app_context' parameter specifying a custom context for the 'javascript' application to execute with. Example 'javascript' that logs information via the sandbox API (provided as 'this' to 'on_inbound_message') and checks that logging was successful:: api.on_inbound_message = function(command) { this.log_info("From command: inbound-message", function (reply) { this.log_info("Log successful: " + reply.success); this.done(); }); } Example 'app_context' that makes the Node.js 'path' module available under the name 'path' in the context that the sandboxed javascript executes in:: {path: require('path')} """ CONFIG_CLASS = JsSandboxConfig POSSIBLE_NODEJS_EXECUTABLES = [ '/usr/local/bin/node', '/usr/local/bin/nodejs', '/usr/bin/node', '/usr/bin/nodejs', ] @classmethod def find_nodejs(cls): for path in cls.POSSIBLE_NODEJS_EXECUTABLES: if os.path.isfile(path): return path return None @classmethod def find_sandbox_js(cls): return pkg_resources.resource_filename('vumi.application', 'sandboxer.js') def get_js_resource(self): return JsSandboxResource('js', self, {}) def get_log_resource(self): return LoggingResource('log', self, {}) def javascript_for_api(self, api): """Called by JsSandboxResource. :returns: String containing Javascript for the app to run. """ return api.config.javascript def app_context_for_api(self, api): """Called by JsSandboxResource :returns: String containing Javascript expression that returns addition context for the namespace the app is being run in. This Javascript is expected to be trusted code. """ return api.config.app_context def get_executable_and_args(self, config): executable = config.executable if executable is None: executable = self.find_nodejs() args = [executable] + (config.args or [self.find_sandbox_js()]) return executable, args def validate_config(self): super(JsSandbox, self).validate_config() if 'js' not in self.resources.resources: self.resources.add_resource('js', self.get_js_resource()) if 'log' not in self.resources.resources: self.resources.add_resource('log', self.get_log_resource()) class JsFileSandbox(JsSandbox): class CONFIG_CLASS(SandboxConfig): javascript_file = ConfigText( "The file containting the Javascript to run", required=True) app_context = ConfigText("Custom context to execute JS with.") def javascript_for_api(self, api): return file(api.config.javascript_file).read() PK=JGkCvumi/application/__init__.py"""The vumi.application API.""" __all__ = ["ApplicationWorker", "SessionManager", "HTTPRelayApplication"] from vumi.application.base import ApplicationWorker from vumi.application.session import SessionManager from vumi.application.http_relay import HTTPRelayApplication PK=JGOvumi/application/session.py# -*- test-case-name: vumi.application.tests.test_session -*- """Session management utilities for ApplicationWorkers.""" import warnings import time from twisted.internet import task class SessionManager(object): """A manager for sessions. :type r_server: redis.Redis :param r_server: Redis db connection. :type prefix: str :param prefix: Prefix to use for Redis keys. :type max_session_length: float :param max_session_length: Time before a session expires. Default is None (never expire). :type gc_period: float :param gc_period: Time in seconds between checking for session expiry. """ def __init__(self, r_server, prefix, max_session_length=None, gc_period=1.0): warnings.warn("vumi.application.SessionManager is deprecated. Use " "vumi.components.session instead.", category=DeprecationWarning) self.max_session_length = max_session_length self.r_server = r_server self.r_prefix = prefix self.gc = task.LoopingCall(lambda: self.active_sessions()) self.gc.start(gc_period) def stop(self): if self.gc.running: return self.gc.stop() def active_sessions(self): """ Return a list of active user_ids and associated sessions. Loops over known active_sessions, some of which might have auto expired. Implements lazy garbage collection, for each entry it checks if the user's session still exists, if not it is removed from the set. """ skey = self.r_key('active_sessions') sessions_to_expire = [] for user_id in self.r_server.smembers(skey): ukey = self.r_key('session', user_id) if self.r_server.exists(ukey): yield user_id, self.load_session(user_id) else: sessions_to_expire.append(user_id) # clear empty ones for user_ids in sessions_to_expire: self.r_server.srem(skey, user_id) def r_key(self, *args): """ Generate a keyname using this workers prefix """ parts = [self.r_prefix] parts.extend(args) return ":".join(parts) def load_session(self, user_id): """ Load session data from Redis """ ukey = self.r_key('session', user_id) return self.r_server.hgetall(ukey) def schedule_session_expiry(self, user_id, timeout): """ Schedule a session to timeout Parameters ---------- user_id : str The user's id. timeout : int The number of seconds after which this session should expire """ ukey = self.r_key('session', user_id) self.r_server.expire(ukey, timeout) def create_session(self, user_id, **kwargs): """ Create a new session using the given user_id """ defaults = { 'created_at': time.time() } defaults.update(kwargs) self.save_session(user_id, defaults) if self.max_session_length: self.schedule_session_expiry(user_id, self.max_session_length) return self.load_session(user_id) def clear_session(self, user_id): ukey = self.r_key('session', user_id) self.r_server.delete(ukey) def save_session(self, user_id, session): """ Save a session Parameters ---------- user_id : str The user's id. session : dict The session info, nested dictionaries are not supported. Any values that are dictionaries are converted to strings by Redis. """ ukey = self.r_key('session', user_id) for s_key, s_value in session.items(): self.r_server.hset(ukey, s_key, s_value) skey = self.r_key('active_sessions') self.r_server.sadd(skey, user_id) return session PK=JG> K $vumi/application/sandbox_rlimiter.py# -*- test-case-name: vumi.application.tests.test_sandbox_rlimiter -*- """NOTE: This module is also used as a standalone Python program that is executed by the sandbox machinery. It must never, ever import non-stdlib modules. """ import os import sys import json import signal import resource class SandboxRlimiter(object): """This reads rlimits in from stdin, applies them and then execs a new executable. It's necessary because Twisted's spawnProcess has no equivalent of the `preexec_fn` argument to :class:`subprocess.POpen`. See http://twistedmatrix.com/trac/ticket/4159. """ def __init__(self, argv, env): start = argv.index('--') + 1 self._executable = argv[start] self._args = [self._executable] + argv[start + 1:] self._env = env def _apply_rlimits(self): data = os.environ[self._SANDBOX_RLIMITS_] rlimits = json.loads(data) if data.strip() else {} for rlimit, (soft, hard) in rlimits.iteritems(): # Cap our rlimits to the maximum allowed. rsoft, rhard = resource.getrlimit(int(rlimit)) soft = min(soft, rsoft) hard = min(hard, rhard) resource.setrlimit(int(rlimit), (soft, hard)) def _reset_signals(self): # reset all signal handlers to their defaults for i in range(1, signal.NSIG): if signal.getsignal(i) == signal.SIG_IGN: signal.signal(i, signal.SIG_DFL) def _sanitize_fds(self): # close everything except stdin, stdout and stderr maxfds = resource.getrlimit(resource.RLIMIT_NOFILE)[1] os.closerange(3, maxfds) def execute(self): self._apply_rlimits() self._restore_child_env(os.environ) self._sanitize_fds() self._reset_signals() os.execvpe(self._executable, self._args, self._env) _SANDBOX_RLIMITS_ = "_SANDBOX_RLIMITS_" @classmethod def _override_child_env(cls, env, rlimits): """Put RLIMIT config in the env.""" env[cls._SANDBOX_RLIMITS_] = json.dumps(rlimits) @classmethod def _restore_child_env(cls, env): """Remove RLIMIT config.""" del env[cls._SANDBOX_RLIMITS_] @classmethod def script_name(cls): # we need to pass Python the actual filename of this script # (rather than using -m __name__) so that is doesn't import # Twisted's reactor (since that causes errors when we close # all the file handles if using certain reactors). script_name = __file__ if script_name.endswith('.pyc') or script_name.endswith('.pyo'): script_name = script_name[:-len('.pyc')] + '.py' return script_name @classmethod def spawn(cls, reactor, protocol, executable, rlimits, **kwargs): # spawns a SandboxRlimiter, connectionMade then passes the rlimits # through to stdin and the SandboxRlimiter applies them args = kwargs.pop('args', []) # the -u for unbuffered I/O is important (otherwise the process # execed will be very confused about where its stdin data has # gone) args = [sys.executable, '-u', cls.script_name(), '--'] + args env = kwargs.pop('env', {}) cls._override_child_env(env, rlimits) reactor.spawnProcess(protocol, sys.executable, args=args, env=env, **kwargs) if __name__ == "__main__": rlimiter = SandboxRlimiter(sys.argv, os.environ) rlimiter.execute() PKqG`!1!1"vumi/application/rapidsms_relay.py# -*- test-case-name: vumi.application.tests.test_rapidsms_relay -*- import json from base64 import b64encode from zope.interface import implements from twisted.internet.defer import ( inlineCallbacks, returnValue, DeferredList) from twisted.web import http from twisted.web.resource import Resource, IResource from twisted.web.server import NOT_DONE_YET from twisted.cred import portal, checkers, credentials, error from twisted.web.guard import HTTPAuthSessionWrapper, BasicCredentialFactory from vumi.application.base import ApplicationWorker from vumi.persist.txredis_manager import TxRedisManager from vumi.config import ( ConfigUrl, ConfigText, ConfigInt, ConfigDict, ConfigBool, ConfigContext, ConfigList) from vumi.message import to_json, TransportUserMessage from vumi.utils import http_request_full from vumi.errors import ConfigError, InvalidEndpoint from vumi import log class HealthResource(Resource): isLeaf = True def render_GET(self, request): request.setResponseCode(http.OK) request.do_not_log = True return 'OK' class SendResource(Resource): isLeaf = True def __init__(self, application): self.application = application Resource.__init__(self) def finish_request(self, request, msgs): request.setResponseCode(http.OK) request.write(to_json([msg.payload for msg in msgs])) request.finish() def fail_request(self, request, f): if f.check(BadRequestError): code = http.BAD_REQUEST else: code = http.INTERNAL_SERVER_ERROR log.err(f) request.setResponseCode(code) request.write(f.getErrorMessage()) request.finish() def render_(self, request): log.msg("Send request: %s" % (request,)) request.setHeader("content-type", "application/json; charset=utf-8") d = self.application.handle_raw_outbound_message(request) d.addCallback(lambda msgs: self.finish_request(request, msgs)) d.addErrback(lambda f: self.fail_request(request, f)) return NOT_DONE_YET def render_PUT(self, request): return self.render_(request) def render_GET(self, request): return self.render_(request) def render_POST(self, request): return self.render_(request) class BadRequestError(Exception): """Raised when an invalid request was received from RapidSMS.""" class RapidSMSRelayRealm(object): implements(portal.IRealm) def __init__(self, resource): self.resource = resource def requestAvatar(self, user, mind, *interfaces): if IResource in interfaces: return (IResource, self.resource, lambda: None) raise NotImplementedError() class RapidSMSRelayAccessChecker(object): implements(checkers.ICredentialsChecker) credentialInterfaces = (credentials.IUsernamePassword, credentials.IAnonymous) def __init__(self, get_avatar_id): self._get_avatar_id = get_avatar_id def requestAvatarId(self, credentials): return self._get_avatar_id(credentials) class RapidSMSRelayConfig(ApplicationWorker.CONFIG_CLASS): """RapidSMS relay configuration.""" web_path = ConfigText( "Path to listen for outbound messages from RapidSMS on.", static=True) web_port = ConfigInt( "Port to listen for outbound messages from RapidSMS on.", static=True) redis_manager = ConfigDict( "Redis manager configuration (only required if" " `allow_replies` is true)", default={}, static=True) allow_replies = ConfigBool( "Whether to support replies via the `in_reply_to` argument" " from RapidSMS.", default=True, static=True) vumi_username = ConfigText( "Username required when calling `web_path` (default: no" " authentication)", default=None) vumi_password = ConfigText( "Password required when calling `web_path`", default=None) vumi_auth_method = ConfigText( "Authentication method required when calling `web_path`." "The 'basic' method is currently the only available method", default='basic') vumi_reply_timeout = ConfigInt( "Number of seconds to keep original messages in redis so that" " replies may be sent via `in_reply_to`.", default=10 * 60) allowed_endpoints = ConfigList( 'List of allowed endpoints to send from.', required=True, default=("default",)) rapidsms_url = ConfigUrl("URL of the rapidsms http backend.") rapidsms_username = ConfigText( "Username to use for the `rapidsms_url` (default: no authentication)", default=None) rapidsms_password = ConfigText( "Password to use for the `rapidsms_url`", default=None) rapidsms_auth_method = ConfigText( "Authentication method to use with `rapidsms_url`." " The 'basic' method is currently the only available method.", default='basic') rapidsms_http_method = ConfigText( "HTTP request method to use for the `rapidsms_url`", default='POST') class RapidSMSRelay(ApplicationWorker): """Application that relays messages to RapidSMS.""" CONFIG_CLASS = RapidSMSRelayConfig ALLOWED_ENDPOINTS = None agent_factory = None # For swapping out the Agent we use in tests. def validate_config(self): self.supported_auth_methods = { 'basic': self.generate_basic_auth_headers, } def generate_basic_auth_headers(self, username, password): credentials = ':'.join([username, password]) auth_string = b64encode(credentials.encode('utf-8')) return { 'Authorization': ['Basic %s' % (auth_string,)] } def get_auth_headers(self, config): auth_method, username, password = (config.rapidsms_auth_method, config.rapidsms_username, config.rapidsms_password) if auth_method not in self.supported_auth_methods: raise ConfigError('HTTP Authentication method %s' ' not supported' % (repr(auth_method,))) if username is not None: handler = self.supported_auth_methods.get(auth_method) return handler(username, password) return {} def get_protected_resource(self, resource): checker = RapidSMSRelayAccessChecker(self.get_avatar_id) realm = RapidSMSRelayRealm(resource) p = portal.Portal(realm, [checker]) factory = BasicCredentialFactory("RapidSMS Relay") protected_resource = HTTPAuthSessionWrapper(p, [factory]) return protected_resource @inlineCallbacks def get_avatar_id(self, creds): # The ConfigContext(username=...) passed into .get_config() is to # allow sub-classes to change how config.vumi_username and # config.vumi_password are looked up by overriding .get_config(). if credentials.IAnonymous.providedBy(creds): config = yield self.get_config(None, ConfigContext(username=None)) # allow anonymous authentication if no username is configured if config.vumi_username is None: returnValue(None) elif credentials.IUsernamePassword.providedBy(creds): username, password = creds.username, creds.password config = yield self.get_config(None, ConfigContext(username=username)) if (username == config.vumi_username and password == config.vumi_password): returnValue(username) raise error.UnauthorizedLogin() @inlineCallbacks def setup_application(self): config = self.get_static_config() self.redis = None if config.allow_replies: self.redis = yield TxRedisManager.from_config(config.redis_manager) send_resource = self.get_protected_resource(SendResource(self)) self.web_resource = yield self.start_web_resources( [ (send_resource, config.web_path), (HealthResource(), 'health'), ], config.web_port) @inlineCallbacks def teardown_application(self): yield self.web_resource.loseConnection() if self.redis is not None: yield self.redis.close_manager() def _msg_key(self, message_id): return ":".join(["messages", message_id]) def _load_message(self, message_id): d = self.redis.get(self._msg_key(message_id)) d.addCallback(lambda r: r and TransportUserMessage.from_json(r)) return d def _store_message(self, message, timeout): msg_key = self._msg_key(message['message_id']) d = self.redis.set(msg_key, message.to_json()) d.addCallback(lambda r: self.redis.expire(msg_key, timeout)) return d @inlineCallbacks def _handle_reply_to(self, config, content, to_addrs, in_reply_to): if not config.allow_replies: raise BadRequestError("Support for `in_reply_to` not configured.") orig_msg = yield self._load_message(in_reply_to) if not orig_msg: raise BadRequestError("Original message %r not found." % (in_reply_to,)) if to_addrs: if len(to_addrs) > 1 or to_addrs[0] != orig_msg['from_addr']: raise BadRequestError( "Supplied `to_addrs` don't match `from_addr` of original" " message %r" % (in_reply_to,)) reply = yield self.reply_to(orig_msg, content) returnValue([reply]) def send_rapidsms_nonreply(self, to_addr, content, config, endpoint): """Call .send_to() for a message from RapidSMS that is not a reply. This is for overriding by sub-classes that need to add additional message options. """ return self.send_to(to_addr, content, endpoint=endpoint) def _handle_send_to(self, config, content, to_addrs, endpoint): sends = [] try: self.check_endpoint(config.allowed_endpoints, endpoint) for to_addr in to_addrs: sends.append(self.send_rapidsms_nonreply( to_addr, content, config, endpoint)) except InvalidEndpoint, e: raise BadRequestError(e) d = DeferredList(sends, consumeErrors=True) d.addCallback(lambda msgs: [msg[1] for msg in msgs if msg[0]]) return d @inlineCallbacks def handle_raw_outbound_message(self, request): config = yield self.get_config( None, ConfigContext(username=request.getUser())) data = json.loads(request.content.read()) content = data['content'] to_addrs = data['to_addr'] if not isinstance(to_addrs, list): raise BadRequestError( "Supplied `to_addr` (%r) was not a list." % (to_addrs,)) in_reply_to = data.get('in_reply_to') endpoint = data.get('endpoint') if in_reply_to is not None: msgs = yield self._handle_reply_to(config, content, to_addrs, in_reply_to) else: msgs = yield self._handle_send_to(config, content, to_addrs, endpoint) returnValue(msgs) @inlineCallbacks def _call_rapidsms(self, message): config = yield self.get_config(message) http_method = config.rapidsms_http_method.encode("utf-8") headers = self.get_auth_headers(config) yield self._store_message(message, config.vumi_reply_timeout) response = http_request_full( config.rapidsms_url.geturl(), message.to_json(), headers, http_method, agent_class=self.agent_factory) response.addCallback(lambda response: log.info(response.code)) response.addErrback(lambda failure: log.err(failure)) yield response def consume_user_message(self, message): return self._call_rapidsms(message) def close_session(self, message): return self._call_rapidsms(message) def consume_ack(self, event): log.info("Acknowledgement received for message %r" % (event['user_message_id'])) def consume_delivery_report(self, event): log.info("Delivery report received for message %r, status %r" % (event['user_message_id'], event['delivery_status'])) PK=JGf&vumi/application/tests/test_session.py"""Tests for vumi.application.session.""" import time from vumi.persist.fake_redis import FakeRedis from vumi.application import SessionManager from vumi.tests.helpers import VumiTestCase class TestSessionManager(VumiTestCase): def setUp(self): self.fake_redis = FakeRedis() self.add_cleanup(self.fake_redis.teardown) self.sm = SessionManager(self.fake_redis, prefix="test") self.add_cleanup(self.sm.stop) def test_active_sessions(self): def get_sessions(): return sorted(self.sm.active_sessions()) def ids(): return [x[0] for x in get_sessions()] self.assertEqual(ids(), []) self.sm.create_session("u1") self.assertEqual(ids(), ["u1"]) # 10 seconds later self.sm.create_session("u2", created_at=time.time() + 10) self.assertEqual(ids(), ["u1", "u2"]) s1, s2 = get_sessions() self.assertTrue(s1[1]['created_at'] < s2[1]['created_at']) def test_schedule_session_expiry(self): self.sm.max_session_length = 60.0 self.sm.create_session("u1") def test_create_and_retrieve_session(self): session = self.sm.create_session("u1") self.assertEqual(sorted(session.keys()), ['created_at']) self.assertTrue(time.time() - float(session['created_at']) < 10.0) loaded = self.sm.load_session("u1") self.assertEqual(loaded, session) def test_save_session(self): test_session = {"foo": 5, "bar": "baz"} self.sm.create_session("u1") self.sm.save_session("u1", test_session) session = self.sm.load_session("u1") self.assertTrue(session.pop('created_at') is not None) # Redis saves & returns all session values as strings self.assertEqual(session, dict([map(str, kvs) for kvs in test_session.items()])) def test_lazy_clearing(self): self.sm.save_session('user_id', {}) self.assertEqual(list(self.sm.active_sessions()), []) PK=JG!./vumi/application/tests/test_sandbox_rlimiter.py"""Tests for vumi.application.sandbox_rlimiter.""" from vumi.application import sandbox_rlimiter from vumi.application.sandbox_rlimiter import SandboxRlimiter from vumi.tests.helpers import VumiTestCase class TestSandboxRlimiter(VumiTestCase): def test_script_name_dot_py(self): self.patch(sandbox_rlimiter, '__file__', 'foo.py') self.assertEqual(SandboxRlimiter.script_name(), 'foo.py') def test_script_name_dot_pyc(self): self.patch(sandbox_rlimiter, '__file__', 'foo.pyc') self.assertEqual(SandboxRlimiter.script_name(), 'foo.py') def test_script_name_dot_pyo(self): self.patch(sandbox_rlimiter, '__file__', 'foo.pyo') self.assertEqual(SandboxRlimiter.script_name(), 'foo.py') PK=JGCڷvumi/application/tests/app.js// Demonstration App api.log_info("From init!"); api.on_unknown_command = function(command) { // Called for any command that doesn't have an explicit // command handler. this.log_info("From unknown: " + command.cmd); } api.on_inbound_message = function(command) { this.log_info("From command: inbound-message", function (reply) { this.log_info("Log successful: " + reply.success); this.done(); }); } PK=JG .vumi/application/tests/app_delayed_requests.js// Demonstration App api.log_info("From init!"); api.on_unknown_command = function(command) { setImmediate(function() { // Called for any command that doesn't have an explicit // command handler. api.log_info("From unknown: " + command.cmd); }); }; api.on_inbound_message = function(command) { setImmediate(function() { api.log_info("From command: inbound-message", function (reply) { api.log_info("Log successful: " + reply.success); api.done(); }); }); }; PK=JG}ކvumi/application/tests/utils.pyfrom twisted.internet.defer import inlineCallbacks from vumi.tests.utils import VumiWorkerTestCase, PersistenceMixin class ApplicationTestCase(VumiWorkerTestCase, PersistenceMixin): """ This is a base class for testing application workers. """ application_class = None def setUp(self): self._persist_setUp() super(ApplicationTestCase, self).setUp() @inlineCallbacks def tearDown(self): yield super(ApplicationTestCase, self).tearDown() yield self._persist_tearDown() def get_application(self, config, cls=None, start=True): """ Get an instance of a worker class. :param config: Config dict. :param cls: The Application class to instantiate. Defaults to :attr:`application_class` :param start: True to start the application (default), False otherwise. Some default config values are helpfully provided in the interests of reducing boilerplate: * ``transport_name`` defaults to :attr:`self.transport_name` """ if cls is None: cls = self.application_class config = self.mk_config(config) config.setdefault('transport_name', self.transport_name) return self.get_worker(config, cls, start) def get_dispatched_messages(self): return self.get_dispatched_outbound() def wait_for_dispatched_messages(self, amount): return self.wait_for_dispatched_outbound(amount) def clear_dispatched_messages(self): return self.clear_dispatched_outbound() def dispatch(self, message, rkey=None, exchange='vumi'): if rkey is None: rkey = self.rkey('inbound') return self._dispatch(message, rkey, exchange) PKqG>'D4D4#vumi/application/tests/test_base.pyfrom twisted.internet.defer import inlineCallbacks, returnValue from vumi.application.base import ApplicationWorker, SESSION_NEW, SESSION_CLOSE from vumi.message import TransportUserMessage from vumi.application.tests.helpers import ApplicationHelper from vumi.tests.helpers import VumiTestCase, WorkerHelper from vumi.errors import InvalidEndpoint class DummyApplicationWorker(ApplicationWorker): ALLOWED_ENDPOINTS = frozenset(['default', 'outbound1']) def __init__(self, *args, **kwargs): super(DummyApplicationWorker, self).__init__(*args, **kwargs) self.record = [] def consume_unknown_event(self, event): self.record.append(('unknown_event', event)) def consume_ack(self, event): self.record.append(('ack', event)) def consume_nack(self, event): self.record.append(('nack', event)) def consume_delivery_report(self, event): self.record.append(('delivery_report', event)) def consume_user_message(self, message): self.record.append(('user_message', message)) def new_session(self, message): self.record.append(('new_session', message)) def close_session(self, message): self.record.append(('close_session', message)) class EchoApplicationWorker(ApplicationWorker): def consume_user_message(self, message): self.reply_to(message, message['content']) class TestApplicationWorker(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper( ApplicationHelper(DummyApplicationWorker)) self.worker = yield self.app_helper.get_application({}) def assert_msgs_match(self, msgs, expected_msgs): for key in ['timestamp', 'message_id']: for msg in msgs + expected_msgs: self.assertTrue(key in msg.payload) msg[key] = 'OVERRIDDEN_BY_TEST' if not msg.get('routing_metadata'): msg['routing_metadata'] = {'endpoint_name': 'default'} for msg, expected_msg in zip(msgs, expected_msgs): self.assertEqual(msg, expected_msg) self.assertEqual(len(msgs), len(expected_msgs)) @inlineCallbacks def test_event_dispatch(self): events = [ ('ack', self.app_helper.make_ack()), ('nack', self.app_helper.make_nack()), ('delivery_report', self.app_helper.make_delivery_report()), ] for name, event in events: yield self.app_helper.dispatch_event(event) self.assertEqual(self.worker.record, [(name, event)]) del self.worker.record[:] @inlineCallbacks def test_unknown_event_dispatch(self): # temporarily pretend the worker doesn't know about acks del self.worker._event_handlers['ack'] bad_event = yield self.app_helper.make_dispatch_ack() self.assertEqual(self.worker.record, [('unknown_event', bad_event)]) @inlineCallbacks def test_user_message_dispatch(self): messages = [ ('user_message', self.app_helper.make_inbound("foo")), ('new_session', self.app_helper.make_inbound( "foo", session_event=SESSION_NEW)), ('close_session', self.app_helper.make_inbound( "foo", session_event=SESSION_CLOSE)), ] for name, message in messages: yield self.app_helper.dispatch_inbound(message) self.assertEqual(self.worker.record, [(name, message)]) del self.worker.record[:] @inlineCallbacks def test_reply_to(self): msg = self.app_helper.make_inbound("foo") yield self.worker.reply_to(msg, "More!") yield self.worker.reply_to(msg, "End!", False) replies = self.app_helper.get_dispatched_outbound() expecteds = [msg.reply("More!"), msg.reply("End!", False)] self.assert_msgs_match(replies, expecteds) @inlineCallbacks def test_waiting_message(self): # Get rid of the old worker. yield self.app_helper.cleanup_worker(self.worker) self.worker = None # Stick a message on the queue before starting the worker so it will be # received as soon as the message consumer starts consuming. msg = yield self.app_helper.make_dispatch_inbound("Hello!") # Start the app and process stuff. self.worker = yield self.app_helper.get_application( {}, EchoApplicationWorker) replies = yield self.app_helper.wait_for_dispatched_outbound(1) expecteds = [msg.reply("Hello!")] self.assert_msgs_match(replies, expecteds) @inlineCallbacks def test_reply_to_group(self): msg = self.app_helper.make_inbound("foo") yield self.worker.reply_to_group(msg, "Group!") replies = self.app_helper.get_dispatched_outbound() expecteds = [msg.reply_group("Group!")] self.assert_msgs_match(replies, expecteds) @inlineCallbacks def test_send_to(self): sent_msg = yield self.worker.send_to( '+12345', "Hi!", endpoint="default") sends = self.app_helper.get_dispatched_outbound() expecteds = [TransportUserMessage.send( '+12345', "Hi!", transport_name=None)] self.assert_msgs_match(sends, expecteds) self.assert_msgs_match(sends, [sent_msg]) @inlineCallbacks def test_send_to_with_different_endpoint(self): sent_msg = yield self.worker.send_to( '+12345', "Hi!", endpoint="outbound1", transport_type=TransportUserMessage.TT_USSD) sends = self.app_helper.get_dispatched_outbound() expecteds = [TransportUserMessage.send( '+12345', "Hi!", transport_type=TransportUserMessage.TT_USSD)] expecteds[0].set_routing_endpoint("outbound1") self.assert_msgs_match(sends, [sent_msg]) self.assert_msgs_match(sends, expecteds) def test_subclassing_api(self): worker = WorkerHelper.get_worker_raw( ApplicationWorker, {'transport_name': 'test'}) worker.consume_ack(self.app_helper.make_ack()) worker.consume_nack(self.app_helper.make_nack()) worker.consume_delivery_report(self.app_helper.make_delivery_report()) worker.consume_unknown_event(self.app_helper.make_inbound("foo")) worker.consume_user_message(self.app_helper.make_inbound("foo")) worker.new_session(self.app_helper.make_inbound("foo")) worker.close_session(self.app_helper.make_inbound("foo")) def get_app_consumers(self, app): for connector in app.connectors.values(): for consumer in connector._consumers.values(): yield consumer @inlineCallbacks def test_application_prefetch_count_custom(self): app = yield self.app_helper.get_application({ 'transport_name': 'test', 'amqp_prefetch_count': 10, }) for consumer in self.get_app_consumers(app): fake_channel = consumer.channel._fake_channel self.assertEqual(fake_channel.qos_prefetch_count, 10) @inlineCallbacks def test_application_prefetch_count_default(self): app = yield self.app_helper.get_application({ 'transport_name': 'test', }) for consumer in self.get_app_consumers(app): fake_channel = consumer.channel._fake_channel self.assertEqual(fake_channel.qos_prefetch_count, 20) @inlineCallbacks def test_application_prefetch_count_none(self): app = yield self.app_helper.get_application({ 'transport_name': 'test', 'amqp_prefetch_count': None, }) for consumer in self.get_app_consumers(app): fake_channel = consumer.channel._fake_channel self.assertEqual(fake_channel.qos_prefetch_count, 0) def assertNotRaises(self, error_class, f, *args, **kw): try: f(*args, **kw) except error_class as e: self.fail("%s unexpectedly raised: %s" % (error_class, e)) @inlineCallbacks def test_check_endpoints(self): app = yield self.app_helper.get_application({}) check = app.check_endpoint self.assertNotRaises(InvalidEndpoint, check, None, None) self.assertNotRaises(InvalidEndpoint, check, None, 'foo') self.assertNotRaises(InvalidEndpoint, check, ['default'], None) self.assertNotRaises(InvalidEndpoint, check, ['foo'], 'foo') self.assertRaises(InvalidEndpoint, check, [], None) self.assertRaises(InvalidEndpoint, check, ['foo'], 'bar') class TestApplicationWorkerWithSendToConfig(VumiTestCase): @inlineCallbacks def setUp(self): self.app_helper = self.add_helper( ApplicationHelper(DummyApplicationWorker)) self.worker = yield self.app_helper.get_application({ 'send_to': { 'default': { 'transport_name': 'default_transport', }, 'outbound1': { 'transport_name': 'outbound1_transport', }, }, }) def assert_msgs_match(self, msgs, expected_msgs): for key in ['timestamp', 'message_id']: for msg in msgs + expected_msgs: self.assertTrue(key in msg.payload) msg[key] = 'OVERRIDDEN_BY_TEST' if not msg.get('routing_metadata'): msg['routing_metadata'] = {'endpoint_name': 'default'} for msg, expected_msg in zip(msgs, expected_msgs): self.assertEqual(msg, expected_msg) self.assertEqual(len(msgs), len(expected_msgs)) @inlineCallbacks def send_to(self, *args, **kw): sent_msg = yield self.worker.send_to(*args, **kw) returnValue(sent_msg) @inlineCallbacks def test_send_to(self): sent_msg = yield self.send_to('+12345', "Hi!") sends = self.app_helper.get_dispatched_outbound() expecteds = [TransportUserMessage.send('+12345', "Hi!", transport_name='default_transport')] self.assert_msgs_match(sends, expecteds) self.assert_msgs_match(sends, [sent_msg]) @inlineCallbacks def test_send_to_with_options(self): sent_msg = yield self.send_to( '+12345', "Hi!", transport_type=TransportUserMessage.TT_USSD) sends = self.app_helper.get_dispatched_outbound() expecteds = [TransportUserMessage.send('+12345', "Hi!", transport_type=TransportUserMessage.TT_USSD, transport_name='default_transport')] self.assert_msgs_match(sends, expecteds) self.assert_msgs_match(sends, [sent_msg]) @inlineCallbacks def test_send_to_with_endpoint(self): sent_msg = yield self.send_to('+12345', "Hi!", "outbound1", transport_type=TransportUserMessage.TT_USSD) sends = self.app_helper.get_dispatched_outbound() expecteds = [TransportUserMessage.send('+12345', "Hi!", transport_type=TransportUserMessage.TT_USSD, transport_name='outbound1_transport')] expecteds[0].set_routing_endpoint("outbound1") self.assert_msgs_match(sends, expecteds) self.assert_msgs_match(sends, [sent_msg]) @inlineCallbacks def test_send_to_with_bad_endpoint(self): yield self.assertFailure( self.send_to('+12345', "Hi!", "outbound_unknown"), InvalidEndpoint) class TestApplicationMiddlewareHooks(VumiTestCase): TEST_MIDDLEWARE_CONFIG = { "middleware": [ {"mw1": "vumi.middleware.tests.utils.RecordingMiddleware"}, {"mw2": "vumi.middleware.tests.utils.RecordingMiddleware"}, ], } def setUp(self): self.app_helper = self.add_helper(ApplicationHelper(ApplicationWorker)) @inlineCallbacks def test_middleware_for_inbound_messages(self): app = yield self.app_helper.get_application( self.TEST_MIDDLEWARE_CONFIG) msgs = [] app.consume_user_message = msgs.append yield self.app_helper.make_dispatch_inbound("hi") [msg] = msgs self.assertEqual(msg['record'], [ ('mw1', 'inbound', self.app_helper.transport_name), ('mw2', 'inbound', self.app_helper.transport_name), ]) @inlineCallbacks def test_middleware_for_events(self): app = yield self.app_helper.get_application( self.TEST_MIDDLEWARE_CONFIG) msgs = [] app._event_handlers['ack'] = msgs.append yield self.app_helper.make_dispatch_ack() [msg] = msgs self.assertEqual(msg['record'], [ ('mw1', 'event', self.app_helper.transport_name), ('mw2', 'event', self.app_helper.transport_name), ]) @inlineCallbacks def test_middleware_for_outbound_messages(self): app = yield self.app_helper.get_application( self.TEST_MIDDLEWARE_CONFIG) orig_msg = self.app_helper.make_inbound("hi") yield app.reply_to(orig_msg, 'Hello!') msgs = self.app_helper.get_dispatched_outbound() [msg] = msgs self.assertEqual(msg['record'], [ ['mw2', 'outbound', self.app_helper.transport_name], ['mw1', 'outbound', self.app_helper.transport_name], ]) PK=JG+vumi/application/tests/app_requires_path.js// Demonstration App api.log_info("From init!"); api.on_inbound_message = function(command) { if (path) { this.log_info("We have access to path!"); } else { this.log_info("We don't have access to path. :("); } this.done(); } PK=JGqQei&vumi/application/tests/test_sandbox.py"""Tests for vumi.application.sandbox.""" import base64 import os import sys import json import resource import pkg_resources import logging from collections import defaultdict from datetime import datetime import warnings from OpenSSL.SSL import ( VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_NONE, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD) from twisted.internet.defer import ( inlineCallbacks, fail, succeed, DeferredQueue) from twisted.internet.error import ProcessTerminated from twisted.web.http_headers import Headers from vumi.application.sandbox import ( Sandbox, SandboxApi, SandboxCommand, SandboxResources, SandboxResource, RedisResource, OutboundResource, JsSandboxResource, LoggingResource, HttpClientResource, JsSandbox, JsFileSandbox, HttpClientContextFactory, HttpClientPolicyForHTTPS, make_context_factory) from vumi.application.tests.helpers import ( ApplicationHelper, find_nodejs_or_skip_test) from vumi.tests.utils import LogCatcher from vumi.tests.helpers import VumiTestCase, PersistenceHelper warnings.warn( "Use of vumi.application.tests.test_sandbox is deprecated, the vumi " "sandbox worker and its components have moved to the vxsandbox package:" "pypi.python.org/pypi/vxsandbox", category=DeprecationWarning) class MockResource(SandboxResource): def __init__(self, name, app_worker, **handlers): super(MockResource, self).__init__(name, app_worker, {}) for name, handler in handlers.iteritems(): setattr(self, "handle_%s" % name, handler) class ListLoggingResource(LoggingResource): def __init__(self, name, app_worker, config): super(ListLoggingResource, self).__init__(name, app_worker, config) self.msgs = [] def log(self, api, msg, level): self.msgs.append((level, msg)) class SandboxTestCaseBase(VumiTestCase): application_class = Sandbox def setUp(self): self.app_helper = self.add_helper( ApplicationHelper(self.application_class)) def setup_app(self, executable=None, args=None, extra_config=None): tmp_path = self.mktemp() os.mkdir(tmp_path) config = { 'path': tmp_path, 'timeout': '10', } if executable is not None: config['executable'] = executable if args is not None: config['args'] = args if extra_config is not None: config.update(extra_config) return self.app_helper.get_application(config) class TestSandbox(SandboxTestCaseBase): def setup_app(self, python_code, extra_config=None): return super(TestSandbox, self).setup_app( sys.executable, ['-c', python_code], extra_config=extra_config) @inlineCallbacks def test_bad_command_from_sandbox(self): app = yield self.setup_app( "import sys, time\n" "sys.stdout.write('{}\\n')\n" "sys.stdout.flush()\n" "time.sleep(5)\n" ) with LogCatcher(log_level=logging.ERROR) as lc: status = yield app.process_event_in_sandbox( self.app_helper.make_ack(sandbox_id='sandbox1')) [msg] = lc.messages() self.assertTrue(msg.startswith( "Resource fallback received unknown command 'unknown'" " from sandbox 'sandbox1'. Killing sandbox." " [Full command:  $vumi-0.6.9.dist-info/DESCRIPTION.rstVumi ==== NOTE: Version 0.6.x is backward-compatible with 0.5.x for the most part, with some caveats. The first few releases will be removing a bunch of obsolete and deprecated code and replacing some of the internals of the base worker. While this will almost certainly not break the majority of things built on vumi, old code or code that relies too heavily on the details of worker setup may need to be fixed. Documentation available online at http://vumi.readthedocs.org/ and in the `docs` directory of the repository. |vumi-ver| |vumi-ci| |vumi-cover| |python-ver| |vumi-docs| |vumi-downloads| |vumi-license| .. |vumi-ver| image:: https://pypip.in/v/vumi/badge.png?text=pypi :alt: Vumi version :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-ci| image:: https://travis-ci.org/praekelt/vumi.png?branch=develop :alt: Vumi Travis CI build status :scale: 100% :target: https://travis-ci.org/praekelt/vumi .. |vumi-cover| image:: https://coveralls.io/repos/praekelt/vumi/badge.png?branch=develop :alt: Vumi coverage on Coveralls :scale: 100% :target: https://coveralls.io/r/praekelt/vumi .. |python-ver| image:: https://pypip.in/py_versions/vumi/badge.svg :alt: Python version :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-docs| image:: https://readthedocs.org/projects/vumi/badge/?version=latest :alt: Vumi documentation :scale: 100% :target: http://vumi.readthedocs.org/ .. |vumi-downloads| image:: https://pypip.in/download/vumi/badge.svg :alt: Vumi downloads from PyPI :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-license| image:: https://pypip.in/license/vumi/badge.svg :target: https://pypi.python.org/pypi/vumi :alt: Vumi license To build the docs locally:: $ virtualenv --no-site-packages ve/ $ source ve/bin/activate (ve)$ pip install Sphinx (ve)$ cd docs (ve)$ make html You'll find the docs in `docs/_build/index.html` You can contact the Vumi development team in the following ways: * via *email* by joining the the `vumi-dev@googlegroups.com`_ mailing list * on *irc* in *#vumi* on the `Freenode IRC network`_ .. _vumi-dev@googlegroups.com: https://groups.google.com/forum/?fromgroups#!forum/vumi-dev .. _Freenode IRC network: https://webchat.freenode.net/?channels=#vumi Issues can be filed in the GitHub issue tracker. Please don't use the issue tracker for general support queries. PK=HgglA"vumi-0.6.9.dist-info/metadata.json{"classifiers": ["Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: POSIX", "Programming Language :: Python", "Programming Language :: Python :: 2.7", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: System :: Networking"], "extensions": {"python.details": {"contacts": [{"email": "dev@praekeltfoundation.org", "name": "Praekelt Foundation", "role": "author"}], "document_names": {"description": "DESCRIPTION.rst"}, "project_urls": {"Home": "http://github.com/praekelt/vumi"}}}, "extras": [], "generator": "bdist_wheel (0.26.0)", "license": "BSD", "metadata_version": "2.0", "name": "vumi", "run_requires": [{"requires": ["PyYAML", "Twisted (>=13.2.0)", "certifi", "confmodel (>=0.2.0)", "cryptography", "hyperloglog", "iso8601", "pyOpenSSL", "python-smpp (>=0.1.5)", "pytz", "redis (>=2.10.0)", "riak (>=2.1)", "service-identity", "treq", "txAMQP (>=0.6.2)", "txJSON-RPC (==0.3.1)", "txTwitter (>=0.1.4a)", "txredis", "txssmi (>=0.3.0)", "wokkel", "zope.interface"]}], "summary": "Super-scalable messaging engine for the delivery of SMS, Star Menu and chat messages to diverse audiences in emerging markets and beyond.", "version": "0.6.9"}PK=H "vumi-0.6.9.dist-info/top_level.txttwisted vumi PK=H''\\vumi-0.6.9.dist-info/WHEELWheel-Version: 1.0 Generator: bdist_wheel (0.26.0) Root-Is-Purelib: true Tag: py2-none-any PK=HJ-3vumi-0.6.9.dist-info/METADATAMetadata-Version: 2.0 Name: vumi Version: 0.6.9 Summary: Super-scalable messaging engine for the delivery of SMS, Star Menu and chat messages to diverse audiences in emerging markets and beyond. Home-page: http://github.com/praekelt/vumi Author: Praekelt Foundation Author-email: dev@praekeltfoundation.org License: BSD Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: BSD License Classifier: Operating System :: POSIX Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.7 Classifier: Topic :: Software Development :: Libraries :: Python Modules Classifier: Topic :: System :: Networking Requires-Dist: PyYAML Requires-Dist: Twisted (>=13.2.0) Requires-Dist: certifi Requires-Dist: confmodel (>=0.2.0) Requires-Dist: cryptography Requires-Dist: hyperloglog Requires-Dist: iso8601 Requires-Dist: pyOpenSSL Requires-Dist: python-smpp (>=0.1.5) Requires-Dist: pytz Requires-Dist: redis (>=2.10.0) Requires-Dist: riak (>=2.1) Requires-Dist: service-identity Requires-Dist: treq Requires-Dist: txAMQP (>=0.6.2) Requires-Dist: txJSON-RPC (==0.3.1) Requires-Dist: txTwitter (>=0.1.4a) Requires-Dist: txredis Requires-Dist: txssmi (>=0.3.0) Requires-Dist: wokkel Requires-Dist: zope.interface Vumi ==== NOTE: Version 0.6.x is backward-compatible with 0.5.x for the most part, with some caveats. The first few releases will be removing a bunch of obsolete and deprecated code and replacing some of the internals of the base worker. While this will almost certainly not break the majority of things built on vumi, old code or code that relies too heavily on the details of worker setup may need to be fixed. Documentation available online at http://vumi.readthedocs.org/ and in the `docs` directory of the repository. |vumi-ver| |vumi-ci| |vumi-cover| |python-ver| |vumi-docs| |vumi-downloads| |vumi-license| .. |vumi-ver| image:: https://pypip.in/v/vumi/badge.png?text=pypi :alt: Vumi version :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-ci| image:: https://travis-ci.org/praekelt/vumi.png?branch=develop :alt: Vumi Travis CI build status :scale: 100% :target: https://travis-ci.org/praekelt/vumi .. |vumi-cover| image:: https://coveralls.io/repos/praekelt/vumi/badge.png?branch=develop :alt: Vumi coverage on Coveralls :scale: 100% :target: https://coveralls.io/r/praekelt/vumi .. |python-ver| image:: https://pypip.in/py_versions/vumi/badge.svg :alt: Python version :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-docs| image:: https://readthedocs.org/projects/vumi/badge/?version=latest :alt: Vumi documentation :scale: 100% :target: http://vumi.readthedocs.org/ .. |vumi-downloads| image:: https://pypip.in/download/vumi/badge.svg :alt: Vumi downloads from PyPI :scale: 100% :target: https://pypi.python.org/pypi/vumi .. |vumi-license| image:: https://pypip.in/license/vumi/badge.svg :target: https://pypi.python.org/pypi/vumi :alt: Vumi license To build the docs locally:: $ virtualenv --no-site-packages ve/ $ source ve/bin/activate (ve)$ pip install Sphinx (ve)$ cd docs (ve)$ make html You'll find the docs in `docs/_build/index.html` You can contact the Vumi development team in the following ways: * via *email* by joining the the `vumi-dev@googlegroups.com`_ mailing list * on *irc* in *#vumi* on the `Freenode IRC network`_ .. _vumi-dev@googlegroups.com: https://groups.google.com/forum/?fromgroups#!forum/vumi-dev .. _Freenode IRC network: https://webchat.freenode.net/?channels=#vumi Issues can be filed in the GitHub issue tracker. Please don't use the issue tracker for general support queries. PK=Hږvumi-0.6.9.dist-info/RECORDtwisted/plugins/vumi_worker_starter.py,sha256=vX00JU5zm3NL05NDNMpBQgXHQDQiA1CNb9Zd-V0f3JE,432 vumi/__init__.py,sha256=OmXM8GXzs5dixMv8uHBrqtk6Dllfdn0g4LIDcpKUL0E,68 vumi/config.py,sha256=7tUIkaqMH0j07d2VyEDyImR1HXJYgiWD-GGiZMBzpKc,3999 vumi/connectors.py,sha256=dxOVowCEJYHHvJv0E-ogJFbhW6jdclgFZH-2zFxI8gQ,7052 vumi/errors.py,sha256=8SB1GSIR0CfnPwoE-0qS5yvaM0JtR-bn7IySSyinCeY,602 vumi/log.py,sha256=TDHRbjaiKR7nGEdFGAAMHa-QgPeZoBv8Al2wh2zCiFM,952 vumi/message.py,sha256=nWRmrOm9urwns8Llbl4voOZO41P0Ccb2pClor16gz0Y,15886 vumi/multiworker.py,sha256=bICYO050h_DHHc3IF5OPJGBH_U4vn_uiAdw67co4IV4,1736 vumi/reconnecting_client.py,sha256=UO1yjBLZdz-h8k-pteXJ1UX-JtYb0eDh20sAaV9-DKA,7689 vumi/rpc.py,sha256=moa-NTlPTWi_eZoi2SX9sUEKzIjYBEGxhRFmzhIMQfk,8843 vumi/sentry.py,sha256=rYWpVQwwsz1jV6HJoSTivpsFHjyhbBMxFxZMvI2T9yA,5872 vumi/service.py,sha256=sGNusnz9MAMzrOZypuVtALFAUPbAmAu7wy49avV_ydI,17157 vumi/servicemaker.py,sha256=KwMGL6JgPRlmO5YLpIM0apdHXRITgs5gZKlxbahvm2I,7895 vumi/utils.py,sha256=Exn860YE27OPnAPmk3EtcNHx51H6erK_3ZrqmklE5DQ,15222 vumi/worker.py,sha256=T9kxmD9WySPBQZxYo11yiJPiYSNFpak2cFgQ7WQClh4,8252 vumi/application/__init__.py,sha256=N2eMtKNcGcIEyDVSodyvY1ogki8LkoH4Z4_w_pUdMOQ,273 vumi/application/base.py,sha256=QfNanjR82ZPDD-bQU6TNRlQGr4kmSzPnjtofMUAfxjQ,8037 vumi/application/http_relay.py,sha256=Q5tOcJyV3Rl-pYO1e7TzTPpTV6koGhZlpCI5Jf9ZPaE,3657 vumi/application/rapidsms_relay.py,sha256=1hNIcVSFFK2cFDDN75TJ-h2uwmt3yn-n1OkWS91F5pE,12577 vumi/application/sandbox.py,sha256=ZRni-L0WXn5Y8_RDJpqmPY9CjMgb21d-qkQ3B5Wz-Ww,50784 vumi/application/sandbox_rlimiter.py,sha256=TMjaSeGT73ve5Q5_44qq9ZmKMYB6_SiAfx0boeivY0U,3521 vumi/application/sandboxer.js,sha256=gwrDI2oYQiQLN3LrhHqjuk1DT8VvGXZJQ2XfuEp9B1w,4317 vumi/application/session.py,sha256=lmquge5G_VhNEcVyPt6Y9RSe42-VTZN6dezbY0AAyp0,3976 vumi/application/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/application/tests/app.js,sha256=-A_n1BjGburLt6s6QOf3yqjga2nWhpSdbQwBsNf1CaY,439 vumi/application/tests/app_delayed_requests.js,sha256=E9P2RojlVfWrOOq42CvkqXeJEnRGhA9gDmlpLJ5QY_0,541 vumi/application/tests/app_requires_path.js,sha256=uVHuwxYeXGgbfDi_ialfWUJ_YJcLIyjsRXB8nZ6tSHM,259 vumi/application/tests/helpers.py,sha256=Ud-GLejS3yDKJoNYA1MTJFMcYlWWLlcs3A7puIqftnk,3399 vumi/application/tests/test_base.py,sha256=7WdQSDqDsMCz3fEc6tCdMdNYVAYV5sznACx6DKMcQkU,13380 vumi/application/tests/test_http_relay.py,sha256=6VAnLghBzLVXkeICNqzI_F6wap87duwXXS0ZnVI5vCM,4393 vumi/application/tests/test_rapidsms_relay.py,sha256=gWhUHp-As8dZNAxCXOg-2dcc-1UAVQiATgpBA4aXtVg,11424 vumi/application/tests/test_sandbox.py,sha256=VVLmEcEAFJhRzVA9EuQ8fWozHLKZHCX0JiPHqfZEm6A,54291 vumi/application/tests/test_sandbox_rlimiter.py,sha256=MWrYKjNROw0VGsF9k64USyxln_Q3IV7tUtXGqDNh-CQ,745 vumi/application/tests/test_session.py,sha256=NaYG-caAGcMJebSFxDZW21lkiqRYE9jrVV79jzWFcfU,2042 vumi/application/tests/test_test_helpers.py,sha256=NdpEB_QhA2YdAXPajJrtDU8VUxs0bYttozCWaDlIXRw,4538 vumi/application/tests/utils.py,sha256=fk7MXoyK6L1kc3K7K4y0DXvCJlCqPA7T8MF2pFxqfe8,1766 vumi/blinkenlights/__init__.py,sha256=3_e2WrFis1IPtdmO0a8o2-ID9lOvsB_OFBLhzaumJSU,332 vumi/blinkenlights/message20110707.py,sha256=SsN3mmm__0aRtd6o34CLWKeeJ4fcI2ufSE8AFMZp4FU,3594 vumi/blinkenlights/message20110818.py,sha256=AyG5ZiIK6KQ0CxaWx8kc3HDWkeP3EZ-6JHTZKfC73ww,1088 vumi/blinkenlights/metrics.py,sha256=6XBAqZmmfMs8lyAj7Uje83Q4tkDhiShE-GwQoWzLOyc,14009 vumi/blinkenlights/metrics_workers.py,sha256=2DDL6i3fUBJ0Ei0h8F0f2h6wpMxDsQy8MUBCQ9iwQ4Q,14906 vumi/blinkenlights/heartbeat/__init__.py,sha256=l8Zdd0svp05i0feUh_cGMhSHtMZaLEVGFxX5-frOcg0,229 vumi/blinkenlights/heartbeat/monitor.py,sha256=887mrOPi0mGptmzZLJ1vLrquzmqjeDVITc2lPv-HVbk,9473 vumi/blinkenlights/heartbeat/publisher.py,sha256=SJ0-Y7Qi1twZ56qaz2HocpIdjSrvv7HXsW8TEyhjtOQ,2037 vumi/blinkenlights/heartbeat/storage.py,sha256=ld8tBrfmKObv-WQqSi95AKYMgZ4ujOwtkZW8KC7_qH8,1865 vumi/blinkenlights/heartbeat/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/blinkenlights/heartbeat/tests/test_monitor.py,sha256=FkvLUEICaEZ0JZ28_Zn2P3uBDO373tLJa-sv8Kd-X5g,8613 vumi/blinkenlights/heartbeat/tests/test_publisher.py,sha256=xzK1S9_TKZH7Gn3LsHz8mmBE0vYD0TAzHHclsh_OHB0,1931 vumi/blinkenlights/heartbeat/tests/test_storage.py,sha256=mu6VcvqJY4Rss60EWi8ugbfWge5DudXgHLu2u0PhAxM,2579 vumi/blinkenlights/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/blinkenlights/tests/test_message20110707.py,sha256=IMW3Dw54VOK64Euo9B6uAgLyAyMxEl3N8Jv3NIE-y7E,5351 vumi/blinkenlights/tests/test_message20110818.py,sha256=dfVJk5RZYPQ6Nxofk5TCpz3dGmbkbLsH6yKStn2PqTI,1004 vumi/blinkenlights/tests/test_metrics.py,sha256=5hdxNRRoINXYqMIxKGRUzwSTXoE_X07zfkBa6peGXvc,17818 vumi/blinkenlights/tests/test_metrics_workers.py,sha256=IBzY5pon0zARfi7Tr_7DT6b8jJGsqmQ4sEPhEeRHydo,13681 vumi/codecs/__init__.py,sha256=3PKbFzD2KerAr6M4DX5ettI_qm-BXpXPeRSzcyUlIaE,71 vumi/codecs/ivumi_codecs.py,sha256=uOIjFA0F9L-jQsiRoP8TgD_I_1tT5zUfqC3XpvBZryY,395 vumi/codecs/vumi_codecs.py,sha256=wzVmjBRgkuzZSbEAlmePO4kixwsRl6n8RaBkVCzGgdU,5088 vumi/codecs/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/codecs/tests/test_vumi_codecs.py,sha256=CVE2ECWfzVxvt0GCt4jAUjwUfFkCGrTgc_7FQrwjWZ8,3353 vumi/components/__init__.py,sha256=Du-U8PBYr5R3LDefDcOv8lcEVPRjDnC4AOIDzGkz4YI,33 vumi/components/message_formatters.py,sha256=QhsSk_rTxbK1v2G2GE_0xnJv4S3PkH2AhmLqm2knHas,2893 vumi/components/message_store.py,sha256=spR4-mh2Su_IZd-cujwxofQqnW_-jtR5Bw69sqw-xDg,42913 vumi/components/message_store_api.py,sha256=JkeMlT7HuFksTWtoTI_JWFvqbazJPpIda1TXjwhUq1Y,8350 vumi/components/message_store_cache.py,sha256=GvMQXXywlJqSg97LsOb1AjnxucSiKr1c9GtZu6LSCZI,24603 vumi/components/message_store_migrators.py,sha256=ozbx2oT1hY1jhjOD01jidVDwAVWM0EsLYRXnxqRrj9w,9839 vumi/components/message_store_resource.py,sha256=_jP21g1y-CKeNi_t7tONwbjhxbeX5eTlYSq-bu_uO4w,10887 vumi/components/schedule_manager.py,sha256=AXXlmkxVx6lAVB85qMLJIFQsh--C67HLR4u2xYzBe4A,4350 vumi/components/session.py,sha256=9V-HpBvsKJEeqnbAsvOSlj1HHB98WkLVeMsmFJdgyEM,3834 vumi/components/tagpool.py,sha256=J12F0DrqBW8AC8DRNMja1tHtJZkL9tVyAxOzCC5bFxI,8957 vumi/components/tagpool_api.py,sha256=sdKshRmeLVMmUlWuoBJmy4Bq7FUK1x_drbDMyPHlSOk,6493 vumi/components/window_manager.py,sha256=PwuBJ4WdzZFFkf1-b1BTSEpd1t-xgFFk8fCMob2r2YQ,7542 vumi/components/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/components/tests/message_store_old_models.py,sha256=swzC9gpMwQ2kt4cztzcD1V2TNBCZfAIlZcjFOn20knk,9317 vumi/components/tests/test_message_formatters.py,sha256=IQR7dyn0FnP3EjEpjdFt7CmFpo_vg1Wq49_u_GL2aSE,5930 vumi/components/tests/test_message_store.py,sha256=JZPBCoJsZkiilB02N9HNxIjG1fYe34O21Ue5jYEWtkU,81214 vumi/components/tests/test_message_store_api.py,sha256=ihFNlsHE7RbO9C3NeVXpsfqtbj-1k8qEvI28nWP-tGU,11269 vumi/components/tests/test_message_store_cache.py,sha256=PuMzYcCE6kgvWkZUUpmILS-DKtE7ap8nQv3OiU5ZT2w,21218 vumi/components/tests/test_message_store_migrators.py,sha256=oD0qLN7dp10X50U1HaQ3cFKoUFAaR9QEyXN-BKe7dNM,50004 vumi/components/tests/test_message_store_resource.py,sha256=VSdVRJ8CSQPmAt_H7ah2gdkk069bJxr1mekyPL37SpE,21522 vumi/components/tests/test_schedule_manager.py,sha256=QFMqrwChm6SBLDMP0jV0Mq6KfaP7M-QCNqB8l62rHX0,6069 vumi/components/tests/test_session.py,sha256=ZDQc58pTqjjyTnJBtPPuW9tp1QLpQmmFv39cSYQO5os,2785 vumi/components/tests/test_tagpool.py,sha256=GG5aYH5XVIAryo6n-bqYAaMsndftZe0KhZRKVl_7YGE,10817 vumi/components/tests/test_tagpool_api.py,sha256=Rj05HdMFciH49cVv-fcJhf-s4IBFf8IfF_mCYTJU3aQ,10732 vumi/components/tests/test_window_manager.py,sha256=FUZ5E6TR_kjNl5a8I6kqIpvpYx1EQocsTZ4kh-azJ90,9682 vumi/demos/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/demos/calculator.py,sha256=pybW_6S9MPAw8oVFpUCScXLAu8zS6Cf_LG2FIGPPMt4,2737 vumi/demos/hangman.py,sha256=fYFCBYUWrMIjLYwvT7nrKh4zkQI6hKuA-Mk-BfKhsKY,7923 vumi/demos/ircbot.py,sha256=Qtf9shIfMtY4r5bXNCrrwvB4xZvbkJvUQLGO2U9WkhI,2928 vumi/demos/rps.py,sha256=fmE1i7ivi3IaiYqcVv-MtNdYfCOW_NFtG7jFGUuwX38,6205 vumi/demos/static_reply.py,sha256=xpJN_vGO_wi7TcUk8uIgb3XRg_jBgrfWJ3RxJZ_IzD0,903 vumi/demos/tictactoe.py,sha256=meCK3N6mds7_DdHqUTw1BUgAVage2Xuu8dNPm2WGGcg,4848 vumi/demos/words.py,sha256=BEEm6vQK6a8bvt-RQYuFUAEkcO94TG3ub9QD0MucvCQ,1988 vumi/demos/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/demos/tests/test_calculator.py,sha256=W4ZdhKgMcNIo7nw2n0CQsPaMfTe2xPCmnzlMPXquRfM,3682 vumi/demos/tests/test_hangman.py,sha256=TxjRiO_llsrzrYLAjJsvFJekeOwc-VcHZef3IkBj5lY,8987 vumi/demos/tests/test_ircbot.py,sha256=csuKOuGUfwAbeYwalYhhTM5yYjjbn1wcZHBK_HpS160,3407 vumi/demos/tests/test_rps.py,sha256=4ZtOBOJqGuAueuAwKbqU9MYA16TROp2VWarkZIouZyA,4557 vumi/demos/tests/test_static_reply.py,sha256=dCnyn1Dh1gpDRuvfBg_eeYun7kOKOxL7ua5vGzuQmKI,1286 vumi/demos/tests/test_tictactoe.py,sha256=MV12y3TxysOH8-mEMagnWYHZ8XYMlRS8JDLt3-TIn2Q,5315 vumi/demos/tests/test_words.py,sha256=XLYsax2JerKT8oqZsV2NS2WdBipilP4BKm4uvjEz1yk,3657 vumi/demos/tests/wikipedia_sample.xml,sha256=Rz1HZMlGCzdlwYpGQGwsPX1gmV4WM-_CnvDXBjnzTzQ,2535 vumi/dispatchers/__init__.py,sha256=XYyezSpnoWgW5jTRmXEK2HbkvMr62FVhyCRs4Bundmc,611 vumi/dispatchers/base.py,sha256=gp7EvF_EMb4MoBJab4CGP9lFzVuixxEcCwCC_x9Q_X8,24645 vumi/dispatchers/endpoint_dispatchers.py,sha256=sUi93tFQOa71Uezuz5OuPwwUL4cEaUA_ImcspcLWvJM,5949 vumi/dispatchers/load_balancer.py,sha256=Jq_lpwbt_aLwpSkUQmPspBlND-K--ZwzHBjP0Rg9Wkk,3626 vumi/dispatchers/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/dispatchers/tests/helpers.py,sha256=t-GNpm-3mMRbc2C_6cJnzdN60EJkNbSdz5wZLkHv2GM,4376 vumi/dispatchers/tests/test_base.py,sha256=OXn8u34DJSI9hnsqhTd__X5y4a4h-tiFmdAjMX0Un20,29225 vumi/dispatchers/tests/test_endpoint_dispatchers.py,sha256=3uX8Y6WpKbVcUCcCzf0Bq-jeQIwNs9994Z5BaH-lHfw,9472 vumi/dispatchers/tests/test_load_balancer.py,sha256=OiP1AK4grmD2nFCrpT24dcW8FI3Jmu3I2YbFV8YrNI4,7477 vumi/dispatchers/tests/test_test_helpers.py,sha256=uoC4jamKOaa-golNvCZXJKAywuqy2HT-cpwI7fIWxek,7475 vumi/dispatchers/tests/utils.py,sha256=L0DFSCe0ma4sXEnbnbO8xTMrkhSd5Wok72hEPcwgMlk,2173 vumi/middleware/__init__.py,sha256=D2GmAHiK5vxsSiz2IfoX-q2d9JETH-Qj5HyZXu7zpGk,447 vumi/middleware/address_translator.py,sha256=E9jZEfdCvd-eQWqthOIvr4t1hTY539WIr2lydaoL-HA,2066 vumi/middleware/base.py,sha256=CjPl16JWacgUrduE8Ay4a2lmiETaoGDxemRqQ-XXrbg,10142 vumi/middleware/logging.py,sha256=xKfquVrORmcum5X11M80dvzkJilCrYYZqqYLJMGcF-0,2175 vumi/middleware/manhole.py,sha256=lKdaEGclU_JihyHJWPhVl4DXG_Pwlpyhl6bBYW3sHp8,2886 vumi/middleware/manhole_utils.py,sha256=z9goLdUZApQXwrWUzLGzLXNlHqKFm3vXZ5lPXkxuXmY,1849 vumi/middleware/message_storing.py,sha256=JwEnIIor-7cgQxCB8ivbZ8DDCrMhwCF287AgbwqdFGs,4736 vumi/middleware/provider_setter.py,sha256=CSB3TIt7GcoCLCOCle200cIC21y9DKfeNNaiTGRBtOs,5770 vumi/middleware/session_length.py,sha256=E3KXdKDAF6gbFrVGqgaLhlrU5NlKUH2JX9BZfi5ONU4,6904 vumi/middleware/tagger.py,sha256=afGUPW733yT1JTzVrJoiRSeH3ShHB7QKvvNTCz9S13Q,8654 vumi/middleware/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/middleware/tests/test_address_translator.py,sha256=Q9v-JtyuW07xaz8QQNfXhzrT7AVzNR8gT5f75jqv-jU,1322 vumi/middleware/tests/test_base.py,sha256=H_4aPlEL2s2esc9c3YVsdMHBnFEWnYKFYfecl7aSPV4,12211 vumi/middleware/tests/test_logging.py,sha256=OMTkEsIDG6c5oa4rm2HTTrpptnG6hfsk3J3hJk9G85I,2545 vumi/middleware/tests/test_manhole.py,sha256=C8bZ6RnwI1IL-udsHA6ZaSwtt7VkNt2YZ5BgvvKnTpM,2903 vumi/middleware/tests/test_message_storing.py,sha256=SLmbP6r9DMTfMVIXDRkijQXZO0_yx1S9lYa2vAYsZ5U,8119 vumi/middleware/tests/test_provider_setter.py,sha256=1zlE-plkZhoZTZ0D4qVtIcRRZ2EZiANXYxg1EWSuXgM,9287 vumi/middleware/tests/test_session_length.py,sha256=u2uKR5BPj-W8eafdsLOVwvfvKOrBlXDAk5Bmr9BE1HA,22912 vumi/middleware/tests/test_tagger.py,sha256=yW0SY5kjdIa91iRL2JOj696P_X4wvtUyN_C2iI1DhYA,5053 vumi/middleware/tests/utils.py,sha256=3LxCmYHNrhhS-6xqIrPdM7WE7iUVA7dOjUxXIu6Bktg,982 vumi/persist/__init__.py,sha256=uGw1imNTsn9JXsRaWrbY75gISRbrYKE8jXd3Mz5OVG8,28 vumi/persist/ast_magic.py,sha256=COHduyQo9yE-E2LWljhKykkw5uVmJ31TpC9elACOHqE,2096 vumi/persist/fake_redis.py,sha256=dfscIDpxgQVd2GjuIndm5VLMrK2i3E69PkQZ3SpiSyY,18716 vumi/persist/fields.py,sha256=5VEoxtJ07F1TrxyWveyQEvDF_6SlPub_qd_HYpsVkD0,36547 vumi/persist/model.py,sha256=qiwMKS8DTLTjHEf31zX3vezGQl0c5d6xml9uqSDS81A,38219 vumi/persist/redis_base.py,sha256=EIgvFviWHpnug09lx4WLM7w2bKybHVmR2PsT3FTkAQs,10551 vumi/persist/redis_manager.py,sha256=cEWfceD-8f90vCgI5HBhZv1uR2yH6nNEdBmV8dSpLYo,3244 vumi/persist/riak_base.py,sha256=HC4LmDUhfm8fgojXWhoztyj15dEC_atA94ys9-RjlRw,6645 vumi/persist/riak_manager.py,sha256=4xD3CCsnIsEoZ5y-Ats6VAjUi12LUnAylS_Kdc4Dz70,7835 vumi/persist/txredis_manager.py,sha256=W6p5pMd9s_eIaNdNXu-9FY0yssWKYSzyPm4KYqI3Nf4,11252 vumi/persist/txriak_manager.py,sha256=P4oAlp9TSdjeegtEiVySM_bzr0dG76NdKB8_jqTVJgc,9183 vumi/persist/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/persist/tests/test_fake_redis.py,sha256=mMeamC6LiWo22uaZg_CgwSQDWDCA-GjOS7UPjK_Z4C8,33777 vumi/persist/tests/test_fields.py,sha256=ffv2HyTRbX52kyCoyq8kYFQBLMgIJwWZ8DbGfPipXFU,47103 vumi/persist/tests/test_model.py,sha256=TeLV68ZSXdqJq27z7jD1bLCUD9rFTuKA-u28HqVYEkc,45245 vumi/persist/tests/test_redis_base.py,sha256=quL0cEgRfZLEtKLidhAHIrTypT2RYQ7eF_Hg4pndaDo,1818 vumi/persist/tests/test_redis_manager.py,sha256=KWKCYXPCO2X7qUypYW3cAbOwEKSLnuMoWLezqESg9r0,2739 vumi/persist/tests/test_riak_manager.py,sha256=EcyAM3VtQ1pOZpEfqbN4BowpEzXCy4iIC53ZJYfGow4,3682 vumi/persist/tests/test_txredis_manager.py,sha256=z3W3Qv4t5uBP3hwf60lSJqj8WQG_J5FkSdPJ7kiGMQk,4676 vumi/persist/tests/test_txriak_manager.py,sha256=YfLmkKUTDlgbtAA-vnMwf-S7dzH5-eck9w4PSegPsdI,14118 vumi/resources/__init__.py,sha256=SEN2jht5YsvRkUUTz84r0Lt4pPOensicAl4EWd_s2eU,54 vumi/resources/amqp-spec-0-8.xml,sha256=fkr5trOYUD2xFYQZ5sev7NW5C1uJTIHDKcIU2RkAb2U,145900 vumi/resources/amqp-spec-0-9-1.xml,sha256=r2gmM9Dzt92IkHhYIWLbkTlowPvrXxXwSqO4cXHj2yc,119305 vumi/resources/amqp-spec-0-9.xml,sha256=_iM48u1xgaIpZKj00HXdKvvgS2D-tIN2-gcAojjA8So,216878 vumi/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/scripts/benchmark_persist.py,sha256=gVjhjwpj1_OqzhYkjVHXm_ZqTu378ZX3kB0tsaR_fE0,4255 vumi/scripts/db_backup.py,sha256=kgLnCSVPr4Blqu27u8g-aFgVsc8grqyMYhOFAJK3ltE,13532 vumi/scripts/inject_messages.py,sha256=518XFEYmnOIggHdNWxdbJu4CXA0TF4L_uwaoPPYg7fg,3079 vumi/scripts/model_migrator.py,sha256=4OxOfowQIwqM0E8RGaxXUPGjzuqgaBFAUeCiGIv61uA,3913 vumi/scripts/parse_log_messages.py,sha256=KButRyJxr9nJM1eJ7lU-zyI4NIPWJtoi1kfphQHFDYI,5001 vumi/scripts/vumi_count_models.py,sha256=qWcHajwP-vKlJkMKVSqB0emLeG-g00RcJekLnuuXXQI,6057 vumi/scripts/vumi_list_messages.py,sha256=DT77cXG4grrnC90L0K38bXFI56YiSe0pzw3hpAg31to,3142 vumi/scripts/vumi_model_migrator.py,sha256=YWXMZc6b7jNON8xNqeyONeTBKjbeQCnz6GghJEQeefQ,7899 vumi/scripts/vumi_redis_tools.py,sha256=8vI7lylabRpgVnVfkMoyh5C-KKZwOehHOpagR5pSEn8,6315 vumi/scripts/vumi_tagpools.py,sha256=ECGbeG-c2-8lRshiowMb_gN4ugJ6WEodX0D9DGk-NPk,7491 vumi/scripts/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/scripts/tests/sample-smpp-output.log,sha256=1pNcwtmLoAj3lJheqY-5Bvsg6zCH3e1B3JWKir_9GKk,868 vumi/scripts/tests/test_db_backup.py,sha256=XMtLv62v1GnfW_13uG4BcPFNb__4t-2Tek2jPZSZkpU,14332 vumi/scripts/tests/test_inject_messages.py,sha256=4pqgixwzVR-ysEVX6tkNBg5pbuaq5GOVG6eQFu5Sums,2698 vumi/scripts/tests/test_model_migrator.py,sha256=tu-r5AJoU777voy4sJcFd3lg21lKpB08ILbSFy0n_ck,5767 vumi/scripts/tests/test_parse_log_messages.py,sha256=eRTBzOCSAU7SxpCnlDtppSO8nI8ynaKTB0jcX-hlC2w,8324 vumi/scripts/tests/test_vumi_count_models.py,sha256=il3p0hCp3fuhIpcwP8QyqN6l1NDeKnbwwU6vHMxjqVk,7507 vumi/scripts/tests/test_vumi_list_messages.py,sha256=Va3Ofp2oEEque65rBOy-4iDHs4Ude1uQVrBcdR5ZlTU,6464 vumi/scripts/tests/test_vumi_model_migrator.py,sha256=LtK-lACfW5o7HfV-KNxZhpLj-Fjz2ul4L8pfcCiXvCk,16546 vumi/scripts/tests/test_vumi_redis_tools.py,sha256=jfyoKs4g9XV_vFa_bEWQ0AxYoJC6GjZrrBq8SkR9vnw,11682 vumi/scripts/tests/test_vumi_tagpools.py,sha256=cfKg2weB2IY-DaoSuQOJKMbak1j_vsU0AyjhC7YxCIc,7935 vumi/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/tests/fake_amqp.py,sha256=_2aYxc6wTmpMvmEvtgU3Kvj5gJDhUZQnA2tPdeuMLlo,18555 vumi/tests/fake_connection.py,sha256=0Mrs5iY-Jy6RG8YvNgyDyUX-e50XBioLoenN_gsIqS8,9175 vumi/tests/helpers.py,sha256=Ey-5LAaJeyacBQ6nM36Yv8OzS96gBNrazmA4OHAp3f8,60262 vumi/tests/test_config.py,sha256=LL9Gpl6Bm65DhH9MXj6242fLRQg_j4FkkC-7N7GVvM8,7299 vumi/tests/test_connectors.py,sha256=k0_s0RDe3uuPJ5sP4RUtP4yhyrWoz8n1NzUDDMBv4BY,16455 vumi/tests/test_fake_amqp.py,sha256=HM_qRvBiwUlDZ_9mNoz99yt1BrDhBbNqvj6wTs3I7hw,15065 vumi/tests/test_fake_connection.py,sha256=SQhjhQON2S5M934un8X4jeejm3rYOB3RwBKR5p8VzaY,16308 vumi/tests/test_log.py,sha256=V7Mi4V4BtLbWJnldOtkhyFoMRRaVk0PjocXB7yLOzuk,2299 vumi/tests/test_message.py,sha256=woF4Xmn5Q-Te7ajIGaTlLDyqHsv2OLgFqmHjdnf202s,21267 vumi/tests/test_multiworker.py,sha256=APqGM5d6PwTYP9nF6VcNjyPQD0EJIKH_NIy9L5UBKuc,4312 vumi/tests/test_reconnecting_client.py,sha256=62bFhwM9p6kUUoHAi0cthL1a_NUaBosf8tMqqUMbMsA,9833 vumi/tests/test_rpc.py,sha256=vb5hi3bWN9CddZCBj2a5S7ntM6JTjdCfHK7sU-ErSYI,8426 vumi/tests/test_sentry.py,sha256=6QYNjBGmg84h8IMdD3HIFe9c0Fe_bIucdY-xZbiItow,6306 vumi/tests/test_service.py,sha256=oGGaS0XX9Jgfji2EHZvhLIicjTfRmPlkzQLQXud14D4,5369 vumi/tests/test_servicemaker.py,sha256=zD1qN3k7Jfxf2wAMUMq-4z8mQwHOXiuEg_52jtF-i-c,10667 vumi/tests/test_test_helpers.py,sha256=-8ZRNqt6VZ8H5cW_CBzFeCsGkxivJ5adL8dWJsE8FA4,85896 vumi/tests/test_testutils.py,sha256=VW-sOKQClm0CKreZ4MU3NkGJtuUcZNc_fil1doN1o3E,4864 vumi/tests/test_utils.py,sha256=VF3tpuHfPJY_RxUI5-WC9XPkylt-x5wvXBwxTo8Oxq8,19050 vumi/tests/test_worker.py,sha256=x0mnVtIhMkw6L98mp7WUrrTs6rexFiUZg9t9uN3ms6s,9346 vumi/tests/utils.py,sha256=4twCDQeuTbEUgNqttiFwkjv8YaQXaNA4IhtHuws4p3U,6096 vumi/transports/__init__.py,sha256=kWmi8h1tgQcSr7sQPJz5QEbvJYufFgXN3-hFlJ44CHs,319 vumi/transports/base.py,sha256=O6siinfvn4uQrxe5VeOn5Ye8pQ02DCbf41YIqXHbKhU,7576 vumi/transports/failures.py,sha256=iWIv9HfW5N_dm3RjfKs5uuu4EyS8WurGsD3cL4DbQ-A,8850 vumi/transports/scheduler.py,sha256=bMrG4c9myMKqFhbro6rLgfblr6nfHUeCiM2eJAZgDoQ,6018 vumi/transports/airtel/__init__.py,sha256=knVxAsN3wOtwpYIpk4RU-4ytsMiSIf1BGapYMoGX79I,97 vumi/transports/airtel/airtel.py,sha256=2AzYfRKPm4Ha1rTI1pERyrYUfwX6-GEpmtOKXkTt0jA,9639 vumi/transports/airtel/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/airtel/tests/test_airtel.py,sha256=Zb0tjOiwG4FcKOWqJd1H8swjr4I1MZ3F2GKpsW8GFV8,17618 vumi/transports/api/__init__.py,sha256=SlaeRcD6DHyqEjQDdJAfmd-N1wxY_nsE3t7hu7xJnLo,345 vumi/transports/api/api.py,sha256=pxw2k878Hny7_CynE0YTIOwkfIbJsSw3C5XuibO85r4,3560 vumi/transports/api/oldapi.py,sha256=_Y8HTOKQUJlHGHVsQZvnHT-FBbypbJpmF2M_o4rDQdE,5723 vumi/transports/api/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/api/tests/test_api.py,sha256=mBKKHocQEQTg2kt3DALEgidFbjmAmc7Ad5orkvqlycc,5952 vumi/transports/api/tests/test_oldapi.py,sha256=coTzcorIGWrk6XHqhnjwgNysJ955ZzjkCvoFrHRpipI,5374 vumi/transports/apposit/__init__.py,sha256=el7drLmOttMfCiKfDR3VWNA3wKIgyIPEZHrRBrM4Xxo,93 vumi/transports/apposit/apposit.py,sha256=WeKxbSA_7muedH9ariGySf0WaDNCpoElWti6p_Q5s6k,6631 vumi/transports/apposit/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/apposit/tests/test_apposit.py,sha256=oP67v4-1byNUUTlcEISIqJNk_ygBv8H2o26W_DXLVxo,10407 vumi/transports/cellulant/__init__.py,sha256=ebLFunxY2MVt2xgVdCBthv7C4SUndUkFPvOec0JJDBg,241 vumi/transports/cellulant/cellulant.py,sha256=LXmk3JR8CcTlcbU8h3AECNhUNGxirRa6RP-ZY84uwd4,4655 vumi/transports/cellulant/cellulant_sms.py,sha256=lqNWbOalTTn49YjdW6y03_MVJiM3HzfEtrdsXoUBQfo,4054 vumi/transports/cellulant/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/cellulant/tests/test_cellulant.py,sha256=oFEtKSSnc2s6otViqiJeZ9knDynZ0GX-4AQYS8i1FIo,5363 vumi/transports/cellulant/tests/test_cellulant_sms.py,sha256=_6D5S-u8l5Xaed5whFd7DXNJQTTQH0ToA1pOSsf_ZMc,12507 vumi/transports/devnull/__init__.py,sha256=IHne370ovuygC2hcvHR1Vhw035JQAz-qJqBucTuZdtU,93 vumi/transports/devnull/devnull.py,sha256=qmC6NP53TP2EDYsIPVldlBy7tkjgby8FIr7rd38Jr2w,3130 vumi/transports/devnull/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/devnull/tests/test_devnull.py,sha256=by-Jm-BOvI-u3akxGRqc6WtHuplZMfnXNF0_YgMZGak,2239 vumi/transports/dmark/__init__.py,sha256=jTJUvLJgWDz3zrDeGXDZkJ0OPLdgM5HET-HeNQAv9TI,164 vumi/transports/dmark/dmark_ussd.py,sha256=bGBwaUL3fIeIlh0pRMbRG7EgufJwhuulDgOKuQ_jXxU,9892 vumi/transports/dmark/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/dmark/tests/test_dmark_ussd.py,sha256=zkWLCcwChZ0ra7GN0_Lm-Mk3BuYWjm79PBUZfZvXK5c,17523 vumi/transports/httprpc/__init__.py,sha256=b5R7k7Z5XnfvPrjC9FJmvPp7P4YttVKEF7gLuNGk0zA,147 vumi/transports/httprpc/auth.py,sha256=LwNmaaBuPFEOpZk5BQRq5D9c4B5aZPM4grVZtUds-Zk,1173 vumi/transports/httprpc/httprpc.py,sha256=NVwCBl3U2TUYbSPAgbLR7z2YFReOpP8PHRotMDpi-YE,16056 vumi/transports/httprpc/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/httprpc/tests/helpers.py,sha256=9kC6LYQ61MCbFbpNOtm85OZuKT-ycuSe6Q0xOL4-eBY,3638 vumi/transports/httprpc/tests/test_auth.py,sha256=cQb3xb4TxLADuQJ68HSTf3tUY_tYXziW_KigKu047l8,1583 vumi/transports/httprpc/tests/test_httprpc.py,sha256=7kjv5Y5Cz-XAEIbXr8He7BIHLJonxbPh5RvnmEE3vak,9651 vumi/transports/imimobile/__init__.py,sha256=D3gGRS1hWIWIOFYpGUH3FFC62ftS72YG-IIMneAEzYI,149 vumi/transports/imimobile/imimobile_ussd.py,sha256=q5YQQyj3qVX-V1WuWoO2Y0SPzmmOrnu2ywyA6EuQzXA,7713 vumi/transports/imimobile/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/imimobile/tests/test_imimobile_ussd.py,sha256=SyMCWB7NU_Q7Xn9js02hVakkPHW3eN85obmm2_B6Ews,9798 vumi/transports/infobip/__init__.py,sha256=vzFV237MeGSmFjkIpBo6RlwxNvuS24QLHVOWxRMHDsE,149 vumi/transports/infobip/infobip.py,sha256=85mX-Rx8EIn3XAKBn4ORpcpmJNqF8nJmNz06YmZMlFc,10777 vumi/transports/infobip/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/infobip/tests/test_infobip.py,sha256=DMg7RceJVrSagSxGa7RkN0HHNa0k_nWbOzJs5ZbuRco,11097 vumi/transports/integrat/__init__.py,sha256=uSyifLqqJknT6jLOJqENV7AZ2aFtgkjSOnP1Ex6EesM,131 vumi/transports/integrat/integrat.py,sha256=rUrLCG0o5P1-NBIwbAdPWJRUsA7aZX9--1H1WLM7T-8,5301 vumi/transports/integrat/utils.py,sha256=vBwR9tF7RruMMuXyOEX7ng_3l-JlcP3KWW--HElEie4,4367 vumi/transports/integrat/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/integrat/tests/test_integrat.py,sha256=hTK6tpGP9U8g2lKrz9uSXgcZvBPZPtzv4ph0oHqgWNc,9625 vumi/transports/integrat/tests/test_utils.py,sha256=NtlzUWjHsJrh5a5O2hqnRpPXjAgaoXDmM8kJ0HBUVa0,13923 vumi/transports/irc/__init__.py,sha256=vnalCaweZYPAs9xEDSjWL7Vbbe7O7oiWeZiRtGXjpYA,99 vumi/transports/irc/irc.py,sha256=Dwg-h5wj206h5L5iVoK_PvAygowKhfcFPGhppK5yDNo,10630 vumi/transports/irc/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/irc/tests/test_irc.py,sha256=SpTbtTbG5DcxCtBnhfAGjdERprC02dyb3QyK0WIL5a0,17711 vumi/transports/mediaedgegsm/__init__.py,sha256=00G1eOj_5uIF1F-jjqhKf_qC2POWvkM4UDgLdPPqhGg,114 vumi/transports/mediaedgegsm/mediaedgegsm.py,sha256=AIuMIvFs9yeysGs79rS5puWG4yvgAei_cZ9G7_HVfzQ,4418 vumi/transports/mediaedgegsm/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mediaedgegsm/tests/test_mediaedgegsm.py,sha256=WDGFJJw54Zworb1janp0drJMdVuxGjL781vzBps2zRE,6900 vumi/transports/mediafonemc/__init__.py,sha256=xZ9n-HTKDXBFAkebUDsHk2MOGUBE_WH-7m5r0-GHp9M,148 vumi/transports/mediafonemc/mediafonemc.py,sha256=zdowfqXwBTnWgoO3sWS_DabeuRsELrxOR9cSPnUq3CE,3142 vumi/transports/mediafonemc/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mediafonemc/tests/test_mediafonemc.py,sha256=rfbWhsltGJPJ9ZFrHL-Icb055rKzMun93yZ0sfFubHc,4949 vumi/transports/mtech_kenya/__init__.py,sha256=PEjDqRRfflQT5A8BvXIYYusFrVz0cVI5z8eOFK0AOws,128 vumi/transports/mtech_kenya/mtech_kenya.py,sha256=4Z9rS6on0_uC-tQ-sQ4yCmKNqO4N8YC_XUY3RpmUIAs,4183 vumi/transports/mtech_kenya/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mtech_kenya/tests/test_mtech_kenya.py,sha256=cU-HYQPDtzjbkoSd4Gcn180Rs0s3a3UOnnrihwcC9fw,7585 vumi/transports/mtech_ussd/__init__.py,sha256=VKF0g8SA_IAI9rvsGGmAQefVeLhOyTWxx2fuUZNrklU,132 vumi/transports/mtech_ussd/mtech_ussd.py,sha256=LmlHoHsObwLQDGXQ_szim0VHXn7vEeb25a9nwxTv5rs,6754 vumi/transports/mtech_ussd/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mtech_ussd/tests/test_mtech_ussd.py,sha256=6yZ-Cl9BNR7Olsec_3v0Dpj6P2jTvT52GCHPWmwNeYE,11978 vumi/transports/mtn_nigeria/__init__.py,sha256=fZoC48ZMngRBoBcieYW6sQDQH1Z3bUDJyadGuE3YmbE,185 vumi/transports/mtn_nigeria/mtn_nigeria_ussd.py,sha256=Vskds-Nj7lFV5NvaGZra4JYXD7rCeSNTXBTgIcT8RL4,9177 vumi/transports/mtn_nigeria/xml_over_tcp.py,sha256=XGHVcMzZHbqh8HmlyRekOBW7mSnTctXp93R9WVC9A24,17961 vumi/transports/mtn_nigeria/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mtn_nigeria/tests/test_mtn_nigeria.py,sha256=f-bN815DLjQkcs8meAccHIQhwPA680064jqn1OY1VOE,11307 vumi/transports/mtn_nigeria/tests/test_xml_over_tcp.py,sha256=ltYWxaGjmapYSY2KMxBG8u0zRNdMWW6W0q7yFXFy_W4,29241 vumi/transports/mtn_nigeria/tests/utils.py,sha256=bCaR4LzTTpuRKxzTzyaGkJTQCcHSueumAvD9F4JLVbY,3213 vumi/transports/mtn_rwanda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mtn_rwanda/mtn_rwanda_ussd.py,sha256=WQPEUclGX4nsfc9x1YMRspS14mV4bJ8DqnouC9XGI2o,8197 vumi/transports/mtn_rwanda/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mtn_rwanda/tests/test_mtn_rwanda_ussd.py,sha256=L5gckodDHMnZVjTwecJulACLo1PJywoFm91ORS6sGBQ,6453 vumi/transports/mxit/__init__.py,sha256=PBDjbL_GLYxCC148HizQvCYglmd2VF-0_bxZKihWA3M,81 vumi/transports/mxit/mxit.py,sha256=klPWDTDOtnM5UlHm7OKYYwXxWmhW5ibzanW5cM250k4,8503 vumi/transports/mxit/responses.py,sha256=t-5RvCc7KQ-7PBpGOXx_IicOclpt78Y-NJKC6NqR9Pw,1514 vumi/transports/mxit/templates/response.xml,sha256=l7LiO1EtqU3kdpgYXmGUqaZjJ5trKyBmDKFfplqXLZ4,578 vumi/transports/mxit/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/mxit/tests/test_mxit.py,sha256=W0SzQWY4fD4ACon4tDvpBX7l40oifUkZrHuCHnwCpQE,9733 vumi/transports/netcore/__init__.py,sha256=7LXOACw9EJXbDjSzAcbS1dTkHcd3EOu8IUQzeDMgC7c,93 vumi/transports/netcore/netcore.py,sha256=607ypJWyIkV8csAYYmcWgWQqNP-QzkkC_ShDOsR5cJg,4587 vumi/transports/netcore/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/netcore/tests/test_netcore.py,sha256=6Ep9u92WZrZv1BMcK4oVpzq_Mk98CuHVpfowwhD42Uk,4145 vumi/transports/opera/__init__.py,sha256=UwmMDeOvo1u38sdZUqyW5XWY0h6M9PHiQfR3tOh83rw,85 vumi/transports/opera/opera.py,sha256=kDPlCA7nWrbtNjWpB6Jv7AXpd2hL28JdHWpKAZJWmXQ,10974 vumi/transports/opera/utils.py,sha256=lCeHc2uf88Y4XruW8fg0enCRE8-BA5lKVOoJcY6Ai84,1185 vumi/transports/opera/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/opera/tests/test_opera.py,sha256=6dkUH0ToR3mDJSS-bQcqk_5NmHO1kyyNfA6ORPClgls,19397 vumi/transports/parlayx/__init__.py,sha256=kpCcktcp6fAwqKlzUJM4Bv8QNXR3k1B55vv_gycNB9Q,120 vumi/transports/parlayx/client.py,sha256=raZ2BLb46ruJsD9QI4PPP1GmBb8Iv8XDvP8uSxy5DjI,7613 vumi/transports/parlayx/parlayx.py,sha256=3GN46GK4w5PLaYpmFoE8X2Kz5Vq4wv27oNWhDGjKbmY,6831 vumi/transports/parlayx/server.py,sha256=hL-YkWdGJ7QD3dTk4J1LTIze-HXkso0F6Y18X9wTgmU,7130 vumi/transports/parlayx/soaputil.py,sha256=VRo5ScpV3CIIBdCJR18XU47O4L8vfzW5OkPc70d5DHE,7992 vumi/transports/parlayx/xmlutil.py,sha256=WRKwI_mqMeJ9hRBXbyQyWYm_ZEp--5IpJl-j3FWCn00,13918 vumi/transports/parlayx/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/parlayx/tests/test_client.py,sha256=Oz4i7F3nH_N-OazBLdUUiCsEwEHBe8J2M6ysjeIeEWs,14015 vumi/transports/parlayx/tests/test_parlayx.py,sha256=EomqaeKGEQKv33As4shTtFWadbxHchs2wTCtKpzOilw,9147 vumi/transports/parlayx/tests/test_server.py,sha256=fJSkvwz5a2silgV6zJVaMPsyO49WFArkG_S3Oqbbf1M,10995 vumi/transports/parlayx/tests/test_soaputil.py,sha256=dptvMrli702L4874yLAyWFhhNTSxkqqa-KSGb1WJYlc,10729 vumi/transports/parlayx/tests/test_xmlutil.py,sha256=KH-6Rn77abEdaNnDZlV5S0YMeWnGXdflqgGPSfsbOXo,17796 vumi/transports/parlayx/tests/utils.py,sha256=1clli7CXbkOZCepAS0KaiLTEakzkJg7JCadbsTqYUjE,4167 vumi/transports/safaricom/__init__.py,sha256=ZrrPJNlfqQ0hrBzs-R-G3rI0PDWHx-K08kvmfr5gTjM,136 vumi/transports/safaricom/safaricom.py,sha256=6lEH4bn-XMod8chlHuiXuhNMRvOrcqtqpvrf2EJm11o,4818 vumi/transports/safaricom/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/safaricom/tests/test_safaricom.py,sha256=1yPGEThQDGBG2a-cAmtlTxrRqrlLnYjrw8iEHlnSr58,8919 vumi/transports/smpp/__init__.py,sha256=MfAuJZF3lKTw8xnmINyOyiYse_Wb1OqPDqhAUlO4PSE,347 vumi/transports/smpp/config.py,sha256=ArNQpuwuTlfh8hvQ5BmyJLg8n2Bt0D-LZcqlswcJpKs,7615 vumi/transports/smpp/iprocessors.py,sha256=cLU5IT7gs4n58co5WxkYOOxEKuY-AZ3N751PGNgENuQ,2809 vumi/transports/smpp/pdu_utils.py,sha256=48jD-_2gTWmfNJdyyGh0ptBjqj_ym8jkVkVgn18xZSY,995 vumi/transports/smpp/protocol.py,sha256=X-P9wmDG4JCdXemRh1uy_BS3WoEjJ_9nPvNA2URUoX0,22381 vumi/transports/smpp/sequence.py,sha256=KimGSkN0QkfPihS-SqzlTcV309nmmDJoc5br3jx9K54,3056 vumi/transports/smpp/smpp_service.py,sha256=rFWSM1BhSMoRag7F4ewmX2-x3AG101Sy6zI_3SvucSI,15840 vumi/transports/smpp/smpp_transport.py,sha256=Lu3D2LzV1oRb92gOcEYDK5ldwK5o6hQz-jQPKNaz2t4,20569 vumi/transports/smpp/smpp_utils.py,sha256=RS3m6JT3YD4NiwiJP5jhWzQNhjdmShX5tpz1Lu4zfyk,654 vumi/transports/smpp/deprecated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/deprecated/service.py,sha256=lInLibl7Tt2dYMqrc8Gx57TJP2n2aH6JKW_1Qvt30G0,1023 vumi/transports/smpp/deprecated/transport.py,sha256=yatQYEXrz2qk3uHj0or4nfu1EnoW6U1c5IyZODtEeFM,22397 vumi/transports/smpp/deprecated/utils.py,sha256=EPgGPg2IVFb8B7qAlfR_Ia5qvzsVa01mm_QpdJLEWiI,1266 vumi/transports/smpp/deprecated/clientserver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/deprecated/clientserver/client.py,sha256=-6l2eTAp5YvQUyGYDjPw51vhfgUOtU1tY2ijdd_LMI0,28727 vumi/transports/smpp/deprecated/clientserver/server.py,sha256=-3VBGiFzjeXVUqcrfUIOdtdGqBhYQgVaquUXdYCF7Ws,5253 vumi/transports/smpp/deprecated/clientserver/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/deprecated/clientserver/tests/test_client.py,sha256=K8qha8St1v3CysUU2yQb1Ons-A4HUKoKMMtPAzPoZcQ,21523 vumi/transports/smpp/deprecated/clientserver/tests/test_server.py,sha256=8F0tEM9maF7kZWyNzC0sVebkEu5uqBPSdbuKXwwpbxk,69 vumi/transports/smpp/deprecated/clientserver/tests/utils.py,sha256=kWivE9g2dLyRTJPOae_64gB6NHmD7ko7w2f-Kt0KDKI,935 vumi/transports/smpp/deprecated/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/deprecated/tests/test_smpp.py,sha256=xqzJ3Ka1dHiZ2F7MefNzqDXhyUyjbNSqMl7wSJZ-304,41655 vumi/transports/smpp/processors/__init__.py,sha256=YIeT6uJ8cLp0yGoHcPWpzn-pk3FA13y-jlZfV0pclqc,488 vumi/transports/smpp/processors/default.py,sha256=9Nbmp_3g3ym6YFuuTmnRIU5-Ht5Q8rBK6dExaHjb1Fg,20999 vumi/transports/smpp/processors/mica.py,sha256=1lwJWAdFtktT0CVx8RGmse3gGAO6CyHvEi_7-LT2vVE,6247 vumi/transports/smpp/processors/sixdee.py,sha256=v6HCq6UGl9z02uLr_v_qgbnaQNKF03tdMBRcAGkXtpo,6995 vumi/transports/smpp/processors/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/processors/tests/test_mica.py,sha256=zwaaVbk8RyJO-T9j7-tkcgZldo5K4tclA6QzvQVUWlI,13370 vumi/transports/smpp/processors/tests/test_sixdee.py,sha256=8H0mxCyerMQZ4RUUXmsOSEJ5KLAY8gBx6GLIPalC1PU,14862 vumi/transports/smpp/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smpp/tests/fake_smsc.py,sha256=SB-dW_1hVMZfQBrqEo9yGrmatXP5IJat-eWOZKUqWgQ,10508 vumi/transports/smpp/tests/test_fake_smsc.py,sha256=OKFNzwuD2UQvDOg6f1vpNKf_1soccNj5MX1rwHOQz7I,24686 vumi/transports/smpp/tests/test_protocol.py,sha256=Ab_B3UkLOToAuccHHJnClfuNXmjts60zgVz4qh-BxSw,14349 vumi/transports/smpp/tests/test_sequence.py,sha256=4PWzEKRATJYoI0-OAHleG68wp4-_66fSDV1LRaIdrfY,1412 vumi/transports/smpp/tests/test_smpp_service.py,sha256=u2XqYYcU4sBHthUrGWA7ppoo8wKWNQfJeckpQyKI-WE,18494 vumi/transports/smpp/tests/test_smpp_transport.py,sha256=jKY8b1KPcX-S_LwPjOFCIp9tFhpfxR6EVRURQFuWOVM,90241 vumi/transports/smssync/__init__.py,sha256=_aX8K6R_EKBxw7DN_sge56BpljFOOXZTRlcsshKY9QQ,193 vumi/transports/smssync/smssync.py,sha256=4aZOrVsH1F5RBRVipN4iaik3-v6Z9U2dLcmV896pO7E,12711 vumi/transports/smssync/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/smssync/tests/test_smssync.py,sha256=5-vTHgrD-Ek__HJLLSCI6bItpAmaNtU5KZ1kSIpTXcg,8931 vumi/transports/telnet/__init__.py,sha256=z_CY6A08wg3wrzD1b3lTIERw7FOy8_aaKaIdZI9mepw,245 vumi/transports/telnet/telnet.py,sha256=dR05FW5u5cC7_qJrUQVIWL5jTw5df65fn1ZCAeBD54k,7536 vumi/transports/telnet/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/telnet/tests/test_telnet.py,sha256=GprXjfZH8oBe8DPMDobDxPpJai1uJZCgFgaw5wofZkM,8495 vumi/transports/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/tests/helpers.py,sha256=WUd-80lSPwEVXyvKTZxkXT23MOsuC9qqKsTZ0n54O3A,3229 vumi/transports/tests/test_base.py,sha256=qfaQC1rAxxSdM1Iu6uC2zU-pEZx4vLfYaBrNamJbByc,6258 vumi/transports/tests/test_failures.py,sha256=BBSPmW3sU6IQpN-68JXQG6XNzCPb11gluVCkNI7vxHM,13796 vumi/transports/tests/test_scheduler.py,sha256=Iy_p3ph6HEVyizPiDvbYzrAFWRRDjeWoa1L6l3EGlX4,5095 vumi/transports/tests/test_test_helpers.py,sha256=53Pc5hG753htmX35TqmNSwjh0wJplJn4V-1l6aHrX2Q,6023 vumi/transports/tests/utils.py,sha256=L7hc6ll3U6FtZpprphd3imbul53amkobH_qfItsR2CI,2208 vumi/transports/trueafrican/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/trueafrican/transport.py,sha256=Pg0yQlHYPD-AuJssCIXli7yKJVGI6rGSCrOuc6ixNDY,11619 vumi/transports/trueafrican/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/trueafrican/tests/test_transport.py,sha256=Uu5j1n52eKHz4xEPswfL5A0cnM-XfD1BCkrGjR4cXJo,11065 vumi/transports/truteq/__init__.py,sha256=6Ja9SvoDjDAYvXM05uf6_EbQVYAEzN08kDEEUF1Gx-0,114 vumi/transports/truteq/truteq.py,sha256=UZ-BMv-hiXFNpxiO-Q4spojbjtvo9Bfb0KVoPcjjSo0,7736 vumi/transports/truteq/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/truteq/tests/test_truteq.py,sha256=aBk4qD6qUYReMZu1ukRKfsS_cWksxj1NDb_EN0rTuMg,11018 vumi/transports/twitter/__init__.py,sha256=2xYkibNke12QBfx4eW9g3lDXI6mawbBywPLv8OCVx3M,150 vumi/transports/twitter/twitter.py,sha256=gghreAm3XerHQg8ap8EIE1fMclkkfFGfw_2k2zc8Yk0,10640 vumi/transports/twitter/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/twitter/tests/test_twitter.py,sha256=zIK-GIrcq7kZh69f0FPvAM0Da4jyr2-CmSRWiQAQUHI,16928 vumi/transports/vas2nets/__init__.py,sha256=BOkNbvzRk3yrZlQuKgi0M9eLq4bMjs7BnaF8V87aHCY,130 vumi/transports/vas2nets/transport_stubs.py,sha256=T8qWIxEM8hrvzpNkVWfT7U9raQLyPo4WRrtgiDt-uDA,4527 vumi/transports/vas2nets/vas2nets.py,sha256=l1e9K6dpNhN6W9Z7PPCHcpJBEksQOAMRch43hP1meX8,11042 vumi/transports/vas2nets/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/vas2nets/tests/test_failures.py,sha256=UTYosxiXe8VekCjhihvnxBFEbRfXZT9nThC8aRVZhgQ,6550 vumi/transports/vas2nets/tests/test_vas2nets.py,sha256=JcCIVuSY8ZnmHFTkQd0rAS7c5-AeGopogpZrR7zHa1s,12842 vumi/transports/vas2nets/tests/test_vas2nets_stubs.py,sha256=6KbLCo1qSUBhsRwHBLnkKZliDZV2Dpec9Jho7Xfkqrc,5119 vumi/transports/vodacom_messaging/__init__.py,sha256=ubA48aem6XLnviHYP5YKxw-Ucoqq8r6G3mEQG5cZBpk,246 vumi/transports/vodacom_messaging/vodacom_messaging.py,sha256=srYYyH67FSOKMlzV2vEssBOGMzOLlJXJ1wEyXn3ItAM,4289 vumi/transports/vodacom_messaging/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/vodacom_messaging/tests/test_vodacom_messaging.py,sha256=-iwgrAe21CJJkSd-jmguciXnBvpmQCsimO82ZSfWXBs,12163 vumi/transports/vumi_bridge/__init__.py,sha256=DV-6kfgd22GPbdn5qDWY4DC2exOg_kFifGR6FnSj94A,122 vumi/transports/vumi_bridge/client.py,sha256=UaXQEcV7Mne0lJyxDPBUb_aPjLagFf2Chz8utBLfJDU,3490 vumi/transports/vumi_bridge/vumi_bridge.py,sha256=XP63QIlHQGAz-Gn8yw7S9d9-L6NTaBGY4gktmLIpbLk,11194 vumi/transports/vumi_bridge/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/vumi_bridge/tests/test_client.py,sha256=Jn1LJf7lyXTrGBYo9cZFFETUJYoO3mbLfnsLHHT2bJc,2462 vumi/transports/vumi_bridge/tests/test_vumi_bridge.py,sha256=cOFmdzkdwPwW8W01r_s3E4UpRVCMMxH01tyg8skIs8I,10320 vumi/transports/wechat/__init__.py,sha256=1WLeTkVDQHOS_MWXCo3YsBoJQCRtDk_yBo1HxaeLaEg,89 vumi/transports/wechat/errors.py,sha256=3TU9cjQwT3lHYKKlluXN80WxwHFesANrJmluv-PYSVM,154 vumi/transports/wechat/message_types.py,sha256=E6HRQ0tjIatiVGPltLUTWaTe2OTsbyLqOy2MgsZTL7U,6113 vumi/transports/wechat/test_message_types.py,sha256=dI60XfD5pb9rQaPahTFCf9IIsSAXaX1reolP59KuWq0,5068 vumi/transports/wechat/wechat.py,sha256=XGYbJKgqrAgE3jvaygfF2J21Zf00in5zlGB7nGu7YwI,21236 vumi/transports/wechat/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/wechat/tests/test_wechat.py,sha256=8URfwzCQAn4uc12yCfzVwYOF5dD2MhdNb6uw9ZbXU1s,28906 vumi/transports/xmpp/__init__.py,sha256=G8L4G3-zvbJF7Mdz_bM4HJVNBLFV_VCj6WluLck51s8,112 vumi/transports/xmpp/xmpp.py,sha256=f6Q14bEPAmriiRPia7gsB6zVVzDVliY7bjsOc-4zJ5s,7980 vumi/transports/xmpp/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vumi/transports/xmpp/tests/test_xmpp.py,sha256=cuQBux1ki7jXSTRQ-Nms501a_iancpEWaNemTTLrn6M,8072 vumi-0.6.9.data/scripts/vumi_count_models.py,sha256=xPNBXkkjWRpH1DxYwm89JN24NVJa-IxrxpHky7_i19Y,6044 vumi-0.6.9.data/scripts/vumi_list_messages.py,sha256=EA2suU_Q_d41tky6MRR7kRu4at7wUNoyM2cptSL21dE,3129 vumi-0.6.9.data/scripts/vumi_model_migrator.py,sha256=-0C1cTPmNG7clMF8JLrjdKTMOfvwGhln9LOmhjJulDA,7886 vumi-0.6.9.data/scripts/vumi_redis_tools.py,sha256=0gpes79os4WcySeWbKx3Qu6PJE_ciaAMX_oHJwltUwY,6302 vumi-0.6.9.data/scripts/vumi_tagpools.py,sha256=SIexkiozVvpDXYSdzsQd1MMjHc9F1MjNYDMImfIaaBY,7478 vumi-0.6.9.dist-info/DESCRIPTION.rst,sha256=DtY9pi-nU4Vb5Bll3nbbWAH61VLX02ta0AY5-zOIxQg,2474 vumi-0.6.9.dist-info/METADATA,sha256=fX_KMT1foLCvF_M4faxIqStPEtj6KUGs_5k5SABNaRU,3794 vumi-0.6.9.dist-info/RECORD,, vumi-0.6.9.dist-info/WHEEL,sha256=JTb7YztR8fkPg6aSjc571Q4eiVHCwmUDlX8PhuuqIIE,92 vumi-0.6.9.dist-info/metadata.json,sha256=T7pUO2cECdE4mVo8-LTq6HLGLEzFSd8HNgpelPOrbOk,1247 vumi-0.6.9.dist-info/top_level.txt,sha256=fuJlO7sihz5a651o3osLOeoNoKh_GBO6VurzifxdnBI,13 PK=JG4U_&twisted/plugins/vumi_worker_starter.pyPK]axGOp+vumi-0.6.9.data/scripts/vumi_redis_tools.pyPK]axG r3,vumi-0.6.9.data/scripts/vumi_count_models.pyPK]axGHW66(2vumi-0.6.9.data/scripts/vumi_tagpools.pyPK]axG29 9 -=Pvumi-0.6.9.data/scripts/vumi_list_messages.pyPK]axG:.\vumi-0.6.9.data/scripts/vumi_model_migrator.pyPKrgTGE Ռ{vumi/connectors.pyPK=JGSp"" vumi/rpc.pyPKqGv;v; Kvumi/utils.pyPKh^xG>>vumi/message.pyPK=H{~6DD'4vumi/__init__.pyPK=JGV\>ZZ4vumi/errors.pyPKqGTECC7vumi/service.pyPK=JGQzvumi/sentry.pyPK=JG0 mvumi/multiworker.pyPK=JGMW-fvumi/servicemaker.pyPKqG[ ovumi/log.pyPKqG< < Pvumi/worker.pyPK=JGxʐ vumi/config.pyPK=JG.b  vumi/reconnecting_client.pyPK=JGMO vumi/scripts/vumi_redis_tools.pyPK=JG%i44"vumi/scripts/db_backup.pyPK=JG܎!Wvumi/scripts/vumi_count_models.pyPK=JG_CCovumi/scripts/vumi_tagpools.pyPK=JG۟!'vumi/scripts/benchmark_persist.pyPK=JG 9IIvumi/scripts/model_migrator.pyPK=JGvumi/scripts/__init__.pyPK=JGNzA  vumi/scripts/inject_messages.pyPK=JG1F F "vumi/scripts/vumi_list_messages.pyPK=JG熩#vumi/scripts/vumi_model_migrator.pyPK=JGG5Ή"vumi/scripts/parse_log_messages.pyPK=JGO_ -ovumi/scripts/tests/test_parse_log_messages.pyPK=JG岴SS,>vumi/scripts/tests/test_vumi_count_models.pyPK=JG~U@@-7vumi/scripts/tests/test_vumi_list_messages.pyPK=JG(fQvumi/scripts/tests/test_vumi_tagpools.pyPK=JGpvumi/scripts/tests/__init__.pyPK=JG@zdd)pvumi/scripts/tests/sample-smpp-output.logPK=JGx)tvumi/scripts/tests/test_model_migrator.pyPK=JG@@.`vumi/scripts/tests/test_vumi_model_migrator.pyPK=JGLP77$Nvumi/scripts/tests/test_db_backup.pyPK=JGr--+vumi/scripts/tests/test_vumi_redis_tools.pyPKqGCC *w2vumi/scripts/tests/test_inject_messages.pyPK=H1·**)I=vumi/components/message_store_resource.pyPK=JGe]]hvumi/components/tagpool_api.pyPK=JG``&vumi/components/message_store_cache.pyPK=JGBSvv!vumi/components/window_manager.pyPKh^xG d vumi/components/message_store.pyPK=JGhU#vumi/components/schedule_manager.pyPK=JGO!!vumi/components/__init__.pyPK=JG%K""<vumi/components/tagpool.pyPK=JGclqvumi/components/session.pyPK=JGf $vumi/components/message_store_api.pyPK@H(M M % vumi/components/message_formatters.pyPK=JGSo&o&*vumi/components/message_store_migrators.pyPK=JG+ %>vumi/components/tests/test_session.pyPKh^xG**0Ivumi/components/tests/test_message_formatters.pyPKh^xGTT4favumi/components/tests/test_message_store_resource.pyPK=JGA*A*%ʵvumi/components/tests/test_tagpool.pyPK=JGRJTT5Nvumi/components/tests/test_message_store_migrators.pyPK=JG!vumi/components/tests/__init__.pyPK=JGl%%,4vumi/components/tests/test_window_manager.pyPK=JGtk=,,/Pvumi/components/tests/test_message_store_api.pyPK=JGxRRR1vumi/components/tests/test_message_store_cache.pyPKh^xGX>=>=+Ivumi/components/tests/test_message_store.pyPK=JG[q.Z vumi/components/tests/test_schedule_manager.pyPK=JGb De$e$1[ vumi/components/tests/message_store_old_models.pyPK=JG))) vumi/components/tests/test_tagpool_api.pyPK=JGƋB vumi/codecs/ivumi_codecs.pyPK=JG vumi/codecs/vumi_codecs.pyPK=JG?7GG vumi/codecs/__init__.pyPK=JG  % vumi/codecs/tests/test_vumi_codecs.pyPK=JG vumi/codecs/tests/__init__.pyPK=JG"dE`E`1 vumi/dispatchers/base.pyPK=JG==(r vumi/dispatchers/endpoint_dispatchers.pyPK=JG܎\cc/ vumi/dispatchers/__init__.pyPK=JG**!̌ vumi/dispatchers/load_balancer.pyPK=JGxpZ}}5 vumi/dispatchers/tests/utils.pyPKqG )r)r# vumi/dispatchers/tests/test_base.pyPK=JGJ#55,Y vumi/dispatchers/tests/test_load_balancer.pyPKqGE%%33 vumi/dispatchers/tests/test_endpoint_dispatchers.pyPK=JGk33+)Y vumi/dispatchers/tests/test_test_helpers.pyPK=JGw!v vumi/dispatchers/tests/helpers.pyPK=JG" vumi/dispatchers/tests/__init__.pyPKqGStG@G@< vumi/tests/test_connectors.pyPKfcH| vumi/tests/test_testutils.pyPKh^xG.1SS vumi/tests/test_message.pyPKrgTGm,$$C/ vumi/tests/test_worker.pyPK=JGtS vumi/tests/test_config.pyPKqGqp vumi/tests/test_log.pyPKfcHy vumi/tests/utils.pyPK=JG2dr vumi/tests/test_sentry.pyPK=JGi$+  vumi/tests/test_rpc.pyPKqGs-jcOO vumi/tests/test_test_helpers.pyPK=JG4i&i&&vumi/tests/test_reconnecting_client.pyPKqGU}OBvumi/tests/test_service.pyPK=JGH))Wvumi/tests/test_servicemaker.pyPKqG ffhvumi/tests/helpers.pyPK=JGmvumi/tests/__init__.pyPKqGVy::5mvumi/tests/test_fake_amqp.pyPKqG8,{H{HHvumi/tests/fake_amqp.pyPK=JGgVvumi/tests/test_multiworker.pyPKqGR(BjJjJ vumi/tests/test_utils.pyPK=H ##Lvumi/tests/fake_connection.pyPKqGkUѴ??"pvumi/tests/test_fake_connection.pyPKqG2vumi/transports/base.pyPK=JG|vumi/transports/scheduler.pyPK=JG??;vumi/transports/__init__.pyPK=JG</""vumi/transports/failures.pyPK=JGnY]]#~ vumi/transports/netcore/__init__.pyPK=JGv" vumi/transports/netcore/netcore.pyPK=JG)Gvumi/transports/netcore/tests/__init__.pyPK=JGo11-vumi/transports/netcore/tests/test_netcore.pyPK=JGBF# .vumi/transports/twitter/__init__.pyPK=JG9G))".vumi/transports/twitter/twitter.pyPK=JG+3 B B-Xvumi/transports/twitter/tests/test_twitter.pyPK=JG)vumi/transports/twitter/tests/__init__.pyPK=JG5$-cvumi/transports/vodacom_messaging/__init__.pyPK=JGМ6vumi/transports/vodacom_messaging/vodacom_messaging.pyPK=JG //Avumi/transports/vodacom_messaging/tests/test_vodacom_messaging.pyPK=JG3vumi/transports/vodacom_messaging/tests/__init__.pyPK=JGi]]#vumi/transports/devnull/__init__.pyPK=JG6t L: : "vumi/transports/devnull/devnull.pyPK=JGvb-vumi/transports/devnull/tests/test_devnull.pyPK=JG)vumi/transports/devnull/tests/__init__.pyPK=JGl]]#Uvumi/transports/apposit/__init__.pyPKqG0Ӌ"vumi/transports/apposit/apposit.pyPK=JG)vumi/transports/apposit/tests/__init__.pyPKqG#r!W((-avumi/transports/apposit/tests/test_apposit.pyPK=JGOY!S8vumi/transports/integrat/utils.pyPKqG׵$Ivumi/transports/integrat/integrat.pyPK=JG:tA$^vumi/transports/integrat/__init__.pyPKqGg+U%%/]_vumi/transports/integrat/tests/test_integrat.pyPK=JG*Cvumi/transports/integrat/tests/__init__.pyPK=JGJ c6c6,vumi/transports/integrat/tests/test_utils.pyPK=JGt1'8vumi/transports/mtn_nigeria/__init__.pyPK=JGO3##/6vumi/transports/mtn_nigeria/mtn_nigeria_ussd.pyPK=Hʣ)F)F+\vumi/transports/mtn_nigeria/xml_over_tcp.pyPK=H19r9r6'vumi/transports/mtn_nigeria/tests/test_xml_over_tcp.pyPKqGG\^ *[vumi/transports/mtn_nigeria/tests/utils.pyPK=JGm8s%+,+,50vumi/transports/mtn_nigeria/tests/test_mtn_nigeria.pyPK=JG-vumi/transports/mtn_nigeria/tests/__init__.pyPK=JG(.'vumi/transports/mediafonemc/__init__.pyPKqGoF F *vumi/transports/mediafonemc/mediafonemc.pyPKqG͊UU5`vumi/transports/mediafonemc/tests/test_mediafonemc.pyPK=JG-vumi/transports/mediafonemc/tests/__init__.pyPK=JGD,Svumi/transports/wechat/test_message_types.pyPK=JG&*$YY"i vumi/transports/wechat/__init__.pyPK=JGj  vumi/transports/wechat/errors.pyPKfcH RR vumi/transports/wechat/wechat.pyPK=JGPF' ^vumi/transports/wechat/message_types.pyPKfcH6[pp+2vvumi/transports/wechat/tests/test_wechat.pyPK=JG(evumi/transports/wechat/tests/__init__.pyPK=JG&q"vumi/transports/parlayx/parlayx.pyPK=JG.^6^6"vumi/transports/parlayx/xmlutil.pyPK=JG\EN!89vumi/transports/parlayx/client.pyPK=JGlJ4xx#4Wvumi/transports/parlayx/__init__.pyPK=JG/!Wvumi/transports/parlayx/server.pyPK=JG]388#tvumi/transports/parlayx/soaputil.pyPK=JGM]GG&vumi/transports/parlayx/tests/utils.pyPK=JG/"ۿ66, vumi/transports/parlayx/tests/test_client.pyPK=JGz**,vumi/transports/parlayx/tests/test_server.pyPK=JG)Pvumi/transports/parlayx/tests/__init__.pyPK=JGh?gܻ##-vumi/transports/parlayx/tests/test_parlayx.pyPK=JGz'&EE-*vumi/transports/parlayx/tests/test_xmlutil.pyPK=JG)).lpvumi/transports/parlayx/tests/test_soaputil.pyPK=JGv^y+aa"vumi/transports/airtel/__init__.pyPK=JGVnl%% Bvumi/transports/airtel/airtel.pyPK=JG('vumi/transports/airtel/tests/__init__.pyPK=JGDD+mvumi/transports/airtel/tests/test_airtel.pyPKh^xG&4,,vumi/transports/xmpp/xmpp.pyPK=JGΐrpp %vumi/transports/xmpp/__init__.pyPKh^xG^dV3'&vumi/transports/xmpp/tests/test_xmpp.pyPK=JG&iFvumi/transports/xmpp/tests/__init__.pyPK=JGC55&Fvumi/transports/tests/test_failures.pyPK=JG~[ |vumi/transports/tests/utils.pyPKqGrr"vumi/transports/tests/test_base.pyPKqGɈ*cvumi/transports/tests/test_test_helpers.pyPKrgTGԊ 2vumi/transports/tests/helpers.pyPK=JG! vumi/transports/tests/__init__.pyPK=JG!^'Lvumi/transports/tests/test_scheduler.pyPK=JG#xvumi/transports/smssync/__init__.pyPK=JGP"ӧ11"zvumi/transports/smssync/smssync.pyPK=JG)a vumi/transports/smssync/tests/__init__.pyPK=JG""- vumi/transports/smssync/tests/test_smssync.pyPK=H++*-vumi/transports/vumi_bridge/vumi_bridge.pyPKqG_W/ %Yvumi/transports/vumi_bridge/client.pyPK[H^-Mzz'gvumi/transports/vumi_bridge/__init__.pyPKqGS 0|hvumi/transports/vumi_bridge/tests/test_client.pyPK gHGP(P(5hrvumi/transports/vumi_bridge/tests/test_vumi_bridge.pyPK=JG- vumi/transports/vumi_bridge/tests/__init__.pyPK=JG"w!Vvumi/transports/dmark/__init__.pyPKqGr}Ƥ&&#9vumi/transports/dmark/dmark_ussd.pyPKqG)ܶsDsD.vumi/transports/dmark/tests/test_dmark_ussd.pyPK=JG'vumi/transports/dmark/tests/__init__.pyPK=JG%n]'"vumi/transports/mtech_kenya/__init__.pyPKqGW5WW*vumi/transports/mtech_kenya/mtech_kenya.pyPKqG6e5vumi/transports/mtech_kenya/tests/test_mtech_kenya.pyPK=JG-z7vumi/transports/mtech_kenya/tests/__init__.pyPKqG4/|"+"+$7vumi/transports/vas2nets/vas2nets.pyPK=JGԡ$)cvumi/transports/vas2nets/__init__.pyPK=JG{䴯+cvumi/transports/vas2nets/transport_stubs.pyPKqG%I*2*2/uvumi/transports/vas2nets/tests/test_vas2nets.pyPKqG7j/\vumi/transports/vas2nets/tests/test_failures.pyPK=JG*?vumi/transports/vas2nets/tests/__init__.pyPK=JGDy5vumi/transports/vas2nets/tests/test_vas2nets_stubs.pyPK=JGZv1))vumi/transports/irc/irc.pyPK=JGdgccvumi/transports/irc/__init__.pyPK=JGܼ/E/E%7vumi/transports/irc/tests/test_irc.pyPK=JG%Fvumi/transports/irc/tests/__init__.pyPKqG)kpp Fvumi/transports/telnet/telnet.pyPK=JG'"dvumi/transports/telnet/__init__.pyPK=JG(evumi/transports/telnet/tests/__init__.pyPK=JG7 `/!/!+fvumi/transports/telnet/tests/test_telnet.pyPK=JG,A%vumi/transports/imimobile/__init__.pyPK=JGmV8!!+evumi/transports/imimobile/imimobile_ussd.pyPK=JGEF&F&6Ϧvumi/transports/imimobile/tests/test_imimobile_ussd.pyPK=JG+ivumi/transports/imimobile/tests/__init__.pyPK=JGw͆88 vumi/transports/truteq/truteq.pyPK=JG'rr"(vumi/transports/truteq/__init__.pyPKqGw~ + ++vumi/transports/truteq/tests/test_truteq.pyPK=JG(-vumi/transports/truteq/tests/__init__.pyPK=JGx!bb(svumi/transports/mtech_ussd/mtech_ussd.pyPK=JGkX_&3vumi/transports/mtech_ussd/__init__.pyPK=JG,3vumi/transports/mtech_ussd/tests/__init__.pyPK=JG&W..3-4vumi/transports/mtech_ussd/tests/test_mtech_ussd.pyPKqG>"Hcvumi/transports/httprpc/httprpc.pyPK=JGl@vumi/transports/httprpc/auth.pyPK=JG˘#vumi/transports/httprpc/__init__.pyPKqGKJ%%-vumi/transports/httprpc/tests/test_httprpc.pyPK=JG,G66(vumi/transports/httprpc/tests/helpers.pyPK=JG)`vumi/transports/httprpc/tests/__init__.pyPK=JG P//*vumi/transports/httprpc/tests/test_auth.pyPK=JG&vumi/transports/mtn_rwanda/__init__.pyPK=JGk  -bvumi/transports/mtn_rwanda/mtn_rwanda_ussd.pyPK=JGi'558vumi/transports/mtn_rwanda/tests/test_mtn_rwanda_ussd.pyPK=JG,=vumi/transports/mtn_rwanda/tests/__init__.pyPKqG{f7!7!vumi/transports/mxit/mxit.pyPK=JGPQQ >vumi/transports/mxit/__init__.pyPK=JGQC!?vumi/transports/mxit/responses.pyPK=JGEwBB+Evumi/transports/mxit/templates/response.xmlPK=H .&&';Hvumi/transports/mxit/tests/test_mxit.pyPK=JG&nvumi/transports/mxit/tests/__init__.pyPK=JG[[nvumi/transports/api/oldapi.pyPK=JG8 _vumi/transports/api/api.pyPK=JGEyYYvumi/transports/api/__init__.pyPK=JG)@@%vumi/transports/api/tests/test_api.pyPK=JG%vumi/transports/api/tests/__init__.pyPK=JG_B(۬vumi/transports/api/tests/test_oldapi.pyPKqGַWs==$vumi/transports/smpp/smpp_service.pyPK=HٙYPYP&Avumi/transports/smpp/smpp_transport.pyPKH`zmWmW Pvumi/transports/smpp/protocol.pyPK=JGdL)"vumi/transports/smpp/smpp_utils.pyPK=JG+[[ Wvumi/transports/smpp/__init__.pyPK=JG(!vumi/transports/smpp/pdu_utils.pyPK=JG #vumi/transports/smpp/iprocessors.pyPK=H,~Lvumi/transports/smpp/config.pyPK=JG+V Gvumi/transports/smpp/sequence.pyPKqGoNL 8 8+uvumi/transports/smpp/tests/test_protocol.pyPK=JG&vumi/transports/smpp/tests/__init__.pyPK=JGs ) )'vumi/transports/smpp/tests/fake_smsc.pyPK=JG/2n`n`,`Hvumi/transports/smpp/tests/test_fake_smsc.pyPK=H``1vumi/transports/smpp/tests/test_smpp_transport.pyPKqGg>H>H/ vumi/transports/smpp/tests/test_smpp_service.pyPK=JG{z+sR vumi/transports/smpp/tests/test_sequence.pyPK=JG=(@X vumi/transports/smpp/deprecated/utils.pyPK=JG+x] vumi/transports/smpp/deprecated/__init__.pyPK=JG=fQ*] vumi/transports/smpp/deprecated/service.pyPK=JGU}W}W,b vumi/transports/smpp/deprecated/transport.pyPK=JG57p7p6Ϲ vumi/transports/smpp/deprecated/clientserver/client.pyPK=JG8Z*!vumi/transports/smpp/deprecated/clientserver/__init__.pyPK=JGl6*!vumi/transports/smpp/deprecated/clientserver/server.pyPK=JG;?!vumi/transports/smpp/deprecated/clientserver/tests/utils.pyPK=JGTTAC!vumi/transports/smpp/deprecated/clientserver/tests/test_client.pyPK=JGEEA!vumi/transports/smpp/deprecated/clientserver/tests/test_server.pyPK=JG>!vumi/transports/smpp/deprecated/clientserver/tests/__init__.pyPK=JG^/2!vumi/transports/smpp/deprecated/tests/test_smpp.pyPK=JG1<"vumi/transports/smpp/deprecated/tests/__init__.pyPK=H 'RR*Q<"vumi/transports/smpp/processors/default.pyPK=H)SS)"vumi/transports/smpp/processors/sixdee.pyPKqGv\gg':"vumi/transports/smpp/processors/mica.pyPK=JG (a+"vumi/transports/smpp/processors/__init__.pyPK=JGIN:4:42"vumi/transports/smpp/processors/tests/test_mica.pyPK=JG1"vumi/transports/smpp/processors/tests/__init__.pyPK=HFؠ::4"vumi/transports/smpp/processors/tests/test_sixdee.pyPKqG| *P4#vumi/transports/cellulant/cellulant_sms.pyPK=JGZ%nD#vumi/transports/cellulant/__init__.pyPK=JG_//&E#vumi/transports/cellulant/cellulant.pyPKqG7h005X#vumi/transports/cellulant/tests/test_cellulant_sms.pyPK=JGB%1C#vumi/transports/cellulant/tests/test_cellulant.pyPK=JG+#vumi/transports/cellulant/tests/__init__.pyPK=JGC&Ξ#vumi/transports/safaricom/safaricom.pyPK=JG &%#vumi/transports/safaricom/__init__.pyPK=JG+#vumi/transports/safaricom/tests/__init__.pyPK=JGY[)""1#vumi/transports/safaricom/tests/test_safaricom.pyPK=JGA@g**"#vumi/transports/infobip/infobip.pyPK=JGX#w$vumi/transports/infobip/__init__.pyPK=JG)M$vumi/transports/infobip/tests/__init__.pyPK=JGY+Y+-$vumi/transports/infobip/tests/test_infobip.pyPK=JG'8-$vumi/transports/trueafrican/__init__.pyPK=Hc-c-(}-$vumi/transports/trueafrican/transport.pyPK=JG-&[$vumi/transports/trueafrican/tests/__init__.pyPK=HX.r9+9+3q[$vumi/transports/trueafrican/tests/test_transport.pyPK=JGHrr($vumi/transports/mediaedgegsm/__init__.pyPKqGRڵBB,$vumi/transports/mediaedgegsm/mediaedgegsm.pyPK=JG.?$vumi/transports/mediaedgegsm/tests/__init__.pyPKqG"7$vumi/transports/mediaedgegsm/tests/test_mediaedgegsm.pyPK=JG5HԴ$vumi/transports/opera/utils.pyPK=JG K $a6vumi/application/sandbox_rlimiter.pyPKqG`!1!1"o6vumi/application/rapidsms_relay.pyPK=JGf&<6vumi/application/tests/test_session.pyPK=JG!./z6vumi/application/tests/test_sandbox_rlimiter.pyPK=JGCڷ6vumi/application/tests/app.jsPK=JG .6vumi/application/tests/app_delayed_requests.jsPK=JG}ކ 6vumi/application/tests/utils.pyPKqG>'D4D4#.6vumi/application/tests/test_base.pyPK=JG+6vumi/application/tests/app_requires_path.jsPK=JGqQei&6vumi/application/tests/test_sandbox.pyPKqG#,,-V7vumi/application/tests/test_rapidsms_relay.pyPK=JG|+A7vumi/application/tests/test_test_helpers.pyPK=JGYG G !D8vumi/application/tests/helpers.pyPK=JG"8vumi/application/tests/__init__.pyPKqG"M))) 8vumi/application/tests/test_http_relay.pyPK=HK> $z 8vumi-0.6.9.dist-info/DESCRIPTION.rstPK=HgglA"f*8vumi-0.6.9.dist-info/metadata.jsonPK=H "/8vumi-0.6.9.dist-info/top_level.txtPK=H''\\/8vumi-0.6.9.dist-info/WHEELPK=HJ-3f08vumi-0.6.9.dist-info/METADATAPK=Hږs?8vumi-0.6.9.dist-info/RECORDPK<8