#!/usr/bin/env python

#$Id$
# -------------------------------------------------------------
__version__      = '$Revision$'[11:-3]
__version_date__ = '$Date$'[7:-3]
__author__       = 'R. Bouwens, <bouwens@ucolick.org>, D. Magee, <magee@ucolick.org>'


import numpy as N
import ndimage as nd
import pyfits,os,imagestats,glob,sys
import shutil 

def measurenoise(wghtdata):
    return N.max(wghtdata[::10,::10].flat)

def findbkg(data,flatdata,wghtdata,clip,size):
    goodfact = 0.6
    noise = measurenoise(wghtdata)

    thresh = clip/noise**0.5
    ndq = (wghtdata > noise*0.2) * (flatdata < thresh)
    flatdata_med = nd.median_filter(flatdata.astype(N.float32),size=3)
    ndq *= (flatdata_med < clip/2/noise**0.5)

    pyfits.writeto('dq.fits',ndq.astype(N.float32),clobber=True)
    (xdim,ydim) = data.shape
    ngridx = xdim / size
    ngridy = ydim / size

    Median = N.zeros((ngridx,ngridy), dtype=N.float64)
    NContr = N.zeros((ngridx,ngridy), dtype=N.int32)
    for i in range(ngridx):
        for j in range(ngridy):
            xl = i*size
            xh = N.clip(xl+size,0,xdim)
            yl = j*size
            yh = N.clip(yl+size,0,ydim)
            datasl = data[xl:xh,yl:yh]
            dqsl = ndq[xl:xh,yl:yh]
            good = (dqsl == 1)
            curdata = datasl[good]
            NContr[i,j] = len(curdata)
#            print i,j,len(curdata)
            if NContr[i,j] > 2:
                stats = imagestats.ImageStats(curdata.astype(N.float32), fields='median',binwidth=0.00005)
                Median[i,j] = stats.median
            else:
                Median[i,j] = 0
    MedianM = N.zeros((ngridx+2,ngridy+2), dtype=N.float64)
    for i in range(ngridx):
        for j in range(ngridy):
            notdone = 1
            sizem = 1
            while notdone:
                xl = N.clip(i-sizem,0,ngridx)
                xh = N.clip(i+sizem+1,0,ngridx)
                yl = N.clip(j-sizem,0,ngridy)
                yh = N.clip(j+sizem+1,0,ngridy)  
                datasl = Median[xl:xh,yl:yh]
                dqsl = NContr[xl:xh,yl:yh]
                good = (dqsl > size*size*goodfact)
                curdata = datasl[good]
                # pf = open('curdata.pickle', 'wb')
                # import cPickle
                # cPickle.dump(curdata, pf)
                # pf.close()
                if len(curdata) > 1:
                    stats = imagestats.ImageStats(curdata.astype(N.float32), fields='median',binwidth=0.00005)
                    if not N.isnan(stats.median):
                        notdone = 0
                if notdone:
                    sizem += 1

            MedianM[i+1,j+1] = stats.median
            if i == 0:
                MedianM[0,j+1] = stats.median
            elif i == ngridx - 1:
                MedianM[ngridx+1,j+1] = stats.median                

            if j == 0:
                MedianM[i+1,0] = stats.median
                if i == 0:
                    MedianM[0,0] = stats.median
                elif i == ngridx - 1:
                    MedianM[ngridx+1,0] = stats.median                    
            elif j == ngridy - 1:
                MedianM[i+1,ngridy+1] = stats.median
                if i == 0:
                    MedianM[0,ngridy+1] = stats.median
                elif i == ngridx - 1:
                    MedianM[ngridx+1,ngridy+1] = stats.median
    
    newmedian = nd.affine_transform(MedianM.astype(N.float32),[[1./size,0],[0,1./size]],output_shape=(xdim,ydim),offset=(0.5,0.5))
    ndq = (wghtdata > noise*0.001)
    newmedian *= ndq
    return newmedian

def flatten(sciname, wghtname, outname):
    newf = pyfits.open(sciname)
    data = newf[0].data
    wghtdata = pyfits.getdata(wghtname)
    cmedian = findbkg(data,data,wghtdata,5.,size=100)
    flatdata = data - cmedian
    cmedian = findbkg(data,flatdata,wghtdata,2.5,size=100)
    data -= cmedian
    skyname = sciname.replace('_sci.fits', '_sky.fits')
    skyheader = pyfits.getheader(sciname)
    print 'Writing out sky image %s' % skyname
    pyfits.writeto(skyname,cmedian,skyheader,clobber=True)
    medbkg = N.median(cmedian.flat)
    newf.writeto(outname,clobber=True)
    newf.close()

FileList = glob.glob('*_sci.fits')
FileList = glob.glob('HUDF09J?_F125W_wfc3ir_drz_sci.fits')+glob.glob('HUDF09H?_F160W_wfc3ir_drz_sci.fits')
for FN in FileList:
  print FN
  flatten(FN,FN.replace('_sci.fits','_weight.fits'),'A.fits')
  os.system('cp A.fits '+FN)

