"""
generate image header
"""
import numpy as np

from astropy.io import fits
import astropy.wcs as pywcs
from collections import OrderedDict

# from scipy import math
import random

import os
import sys
import astropy.coordinates as coord
from astropy.time import Time

def chara2digit(char):
    """ Function to judge and convert characters to digitals

    Parameters
    ----------

    """

    try:
        float(char) # for int, long and float
    except ValueError:
        pass
        return char
    else:
        data = float(char)
        return data


def read_header_parameter(filename='global_header.param'):
    """ Function to read the header parameters

    Parameters
    ----------

    """

    name = []
    value = []
    description = []
    for line in open(filename):
        line = line.strip("\n")
        arr = line.split('|')
#        csvReader = csv.reader(csvDataFile)
#        for arr in csvReader:
        name.append(arr[0])
        # print(arr[0],arr[1])
        value.append(chara2digit(arr[1]))
        description.append(arr[2])

#    print(value)
    return name, value, description

def rotate_CD_matrix(cd, pa_aper):
    """Rotate CD matrix
    
    Parameters
    ----------
    cd: (2,2) array
        CD matrix
    
    pa_aper: float
        Position angle, in degrees E from N, of y axis of the detector
    
    Returns
    -------
    cd_rot: (2,2) array
        Rotated CD matrix
    
    Comments
    --------
    `astropy.wcs.WCS.rotateCD` doesn't work for non-square pixels in that it
    doesn't preserve the pixel scale!  The bug seems to come from the fact
    that `rotateCD` assumes a transposed version of its own CD matrix.
    
    """
    rad = np.deg2rad(-pa_aper)
    mat = np.zeros((2,2))
    mat[0,:] = np.array([np.cos(rad),-np.sin(rad)])
    mat[1,:] = np.array([np.sin(rad),np.cos(rad)])
    cd_rot = np.dot(mat, cd)
    return cd_rot

def calcaluteSLSRotSkyCoor(pix_xy = None,rot_angle = 1, xlen = 9216, ylen = 9232, w = None):
    rad = np.deg2rad(rot_angle)
    mat = np.zeros((2,2))
    mat[0,:] = np.array([np.cos(rad),-np.sin(rad)])
    mat[1,:] = np.array([np.sin(rad),np.cos(rad)])
    center = np.array([xlen/2, ylen/2])
    rot_pix = np.dot(mat, pix_xy-center) + center
    skyCoor = w.wcs_pix2world(np.array([rot_pix]), 1)

    return skyCoor


# def Header_extention(xlen = 9216, ylen = 9232, gain = 1.0, readout = 5.0, dark = 0.02,saturation=90000, row_num = 1, col_num = 1):
#
#     """ Creat an image frame for CCST with multiple extensions
#
#     Parameters
#     ----------
#
#     """
#
#     flag_ltm_x = [0,1,-1,1,-1]
#     flag_ltm_y = [0,1,1,-1,-1]
#     flag_ltv_x = [0,0,1,0,1]
#     flag_ltv_y = [0,0,0,1,1]
#
#     detector_size_x = int(xlen)
#     detector_size_y = int(ylen)
#
#     data_x = str(int(detector_size_x))
#     data_y = str(int(detector_size_y))
#
#     data_sec = '[1:'+data_x+',1:'+data_y+']'
#     e_header_fn = os.path.split(os.path.realpath(__file__))[0] + '/extension_header.param'
#     name, value, description = read_header_parameter(e_header_fn)
#     f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
#     s = f.readline()
#     s = s.strip("\n")
#     filters = s.split(' ')
#     s = f.readline()
#     s = s.strip("\n")
#     filterID = s.split()
#
#     s = f.readline()
#     s = s.strip("\n")
#     CCDID = s.split()
#
#     k = (row_num-1)*6+col_num
#
#     h_iter = 0
#     for n1,v1,d1 in zip(name, value, description):
#         if n1=='EXTNAME':
#             value[h_iter] = 'RAW,'+CCDID[k-1].rjust(2,'0')
#         if n1=='CCDNAME':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0')
#         if n1=='AMPNAME':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0') + ':A'
#         if n1=='GAIN':
#             value[h_iter] = gain
#         if n1=='RDNOISE':
#             value[h_iter] = readout
#         if n1=='SATURATE':
#             value[h_iter] = saturation
#         if n1=='CCDCHIP':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0')
#         if n1=='CCDLABEL':
#             value[h_iter] = filters[k-1] + '-' + filterID[k-1]
#         if n1=='DATASEC':
#             value[h_iter] = data_sec
#
#         h_iter = h_iter + 1
#
#
#     return name, value, description


