Skip to content
MockObject.py 10.3 KiB
Newer Older
Fang Yuedong's avatar
Fang Yuedong committed
import galsim
import numpy as np
import astropy.constants as cons
from astropy.table import Table
from ._util import magToFlux, vc_A
from ._util import integrate_sed_bandpass, getNormFactorForSpecWithABMAG, getObservedSED, getABMAG
from .SpecDisperser import SpecDisperser

class MockObject(object):
    def __init__(self, param):
        self.param = param

        if self.param["star"] == 0:
            self.type = "galaxy"
        elif self.param["star"] == 1:
            self.type = "star"
        elif self.param["star"] == 2:
            self.type = "quasar"
        self.id = self.param["id"]
        self.ra = self.param["ra"]
        self.dec = self.param["dec"]
        self.z = self.param["z"]
        self.sed_type = self.param["sed_type"]
        self.model_tag = self.param["model_tag"]
        self.mag_use_normal = self.param["mag_use_normal"]

    def getMagFilter(self, filt):
        if filt.filter_type in ["GI", "GV", "GU"]:
            return self.param["mag_use_normal"]
        return self.param["mag_%s"%filt.filter_type]
        # (TEST) stamp size
        # return 13.0

    def getNumPhotons(self, flux, tel, exptime=150.):
        pupil_area = tel.pupil_area * (100.)**2 # m^2 to cm^2
        return flux * pupil_area * exptime

    def getElectronFluxFilt(self, filt, tel, exptime=150.):
        photonEnergy = filt.getPhotonE()
        flux = magToFlux(self.getMagFilter(filt))
        factor = 1.0e4 * flux/photonEnergy * vc_A * (1.0/filt.blue_limit - 1.0/filt.red_limit)
        return factor * filt.efficiency * tel.pupil_area * exptime

    def getPosWorld(self):
        ra = self.param["ra"]
        dec = self.param["dec"]
        return galsim.CelestialCoord(ra=ra*galsim.degrees,dec=dec*galsim.degrees)

    def getPosImg_Offset_WCS(self, img, fdmodel=None, chip=None, verbose=True):
        self.posImg = img.wcs.toImage(self.getPosWorld())
        self.localWCS = img.wcs.local(self.posImg)
        if (fdmodel is not None) and (chip is not None):
            if verbose:
                print("\n")
                print("Before field distortion:\n")
                print("x = %.2f, y = %.2f\n"%(self.posImg.x, self.posImg.y), flush=True)
            self.posImg = fdmodel.get_Distorted(chip=chip, pos_img=self.posImg)
            if verbose:
                print("After field distortion:\n")
                print("x = %.2f, y = %.2f\n"%(self.posImg.x, self.posImg.y), flush=True)
        x, y = self.posImg.x + 0.5, self.posImg.y + 0.5
        self.x_nominal = int(np.floor(x + 0.5))
        self.y_nominal = int(np.floor(y + 0.5))
        dx = x - self.x_nominal
        dy = y - self.y_nominal
        self.offset = galsim.PositionD(dx, dy)
        return self.posImg, self.offset, self.localWCS

    def drawObject(self, img, final, flux=None, filt=None, tel=None, exptime=150.):
        """ Draw (point like) object on img.
        Should be overided for extended source, e.g. galaxy...
        Paramter:
            img: the "canvas"
            final: final (after shear, PSF etc.) GSObject
        Return:
            img: the image with the GSObject added (or discarded)
            isUpdated: is the "canvas" been updated? (a flag for updating statistcs)
        """
        isUpdated = True

        # Draw with FFT
        # stamp = final.drawImage(wcs=self.localWCS, offset=self.offset)

        # Draw with Photon Shoot
        stamp = final.drawImage(wcs=self.localWCS, method='phot', offset=self.offset)
        
        stamp.setCenter(self.x_nominal, self.y_nominal)
        if np.sum(np.isnan(stamp.array)) >= 1:
            stamp.setZero()
        bounds = stamp.bounds & img.bounds
        if bounds.area() == 0:
            isUpdated = False
        else:
            img[bounds] += stamp[bounds]
        return img, stamp, isUpdated

    def drawObj_multiband(self, tel, pos_img, psf_model, bandpass_list, filt, chip, nphotons_tot=None, g1=0, g2=0, exptime=150.):
        if nphotons_tot == None:
            nphotons_tot = self.getElectronFluxFilt(filt, tel, exptime)
        # print("nphotons_tot = ", nphotons_tot)

        try:
            full = integrate_sed_bandpass(sed=self.sed, bandpass=filt.bandpass_full)
        except Exception as e:
            print(e)
            return False

        nphotons_sum = 0
        photons_list = []
        xmax, ymax = 0, 0

        # (TEST) Galsim Parameters
        if self.getMagFilter(filt) <= 15:
            folding_threshold = 5.e-4
        else:
            folding_threshold = 5.e-3
        gsp = galsim.GSParams(folding_threshold=folding_threshold)

        for i in range(len(bandpass_list)):
            bandpass = bandpass_list[i]
            try:
                sub = integrate_sed_bandpass(sed=self.sed, bandpass=bandpass)
            except Exception as e:
                print(e)
                # return False
                continue
        
            ratio = sub/full

            if not (ratio == -1 or (ratio != ratio)):
                nphotons = ratio * nphotons_tot
            else:
                # return False
                continue
            nphotons_sum += nphotons
            # print("nphotons_sub-band_%d = %.2f"%(i, nphotons))
            psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img, bandpass=bandpass, folding_threshold=folding_threshold)
            star = galsim.DeltaFunction(gsparams=gsp)
            star = star.withFlux(nphotons)
            star = galsim.Convolve(psf, star)

            stamp = star.drawImage(wcs=self.localWCS, method='phot', offset=self.offset, save_photons=True)
            xmax = max(xmax, stamp.xmax)
            ymax = max(ymax, stamp.ymax)
            photons = stamp.photons
            photons.x += self.x_nominal
            photons.y += self.y_nominal
            photons_list.append(photons)

        # Test stamp size
        # print(xmax, ymax)

        stamp = galsim.ImageF(int(xmax*1.1), int(ymax*1.1))
        stamp.wcs = self.localWCS
        stamp.setCenter(self.x_nominal, self.y_nominal)
        bounds = stamp.bounds & chip.img.bounds
        stamp[bounds] = chip.img[bounds]
        for i in range(len(photons_list)):
            if i == 0:
                chip.sensor.accumulate(photons_list[i], stamp)
            else:
                chip.sensor.accumulate(photons_list[i], stamp, resume=True)

        chip.img[bounds] = stamp[bounds]
        # print(chip.img.array.sum())
        # print("nphotons_sum = ", nphotons_sum)
        del photons_list
        del stamp
        return True, pos_shear

    def drawObj_slitless(self, tel, pos_img, psf_model, bandpass_list, filt, chip, nphotons_tot=None, g1=0, g2=0,
                         exptime=150., normFilter=None):
        
        norm_thr_rang_ids = normFilter['SENSITIVITY'] > 0.001
        sedNormFactor = getNormFactorForSpecWithABMAG(ABMag=self.param['mag_use_normal'], spectrum=self.sed,
                                                      norm_thr=normFilter,
                                                      sWave=np.floor(normFilter[norm_thr_rang_ids][0][0]),
                                                      eWave=np.ceil(normFilter[norm_thr_rang_ids][-1][0]))
        # print(self.x_nominal, self.y_nominal, chip.bound)

        if sedNormFactor == 0:
            return False

        if self.getMagFilter(filt) <= 15:
            folding_threshold = 5.e-4
        else:
            folding_threshold = 5.e-3
        gsp = galsim.GSParams(folding_threshold=folding_threshold)

        normalSED = Table(np.array([self.sed['WAVELENGTH'], self.sed['FLUX'] * sedNormFactor]).T,
                          names=('WAVELENGTH', 'FLUX'))

        for i in range(len(bandpass_list)):
            bandpass = bandpass_list[i]
            psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img, bandpass=bandpass, folding_threshold=folding_threshold)
            star = galsim.DeltaFunction(gsparams=gsp)
            star = star.withFlux(tel.pupil_area * exptime)
            star = galsim.Convolve(psf, star)
            starImg = star.drawImage(nx=100, ny=100, wcs=self.localWCS)

            origin_star = [self.y_nominal - (starImg.center.y - starImg.ymin),
                           self.x_nominal - (starImg.center.x - starImg.xmin)]

            sdp = SpecDisperser(orig_img=starImg, xcenter=self.x_nominal-chip.bound.xmin,
                                ycenter=self.y_nominal-chip.bound.ymin, origin=origin_star,
                                tar_spec=normalSED,
                                band_start=bandpass.blue_limit * 10, band_end=bandpass.red_limit * 10,
                                conf=chip.sls_conf[1],
                                isAlongY=0)

            spec_orders = sdp.compute_spec_orders()
            for k, v in spec_orders.items():
                img_s = v[0]
                origin_order_x = v[1]
                origin_order_y = v[2]
                specImg = galsim.ImageF(img_s)
                photons = galsim.PhotonArray.makeFromImage(specImg)
                photons.x += origin_order_x
                photons.y += origin_order_y

                xlen_imf = int(specImg.xmax - specImg.xmin + 1)
                ylen_imf = int(specImg.ymax - specImg.ymin + 1)
                stamp = galsim.ImageF(xlen_imf, ylen_imf)
                stamp.wcs = self.localWCS
                stamp.setOrigin(origin_order_x, origin_order_y)

                bounds = stamp.bounds & chip.img.bounds
                if bounds.area() == 0:
                    continue
                stamp[bounds] = chip.img[bounds]
                chip.sensor.accumulate(photons, stamp)
                chip.img[bounds] = stamp[bounds]
                del stamp
            del sdp
            del spec_orders
            del psf
        return True, pos_shear

    def SNRestimate(self, img_obj, flux, noise_level=0.0, seed=31415):
        img_flux = img_obj.added_flux
        stamp = img_obj.copy() * 0.0
        rng = galsim.BaseDeviate(seed)
        gaussianNoise = galsim.GaussianNoise(rng, sigma=noise_level)
        stamp.addNoise(gaussianNoise)
        sig_obj = np.std(stamp.array)
        snr_obj = img_flux / sig_obj
        return snr_obj

    def getObservedEll(self, g1=0, g2=0):
        return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0