"""
generate image header
"""
import numpy as np

from astropy.io import fits
import astropy.wcs as pywcs
from collections import OrderedDict

# from scipy import math
import random

import os
import sys
import astropy.coordinates as coord
from astropy.coordinates import SkyCoord
from astropy.wcs.utils import fit_wcs_from_points
from astropy.time import Time
from astropy import wcs
from observation_sim.config._util import get_obs_id, get_file_type

from datetime import datetime, timezone
# import socket
import platform
import toml


def chara2digit(char):
    """ Function to judge and convert characters to digitals

    Parameters
    ----------

    """

    try:
        float(char)  # for int, long and float
    except ValueError:
        pass
        return char
    else:
        data = float(char)
        return data


def read_header_parameter(filename='global_header.param'):
    """ Function to read the header parameters

    Parameters
    ----------

    """

    name = []
    value = []
    description = []
    for line in open(filename):
        line = line.strip("\n")
        arr = line.split('|')
#        csvReader = csv.reader(csvDataFile)
#        for arr in csvReader:
        name.append(arr[0])
        # print(arr[0],arr[1])
        value.append(chara2digit(arr[1]))
        description.append(arr[2])

#    print(value)
    return name, value, description


def rotate_CD_matrix(cd, pa_aper):
    """Rotate CD matrix

    Parameters
    ----------
    cd: (2,2) array
        CD matrix

    pa_aper: float
        Position angle, in degrees E from N, of y axis of the detector

    Returns
    -------
    cd_rot: (2,2) array
        Rotated CD matrix

    Comments
    --------
    `astropy.wcs.WCS.rotateCD` doesn't work for non-square pixels in that it
    doesn't preserve the pixel scale!  The bug seems to come from the fact
    that `rotateCD` assumes a transposed version of its own CD matrix.

    """
    rad = np.deg2rad(-pa_aper)
    mat = np.zeros((2, 2))
    mat[0, :] = np.array([np.cos(rad), -np.sin(rad)])
    mat[1, :] = np.array([np.sin(rad), np.cos(rad)])
    cd_rot = np.dot(mat, cd)
    return cd_rot


def calcaluteSLSRotSkyCoor(pix_xy=None, rot_angle=1, xlen=9216, ylen=9232, w=None):
    rad = np.deg2rad(rot_angle)
    mat = np.zeros((2, 2))
    mat[0, :] = np.array([np.cos(rad), -np.sin(rad)])
    mat[1, :] = np.array([np.sin(rad), np.cos(rad)])
    center = np.array([xlen/2, ylen/2])
    rot_pix = np.dot(mat, pix_xy-center) + center
    skyCoor = w.wcs_pix2world(np.array([rot_pix]), 1)

    return skyCoor


# def Header_extention(xlen = 9216, ylen = 9232, gain = 1.0, readout = 5.0, dark = 0.02,saturation=90000, row_num = 1, col_num = 1):
#
#     """ Creat an image frame for CCST with multiple extensions
#
#     Parameters
#     ----------
#
#     """
#
#     flag_ltm_x = [0,1,-1,1,-1]
#     flag_ltm_y = [0,1,1,-1,-1]
#     flag_ltv_x = [0,0,1,0,1]
#     flag_ltv_y = [0,0,0,1,1]
#
#     detector_size_x = int(xlen)
#     detector_size_y = int(ylen)
#
#     data_x = str(int(detector_size_x))
#     data_y = str(int(detector_size_y))
#
#     data_sec = '[1:'+data_x+',1:'+data_y+']'
#     e_header_fn = os.path.split(os.path.realpath(__file__))[0] + '/extension_header.param'
#     name, value, description = read_header_parameter(e_header_fn)
#     f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
#     s = f.readline()
#     s = s.strip("\n")
#     filters = s.split(' ')
#     s = f.readline()
#     s = s.strip("\n")
#     filterID = s.split()
#
#     s = f.readline()
#     s = s.strip("\n")
#     CCDID = s.split()
#
#     k = (row_num-1)*6+col_num
#
#     h_iter = 0
#     for n1,v1,d1 in zip(name, value, description):
#         if n1=='EXTNAME':
#             value[h_iter] = 'RAW,'+CCDID[k-1].rjust(2,'0')
#         if n1=='CCDNAME':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0')
#         if n1=='AMPNAME':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0') + ':A'
#         if n1=='GAIN':
#             value[h_iter] = gain
#         if n1=='RDNOISE':
#             value[h_iter] = readout
#         if n1=='SATURATE':
#             value[h_iter] = saturation
#         if n1=='CCDCHIP':
#             value[h_iter] = 'ccd' + CCDID[k-1].rjust(2,'0')
#         if n1=='CCDLABEL':
#             value[h_iter] = filters[k-1] + '-' + filterID[k-1]
#         if n1=='DATASEC':
#             value[h_iter] = data_sec
#
#         h_iter = h_iter + 1
#
#
#     return name, value, description


