import os
import numpy as np
import mpi4py.MPI as MPI
import galsim
import psutil
import gc
from datetime import datetime

import traceback

from observation_sim.config import ChipOutput
from observation_sim.instruments import Telescope, Filter, FilterParam, FocalPlane, Chip
from observation_sim.instruments.chip import effects
from observation_sim.instruments.chip import chip_utils as chip_utils
from observation_sim.astrometry.Astrometry_util import on_orbit_obs_position
from observation_sim.sim_steps import SimSteps, SIM_STEP_TYPES


class Observation(object):
    def __init__(self, config, Catalog, work_dir=None, data_dir=None):
        self.config = config
        self.tel = Telescope()
        self.filter_param = FilterParam()
        self.Catalog = Catalog

    def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None, slsPSFOptim=False):
        # Get WCS for the focal plane
        if wcs_fp is None:
            wcs_fp = self.focal_plane.getTanWCS(
                ra_cen, dec_cen, pointing.img_pa, chip.pix_scale)

        # Create chip Image
        chip.img = galsim.ImageF(chip.npix_x, chip.npix_y)
        chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
        chip.img.wcs = wcs_fp

        chip.slsPSFOptim = slsPSFOptim
        if chip.chipID in [1, 2, 3, 4, 5, 10, 21, 26, 27, 28, 29, 30] and slsPSFOptim:
            chip.img_stack = {}
            for id1 in np.arange(2):
                gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
                orders = {}
                # for id2 in ['-2','-1','0','1','2']:
                for id2 in ['0', '1']:
                    o_n = "order"+id2
                    allbands = {}
                    for id3 in ['1', '2', '3', '4']:
                        w_n = "w"+id3
                        allbands[w_n] = galsim.ImageF(chip.npix_x, chip.npix_y)
                        allbands[w_n].setOrigin(
                            chip.bound.xmin, chip.bound.ymin)
                        allbands[w_n].wcs = wcs_fp
                    orders[o_n] = allbands
                chip.img_stack[gn] = orders
        else:
            chip.img_stack = {}

        # Get random generators for this chip
        chip.rng_poisson, chip.poisson_noise = chip_utils.get_poisson(
            seed=int(self.config["random_seeds"]["seed_poisson"]) + pointing.id*30 + chip.chipID, sky_level=0.)

        # Get flat, shutter, and PRNU images
        chip.flat_img, _ = chip_utils.get_flat(
            img=chip.img, seed=int(self.config["random_seeds"]["seed_flat"]))
        if chip.chipID <= 30:
            chip.flat_img = chip.flat_img*chip_utils.get_innerflat(chip=chip)
        if chip.chipID > 30:
            chip.shutter_img = np.ones_like(chip.img.array)
        else:
            chip.shutter_img = effects.ShutterEffectArr(
                chip.img, t_exp=pointing.exp_time, t_shutter=1.3, dist_bearing=735, dt=1E-3)
        chip.prnu_img = effects.PRNU_Img(xsize=chip.npix_x, ysize=chip.npix_y, sigma=0.01,
                                         seed=int(self.config["random_seeds"]["seed_prnu"]+chip.chipID))

        return chip

    def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, cat_dir=None, sed_dir=None):

        chip_output.Log_info(
            ':::::::::::::::::::Current Pointing Information::::::::::::::::::')
        chip_output.Log_info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec))
        chip_output.Log_info("Time: %s" % datetime.utcfromtimestamp(
            pointing.timestamp).isoformat())
        chip_output.Log_info("Exposure time: %f" % pointing.exp_time)
        chip_output.Log_info("Satellite Position (x, y, z): (%f, %f, %f)" % (
            pointing.sat_x, pointing.sat_y, pointing.sat_z))
        chip_output.Log_info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (
            pointing.sat_vx, pointing.sat_vy, pointing.sat_vz))
        chip_output.Log_info("Position Angle: %f" % pointing.img_pa.deg)
        chip_output.Log_info('Chip : %d' % chip.chipID)
        chip_output.Log_info(
            ':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::')

        # Apply astrometric simulation for pointing
        if self.config["obs_setting"]["enable_astrometric_model"]:
            dt = datetime.utcfromtimestamp(pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_cen, dec_cen = on_orbit_obs_position(
                input_ra_list=[pointing.ra],
                input_dec_list=[pointing.dec],
                input_pmra_list=[0.],
                input_pmdec_list=[0.],
                input_rv_list=[0.],
                input_parallax_list=[1e-9],
                input_nstars=1,
                input_x=pointing.sat_x,
                input_y=pointing.sat_y,
                input_z=pointing.sat_z,
                input_vx=pointing.sat_vx,
                input_vy=pointing.sat_vy,
                input_vz=pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )
            ra_offset, dec_offset = pointing.ra - \
                ra_cen[0], pointing.dec - dec_cen[0]
        else:
            ra_offset, dec_offset = 0., 0.
        ra_cen = pointing.ra
        dec_cen = pointing.dec

        slsPSFOpt = False
        # Prepare necessary chip properties for simulation
        chip = self.prepare_chip_for_exposure(
            chip, ra_cen, dec_cen, pointing, slsPSFOptim=slsPSFOpt)

        # Initialize SimSteps
        sim_steps = SimSteps(overall_config=self.config,
                             chip_output=chip_output,
                             all_filters=self.all_filters,
                             ra_offset=ra_offset,
                             dec_offset=dec_offset)

        for step in pointing.obs_param["call_sequence"]:
            if self.config["run_option"]["out_cat_only"]:
                if step != "scie_obs":
                    continue
            chip_output.Log_info("Starting simulation step: %s, calling function: %s" % (
                step, SIM_STEP_TYPES[step]))
            obs_param = pointing.obs_param["call_sequence"][step]
            step_name = SIM_STEP_TYPES[step]
            try:
                step_func = getattr(sim_steps, step_name)
                chip, filt, tel, pointing = step_func(
                    chip=chip,
                    filt=filt,
                    tel=self.tel,
                    pointing=pointing,
                    catalog=self.Catalog,
                    obs_param=obs_param)
                chip_output.Log_info("Finished simulation step: %s" % (step))
            except Exception as e:
                traceback.print_exc()
                chip_output.Log_error(e)
                chip_output.Log_error("Failed simulation on step: %s" % (step))
                break
        del chip.img
        del chip.flat_img
        del chip.prnu_img
        del chip.shutter_img
        chip_output.Log_info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB" % (pointing.id,
                             chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)))

    def runExposure_MPI_PointingList(self, pointing_list, chips=None):
        comm = MPI.COMM_WORLD
        ind_thread = comm.Get_rank()
        num_thread = comm.Get_size()

        process_counter = 0
        for ipoint in range(len(pointing_list)):
            # Construct chips & filters:
            pointing = pointing_list[ipoint]
            # pointing_ID = pointing.id
            pointing_ID = pointing.obs_id

            pointing.make_output_pointing_dir(
                overall_config=self.config, copy_obs_config=True)

            self.focal_plane = FocalPlane(
                chip_list=pointing.obs_param["run_chips"])
            # Make Chip & Filter lists
            self.chip_list = []
            self.filter_list = []
            self.all_filters = []
            for i in range(self.focal_plane.nchips):
                chipID = i + 1
                chip = Chip(chipID=chipID, config=self.config)
                filter_id, filter_type = chip.getChipFilter()
                filt = Filter(
                    filter_id=filter_id,
                    filter_type=filter_type,
                    filter_param=self.filter_param)
                if not self.focal_plane.isIgnored(chipID=chipID):
                    self.chip_list.append(chip)
                    self.filter_list.append(filt)
                self.all_filters.append(filt)

            if chips is None:
                # Run all chips defined in configuration of this pointing
                run_chips = self.chip_list
                run_filts = self.filter_list
                nchips_per_fp = len(self.chip_list)
            else:
                # Only run a particular set of chips
                run_chips = []
                run_filts = []
                for ichip in range(len(self.chip_list)):
                    chip = self.chip_list[ichip]
                    filt = self.filter_list[ichip]
                    if chip.chipID in chips:
                        run_chips.append(chip)
                        run_filts.append(filt)
                nchips_per_fp = len(chips)

            for ichip in range(nchips_per_fp):
                i_process = process_counter + ichip
                if i_process % num_thread != ind_thread:
                    continue
                pid = os.getpid()

                chip = run_chips[ichip]
                filt = run_filts[ichip]

                chip_output = ChipOutput(
                    config=self.config,
                    chip=chip,
                    filt=filt,
                    pointing=pointing
                )
                chip_output.Log_info("running pointing#%d, chip#%d, at PID#%d..." % (
                    int(pointing_ID), chip.chipID, pid))
                self.run_one_chip(
                    chip=chip,
                    filt=filt,
                    chip_output=chip_output,
                    pointing=pointing)
                chip_output.Log_info(
                    "finished running chip#%d..." % (chip.chipID))
                for handler in chip_output.logger.handlers[:]:
                    chip_output.logger.removeHandler(handler)
                del chip_output
                gc.collect()
            process_counter += nchips_per_fp