#! python
"""
Copyright 2016 Oliver Schoenborn. BSD 3-Clause license (see __license__ at bottom of this file for details).

This script transforms nose.tools.assert_* function calls into raw assert statements, while preserving format
of original arguments as much as possible. A small subset of nose.tools.assert_* function calls is not
transformed because there is no raw assert statement equivalent. However, if you don't use those functions
in your code, you will be able to remove nose as a test dependency of your library/app.

Requires Python 3.4.

This script relies heavily on lib2to3, using it to find patterns of code to transform and convert transformed
code nodes back into Python source code. The following article was very useful:
http://python3porting.com/fixers.html#find-pattern.
"""

__version__ = "1.0.0"

import argparse
import logging

from lib2to3 import refactor, fixer_base, pygram, pytree, pgen2
from lib2to3.pytree import Node as PyNode, Leaf as PyLeaf
from lib2to3.pgen2 import token
from lib2to3.fixer_util import parenthesize


log = logging.getLogger('nose2pytest')


def override_required(func: callable):
    """Decorator used to document that the decorated function must be overridden in derived class."""
    return func


def override_optional(func: callable):
    """Decorator used to document that the decorated function can be overridden in derived class, but need not be."""
    return func


def override(BaseClass):
    """Decorator used to document that the decorated function overrides the function of same name in BaseClass."""

    def decorator(func):
        return func

    return decorator


# Transformations:

grammar = pygram.python_grammar
driver = pgen2.driver.Driver(grammar, convert=pytree.convert, logger=log)


PATTERN_ONE_ARG_OR_KWARG =   """power< 'func' trailer< '(' not(arglist) obj1=any                         ')' > >"""
PATTERN_ONE_ARG =            """power< 'func' trailer< '(' not(arglist | argument<any '=' any>) obj1=any ')' > >"""
PATTERN_ONE_KWARG =          """power< 'func' trailer< '(' obj1=argument< any '=' any >                  ')' > >"""
PATTERN_TWO_ARGS_OR_KWARGS = """power< 'func' trailer< '(' arglist< obj1=any ',' obj2=any >              ')' > >"""

PATTERN_1_OR_2_ARGS = """
    power< '{}' trailer< '('
        ( not(arglist | argument<any '=' any>) test=any
        | arglist< test=any ',' msg=any > )
    ')' > >
    """

PATTERN_2_OR_3_ARGS = """
    power< '{}' trailer< '('
        ( arglist< lhs=any ',' rhs=any [','] >
        | arglist< lhs=any ',' rhs=any ',' msg=any > )
    ')' > >
    """

PATTERN_ALMOST_ARGS = """
    power< '{}' trailer< '('
        ( arglist< aaa=any ',' bbb=any ',' delta=any [','] >
        | arglist< aaa=any ',' bbb=any ',' delta=any ',' msg=any > )
    ')' > >
    """

# for the following node types, contains_newline() will return False even if newlines are between ()[]{}
NEWLINE_OK_TOKENS = (token.LPAR, token.LSQB, token.LBRACE)

# these operators require parens around function arg if binop is ==, !=, ...
COMPARISON_TOKENS = (token.EQEQUAL, token.NOTEQUAL, token.LESS, token.LESSEQUAL, token.GREATER, token.GREATEREQUAL)
# the first item of each triplet below is a number generated by the grammar driver, so could easily break in future
# versions of Python:
MEMBERSHIP_SYMBOLS = ((271, 1, 'in'), (271, 270, 'not in'))
IDENTITY_SYMBOLS = ((271, 1, 'is'), (271, 270, 'is not'))
BOOLEAN_OPS = ((302, 1, 'not'), (258, 1, 'and'), (305, 1, 'or'))

# these operators require parens around function arg if binop is + or -
ADD_SUB_GROUP_TOKENS = (
    token.PLUS, token.MINUS,
    token.RIGHTSHIFT, token.LEFTSHIFT,
    token.VBAR, token.AMPER, token.CIRCUMFLEX,
)


def contains_newline(node: PyNode) -> bool:
    """
    Returns True if any of the children of node have a prefix containing \n, or any of their children recursively.
    Returns False if no non-bracketed children are found that have such prefix. Example: node of 'a\n  in b' would
    return True, whereas '(a\n   b)' would return False.
    """
    for child in node.children:
        if child.type in NEWLINE_OK_TOKENS:
            return False
        if '\n' in child.prefix:
            return True
        if isinstance(child, PyNode) and contains_newline(child):
            return True

    return False