##9232 9216  898 534 1309 60 -40  -23.4333
def WCS_def(xlen = 9216, ylen = 9232, gapy = 898.0, gapx1 = 534, gapx2 = 1309, ra = 60, dec = -40, pa = -23.433,psize = 0.074, row_num = 1, col_num = 1, filter = 'GI'):

    """ Creat a wcs frame for CCST with multiple extensions

    Parameters
    ----------

    """

    flag_x = [0, 1, -1, 1, -1]
    flag_y = [0, 1, 1, -1, -1]
    flag_ext_x = [0,-1,1,-1,1]
    flag_ext_y = [0,-1,-1,1,1]
    x_num = 6
    y_num = 5
    detector_num = x_num*y_num


    detector_size_x = xlen
    detector_size_y = ylen
    gap_y = gapy
    gap_x = [gapx1,gapx2]
    ra_ref = ra
    dec_ref = dec

    pa_aper = pa

    pixel_size = psize

    gap_x1_num = 3
    gap_x2_num = 2

    y_center = (detector_size_y*y_num+gap_y*(y_num-1))/2

    x_center = (detector_size_x*x_num+gap_x[0]*gap_x1_num+gap_x[1]*gap_x2_num)/2

    gap_x_map = np.array([[0,0,0,0,0],[gap_x[0],gap_x[1],gap_x[1],gap_x[1],gap_x[1]],[gap_x[1],gap_x[0],gap_x[0],gap_x[0],gap_x[0]],[gap_x[0],gap_x[0],gap_x[0],gap_x[0],gap_x[0]],[gap_x[0],gap_x[0],gap_x[0],gap_x[0],gap_x[1]],[gap_x[1],gap_x[1],gap_x[1],gap_x[1],gap_x[0]]])


    # frame_array = np.empty((5,6),dtype=np.float64)
    # print(x_center,y_center)
    
    j = row_num
    i = col_num
    # ccdnum = str((j-1)*5+i)

    x_ref, y_ref = detector_size_x*i + sum(gap_x_map[0:i,j-1]) - detector_size_x/2. , (detector_size_y+gap_y)*j-gap_y-detector_size_y/2

    # print(i,j,x_ref,y_ref,ra_ref,dec_ref)

    r_dat = OrderedDict()

    # name = []
    # value = []
    # description = []
    
    for k in range(1,2):
        
        cd = np.array([[ pixel_size,  0], [0, pixel_size]])/3600.*flag_x[k]
        cd_rot = rotate_CD_matrix(cd, pa_aper)

        # f = open("CCD"+ccdnum.rjust(2,'0')+"_extension"+str(k)+"_wcs.param","w")

        r_dat['EQUINOX'] = 2000.0
        r_dat['WCSDIM'] = 2.0
        r_dat['CTYPE1'] = 'RA---TAN'
        r_dat['CTYPE2'] = 'DEC--TAN'
        r_dat['CRVAL1'] = ra_ref
        r_dat['CRVAL2'] = dec_ref
        r_dat['CRPIX1'] = flag_ext_x[k]*((x_ref+flag_ext_x[k]*detector_size_x/2)-x_center)
        r_dat['CRPIX2'] = flag_ext_y[k]*((y_ref+flag_ext_y[k]*detector_size_y/2)-y_center)
        r_dat['CD1_1'] = cd_rot[0,0]
        r_dat['CD1_2'] = cd_rot[0,1]
        r_dat['CD2_1'] = cd_rot[1,0]
        r_dat['CD2_2'] = cd_rot[1,1]

        if filter in ['GU', 'GV', 'GI']:
            from astropy import wcs

            w = wcs.WCS(naxis=2)
            w.wcs.crpix = [r_dat['CRPIX1'], r_dat['CRPIX2']]
            w.wcs.cd = cd_rot
            w.wcs.crval = [ra_ref, dec_ref]
            w.wcs.ctype = [r_dat['CTYPE1'], r_dat['CTYPE2']]

            # test_center_o = w.wcs_pix2world(np.array([[xlen / 2, ylen / 2]]), 1)

            sls_rot = 1
            if i > 2:
                sls_rot = -sls_rot

            sn_x = 30
            sn_y = 30
            x_pixs = np.zeros(sn_y * sn_x)
            y_pixs = np.zeros(sn_y * sn_x)
            xpixs_line = np.linspace(1, xlen, sn_x)
            ypixs_line = np.linspace(1, ylen, sn_y)

            sky_coors = []

            for n1, y in enumerate(ypixs_line):
                for n2, x in enumerate(xpixs_line):
                    i_pix = n1 * sn_x + n2
                    x_pixs[i_pix] = x
                    y_pixs[i_pix] = y

                    pix_coor = np.array([x, y])
                    sc1 = calcaluteSLSRotSkyCoor(pix_xy=pix_coor, rot_angle=sls_rot, w=w)
                    # print(sc1[0,0],sc1[0,1])
                    sky_coors.append((sc1[0, 0], sc1[0, 1]))

            from astropy.coordinates import SkyCoord
            from astropy.wcs.utils import fit_wcs_from_points

            wcs_new = fit_wcs_from_points(xy=np.array([x_pixs, y_pixs]),
                                          world_coords=SkyCoord(sky_coors, frame="icrs", unit="deg"), projection='TAN')

            # print(wcs_new)
            # test_center = wcs_new.wcs_pix2world(np.array([[xlen / 2, ylen / 2]]), 1)
            #
            # print(test_center - test_center_o)

            r_dat['CD1_1'] = wcs_new.wcs.cd[0, 0]
            r_dat['CD1_2'] = wcs_new.wcs.cd[0, 1]
            r_dat['CD2_1'] = wcs_new.wcs.cd[1, 0]
            r_dat['CD2_2'] = wcs_new.wcs.cd[1, 1]
            r_dat['CRPIX1'] = wcs_new.wcs.crpix[0]
            r_dat['CRPIX2'] = wcs_new.wcs.crpix[1]

            r_dat['CRVAL1'] = wcs_new.wcs.crval[0]
            r_dat['CRVAL2'] = wcs_new.wcs.crval[1]

        return r_dat



