# -*- coding: utf-8 -*-
"""
Base utilities and constants for ObsPy.

:copyright:
    The ObsPy Development Team (devs@obspy.org)
:license:
    GNU Lesser General Public License, Version 3
    (https://www.gnu.org/copyleft/lesser.html)
"""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
from future.builtins import *  # NOQA
from future.utils import PY2

import builtins
import doctest
import inspect
import io
import os
import re
import sys
import tempfile
import unicodedata
from collections import OrderedDict

import numpy as np
import pkg_resources
import requests
from future.utils import native_str
from pkg_resources import iter_entry_points

from obspy.core.util.misc import to_int_or_zero, buffered_load_entry_point


# defining ObsPy modules currently used by runtests and the path function
DEFAULT_MODULES = ['clients.filesystem', 'core', 'db', 'geodetics', 'imaging',
                   'io.ah', 'io.arclink', 'io.ascii', 'io.cmtsolution',
                   'io.cnv', 'io.css', 'io.iaspei', 'io.win', 'io.gcf',
                   'io.gse2', 'io.json', 'io.kinemetrics', 'io.kml',
                   'io.mseed', 'io.ndk', 'io.nied', 'io.nlloc', 'io.nordic',
                   'io.pdas', 'io.pde', 'io.quakeml', 'io.reftek', 'io.sac',
                   'io.scardec', 'io.seg2', 'io.segy', 'io.seisan', 'io.sh',
                   'io.shapefile', 'io.seiscomp', 'io.stationtxt',
                   'io.stationxml', 'io.wav', 'io.xseed', 'io.y', 'io.zmap',
                   'realtime', 'scripts', 'signal', 'taup']
NETWORK_MODULES = ['clients.arclink', 'clients.earthworm', 'clients.fdsn',
                   'clients.iris', 'clients.neic', 'clients.nrl',
                   'clients.seedlink', 'clients.seishub', 'clients.syngine']
ALL_MODULES = DEFAULT_MODULES + NETWORK_MODULES

# default order of automatic format detection
WAVEFORM_PREFERRED_ORDER = ['MSEED', 'SAC', 'GSE2', 'SEISAN', 'SACXY', 'GSE1',
                            'Q', 'SH_ASC', 'SLIST', 'TSPAIR', 'Y', 'PICKLE',
                            'SEGY', 'SU', 'SEG2', 'WAV', 'WIN', 'CSS',
                            'NNSA_KB_CORE', 'AH', 'PDAS', 'KINEMETRICS_EVT',
                            'GCF']
EVENT_PREFERRED_ORDER = ['QUAKEML', 'NLLOC_HYP']
INVENTORY_PREFERRED_ORDER = ['STATIONXML', 'SEED', 'RESP']
# waveform plugins accepting a byteorder keyword
WAVEFORM_ACCEPT_BYTEORDER = ['MSEED', 'Q', 'SAC', 'SEGY', 'SU']

_sys_is_le = sys.byteorder == 'little'
NATIVE_BYTEORDER = _sys_is_le and '<' or '>'


