import os
import numpy as np
import mpi4py.MPI as MPI
import galsim
import logging
import psutil
from astropy.io import fits
from datetime import datetime

import traceback

from ObservationSim.Config import config_dir, ChipOutput
from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader
from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip
from ObservationSim.Instrument.Chip import Effects
from ObservationSim.MockObject import calculateSkyMap_split_g
from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp
from ObservationSim._util import get_shear_field, makeSubDir_PointingList
from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position

class Observation(object):
    def __init__(self, config, Catalog, work_dir=None, data_dir=None):
        self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir)
        self.config = config
        self.tel = Telescope()
        self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) 
        self.filter_param = FilterParam() 
        self.chip_list = []
        self.filter_list = []
        self.all_filter = []
        self.Catalog = Catalog

        # if we want to apply field distortion?
        if self.config["ins_effects"]["field_dist"] == True:
            self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"])
        else:
            self.fd_model = None

        # Construct chips & filters:
        nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y
        for i in range(nchips):
            chipID = i + 1

            # Make Chip & Filter lists
            chip = Chip(
                chipID=chipID, 
                config=self.config)
            filter_id, filter_type = chip.getChipFilter()
            filt = Filter(filter_id=filter_id, 
                filter_type=filter_type, 
                filter_param=self.filter_param)
            if not self.focal_plane.isIgnored(chipID=chipID):
                self.chip_list.append(chip)
                self.filter_list.append(filt)
            self.all_filter.append(filt)

        # Read catalog and shear(s)
        self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config)

    def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, shear_cat_file=None, cat_dir=None, sed_dir=None):

        # print(':::::::::::::::::::Current Pointing Information::::::::::::::::::')
        # print("RA: %f, DEC; %f" % (pointing.ra, pointing.dec))
        # print("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat())
        # print("Exposure time: %f" % pointing.exp_time)
        # print("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z))
        # print("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz))
        # print("Position Angle: %f" % pointing.img_pa.deg)
        # print('Chip : %d' % chip.chipID)
        # print(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::')
        chip_output.logger.info(':::::::::::::::::::Current Pointing Information::::::::::::::::::')
        chip_output.logger.info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec))
        chip_output.logger.info("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat())
        chip_output.logger.info("Exposure time: %f" % pointing.exp_time)
        chip_output.logger.info("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z))
        chip_output.logger.info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz))
        chip_output.logger.info("Position Angle: %f" % pointing.img_pa.deg)
        chip_output.logger.info('Chip : %d' % chip.chipID)
        chip_output.logger.info(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::')

        if self.config["psf_setting"]["psf_model"] == "Gauss":
            psf_model = PSFGauss(chip=chip, psfRa=self.config["psf_setting"]["psf_rcont"])
        elif self.config["psf_setting"]["psf_model"] == "Interp":
            psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"])
        else:
            # print("unrecognized PSF model type!!", flush=True)
            chip_output.logger.error("unrecognized PSF model type!!", flush=True)

        # Get (extra) shear fields
        if shear_cat_file is not None:
            self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config, shear_cat_file=shear_cat_file)

        # Apply astrometric simulation for pointing
        if self.config["obs_setting"]["enable_astrometric_model"]:
            dt = datetime.utcfromtimestamp(pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_cen, dec_cen = on_orbit_obs_position(
                input_ra_list=[pointing.ra],
                input_dec_list=[pointing.dec],
                input_pmra_list=[0.],
                input_pmdec_list=[0.],
                input_rv_list=[0.],
                input_parallax_list=[1e-9],
                input_nstars=1,
                input_x=pointing.sat_x,
                input_y=pointing.sat_y,
                input_z=pointing.sat_z,
                input_vx=pointing.sat_vx,
                input_vy=pointing.sat_vy,
                input_vz=pointing.sat_vz,
                input_epoch="J2015.5",
                input_date_str=date_str,
                input_time_str=time_str
            )
            ra_cen, dec_cen = ra_cen[0], dec_cen[0]
        else:
            ra_cen = pointing.ra
            dec_cen = pointing.dec

        # Get WCS for the focal plane
        if wcs_fp == None:
            wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale)

        # Create chip Image
        chip.img = galsim.ImageF(chip.npix_x, chip.npix_y)
        chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
        chip.img.wcs = wcs_fp
        if chip.survey_type == "photometric":
            sky_map = None
        # elif chip.survey_type == "spectroscopic":
        #     sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0)
        elif chip.survey_type == "spectroscopic":
            # chip.loadSLSFLATCUBE(flat_fn='flat_cube.fits')
            flat_normal = np.ones_like(chip.img.array)
            if self.config["ins_effects"]["flat_fielding"] == True:
                # print("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID, flush=True)
                # print(chip.img.bounds, flush=True)
                chip_output.logger.info("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID)
                msg = str(chip.img.bounds)
                chip_output.logger.info(msg)
                flat_img = Effects.MakeFlatSmooth(
                    chip.img.bounds,
                    int(self.config["random_seeds"]["seed_flat"]))
                flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array)
            if self.config["ins_effects"]["shutter_effect"] == True:
                # print("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID, flush=True)
                chip_output.logger.info("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID)
                shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735,
                                                    dt=1E-3)  # shutter effect normalized image for this chip
                flat_normal = flat_normal*shuttimg
                flat_normal = np.array(flat_normal,dtype='float32')
            sky_map = calculateSkyMap_split_g(
                skyMap=flat_normal,
                blueLimit=filt.blue_limit,
                redLimit=filt.red_limit,
                conf=chip.sls_conf,
                pixelSize=chip.pix_scale,
                isAlongY=0,
                flat_cube=chip.flat_cube)
            # sky_map = np.ones([9216, 9232])
            del flat_normal

        if pointing.pointing_type == 'MS':
            # Load catalogues and templates
            self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir, chip_output=chip_output, filt=filt)
            chip_output.create_output_file()
            self.nobj = len(self.cat.objs)

            for ifilt in range(len(self.all_filter)):
                temp_filter = self.all_filter[ifilt]
                # Update the limiting magnitude using exposure time in pointing
                temp_filter.update_limit_saturation_mags(exptime=pointing.exp_time, chip=chip)

                # Select cutting band filter for saturation/limiting magnitude
                if temp_filter.filter_type.lower() == self.config["obs_setting"]["cut_in_band"].lower():
                    cut_filter = temp_filter

            # Loop over objects
            missed_obj = 0
            bright_obj = 0
            dim_obj = 0
            for j in range(self.nobj):
                
                # (DEBUG)
                # if j >= 10:
                #     break

                obj = self.cat.objs[j]
                if obj.type == 'star' and self.config["run_option"]["galaxy_only"]:
                    continue
                elif obj.type == 'galaxy' and self.config["run_option"]["star_only"]:
                    continue
                elif obj.type == 'quasar' and self.config["run_option"]["star_only"]:
                    continue

                # load and convert SED; also caculate object's magnitude in all CSST bands
                try:
                    sed_data = self.cat.load_sed(obj)
                    norm_filt = self.cat.load_norm_filt(obj)
                    obj.sed, obj.param["mag_%s"%filt.filter_type], obj.param["flux_%s"%filt.filter_type] = self.cat.convert_sed(
                        mag=obj.param["mag_use_normal"],
                        sed=sed_data,
                        target_filt=filt, 
                        norm_filt=norm_filt,
                    )
                    _, obj.param["mag_%s"%cut_filter.filter_type], obj.param["flux_%s"%cut_filter.filter_type] = self.cat.convert_sed(
                        mag=obj.param["mag_use_normal"],
                        sed=sed_data,
                        target_filt=cut_filter, 
                        norm_filt=norm_filt,
                    )

                except Exception as e:
                    # print(e)
                    traceback.print_exc()
                    chip_output.logger.error(e)
                    continue
                
                # chip_output.logger.info("debug point #1")

                # Exclude very bright/dim objects (for now)
                # if filt.is_too_bright(mag=obj.getMagFilter(filt)):
                # if filt.is_too_bright(mag=obj.mag_use_normal):
                if cut_filter.is_too_bright(
                    mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()],
                    margin=self.config["obs_setting"]["mag_sat_margin"]):
                    # print("obj too birght!!", flush=True)
                    # if obj.type != 'galaxy':
                    #     bright_obj += 1
                    #     obj.unload_SED()
                    #     continue
                    bright_obj += 1
                    obj.unload_SED()
                    continue
                if filt.is_too_dim(
                    mag=obj.getMagFilter(filt),
                    margin=self.config["obs_setting"]["mag_lim_margin"]):
                # if cut_filter.is_too_dim(mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()]):
                    # print("obj too dim!!", flush=True)
                    dim_obj += 1
                    obj.unload_SED()
                    # print(obj.getMagFilter(filt))
                    continue

                # chip_output.logger.info("debug point #2")

                if self.config["shear_setting"]["shear_type"] == "constant":
                    if obj.type == 'star':
                        obj.g1, obj.g2 = 0., 0.
                    else:
                        obj.g1, obj.g2 = self.g1_field, self.g2_field
                elif self.config["shear_setting"]["shear_type"] == "extra":
                    try:
                        # TODO: every object with individual shear from input catalog(s)
                        obj.g1, obj.g2 = self.g1_field[j], self.g2_field[j]
                    except:
                        # print("failed to load external shear.")
                        chip_output.logger.error("failed to load external shear.")
                        pass

                    # chip_output.logger.info("debug point #3")
                elif self.config["shear_setting"]["shear_type"] == "catalog":
                    pass
                else:
                    chip_output.logger.error("Unknown shear input")
                    raise ValueError("Unknown shear input")
                
                # chip_output.logger.info("debug point #4")
                header_wcs = generateExtensionHeader(
                    xlen=chip.npix_x,
                    ylen=chip.npix_y,
                    ra=ra_cen,
                    dec=dec_cen,
                    pa=pointing.img_pa.deg,
                    gain=chip.gain,
                    readout=chip.read_noise,
                    dark=chip.dark_noise,
                    saturation=90000,
                    psize=chip.pix_scale,
                    row_num=chip.rowID,
                    col_num=chip.colID,
                    extName='raw')

                pos_img, offset, local_wcs, real_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False, img_header=header_wcs)
                if pos_img.x == -1 or pos_img.y == -1:
                    # Exclude object which is outside the chip area (after field distortion)
                    # print("obj missed!!")
                    missed_obj += 1
                    obj.unload_SED()
                    continue

                # chip_output.logger.info("debug point #5")

                # Draw object & update output catalog
                try:
                    # chip_output.logger.info("debug point #6")
                    # chip_output.logger.info("current filter type: %s"%filt.filter_type)
                    if self.config["run_option"]["out_cat_only"]:
                        isUpdated = True
                        obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y,
                                        img_real_wcs=obj.real_wcs)
                        pos_shear = 0.
                    elif chip.survey_type == "photometric" and not self.config["run_option"]["out_cat_only"]:
                        isUpdated, pos_shear = obj.drawObj_multiband(
                            tel=self.tel,
                            pos_img=pos_img, 
                            psf_model=psf_model, 
                            bandpass_list=filt.bandpass_sub_list, 
                            filt=filt, 
                            chip=chip, 
                            g1=obj.g1, 
                            g2=obj.g2, 
                            exptime=pointing.exp_time
                            )
                    elif chip.survey_type == "spectroscopic" and not self.config["run_option"]["out_cat_only"]:
                        isUpdated, pos_shear = obj.drawObj_slitless(
                            tel=self.tel, 
                            pos_img=pos_img, 
                            psf_model=psf_model, 
                            bandpass_list=filt.bandpass_sub_list, 
                            filt=filt, 
                            chip=chip, 
                            g1=obj.g1, 
                            g2=obj.g2, 
                            exptime=pointing.exp_time,
                            normFilter=norm_filt)
                    # chip_output.logger.info("debug point #7")
                    if isUpdated:
                        # TODO: add up stats
                        # print("updating output catalog...")
                        chip_output.cat_add_obj(obj, pos_img, pos_shear)
                        pass
                    else:
                        # print("object omitted", flush=True)
                        continue
                except Exception as e:
                    # print(e)
                    traceback.print_exc()
                    chip_output.logger.error(e)
                    pass
                # Unload SED:
                obj.unload_SED()
                del obj

            del psf_model
            del self.cat

        # print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True)
        chip_output.logger.info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ))

        # Detector Effects
        # ===========================================================
        # whether to output zero, dark, flat calibration images.

        if not self.config["run_option"]["out_cat_only"]:
            chip.img = chip.addEffects(
                config=self.config, 
                img=chip.img, 
                chip_output=chip_output, 
                filt=filt, 
                ra_cen=pointing.ra, 
                dec_cen=pointing.dec,
                img_rot=pointing.img_pa,
                pointing_ID=pointing.id,
                timestamp_obs=pointing.timestamp,
                pointing_type=pointing.pointing_type,
                sky_map=sky_map, tel = self.tel,
                logger=chip_output.logger)
            
            if pointing.pointing_type == 'MS':
                datetime_obs = datetime.utcfromtimestamp(pointing.timestamp)
                date_obs = datetime_obs.strftime("%y%m%d")
                time_obs = datetime_obs.strftime("%H%M%S")
                h_prim = generatePrimaryHeader(
                    xlen=chip.npix_x, 
                    ylen=chip.npix_y, 
                    pointNum = str(pointing.id),
                    ra=pointing.ra, 
                    dec=pointing.dec, 
                    psize=chip.pix_scale, 
                    row_num=chip.rowID, 
                    col_num=chip.colID,
                    date=date_obs,
                    time_obs=time_obs,
                    exptime=pointing.exp_time,
                    im_type='SCI',
                    sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z],
                    sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz])
                h_ext = generateExtensionHeader(
                    xlen=chip.npix_x, 
                    ylen=chip.npix_y, 
                    ra=pointing.ra, 
                    dec=pointing.dec, 
                    pa=pointing.img_pa.deg, 
                    gain=chip.gain, 
                    readout=chip.read_noise, 
                    dark=chip.dark_noise, 
                    saturation=90000, 
                    psize=chip.pix_scale, 
                    row_num=chip.rowID, 
                    col_num=chip.colID,
                    extName='raw')
                chip.img = galsim.Image(chip.img.array, dtype=np.uint16)
                hdu1 = fits.PrimaryHDU(header=h_prim)
                hdu2 = fits.ImageHDU(chip.img.array, header=h_ext)
                hdu1 = fits.HDUList([hdu1, hdu2])
                fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits')
                hdu1.writeto(fname, output_verify='ignore', overwrite=True)
                # print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj))
                # print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj))
                # print("# objects that are missed %d out of %d"%(missed_obj, self.nobj))
                chip_output.logger.info("# objects that are too bright %d out of %d"%(bright_obj, self.nobj))
                chip_output.logger.info("# objects that are too dim %d out of %d"%(dim_obj, self.nobj))
                chip_output.logger.info("# objects that are missed %d out of %d"%(missed_obj, self.nobj))
        del chip.img

        # print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True)
        chip_output.logger.info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ))

    def runExposure_MPI_PointingList(self, pointing_list, shear_cat_file=None, chips=None, use_mpi=False):
        if use_mpi:
            comm = MPI.COMM_WORLD
            ind_thread = comm.Get_rank()
            num_thread = comm.Get_size()

        if chips is None:
            nchips_per_fp = len(self.chip_list)
            run_chips = self.chip_list
            run_filts = self.filter_list
        else:
            # Only run a particular set of chips
            run_chips = []
            run_filts = []
            nchips_per_fp = len(chips)
            for ichip in range(len(self.chip_list)):
                chip = self.chip_list[ichip]
                filt = self.filter_list[ichip]
                if chip.chipID in chips:
                    run_chips.append(chip)
                    run_filts.append(filt)

        for ipoint in range(len(pointing_list)):
            for ichip in range(nchips_per_fp):
                i = ipoint*nchips_per_fp + ichip
                pointing = pointing_list[ipoint]
                pointing_ID = pointing.id
                if use_mpi:
                    if i % num_thread != ind_thread:
                        continue

                pid = os.getpid()

                sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID)

                chip = run_chips[ichip]
                filt = run_filts[ichip]
                # print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True)
                chip_output = ChipOutput(
                    config=self.config, 
                    focal_plane=self.focal_plane, 
                    chip=chip, 
                    filt=filt,  
                    exptime=pointing.exp_time,
                    pointing_type=pointing.pointing_type,
                    pointing_ID=pointing_ID,  
                    subdir=sub_img_dir,
                    prefix=prefix)
                chip_output.logger.info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid))
                self.run_one_chip(
                    chip=chip, 
                    filt=filt, 
                    chip_output=chip_output, 
                    pointing=pointing,
                    cat_dir=self.path_dict["cat_dir"])
                print("finished running chip#%d..."%(chip.chipID), flush=True)
                chip_output.logger.info("finished running chip#%d..."%(chip.chipID))
                for handler in chip_output.logger.handlers[:]:
                    chip_output.logger.removeHandler(handler)