#TODO project_cycle is temporary, is not in header defined, delete in future
def generatePrimaryHeader(xlen = 9216, ylen = 9232, pointNum = '1', ra = 60, dec = -40, psize = 0.074, row_num = 1, col_num = 1, date='200930', time_obs='120000', im_type = 'MS', exptime=150., sat_pos = [0.,0.,0.], sat_vel = [0., 0., 0.], project_cycle=6):

    # array_size1, array_size2, flux, sigma = int(argv[1]), int(argv[2]), 1000.0, 5.0

    
    k = (row_num-1)*6+col_num
    # ccdnum = str(k)

    g_header_fn = os.path.split(os.path.realpath(__file__))[0] + '/global_header.header'
    f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
    s = f.readline()
    s = s.strip("\n")
    filters = s.split(' ')
    s = f.readline()
    s = s.strip("\n")
    filterID = s.split()

    s = f.readline()
    s = s.strip("\n")
    CCDID = s.split()

    h_prim = fits.Header()
    h_prim = fits.Header.fromfile(g_header_fn)

    # h_prim['PIXSIZE1'] = xlen
    # h_prim['PIXSIZE2'] = ylen

    h_prim['DATE'] = '20'+date[0:2]+'-' + date[2:4]+'-'+date[4:6] + 'T' + time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['TIME'] = time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    h_prim['DATE-OBS'] = '20'+date[0:2]+'-' + date[2:4]+'-'+date[4:6] + 'T' + time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['TIME-OBS'] = time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['DETECTOR'] = 'CHIP'+CCDID[k-1].rjust(2,'0')
    h_prim['OBJ_RA'] = ra
    h_prim['OBJ_DEC'] = dec
    h_prim['OBJECT'] = '1'+ str(int(project_cycle)) + pointNum.rjust(7,'0')
    h_prim['OBSID'] = '1'+ str(int(project_cycle)) + pointNum.rjust(7,'0')
    # h_prim['TELFOCUS'] = 'f/14'
    h_prim['EXPTIME'] = exptime
    
    # Define file types
    file_type = {'SCI':'sci', 'BIAS':'bias', 'DARK':'dark', 'FLAT':'flat', 'CRS':'cosmic_ray', 'CRD':'cosmic_ray'}
    h_prim['OBSTYPE'] = file_type[im_type]

    # co = coord.SkyCoord(ra, dec, unit='deg')
    #
    # ra_hms = format(co.ra.hms.h, '02.0f') + ':' + format(co.ra.hms.m, '02.0f') + ':' + format(co.ra.hms.s, '05.2f')
    # dec_hms = format(co.dec.dms.d, '02.0f') + ':' + format(abs(co.dec.dms.m), '02.0f') + ':' + format(abs(co.dec.dms.s),
    #                                                                                                   '05.2f')
    #
    # h_prim['RA_NOM'] = ra_hms
    # h_prim['DEC_NOM'] = dec_hms

    h_prim['RA_PNT0'] = ra
    h_prim['DEC_PNT0'] = dec
    h_prim['RA_PNT1'] = ra
    h_prim['DEC_PNT1'] = dec



    # h_prim['PIXSCAL1'] = psize
    # h_prim['PIXSCAL2'] = psize

    ttt = h_prim['DATE']
    tstart = Time(ttt)
    h_prim['EXPSTART'] = round(tstart.mjd, 5)
    h_prim['CABSTART'] = h_prim['EXPSTART']
    # tend = Time(tstart.cxcsec + h_prim['EXPTIME'], format="cxcsec")
    tend = Time(tstart.mjd + h_prim['EXPTIME']/86400., format="mjd")
    h_prim['EXPEND'] = round(tend.mjd, 5)
    h_prim['CABEND'] = h_prim['EXPEND']

    file_start_time = '20' + date[0:6] + time_obs[0:6]
    end_time_str = str(tend.datetime)
    file_end_time = end_time_str[0:4] + end_time_str[5:7]+end_time_str[8:10] + end_time_str[11:13] + end_time_str[14:16] + end_time_str[17:19]
    h_prim['FILENAME'] = 'CSST_MSC_MS_' + im_type + '_' + file_start_time + '_' + file_end_time + '_1' + pointNum.rjust(8, '0') + '_' + CCDID[
        k - 1].rjust(2, '0') + '_L0_1'


    h_prim['POSI0_X'] = sat_pos[0]
    h_prim['POSI0_Y'] = sat_pos[1]
    h_prim['POSI0_Z'] = sat_pos[2]

    h_prim['VELO0_X'] = sat_vel[0]
    h_prim['VELO0_Y'] = sat_vel[1]
    h_prim['VELO0_Z'] = sat_vel[2]
    # h_prim['RA_PNT0'] = ra_hms
    # h_prim['DEC_PNT0'] = dec_hms
    
    # Get version of CSSTSim Package
    from pkg_resources import get_distribution
    # h_prim['SIM_VER'] = (get_distribution("CSSTSim").version, "Version of CSST MSC simulation software")
    h_prim['FITSCREA'] = get_distribution("CSSTSim").version

    return h_prim

