#This file is part of Higgins. The COPYRIGHT file at the top level of
#this repository contains the full copyright notices and license
#terms.

from collections import defaultdict
from contextlib import contextmanager
from datetime import date, datetime, timedelta
import logging
import os
import threading
import glob
from sqlalchemy import (Column, Integer, String, Boolean, Date, TIMESTAMP,
                        ForeignKey, Sequence, UniqueConstraint, orm, update,
                        create_engine, event, or_, and_, not_, func, desc)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, validates, relationship

from . import config
from .utils import call, mkdir, UserException

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('higgins')

CTX = None


class ValidationError(Exception):
    pass


class Base(object):

    id = Column(Integer, primary_key=True)
    timestamp = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now,
                       index=True)

Base = declarative_base(cls=Base)


def resolve_dependencies(keys, triage, deps, level=0, parent=None):
    for k in keys:
        cur_lev, ancestors = deps.get(k, (0, set()))
        if parent and parent in ancestors:
            logger.debug('Triage: skip already visited transition "%s" -> "%s"'\
                        % (parent, k))
            return
        ancestors.add(parent)
        new_level = max(level, cur_lev)
        deps[k] = (new_level, ancestors)

        resolve_dependencies(
            triage.get(k, []),
            triage,
            deps,
            level=new_level+1,
            parent=k
        )

def check_config(cls, conf, **ctx_filter):
    # Check if all records defined in config are known
    session = CTX.get('session')
    active_ids = []

    for record_filter, default_values in cls.parse_config(conf):
        record_filter.update(ctx_filter)
        record = session.query(cls).filter_by(**record_filter).first()
        default_values.update(record_filter)

        if record is not None:
            for k, v in default_values.iteritems():
                setattr(record, k, v)
            active_ids.append(record.id)
            continue

        # Create new record in repo table
        table = cls.__tablename__
        logger.info('New %s "%s" created' % (table, default_values['key']))
        new_record = cls(**default_values)
        session.add(new_record)
        session.flush()
        active_ids.append(new_record.id)

    # Disable unknown records
    if active_ids:
        stm = update(cls).values(
            active=cls.id.in_(active_ids)
        )
        if ctx_filter:
            where = [getattr(cls, k) == v for k, v in ctx_filter.iteritems()]
            stm = stm.where(*where)
        session.execute(stm)
        session.flush()


class Project(Base):
    __tablename__ = 'project'
    __table_args__ = (
        UniqueConstraint('key', 'path'),
    )

    key = Column(String)
    path = Column(String)
    name = Column(String)
    sync_all = Column(String)
    active = Column(Boolean)
    repositories = relationship("Repository")

    @classmethod
    def parse_config(cls, conf):
        for key, value in conf.iteritems():
            path = value.get('path')
            if path:
                index = {
                    'key': key,
                    'path': path,
                }
                default_values = {
                    'name': value.get('name', key),
                    'sync_all': value.get('sync-all'),
                }
                yield index, default_values

    def sync_repositories(self):
        if self.sync_all:
            logger.debug('Sync repositories for project %s' % self.key)
            call(self.sync_all, self.path, True)
            return

        for repo in self.repositories:
            if not repo.active:
                continue
            logger.debug('Sync repository %s for project %s' % (
                repo.key, self.key
            ))
            repo.sync()

    def dirty_repositories(self):
        session = CTX.get('session')
        # Prepare indexes
        rev2repo = {}
        key2rev = {}
        for repo in self.repositories:
            rev = repo.current_revision()
            if rev is None:
                # Nothing sensible to do
                return
            rev2repo[rev] = repo

        # Find runs for current revisions on current repos
        query = session.query(Run.revision).filter(
            Run.revision.in_(rev2repo),
            Run.repository_id.in_(r.id for r in rev2repo.itervalues())
        )
        run_revs = set(r for r, in query)

        # Build dirty repos
        for rev in run_revs:
            rev2repo.pop(rev)
        dirty = rev2repo.values()
        if not dirty:
            return

        # Resolve dependencies
        keys = [r.key for r in dirty]
        triage = config.get(self.key).get('triage', {})
        deps = {}
        resolve_dependencies(
            keys,
            triage,
            deps,
        )

        # Order them by level and alphabetically
        levels = defaultdict(list)
        for k, (l, _) in deps.iteritems():
            levels[l].append(k)
        keys = set(k for i in sorted(levels) for k in sorted(levels[i]))
        return keys


