import unittest
import struct
import math
try:
    import numpy as np
except:
    np = None

from PyNumpress import MSNumpress

@unittest.skipIf(np is None, 'Numpy is required for this test.')
class test_MSNumpress(unittest.TestCase):

    """
    unittest for MSNumpress en- and decoding
    """

    def setUp(self):
        """
        """
        self.mz_data = np.asarray([100.1, 100.01, 100.001, 100.0001], dtype=np.float64)
        self.i_data = [5e6, 6.5e5, 2e6, 12e6]
        self.i_slof_data = [1, 2, 4]
        self.integer = 23
        # self._encoded_int = bytes([ 0x06, 0x07, 0x01 ])
        self.Decoder = MSNumpress([])
        self.fixed_point = 10000

    def test_encode_slof(self):
        fp = self.Decoder.optimal_slof_fixed_point(self.i_slof_data)
        encoded_array = self.Decoder.encode_slof(np.asarray(self.i_slof_data, dtype=np.float64), fp)

        self.assertEqual(len(encoded_array), 14)

        self.assertEqual(encoded_array[0], 0x40)
        self.assertEqual(encoded_array[1], 0xe3)
        self.assertEqual(encoded_array[2], 0xe1)
        self.assertEqual(encoded_array[3], 0xe0)
        self.assertEqual(encoded_array[4], 0x0)
        self.assertEqual(encoded_array[5], 0x0)
        self.assertEqual(encoded_array[6], 0x0)
        self.assertEqual(encoded_array[7], 0x0)
        self.assertEqual(encoded_array[8], 0x40)
        self.assertEqual(encoded_array[9], 0x6e)
        self.assertEqual(encoded_array[10], 0xbe)
        self.assertEqual(encoded_array[11], 0xae)
        self.assertEqual(encoded_array[12], 0xff)
        self.assertEqual(encoded_array[13], 0xff)

    def test_decode_slof(self):
        test_array = [
            0x40, 0xc3, 0x88, 0x0,
            0x0, 0x0, 0x0, 0x0,
            0x13, 0x1b, 0xea, 0x2a,
            0xde, 0x3e
        ]
        test_array = bytearray(test_array)
        # self.Decoder.encoded_data = test_array
        decoded_array = self.Decoder.decode_slof(test_array)
        for i, dec in enumerate(decoded_array):
            self.assertAlmostEqual(dec, self.i_slof_data[i], places=2)
        # self.assertCountEqual(decoded_array, self.i_slof_data)

    def test_encode_decode_slof(self):
        """
        """
        # self.Decoder.decoded_data = self.i_slof_data
        fp = self.Decoder.optimal_slof_fixed_point(self.i_slof_data)
        encoded_array = self.Decoder.encode_slof(
            np.asarray(self.i_slof_data, dtype=np.float64),
            fp
        )
        decoded_array = self.Decoder.decode_slof(
            encoded_array
        )
        for i, dec in enumerate(decoded_array):
            self.assertAlmostEqual(dec, self.i_slof_data[i], places=2)

    def test_encode_pic_i_data(self):
        encoded_array = self.Decoder.encode_pic(np.asarray(self.i_data, dtype=np.float64))

        self.assertEqual(
            len(encoded_array),
            14,
            msg='{}'.format(
                [hex(x) for x in encoded_array]
            )
        )

        self.assertEqual(encoded_array[0], 0x20)
        self.assertEqual(encoded_array[1], 0x4b)
        self.assertEqual(encoded_array[2], 0x4c)
        self.assertEqual(encoded_array[3], 0x43)
        self.assertEqual(encoded_array[4], 0x1)
        self.assertEqual(encoded_array[5], 0xbe)
        self.assertEqual(encoded_array[6], 0x92)
        self.assertEqual(encoded_array[7], 0x8)
        self.assertEqual(encoded_array[8], 0x48)
        self.assertEqual(encoded_array[9], 0xe1)
        self.assertEqual(encoded_array[10], 0x20)
        self.assertEqual(encoded_array[11], 0xb)
        self.assertEqual(encoded_array[12], 0x17)
        self.assertEqual(encoded_array[13], 0xb0)

    def test_decode_pic_i_data(self):
        test_array = [
            0x20, 0x4b, 0x4c, 0x43,
            0x1, 0xbe, 0x92, 0x8,
            0x48, 0xe1, 0x20, 0xb,
            0x17, 0xb0
        ]
        decoded_array = self.Decoder.decode_pic(test_array)
        self.assertCountEqual(
            decoded_array,
            self.i_data,
            msg='{}\n{}'.format(
                [x for x in decoded_array],
                [x for x in self.i_data]
            )
        )

    def test_encode_decode_pic(self):
        """
        """
        encoded_array = self.Decoder.encode_pic(
            np.asarray(self.i_slof_data, dtype=np.float64)
        )
        decoded_array = self.Decoder.decode_pic(encoded_array)
        self.assertCountEqual(self.i_slof_data, decoded_array)

    # def test_encode_linear(self):
    #     self.Decoder.decoded_data = self.mz_data
    #     self.Decoder.encoded_data = None
    #     encoded_array = self.Decoder.encode_linear()
    #     self.assertEqual(
    #         len(encoded_array),
    #         23,
    #         msg='{0}\n{1}'.format(
    #             [hex(x) for x in encoded_array],
    #             [
    #                 '0x41', '0x74', '0x75', '0xa4', '0x70', '0x0', '0x0',
    #                 '0x0', '0xf6', '0xff', '0xff', '0x7f', '0xc2', '0x89',
    #                 '0xe2', '0x7f', '0x2b', '0xf3', '0x8a', '0x13', '0xdc',
    #                 '0x6a', '0x20'
    #             ]
    #         )
    #     )

    #     # assert first value
    #     self.assertEqual(
    #         0xff & encoded_array[8],
    #         0xf6,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             8, hex(0xff & encoded_array[8]), hex(0xf6)
    #         )
    #     )

    #     self.assertEqual(
    #         0xff & encoded_array[9],
    #         0xff,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             9,
    #             hex(0xff & encoded_array[9]),
    #             hex(0xff)
    #         )
    #         )
    #     self.assertEqual(
    #         0xff & encoded_array[10],
    #         0xff,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             10,
    #             hex(0xff & encoded_array[10]),
    #             hex(0xff)
    #         )
    #         )
    #     self.assertEqual(
    #         0xff & encoded_array[11],
    #         0x7f,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             11,
    #             hex(0xff & encoded_array[11]),
    #             hex(0x7f)
    #         )
    #         )

    #     # assert second value
    #     self.assertEqual(
    #         0xff & encoded_array[12],
    #         0xc2,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             12,
    #             hex(0xff & encoded_array[12]),
    #             hex(0xc2)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[13],
    #         0x89,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             13,
    #             hex(0xff & encoded_array[13]),
    #             hex(0x89)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[14],
    #         0xe2,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             14,
    #             hex(0xff & encoded_array[14]),
    #             hex(0xe2)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[15],
    #         0x7f,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             15,
    #             hex(0xff & encoded_array[15]),
    #             hex(0x7f)
    #         )
    #     )

    #     # assert third value
    #     self.assertEqual(
    #         0xff & encoded_array[16],
    #         0x2b,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             16,
    #             hex(0xff & encoded_array[16]),
    #             hex(0x2b)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[17],
    #         0xf3,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             17,
    #             hex(0xff & encoded_array[17]),
    #             hex(0xc2)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[18],
    #         0x8a,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             18,
    #             hex(0xff & encoded_array[18]),
    #             hex(0x8a)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[19],
    #         0x13,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             19,
    #             hex(0xff & encoded_array[19]),
    #             hex(0x13)
    #         )
    #     )

    #     # assert fourth value
    #     self.assertEqual(
    #         0xff & encoded_array[20],
    #         0xdc,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             19,
    #             hex(0xff & encoded_array[20]),
    #             hex(0xdc)
    #         )
    #     )
    #     self.assertEqual(
    #         0xff & encoded_array[21],
    #         0x6a,
    #         msg='Fail in value at pos {0}: {1} != {2}'.format(
    #             19,
    #             hex(0xff & encoded_array[21]),
    #             hex(0x6a)
    #         )
    #     )
    #     # self.assertEqual(
    #     #     0xff & encoded_array[22],
    #     #     0x20,
    #     #     msg='Fail in value at pos {0}: {1} != {2}'.format(
    #     #         19,
    #     #         hex(0xff & encoded_array[22]),
    #     #         hex(0x20)
    #     #     )
    #     # )

    # def test_decode_linear(self):

    #     # ouput of  >>> PyMSNumpress.encode_linear(self.mz_data, enc,
    #     # self.fixed_point)
    #     encoded_array = [
    #         65, 116, 117, 164, 112, 0, 0, 0, 246, 255,
    #         255, 127, 194, 137, 226, 127, 43, 243, 138, 19, 220, 106, 32
    #     ]

    #     # make bytearray from hex vals
    #     encoded_array = bytearray(encoded_array)
    #     MSNumpress.encoded_data = encoded_array

    #     # decode bytearray to numpy array
    #     decoded_array = self.Decoder.decode_linear()
    #     self.assertIsInstance(decoded_array, np.ndarray)
    #     self.assertEqual(len(decoded_array), len(self.mz_data))
    #     for i in range(len(decoded_array)):
    #         self.assertAlmostEqual(
    #             decoded_array[i], self.mz_data[i], places=i+1)

    def test_encode_decode_linear(self):
        """
        """
        test_array = [
            100.00066,
            100.00217,
            100.00368,
            100.00519,
            111.73335,
            111.73513,
            111.73692,
            111.7387,
            111.74049,
            111.74227,
            111.74406,
            111.74584,
            111.74763,
            111.74941,
            111.7512,
            111.75298,
            111.75477,
            112.00694,
            112.00873,
            112.01052,
            112.01231,
            112.0141,
            112.01589,
            112.01768
        ]
        test_array = np.asarray(test_array, dtype=np.float64)
        fp = self.Decoder.optimal_linear_fixed_point(test_array)
        encoded_array = self.Decoder.encode_linear(
            test_array,
            fp
        )
        decoded_array = self.Decoder.decode_linear(encoded_array)
        for i in range(len(decoded_array)):
            self.assertAlmostEqual(
                decoded_array[i],
                test_array[i],
                places=4,
                msg='error at pos {0}'.format(i)
            )

    def test_encode_decode_self_mz(self):
        mz_data = np.asarray(self.mz_data, dtype=np.float64)
        fp = self.Decoder.optimal_linear_fixed_point(mz_data)
        encoded_mz_data = self.Decoder.encode_linear(mz_data, fp)
        decoded_mz_data = self.Decoder.decode_linear(encoded_mz_data)
        for i in range(len(decoded_mz_data)):
            self.assertAlmostEqual(
                decoded_mz_data[i],
                mz_data[i],
                places=4,
                msg='error at pos {0}'.format(i)
            )

if __name__ == '__main__':
    unittest.main(verbosity=3)