def generateExtensionHeader(xlen = 9216, ylen = 9232,ra = 60, dec = -40, pa = -23.433, gain = 1.0, readout = 5.0, dark = 0.02, saturation=90000, psize = 0.074, row_num = 1, col_num = 1, extName='SCI'):

    e_header_fn = os.path.split(os.path.realpath(__file__))[0] + '/extension_header.header'
    f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
    s = f.readline()
    s = s.strip("\n")
    filters = s.split(' ')
    s = f.readline()
    s = s.strip("\n")
    filterID = s.split()

    s = f.readline()
    s = s.strip("\n")
    CCDID = s.split()

    k = (row_num - 1) * 6 + col_num

    h_ext = fits.Header.fromfile(e_header_fn)

    h_ext['CCDCHIP'] = CCDID[k - 1].rjust(2, '0')
    h_ext['CCDLABEL'] = filters[k-1] + '-' + filterID[k-1]
    h_ext['FILTER'] = filters[k-1]
    h_ext['NAXIS1'] = xlen
    h_ext['NAXIS2'] = ylen
    h_ext['EXTNAME'] = extName
    h_ext['GAIN1'] = gain
    h_ext['GAIN2'] = gain
    h_ext['GAIN3'] = gain
    h_ext['GAIN4'] = gain
    h_ext['GAIN5'] = gain
    h_ext['GAIN6'] = gain
    h_ext['GAIN7'] = gain
    h_ext['GAIN8'] = gain
    h_ext['GAIN9'] = gain
    h_ext['GAIN10'] = gain
    h_ext['GAIN11'] = gain
    h_ext['GAIN12'] = gain
    h_ext['GAIN13'] = gain
    h_ext['GAIN14'] = gain
    h_ext['GAIN15'] = gain
    h_ext['GAIN16'] = gain
    h_ext['RDNOIS1'] = readout
    h_ext['RDNOIS2'] = readout
    h_ext['RDNOIS3'] = readout
    h_ext['RDNOIS4'] = readout
    h_ext['RDNOIS5'] = readout
    h_ext['RDNOIS6'] = readout
    h_ext['RDNOIS7'] = readout
    h_ext['RDNOIS8'] = readout
    h_ext['RDNOIS9'] = readout
    h_ext['RDNOIS10'] = readout
    h_ext['RDNOIS11'] = readout
    h_ext['RDNOIS12'] = readout
    h_ext['RDNOIS13'] = readout
    h_ext['RDNOIS14'] = readout
    h_ext['RDNOIS15'] = readout
    h_ext['RDNOIS16'] = readout

    h_ext['PIXSCAL1'] = psize
    h_ext['PIXSCAL2'] = psize

    # h_ext['POS_ANG'] = pa
    header_wcs = WCS_def(xlen=xlen, ylen=ylen, gapy=898.0, gapx1=534, gapx2=1309, ra=ra, dec=dec, pa=pa, psize=psize,
                         row_num=row_num, col_num=col_num, filter = h_ext['FILTER'])

    h_ext['CRPIX1'] = header_wcs['CRPIX1']
    h_ext['CRPIX2'] = header_wcs['CRPIX2']
    h_ext['CRVAL1'] = header_wcs['CRVAL1']
    h_ext['CRVAL2'] = header_wcs['CRVAL2']
    h_ext['CD1_1'] = header_wcs['CD1_1']
    h_ext['CD1_2'] = header_wcs['CD1_2']
    h_ext['CD2_1'] = header_wcs['CD2_1']
    h_ext['CD2_2'] = header_wcs['CD2_2']
    h_ext['EQUINOX'] = header_wcs['EQUINOX']
    h_ext['WCSDIM'] = header_wcs['WCSDIM']
    h_ext['CTYPE1'] = header_wcs['CTYPE1']
    h_ext['CTYPE2'] = header_wcs['CTYPE2']

    return h_ext