class NamedTemporaryFile(io.BufferedIOBase):
    """
    Weak replacement for the Python's tempfile.TemporaryFile.

    This class is a replacement for :func:`tempfile.NamedTemporaryFile` but
    will work also with Windows 7/Vista's UAC.

    :type dir: str
    :param dir: If specified, the file will be created in that directory,
        otherwise the default directory for temporary files is used.
    :type suffix: str
    :param suffix: The temporary file name will end with that suffix. Defaults
        to ``'.tmp'``.

    .. rubric:: Example

    >>> with NamedTemporaryFile() as tf:
    ...     _ = tf.write(b"test")
    ...     os.path.exists(tf.name)
    True
    >>> # when using the with statement, the file is deleted at the end:
    >>> os.path.exists(tf.name)
    False

    >>> with NamedTemporaryFile() as tf:
    ...     filename = tf.name
    ...     with open(filename, 'wb') as fh:
    ...         _ = fh.write(b"just a test")
    ...     with open(filename, 'r') as fh:
    ...         print(fh.read())
    just a test
    >>> # when using the with statement, the file is deleted at the end:
    >>> os.path.exists(tf.name)
    False
    """
    def __init__(self, dir=None, suffix='.tmp', prefix='obspy-'):
        fd, self.name = tempfile.mkstemp(dir=dir, prefix=prefix, suffix=suffix)
        self._fileobj = os.fdopen(fd, 'w+b', 0)  # 0 -> do not buffer

    def read(self, *args, **kwargs):
        return self._fileobj.read(*args, **kwargs)

    def write(self, *args, **kwargs):
        return self._fileobj.write(*args, **kwargs)

    def seek(self, *args, **kwargs):
        self._fileobj.seek(*args, **kwargs)
        return self._fileobj.tell()

    def tell(self, *args, **kwargs):
        return self._fileobj.tell(*args, **kwargs)

    def close(self, *args, **kwargs):
        super(NamedTemporaryFile, self).close(*args, **kwargs)
        self._fileobj.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):  # @UnusedVariable
        self.close()
        os.remove(self.name)


def create_empty_data_chunk(delta, dtype, fill_value=None):
    """
    Creates an NumPy array depending on the given data type and fill value.

    If no ``fill_value`` is given a masked array will be returned.

    :param delta: Number of samples for data chunk
    :param dtype: NumPy dtype for returned data chunk
    :param fill_value: If ``None``, masked array is returned, else the
        array is filled with the corresponding value

    .. rubric:: Example

    >>> create_empty_data_chunk(3, 'int', 10)
    array([10, 10, 10])

    >>> create_empty_data_chunk(
    ...     3, 'f')  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    masked_array(data = [-- -- --],
                 mask = ...,
                 ...)
    """
    # For compatibility with NumPy 1.4
    if isinstance(dtype, str):
        dtype = native_str(dtype)
    if fill_value is None:
        temp = np.ma.masked_all(delta, dtype=np.dtype(dtype))
        # fill with nan if float number and otherwise with a very small number
        if issubclass(temp.data.dtype.type, np.integer):
            temp.data[:] = np.iinfo(temp.data.dtype).min
        else:
            temp.data[:] = np.nan
    elif (isinstance(fill_value, list) or isinstance(fill_value, tuple)) \
            and len(fill_value) == 2:
        # if two values are supplied use these as samples bordering to our data
        # and interpolate between:
        ls = fill_value[0]
        rs = fill_value[1]
        # include left and right sample (delta + 2)
        interpolation = np.linspace(ls, rs, delta + 2)
        # cut ls and rs and ensure correct data type
        temp = np.require(interpolation[1:-1], dtype=np.dtype(dtype))
    else:
        temp = np.ones(delta, dtype=np.dtype(dtype))
        temp *= fill_value
    return temp


def get_example_file(filename):
    """
    Function to find the absolute path of a data file

    The ObsPy modules are installed to a custom installation directory.
    That is the path cannot be predicted. This functions searches for all
    installed ObsPy modules and checks whether the file is in any of
    the "tests/data/" or "data/" subdirectories.

    :param filename: A test file name to which the path should be returned.
    :return: Full path to file.

    .. rubric:: Example

    >>> get_example_file('slist.ascii')  # doctest: +SKIP
    /custom/path/to/obspy/io/ascii/tests/data/slist.ascii

    >>> get_example_file('does.not.exists')  # doctest: +ELLIPSIS
    Traceback (most recent call last):
    ...
    OSError: Could not find file does.not.exists ...
    """
    for module in ALL_MODULES:
        try:
            mod = __import__("obspy.%s" % module,
                             fromlist=[native_str("obspy")])
        except ImportError:
            continue
        file_ = os.path.join(mod.__path__[0], "tests", "data", filename)
        if os.path.isfile(file_):
            return file_
        file_ = os.path.join(mod.__path__[0], "data", filename)
        if os.path.isfile(file_):
            return file_
    msg = ("Could not find file %s in tests/data or data "
           "directory of ObsPy modules") % filename
    raise OSError(msg)


