#!/usr/bin/env python
import os
import sys
import string
import numpy as np
import nibabel as nib
from subprocess import call
from optparse import OptionParser
from dipy.reconst.dti import Tensor
from dipy.io.utils import nifti1_symmat
from dipy.io.bvectxt import read_bvec_file, orientation_to_string
from nibabel.trackvis import empty_header, write

usage = """fit_tensor [options] dwi_images"""
parser = OptionParser(usage)
parser.add_option("-b","--bvec",help="text file with gradient directions")
parser.add_option("-r","--root",help="root for files to be saved")
parser.add_option("-m","--mask",default="BET",
                  help="use BET by default, --mask=none to not use mask")
parser.add_option("--threshold",help="threshold passed to BET", default='.2')
parser.add_option("--min_signal",help="minimum valid singal value",
                  type='float', default=1.)
parser.add_option("--save_tensor", action='store_true', help="Save tensor in "
                        "nifti symmat format")
parser.add_option("--scale", type='float', default=1., help="used to scale "
                        "tensor file when --save-tensor is used")
opts, args = parser.parse_args()

def dipysave(img, filename):
    """Some DTI/tools require the qform code to be 1. We set the affine, qform,
    and sfrom to be the same for maximum portibility.
    """
    affine = img.get_affine()
    img.set_sform(affine, 1)
    img.set_qform(affine, 1)
    nib.save(img, filename)

if len(args) != 1:
    parser.print_help()
    parser.exit(2)
dwi_file = args[0]

if opts.root is None:
    pth, file = os.path.split(dwi_file)
    root = string.split(file, os.path.extsep, 1)[0]
    root = os.path.join(pth, root)
else:
    root = opts.root

if opts.bvec is None:
    bvec = root+'.bvec'
else:
    bvec = opts.bvec

img = nib.load(dwi_file)
affine = img.get_affine()
voxel_size = img.get_header().get_zooms()[:3]
data = img.get_data()
bvec, bval = read_bvec_file(bvec)

where_dwi = bval > 0
t2di = data[..., ~where_dwi].mean(-1)
t2di = np.asarray(t2di, 'float32')
dipysave(nib.Nifti1Image(t2di, affine), root+'_t2di.nii.gz')
del t2di

mask = opts.mask
if mask == 'BET':
    env = os.environ
    env['FSLOUTPUTTYPE'] = 'NIFTI_GZ'
    call(['bet2', root+'_t2di', root, '-n', '-f', opts.threshold, '-m'],
         env=env)
    if not os.path.exists(root+'_mask.nii.gz'):
        raise RuntimeError("There was a problem running BET")
    mask = root+"_mask.nii.gz"

if mask.lower() != 'none':
    mask = nib.load(mask).get_data() > 0
else:
    mask = data[..., bval == 0].max(-1) > 0
ten = Tensor(data, bval, bvec.T, mask, min_signal=opts.min_signal)

if opts.save_tensor:
    lower_triangular = ten.lower_triangular()
    lower_triangular *= opts.scale
    lower_triangular = lower_triangular.astype('float32')
    tensor_img = nifti1_symmat(lower_triangular, affine)
    dipysave(tensor_img, root+'_tensor.nii.gz')
    del tensor_img, lower_triangular

L1, L2, L3 = np.rollaxis(ten.evals, -1)
rd = (L2 + L3) / 2
rd = np.asarray(rd, 'float32')
L1 = np.asarray(L1, 'float32')
dipysave(nib.Nifti1Image(L1, affine), root+'_L1.nii.gz')
dipysave(nib.Nifti1Image(rd, affine), root+'_rd.nii.gz')
del L1, L2, L3, rd

md = np.asarray(ten.md(), "float32")
dipysave(nib.Nifti1Image(md, affine), root+'_md.nii.gz')
del md

fa = np.asarray(ten.fa(0), 'float32')
assert not np.isnan(fa).any()
dipysave(nib.Nifti1Image(fa, affine), root+'_fa.nii.gz')

dfa = np.abs(fa[..., None] * ten.evecs[..., 0])
dfa *= 256*(1.-np.finfo(float).eps)
assert dfa.max() < 256
assert dfa.min() >= 0
dfa = dfa.astype('uint8')
dtype = [('R', 'uint8'), ('G', 'uint8'), ('B', 'uint8')]
dfa = dfa.view(dtype)
dfa.shape = dfa.shape[:-1]
dipysave(nib.Nifti1Image(dfa, affine), root+'_dirFA.nii.gz')
del fa, dfa

trk_hdr = empty_header()
trk_hdr['voxel_order'] = orientation_to_string(nib.io_orientation(affine))
trk_hdr['dim'] = ten.shape
trk_hdr['voxel_size'] = voxel_size
# One streamline with two points at [0, 0, 0]
dummy_track = [(np.zeros((2,3), dtype='float32'), None, None)]
write(root+'_dummy.trk', dummy_track, trk_hdr)
