import galsim
import numpy as np
import astropy.constants as cons
from astropy import wcs
from astropy.table import Table

from ObservationSim.MockObject._util import magToFlux, VC_A, convolveGaussXorders
from ObservationSim.MockObject._util import integrate_sed_bandpass, getNormFactorForSpecWithABMAG, getObservedSED, \
    getABMAG
from ObservationSim.MockObject.SpecDisperser import SpecDisperser


class MockObject(object):
    def __init__(self, param, logger=None):
        self.param = param
        
        for key in self.param:
            setattr(self, key, self.param[key])
        
        if self.param["star"] == 0:
            self.type = "galaxy"
        elif self.param["star"] == 1:
            self.type = "star"
        elif self.param["star"] == 2:
            self.type = "quasar"

        # self.id = self.param["id"]
        # self.ra = self.param["ra"]
        # self.dec = self.param["dec"]
        # self.ra_orig = self.param["ra_orig"]
        # self.dec_orig = self.param["dec_orig"]
        # self.z = self.param["z"]
        # self.sed_type = self.param["sed_type"]
        # self.model_tag = self.param["model_tag"]
        # self.mag_use_normal = self.param["mag_use_normal"]
        self.sed = None

        # Place holder for outputs
        # self.av = self.param["av"]
        # self.redden = self.param["redden"]
        # self.pmra = self.param["pmra"]
        # self.pmdec = self.param["pmdec"]
        # self.rv = self.param["rv"]
        # self.parallax = self.param["parallax"]
        # self.g1 = self.param["g1"]
        # self.g2 = self.param["g2"]
        # self.thetaR = self.param["theta"]
        # self.bfrac = self.param["bfrac"]
        # self.hlr_disk = self.param["hlr_disk"]
        # self.hlr_bulge = self.param["hlr_bulge"]
        # self.e1_disk, self.e2_disk = 0., 0.
        # self.e1_bulge, self.e2_bulge = 0., 0.
        self.additional_output_str = ""

        self.logger = logger

    def getMagFilter(self, filt):
        if filt.filter_type in ["GI", "GV", "GU"]:
            return self.param["mag_use_normal"]
        return self.param["mag_%s" % filt.filter_type]
        # (TEST) stamp size
        # return 13.0

    def getFluxFilter(self, filt):
        return self.param["flux_%s" % filt.filter_type]

    def getNumPhotons(self, flux, tel, exptime=150.):
        pupil_area = tel.pupil_area * (100.) ** 2  # m^2 to cm^2
        return flux * pupil_area * exptime

    def getElectronFluxFilt(self, filt, tel, exptime=150.):
        # photonEnergy = filt.getPhotonE()
        # flux = magToFlux(self.getMagFilter(filt))
        # factor = 1.0e4 * flux/photonEnergy * VC_A * (1.0/filt.blue_limit - 1.0/filt.red_limit)
        # return factor * filt.efficiency * tel.pupil_area * exptime
        flux = self.getFluxFilter(filt)
        return flux * tel.pupil_area * exptime

    def getPosWorld(self):
        ra = self.param["ra"]
        dec = self.param["dec"]
        return galsim.CelestialCoord(ra=ra * galsim.degrees, dec=dec * galsim.degrees)

    def getPosImg_Offset_WCS(self, img, fdmodel=None, chip=None, verbose=True, img_header=None):
        self.posImg = img.wcs.toImage(self.getPosWorld())
        self.localWCS = img.wcs.local(self.posImg)
        if (fdmodel is not None) and (chip is not None):
            if verbose:
                print("\n")
                print("Before field distortion:\n")
                print("x = %.2f, y = %.2f\n" % (self.posImg.x, self.posImg.y), flush=True)
            self.posImg = fdmodel.get_Distorted(chip=chip, pos_img=self.posImg)
            if verbose:
                print("After field distortion:\n")
                print("x = %.2f, y = %.2f\n" % (self.posImg.x, self.posImg.y), flush=True)
        x, y = self.posImg.x + 0.5, self.posImg.y + 0.5
        self.x_nominal = int(np.floor(x + 0.5))
        self.y_nominal = int(np.floor(y + 0.5))
        dx = x - self.x_nominal
        dy = y - self.y_nominal
        self.offset = galsim.PositionD(dx, dy)

        if img_header is not None:
            self.real_wcs = galsim.FitsWCS(header=img_header)
        else:
            self.real_wcs = None

        return self.posImg, self.offset, self.localWCS, self.real_wcs

    def getRealPos(self, img, global_x=0., global_y=0., img_real_wcs=None):
        img_global_pos = galsim.PositionD(global_x, global_y)
        cel_pos = img.wcs.toWorld(img_global_pos)
        realPos = img_real_wcs.toImage(cel_pos)

        return realPos

    def drawObject(self, img, final, flux=None, filt=None, tel=None, exptime=150.):
        """ Draw (point like) object on img.
        Should be overided for extended source, e.g. galaxy...
        Paramter:
            img: the "canvas"
            final: final (after shear, PSF etc.) GSObject
        Return:
            img: the image with the GSObject added (or discarded)
            isUpdated: is the "canvas" been updated? (a flag for updating statistcs)
        """
        isUpdated = True

        # Draw with FFT
        # stamp = final.drawImage(wcs=self.localWCS, offset=self.offset)

        # Draw with Photon Shoot
        stamp = final.drawImage(wcs=self.localWCS, method='phot', offset=self.offset)

        stamp.setCenter(self.x_nominal, self.y_nominal)
        if np.sum(np.isnan(stamp.array)) >= 1:
            stamp.setZero()
        bounds = stamp.bounds & img.bounds
        if bounds.area() == 0:
            isUpdated = False
        else:
            img[bounds] += stamp[bounds]
        return img, stamp, isUpdated

    def drawObj_multiband(self, tel, pos_img, psf_model, bandpass_list, filt, chip, nphotons_tot=None, g1=0, g2=0,
                          exptime=150.):
        if nphotons_tot == None:
            nphotons_tot = self.getElectronFluxFilt(filt, tel, exptime)
        # print("nphotons_tot = ", nphotons_tot)

        try:
            full = integrate_sed_bandpass(sed=self.sed, bandpass=filt.bandpass_full)
        except Exception as e:
            print(e)
            self.logger.error(e)
            return False

        nphotons_sum = 0
        photons_list = []
        xmax, ymax = 0, 0

        # (TEST) Galsim Parameters
        if self.getMagFilter(filt) <= 15:
            folding_threshold = 5.e-4
        else:
            folding_threshold = 5.e-3
        gsp = galsim.GSParams(folding_threshold=folding_threshold)

        self.real_pos = self.getRealPos(chip.img, global_x=self.posImg.x, global_y=self.posImg.y,
                                        img_real_wcs=self.real_wcs)

        x, y = self.real_pos.x + 0.5, self.real_pos.y + 0.5
        x_nominal = int(np.floor(x + 0.5))
        y_nominal = int(np.floor(y + 0.5))
        dx = x - x_nominal
        dy = y - y_nominal
        offset = galsim.PositionD(dx, dy)

        real_wcs_local = self.real_wcs.local(self.real_pos)

        for i in range(len(bandpass_list)):
            bandpass = bandpass_list[i]
            try:
                sub = integrate_sed_bandpass(sed=self.sed, bandpass=bandpass)
            except Exception as e:
                print(e)
                self.logger.error(e)
                # return False
                continue

            ratio = sub / full

            if not (ratio == -1 or (ratio != ratio)):
                nphotons = ratio * nphotons_tot
            else:
                # return False
                continue
            nphotons_sum += nphotons
            # print("nphotons_sub-band_%d = %.2f"%(i, nphotons))
            psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img, bandpass=bandpass,
                                               folding_threshold=folding_threshold)
            star = galsim.DeltaFunction(gsparams=gsp)
            star = star.withFlux(nphotons)
            star = galsim.Convolve(psf, star)

            # stamp = star.drawImage(wcs=self.localWCS, method='phot', offset=self.offset, save_photons=True)
            # xmax = max(xmax, stamp.xmax)
            # ymax = max(ymax, stamp.ymax)
            # photons = stamp.photons
            # photons.x += self.x_nominal
            # photons.y += self.y_nominal
            # photons_list.append(photons)

            stamp = star.drawImage(wcs=real_wcs_local, method='phot', offset=offset, save_photons=True)
            xmax = max(xmax, stamp.xmax)
            ymax = max(ymax, stamp.ymax)
            photons = stamp.photons
            photons.x += x_nominal
            photons.y += y_nominal
            photons_list.append(photons)

        stamp = galsim.ImageF(int(xmax * 1.1), int(ymax * 1.1))
        stamp.wcs = real_wcs_local
        stamp.setCenter(x_nominal, y_nominal)
        bounds = stamp.bounds & galsim.BoundsI(0, chip.npix_x - 1, 0, chip.npix_y - 1)
        
        # # (DEBUG)
        # print("stamp bounds: ", stamp.bounds)
        # print(bounds)
        if bounds.area() > 0:
            chip.img.setOrigin(0, 0)
            stamp[bounds] = chip.img[bounds]
            for i in range(len(photons_list)):
                if i == 0:
                    chip.sensor.accumulate(photons_list[i], stamp)
                else:
                    chip.sensor.accumulate(photons_list[i], stamp, resume=True)

            chip.img[bounds] = stamp[bounds]

            chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
        # Test stamp size
        # print(xmax, ymax)

        # stamp = galsim.ImageF(int(xmax*1.1), int(ymax*1.1))
        # stamp.wcs = self.localWCS
        # stamp.setCenter(self.x_nominal, self.y_nominal)
        # bounds = stamp.bounds & chip.img.bounds
        # stamp[bounds] = chip.img[bounds]
        # for i in range(len(photons_list)):
        #     if i == 0:
        #         chip.sensor.accumulate(photons_list[i], stamp)
        #     else:
        #         chip.sensor.accumulate(photons_list[i], stamp, resume=True)
        #
        # chip.img[bounds] = stamp[bounds]
        # print(chip.img.array.sum())
        # print("nphotons_sum = ", nphotons_sum)
        del photons_list
        del stamp
        return True, pos_shear

    def addSLStoChipImage(self, sdp=None, chip=None, xOrderSigPlus=None, local_wcs=None):
        spec_orders = sdp.compute_spec_orders()
        for k, v in spec_orders.items():
            img_s = v[0]
            #########################################################
            # DEBUG
            #########################################################
            # print("before convolveGaussXorders, img_s:", img_s)
            nan_ids = np.isnan(img_s)
            if img_s[nan_ids].shape[0] > 0:
                # img_s[nan_ids] = 0
                print("DEBUG: before convolveGaussXorders specImg nan num is", img_s[nan_ids].shape[0])
            #########################################################
            img_s, orig_off = convolveGaussXorders(img_s, xOrderSigPlus[k])
            origin_order_x = v[1] - orig_off
            origin_order_y = v[2] - orig_off

            #########################################################
            # DEBUG
            #########################################################
            # print("DEBUG: orig_off is", orig_off)
            nan_ids = np.isnan(img_s)
            if img_s[nan_ids].shape[0] > 0:
                img_s[nan_ids] = 0
                print("DEBUG: specImg nan num is", img_s[nan_ids].shape[0])
            #########################################################
            specImg = galsim.ImageF(img_s)
            photons = galsim.PhotonArray.makeFromImage(specImg)
            photons.x += origin_order_x
            photons.y += origin_order_y

            xlen_imf = int(specImg.xmax - specImg.xmin + 1)
            ylen_imf = int(specImg.ymax - specImg.ymin + 1)
            stamp = galsim.ImageF(xlen_imf, ylen_imf)
            stamp.wcs = local_wcs
            stamp.setOrigin(origin_order_x, origin_order_y)

            bounds = stamp.bounds & galsim.BoundsI(0, chip.npix_x - 1, 0, chip.npix_y - 1)
            if bounds.area() == 0:
                continue
            chip.img.setOrigin(0, 0)
            stamp[bounds] = chip.img[bounds]
            chip.sensor.accumulate(photons, stamp)
            chip.img[bounds] = stamp[bounds]
            chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
            del stamp
        del spec_orders

    def drawObj_slitless(self, tel, pos_img, psf_model, bandpass_list, filt, chip, nphotons_tot=None, g1=0, g2=0,
                         exptime=150., normFilter=None, grating_split_pos=3685):

        norm_thr_rang_ids = normFilter['SENSITIVITY'] > 0.001
        sedNormFactor = getNormFactorForSpecWithABMAG(ABMag=self.param['mag_use_normal'], spectrum=self.sed,
                                                      norm_thr=normFilter,
                                                      sWave=np.floor(normFilter[norm_thr_rang_ids][0][0]),
                                                      eWave=np.ceil(normFilter[norm_thr_rang_ids][-1][0]))
        # print(self.x_nominal, self.y_nominal, chip.bound)

        if sedNormFactor == 0:
            return False

        if self.getMagFilter(filt) <= 15:
            folding_threshold = 5.e-4
        else:
            folding_threshold = 5.e-3
        gsp = galsim.GSParams(folding_threshold=folding_threshold)

        normalSED = Table(np.array([self.sed['WAVELENGTH'], self.sed['FLUX'] * sedNormFactor]).T,
                          names=('WAVELENGTH', 'FLUX'))

        self.real_pos = self.getRealPos(chip.img, global_x=self.posImg.x, global_y=self.posImg.y,
                                        img_real_wcs=self.real_wcs)

        x, y = self.real_pos.x + 0.5, self.real_pos.y + 0.5
        x_nominal = int(np.floor(x + 0.5))
        y_nominal = int(np.floor(y + 0.5))
        dx = x - x_nominal
        dy = y - y_nominal
        offset = galsim.PositionD(dx, dy)

        real_wcs_local = self.real_wcs.local(self.real_pos)

        flat_cube = chip.flat_cube

        xOrderSigPlus = {'A': 1.3909419820029296, 'B': 1.4760376591236062, 'C': 4.035447379743442,
                         'D': 5.5684364343742825, 'E': 16.260021029735388}
        grating_split_pos_chip = 0 + grating_split_pos
        for i in range(len(bandpass_list)):
            bandpass = bandpass_list[i]
            psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img, bandpass=bandpass,
                                               folding_threshold=folding_threshold)
            star = galsim.DeltaFunction(gsparams=gsp)
            star = star.withFlux(tel.pupil_area * exptime)
            star = galsim.Convolve(psf, star)
            starImg = star.drawImage(nx=100, ny=100, wcs=real_wcs_local, offset=offset)

            origin_star = [y_nominal - (starImg.center.y - starImg.ymin),
                           x_nominal - (starImg.center.x - starImg.xmin)]
            starImg.setOrigin(0,0)
            gal_origin = [origin_star[0], origin_star[1]]
            gal_end = [origin_star[0] + starImg.array.shape[0] - 1, origin_star[1] + starImg.array.shape[1] - 1]

            if gal_origin[1] < grating_split_pos_chip < gal_end[1]:
                subSlitPos = int(grating_split_pos_chip - gal_origin[1] + 1)
                ## part img disperse

                subImg_p1 = starImg.array[:, 0:subSlitPos]
                star_p1 = galsim.Image(subImg_p1)
                origin_p1 = origin_star
                star_p1.setOrigin(0,0)
                xcenter_p1 = min(x_nominal, grating_split_pos_chip - 1) - 0
                ycenter_p1 = y_nominal - 0

                sdp_p1 = SpecDisperser(orig_img=star_p1, xcenter=xcenter_p1,
                                       ycenter=ycenter_p1, origin=origin_p1,
                                       tar_spec=normalSED,
                                       band_start=bandpass.blue_limit * 10, band_end=bandpass.red_limit * 10,
                                       conf=chip.sls_conf[0],
                                       isAlongY=0,
                                       flat_cube=flat_cube)

                self.addSLStoChipImage(sdp=sdp_p1, chip=chip, xOrderSigPlus=xOrderSigPlus, local_wcs=real_wcs_local)

                subImg_p2 = starImg.array[:, subSlitPos + 1:starImg.array.shape[1]]
                star_p2 = galsim.Image(subImg_p2)
                star_p2.setOrigin(0, 0)
                origin_p2 = [origin_star[0], grating_split_pos_chip]
                xcenter_p2 = max(x_nominal, grating_split_pos_chip - 1) - 0
                ycenter_p2 = y_nominal - 0

                sdp_p2 = SpecDisperser(orig_img=star_p2, xcenter=xcenter_p2,
                                       ycenter=ycenter_p2, origin=origin_p2,
                                       tar_spec=normalSED,
                                       band_start=bandpass.blue_limit * 10, band_end=bandpass.red_limit * 10,
                                       conf=chip.sls_conf[1],
                                       isAlongY=0,
                                       flat_cube=flat_cube)

                self.addSLStoChipImage(sdp=sdp_p2, chip=chip, xOrderSigPlus=xOrderSigPlus, local_wcs=real_wcs_local)

                del sdp_p1
                del sdp_p2
            elif grating_split_pos_chip <= gal_origin[1]:
                sdp = SpecDisperser(orig_img=starImg, xcenter=x_nominal - 0,
                                    ycenter=y_nominal - 0, origin=origin_star,
                                    tar_spec=normalSED,
                                    band_start=bandpass.blue_limit * 10, band_end=bandpass.red_limit * 10,
                                    conf=chip.sls_conf[1],
                                    isAlongY=0,
                                    flat_cube=flat_cube)
                self.addSLStoChipImage(sdp=sdp, chip=chip, xOrderSigPlus=xOrderSigPlus, local_wcs=real_wcs_local)
                del sdp
            elif grating_split_pos_chip >= gal_end[1]:
                sdp = SpecDisperser(orig_img=starImg, xcenter=x_nominal - 0,
                                    ycenter=y_nominal - 0, origin=origin_star,
                                    tar_spec=normalSED,
                                    band_start=bandpass.blue_limit * 10, band_end=bandpass.red_limit * 10,
                                    conf=chip.sls_conf[0],
                                    isAlongY=0,
                                    flat_cube=flat_cube)
                self.addSLStoChipImage(sdp=sdp, chip=chip, xOrderSigPlus=xOrderSigPlus, local_wcs=real_wcs_local)
                del sdp
            del psf
        return True, pos_shear

    def SNRestimate(self, img_obj, flux, noise_level=0.0, seed=31415):
        img_flux = img_obj.added_flux
        stamp = img_obj.copy() * 0.0
        rng = galsim.BaseDeviate(seed)
        gaussianNoise = galsim.GaussianNoise(rng, sigma=noise_level)
        stamp.addNoise(gaussianNoise)
        sig_obj = np.std(stamp.array)
        snr_obj = img_flux / sig_obj
        return snr_obj

    # def getObservedEll(self, g1=0, g2=0):
    #     return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0