"""
generate image header
"""
import numpy as np

from astropy.io import fits
import astropy.wcs as pywcs

from scipy import math
import random

import os
import sys

def chara2digit(char):
    """ Function to judge and convert characters to digitals

    Parameters
    ----------

    """

    try:
        float(char) # for int, long and float
    except ValueError:
        pass
        return char
    else:
        data = float(char)
        return data


def read_header_parameter(filename='global_header.param'):
    """ Function to read the header parameters

    Parameters
    ----------

    """

    name = []
    value = []
    description = []
    for line in open(filename):
        line = line.strip("\n")
        arr = line.split('|')
#        csvReader = csv.reader(csvDataFile)
#        for arr in csvReader:
        name.append(arr[0])
        value.append(chara2digit(arr[1]))
        description.append(arr[2])

#    print(value)
    return name, value, description

def rotate_CD_matrix(cd, pa_aper):
    """Rotate CD matrix
    
    Parameters
    ----------
    cd: (2,2) array
        CD matrix
    
    pa_aper: float
        Position angle, in degrees E from N, of y axis of the detector
    
    Returns
    -------
    cd_rot: (2,2) array
        Rotated CD matrix
    
    Comments
    --------
    `astropy.wcs.WCS.rotateCD` doesn't work for non-square pixels in that it
    doesn't preserve the pixel scale!  The bug seems to come from the fact
    that `rotateCD` assumes a transposed version of its own CD matrix.
    
    """
    rad = np.deg2rad(-pa_aper)
    mat = np.zeros((2,2))
    mat[0,:] = np.array([np.cos(rad),-np.sin(rad)])
    mat[1,:] = np.array([np.sin(rad),np.cos(rad)])
    cd_rot = np.dot(mat, cd)
    return cd_rot


def Header_extention(xlen = 9232, ylen = 9216, gain = 1.0, readout = 5.0, dark = 0.02,saturation=90000, row_num = 1, col_num = 1):

    """ Creat an image frame for CCST with multiple extensions

    Parameters
    ----------

    """

    flag_ltm_x = [0,1,-1,1,-1]
    flag_ltm_y = [0,1,1,-1,-1]
    flag_ltv_x = [0,0,1,0,1]
    flag_ltv_y = [0,0,0,1,1]

    detector_size_x = int(xlen)
    detector_size_y = int(ylen)

    data_x = str(int(detector_size_x))
    data_y = str(int(detector_size_y))

    data_sec = '[1:'+data_x+',1:'+data_y+']'

    name = []
    value = []
    description = []

    for k in range(1,2):
        # f = open("extension"+str(k)+"_image.param","w")
        j = row_num
        i = col_num
        ccdnum = str((j-1)*5+i)
        name = ['EXTNAME',
                'BSCALE',
                'BZERO',
                'OBSID',
                'CCDNAME',
                'AMPNAME',
                'GAIN',
                'RDNOISE',
                'DARK',
                'SATURATE',
                'RSPEED',
                'CHIPTEMP',
                'CCDCHIP',
                'DATASEC',
                'CCDSUM',
                'NSUM',
                'LTM1_1',
                'LTM2_2',
                'LTV1',
                'LTV2',
                'ATM1_1',
                'ATM2_2',
                'ATV1',
                'ATV2',
                'DTV1',
                'DTV2',
                'DTM1_1',
                'DTM2_2']

        value = ['IM'+str(k),
                 1.0,
                 0.0,
                 'CSST.20200101T000000',
                 'ccd' + ccdnum.rjust(2,'0'),
                 'ccd' + ccdnum.rjust(2,'0') + ':'+str(k), 
                 gain, 
                 readout, 
                 dark,
                 saturation, 
                 10.0, 
                 -100.0, 
                 'ccd' + ccdnum.rjust(2,'0'), 
                 data_sec, 
                 '1 1', 
                 '1 1', 
                 flag_ltm_x[k], 
                 flag_ltm_y[k], 
                 flag_ltv_x[k]*(detector_size_x-20*2+1), 
                 flag_ltv_y[k]*(detector_size_y+1), 
                 flag_ltm_x[k], 
                 flag_ltm_y[k], 
                 flag_ltv_x[k]*(detector_size_x-20*2+1), 
                 flag_ltv_y[k]*(detector_size_y+1), 
                 0, 
                 0, 
                 1, 
                 1]
        
        description = ['Extension name',
                       ' ',
                       ' ',
                       'Observation ID',
                       'CCD name',
                       'Amplifier name',
                       'Gain (e-/ADU)',
                       'Readout noise (e-/pixel)',
                       'Dark noise (e-/pixel/s)',
                       'Saturation (e-)',
                       'Read speed',
                       'Chip temperature',
                       'CCD chip ID',
                       'Data section',
                       'CCD pixel summing',
                       'CCD pixel summing',
                       'CCD to image transformation',
                       'CCD to image transformation',
                       'CCD to image transformation',
                       'CCD to image transformation',
                       'CCD to amplifier transformation',
                       'CCD to amplifier transformation',
                       'CCD to amplifier transformation',
                       'CCD to amplifier transformation',
                       'CCD to detector transformatio',
                       'CCD to detector transformatio',
                       'CCD to detector transformatio',
                       'CCD to detector transformatio']
    return name, value, description