def _get_entry_points(group, subgroup=None):
    """
    Gets a dictionary of all available plug-ins of a group or subgroup.

    :type group: str
    :param group: Group name.
    :type subgroup: str, optional
    :param subgroup: Subgroup name (defaults to None).
    :rtype: dict
    :returns: Dictionary of entry points of each plug-in.

    .. rubric:: Example

    >>> _get_entry_points('obspy.plugin.waveform')  # doctest: +ELLIPSIS
    {...'SLIST': EntryPoint.parse('SLIST = obspy.io.ascii.core')...}
    """
    features = {}
    for ep in iter_entry_points(group):
        if subgroup:
            if list(iter_entry_points(group + '.' + ep.name, subgroup)):
                features[ep.name] = ep
        else:
            features[ep.name] = ep
    return features


def _get_ordered_entry_points(group, subgroup=None, order_list=[]):
    """
    Gets a ordered dictionary of all available plug-ins of a group or subgroup.
    """
    # get all available entry points
    ep_dict = _get_entry_points(group, subgroup)
    # loop through official supported waveform plug-ins and add them to
    # ordered dict of entry points
    entry_points = OrderedDict()
    for name in order_list:
        try:
            entry_points[name] = ep_dict.pop(name)
        except Exception:
            # skip plug-ins which are not installed
            continue
    # extend entry points with any left over waveform plug-ins
    entry_points.update(ep_dict)
    return entry_points


ENTRY_POINTS = {
    'trigger': _get_entry_points('obspy.plugin.trigger'),
    'filter': _get_entry_points('obspy.plugin.filter'),
    'rotate': _get_entry_points('obspy.plugin.rotate'),
    'detrend': _get_entry_points('obspy.plugin.detrend'),
    'interpolate': _get_entry_points('obspy.plugin.interpolate'),
    'integrate': _get_entry_points('obspy.plugin.integrate'),
    'differentiate': _get_entry_points('obspy.plugin.differentiate'),
    'waveform': _get_ordered_entry_points(
        'obspy.plugin.waveform', 'readFormat', WAVEFORM_PREFERRED_ORDER),
    'waveform_write': _get_ordered_entry_points(
        'obspy.plugin.waveform', 'writeFormat', WAVEFORM_PREFERRED_ORDER),
    'event': _get_ordered_entry_points('obspy.plugin.event', 'readFormat',
                                       EVENT_PREFERRED_ORDER),
    'event_write': _get_entry_points('obspy.plugin.event', 'writeFormat'),
    'taper': _get_entry_points('obspy.plugin.taper'),
    'inventory': _get_ordered_entry_points(
        'obspy.plugin.inventory', 'readFormat', INVENTORY_PREFERRED_ORDER),
    'inventory_write': _get_entry_points(
        'obspy.plugin.inventory', 'writeFormat'),
}


