#   Compatibility with python 3.x
from __future__ import print_function
from __future__ import unicode_literals

import copy
from collections import namedtuple
import utils
import matplotlib.pyplot as plt
import numpy.matlib as npml
import numpy as np
import math
import sys
from scipy.spatial import Voronoi
import collections
import time

from scipy import signal

from datatypes.RecoProtocols import griddingProtocol

def uteReco (dataObject, **kwargs):

    kernelWidth = kwargs.get('kernel')
    overgriddFactor = kwargs.get('overgridd')
    kernelSamples_nr = 100

    if not kernelWidth:
        kernelWidth = 4
        print('Kernel width not passed. It is set to: ',kernelWidth)

    if not overgriddFactor:
        overgriddFactor = 1.5
        print('OVergridding factor not passed. It is set to: ',overgriddFactor)

    try:
        getattr(dataObject.traj,'data')
    except:
        print("Data object, you're trying to reconstruct doesn't have trajectory data.")
        print("RECONSTRUCTION NOT POSSIBLE")
        sys.exit(1)

    setattr(dataObject.visu_pars_local, 'kernelWidth', kernelWidth)
    setattr(dataObject.visu_pars_local, 'overgriddFactor', overgriddFactor)

    beta = np.pi*np.sqrt( math.pow(kernelWidth,2)/math.pow(overgriddFactor,2)*math.pow((overgriddFactor-0.5),2)-0.8 );
    setattr(dataObject.visu_pars_local, 'beta', beta)


    kernel = np.kaiser(2*kernelSamples_nr,beta)[-kernelSamples_nr:]
    setattr(dataObject.visu_pars_local, 'kernel', kernel)

    setattr(dataObject.visu_pars_local, 'imageSize', dataObject.method.PVM_Matrix[0])
    setattr(dataObject.visu_pars_local, 'gridSize', dataObject.visu_pars_local.overgriddFactor * dataObject.visu_pars_local.imageSize)

    # Now the object has the trajectory info and visu_pars_local stuff as well

    # Density compensation
    gridding_dcf(dataObject)

    """

    # Gridding
    gridding(dataObject)

    # FFT
    kspace = dataObject.visu_pars_local.kspace
    image = np.fft.fftshift(np.fft.ifft2(kspace))
    setattr(dataObject.data2dseq_local, 'data', image)

    """

    return

def DtiEpiReco(dataObject):

    highDim_nr = dataObject.fid.data.shape[3]
    channels_nr = dataObject.fid.data.shape[2]
    dim0 = dataObject.method.PVM_Matrix[0]
    dim1 = dataObject.method.PVM_Matrix[1]

    slices_nr = dataObject.acqp.NI
    repetitions_nr = dataObject.acqp.NR

    result = np.zeros((dim0,dim1,highDim_nr),dtype=dataObject.fid.data.dtype)
    channels = np.zeros((channels_nr,dim0,dim1),dtype=dataObject.fid.data.dtype)

    for _highDim in range(0,highDim_nr):
        for _channel in range(0,channels_nr):
            slope = dataObject.method.PVM_EpiPhaseCorrection[_channel, _highDim%slices_nr]
            offset = dataObject.method.PVM_EpiPhaseCorrection[_channel, _highDim%slices_nr+1]
            channels[_channel,:,:] = epiRecoRaw(dataObject.fid.data[:,:,_channel,_highDim],dataObject, slope, offset)

        result[:,:,_highDim] = utils.combineChannels(channels,'AdaptiveCombine')

    setattr(dataObject.data2dseq_local,'data',result)

    return