##9232 9216  898 534 1309 60 -40  -23.4333
def WCS_def(xlen = 9232, ylen = 9216, gapx = 898.0, gapy1 = 534, gapy2 = 1309, ra = 60, dec = -40, pa = -23.433,psize = 0.074, row_num = 1, col_num = 1):

    """ Creat a wcs frame for CCST with multiple extensions

    Parameters
    ----------

    """

    flag_x = [0, 1, -1, 1, -1]
    flag_y = [0, 1, 1, -1, -1]
    flag_ext_x = [0,-1,1,-1,1]
    flag_ext_y = [0,-1,-1,1,1]
    x_num = 5
    y_num = 6
    detector_num = x_num*y_num


    detector_size_x = xlen
    detector_size_y = ylen
    gap_x = gapx
    gap_y = [gapy1,gapy2]
    ra_ref = ra
    dec_ref = dec

    pa_aper = pa

    pixel_size = psize

    gap_y1_num = 3
    gap_y2_num = 2

    x_center = (detector_size_x*x_num+gap_x*(x_num-1))/2
    y_center = (detector_size_y*y_num+gap_y[0]*gap_y1_num+gap_y[1]*gap_y2_num)/2

    gap_y_map = np.array([[0,0,0,0,0],[gap_y[0],gap_y[1],gap_y[1],gap_y[1],gap_y[1]],[gap_y[1],gap_y[0],gap_y[0],gap_y[0],gap_y[0]],[gap_y[0],gap_y[0],gap_y[0],gap_y[0],gap_y[0]],[gap_y[0],gap_y[0],gap_y[0],gap_y[0],gap_y[1]],[gap_y[1],gap_y[1],gap_y[1],gap_y[1],gap_y[0]]])


    frame_array = np.empty((5,6),dtype=np.float64)
    # print(x_center,y_center)
    
    j = row_num
    i = col_num
    ccdnum = str((j-1)*5+i)

    x_ref, y_ref = (detector_size_x+gap_x)*i-gap_x-detector_size_x/2, detector_size_y*j + sum(gap_y_map[0:j,i-1]) - detector_size_y/2

    # print(i,j,x_ref,y_ref,ra_ref,dec_ref)

    name = []
    value = []
    description = []
    
    for k in range(1,2):
        
        cd = np.array([[ pixel_size,  0], [0, pixel_size]])/3600.*flag_x[k]
        cd_rot = rotate_CD_matrix(cd, pa_aper)

        # f = open("CCD"+ccdnum.rjust(2,'0')+"_extension"+str(k)+"_wcs.param","w")

        name = ['EQUINOX',
                'WCSDIM',
                'CTYPE1',
                'CTYPE2',
                'CRVAL1',
                'CRVAL2',
                'CRPIX1',
                'CRPIX2',
                'CD1_1',
                'CD1_2',
                'CD2_1',
                'CD2_2']
        value = [2000.0,
                 2.0,
                 'RA---TAN',
                 'DEC--TAN',
                 ra_ref,
                 dec_ref,
                 flag_ext_x[k]*((x_ref+flag_ext_x[k]*detector_size_x/2)-x_center),
                 flag_ext_y[k]*((y_ref+flag_ext_y[k]*detector_size_y/2)-y_center),
                 cd_rot[0,0],
                 cd_rot[0,1],
                 cd_rot[1,0],
                 cd_rot[1,1]]
        description = ['Equinox of WCS',
                       'WCS Dimensionality',
                       'Coordinate type',
                       'Coordinate typ',
                       'Coordinate reference value',
                       'Coordinate reference value',
                       'Coordinate reference pixel',
                       'Coordinate reference pixel',
                       'Coordinate matrix',
                       'Coordinate matrix',
                       'Coordinate matrix',
                       'Coordinate matrix']

    return name, value, description



