#!/usr/bin/env python

import itertools
import logging
import logging.config
import sys, os

import docopt
import multiprocessing
import psycopg2

options = docopt.docopt("""
usage: codex-process-events [options] [<processor> ...]

options:
  -s <string>, --schema <string>    Database schema [default: public]
  -d <string>, --database <string>  Database name or URI [default: postgres]
  -v, --verbose                     Display additional debugging information
  -h, --help                        Display this help then exit
  <processor>                       Python script with processing function
""")

logging.config.dictConfig({
    "version": 1,
    "disable_existing_loggers": True,
    "formatters": {"default": {
        "format": "[%(asctime)s] %(levelname)s: %(message)s"
        }},
    "handlers": {"default": {
        "class": "logging.StreamHandler",
        "formatter": "default",
        }},
    "loggers": {"": {
        "handlers": ["default"],
        "level": logging.DEBUG if (options["--verbose"]) else logging.INFO,
        "propagate": True
        }}
    })

logger = logging.getLogger(os.path.basename(__file__))

def error (msg, is_exception = False):
    if (is_exception) and (options["--verbose"]):
        logger.exception(msg)
    else:
        logger.error(msg)
    sys.exit(1)

for event_processor_fn in options["<processor>"]:
    if (not os.path.exists(event_processor_fn)):
        error("file '%s' not found" % event_processor_fn)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

logger.info("connecting to database server")

postgres_uri = options["--database"]
postgres_schema = options["--schema"]

def create_connection():
    try:
        # if the value provided is a URI, pass it as such
        if (postgres_uri.startswith("postgresql://")):
            connection = psycopg2.connect(postgres_uri)

        # if not, consider it to be a database name
        else:
            connection = psycopg2.connect(database = postgres_uri)

        with connection.cursor() as cursor:
            cursor.execute("SET search_path TO %s", (postgres_schema,))

        return connection

    except psycopg2.Error as e:
        error("error from database server:\n%s" % e, True)

logger.info("connecting to database server: done")

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

person_ids = multiprocessing.Queue()

def retrieve_person_events (connection, person_id):
    query = "SELECT * FROM codex_events WHERE (person_id = %s)"
    with connection.cursor() as cursor:
        for entry in cursor.execute(query, (person_id,)):
            yield entry

def process_person_events (connection, person_id):
    timeline = retrieve_person_events(connection, person_id)

    #
    jobs = zip(event_processors, itertools.tee(
        timeline, n = len(event_processors)))

    for (event_processor, timeline) in jobs:
        event_processor(timeline)


if (__name__ == "__main__"):
    sys.exit(0)

executor = multiprocessing.Pool(None, process_person_events)



    query = "SELECT DISTINCT person_id FROM codex_events"
    with create_connection() as connection:
        with connection.cursor() as cursor:
            for (person_id,) in cursor.execute(query):
                person_ids.put(person_id, block = False)