class Repository(Base):
    __tablename__ = 'repository'
    __table_args__ = (
        UniqueConstraint('project_id', 'key', 'path'),
    )

    project_id = Column(Integer, ForeignKey('project.id'))
    key = Column(String)
    path = Column(String)
    active = Column(Boolean)
    project = relationship("Project")

    @classmethod
    def parse_config(cls, conf):
        for key, value in conf.iteritems():
            index = {
                'key': key,
                'path': value,
            }
            yield index, {}

    def sync(self):
        path = self.full_path()
        if os.path.isdir(os.path.join(path, '.hg')):
            call('hg pull -u', path, True)
        else:
            logger.error('Repository "%s" cannot be synced' % path)

    def full_path(self):
        return os.path.join(self.project.path, self.path)

    def current_revision(self):
        if hasattr(self, '_current_revision'):
            return self._current_revision

        path = self.full_path()
        if os.path.isdir(os.path.join(path, '.hg')):
            cmd = 'hg parent --template {node}'
        else:
            logger.error(
                'Cannot find active revision for repository "%s"' % path)
            return None
        process, stdout, stderr = call(cmd, path, True)
        if process.returncode != 0:
            revision = None
        else:
            revision = stdout.strip()

        self._current_revision = revision
        return revision

    def update_to(self, revision):
        path = self.full_path()
        if os.path.isdir(os.path.join(path, '.hg')):
            call('hg update %s' % revision, path, True)
        else:
            logger.error('Repository "%s" cannot be updated' % path)


class Build(Base):
    __tablename__ = 'build'

    status = Column(Integer, default='waiting')
    environment = Column(String)
    success = Column(Boolean)
    project_id = Column(Integer, ForeignKey('project.id'))
    project = relationship("Project")
    runs = relationship("Run")

    @validates('status')
    def validate_status(self, field, value):
        if not value in ('waiting', 'done'):
            raise ValidationError('Invalid value %s for field %s',
                                  (value, field))
        return value

    @classmethod
    def new_build(self, conf, project, dirty):
        # Skip build creation if not needed
        if not any(conf.get(k) for k in dirty):
            logger.debug("No build needed on project %s" % project.key)
            return

        # Create build
        session = CTX.get('session')
        build = Build(status='waiting', project_id=project.id)
        session.add(build)
        session.flush()
        logger.info("New build in project %s for repositories: %s" % (
            project.key,
            ', '.join(r.key for r in project.repositories if r.key in dirty)),
        )

        # Cache mapping from key to test records
        key2tests = defaultdict(list)
        query = session.query(Test.key, Test).filter_by(
            active=True,
            project_id=project.id,
        )
        for key, test in query:
            key2tests[key].append(test)

        # Create runs
        for repo in project.repositories:
            if repo.key not in dirty or not conf.get(repo.key):
                # An empty run (not linked to a test) allows to
                # capture the revision
                r = Run(
                    build_id=build.id,
                    repository_id=repo.id,
                    revision=repo.current_revision(),
                )
                session.add(r)
                continue

            # Create all the tests
            for k in conf.get(repo.key):
                for test in key2tests[k]:
                    for command in test.expand_command(repo):
                        r = Run(
                            build_id=build.id,
                            repository_id=repo.id,
                            revision=repo.current_revision(),
                            test_id=test.id,
                            command=command,
                        )
                        session.add(r)

    @classmethod
    def next_waiting_build(cls):
        session = CTX.get('session')
        query = session.query(Build)\
                       .filter_by(status='waiting')\
                       .order_by(Build.id)
        return query.first()

    def launch(self):
        for run in self.runs:
            run.repository.update_to(run.revision)

        logger.info('Launch build #%s ' % self.id)
        results = [r.launch() for r in self.runs]
        success = all(results)

        logger.info('Build #%s finished %s success' % (
            self.id,
            'with' if success else 'without'
        ))
        self.status = 'done'
        self.success = success

    def run_group(self):
        by_key = defaultdict(list)
        for run in self.runs:
            if not run.test:
                continue
            key = '%s:%s' % (run.repository.key, run.test.key)
            by_key[key].append(run)

        for key in sorted(by_key):
            runs = by_key[key]
            statuses = defaultdict(int)
            for r in runs:
                statuses[r.get_status()] += 1
            yield key, statuses

    def get_runs(self, key):
        repo_key, test_key = key.split(':')
        for run in self.runs:
            if run.repository.key == repo_key and run.test.key == test_key:
                yield run


