from __future__ import print_function

import numpy as np
import sys
import inout.writers as writers
import scipy.ndimage as ndi
import math


def mirrorOddLines2d(dataIn):

    dataOut = np.zeros(dataIn.shape,dataIn.dtype)

    dataOut [::2,:] = dataIn[::2,:]
    dataOut [1::2,:] = dataIn[1::2,::-1]

    return dataOut

def combineChannels (dataIn, mode):
    """
    Description
    -----------
    Function to combine images reconstructed from the signal,
    from individual channels. It uses either adaptive combine,
    or sum of squares.
    Parameters
    ----------
    dataIn
        In format:
        [channels, rows, columns]
        or
        [channels, slices, rows, columns]
    mode
        Mode of combination:
        "AdaptiveCombine",  or it's equivalent 0
        or
        "SumOfSquares",     or it's equivalent 1
    Returns
    -------
    dataOut
        Combined data
        In format:
        [rows, columns]
        or
        [slices, rows, columns]
    """
    defaultMode = "AdaptiveCombine"
    # defaultMode = "SumOfSquares"

    dim = np.shape(dataIn)

    if len(dim) is 3:
        channels_nr = np.shape(dataIn)[0]
        slices_nr = 1
        rows_nr = np.shape(dataIn)[1]
        columns_nr = np.shape(dataIn)[2]
        temp = np.zeros((channels_nr, slices_nr, rows_nr, columns_nr),dataIn.dtype)
        temp[:,0,:,:] = dataIn[:,:,:]
        dataIn = temp
    elif len(dim) is 4:
        channels_nr = np.shape(dataIn)[0]
        slices_nr = np.shape(dataIn)[1]
        rows_nr = np.shape(dataIn)[2]
        columns_nr = np.shape(dataIn)[3]
    else:
        print("Data dimensions ", dim, " are not valid")
        sys.exit(1)

    dataOut = np.zeros((slices_nr, rows_nr, columns_nr),dataIn.dtype)
    midSlice = int(slices_nr/2)

    if mode is not 0 \
            and mode is not 1 \
            and mode is not "AdaptiveCombine"\
            and mode is not "SumOfSquares":
        mode = defaultMode

    #   Do adaptive combine
    if mode is "AdaptiveCombine" or mode is 0:
        correlationMatrix = np.zeros((channels_nr,channels_nr),dataIn.dtype)
        for _x in range(0,channels_nr):
            for _y in range(0,channels_nr):
                correlationMatrix[_x,_y] = sum(sum(dataIn[_x,midSlice,:,:] * np.conj(dataIn[_y,midSlice,:,:])))

        val, vec = np.linalg.eigh(correlationMatrix)
        abs_val = np.abs(val)
        max_index = np.argmax(abs_val)
        max_vec = vec[:,max_index]

        for _channel in range(0, channels_nr):
            dataOut = dataOut + dataIn[_channel,:,:,:] * np.conj(max_vec[_channel])
    #   Do sum of squares
    elif mode is "SumOfSquares" or mode is 1:
        for _slice in range (0, slices_nr):
            for _channel in range(0, channels_nr):
                dataOut[_slice,:,:] = dataOut[_slice,:,:] + dataIn[_channel,_slice,:,:] * np.conj(dataIn[_channel,_slice,:,:])
        dataOut = np.sqrt(dataOut)

    if len(dim) is 3:
        dataOut = np.squeeze(dataOut,0)

    return dataOut

def zeroFilling(dataIn, loc, size):
    """
    Description
    -----------
    Function to complete non-complete k-spaces with zeros,
    prior to 2D IFFT
    Parameters
    ----------
    dataIn
        In format:
        [rows, columns]
        or
        [channels, rows, columns]
        or
        [channels, slices, rows, columns]
    axis
        Determine an axis along which axis the zero filling is
        to be done. Options: 0,1
    Returns
    -------
    dataOut
        In format:
        [rows, columns]
        or
        [channels, rows, columns]
        or
        [channels, slices, rows, columns]
    """

    dim0_out = size[0]
    dim1_out = size[1]
    dim0_in = dataIn.shape[0]
    dim1_in = dataIn.shape[1]
    dim0_diff = abs(dim0_out - dim0_in)
    dim1_diff = abs(dim1_out - dim1_in)

    vert = loc[0]
    hor = loc[1]

    dataOut = np.zeros((dim0_out,dim1_out),dataIn.dtype)

    if vert == 'up' and hor == 'left':
        dataOut[0:-dim0_diff,0:-dim1_diff] = dataIn
    elif vert == 'up' and hor == 'mid':
        dataOut[0:-dim0_diff,int(dim1_out/2)-int(dim1_in/2):int(dim1_out/2)+int(dim1_in/2)] = dataIn
    elif vert == 'up' and hor == 'right':
        dataOut[0:-dim0_diff,-(dim1_out-dim1_diff):] = dataIn
    elif vert == 'mid' and hor == 'left':
        dataOut[int(dim0_out/2)-int(dim0_in/2):int(dim0_out/2)+int(dim0_in/2),0:-dim1_diff] = dataIn
    elif vert == 'mid' and hor == 'mid':
        dataOut[int(dim0_out/2)-int(dim0_in/2):int(dim0_out/2)+int(dim0_in/2),int(dim1_out/2)-int(dim1_in/2):int(dim1_out/2)+int(dim1_in/2)] = dataIn
    elif vert == 'mid' and hor == 'right':
        dataOut[int(dim0_out/2)-int(dim0_in/2):int(dim0_out/2)+int(dim0_in/2),-(dim1_out-dim1_diff):] = dataIn
    elif vert == 'down' and hor == 'left':
        dataOut[-(dim0_out-dim0_diff):,0:-dim1_diff] = dataIn
    elif vert == 'down' and hor == 'mid':
        dataOut[-(dim0_out-dim0_diff):,int(dim1_out/2)-int(dim1_in/2):int(dim1_out/2)+int(dim1_in/2)] = dataIn
    elif vert == 'down' and hor == 'right':
        dataOut[-(dim0_out-dim0_diff):,-(dim1_out-dim1_diff):] = dataIn
    else:
        print(zeroFilling.__doc__)

    return dataOut

def snr (dataIn, sigRoi, noiseRoi):

    signal = dataIn[sigRoi[0]:sigRoi[1],sigRoi[2]:sigRoi[3]]
    noise  = dataIn[noiseRoi[0]:noiseRoi[1],noiseRoi[2]:noiseRoi[3]]

    signal_p = np.sum(np.abs(signal)**2)/float(signal.size)
    noise_p = np.sum(np.abs(noise)**2)/float(noise.size)

    snr = 10*np.log10((signal_p-noise_p)/noise_p)

    return snr

def shoeLace(corners):

    n = corners.shape[0]
    area = 0

    for i in range(0,n):
        j = (i + 1) % n
        area += corners[i][0] * corners[j][1]
        area -= corners[j][0] * corners[i][1]


    area = abs(area) * 0.5

    return area
