from __future__ import print_function
from jcampdx import *
import sys
import numpy as np
import struct

import matplotlib.pyplot as plt

from inout.codeXchange import codeXchange

__author__ = 'Tomas Psorn'


def readBrukerParamFile(path, dataobject):
    """
    Function to read Bruker parameter file in JCAMPDX format. It stores all attributes and
    their values found in a file into particular object parameter.

    Parameters
    ----------
    path (string)
    imagedataobject

    Returns
    -------
    -
    """

    fileType = path.split('/')
    fileType = fileType[len(fileType)-1]

    recoFile = open(path)
    recoData = recoFile.read()
    recoDataLines, comments = strip_comments(recoData)

    ldr_sep, ldr_usr_sep, ldr_dict_sep = '##', '$', '='

    ldrs = [ldr for ldr in recoDataLines.split(ldr_sep) if ldr]

    for ldr in ldrs:
        try:
            ldr_name, ldr_val = ldr.split(ldr_dict_sep)
        except:
            print('Parameter file line is immposible to process, program continues to run')
            print(ldr_dict_sep)
            continue
        ldr_name=ldr_name[1:]
        ldr_val = parse_record(ldr_val)

        command = 'setattr(''dataobject'+'.' + fileType + ', ldr_name, ldr_val)'
        eval(command)

    return

def readBrukerParamFile2(path, dataobject):
    """
    Function to read Bruker parameter file in JCAMPDX format. It stores all attributes and
    their values found in a file into particular object parameter.

    Parameters
    ----------
    path (string)
    imagedataobject

    Returns
    -------
    -
    """

    fileType = path.split('/')
    fileType = fileType[len(fileType)-1]

    recoFile = open(path)
    recoData = recoFile.read()
    recoDataLines, comments = strip_comments(recoData)

    ldr_sep, ldr_usr_sep, ldr_dict_sep = '##', '$', '='

    ldrs = [ldr for ldr in recoDataLines.split(ldr_sep) if ldr]



    for ldr in ldrs:
        try:
            ldr_name, ldr_val = ldr.split(ldr_dict_sep)
        except:
            print('Parameter file line is immposible to process, program continues to run')
            print(ldr_dict_sep)
            continue

        ldr_val = parse_record(ldr_val)


        setattr(eval('dataobject.' + fileType + '2'),ldr_name,ldr_val)

    return

def readBrukerTrajFile(path,rawdataobject):
    """
    Function to read trajectory data. It saves trajectory data into rawdataobject.traj.data.

    Parameters
    ----------
    path
    rawdataobject

    Returns
    -------
    -
    """
    dimensions = 2
    projections_nr = rawdataobject.method.PVM_TrajIntAll
    samples_nr = rawdataobject.method.PVM_TrajSamples

    trajFile = open(path,"rb")
    rawTrajData = np.fromfile(trajFile,dtype='float64',count=-1)
    trajFile.close()

    trajData = np.zeros((samples_nr, projections_nr, dimensions), dtype=rawTrajData.dtype)

    # x coordinates
    trajData [:,:,0] = np.reshape(rawTrajData[::2],(samples_nr, projections_nr),order='F')
     # y coordinates
    trajData [:,:,1] = np.reshape(rawTrajData[1::2],(samples_nr, projections_nr),order='F')

    setattr(rawdataobject.traj, 'data', trajData)

    return

