#!/usr/bin/env python
# -*- coding: utf-8
"""Fetches information from the variable positions table"""

import os
import sys
import copy
import random
import argparse

from collections import Counter

import anvio
import anvio.tables as t
import anvio.dbops as dbops
import anvio.utils as utils
import anvio.dictio as dictio
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpath
import anvio.ccollections as ccollections

from anvio.errors import ConfigError


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__status__ = "Development"


pp = terminal.pretty_print
progress = terminal.Progress()
run = terminal.Run(width = 60)


class VariablePositions:
    def __init__(self, args = None, p=progress, r=run):
        self.samples_of_interest = set([])
        self.min_ratio = 0
        self.min_occurrence = 1
        self.num_positions_from_each_split = 2
        self.profile_db_path = None
        self.annotation_db_path = None
        self.output_file_path = None
        self.quince_mode = False

        self.splits_of_interest = set([])
        self.collection_id = None
        self.bin_id = None

        self.progress = p
        self.run = r

        if args:
            self.process_cmd_line_args(args)

        self.variable_positions_table = {} 
        self.unique_pos_identifier = 0
        self.split_name_position_dict = {}
        self.unique_pos_id_to_entry_id = {}
        self.summaries = None
        self.contig_sequences = None
        self.input_file_path = None


    def process_cmd_line_args(self, args):
        self.args = args
        self.bin_id = args.bin_id
        self.collection_id = args.collection_id
        self.splits_of_interest_path = args.splits_of_interest

        if args.samples_of_interest:
            filesnpath.is_file_exists(args.samples_of_interest)
            self.samples_of_interest = set([s.strip() for s in open(args.samples_of_interest).readlines()])

        self.min_ratio = float(args.min_ratio)
        self.min_occurrence = int(args.min_occurrence)
        self.num_positions_from_each_split = int(args.num_positions_from_each_split)
        self.min_scatter = int(args.min_scatter)
        self.profile_db_path = args.profile_db
        self.annotation_db_path = args.annotation_db
        self.quince_mode = args.quince_mode

        if args.output:
            self.output_file_path = args.output
            filesnpath.is_output_file_writable(self.output_file_path)


    def init(self):
        self.progress.new('Init')

        self.progress.update('Making sure our databases are here, and they are compatible ..')
        if not self.profile_db_path:
            raise ConfigError, 'You need to provide a profile database.'

        if not self.annotation_db_path:
            raise ConfigError, 'You need to provide an annotation database.'

        filesnpath.is_file_exists(self.annotation_db_path)
        filesnpath.is_file_exists(self.profile_db_path)

        dbops.is_annotation_and_profile_dbs_compatible(self.annotation_db_path, self.profile_db_path)

        self.progress.update('Attempting to get our splits of interest sorted ..')
        if self.collection_id:
            # the user wants to go with the collection id path. fine. we will get our split names from
            # the profile database.
            if not self.bin_id:
                self.progress.end()
                raise ConfigError, 'When you declare a collection id, you must also declare a bin name\
                                    (from which the split names of interest will be acquired)'
            if self.splits_of_interest:
                self.progress.end()
                raise ConfigError, 'You declared a collection id and bin name so anvi\'o can find out\
                                    splits of interest, but you also have specified a file for split names?\
                                    This is confusing, and you should choose one way or another :/'

            self.splits_of_interest = ccollections.GetSplitNamesInBins(self.args).get_split_names_only()
        else:
            # OK. no collection id. we will go oldschool. we whope to find what we are looking for in
            # self.splits_of_interst_path  at this point (which may have been filled through the command
            # line client), or in self.splits_of_interest (which may have been filled in by another program)
            if not self.splits_of_interest:
                if not self.splits_of_interest_path:
                    self.progress.end()
                    raise ConfigError, 'You did not declare a source for split names. You either should give me\
                                        a file with split names you are interested in, or a collection id and\
                                        bin name so I can learn split names from the profile database.'
                filesnpath.is_file_exists(self.splits_of_interest_path)
                self.splits_of_interest = set([c.strip().replace('\r', '') for c in open(self.splits_of_interest_path).readlines()])

        self.input_file_path = '/' + '/'.join(os.path.abspath(self.profile_db_path).split('/')[:-1])

        self.progress.update('Reading variable positions table ...')
        profile_db = dbops.ProfileDatabase(self.profile_db_path)
        self.sample_ids = profile_db.samples # we set this now, but we will overwrite it with args.samples_of_interest if necessary
        self.variable_positions_table = profile_db.db.get_table_as_dict(t.variable_positions_table_name)
        profile_db.disconnect()

        self.progress.update('Reading splits info table ...')
        annotation_db = dbops.AnnotationDatabase(self.annotation_db_path)
        self.splits_info_table = annotation_db.db.get_table_as_dict(t.splits_info_table_name)
        self.num_splits_in_db = len(self.splits_info_table)
        if self.quince_mode:
            self.progress.update('Reading summaries ...')
            self.summaries = dictio.read_serialized_object(os.path.join(self.input_file_path, 'SUMMARY.cp'))
            self.progress.update('Reading contig sequences table ...')
            self.contig_sequences = annotation_db.db.get_table_as_dict(t.contig_sequences_table_name)
        annotation_db.disconnect()

        self.progress.end()

        self.process_variable_positions_table()

        self.set_unique_pos_identification_numbers() # which allows us to track every unique position across samples

        self.filter_based_on_scattering_factor()

        self.filter_based_on_num_positions_from_each_split()

        if self.quince_mode: # will be very costly...
            self.recover_base_frequencies_for_all_samples() 


    def filter(self, filter_name, test_func):
        self.progress.new('Filtering based on "%s"' % filter_name)
        num_entries_before_filter = len(self.variable_positions_table)

        entry_ids_to_remove, counter = set([]), 0

        for entry_id in self.variable_positions_table:
            if counter % 1000 == 0:
                self.progress.update('identifying entries to remove :: %s' % pp(counter))

            counter += 1

            if test_func(self.variable_positions_table[entry_id]):
                entry_ids_to_remove.add(entry_id)
                continue

        self.progress.update('removing %s entries from table ...' % pp(len(entry_ids_to_remove)))
        for entry_id in entry_ids_to_remove:
            self.variable_positions_table.pop(entry_id)

        num_entries_after_filter = len(self.variable_positions_table)

        self.progress.end()

        self.run.info('Remaining entries after "%s" filter' % filter_name,
                      '%s (filter removed %s entries)' % (pp(num_entries_after_filter),
                                                          pp(num_entries_before_filter - num_entries_after_filter)),
                      mc = 'green')

        self.check_if_variable_table_is_empty()


    def check_if_variable_table_is_empty(self):
        if not len(self.variable_positions_table):
            self.progress.end()
            self.run.info_single('No variable positions left to work with. Quitting.', 'red', 1, 1)
            sys.exit()


    def process_variable_positions_table(self):
        self.run.info('Variability table', '%s entries in %s splits across %s samples'\
                % (pp(len(self.variable_positions_table)), pp(self.num_splits_in_db), pp(len(self.sample_ids))))

        self.run.info('Samples in the profile db', ', '.join(sorted(self.sample_ids)))
        if self.samples_of_interest:
            samples_missing_from_db = [sample for sample in self.samples_of_interest if sample not in self.sample_ids]

            if len(samples_missing_from_db):
                raise ConfigError, 'One or more samples you are interested in seem to be missing from\
                                    the profile database: %s' % ', '.join(samples_missing_from_db)

            self.run.info('Samples of interest', ', '.join(sorted(list(self.samples_of_interest))))
            self.sample_ids = self.samples_of_interest
            self.filter('samples of interest', lambda x: x['sample_id'] not in self.samples_of_interest)

        if self.splits_of_interest:
            self.run.info('Num splits of interest', pp(len(self.splits_of_interest)))
            self.filter('splits of interest', lambda x: x['split_name'] not in self.splits_of_interest)

        # let's report the number of positions reported in each sample before filtering any futrher:
        num_positions_each_sample = Counter([v['sample_id'] for v in self.variable_positions_table.values()])
        self.run.info('Total number of variable positions in samples', '; '.join(['%s: %s' % (s, num_positions_each_sample[s]) for s in sorted(self.sample_ids)]))

        if self.min_ratio:
            self.run.info('Min n2/n1 ratio', self.min_ratio)
            self.filter('n2/n1', lambda x: x['n2n1ratio'] < self.min_ratio)

        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            v['unique_position_id'] = '_'.join([v['split_name'], str(v['pos'])])

        if self.min_occurrence == 1:
            return

        if self.min_occurrence > 1:
            self.run.info('Min occurrence requested', self.min_occurrence)

        self.progress.new('Processing positions further')

        self.progress.update('counting occurrences of each position across samples ...')
        unique_position_id_occurrences = {}
        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            if unique_position_id_occurrences.has_key(v['unique_position_id']):
                unique_position_id_occurrences[v['unique_position_id']] += 1
            else:
                unique_position_id_occurrences[v['unique_position_id']] = 1

        self.progress.update('identifying entries that occurr in less than %d samples ...' % (self.min_occurrence))
        entry_ids_to_remove = set([])
        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            if not unique_position_id_occurrences[v['unique_position_id']] >= self.min_occurrence:
                entry_ids_to_remove.add(entry_id)

        self.progress.update('removing %s entries from table ...' % pp(len(entry_ids_to_remove)))
        for entry_id in entry_ids_to_remove:
            self.variable_positions_table.pop(entry_id)

        self.progress.end()

        self.check_if_variable_table_is_empty()


    def get_unique_pos_identification_number(self, unique_position_id):
        if unique_position_id in self.split_name_position_dict:
            return self.split_name_position_dict[unique_position_id]
        else:
            self.split_name_position_dict[unique_position_id] = self.unique_pos_identifier
            self.unique_pos_identifier += 1
            return self.split_name_position_dict[unique_position_id]


    def gen_unique_pos_identifier_to_entry_id_dict(self):
        self.progress.new('Generating the `unique pos` -> `entry id` dict')
        self.progress.update('...')

        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            u = v['unique_position_id']
            if u in self.unique_pos_id_to_entry_id:
                self.unique_pos_id_to_entry_id[u].add(entry_id)
            else:
                self.unique_pos_id_to_entry_id[u] = set([entry_id])

        self.progress.end()


    def set_unique_pos_identification_numbers(self):
        self.progress.new('Further processing')
        self.progress.update('re-setting unique identifiers to track split/position pairs across samples')

        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            v['unique_pos_identifier'] = self.get_unique_pos_identification_number(v['unique_position_id'])
            v['parent'] = self.splits_info_table[v['split_name']]['parent']

        self.progress.end()


    def filter_based_on_scattering_factor(self):
        """To remove any unique entry from the variable positions table that describes a variable position
           and yet is not helpful to distinguish samples from eachother."""

        if self.min_scatter == 0:
            return

        num_samples = len(self.sample_ids)
        if self.min_scatter > num_samples / 2:
            raise ConfigError, 'Minimum scatter (%d) can not be more than half of the number of samples\
                                (%d) :/' % (self.min_scatter, num_samples)

        self.run.info('Min scatter', self.min_scatter)

        num_entries_before_filter = len(self.variable_positions_table)

        # we need the unique pos_id to entry id dict filled for this function:
        self.gen_unique_pos_identifier_to_entry_id_dict()

        self.progress.new('Examining scatter')
        self.progress.update('...')

        entry_ids_to_remove = set([])

        for unique_pos_id in self.unique_pos_id_to_entry_id:
            entry_ids = self.unique_pos_id_to_entry_id[unique_pos_id]

            # find how many samples it occurs:
            num_occurrence = len(entry_ids)

            # if the number of samples it occurs more 
            scatter = num_occurrence if num_occurrence < num_samples - num_occurrence else num_samples - num_occurrence
            if scatter < self.min_scatter:
                entry_ids_to_remove.update(entry_ids)

        self.progress.update('removing %s entries from table ...' % pp(len(entry_ids_to_remove)))
        for entry_id in entry_ids_to_remove:
            self.variable_positions_table.pop(entry_id)

        num_entries_after_filter = len(self.variable_positions_table)

        self.progress.end()

        self.run.info('Remaining entries after "minimum scatter" filter',
                      '%s (filter removed %s entries)' % (pp(num_entries_after_filter),
                                                          pp(num_entries_before_filter - num_entries_after_filter)),
                      mc = 'green')

        self.check_if_variable_table_is_empty()


    def filter_based_on_num_positions_from_each_split(self):
        self.run.info('Num positions to keep from each split (-n)', self.num_positions_from_each_split)

        num_entries_before_filter = len(self.variable_positions_table)

        self.progress.new('Filtering based on -n')

        self.progress.update('Generating splits and positions tuples ...')
        splits_and_positions = set([(v['split_name'], v['unique_pos_identifier']) for v in self.variable_positions_table.values()])
        unique_positions_to_remove = set([])

        self.progress.update('Generating positions in splits dictionary ...')
        positions_in_splits = {}
        for split_name, position in splits_and_positions:
            if not positions_in_splits.has_key(split_name):
                positions_in_splits[split_name] = set([])

            positions_in_splits[split_name].add(position)

        self.progress.update('Randomly subsampling from splits with num positions > %d' % self.num_positions_from_each_split)
        for split_name in positions_in_splits:
            if self.num_positions_from_each_split and len(positions_in_splits[split_name]) > self.num_positions_from_each_split:
                positions_to_keep = set(random.sample(positions_in_splits[split_name], self.num_positions_from_each_split))
                for pos in (positions_in_splits[split_name] - positions_to_keep):
                    unique_positions_to_remove.add(pos)

        self.progress.update('Identifying entry ids to remove ...')
        entry_ids_to_remove = set([entry_id for entry_id in self.variable_positions_table if self.variable_positions_table[entry_id]['unique_pos_identifier'] in unique_positions_to_remove])

        self.progress.update('Removing %d positions ...' % len(unique_positions_to_remove))
        for entry_id in entry_ids_to_remove:
            self.variable_positions_table.pop(entry_id)

        num_entries_after_filter = len(self.variable_positions_table)

        self.progress.end()

        self.run.info('Remaining entries after "-n" filter',
                      '%s (filter removed %s entries)' % (pp(num_entries_after_filter),
                                                          pp(num_entries_before_filter - num_entries_after_filter)),
                      mc = 'green')

        self.check_if_variable_table_is_empty()


    def recover_base_frequencies_for_all_samples(self):
        """this function populates variable_positions_table dict with entries from samples that have no
           variation at nucleotide positions reported in the table"""

        self.progress.new('Recovering base frequencies for all')

        samples_wanted = self.samples_of_interest if self.samples_of_interest else self.sample_ids
        splits_wanted = self.splits_of_interest if self.splits_of_interest else set(self.splits_info_table.keys())
        next_available_entry_id = max(self.variable_positions_table.keys()) + 1

        self.progress.update('creating a dicts to track missing base frequencies for each sample / split / pos')
        split_pos_to_unique_pos_identifier = {}
        splits_to_consider = {}
        for split_name in splits_wanted:
            splits_to_consider[split_name] = {}
            split_pos_to_unique_pos_identifier[split_name] = {}

        self.progress.update('populating the dict to track missing base frequencies for each sample / split / pos')
        for entry_id in self.variable_positions_table:
            v = self.variable_positions_table[entry_id]
            p = v['pos']
            d = splits_to_consider[v['split_name']]
            u = split_pos_to_unique_pos_identifier[v['split_name']]

            if d.has_key(p):
                d[p].remove(v['sample_id'])
            else:
                d[p] = copy.deepcopy(samples_wanted)
                d[p].remove(v['sample_id'])

            if not u.has_key(p):
                u[p] = v['unique_pos_identifier']

        counter = 0
        for split in splits_to_consider:
            counter += 1
            self.progress.update('accessing split summaries and updating variable positions dict :: %s' % pp(counter))

            split_info = self.splits_info_table[split]
            summary = dictio.read_serialized_object(os.path.join(self.input_file_path, self.summaries[split]))
            for pos in splits_to_consider[split]:
                parent_seq = self.contig_sequences[split_info['parent']]['sequence']
                base_at_pos = parent_seq[split_info['start'] + pos]
                for sample in splits_to_consider[split][pos]:
                    self.variable_positions_table[next_available_entry_id] = {'parent': split_info['parent'],
                                                                              'n2n1ratio': 0,
                                                                              'consensus': base_at_pos,
                                                                              'A': 0, 'T': 0, 'C': 0, 'G': 0, 'N': 0,
                                                                              'pos': pos,
                                                                              'coverage': summary[sample]['coverage'][pos],
                                                                              'sample_id': sample,
                                                                              'competing_nts': base_at_pos + base_at_pos,
                                                                              'unique_pos_identifier': split_pos_to_unique_pos_identifier[split][pos],
                                                                              'split_name': split}
                    self.variable_positions_table[next_available_entry_id][base_at_pos] = summary[sample]['coverage'][pos]
                    next_available_entry_id += 1

        self.progress.end()

    def report(self):
        self.progress.new('Reporting')

        new_structure = [t.variable_positions_table_structure[0]] + ['unique_pos_identifier'] + t.variable_positions_table_structure[1:] + ['parent']

        self.progress.update('exporting variable positions table as a TAB-delimited file ...')

        # FIXME: THIS HAS TO GO INTO THE TABLE THIS WAY
        for e in self.variable_positions_table:
            self.variable_positions_table[e]['competing_nts'] = ''.join(sorted(self.variable_positions_table[e]['competing_nts']))

        utils.store_dict_as_TAB_delimited_file(self.variable_positions_table, args.output, new_structure)
        self.progress.end()

        self.run.info('Num entries reported', pp(len(self.variable_positions_table)))
        self.run.info('Output File', args.output) 
        self.run.info('Num nt positions reported', pp(len(set([e['unique_pos_identifier'] for e in self.variable_positions_table.values()]))))



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Extract information from the variable positions table')
    groupD = parser.add_argument_group('DATABASES', 'Declaring relevant anvi\'o databases. First things first.')
    groupD.add_argument('-a', '--annotation-db', metavar = "ANNOTATION_DB", required = True,\
                        help = 'Annotation database.')
    groupD.add_argument('-p', '--profile-db', metavar = "PROFILE_DB", required = True,\
                        help = 'Profile database.')

    groupS = parser.add_argument_group('SPLITS', 'Declaring relevant splits for the analysis. There are two\
                                                  ways to do it. One, you can give a file path with split names,\
                                                  or, as an alternative, you can provide a collection id with a\
                                                  bin name.')
    groupS.add_argument('-s', '--splits-of-interest', metavar = "SPLITS",\
                         help = 'File path that contains the list of splits to analyze.')
    groupS.add_argument('-c', '--collection-id', metavar = 'COLLECTION-ID',
                        help = 'Collection ID to look for the bin you are interested in.')
    groupS.add_argument('-b', '--bin-id', metavar = 'BIN-NAME',
                        help = 'Bin name in the collection to recover split names of interest.')

    groupO = parser.add_argument_group('OUTPUT', 'Output file and output style')
    groupO.add_argument('-o', '--output', type=str, default = 'variability.txt', help = 'Output path for the resulting matrix.')
    groupO.add_argument('-S', '--samples-of-interest', metavar = "SAMPLES", default = None,\
                        help = 'File that contains the list of samples to foucs. If not declared, all samples are used.')
    groupO.add_argument('--quince-mode', action='store_true', default=False,
                  help = 'The default behavior is to report base frequencies of nucleotide positions only if there\
                          is any variation reported during profiling (which by default uses some heuristics to minimize\
                          the impact of error-driven variation). So, if there are 10 samples, and a given position has been\
                          reported as a varaible site during profiling in only one of those samples, there will be no\
                          information will be stored in the database for the remaining 9. When this flag is\
                          used, we go back to each sample, and report base frequencies for each sample at this position\
                          even if they do not vary. It will take considerably longer to report when this flag is on, and the use\
                          of it will increase the file size dramatically, however it is inevitable for some statistical approaches\
                          (as well as for some beautiful visualizations).')

    groupE = parser.add_argument_group('EXTRAS', 'Parameters that will help you to do a very precise analysis.\
                                                  If you declare nothing from this bunch, you will get "everything"\
                                                  to play with, which is not necessarily a good thing...')
    groupE.add_argument('-n', '--num-positions-from-each-split', type=int, default = 2,
                  help = 'Each split may have one or more variable positions. What is the maximum number of positons\
                          to report from each split is described via this paramter. The default is %(default)d. Which\
                          means from every split, a maximum of %(default)d eligable SNP is going to be reported.')
    groupE.add_argument('-r', '--min-ratio', type=float, default = 0, metavar = 'RATIO',
                  help = 'Minimum ratio of the competing nucleotides at a given position. Default is %(default)d.')
    groupE.add_argument('-x', '--min-occurrence', type=float, default = 1, metavar = 'NUM_SAMPLES',
                  help = 'Minimum number of samples a nucleotide position should be reported as variable. Default is %(default)d.\
                          If you set it to 2, for instance, each eligable variable position will be expected to appear in at least\
                          two samples, which will reduce the impact of stochastic, or unintelligeable varaible positions.')
    groupE.add_argument('-m', '--min-scatter', type=int, default = 0,
                  help = 'This one is tricky. If you have N samples in your dataset, a given variable position x in one\
                          of your splits can split your N samples into `t` groups based on the identity of the\
                          variation they harbor at position x. For instance, `t` would have been 1, if all samples had the same\
                          type of variation at position x (which would not be very interesting, because in this case\
                          position x would have zero contribution to a deeper understanding of how these samples differ\
                          based on variability. When `t` > 1, it would mean that identities at position x across samples\
                          do differ. But how much scattering occurs based on position x when t > 1? If t=2, how many\
                          samples ended in each group? Obviously, even distribution of samples across groups may tell\
                          us something different than uneven distribution of samples across groups. So, this parameter\
                          filters out any x if "the number of samples in the second largest group" (=scatter) is less\
                          than -m. Here is an example: lets assume you have 7 samples. While 5 of those have AG, 2\
                          of them have TC at position x. This would mean scatter of x is 2. If you set -m to 2, this\
                          position would not be reported in your output matrix. The default value for -m is\
                          %(default)d, which means every x found in the database and survived previous filtering\
                          criteria will be reported. Naturally, -m can not be more than half of the number of samples.\
                          Please refer to the user documentation if this is confusing.')
    args = parser.parse_args()
    
    try:
        variable_positions = VariablePositions(args)
        variable_positions.init()
        variable_positions.report()
    except ConfigError, e:
        print e