# 9232 9216  898 534 1309 60 -40  -23.4333
def WCS_def(xlen=9216, ylen=9232, gapy=898.0, gapx1=534, gapx2=1309, ra_ref=60, dec_ref=-40, pa=-23.433, pixel_scale=0.074, pixel_size=1e-2,
            rotate_chip=0., filter='GI', row_num=None, col_num=None, xcen=None, ycen=None):
    """ Creat a wcs frame for CCST with multiple extensions

    Parameters
    ----------

    """
    r_dat = OrderedDict()
    r_dat['EQUINOX'] = 2000.0
    r_dat['WCSDIM'] = 2.0
    r_dat['CTYPE1'] = 'RA---TAN'
    r_dat['CTYPE2'] = 'DEC--TAN'
    r_dat['CRVAL1'] = ra_ref
    r_dat['CRVAL2'] = dec_ref

    flag_x = [0, 1, -1, 1, -1]
    flag_y = [0, 1, 1, -1, -1]
    flag_ext_x = [0, -1, 1, -1, 1]
    flag_ext_y = [0, -1, -1, 1, 1]
    pa_aper = pa

    if (row_num is not None) and (col_num is not None):

        x_num = 6
        y_num = 5
        detector_num = x_num*y_num

        detector_size_x = xlen
        detector_size_y = ylen
        gap_y = gapy
        gap_x = [gapx1, gapx2]

        gap_x1_num = 3
        gap_x2_num = 2

        y_center = (detector_size_y*y_num+gap_y*(y_num-1))/2

        x_center = (detector_size_x*x_num +
                    gap_x[0]*gap_x1_num+gap_x[1]*gap_x2_num)/2

        gap_x_map = np.array([[0, 0, 0, 0, 0], [gap_x[0], gap_x[1], gap_x[1], gap_x[1], gap_x[1]], [gap_x[1], gap_x[0], gap_x[0], gap_x[0], gap_x[0]], [
                             gap_x[0], gap_x[0], gap_x[0], gap_x[0], gap_x[0]], [gap_x[0], gap_x[0], gap_x[0], gap_x[0], gap_x[1]], [gap_x[1], gap_x[1], gap_x[1], gap_x[1], gap_x[0]]])

        j = row_num
        i = col_num
        # ccdnum = str((j-1)*5+i)

        x_ref, y_ref = detector_size_x*i + \
            sum(gap_x_map[0:i, j-1]) - detector_size_x / \
            2., (detector_size_y+gap_y)*j-gap_y-detector_size_y/2

        for k in range(1, 2):

            cd = np.array([[pixel_scale,  0], [0, pixel_scale]]) / \
                3600.*flag_x[k]
            cd_rot = rotate_CD_matrix(cd, pa_aper)

            # f = open("CCD"+ccdnum.rjust(2,'0')+"_extension"+str(k)+"_wcs.param","w")
            r_dat['CRPIX1'] = flag_ext_x[k] * \
                ((x_ref+flag_ext_x[k]*detector_size_x/2)-x_center)
            r_dat['CRPIX2'] = flag_ext_y[k] * \
                ((y_ref+flag_ext_y[k]*detector_size_y/2)-y_center)
            r_dat['CD1_1'] = cd_rot[0, 0]
            r_dat['CD1_2'] = cd_rot[0, 1]
            r_dat['CD2_1'] = cd_rot[1, 0]
            r_dat['CD2_2'] = cd_rot[1, 1]

            if filter in ['GU', 'GV', 'GI']:

                w = wcs.WCS(naxis=2)
                w.wcs.crpix = [r_dat['CRPIX1'], r_dat['CRPIX2']]
                w.wcs.cd = cd_rot
                w.wcs.crval = [ra_ref, dec_ref]
                w.wcs.ctype = [r_dat['CTYPE1'], r_dat['CTYPE2']]

                # test_center_o = w.wcs_pix2world(np.array([[xlen / 2, ylen / 2]]), 1)

                sls_rot = rotate_chip
                if i > 2:
                    sls_rot = -sls_rot

                sn_x = 30
                sn_y = 30
                x_pixs = np.zeros(sn_y * sn_x)
                y_pixs = np.zeros(sn_y * sn_x)
                xpixs_line = np.linspace(1, xlen, sn_x)
                ypixs_line = np.linspace(1, ylen, sn_y)

                sky_coors = []

                for n1, y in enumerate(ypixs_line):
                    for n2, x in enumerate(xpixs_line):
                        i_pix = n1 * sn_x + n2
                        x_pixs[i_pix] = x
                        y_pixs[i_pix] = y

                        pix_coor = np.array([x, y])
                        sc1 = calcaluteSLSRotSkyCoor(
                            pix_xy=pix_coor, rot_angle=sls_rot, xlen=xlen, ylen=ylen, w=w)
                        # print(sc1[0,0],sc1[0,1])
                        sky_coors.append((sc1[0, 0], sc1[0, 1]))

                wcs_new = fit_wcs_from_points(xy=np.array([x_pixs, y_pixs]),
                                              world_coords=SkyCoord(sky_coors, frame="icrs", unit="deg"), projection='TAN')

                # print(wcs_new)
                # test_center = wcs_new.wcs_pix2world(np.array([[xlen / 2, ylen / 2]]), 1)
                #
                # print(test_center - test_center_o)

                r_dat['CD1_1'] = wcs_new.wcs.cd[0, 0]
                r_dat['CD1_2'] = wcs_new.wcs.cd[0, 1]
                r_dat['CD2_1'] = wcs_new.wcs.cd[1, 0]
                r_dat['CD2_2'] = wcs_new.wcs.cd[1, 1]
                r_dat['CRPIX1'] = wcs_new.wcs.crpix[0]
                r_dat['CRPIX2'] = wcs_new.wcs.crpix[1]

                r_dat['CRVAL1'] = wcs_new.wcs.crval[0]
                r_dat['CRVAL2'] = wcs_new.wcs.crval[1]

    elif (xcen is not None) and (ycen is not None):
        xcen, ycen = xcen/pixel_size, ycen/pixel_size
        x1, y1 = xcen - xlen/2., ycen - ylen/2.
        r_dat['CRPIX1'] = -x1
        r_dat['CRPIX2'] = -y1

        # cd = np.array([[ pixel_scale,  0], [0, pixel_scale]])/3600.*flag_x[1]
        cd = np.array([[pixel_scale,  0], [0, -pixel_scale]])/3600.
        cd_rot = rotate_CD_matrix(cd, pa_aper)
        r_dat['CD1_1'] = cd_rot[0, 0]
        r_dat['CD1_2'] = cd_rot[0, 1]
        r_dat['CD2_1'] = cd_rot[1, 0]
        r_dat['CD2_2'] = cd_rot[1, 1]

        w = wcs.WCS(naxis=2)
        w.wcs.crpix = [r_dat['CRPIX1'], r_dat['CRPIX2']]
        w.wcs.cd = cd_rot
        w.wcs.crval = [ra_ref, dec_ref]
        w.wcs.ctype = [r_dat['CTYPE1'], r_dat['CTYPE2']]

        sn_x = 30
        sn_y = 30
        x_pixs = np.zeros(sn_y * sn_x)
        y_pixs = np.zeros(sn_y * sn_x)
        xpixs_line = np.linspace(1, xlen, sn_x)
        ypixs_line = np.linspace(1, ylen, sn_y)
        sky_coors = []
        for n1, y in enumerate(ypixs_line):
            for n2, x in enumerate(xpixs_line):
                i_pix = n1 * sn_x + n2
                x_pixs[i_pix] = x
                y_pixs[i_pix] = y
                pix_coor = np.array([x, y])
                sc1 = calcaluteSLSRotSkyCoor(
                    pix_xy=pix_coor, rot_angle=rotate_chip, xlen=xlen, ylen=ylen, w=w)
                sky_coors.append((sc1[0, 0], sc1[0, 1]))
        wcs_new = fit_wcs_from_points(xy=np.array([x_pixs, y_pixs]),
                                      world_coords=SkyCoord(
                                          sky_coors, frame="icrs", unit="deg"),
                                      projection='TAN')
        r_dat['CD1_1'] = wcs_new.wcs.cd[0, 0]
        r_dat['CD1_2'] = wcs_new.wcs.cd[0, 1]
        r_dat['CD2_1'] = wcs_new.wcs.cd[1, 0]
        r_dat['CD2_2'] = wcs_new.wcs.cd[1, 1]
        r_dat['CRPIX1'] = wcs_new.wcs.crpix[0]
        r_dat['CRPIX2'] = wcs_new.wcs.crpix[1]

        r_dat['CRVAL1'] = wcs_new.wcs.crval[0]
        r_dat['CRVAL2'] = wcs_new.wcs.crval[1]

    else:
        raise ValueError(
            'In function WCS_def(): Either (row_num, col_num) or (xcen, ycen, pixel_size) should be given')

    return r_dat