def EpiReco(dataObject):

    dataIn = dataObject.fid.data
    NI = dataObject.acqp.NI
    NR = dataObject.acqp.NR
    lines_data = dataIn.shape[0]
    lines_image = dataObject.method.PVM_Matrix[1]
    columns_nr = dataIn.shape[1]
    channels_nr = dataIn.shape[2]
    highDim = dataIn.shape[3]

    dataIn = np.transpose(dataIn,(3,0,1,2))

    dataIn = np.reshape(dataIn,(NR, NI, lines_data, columns_nr,channels_nr))

    dataOut = np.zeros((lines_image, columns_nr, NI, NR),dtype=dataIn.dtype)
    channels = np.zeros((channels_nr, lines_image, columns_nr),dtype=dataIn.dtype)
    correction = dataObject.method.PVM_EpiPhaseCorrection


    for _repetition in range(0,NR):
        for _slice in range(0,NI):
            for _channel in range(0,1):
                slope = dataObject.method.PVM_EpiPhaseCorrection[0, 0]
                offset = dataObject.method.PVM_EpiPhaseCorrection[0, 1]
                channels[_channel,:,:] = epiRecoRaw(dataIn[_repetition, _slice, :, :, _channel],dataObject, slope, offset)

            dataOut[:, :, _slice, _repetition] = utils.combineChannels(channels,'SumOfSquares')


    setattr(dataObject.data2dseq_local,'data', dataOut)

    return

def epiRecoRaw (data,dataObject, slope, offset):

    # Get basics
    dim0_out = dataObject.method.PVM_Matrix[0]
    dim1_out = dataObject.method.PVM_Matrix[1]

    dim0 = data.shape[0]
    dim1 = data.shape[1]

    read_orientation = dataObject.method.PVM_SPackArrReadOrient
    read_steps = range(-dim1/2,dim1/2)

    # Mirror Odd Lines
    data = utils.mirrorOddLines2d(data)

    # Prepare trajectory for gridding
    protocol = griddingProtocol()
    protocol.overgridFactor = 3
    protocol.kernelWidth = 1.5
    protocol.traj = dataObject.method.PVM_EpiTrajAdjkx

    for row in range(0,dim0):
        protocol.data = data[row,:]
        data[row,:] = gridding1(protocol)

    #
    # # Go to hybrid space
    data = np.fft.fftshift(np.fft.ifft2(data,axes=(-1,)),axes=1)

    for _dim0 in range(0,dim0):
        for _dim1 in range(0,dim1):
            if _dim0 % 2 == 0:
                if read_orientation == 'L_R':
                    data[_dim0,_dim1] = data[_dim0,_dim1]
                elif read_orientation == 'A_P':
                    data[_dim0,_dim1] = data[_dim0,_dim1] * np.exp(-1.0*np.complex(0.0,1.0*(offset+slope*read_steps[_dim1])))
            else:
                data[_dim0,_dim1] = data[_dim0,_dim1] * np.exp(1.0*np.complex(0.0,1.0*(offset+slope*read_steps[_dim1])))

    data = np.fft.fftshift(data,axes=1)
    data = np.fft.fft2(data,axes=(-1,))

    kspace = utils.zeroFilling(data,('down','mid'),(dim0_out,dim1_out))

    if read_orientation == 'A_P':
        kspace = np.rot90(kspace,1)

    setattr(dataObject.visu_pars_local,'kspace',kspace)

    image = np.fft.fftshift(np.fft.ifft2(kspace))

    return image

def gridding_dcf1(traj, effective_matrix, osf , iter_nr):

    default_kernel_table_size = 1000
    rfp = 0.96960938
    grid_mat = int(effective_matrix * osf)
    norm_rfp = rfp * osf
    winLen = (grid_mat + norm_rfp * 2.0)/float(grid_mat)

    kernel = loadKernelTable(int(default_kernel_table_size*osf))

    grid_tmp = np.zeros(grid_mat,order='F')
    out = np.ones(grid_mat,order='F')


    for _iter in range(0,iter_nr):

        grid_tmp = gridding1(out, traj, kernel, norm_rfp, winLen, effective_matrix, osf)
        print ("Gridding ok")
        weights_tmp = degridding1(grid_tmp, traj, kernel, norm_rfp, winLen, effective_matrix, osf)
        print("Degridding ok")

        for i in range(0,len(out)):
            if weights_tmp[i] == 0.: out[i] = 0.
            else: out[i] /= weights_tmp[i]


    plt.figure()
    plt.stem(grid_tmp)
    plt.figure()
    plt.stem(weights_tmp)
    plt.figure()
    plt.stem(out)
    plt.show()

    return out