def _get_function_from_entry_point(group, type):
    """
    A "automagic" function searching a given dict of entry points for a valid
    entry point and returns the function call. Otherwise it will raise a
    default error message.

    .. rubric:: Example

    >>> _get_function_from_entry_point(
    ...     'detrend', 'simple')  # doctest: +ELLIPSIS
    <function simple at 0x...>

    >>> _get_function_from_entry_point('detrend', 'XXX')  # doctest: +ELLIPSIS
    Traceback (most recent call last):
    ...
    ValueError: Detrend type "XXX" is not supported. Supported types: ...
    """
    ep_dict = ENTRY_POINTS[group]
    try:
        # get entry point
        if type in ep_dict:
            entry_point = ep_dict[type]
        else:
            # search using lower cases only
            entry_point = [v for k, v in ep_dict.items()
                           if k.lower() == type.lower()][0]
    except (KeyError, IndexError):
        # check if any entry points are available at all
        if not ep_dict:
            msg = "Your current ObsPy installation does not support " + \
                  "any %s functions. Please make sure " + \
                  "SciPy is installed properly."
            raise ImportError(msg % (group.capitalize()))
        # ok we have entry points, but specified function is not supported
        msg = "%s type \"%s\" is not supported. Supported types: %s"
        raise ValueError(msg % (group.capitalize(), type, ', '.join(ep_dict)))
    # import function point
    # any issue during import of entry point should be raised, so the user has
    # a chance to correct the problem
    func = buffered_load_entry_point(entry_point.dist.key,
                                     'obspy.plugin.%s' % (group),
                                     entry_point.name)
    return func


def get_dependency_version(package_name, raw_string=False):
    """
    Get version information of a dependency package.

    :type package_name: str
    :param package_name: Name of package to return version info for
    :returns: Package version as a list of three integers or ``None`` if
        import fails. With option ``raw_string=True`` returns raw version
        string instead (or ``None`` if import fails).
        The last version number can indicate different things like it being a
        version from the old svn trunk, the latest git repo, some release
        candidate version, ...
        If the last number cannot be converted to an integer it will be set to
        0.
    """
    try:
        version_string = pkg_resources.get_distribution(package_name).version
    except pkg_resources.DistributionNotFound:
        return []
    if raw_string:
        return version_string
    version_list = version_string.split("rc")[0].strip("~")
    version_list = list(map(to_int_or_zero, version_list.split(".")))
    return version_list


def get_proj_version(raw_string=False):
    """
    Get the version number for proj4 as a list.

    proj4 >= 5 does not play nicely for pseudocyl projections
    (see basemap issue 433).  Checking this will allow us to raise a warning
    when plotting using basemap.

    :returns: Package version as a list of three integers. Empty list if pyproj
        not installed.
        With option ``raw_string=True`` returns raw version string instead.
        The last version number can indicate different things like it being a
        version from the old svn trunk, the latest git repo, some release
        candidate version, ...
        If the last number cannot be converted to an integer it will be set to
        0.
    """
    try:
        from pyproj import Proj
    except ImportError:
        return []

    # proj4 is a c library, prproj wraps this.  proj_version is an attribute
    # of the Proj class that is only set when the projection is made. Make
    # a dummy projection and get the version
    version_string = str(Proj(proj='utm', zone=10, ellps='WGS84').proj_version)
    if raw_string:
        return version_string
    version_list = [to_int_or_zero(no) for no in version_string.split(".")]
    # For version 5.2.0 the version number is given as 5.2
    while len(version_list) < 3:
        version_list.append(0)
    return version_list


NUMPY_VERSION = get_dependency_version('numpy')
SCIPY_VERSION = get_dependency_version('scipy')
MATPLOTLIB_VERSION = get_dependency_version('matplotlib')
BASEMAP_VERSION = get_dependency_version('basemap')
PROJ4_VERSION = get_proj_version()
CARTOPY_VERSION = get_dependency_version('cartopy')


if PY2:
    FileNotFoundError = getattr(builtins, 'IOError')