def main(argv):

    xlen = int(argv[1])
    ylen = int(argv[2])
    pointingNum = argv[3]
    ra = float(argv[4])
    dec = float(argv[5])
    pSize = float(argv[6])
    ccd_row_num = int(argv[7])
    ccd_col_num = int(argv[8])
    pa_aper = float(argv[9])
    gain = float(argv[10])
    readout = float(argv[11])
    dark = float(argv[12])
    fw = float(argv[13])



    h_prim = generatePrimaryHeader(xlen = xlen, ylen = ylen,ra = ra, dec = dec, psize = pSize, row_num = ccd_row_num, col_num = ccd_col_num, pointNum = pointingNum)

    h_ext = generateExtensionHeader(xlen = xlen, ylen = ylen,ra = ra, dec = dec, pa = pa_aper, gain = gain, readout = readout, dark = dark, saturation=fw, psize = pSize, row_num = ccd_row_num, col_num = ccd_col_num)
    hdu1 = fits.PrimaryHDU(header=h_prim)
    hdu2 = fits.ImageHDU(np.zeros([ylen,xlen]),header = h_ext)

    hdul = fits.HDUList([hdu1,hdu2])

    hdul.writeto(h_prim['FILENAME']+'.fits',output_verify='ignore')

# if __name__ == "__main__":
#     main(sys.argv)