# TODO project_cycle is temporary, is not in header defined, delete in future
def generatePrimaryHeader(xlen=9216, ylen=9232, pointing_id='00000001', pointing_type_code='101', ra=60, dec=-40, pixel_scale=0.074, time_pt=None, im_type='SCI', exptime=150., sat_pos=[0., 0., 0.], sat_vel=[0., 0., 0.], project_cycle=6, run_counter=0, chip_name="01", obstype='WIDE', dataset='csst-msc-c9-25sqdeg-v3'):

    # array_size1, array_size2, flux, sigma = int(argv[1]), int(argv[2]), 1000.0, 5.0

    # k = (row_num-1)*6+col_num
    # ccdnum = str(k)

    datetime_obs = datetime.utcfromtimestamp(time_pt)
    datetime_obs = datetime_obs.replace(tzinfo=timezone.utc)
    # print(datetime_obs.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5])
    datetime_obs = datetime.utcfromtimestamp(
        np.round(datetime_obs.timestamp(), 1))
    # print(datetime_obs.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5])
    # date_obs = datetime_obs.strftime("%y%m%d")
    # time_obs = datetime_obs.strftime("%H%M%S%f")[:-5]

    g_header_fn = os.path.split(os.path.realpath(__file__))[
        0] + '/csst_msc_l0_ms.fits'
    f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
    s = f.readline()
    s = s.strip("\n")
    filters = s.split(' ')
    s = f.readline()
    s = s.strip("\n")
    filterID = s.split()

    s = f.readline()
    s = s.strip("\n")
    CCDID = s.split()

    # h_prim = fits.Header()
    # h_prim = fits.Header.fromfile(g_header_fn)

    header_fits = fits.open(g_header_fn)
    h_prim = header_fits[0].header

    # h_prim = fits.Header()
    # with open(g_header_fn, 'r') as file:
    #     header_toml = toml.load(file)
    # h_key='HDU0'
    # for key, value in header_toml[h_key].items():
    #     h_card = fits.card.Card(header_toml[h_key][key]['key'],header_toml[h_key][key]['example'],header_toml[h_key][key]['comment'])
    #     h_prim.append(h_card)

    # h_prim['PIXSIZE1'] = xlen
    # h_prim['PIXSIZE2'] = ylen

    h_prim['DATE'] = datetime_obs.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    h_prim['DATE-OBS'] = datetime_obs.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    # h_prim['DATE'] = '20'+date[0:2]+'-' + date[2:4]+'-'+date[4:6] + 'T' + time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['TIME'] = time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['DATE-OBS'] = '20'+date[0:2]+'-' + date[2:4]+'-'+date[4:6] + 'T' + time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['TIME-OBS'] = time_obs[0:2]+':'+time_obs[2:4]+':'+time_obs[4:6]
    # h_prim['DETECTOR'] = 'CHIP'+CCDID[k-1].rjust(2,'0')
    h_prim['RA_OBJ'] = ra
    h_prim['DEC_OBJ'] = dec

    # obs_type = {'SCI': '01', 'BIAS': '03', 'DARK': '07', 'FLAT': '11', 'CRS': '98', 'CRD': '99'}

    # # OBS_id = '1'+ obs_type[im_type] + str(int(project_cycle)) + pointNum.rjust(7,'0')
    # OBS_id = '1'+ obs_type[im_type] + str(int(project_cycle)) + str(int(run_counter)).rjust(2, '0') + pointNum.rjust(5,'0')
    # OBS_id = get_obs_id(img_type=im_type, project_cycle='project_cycle', run_counter=run_counter,
    #                     pointing_id=pointing_id, pointing_type_code=pointing_type_code)
    OBS_id = pointing_type_code+pointing_id

    # h_prim['OBJECT'] = str(int(project_cycle)) + pointNum.rjust(7, '0')
    h_prim['OBJECT'] = pointing_id
    h_prim['OBSID'] = OBS_id
    # h_prim['TELFOCUS'] = 'f/14'
    h_prim['EXPTIME'] = exptime

    # # Define file types
    # file_type = {'SCI':'SCIE', 'BIAS':'BIAS', 'DARK':'DARK', 'FLAT':'FLAT', 'CRS':'CRS', 'CRD':'CRD','CALS':'CALS','CALF':'CALF'}
    # h_prim['FILETYPE'] = file_type[im_type]
    # h_prim['FILETYPE'] = get_file_type(img_type=im_type)
    # h_prim['FILETYPE'] = im_type
    h_prim['OBSTYPE'] = obstype
    h_prim['DATASET'] = dataset
    co = coord.SkyCoord(ra, dec, unit='deg')

    ra_hms = format(co.ra.hms.h, '02.0f') + format(co.ra.hms.m,
                                                   '02.0f') + format(co.ra.hms.s, '04.1f')
    dec_hms = format(co.dec.dms.d, '02.0f') + format(abs(co.dec.dms.m),
                                                     '02.0f') + format(abs(co.dec.dms.s), '02.0f')
    if dec >= 0:
        h_prim['TARGET'] = ra_hms + '+' + dec_hms
    else:
        h_prim['TARGET'] = ra_hms + dec_hms
    #
    # h_prim['RA_NOM'] = ra_hms
    # h_prim['DEC_NOM'] = dec_hms

    h_prim['RA_PNT0'] = ra
    h_prim['DEC_PNT0'] = dec
    h_prim['RA_PNT1'] = ra
    h_prim['DEC_PNT1'] = dec

    # h_prim['PIXSCAL1'] = pixel_scale
    # h_prim['PIXSCAL2'] = pixel_scale

    ttt = h_prim['DATE']
    tstart = Time(ttt)
    h_prim['EXPSTART'] = round(tstart.mjd, 5)
    h_prim['CABSTART'] = h_prim['EXPSTART']
    # tend = Time(tstart.cxcsec + h_prim['EXPTIME'], format="cxcsec")
    tend = Time(tstart.mjd + h_prim['EXPTIME']/86400., format="mjd")
    h_prim['EXPEND'] = round(tend.mjd, 5)
    h_prim['CABEND'] = h_prim['EXPEND']

    # file_start_time = '20' + date[0:6] + time_obs[0:6]
    file_start_time = datetime_obs.strftime("%Y%m%d%H%M%S")
    end_time_str = str(tend.datetime)
    file_end_time = end_time_str[0:4] + end_time_str[5:7]+end_time_str[8:10] + \
        end_time_str[11:13] + end_time_str[14:16] + end_time_str[17:19]
    # h_prim['FILENAME'] = 'CSST_MSC_MS_' + im_type + '_' + file_start_time + '_' + file_end_time + '_' + OBS_id + '_' + CCDID[
    #     k - 1].rjust(2, '0') + '_L0_V01'
    h_prim['FILENAME'] = 'CSST_MSC_MS_' + im_type + '_' + \
        file_start_time + '_' + file_end_time + \
        '_' + OBS_id + '_' + chip_name + '_L0_V01'

    h_prim['POSI0_X'] = sat_pos[0]
    h_prim['POSI0_Y'] = sat_pos[1]
    h_prim['POSI0_Z'] = sat_pos[2]

    h_prim['VELO0_X'] = sat_vel[0]
    h_prim['VELO0_Y'] = sat_vel[1]
    h_prim['VELO0_Z'] = sat_vel[2]
    # h_prim['RA_PNT0'] = ra_hms
    # h_prim['DEC_PNT0'] = dec_hms

    # Get version of CSSTSim Package
    from pkg_resources import get_distribution
    # h_prim['SIM_VER'] = (get_distribution("CSSTSim").version, "Version of CSST MSC simulation software")
    currentDateAndTime = datetime.now()
    compute_name = platform.node()
    h_prim['FITSSWV'] = get_distribution(
        "csst_msc_sim").version + '_' + currentDateAndTime.strftime("%Y%m%d") + '_' + compute_name
    h_prim['EPOCH'] = round(
        (Time(h_prim['EXPSTART'], format='mjd', scale='tcb')).jyear, 1)

    return h_prim