def readBrukerFidFile(path,rawdataobject):
    """
    pvtools original function to read fid data. It lacks some of the advanced functionality
    of the original function. At the end it saves read data into rawdataobject.fid.data.

    Parameters
    ----------
    path
    rawdataobject

    Returns
    -------
    -
    """

    #   Test the presence of all required parameters in visu_pars
    minCondition = ('GO_raw_data_format','BYTORDA','NI','NR','ACQ_size','GO_data_save','GO_block_size', 'AQ_mod');
    all_here = bruker_requires(rawdataobject, minCondition, 'acqp')
    if not all_here:
        print("ERROR: visu_pars file does not provide enough info")
        sys.exit(1)

    NI = rawdataobject.acqp.NI
    NR = rawdataobject.acqp.NR
    numDataHighDim = np.prod(rawdataobject.acqp.ACQ_size[1:])
    numSelectedReceivers = bruker_getSelectedReceivers(rawdataobject)

    #   get data type and number of bits
    if rawdataobject.acqp.GO_raw_data_format == 'GO_32BIT_SGN_INT' and rawdataobject.acqp.BYTORDA == 'little':
        format = np.dtype('i4').newbyteorder('<')
        bits = 32
    elif rawdataobject.acqp.GO_raw_data_format == 'GO_16BIT_SGN_INT' and rawdataobject.acqp.BYTORDA == 'little':
        format = np.dtype('i').newbyteorder('<')
        bits = 16
    elif rawdataobject.acqp.GO_raw_data_format == 'GO_32BIT_FLOAT' and rawdataobject.acqp.BYTORDA == 'little':
        format = np.dtype('f4').newbyteorder('<')
        bits = 32
    elif rawdataobject.acqp.GO_raw_data_format == 'GO_32BIT_SGN_INT' and rawdataobject.acqp.BYTORDA == 'big':
        format = np.dtype('i4').newbyteorder('>')
        bits = 32
    elif rawdataobject.acqp.GO_raw_data_format == 'GO_16BIT_SGN_INT' and rawdataobject.acqp.BYTORDA == 'big':
        format = np.dtype('i').newbyteorder('>')
        bits = 16
    elif rawdataobject.acqp.GO_raw_data_format == 'GO_32BIT_FLOAT' and rawdataobject.acqp.BYTORDA == 'big':
        format = np.dtype('f4').newbyteorder('>')
        bits = 32
    else:
        format = np.dtype('i4').newbyteorder('<')
        print('Data format not specified correctly, set to int32, little endian')
        bits = 32

    if rawdataobject.acqp.GO_block_size == 'Standard_KBlock_Format':
        blockSize = np.ceil(float(rawdataobject.acqp.ACQ_size[0])*numSelectedReceivers*(bits/8)/1024)*1024/(bits/8)
    else:
        blockSize = rawdataobject.acqp.ACQ_size[0]*numSelectedReceivers;

    fidFile = open(path,"rb")
    fidData = np.fromfile(fidFile,dtype=format,count=-1)
    fidFile.close()

    if len(fidData) != blockSize*numDataHighDim*rawdataobject.acqp.NI*rawdataobject.acqp.NR:
        print('Missmatch')

    fidData = np.reshape(fidData,(blockSize, numDataHighDim*rawdataobject.acqp.NI*rawdataobject.acqp.NR),order='F')

    if blockSize != rawdataobject.acqp.ACQ_size[0]*numSelectedReceivers:
        fidData = np.transpose(fidData,(1,0))
        fidData = fidData[:,:(rawdataobject.acqp.ACQ_size[0]*numSelectedReceivers)]
        fidData = np.reshape(fidData,(numDataHighDim*NI*NR, rawdataobject.acqp.ACQ_size[0],numSelectedReceivers),order='F')
        fidData = np.transpose(fidData,(2,1,0))
    else:
        fidData = np.reshape(fidData,(rawdataobject.acqp.ACQ_size[0],numSelectedReceivers,numDataHighDim*NI*NR),order='F')
        fidData = np.transpose(fidData,(1,0,2))

    fidData = fidData[:,0::2,:]+1j*fidData[:,1::2,:]

    dataOut = np.zeros((numSelectedReceivers,fidData.shape[1],fidData.shape[2]),dtype=fidData.dtype)

    dataOut =fidData

    setattr(rawdataobject.fid, 'data', dataOut)

    return