def _read_from_plugin(plugin_type, filename, format=None, **kwargs):
    """
    Reads a single file from a plug-in's readFormat function.
    """
    if isinstance(filename, (str, native_str)):
        if not os.path.exists(filename):
            msg = "[Errno 2] No such file or directory: '{}'".format(
                filename)
            raise FileNotFoundError(msg)
    eps = ENTRY_POINTS[plugin_type]
    # get format entry point
    format_ep = None
    if not format:
        # auto detect format - go through all known formats in given sort order
        for format_ep in eps.values():
            # search isFormat for given entry point
            is_format = buffered_load_entry_point(
                format_ep.dist.key,
                'obspy.plugin.%s.%s' % (plugin_type, format_ep.name),
                'isFormat')
            # If it is a file-like object, store the position and restore it
            # later to avoid that the isFormat() functions move the file
            # pointer.
            if hasattr(filename, "tell") and hasattr(filename, "seek"):
                position = filename.tell()
            else:
                position = None
            # check format
            is_format = is_format(filename)
            if position is not None:
                filename.seek(0, 0)
            if is_format:
                break
        else:
            raise TypeError('Unknown format for file %s' % filename)
    else:
        # format given via argument
        format = format.upper()
        try:
            format_ep = eps[format]
        except (KeyError, IndexError):
            msg = "Format \"%s\" is not supported. Supported types: %s"
            raise TypeError(msg % (format, ', '.join(eps)))
    # file format should be known by now
    try:
        # search readFormat for given entry point
        read_format = buffered_load_entry_point(
            format_ep.dist.key,
            'obspy.plugin.%s.%s' % (plugin_type, format_ep.name),
            'readFormat')
    except ImportError:
        msg = "Format \"%s\" is not supported. Supported types: %s"
        raise TypeError(msg % (format_ep.name, ', '.join(eps)))
    # read
    list_obj = read_format(filename, **kwargs)
    return list_obj, format_ep.name


def get_script_dir_name():
    """
    Get the directory of the current script file. This is more robust than
    using __file__.
    """
    return os.path.abspath(os.path.dirname(inspect.getfile(
        inspect.currentframe())))


def make_format_plugin_table(group="waveform", method="read", numspaces=4,
                             unindent_first_line=True):
    """
    Returns a markdown formatted table with read waveform plugins to insert
    in docstrings.

    >>> table = make_format_plugin_table("event", "write", 4, True)
    >>> print(table)  # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    ======... ===========... ========================================...
    Format    Used Module    _`Linked Function Call`
    ======... ===========... ========================================...
    CMTSOLUTION  :mod:`...io.cmtsolution` :func:`..._write_cmtsolution`
    CNV       :mod:`...io.cnv`   :func:`obspy.io.cnv.core._write_cnv`
    JSON      :mod:`...io.json`  :func:`obspy.io.json.core._write_json`
    KML       :mod:`obspy.io.kml` :func:`obspy.io.kml.core._write_kml`
    NLLOC_OBS :mod:`...io.nlloc` :func:`obspy.io.nlloc.core.write_nlloc_obs`
    NORDIC    :mod:`obspy.io.nordic` :func:`obspy.io.nordic.core.write_select`
    QUAKEML :mod:`...io.quakeml` :func:`obspy.io.quakeml.core._write_quakeml`
    SC3ML   :mod:`...io.seiscomp` :func:`obspy.io.seiscomp.event._write_sc3ml`
    SCARDEC   :mod:`obspy.io.scardec`
                             :func:`obspy.io.scardec.core._write_scardec`
    SHAPEFILE :mod:`obspy.io.shapefile`
                             :func:`obspy.io.shapefile.core._write_shapefile`
    ZMAP      :mod:`...io.zmap`  :func:`obspy.io.zmap.core._write_zmap`
    ======... ===========... ========================================...

    :type group: str
    :param group: Plugin group to search (e.g. "waveform" or "event").
    :type method: str
    :param method: Either 'read' or 'write' to select plugins based on either
        read or write capability.
    :type numspaces: int
    :param numspaces: Number of spaces prepended to each line (for indentation
        in docstrings).
    :type unindent_first_line: bool
    :param unindent_first_line: Determines if first line should start with
        prepended spaces or not.
    """
    method = method.lower()
    if method not in ("read", "write"):
        raise ValueError("no valid type: %s" % method)

    method = "%sFormat" % method
    eps = _get_ordered_entry_points("obspy.plugin.%s" % group, method,
                                    WAVEFORM_PREFERRED_ORDER)
    mod_list = []
    for name, ep in eps.items():
        module_short = ":mod:`%s`" % ".".join(ep.module_name.split(".")[:3])
        ep_list = [ep.dist.key, "obspy.plugin.%s.%s" % (group, name), method]
        func = buffered_load_entry_point(*ep_list)
        func_str = ':func:`%s`' % ".".join((ep.module_name, func.__name__))
        mod_list.append((name, module_short, func_str))

    mod_list = sorted(mod_list)
    headers = ["Format", "Used Module", "_`Linked Function Call`"]
    maxlens = [max([len(x[0]) for x in mod_list] + [len(headers[0])]),
               max([len(x[1]) for x in mod_list] + [len(headers[1])]),
               max([len(x[2]) for x in mod_list] + [len(headers[2])])]

    info_str = [" ".join(["=" * x for x in maxlens])]
    info_str.append(
        " ".join([headers[i].ljust(maxlens[i]) for i in range(3)]))
    info_str.append(info_str[0])

    for mod_infos in mod_list:
        info_str.append(
            " ".join([mod_infos[i].ljust(maxlens[i]) for i in range(3)]))
    info_str.append(info_str[0])

    ret = " " * numspaces + ("\n" + " " * numspaces).join(info_str)
    if unindent_first_line:
        ret = ret[numspaces:]
    return ret