def generatePrimaryHeader(xlen = 9232, ylen = 9216,pointNum = '1', ra = 60, dec = -40, psize = 0.074, row_num = 1, col_num = 1):

    # array_size1, array_size2, flux, sigma = int(argv[1]), int(argv[2]), 1000.0, 5.0
    filerParm_fn = os.path.split(os.path.realpath(__file__))[0] + '/filter.lst'
    f = open(filerParm_fn)
    s = f.readline()
    s = s.strip("\n")
    filter = s.split(' ')
    
    k = (row_num-1)*5+col_num
    ccdnum = str(k)

    g_header_fn = os.path.split(os.path.realpath(__file__))[0] + '/global_header.param'
    name, value, description = read_header_parameter(g_header_fn)

    h_prim = fits.Header()

    date = '200930'
    time_obs = '120000'

    for i in range(len(name)):
        if(name[i]=='FILTER'):
            value[i] = filter[k-1]
        
        if(name[i]=='FILENAME'):
            value[i] = 'CSST_' + date + '_' +time_obs + '_' + pointNum.rjust(6,'0') + '_' +ccdnum.rjust(2,'0')+'_raw'
        
        if(name[i]=='DETSIZE'):
            value[i] = '[1:' + str(int(xlen)) + ',1:'+  str(int(ylen)) + ']'

        if(name[i]=='PIXSCAL1'):
            value[i] = str(psize)

        if(name[i]=='PIXSCAL2'):
            value[i] = str(psize)
        

        h_prim[name[i]] = (value[i],description[i])

    h_prim.add_comment('==================================================================',after='FILETYPE')
    h_prim.add_comment('Target information')
    h_prim.add_comment('==================================================================')

    h_prim.add_comment('==================================================================',after='EQUINOX')
    h_prim.add_comment('Exposure information')
    h_prim.add_comment('==================================================================')

    h_prim.add_comment('==================================================================',after='MJDEND')
    h_prim.add_comment('Telescope information')
    h_prim.add_comment('==================================================================')

    h_prim.add_comment('==================================================================',after='REFFRAME')
    h_prim.add_comment('Detector information')
    h_prim.add_comment('==================================================================')

    h_prim.add_comment('==================================================================',after='FILTER')
    h_prim.add_comment('Other information')
    h_prim.add_comment('==================================================================')

    return h_prim

def generateExtensionHeader(xlen = 9232, ylen = 9216,ra = 60, dec = -40, pa = -23.433, gain = 1.0, readout = 5.0, dark = 0.02, saturation=90000, psize = 0.074, row_num = 1, col_num = 1):

    h_ext = fits.Header()

    for i in range(1,2):
        
        # NAXIS1:Number of pixels per row;  NAXIS2:Number of rows
        h_ext['NAXIS1'] = xlen
        h_ext['NAXIS2'] = ylen
        name, value, description = Header_extention(xlen = xlen, ylen = ylen, gain = gain, readout = readout, dark = dark, saturation=saturation, row_num = row_num, col_num = col_num)

        for j in range(len(name)):
            h_ext[name[j]] = (value[j],description[j])

        name, value, description = WCS_def(xlen = xlen, ylen = ylen, gapx = 898.0, gapy1 = 534, gapy2 = 1309, ra = ra, dec = dec, pa = pa ,psize = psize, row_num = row_num, col_num = col_num)

        for j in range(len(name)):
            h_ext[name[j]] = (value[j],description[j])

        h_ext.add_comment('==================================================================',after='OBSID')
        h_ext.add_comment('Readout information')
        h_ext.add_comment('==================================================================')

        h_ext.add_comment('==================================================================',after='CHIPTEMP')
        h_ext.add_comment('Chip information')
        h_ext.add_comment('==================================================================')

        h_ext.add_comment('==================================================================',after='DTM2_2')
        h_ext.add_comment('WCS information')
        h_ext.add_comment('==================================================================')

    return h_ext



def main(argv):

    xlen = int(argv[1])
    ylen = int(argv[2])
    pointingNum = argv[3]
    ra = float(argv[4])
    dec = float(argv[5])
    pSize = float(argv[6])
    ccd_row_num = int(argv[7])
    ccd_col_num = int(argv[8])
    pa_aper = float(argv[9])
    gain = float(argv[10])
    readout = float(argv[11])
    dark = float(argv[12])
    fw = float(argv[13])



    h_prim = generatePrimaryHeader(xlen = xlen, ylen = ylen,ra = ra, dec = dec, psize = pSize, row_num = ccd_row_num, col_num = ccd_col_num, pointNum = pointingNum)

    h_ext = generateExtensionHeader(xlen = xlen, ylen = ylen,ra = ra, dec = dec, pa = pa_aper, gain = gain, readout = readout, dark = dark, saturation=fw, psize = pSize, row_num = ccd_row_num, col_num = ccd_col_num)
    hdu1 = fits.PrimaryHDU(header=h_prim)
    hdu2 = fits.ImageHDU(np.zeros([ylen,xlen]),header = h_ext)

    hdul = fits.HDUList([hdu1,hdu2])

    hdul.writeto(h_prim['FILENAME']+'.fits',output_verify='ignore')

# if __name__ == "__main__":
#     main(sys.argv)