def generateExtensionHeader(chip, xlen=9216, ylen=9232, ra=60, dec=-40, pa=-23.433, gain=1.0, readout=5.0, dark=0.02, saturation=90000, pixel_scale=0.074, pixel_size=1e-2,
                            extName='SCIE', row_num=None, col_num=None, xcen=None, ycen=None, timestamp=1621915200, exptime=150., readoutTime=40., t_shutter_open=1.3, t_shutter_close=1.3):

    e_header_fn = os.path.split(os.path.realpath(__file__))[
        0] + '/csst_msc_l0_ms.fits'
    f = open(os.path.split(os.path.realpath(__file__))[0] + '/filter.lst')
    s = f.readline()
    s = s.strip("\n")
    filters = s.split(' ')
    s = f.readline()
    s = s.strip("\n")
    filterID = s.split()

    s = f.readline()
    s = s.strip("\n")
    CCDID = s.split()

    # k = (row_num - 1) * 6 + col_num

    # h_ext = fits.Header.fromfile(e_header_fn)

    header_fits = fits.open(e_header_fn)
    h_ext = header_fits[1].header

    # h_ext = fits.Header()
    # with open(e_header_fn, 'r') as file:
    #     header_toml = toml.load(file)
    # h_key='HDU1'
    # for key, value in header_toml[h_key].items():
    #     h_card = fits.card.Card(header_toml[h_key][key]['key'],header_toml[h_key][key]['example'],header_toml[h_key][key]['comment'])
    #     h_ext.append(h_card)

    # h_ext['CCDCHIP'] = CCDID[k - 1].rjust(2, '0')
    # h_ext['CCDLABEL'] = filters[k-1] + '-' + filterID[k-1]
    # h_ext['FILTER'] = filters[k-1]
    h_ext['DETECTOR'] = str(chip.chipID).rjust(2, '0')
    h_ext['DETLABEL'] = chip.chip_name
    h_ext['FILTER'] = chip.filter_type
    h_ext['NAXIS1'] = xlen
    h_ext['NAXIS2'] = ylen
    h_ext['EXTNAME'] = extName
    h_ext['GAIN01'] = chip.gain_channel[0]
    h_ext['GAIN02'] = chip.gain_channel[1]
    h_ext['GAIN03'] = chip.gain_channel[2]
    h_ext['GAIN04'] = chip.gain_channel[3]
    h_ext['GAIN05'] = chip.gain_channel[4]
    h_ext['GAIN06'] = chip.gain_channel[5]
    h_ext['GAIN07'] = chip.gain_channel[6]
    h_ext['GAIN08'] = chip.gain_channel[7]
    h_ext['GAIN09'] = chip.gain_channel[8]
    h_ext['GAIN10'] = chip.gain_channel[9]
    h_ext['GAIN11'] = chip.gain_channel[10]
    h_ext['GAIN12'] = chip.gain_channel[11]
    h_ext['GAIN13'] = chip.gain_channel[12]
    h_ext['GAIN14'] = chip.gain_channel[13]
    h_ext['GAIN15'] = chip.gain_channel[14]
    h_ext['GAIN16'] = chip.gain_channel[15]
    h_ext['RON01'] = readout
    h_ext['RON02'] = readout
    h_ext['RON03'] = readout
    h_ext['RON04'] = readout
    h_ext['RON05'] = readout
    h_ext['RON06'] = readout
    h_ext['RON07'] = readout
    h_ext['RON08'] = readout
    h_ext['RON09'] = readout
    h_ext['RON10'] = readout
    h_ext['RON11'] = readout
    h_ext['RON12'] = readout
    h_ext['RON13'] = readout
    h_ext['RON14'] = readout
    h_ext['RON15'] = readout
    h_ext['RON16'] = readout

    h_ext['PIXSCAL1'] = pixel_scale
    h_ext['PIXSCAL2'] = pixel_scale
    h_ext['EXPTIME'] = exptime
    h_ext['DARKTIME'] = exptime

    datetime_obs = datetime.utcfromtimestamp(timestamp)
    datetime_obs = datetime_obs.replace(tzinfo=timezone.utc)
    tstart = Time(datetime_obs)
    t_shutter_os = tstart
    t_shutter_oe = Time(tstart.mjd + t_shutter_open / 86400., format="mjd")
    t_shutter_co = Time(tstart.mjd + exptime / 86400., format="mjd")
    t_shutter_ce = Time(
        tstart.mjd + (exptime + t_shutter_close) / 86400., format="mjd")
    t_shutter_os1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        t_shutter_os.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['SHTOPEN0'] = t_shutter_os1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    t_shutter_oe1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        t_shutter_oe.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['SHTOPEN1'] = t_shutter_oe1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    t_shutter_co1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        t_shutter_co.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['SHTCLOS0'] = t_shutter_co1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    t_shutter_ce1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        t_shutter_ce.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['SHTCLOS1'] = t_shutter_ce1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]

    tstart_read = Time(tstart.mjd + exptime / 86400., format="mjd")
    tend_read = Time(tstart.mjd + (exptime + readoutTime) /
                     86400., format="mjd")
    # tstart1=tstart.datetime.replace(microsecond=round(tstart.datetime.microsecond, -5))
    tstart1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        tstart_read.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['ROTIME0'] = tstart1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    # tend_read1 = tend_read.datetime.replace(microsecond=round(tend_read.datetime.microsecond, -5))
    tend_read1 = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        tend_read.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    h_ext['ROTIME1'] = tend_read1.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]
    # h_ext['POS_ANG'] = pa
    header_wcs = WCS_def(xlen=xlen, ylen=ylen, gapy=898.0, gapx1=534, gapx2=1309, ra_ref=ra, dec_ref=dec, pa=pa, pixel_scale=pixel_scale, pixel_size=pixel_size,
                         rotate_chip=chip.rotate_angle, filter=h_ext['FILTER'], row_num=row_num, col_num=col_num, xcen=xcen, ycen=ycen)

    h_ext['CRPIX1'] = header_wcs['CRPIX1']
    h_ext['CRPIX2'] = header_wcs['CRPIX2']
    h_ext['CRVAL1'] = header_wcs['CRVAL1']
    h_ext['CRVAL2'] = header_wcs['CRVAL2']
    h_ext['CD1_1'] = header_wcs['CD1_1']
    h_ext['CD1_2'] = header_wcs['CD1_2']
    h_ext['CD2_1'] = header_wcs['CD2_1']
    h_ext['CD2_2'] = header_wcs['CD2_2']
    # h_ext['EQUINOX'] = header_wcs['EQUINOX']
    # h_ext['WCSDIM'] = header_wcs['WCSDIM']
    h_ext['CTYPE1'] = header_wcs['CTYPE1']
    h_ext['CTYPE2'] = header_wcs['CTYPE2']

    h_ext['EXTNAME'] = 'IMAGE'
    h_ext.comments["XTENSION"] = "image extension"

    return h_ext


