import numpy as np
from numpy import empty,array
from flopy.mbase import Package
from flopy.utils import util_2d,util_3d

class Mt3dBtn(Package):
    'Basic transport package class\n'
    #--changed default ncomp to None and raise error if len(sconc) != ncomp - relieves sconc assignement problems
    def __init__(self, model, ncomp=None, mcomp=1, tunit='D', lunit='M',
                 munit='KG', prsity=0.30, icbund=1, sconc=0.0,
                 cinact=1e30, thkmin=0.01, ifmtcn=0, ifmtnp=0, 
                 ifmtrf=0, ifmtdp=0, savucn=True, nprs=0, timprs=None,
                 obs=None,nprobs=1, chkmas=True, nprmas=1, dt0=0, 
                 mxstrn=50000, ttsmult=1.0, ttsmax=0, 
                 species_names = [], extension='btn'):
        Package.__init__(self, model, extension, 'BTN', 31) 
        nrow, ncol, nlay, nper = self.parent.mf.nrow_ncol_nlay_nper
        self.heading1 = '# BTN for MT3DMS, generated by Flopy.'
        self.heading2 = '#'
        self.mcomp = mcomp
        self.tunit = tunit
        self.lunit = lunit
        self.munit = munit
        self.cinact = cinact
        self.thkmin = thkmin
        self.ifmtcn = ifmtcn
        self.ifmtnp = ifmtnp
        self.ifmtrf = ifmtrf
        self.ifmtdp = ifmtdp
        self.savucn = savucn
        self.nprs = nprs
        self.timprs = timprs
        self.obs = obs
        self.nprobs = nprobs
        self.chkmas = chkmas
        self.nprmas = nprmas
        self.species_names = species_names        
        self.prsity = util_3d(model,(nlay,nrow,ncol),np.float32,\
            prsity,name='prsity',locat=self.unit_number[0])        
        self.icbund = util_3d(model,(nlay,nrow,ncol),np.int,\
            icbund,name='icbund',locat=self.unit_number[0])
        # Starting concentrations
        #--some defense
        if np.isscalar(sconc) and ncomp is None:
            #print 'setting ncomp == 1 and tiling scalar-valued sconc to nlay'
            sconc = [sconc] 
            ncomp = 1
        elif ncomp != len(sconc):
            raise Exception('BTN input error - ncomp not equal to len(sconc)')
        self.ncomp = ncomp
        self.sconc = []       
        for i in range(ncomp):
            u3d = util_3d(model,(nlay,nrow,ncol),np.float32,sconc[i],\
                name='sconc'+str(i+1),locat=self.unit_number[0])
            self.sconc.append(u3d)                    
        self.dt0 = util_2d(model,(nper,),np.float32,dt0,name='dt0')        
        self.mxstrn = util_2d(model,(nper,),np.int,mxstrn,name='mxstrn')        
        self.ttsmult = util_2d(model,(nper,),np.float32,ttsmult,name='ttmult')        
        self.ttsmax = util_2d(model,(nper,),np.float32,ttsmax,name='ttsmax')
        self.parent.add_package(self)
    def write_file(self):
        nrow, ncol, nlay, nper = self.parent.mf.nrow_ncol_nlay_nper
        ModflowDis = self.parent.mf.get_package('DIS')
        # Open file for writing
        f_btn = open(self.fn_path, 'w')
        f_btn.write('#{0:s}\n#{1:s}\n'.format(self.heading1,self.heading2))
        f_btn.write('{0:10d}{1:10d}{2:10d}{3:10d}{4:10d}{5:10d}\n'\
            .format(nlay,nrow,ncol,nper,self.ncomp,self.mcomp))
        f_btn.write('{0:4s}{1:4s}{2:4s}\n'\
            .format(self.tunit,self.lunit,self.munit))        
        if (self.parent.adv != None):         
            f_btn.write('{0:2s}'.format('T'))
        else:            
            f_btn.write('{0:2s}'.format('F'))
        if (self.parent.dsp != None):
            f_btn.write('{0:2s}'.format('T'))
        else:           
            f_btn.write('{0:2s}'.format('F'))
        if (self.parent.ssm != None):            
            f_btn.write('{0:2s}'.format('T'))
        else:
            f_btn.write('{0:2s}'.format('F'))
        if (self.parent.rct != None):            
            f_btn.write('{0:2s}'.format('T'))
        else:           
            f_btn.write('{0:2s}'.format('F'))
        if (self.parent.gcg != None):            
            f_btn.write('{0:2s}'.format('T'))
        else:            
            f_btn.write('{0:2s}'.format('F'))
        f_btn.write('\n')
        flow_package = self.parent.mf.get_package('BCF6')
        if (flow_package != None):
            lc = util_2d(self.parent,(nlay,),np.int,\
                flow_package.laycon.get_value(),name='btn - laytype',\
                locat=self.unit_number[0])
        else:
            flow_package = self.parent.mf.get_package('LPF')
            if (flow_package != None):
                lc = util_2d(self.parent,(nlay,),\
                    np.int,flow_package.laytyp.get_value(),\
                    name='btn - laytype',locat=self.unit_number[0])       
        #--need to reset lc fmtin
        lc.set_fmtin('(40I2)')
        f_btn.write(lc.string)        
        delr = util_2d(self.parent,(ncol,),\
            np.float32,ModflowDis.delr.get_value(),\
            name='delr',locat=self.unit_number[0])
        f_btn.write(delr.get_file_entry())
        
        delc = util_2d(self.parent,(nrow,),np.float32,\
            ModflowDis.delc.get_value(),name='delc',\
            locat=self.unit_number[0])
        f_btn.write(delc.get_file_entry())

        top = util_2d(self.parent,(nrow,ncol),\
            np.float32,ModflowDis.top.array,\
            name='top',locat=self.unit_number[0])
        f_btn.write(top.get_file_entry())
        
        thickness = util_3d(self.parent,(nlay,nrow,ncol),\
            np.float32,ModflowDis.thickness.get_value(),\
            name='thickness',locat=self.unit_number[0])
        f_btn.write(thickness.get_file_entry())
                
        f_btn.write(self.prsity.get_file_entry())
        
        f_btn.write(self.icbund.get_file_entry())
              
        # Starting concentrations
        for s in range(len(self.sconc)):            
            f_btn.write(self.sconc[s].get_file_entry())
               
        f_btn.write('{0:10.0E}{1:10.4f}\n'\
            .format(self.cinact,self.thkmin))
               
        f_btn.write('{0:10d}{1:10d}{2:10d}{3:10d}'\
            .format(self.ifmtcn,self.ifmtnp,self.ifmtrf,self.ifmtdp))
        if (self.savucn == True):
            ss = 'T'
        else:
            ss = 'F'        
        f_btn.write('{0:>10s}\n'.format(ss))
        
        # NPRS
        if (self.timprs == None):            
            f_btn.write('{0:10d}\n'.format(self.nprs))
        else:            
            f_btn.write('{0:10d}\n'.format(len(self.timprs)))        
            timprs = util_2d(self.parent,(len(self.timprs),)\
                ,np.int,self.timprs,name='timprs',fmtin='(8F10.0)')         
            f_btn.write(timprs.string)
        # OBS
        if (self.obs == None):            
            f_btn.write('{0:10d}{1:10d}\n'.format(0,self.nprobs))
        else:
            nobs = self.obs.shape[0]            
            f_btn.write('{0:10d}{1:10d}\n'.format(nobs,self.nprobs))
            for r in range(nobs):                
                f_btn.write('{0:10d}{1:10d}{2:10d}\n'\
                    .format(self.obs[r,0],self.obs[r,1],self.obs[r,2]))
        # CHKMAS, NPRMAS
        if (self.chkmas == True):
            ss = 'T'
        else:
            ss = 'F'        
        f_btn.write('{0:>10s}{1:10d}\n'.format(ss,self.nprmas))
        # PERLEN, NSTP, TSMULT
        for t in range(nper):            
            f_btn.write('{0:10.4g}{1:10d}{2:10f}\n'\
                .format(ModflowDis.perlen[t],ModflowDis.nstp[t],\
                ModflowDis.tsmult[t]))            
            f_btn.write('{0:10.4g}{1:10d}{2:10f}{3:10f}\n'\
                .format(self.dt0[t],self.mxstrn[t],\
                self.ttsmult[t],self.ttsmax[t]))
        f_btn.close() 