#!/usr/bin/env python

from __future__ import print_function
from __future__ import unicode_literals

import logging
import logging.config
import pkg_resources
import re
import sys, os

import docopt
import psycopg2
import psycopg2.extras
import six
import sqlparse

options = docopt.docopt("""
usage: codex-stage-events [options]

options:
  -s <string>, --schema <string>    Database schema [default: public]
  -d <string>, --database <string>  Database name or URI [default: postgres]
  --cohort-from-script <filename>   Cohort to consider, as a SQL script
  --cohort-from-query <string>      Cohort to consider, as a SQL query
  --cohort-from-list <string>       Cohort to consider, as a list of identifiers
  --cohort-from-table <string>      Cohort to consider, as a table
  --events-table <string>           Events table name [default: codex_events]
  -v, --verbose                     Display additional debugging information
  -h, --help                        Display this help then exit

notes:
  - if --cohort-from-script or --cohort-from-query is used, the corresponding
    SELECT statement must return a 'person_id' field with the cohort member
    identifiers; all other fields in this statement will be ignored
  - if --cohort-from-list is used, the corresponding string must be a
    comma-separated list of integers
  - if --cohort-from-table is used, the corresponding table must exist in
    the database schema and have a 'person_id' field with the cohort member
    identifiers; all other fields in this table will be ignored
""")

logging.config.dictConfig({
    "version": 1,
    "disable_existing_loggers": True,
    "formatters": {"default": {
        "format": "[%(asctime)s] %(levelname)s: %(message)s"
        }},
    "handlers": {"default": {
        "class": "logging.StreamHandler",
        "formatter": "default",
        }},
    "loggers": {"": {
        "handlers": ["default"],
        "level": logging.DEBUG if (options["--verbose"]) else logging.INFO,
        "propagate": True
        }}
    })

logger = logging.getLogger(os.path.basename(__file__))

def error (msg, is_exception = False):
    if (is_exception) and (options["--verbose"]):
        logger.exception(msg)
    else:
        logger.error(msg)
    sys.exit(1)

cohort_n_options = \
    (options["--cohort-from-script"] is not None) + \
    (options["--cohort-from-query"] is not None) + \
    (options["--cohort-from-list"] is not None) + \
    (options["--cohort-from-table"] is not None)

if (cohort_n_options > 1):
    error("only one of --cohort-from-* options is allowed")

if (options["--cohort-from-script"] is not None):
    arg = options["--cohort-from-script"]
    if (arg == '-'):
        query = sys.stdin.read()
    else:
        if (not os.path.isfile(arg)):
            error("file not found: %s" % arg)

        query = open(arg, "rU").read()

    options["--cohort-from-script"] = None
    options["--cohort-from-query"] = query

if (options["--cohort-from-query"] == '-'):
    options["--cohort-from-query"] = sys.stdin.read()

if (options["--cohort-from-list"] is not None):
    arg = options["--cohort-from-list"]
    try:
        identifiers = sorted(dict.fromkeys(map(int, arg.split(','))))
    except:
        error("invalid value for --cohort-from-list: %s" % arg)

    options["--cohort-from-list"] = identifiers

def force_print (text):
    print(text)
    sys.stdout.flush()

force_print("staging events to '%s.%s'\n" % (
    options["--schema"], options["--events-table"]))

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

logger.info("connecting to database server")

logger.debug("psycopg2.__version__ = %s", psycopg2.__version__)

postgres_uri = options["--database"]
postgres_schema = options["--schema"]

try:
    # if the value provided is a URI, pass it as such
    if (postgres_uri.startswith("postgresql://")):
        connection = psycopg2.connect(postgres_uri)

    # if not, consider it to be a database name
    else:
        connection = psycopg2.connect(database = postgres_uri)

    psycopg2.extensions.set_wait_callback(
        psycopg2.extras.wait_select)

    with connection.cursor() as cursor:
        cursor.execute("SET search_path TO %s", (postgres_schema,))