"""


        (SELECT DISTINCT person_id FROM %(schema)s.codex_events

event_processors = []

def events_processor (variable_names, accepts_only = None):
    if (accepts_only is None):
        filtered_event_classes = []
        events_filter = lambda event: True
    else:
        if (isinstance(accepts_only, magenta.EVENT_CLASS)):
            accepts_only = [accepts_only]
        elif (not isinstance(accepts_only, collections.Iterable)) or \
             (not all(isinstance(ec, magenta.EVENT_CLASS) for ec in accepts_only)):
            raise ValueError("invalid value for 'accepts_only': %s" % accepts_only)

        filtered_event_classes = accepts_only
        events_filter = lambda event: all(
            lambda event: (event.event_class == ec) for ec in accepts_only)

    def wrapped (fn):
        def wrapper (patient_id, patient_events):
            try:
                results = fn(patient_id, filter(events_filter, patient_events))
            except Exception as e:
                _forced_print('\n')
                error("error in processor '%s':" % fn.__name__, True)

            if (results is None):
                results = [None] * len(variable_names)
            else:
                assert (len(results) == len(variable_names))

            return zip(variable_names, results)

        msg = "new event processor: '%s'" % fn.__name__
        if (len(filtered_event_classes) > 0):
            msg += ", accepting only %s" % ' '.join(sorted(filtered_event_classes))

        logger.debug(msg)
        _event_processors.append(wrapper)
        return wrapper

    return wrapped

if (options.cohort_processors_fns is not None):
    for cohort_processors_fn in options.cohort_processors_fns:
        execfile(cohort_processors_fn)
else:
    @events_processor(
        variable_names = (
            "person_id",                # person identifier
            "age",                      # age (either current or at death)
            "number_of_visit_occurrences",
            "number_of_drug_exposures",
            "number_of_procedure_occurrences",
            "number_of_condition_occurrences",
            "number_of_observations",
            "first_event_date",         # date of first event (birth excluded)
            "last_event_date",          # date of last event (death included)
            "events_span_in_years")     # span of events (birth excluded)
    )
    def default_processor (person_id, person_events):
        birth_date, death_date = None, datetime.date.today()
        number_of_events = {}
        first_date, last_date = None, None

        for event in person_events:
            event_class = event.event_class
            event_date = event.event_start_date

            if (event_class == magenta.EVENT_CLASS.BIRTH_RECORD):
                birth_date = event_date
                continue

            if (not event_class in number_of_events):
                number_of_events[event_class] = 0

            number_of_events[event_class] += 1

            if (first_date is None):
                first_date = event_date

            last_date = event_date

            if (event_class == magenta.EVENT_CLASS.DEATH_RECORD):
                death_date = event_date

        return (
            person_id,
            magenta.utils.time_delta_in_years(birth_date, death_date),
            number_of_events.get(magenta.EVENT_CLASS.VISIT_OCCURRENCE, 0),
            number_of_events.get(magenta.EVENT_CLASS.DRUG_EXPOSURE, 0),
            number_of_events.get(magenta.EVENT_CLASS.PROCEDURE_OCCURRENCE, 0),
            number_of_events.get(magenta.EVENT_CLASS.CONDITION_OCCURRENCE, 0),
            number_of_events.get(magenta.EVENT_CLASS.OBSERVATION, 0),
            first_date,
            last_date,
            magenta.utils.time_delta_in_years(first_date, last_date),
        )

logger.info("%d processor%s defined" % (
    len(_event_processors),
    {True: 's', False: ''}[(len(_event_processors) > 1)]))

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

logger.info(
    "retrieving events for members of cohort '%s'" % options.cohort_name)

try:
    events = magenta.fetch_events_for_cohort(
        options.db_schema,
        options.cohort_name,
        batch_size = options.batch_size,
        max_events_per_member = options.max_events_per_member,
        call_backs = (
            (01, lambda: _forced_print('.')),
            (10, lambda: _forced_print(' ')),
            (50, lambda: _forced_print('\n'))
        ))

    n_persons, n_events, n_empty = 0, 0, 0
    for (person_id, person_events) in events:
        n_persons += 1
        n_events += len(person_events)

        # aggregation of key/value pairs returned by each event
        # processor when applied on the same list of person events
        person_kv = []
        for ep in _event_processors:
            person_kv.extend(ep(person_id, person_events))

        # if all the values are None, we skip this person entirely
        is_empty = True
        for k, v in person_kv:
            if (v is not None):
                is_empty = False

        if (is_empty):
            n_empty += 1
            continue

        # send all key/value pairs to the user-selected serializers
        for es in _event_serializers:
            es.add_event(person_kv)

        del person_events
        del person_kv

    _forced_print('\n')

except KeyboardInterrupt:
    error("operation cancelled by user")

except magenta.RedshiftException as e:
    error(e)

logger.info(
    "retrieving events for members of cohort '%s': done (%s members%s; %s events)" % (
    options.cohort_name,
    magenta.utils.str_ts(n_persons),
    ", plus %s excluded by processors" % magenta.utils.str_ts(n_empty) if (n_empty > 0) else '',
    magenta.utils.str_ts(n_events)))
"""
logger.info("all done")