def main(argv):

    xlen = int(argv[1])
    ylen = int(argv[2])
    pointingNum = argv[3]
    ra = float(argv[4])
    dec = float(argv[5])
    pSize = float(argv[6])
    ccd_row_num = int(argv[7])
    ccd_col_num = int(argv[8])
    pa_aper = float(argv[9])
    gain = float(argv[10])
    readout = float(argv[11])
    dark = float(argv[12])
    fw = float(argv[13])

    h_prim = generatePrimaryHeader(xlen=xlen, ylen=ylen, ra=ra, dec=dec, psize=pSize,
                                   row_num=ccd_row_num, col_num=ccd_col_num, pointNum=pointingNum)

    h_ext = generateExtensionHeader(xlen=xlen, ylen=ylen, ra=ra, dec=dec, pa=pa_aper, gain=gain,
                                    readout=readout, dark=dark, saturation=fw, psize=pSize, row_num=ccd_row_num, col_num=ccd_col_num)
    hdu1 = fits.PrimaryHDU(header=h_prim)
    hdu2 = fits.ImageHDU(np.zeros([ylen, xlen]), header=h_ext)

    hdul = fits.HDUList([hdu1, hdu2])

    hdul.writeto(h_prim['FILENAME']+'.fits', output_verify='ignore')

# if __name__ == "__main__":
#     main(sys.argv)