def readBruker2dseq(path, imagedataobject):
    """
    pvtools original function to read 2dseq data. It lacks some of the advanced functionality
    of the original function. At the end it saves read data into imagedataobject.data2seq.data.

    Parameters
    ----------
    path
    imagedataobject

    Returns
    -------

    """

    #   Test the presence of all required parameters in visu_pars
    minCondition = ('VisuCoreWordType', 'VisuCoreByteOrder', 'VisuCoreSize', 'VisuCoreFrameCount',
            'VisuCoreDataSlope', 'VisuCoreDataOffs','VisuCoreFrameType', 'VisuCoreDim', 'VisuCoreDimDesc')
    all_here = bruker_requires(imagedataobject, minCondition, 'visu_pars')
    if not all_here:
        print("ERROR: visu_pars file does not provide enough info")
        sys.exit(1)

    #   Transform used visu_pars variables to local variables
    VisuCoreWordType = imagedataobject.visu_pars.VisuCoreWordType
    VisuCoreByteOrder = imagedataobject.visu_pars.VisuCoreByteOrder
    VisuCoreSize = imagedataobject.visu_pars.VisuCoreSize
    VisuCoreDataSlope = imagedataobject.visu_pars.VisuCoreDataSlope
    VisuCoreDataOffs = imagedataobject.visu_pars.VisuCoreDataOffs
    VisuCoreFrameType = imagedataobject.visu_pars.VisuCoreFrameType
    VisuCoreDim = imagedataobject.visu_pars.VisuCoreDim
    VisuCoreDimDesc = imagedataobject.visu_pars.VisuCoreDimDesc

    #   Get used variables
    blockSize = VisuCoreSize[0]
    numDataHighDim = np.prod(VisuCoreSize[1:])
    NI = len(VisuCoreDataSlope)

    #   get data type and number of bits
    if VisuCoreWordType == '_32BIT_SGN_INT':
        format = np.dtype('int32')
    elif VisuCoreWordType == '_16BIT_SGN_INT':
        format = np.dtype('int16')
    elif VisuCoreWordType == '_32BIT_FLOAT':
        format = np.dtype('float32')
    elif VisuCoreWordType == '_8BIT_USGN_INT':
        format = np.dtype('uint8')
    else:
        print('Data format not specified correctly!')

    if VisuCoreByteOrder == 'littleEndian':
        format = format.newbyteorder('L')
    elif VisuCoreWordType == 'bigEndian':
        format = format.newbyteorder('B')
    else:
        print('Byte order not specified correctly!')

    #   Read 2seq file ADD catch
    twodseqFile = open(path,"rb")
    twodseqData = np.fromfile(twodseqFile,dtype=format,count=-1)
    twodseqFile.close()
    twodseqData = np.reshape(twodseqData,(VisuCoreSize[1],-1),order='F')


    dataOut = np.zeros((VisuCoreSize[1],VisuCoreSize[0],NI),format, order='F')
    # dataOut = np.zeros((VisuCoreSize[0], VisuCoreSize[1], NI), format)

    print (dataOut.shape)
    print (twodseqData.shape)

    for i in range(0,NI):
        dataOut[:,:,i] = twodseqData[:,i*VisuCoreSize[0]:(i+1)*VisuCoreSize[0]]

    dataOut = np.transpose(dataOut, (1,0,2))

    setattr(imagedataobject.data2dseq, 'data', dataOut)
    return


def bruker_requires(dataobject,minCondition, fileType ):
    """
    pvtools original function to control the presence of an essential parameters
    in a dataobject's parameter defined by fileType

    Parameters
    ----------
    dataobject
    minCondition (tuple of strings)
    fileType (string)

    Returns
    -------
    all_here (bool)
    """
    all_here = True
    for conditionElement in minCondition:
        condition = 'dataobject'+'.'+fileType+'.'+ conditionElement
        try:
            eval(condition)
        except AttributeError:
            print('ERROR: ', fileType,' file does not contain essential parameter: ',conditionElement)
            all_here = False
    return all_here

def bruker_getSelectedReceivers(rawdataobject):
    """
    pvtools original function to determine number of channels used for acquisition

    Parameters
    ----------
    rawdataobject

    Returns
    -------
    Number of channels
    """
    if rawdataobject.acqp.ACQ_experiment_mode == 'ParallelExperiment':
        if hasattr(rawdataobject.acqp,'GO_ReceiverSelect'):
            if rawdataobject.acqp.GO_ReceiverSelect[0].isalpha():
                numSelectedReceivers = 0
                for channel in rawdataobject.acqp.GO_ReceiverSelect:
                    if channel == 'Yes':
                        numSelectedReceivers += 1
        elif hasattr(rawdataobject.acqp,'ACQ_ReceiverSelect'):
            if rawdataobject.acqp.ACQ_ReceiverSelect[0].isalpha():
                numSelectedReceivers = 0
                for channel in rawdataobject.acqp.ACQ_ReceiverSelect:
                    if channel == 'Yes':
                        numSelectedReceivers += 1
        else:
            print('Information about number of receivers is unknown, check your acqp file')
    else:
        numSelectedReceivers = 1
    return numSelectedReceivers

