#!/usr/bin/python
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the NiBabel package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
# Copyright (C) 2011 Christian Haselgrove

import sys
import os
import stat
import errno
import time
import locale
import fuse
import nibabel.dft as dft

uid = os.getuid()
gid = os.getgid()
encoding = locale.getdefaultlocale()[1]

fuse.fuse_python_api = (0, 2)

class FileHandle:

    def __init__(self, fno):
        self.fno = fno
        self.keep_cache = False
        self.direct_io = False
        return

    def __str__(self):
        return 'FileHandle(%d)' % self.fno

class DICOMFS(fuse.Fuse):

    def __init__(self, *args, **kwargs):
        fuse.Fuse.__init__(self, *args, **kwargs)
        self.fhs = {}
        return

    def get_paths(self):
        paths = {}
        for study in dft.get_studies(self.dicom_path):
            pd = paths.setdefault(study.patient_name_or_uid(), {})
            patient_info = 'patient information\n'
            patient_info = 'name: %s\n' % study.patient_name
            patient_info += 'ID: %s\n' % study.patient_id
            patient_info += 'birth date: %s\n' % study.patient_birth_date
            patient_info += 'sex: %s\n' % study.patient_sex
            pd['INFO'] = patient_info.encode('ascii', 'replace')
            study_datetime = '%s_%s' % (study.date, study.time)
            study_info = 'study info\n'
            study_info += 'UID: %s\n' % study.uid
            study_info += 'date: %s\n' % study.date
            study_info += 'time: %s\n' % study.time
            study_info += 'comments: %s\n' % study.comments
            d = {'INFO': study_info.encode('ascii', 'replace')}
            for series in study.series:
                series_info = 'series info\n'
                series_info += 'UID: %s\n' % series.uid
                series_info += 'number: %s\n' % series.number
                series_info += 'description: %s\n' % series.description
                series_info += 'rows: %d\n' % series.rows
                series_info += 'columns: %d\n' % series.columns
                series_info += 'bits allocated: %d\n' % series.bits_allocated
                series_info += 'bits stored: %d\n' % series.bits_stored
                series_info += 'storage instances: %d\n' % len(series.storage_instances)
                d[series.number] = {'INFO': series_info.encode('ascii', 'replace'), 
                                    '%s.nii' % series.number: (series.nifti_size, series.as_nifti), 
                                    '%s.png' % series.number: (series.png_size, series.as_png)}
            pd[study_datetime] = d
        return paths

    def match_path(self, path):
        wd = self.get_paths()
        if path == '/':
            print 'return root'
            return wd
        for part in path.lstrip('/').split('/'):
            print path, part
            if part not in wd:
                return None
            wd = wd[part]
        print 'return'
        return wd

    def readdir(self, path, fh):
        print 'readdir'
        print path
        matched_path = self.match_path(path)
        if matched_path is None:
            return -errno.ENOENT
        print 'match', matched_path
        fnames = [ k.encode('ascii', 'replace') for k in matched_path.keys() ]
        fnames.append('.')
        fnames.append('..')
        return [ fuse.Direntry(f) for f in fnames ]
    
    def getattr(self, path):
        print 'getattr'
        print 'path:', path
        matched_path = self.match_path(path)
        print matched_path
        now = time.time()
        st = fuse.Stat()
        if isinstance(matched_path, dict):
            st.st_mode = stat.S_IFDIR | 0755
            st.st_ctime = now
            st.st_mtime = now
            st.st_atime = now
            st.st_uid = uid
            st.st_gid = gid
            st.st_nlink = len(matched_path)
            return st
        if isinstance(matched_path, str):
            st.st_mode = stat.S_IFREG | 0644
            st.st_ctime = now
            st.st_mtime = now
            st.st_atime = now
            st.st_uid = uid
            st.st_gid = gid
            st.st_size = len(matched_path)
            st.st_nlink = 1
            return st
        if isinstance(matched_path, tuple):
            st.st_mode = stat.S_IFREG | 0644
            st.st_ctime = now
            st.st_mtime = now
            st.st_atime = now
            st.st_uid = uid
            st.st_gid = gid
            st.st_size = matched_path[0]()
            st.st_nlink = 1
            return st
        return -errno.ENOENT
    
    def open(self, path, flags):
        print 'open'
        print path
        matched_path = self.match_path(path)
        if matched_path is None:
            return -errno.ENOENT
        for i in xrange(1, 10):
            if i not in self.fhs:
                if isinstance(matched_path, str):
                    self.fhs[i] = matched_path
                elif isinstance(matched_path, tuple):
                    self.fhs[i] = matched_path[1]()
                else:
                    raise -errno.EFTYPE
                return FileHandle(i)
        raise -errno.ENFILE

    # not done
    def read(self, path, size, offset, fh):
        print 'read'
        print path
        print size
        print offset
        print fh
        return self.fhs[fh.fno][offset:offset+size]

    def release(self, path, flags, fh):
        print 'release'
        print path
        print fh
        del self.fhs[fh.fno]
        return
    
progname = os.path.basename(sys.argv[0])
if len(sys.argv) != 3:
    sys.stderr.write('usage: %s <directory containing DICOMs> <mount point>\n' % progname)
    sys.exit(1)

fs = DICOMFS(dash_s_do='setsingle')
fs.parse(['-f', '-s', sys.argv[2]])
fs.dicom_path = sys.argv[1].decode(encoding)
try:
    fs.main()
except fuse.FuseError:
    # fuse prints the error message
    sys.exit(1)

sys.exit(0)

# eof