def wrap_parens(arg_node: PyNode, checker_fn: callable) -> PyNode or PyLeaf:
    """
    If a node that represents an argument to assert_ function should be grouped, return a new node that adds
    parentheses around arg_node. Otherwise, return arg_node.
    :param arg_node: the arg_node to parenthesize
    :return: the arg_node for the parenthesized expression, or the arg_node itself
    """
    if isinstance(arg_node, PyNode) and checker_fn(arg_node):
        # log.info('adding parens: "{}" ({}), "{}" ({})'.format(first_child, first_child.type, sibling, sibling.type))
        # sometimes arg_node has parent, need to remove it before giving to parenthesize() then re-insert:
        parent = arg_node.parent
        if parent is not None:
            pos_parent = arg_node.remove()
            new_node = parenthesize(arg_node)
            parent.insert_child(pos_parent, new_node)
        else:
            new_node = parenthesize(arg_node)

        new_node.prefix = arg_node.prefix
        arg_node.prefix = ''
        return new_node

    return arg_node


def is_if_else_op(node: PyNode) -> bool:
    return (len(node.children) == 5 and
            node.children[1] == PyLeaf(token.NAME, 'if') and
            node.children[3] == PyLeaf(token.NAME, 'else')
            )


def has_weak_op_for_comparison(node: PyNode) -> bool:
    """Test if node contains operators that are weaking than comparison operators"""

    if is_if_else_op(node):
        return True

    for child in node.children:
        if child.type in NEWLINE_OK_TOKENS:
            return False

        # comparisons and boolean combination:
        binop_type = child.type
        if binop_type in COMPARISON_TOKENS:
            return True

        # membership and identity tests:
        binop_name = str(child).strip()
        symbol = (node.type, binop_type, binop_name)
        if symbol in BOOLEAN_OPS or symbol in MEMBERSHIP_SYMBOLS or symbol in IDENTITY_SYMBOLS:
            return True

        # continue into children that are nodes:
        if isinstance(child, PyNode) and has_weak_op_for_comparison(child):
            return True

    return False


def wrap_parens_for_comparison(arg_node: PyNode or PyLeaf) -> PyNode or PyLeaf:
    """
    Assuming arg_node represents an argument to an assert_ function that uses comparison operators, then if
    arg_node has any operators that have equal or weaker precedence than those operators (including
    membership and identity test operators), return a new node that adds parentheses around arg_node.
    Otherwise, return arg_node.

    :param arg_node: the arg_node to parenthesize
    :return: the arg_node for the parenthesized expression, or the arg_node itself
    """
    return wrap_parens(arg_node, has_weak_op_for_comparison)


def has_weak_op_for_addsub(node: PyNode, check_comparison: bool=True) -> bool:
    if check_comparison and has_weak_op_for_comparison(node):
        return True

    for child in node.children:
        if child.type in NEWLINE_OK_TOKENS:
            return False

        if child.type in ADD_SUB_GROUP_TOKENS:
            return True

        # continue into children that are nodes:
        if isinstance(child, PyNode) and has_weak_op_for_addsub(child, check_comparison=False):
            return True

    return False


def wrap_parens_for_addsub(arg_node: PyNode or PyLeaf) -> PyNode or PyLeaf:
    """
    Assuming arg_node represents an argument to an assert_ function that uses + or - operators, then if
    arg_node has any operators that have equal or weaker precedence than those operators, return a new node
    that adds parentheses around arg_node. Otherwise, return arg_node.

    :param arg_node: the arg_node to parenthesize
    :return: the arg_node for the parenthesized expression, or the arg_node itself
    """
    return wrap_parens(arg_node, has_weak_op_for_addsub)


def get_prev_sibling(node: PyNode) -> PyNode:
    if node is None:
        return None  # could not find
    if node.prev_sibling is not None:
        return node.prev_sibling
    return get_prev_sibling(node.parent)


def adjust_prefix_first_arg(node: PyNode or PyLeaf, orig_prefix: str):
    if get_prev_sibling(node).type != token.NAME:
        node.prefix = ''
    else:
        node.prefix = orig_prefix or " "