class ComparingObject(object):
    """
    Simple base class that implements == and != based on self.__dict__
    """
    def __eq__(self, other):
        return (isinstance(other, self.__class__)
                and self.__dict__ == other.__dict__)

    def __ne__(self, other):
        return not self.__eq__(other)


def _get_deprecated_argument_action(old_name, new_name, real_action='store'):
    """
    Specifies deprecated command-line arguments to scripts
    """
    message = '%s has been deprecated. Please use %s in the future.' % (
        old_name, new_name
    )

    from argparse import Action

    class _Action(Action):
        def __call__(self, parser, namespace, values, option_string=None):
            import warnings
            warnings.warn(message)

            # I wish there were an easier way...
            if real_action == 'store':
                setattr(namespace, self.dest, values)
            elif real_action == 'store_true':
                setattr(namespace, self.dest, True)
            elif real_action == 'store_false':
                setattr(namespace, self.dest, False)

    return _Action


def sanitize_filename(filename):
    """
    Adapted from Django's slugify functions.

    :param filename: The filename.
    """
    try:
        filename = filename.decode()
    except AttributeError:
        pass

    value = unicodedata.normalize('NFKD', filename).encode(
        'ascii', 'ignore').decode('ascii')
    # In constrast to django we allow dots and don't lowercase.
    value = re.sub(r'[^\w\.\s-]', '', value).strip()
    return re.sub(r'[-\s]+', '-', value)


def download_to_file(url, filename_or_buffer, chunk_size=1024):
    """
    Helper function to download a potentially large file.

    :param url: The URL to GET the data from.
    :type url: str
    :param filename_or_buffer: The filename_or_buffer or file-like object to
        download to.
    :type filename_or_buffer: str or file-like object
    :param chunk_size: The chunk size in bytes.
    :type chunk_size: int
    """
    # Workaround for old request versions.
    try:
        r = requests.get(url, stream=True)
    except TypeError:
        r = requests.get(url)

    # Raise anything except for 200
    if r.status_code != 200:
        raise requests.HTTPError('%s HTTP Error: %s for url: %s'
                                 % (r.status_code, r.reason, url))

    if hasattr(filename_or_buffer, "write"):
        for chunk in r.iter_content(chunk_size=chunk_size):
            if not chunk:
                continue
            filename_or_buffer.write(chunk)
    else:
        with io.open(filename_or_buffer, "wb") as fh:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if not chunk:
                    continue
                fh.write(chunk)


if __name__ == '__main__':
    doctest.testmod(exclude_empty=True)