def gridding1 (weight_in, traj, kernel, radiusFOVproduct, windowLength, effectiveMatrix, osf):
    """
    1d version of gridding

    Parameters
    ----------
    protocol

    Returns
    -------

    """

    width = int (effectiveMatrix * osf)

    kernelRadius = radiusFOVproduct / float( width)
    kernelRadius_sqr = kernelRadius**2
    kernelRadius_invSqr = 1.0 / kernelRadius_sqr
    width_inv = 1.0 / width
    center = width / 2
    dist_multiplier = len(kernel) * kernelRadius_invSqr

    dataOut = np.zeros(width,order='F')

    for ind in range(0,traj.shape[0]):

        x = traj[ind] / float(windowLength)
        dat = weight_in[ind]

        ix = x * width + center

        x_minMax = gridMinMax(ix, width, radiusFOVproduct)

        for i in range(x_minMax.min,x_minMax.max):
            ix = (i - center) * width_inv
            dx_sqr = ix - x
            dx_sqr *= dx_sqr

            if dx_sqr < kernelRadius_sqr:
                ker = kernel[int(dx_sqr * dist_multiplier + .5)]
                dataOut[i] += dat * ker

    return dataOut


def degridding1(grid_in, traj, kernel, radiusFOVproduct, winLen, effectiveMatrix, osf):

    width = grid_in.shape[0]
    width_div2 = width / 2
    width_inv = 1.0 / width

    kernelRadius = radiusFOVproduct/float(width)
    kernelRadius_sqr = kernelRadius**2
    kernelRadius_invSqr = 1/kernelRadius_sqr

    dist_multiplier = kernelRadius_invSqr * (len(kernel)-1);

    dataOut = np.zeros(effectiveMatrix * osf, order='F')

    for ind in range(0,traj.shape[0]):
        sum = 0.0
        x = traj[ind] / winLen
        ix = x * width + width_div2
        x_minMax = gridMinMax(ix, width, radiusFOVproduct)

        for i in range(x_minMax.min,x_minMax.max):

            ix = (i - width_div2) * width_inv
            dx_sqr = ix - x
            dx_sqr *= dx_sqr

            if dx_sqr < kernelRadius_sqr:
                ker = kernel[int(dx_sqr * dist_multiplier + .5)]
                sum += grid_in[i] * ker


        dataOut[ind] = sum

    return dataOut


def gridMinMax(x, maximum, radius):

    out = collections.namedtuple('MinMax',['min','max'])
    out.min = int(np.ceil(x-radius))
    out.max = int(np.floor(x+radius))
    if out.min < 0: out.min  = 0
    if out.max > maximum: out.max  = maximum

    return out

def _poly_sdc_kern_0lobes (r):

    POLY_ORDER = 5
    FIT_LEN = 394
    SPECTRAL_LEN = 25600
    FOV = 63

    x = SPECTRAL_LEN * r / float(FOV)

    poly = np.array ((-1.1469041640943728E-13,
                      8.5313956268885989E-11,
                      1.3282009203652969E-08,
                      -1.7986635886194154E-05,
                      3.4511129626832091E-05,
                      0.99992359966186584))

    out = poly[5]

    for i in range(1,POLY_ORDER+1):
        out += pow(x,i) * poly[POLY_ORDER-i]

    if out < 0:
        out = 0

    return out

def loadKernelTable(length):

    rfp = 0.96960938
    out = np.zeros(length,order='F')

    for i in range(0,length):
        out[i] = _poly_sdc_kern_0lobes(math.sqrt(rfp**2 * float(i) / float(length-1)))

    return out