class FixAssertBase(fixer_base.BaseFix):
    # BM_compatible = True

    # Each derived class should define a dictionary where the key is the name of the nose function to convert,
    # and the value is a pair where the first item is the assertion statement expression, and the second item
    # is data that will be available in _transform_dest() override as self._conv_data.
    conversions = None

    @classmethod
    def create_all(cls, *args, **kwargs) -> [fixer_base.BaseFix]:
        """
        Create an instance for each key in cls.conversions, assumed to be defined by derived class.
        The *args and **kwargs are those of BaseFix.
        :return: list of instances created
        """
        fixers = []
        for nose_func in cls.conversions:
            fixers.append(cls(nose_func, *args, **kwargs))
        return fixers

    def __init__(self, nose_func_name: str, *args, **kwargs):
        test_expr, conv_data = self.conversions[nose_func_name]
        self.nose_func_name = nose_func_name
        self._conv_data = conv_data

        self.PATTERN = self.PATTERN.format(nose_func_name)
        log.info('%s will convert %s as "assert %s"', self.__class__.__name__, nose_func_name, test_expr)
        super().__init__(*args, **kwargs)

        self.dest_tree = driver.parse_string('assert ' + test_expr + '\n')
        # remove the \n we added
        del self.dest_tree.children[0].children[1]

    @override(fixer_base.BaseFix)
    def transform(self, node: PyNode, results: {str: PyNode}) -> PyNode:
        assert results
        dest_tree = self.dest_tree.clone()
        assert_arg_test_node = self._get_node(dest_tree, (0, 0, 1))
        assert_args = assert_arg_test_node.parent

        if self._transform_dest(assert_arg_test_node, results):
            assert_arg_test_node = self._get_node(dest_tree, (0, 0, 1))
            if contains_newline(assert_arg_test_node):
                prefixes = assert_arg_test_node.prefix.split('\n', 1)
                assert_arg_test_node.prefix = '\n'+prefixes[1] if len(prefixes) > 1 else ''
                new_node = parenthesize(assert_arg_test_node.clone())
                new_node.prefix = prefixes[0] or ' '
                assert_arg_test_node.replace(new_node)

            self.__handle_opt_msg(assert_args, results)

            dest_tree.prefix = node.prefix
            return dest_tree

        else:
            return node

    @override_required
    def _transform_dest(self, assert_arg_test_node: PyNode, results: {str: PyNode}) -> bool:
        """
        Transform the given node to use the results.
        :param assert_arg_test_node: the destination node representing the assertion test argument
        :param results: the results of pattern matching
        """
        pass

    def _get_node(self, from_node, indices_path: None or int or [int]) -> PyLeaf or PyNode:
        """
        Get a node relative to another node.
        :param from_node: the node from which to start
        :param indices_path: the path through children
        :return: node found (could be leaf); if indices_path is None, this is from_node itself; if it is a
            number, return from_node[indices_path]; else returns according to sequence of children indices

        Example: if indices_path is (1, 2, 3), will return from_node.children[1].children[2].children[3].
        """
        if indices_path is None:
            return from_node

        try:
            node = from_node
            for index in indices_path:
                node = node.children[index]
            return node

        except TypeError:
            return from_node.children[indices_path]

    def __handle_opt_msg(self, assertion_args_node: PyNode, results: {str: PyNode}):
        """
        Append a message argument to assertion args node, if one appears in results.
        :param assertion_args_node: the node representing all the arguments of assertion function
        :param results: results from pattern matching
        """
        if 'msg' in results:
            msg = results["msg"]
            if len(msg.children) > 1:
                # the message text might have been passed by name, extract the text:
                children = msg.children
                if children[0] == PyLeaf(token.NAME, 'msg') and children[1] == PyLeaf(token.EQUAL, '='):
                    msg = children[2]

            msg = msg.clone()
            msg.prefix = ' '
            siblings = assertion_args_node.children
            siblings.append(PyLeaf(token.COMMA, ','))
            siblings.append(msg)