except psycopg2.Error as e:
    error("error from database server:\n%s" % e, True)

server_version = str(connection.server_version)
logger.debug("connection.server_version = %s", server_version)

if (len(server_version) % 2 == 1):
    server_version = '0' + server_version

wanted_version = (9,5,0)
server_version = tuple(map(int,
    [server_version[i:i+2] for i in range(0, len(server_version), 2)]))

if (server_version < wanted_version):
    error("invalid PostgreSQL server version: %s (should be %s or above)" % (
        '.'.join(map(str, server_version)),
        '.'.join(map(str, wanted_version))
        ))

logger.info("connecting to database server: done")

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

cohort_ddl = None

if (options["--cohort-from-query"] is not None):
    cohort_ddl = sqlparse.format(
        options["--cohort-from-query"],
        keyword_case = "upper",
        strip_comments = True,
        reindent = True).strip()

    if (len(sqlparse.split(cohort_ddl)) > 1):
        error("invalid SQL query: only one statement allowed")

    if (cohort_ddl.endswith(';')):
        cohort_ddl = cohort_ddl[:-1]

    cohort_arg = None

if (options["--cohort-from-list"] is not None):
    cohort_ddl = """\
     SELECT unnest(%(cohort_definition)s) AS person_id
        """
    cohort_arg = options["--cohort-from-list"]

if (options["--cohort-from-table"] is not None):
    cohort_ddl = """\
     SELECT DISTINCT person_id FROM %(cohort_definition)s
        """
    cohort_arg = options["--cohort-from-table"]

if (cohort_ddl is None):
    cohort_sql = ' ' * 12
    cohort_arg = None
else:
    cohort_ddl = "CREATE TEMPORARY VIEW codex_cohort AS\n%s;" % cohort_ddl
    cohort_sql = """\
 INNER JOIN codex_cohort
         ON (codex_source.person_id = codex_cohort.person_id)
        """

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

cwd = os.path.abspath(os.path.dirname(__file__))

def load_sql (name, kwargs):
    content = pkg_resources.resource_string(
        "codex", "share/%s.sql" % name).decode("utf-8")

    for (key, value) in six.iteritems(kwargs):
        content = content.replace("{{%s}}" % key, value)

    for statement in sqlparse.split(content):
        yield statement

info_pragma = re.compile(r"--\s+INFO:\s+(.*?)$")

with connection:
    try:
        with connection.cursor() as cursor:
            def include_sql (name):
                script = load_sql(name, {
                    "codex_events": options["--events-table"],
                    "cohort_ddl": cohort_ddl.rstrip(),
                    "cohort_sql": cohort_sql.rstrip()})

                for statement in script:
                    if (statement.strip() == ''):
                        continue

                    for line in statement.splitlines():
                        m = info_pragma.match(line)
                        if (m is not None):
                            logger.info(m.group(1))

                    cursor.execute(statement, {
                       "cohort_definition": cohort_arg})

            include_sql("codex_functions")
            include_sql("codex_tables")
            include_sql("codex_cdm5")
            include_sql("codex_indices")

    except psycopg2.extensions.QueryCanceledError:
        connection.rollback()
        error("operation cancelled by user")

logger.info("staging done.")

force_print("\ncontent of '%s.%s':" % (
    options["--schema"], options["--events-table"]))

with connection.cursor() as cursor:
    cursor.execute("""\
     SELECT event_class,
            count(event_class)
       FROM %(events_table)s
   GROUP BY event_class;
        """ % {
            "events_table": options["--events-table"]})

    classes, counts = zip(*sorted(cursor.fetchall()))
    col_a_width = max([len(x) for x in classes])
    col_b_width = max([len("{:,d}".format(x)) for x in counts])

    for (event_class, count) in zip(classes, counts):
        print("  %s  %s entries" % (
            event_class.ljust(col_a_width),
            "{:,d}".format(count).rjust(col_b_width)))

connection.close()
print()