def fidMethodBasedReshape (rawdataobject):
    """
    Description
    -----------
    Call a particular function, according to the pulse program, to reshape fid data.
    If no function for handling of the fid of a given pulse program (sequence)
    is not specified, nothing happens.

    To add a new function for a certain sequence, just add an another elif line.

    Parameters
    ----------
    rawdataobject

    Returns
    -------

    -
    """

    method = rawdataobject.acqp.PULPROG
    if method == 'UTE.ppg': fidHandle_UTE(rawdataobject)
    elif method == 'FAIR_RARE.ppg': fidHandle_FAIR_RARE(rawdataobject)
    elif method == 'DtiEpi.ppg' or method == 'EPI.ppg': fidHandle_Epi(rawdataobject)
    else: print('Function to reshape fid data of this sequence is not developed yet, your data is in a pv-tools\
                like form')
    return

def fidHandle_UTE(rawdataobject):
    minCondition = ('NI',)
    all_here = bruker_requires(rawdataobject, minCondition, 'acqp')
    if not all_here:
        print("ERROR: acqp file does not provide enough info")
        sys.exit(1)
    slices_nr = rawdataobject.acqp.NI
    return

def fidHandle_FAIR_RARE(rawdataobject):
    """
    Description
    -----------
    Simple function to reshape fid data using prior knowledge about data storage.
    For FAIR_RARE sequence in this case. It is just a draft version so far, some
    improvements are to be made.

    Parameters
    ----------
    rawdataobject

    Returns
    -------
    -
    """

    minCondition = ('NR','ACQ_rare_factor')
    paramFile = 'acqp'
    all_here = bruker_requires(rawdataobject, minCondition, paramFile)
    if not all_here:
        print("ERROR: ", paramFile, " file does not provide enough info")
        sys.exit(1)

    dataIn = rawdataobject.fid.data
    dataIn_type = dataIn.dtype  # so the data format is preserved

    #   Get dataOut dimension parameters, create dataOut np.array
    channels_nr = dataIn.shape[0]
    repetitions_nr = rawdataobject.acqp.NR     # just for the readability
    images_nr = rawdataobject.acqp.NI
    views_nr = rawdataobject.acqp.ACQ_rare_factor
    samples_nr = dataIn.shape[1]
    dataOut = np.zeros((channels_nr,repetitions_nr,views_nr,samples_nr),dataIn_type)
    if repetitions_nr > 1:
        for _repetition in range(0,repetitions_nr):
            dataOut[:,_repetition,:,:] = dataIn[:,:,_repetition*views_nr:(_repetition+1)*views_nr].transpose((0,2,1))

    rawdataobject.fid.data = dataOut    # replace
    return

def fidHandle_Epi(rawdataobject):

    dataIn = rawdataobject.fid.data

    minCondition = ('NR','NI')
    all_here = bruker_requires(rawdataobject, minCondition, 'acqp')
    if not all_here:
        print("ERROR: acqp file does not provide enough info")
        sys.exit(1)

    minCondition = ('PVM_EncMatrix',)
    all_here = bruker_requires(rawdataobject, minCondition, 'method')
    if not all_here:
        print("ERROR: acqp file does not provide enough info")
        sys.exit(1)

    slices_nr = rawdataobject.acqp.NI
    repetitions_nr = rawdataobject.acqp.NR
    channels_nr = bruker_getSelectedReceivers(rawdataobject)
    lines_nr = rawdataobject.method.PVM_EncMatrix[0]
    samples_nr = rawdataobject.method.PVM_EncMatrix[1]
    numDataHighDim = dataIn.shape[2]

    # dataIn is supposed to be in the shape of [channels_nr, acq size[0], numhigh dim]

    dataIn = dataIn[:,-lines_nr*samples_nr:,:] # to get rid of navigator data, if present
    dataIn = np.transpose(dataIn,(1,0,2))
    dataIn = np.reshape(dataIn,(samples_nr,lines_nr,channels_nr,numDataHighDim))

    # export as [samples, lines, channels, numDataHighDim]

    del rawdataobject.fid.data
    setattr(rawdataobject.fid,'data',dataIn)

    return