class FixAssert1Arg(FixAssertBase):
    """
    Fixer class for any 1-argument assertion function (assert_func(a)). It supports optional 2nd arg for the
    assertion message, ie assert_func(a, msg) -> assert a binop something, msg.
    """

    PATTERN = PATTERN_1_OR_2_ARGS

    # the conv data is a node children indices path from the PyNode that represents the assertion expression.
    # Example: assert_false(a) becomes "assert not a", so the PyNode for assertion expression is 'not a', and
    # the 'a' is its children[1] so self._conv_data needs to be 1.
    conversions = dict(
        assert_true=('a', None),
        assert_false=('not a', 1),
        assert_is_none=('a is None', 0),
        assert_is_not_none=('a is not None', 0),
    )

    @override(FixAssertBase)
    def _transform_dest(self, assert_arg_test_node: PyNode, results: {str: PyNode}) -> bool:
        test = results["test"]
        test = test.clone()
        test.prefix = " "

        # the destination node for 'a' is in conv_data:
        dest_node = self._get_node(assert_arg_test_node, self._conv_data)
        dest_node.replace(test)

        return True


class FixAssert2Args(FixAssertBase):
    """
    Fixer class for any 2-argument assertion function (assert_func(a, b)). It supports optional third arg
    as the assertion message, ie assert_func(a, b, msg) -> assert a binop b, msg.
    """

    PATTERN = PATTERN_2_OR_3_ARGS

    # The conversion data (2nd item of the value; see base class docs) is a pair of "node paths": the first
    # node path is to "a", the second one is to "b", relative to the assertion expression.
    #
    # Example 1: assert_equal(a, b) will convert to "assert a == b" so the PyNode for assertion expression
    # is 'a == b' and a is that node's children[0], whereas b is that node's children[2], so the self._conv_data
    # is simply (0, 2).
    #
    # Example 2: assert_is_instance(a, b) converts to "assert isinstance(a, b)" so the conversion data is
    # the pair of node paths (1, 1, 0) and (1, 1, 1) since from the PyNode for the assertion expression
    # "isinstance(a, b)", 'a' is that node's children[1].children[1].children[0], whereas 'b' is
    # that node's children[1].children[1].children[1].
    conversions = dict(
        assert_equal=('a == b', (0, 2)),
        assert_equals=('a == b', (0, 2)),
        assert_not_equal=('a != b', (0, 2)),
        assert_not_equals=('a != b', (0, 2)),

        assert_list_equal=('a == b', (0, 2)),
        assert_dict_equal=('a == b', (0, 2)),
        assert_set_equal=('a == b', (0, 2)),
        assert_sequence_equal=('a == b', (0, 2)),
        assert_tuple_equal=('a == b', (0, 2)),
        assert_multi_line_equal=('a == b', (0, 2)),

        assert_greater=('a > b', (0, 2)),
        assert_greater_equal=('a >= b', (0, 2)),
        assert_less=('a < b', (0, 2)),
        assert_less_equal=('a <= b', (0, 2)),

        assert_in=('a in b', (0, 2)),
        assert_not_in=('a not in b', (0, 2)),

        assert_is=('a is b', (0, 2)),
        assert_is_not=('a is not b', (0, 2)),

        assert_is_instance=('isinstance(a, b)', ((1, 1, 0), (1, 1, 2), False)),
        assert_count_equal=('collections.Counter(a) == collections.Counter(b)', ((0, 2, 1), (2, 2, 1), False)),
        assert_not_regex=('not re.search(b, a)', ((1, 2, 1, 2), (1, 2, 1, 0), False)),
        assert_regex=('re.search(b, a)', ((2, 1, 2), (2, 1, 0), False)),
    )

    @override(FixAssertBase)
    def _transform_dest(self, assert_arg_test_node: PyNode, results: {str: PyNode}) -> bool:
        lhs = results["lhs"].clone()

        rhs = results["rhs"]
        rhs = rhs.clone()

        dest1 = self._get_node(assert_arg_test_node, self._conv_data[0])
        dest2 = self._get_node(assert_arg_test_node, self._conv_data[1])

        # only transformations that involve a comparison operator may need wrapping in parens
        trans_op_is_comparison = len(self._conv_data) <= 2 or self._conv_data[2]

        new_lhs = wrap_parens_for_comparison(lhs) if trans_op_is_comparison else lhs
        dest1.replace(new_lhs)
        adjust_prefix_first_arg(new_lhs, results["lhs"].prefix)

        new_rhs = wrap_parens_for_comparison(rhs) if trans_op_is_comparison else rhs
        dest2.replace(new_rhs)
        if get_prev_sibling(new_rhs).type in NEWLINE_OK_TOKENS:
            new_rhs.prefix = ''

        return True