class Run(Base):
    __tablename__ = 'run'
    __table_args__ = (
        UniqueConstraint('build_id', 'repository_id', 'revision', 'test_id',
                         'command'),
    )

    build_id = Column(Integer, ForeignKey('build.id'))
    repository_id = Column(Integer, ForeignKey('repository.id'))
    revision = Column(String)
    test_id = Column(Integer, ForeignKey('test.id'))
    command = Column(String)
    return_code = Column(Integer)
    repository = relationship('Repository')
    test = relationship('Test')
    build = relationship('Build')

    def launch(self):
        session = CTX.get('session')
        # Skip run not linked to a test
        success = True
        if not self.test_id:
            return success

        logger.debug('Launch test %s' % self.test.key)
        cwd = self.repository.full_path()
        process, stdout, stderr = call(self.command, cwd=cwd)
        self.return_code = process.returncode
        self.save_data(process, stdout, stderr)
        if success and process.returncode != 0:
            success = False
        return success

    def save_data(self, process, stdout, stderr):
        # TODO read config for extra output files
        data = {
            'stdout': stdout,
            'stderr': stderr
        }

        run_dir = os.path.join(config.BUILD_DIR, str(self.id))
        mkdir(run_dir)

        # Save output of each test. XXX zip ?
        for name, data in data.iteritems():
            data_path = os.path.join(run_dir, name)
            logger.debug('Create output file %s' % data_path)
            with open(data_path, 'w') as fh:
                fh.write(data)

    def get_data(self, filename=None):
        if self.get_status() == 'waiting':
            return {}

        run_dir = os.path.join(config.BUILD_DIR, str(self.id))
        if filename:
            full_name = os.path.join(run_dir, filename)
            with open(full_name) as fh:
                return fh.read()

        data = {}
        for name in os.listdir(run_dir):
            if os.path.isdir(name):
                continue
            full_name = os.path.join(run_dir, name)
            with open(full_name) as fh:
                data[name] = fh.read()
        return data


    def get_status(self):
        if self.return_code is None:
            return 'waiting'
        return 'success' if self.return_code == 0 else 'failed'

class Test(Base):
    __tablename__ = 'test'
    __table_args__ = (
        UniqueConstraint('project_id', 'key', 'command'),
    )

    project_id = Column(Integer, ForeignKey('project.id'))
    key = Column(String)
    command = Column(String)
    path = Column(String)
    exclude = Column(String)
    active = Column(Boolean)
    project = relationship("Project")

    @classmethod
    def parse_config(cls, conf):
        for key, value in conf.iteritems():
            if isinstance(value, basestring):
                index = {
                    'key': key,
                    'command': value,
                }
                yield index, {}

            if isinstance(value, list):
                index = {
                    'key': key,
                }
                for command in value:
                    index['command'] = command
                    yield index, {}

            if isinstance(value, dict):
                command = value.get('command')
                if not command:
                    continue
                index = {
                    'key': key,
                    'command': command,
                }
                default_values= {
                    'path': value.get('path'),
                    'exclude': value.get('exclude'),
                }
                yield index, default_values

    def expand_command(self, repository):
        if not self.path:
            yield self.command
            return

        repo_path = repository.full_path()
        all_path = set(glob.glob(os.path.join(repo_path, self.path)))
        if self.exclude:
            exclude = set(glob.glob(os.path.join(repo_path, self.exclude)))
            all_path = all_path - exclude

        for path in all_path:
            relpath = os.path.relpath(path, repo_path)
            yield '%s %s' % (self.command, relpath)


def create_all():
    Base.metadata.create_all(CTX.engine)


def session():
    return CTX.session()


class Context(object):

    def __init__(self, db_uri, echo=False):
        global CTX

        if CTX is not None:
            raise Exception('Context cannot be instantiated twice')
        CTX = self
        self.engine = create_engine(db_uri, echo=echo)
        event.listen(self.engine, 'connect', self._fk_pragma_on_connect)

        self.sessionmaker = sessionmaker(bind=self.engine)

        # Makes data thread-safe
        self.data = threading.local()

    @contextmanager
    def session(self):
        session = self.sessionmaker()
        self.data.session = session
        try:
            yield session
            session.commit()
        except:
            logger.debug('Rollback session (dirty: %s)' % session.dirty)
            session.rollback()
            raise
        finally:
            # Reset session
            session.close()
            self.data.session = None

    @staticmethod
    def _fk_pragma_on_connect(dbapi_con, con_record):
        dbapi_con.execute('pragma foreign_keys=ON')

    def get(self, name, default=None):
        if not hasattr(self.data, name):
            return default
        return getattr(self.data, name)

    def set(self, name, value):
        return setattr(self.data, name, value)