class FixAssertAlmostEq(FixAssertBase):
    """
    Fixer class for any 3-argument assertion function (assert_func(a, b, c)). It supports optional fourth arg
    as the assertion message, ie assert_func(a, b, c, msg) -> assert a op b op c, msg.
    """

    PATTERN = PATTERN_ALMOST_ARGS

    # See FixAssert2Args for an explanation of the conversion data
    conversions = dict(
            assert_almost_equal=('abs(a - b) <= delta', ((0, 1, 1, 0), (0, 1, 1, 2), 2)),
            assert_almost_equals=('abs(a - b) <= delta', ((0, 1, 1, 0), (0, 1, 1, 2), 2)),
            assert_not_almost_equal=('abs(a - b) > delta', ((0, 1, 1, 0), (0, 1, 1, 2), 2)),
            assert_not_almost_equals=('abs(a - b) > delta', ((0, 1, 1, 0), (0, 1, 1, 2), 2)),
    )

    @override(FixAssertBase)
    def _transform_dest(self, assert_arg_test_node: PyNode, results: {str: PyNode}) -> bool:
        delta = results["delta"].clone()
        if not delta.children:
            return False

        aaa = results["aaa"].clone()
        bbb = results["bbb"].clone()

        dest1 = self._get_node(assert_arg_test_node, self._conv_data[0])
        new_aaa = wrap_parens_for_addsub(aaa)
        dest1.replace(new_aaa)
        adjust_prefix_first_arg(new_aaa, results["aaa"].prefix)

        dest2 = self._get_node(assert_arg_test_node, self._conv_data[1])
        dest2.replace(wrap_parens_for_addsub(bbb))

        dest3 = self._get_node(assert_arg_test_node, self._conv_data[2])
        if delta.children[0] == PyLeaf(token.NAME, 'delta'):
            delta_val = delta.children[2]
            delta_val.prefix = " "
            dest3.replace(wrap_parens_for_comparison(delta_val))

        elif delta.children[0] == PyLeaf(token.NAME, 'msg'):
            delta_val = results['msg'].children[2]
            delta_val.prefix = " "
            dest3.replace(wrap_parens_for_comparison(delta_val))
            results['msg'] = delta

        else:
            return False

        return True


# ------------ Main portion of script -------------------------------

class NoseConversionRefactoringTool(refactor.MultiprocessRefactoringTool):
    def __init__(self, verbose: bool=False):
        flags = dict(print_function=True)
        super().__init__([], flags)
        level = logging.DEBUG if verbose else logging.INFO
        logging.basicConfig(format='%(name)s: %(message)s', level=level)
        logger = logging.getLogger('lib2to3.main')

    def get_fixers(self):
        pre_fixers = []
        post_fixers = []

        pre_fixers.extend(FixAssert1Arg.create_all(self.options, self.fixer_log))
        pre_fixers.extend(FixAssert2Args.create_all(self.options, self.fixer_log))
        pre_fixers.extend(FixAssertAlmostEq.create_all(self.options, self.fixer_log))

        return pre_fixers, post_fixers


def setup():
    # from nose import tools as nosetools
    # import inspect
    # for key in dir(nosetools):
    #     if key.startswith('assert_'):
    #         argspec = inspect.getargspec(getattr(nosetools, key))
    #         print(key, argspec)

    parser = argparse.ArgumentParser(description='Convert nose assertions to regular assertions for use by pytest')
    parser.add_argument('dir_name', type=str,
                        help='folder name from which to start; all .py files under it will be converted')
    parser.add_argument('-w', dest='write', action='store_false',
                        help='disable overwriting of original files')
    parser.add_argument('-v', dest='verbose', action='store_true',
                        help='verbose output (list files changed, etc)')

    return parser.parse_args()


if __name__ == '__main__':
    args = setup()
    refac = NoseConversionRefactoringTool(args.verbose)
    refac.refactor_dir(args.dir_name, write=args.write)


__license__ = """
    Copyright (c) 2016, Oliver Schoenborn
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    * Redistributions of source code must retain the above copyright notice, this
      list of conditions and the following disclaimer.

    * Redistributions in binary form must reproduce the above copyright notice,
      this list of conditions and the following disclaimer in the documentation
      and/or other materials provided with the distribution.

    * Neither the name of nose2pytest nor the names of its
      contributors may be used to endorse or promote products derived from
      this software without specific prior written permission.

    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
    AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
    IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
    FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
    DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
    CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
    OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
