import os
import galsim
import random
import numpy as np
import h5py as h5
import healpy as hp
import astropy.constants as cons
import traceback
from astropy.coordinates import spherical_to_cartesian
from astropy.table import Table
from scipy import interpolate
from datetime import datetime

from ObservationSim.MockObject import CatalogBase, Star, Galaxy, Quasar, Stamp
from ObservationSim.MockObject._util import tag_sed, getObservedSED, getABMAG, integrate_sed_bandpass, comoving_dist
from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position

import astropy.io.fits as fitsio
from ObservationSim.MockObject._util import seds, sed_assign, extAv

# (TEST)
from astropy.cosmology import FlatLambdaCDM
from astropy import constants
from astropy import units as U

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources

NSIDE = 128

def get_bundleIndex(healpixID_ring, bundleOrder=4, healpixOrder=7):
    assert NSIDE == 2**healpixOrder
    shift = healpixOrder - bundleOrder
    shift = 2*shift

    nside_bundle = 2**bundleOrder
    nside_healpix= 2**healpixOrder

    healpixID_nest= hp.ring2nest(nside_healpix, healpixID_ring)
    bundleID_nest = (healpixID_nest >> shift)
    bundleID_ring = hp.nest2ring(nside_bundle, bundleID_nest)

    return bundleID_ring

class Catalog(CatalogBase):
    def __init__(self, config, chip, pointing, chip_output, filt, **kwargs):
        super().__init__()
        self.cat_dir = os.path.join(config["data_dir"], config["catalog_options"]["input_path"]["cat_dir"])
        self.seed_Av = config["catalog_options"]["seed_Av"]

        # (TEST)
        self.cosmo = FlatLambdaCDM(H0=67.66, Om0=0.3111)

        self.chip_output = chip_output
        self.filt = filt
        self.logger = chip_output.logger

        with pkg_resources.path('Catalog.data', 'SLOAN_SDSS.g.fits') as filter_path:
            self.normF_star = Table.read(str(filter_path))
        
        self.config = config
        self.chip = chip
        self.pointing = pointing

        self.max_size = 0.

        if "star_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["star_cat"] and not config["catalog_options"]["galaxy_only"]:
            star_file = config["catalog_options"]["input_path"]["star_cat"]
            star_SED_file = config["catalog_options"]["SED_templates_path"]["star_SED"]
            self.star_path = os.path.join(self.cat_dir, star_file)
            self.star_SED_path = os.path.join(config["data_dir"], star_SED_file)
            self._load_SED_lib_star()
        if "galaxy_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["galaxy_cat"] and not config["catalog_options"]["star_only"]:
            galaxy_dir = config["catalog_options"]["input_path"]["galaxy_cat"]
            self.galaxy_path = os.path.join(self.cat_dir, galaxy_dir)
            self.galaxy_SED_path = os.path.join(config["data_dir"], config["catalog_options"]["SED_templates_path"]["galaxy_SED"])
            self._load_SED_lib_gals()

        if "AGN_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["AGN_cat"] and not config["catalog_options"]["star_only"]:
            AGN_dir = config["catalog_options"]["input_path"]["AGN_cat"]
            self.AGN_path = os.path.join(config["data_dir"], config["catalog_options"]["input_path"]["AGN_cat"])
            self.AGN_SED_path = os.path.join(config["data_dir"], config["catalog_options"]["SED_templates_path"]["AGN_SED"])
            self.AGN_SED_wave_path = os.path.join(config['data_dir'], config["catalog_options"]["SED_templates_path"]["AGN_SED_WAVE"])
            self._load_SED_lib_AGN()

        ###mock_stamp_START
        if "stamp_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["stamp_cat"] and config["catalog_options"]["stamp_yes"]:
            stamp_file = config["catalog_options"]["input_path"]["stamp_cat"]
            self.stamp_path = os.path.join(self.cat_dir, stamp_file)
            self.tempSed_gal, self.tempRed_gal = seds("galaxy.list", seddir="/share/simudata/CSSOSDataProductsSims/data/Templates/Galaxy/") #only for test
        ###mock_stamp_END
        
        if "rotateEll" in config["catalog_options"]:
            self.rotation = np.radians(float(config["catalog_options"]["rotateEll"]))
        else:
            self.rotation = 0.

        # Update output .cat header with catalog specific output columns
        self._add_output_columns_header()

        self._get_healpix_list()
        self._load()
    
    def _add_output_columns_header(self):
        self.add_hdr = " model_tag teff logg feh"
        self.add_hdr += " bulgemass diskmass detA e1 e2 kappa g1 g2 size galType veldisp "

        self.add_fmt = " %10s %8.4f %8.4f %8.4f"
        self.add_fmt += " %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %4d %8.4f "
        self.chip_output.update_ouptut_header(additional_column_names=self.add_hdr)

    def _get_healpix_list(self):
        self.sky_coverage = self.chip.getSkyCoverageEnlarged(self.chip.img.wcs, margin=0.2)
        ra_min, ra_max, dec_min, dec_max = self.sky_coverage.xmin, self.sky_coverage.xmax, self.sky_coverage.ymin, self.sky_coverage.ymax
        ra = np.deg2rad(np.array([ra_min, ra_max, ra_max, ra_min]))
        dec = np.deg2rad(np.array([dec_max, dec_max, dec_min, dec_min]))
        # vertices = spherical_to_cartesian(1., dec, ra)
        self.pix_list = hp.query_polygon(
            NSIDE,
            hp.ang2vec(np.radians(90.) - dec, ra),
            inclusive=True
        )
        # self.pix_list = hp.query_polygon(NSIDE, np.array(vertices).T, inclusive=True)
        if self.logger is not None:
            msg = str(("HEALPix List: ", self.pix_list))
            self.logger.info(msg)
        else:
            print("HEALPix List: ", self.pix_list)

    def load_norm_filt(self, obj):
        if obj.type == "star":
            return self.normF_star
        elif obj.type == "galaxy" or obj.type == "quasar":
            # return self.normF_galaxy
            return None
        ###mock_stamp_START
        elif obj.type == "stamp":
            #return self.normF_galaxy  ###normalize_filter for stamp
            return None
        ###mock_stamp_END
        else:
            return None

    def _load_SED_lib_star(self):
        self.tempSED_star = h5.File(self.star_SED_path,'r')

    def _load_SED_lib_gals(self):
        pcs = h5.File(os.path.join(self.galaxy_SED_path, "pcs.h5"), "r")
        lamb = h5.File(os.path.join(self.galaxy_SED_path, "lamb.h5"), "r")
        self.lamb_gal = lamb['lamb'][()]
        self.pcs = pcs['pcs'][()]

    def _load_SED_lib_AGN(self):
        from astropy.io import fits
        self.SED_AGN = fits.open(self.AGN_SED_path)[0].data
        self.lamb_AGN = np.load(self.AGN_SED_wave_path)


    def _load_gals(self, gals, pix_id=None, cat_id=0):
        ngals = len(gals['ra'])

        # Apply astrometric modeling
        # in C3 case only aberration
        ra_arr = gals['ra'][:]
        dec_arr = gals['dec'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = np.zeros(ngals).tolist()
            pmdec_list = np.zeros(ngals).tolist()
            rv_list = np.zeros(ngals).tolist()
            parallax_list = [1e-9] * ngals
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=ngals,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )

        for igals in range(ngals):
            # # (TEST)
            # if igals > 100:
            #     break
            
            param = self.initialize_param()
            param['ra'] = ra_arr[igals]
            param['dec'] = dec_arr[igals]
            param['ra_orig'] = gals['ra'][igals]
            param['dec_orig'] = gals['dec'][igals]
            param['mag_use_normal'] = gals['mag_csst_%s'%(self.filt.filter_type)][igals]
            if self.filt.is_too_dim(mag=param['mag_use_normal'], margin=self.config["obs_setting"]["mag_lim_margin"]):
                continue

            param['z'] = gals['redshift'][igals]
            param['model_tag'] = 'None'
            param['g1'] = gals['shear'][igals][0]
            param['g2'] = gals['shear'][igals][1]
            param['kappa'] = gals['kappa'][igals]
            param['e1'] = gals['ellipticity_true'][igals][0]
            param['e2'] = gals['ellipticity_true'][igals][1]
            
            # For shape calculation
            
            param['ell_total'] = np.sqrt(param['e1']**2 + param['e2']**2)
            if param['ell_total'] > 0.9:
                continue
            param['e1_disk'] = param['e1']
            param['e2_disk'] = param['e2']
            param['e1_bulge'] = param['e1']
            param['e2_bulge'] = param['e2']


            param['delta_ra'] = 0
            param['delta_dec'] = 0

            # Masses
            param['bulgemass'] = gals['bulgemass'][igals]
            param['diskmass'] = gals['diskmass'][igals]

            param['size'] = gals['size'][igals]
            if param['size'] > self.max_size:
                self.max_size = param['size']

            # Sizes
            param['bfrac'] = param['bulgemass']/(param['bulgemass'] + param['diskmass'])
            if param['bfrac'] >= 0.6:
                param['hlr_bulge'] = param['size']
                param['hlr_disk'] = param['size'] * (1. - param['bfrac'])
            else:
                param['hlr_disk'] = param['size']
                param['hlr_bulge'] = param['size'] * param['bfrac']

            # SED coefficients
            param['coeff'] = gals['coeff'][igals]
            param['detA'] = gals['detA'][igals]

            # Others
            param['galType'] = gals['type'][igals]
            param['veldisp'] = gals['veldisp'][igals]
            
            # TEST no redening and no extinction
            param['av'] = 0.0
            param['redden'] = 0

            param['star'] = 0   # Galaxy

            # NOTE: this cut cannot be put before the SED type has been assigned
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue

            # TEMP
            self.ids += 1
            # param['id'] = self.ids
            param['id'] = '%06d'%(int(pix_id)) + '%06d'%(cat_id) + '%08d'%(igals)
            
            if param['star'] == 0:
                obj = Galaxy(param, logger=self.logger)
            
            # Need to deal with additional output columns
            obj.additional_output_str = self.add_fmt%("n", 0., 0., 0.,
                                                    param['bulgemass'], param['diskmass'], param['detA'],
                                                    param['e1'], param['e2'], param['kappa'], param['g1'], param['g2'], param['size'],
                                                    param['galType'], param['veldisp'])
            
            self.objs.append(obj)

    def _load_stars(self, stars, pix_id=None):
        nstars = len(stars['sourceID'])
        # Apply astrometric modeling
        ra_arr = stars["RA"][:]
        dec_arr = stars["Dec"][:]
        pmra_arr = stars['pmra'][:]
        pmdec_arr = stars['pmdec'][:]
        rv_arr = stars['RV'][:]
        parallax_arr = stars['parallax'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = pmra_arr.tolist()
            pmdec_list = pmdec_arr.tolist()
            rv_list = rv_arr.tolist()
            parallax_list = parallax_arr.tolist()
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=nstars,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )
        for istars in range(nstars):
            # # (TEST)
            # if istars > 100:
            #     break

            param = self.initialize_param()
            param['ra'] = ra_arr[istars]
            param['dec'] = dec_arr[istars]
            param['ra_orig'] = stars["RA"][istars]
            param['dec_orig'] = stars["Dec"][istars]
            param['pmra'] = pmra_arr[istars]
            param['pmdec'] = pmdec_arr[istars]
            param['rv'] = rv_arr[istars]
            param['parallax'] = parallax_arr[istars]
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue
            param['mag_use_normal'] = stars['app_sdss_g'][istars]
            # if param['mag_use_normal'] >= 26.5:
            #     continue
            self.ids += 1
            param['id'] = stars['sourceID'][istars]
            param['sed_type'] = stars['sourceID'][istars]
            param['model_tag'] = stars['model_tag'][istars]
            param['teff'] = stars['teff'][istars]
            param['logg'] = stars['grav'][istars]
            param['feh'] = stars['feh'][istars]
            param['z'] = 0.0
            param['star'] = 1   # Star
            obj = Star(param, logger=self.logger)

            # Append additional output columns to the .cat file
            obj.additional_output_str = self.add_fmt%(param["model_tag"], param['teff'], param['logg'], param['feh'],
                                                    0., 0., 0., 0., 0., 0., 0., 0., 0., -1, 0.)

            self.objs.append(obj)

    def _load_AGNs(self):
        data = Table.read(self.AGN_path)
        ra_arr = data['ra']
        dec_arr = data['dec']
        nAGNs = len(data)
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = np.zeros(nAGNs).tolist()
            pmdec_list = np.zeros(nAGNs).tolist()
            rv_list = np.zeros(nAGNs).tolist()
            parallax_list = [1e-9] * nAGNs
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=nAGNs,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )
        for iAGNs in range(nAGNs):
            param = self.initialize_param()
            param['ra'] = ra_arr[iAGNs]
            param['dec'] = dec_arr[iAGNs]
            param['ra_orig'] = data['ra'][iAGNs]
            param['dec_orig'] = data['dec'][iAGNs]
            param['z'] = data['z'][iAGNs]
            param['appMag'] = data['appMag'][iAGNs]
            param['absMag'] = data['absMag'][iAGNs]
            
            # NOTE: this cut cannot be put before the SED type has been assigned
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue

            # TEST no redening and no extinction
            param['av'] = 0.0
            param['redden'] = 0

            param['star'] = 2   # Quasar
            param['id'] = data['igmlos'][iAGNs]

            if param['star'] == 2:
                obj = Quasar(param, logger=self.logger)

            # Append additional output columns to the .cat file
            obj.additional_output_str = self.add_fmt%("n", 0., 0., 0.,
                                                    0., 0., 0., 0., 0., 0., 0., 0., 0., -1, 0.)
            self.objs.append(obj)

    ###mock_stamp_START
    def _load_stamps(self, stamps, pix_id=None):
        nstamps = len(stamps['filename'])
        self.rng_sedGal = random.Random()
        self.rng_sedGal.seed(pix_id) # Use healpix index as the random seed
        self.ud = galsim.UniformDeviate(pix_id)

        for istamp in range(nstamps):
            fitsfile = os.path.join(self.cat_dir, "stampCats/"+stamps['filename'][istamp].decode('utf-8'))
            hdu=fitsio.open(fitsfile)

            param = self.initialize_param()
            param['id']   = hdu[0].header['index'] #istamp
            param['star'] = 3      # Stamp type in .cat file
            param['lensGalaxyID'] = hdu[0].header['lensGID']
            param['ra'] = hdu[0].header['ra']
            param['dec']= hdu[0].header['dec']
            param['pixScale']= hdu[0].header['pixScale']
            #param['srcGalaxyID'] = hdu[0].header['srcGID']
            #param['mu']= hdu[0].header['mu']
            #param['PA']= hdu[0].header['PA']
            #param['bfrac']= hdu[0].header['bfrac']
            #param['z']= hdu[0].header['z']
            param['mag_use_normal'] = 22 #hdu[0].header['m_normal'] #gals['mag_true_g_lsst']

            assert(stamps['lensGID'][istamp] == param['lensGalaxyID'])

            # Apply astrometric modeling
            # in C3 case only aberration
            param['ra_orig'] = param['ra']
            param['dec_orig']= param['dec']
            if self.config["obs_setting"]["enable_astrometric_model"]:
                ra_list = [param['ra']] #ra_arr.tolist()
                dec_list= [param['dec']] #dec_arr.tolist()
                pmra_list = np.zeros(1).tolist()
                pmdec_list = np.zeros(1).tolist()
                rv_list = np.zeros(1).tolist()
                parallax_list = [1e-9] * 1
                dt = datetime.fromtimestamp(self.pointing.timestamp)
                date_str = dt.date().isoformat()
                time_str = dt.time().isoformat()
                ra_arr, dec_arr = on_orbit_obs_position(
                    input_ra_list=ra_list,
                    input_dec_list=dec_list,
                    input_pmra_list=pmra_list,
                    input_pmdec_list=pmdec_list,
                    input_rv_list=rv_list,
                    input_parallax_list=parallax_list,
                    input_nstars=1,
                    input_x=self.pointing.sat_x,
                    input_y=self.pointing.sat_y,
                    input_z=self.pointing.sat_z,
                    input_vx=self.pointing.sat_vx,
                    input_vy=self.pointing.sat_vy,
                    input_vz=self.pointing.sat_vz,
                    input_epoch="J2015.5",
                    input_date_str=date_str,
                    input_time_str=time_str
                )
                param['ra'] = ra_arr[0]
                param['dec']= dec_arr[0]

            # Assign each galaxy a template SED
            param['sed_type'] = sed_assign(phz=param['z'], btt=param['bfrac'], rng=self.rng_sedGal)
            param['redden'] = self.tempRed_gal[param['sed_type']]
            param['av'] = 0.0
            param['redden'] = 0

            #param["CSSTmag"]= True
            #param["mag_r"] = 20.
            #param['']
            ###more keywords for stamp###
            param['image'] = hdu[0].data
            param['image'] = param['image']/(np.sum(param['image']))
            obj = Stamp(param)
            self.objs.append(obj)
    ###mock_stamp_END

    def _load(self, **kwargs):
        self.objs = []
        self.ids = 0
        
        if "star_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["star_cat"] and not self.config["catalog_options"]["galaxy_only"]:
            star_cat = h5.File(self.star_path, 'r')['catalog']
            for pix in self.pix_list:
                try:
                    stars = star_cat[str(pix)]
                    self._load_stars(stars, pix_id=pix)
                    del stars
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)
        
        if "galaxy_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["galaxy_cat"] and not self.config["catalog_options"]["star_only"]:
            for pix in self.pix_list:
                try:
                    bundleID  = get_bundleIndex(pix)
                    file_path = os.path.join(self.galaxy_path, "galaxies_C6_bundle{:06}.h5".format(bundleID))
                    gals_cat = h5.File(file_path, 'r')['galaxies']
                    gals = gals_cat[str(pix)]
                    self._load_gals(gals, pix_id=pix, cat_id=bundleID)
                    del gals
                except Exception as e:
                    traceback.print_exc()
                    self.logger.error(str(e))
                    print(e)

        if "AGN_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["AGN_cat"] and not self.config["catalog_options"]["star_only"]:
            try:
                self._load_AGNs()
            except Exception as e:
                traceback.print_exc()
                self.logger.error(str(e))
                print(e)

        ###mock_stamp_START
        if "stamp_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["stamp_cat"] and self.config["catalog_options"]["stamp_yes"]:
            stamps_cat = h5.File(self.stamp_path, 'r')['Stamps']
            for pix in self.pix_list:
                try:
                    stamps = stamps_cat[str(pix)]
                    self._load_stamps(stamps, pix_id=pix)
                    del stamps
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)
        ###mock_stamp_END

        if self.logger is not None:
            self.logger.info("maximum galaxy size: %.4f"%(self.max_size))
            self.logger.info("number of objects in catalog: %d"%(len(self.objs)))
        else:
            print("number of objects in catalog: ", len(self.objs))

    def load_sed(self, obj, **kwargs):
        if obj.type == 'star':
            _, wave, flux = tag_sed(
                h5file=self.tempSED_star,
                model_tag=obj.param['model_tag'],
                teff=obj.param['teff'],
                logg=obj.param['logg'],
                feh=obj.param['feh']
            )
        elif obj.type == 'galaxy' or obj.type == 'quasar':
            factor = 10**(-.4 * self.cosmo.distmod(obj.z).value)
            if obj.type == 'galaxy':
                flux = np.matmul(self.pcs, obj.coeff) * factor
                #  if np.any(flux < 0):
                #     raise ValueError("Glaxy %s: negative SED fluxes"%obj.id)
                flux[flux < 0] = 0.
                sedcat = np.vstack((self.lamb_gal, flux)).T
                sed_data = getObservedSED(
                    sedCat=sedcat,
                    redshift=obj.z,
                    av=obj.param["av"],
                    redden=obj.param["redden"]
                )
                wave, flux = sed_data[0], sed_data[1]
            elif obj.type == 'quasar':
                flux = self.SED_AGN[int(obj.id)] * 1e-17
                # if np.any(flux < 0):
                #     raise ValueError("Glaxy %s: negative SED fluxes"%obj.id)
                flux[flux < 0] = 0.
                # sedcat = np.vstack((self.lamb_AGN, flux)).T
                wave = self.lamb_AGN
            # print("sed (erg/s/cm2/A) = ", sed_data)
            # np.savetxt(os.path.join(self.config["work_dir"], "%s_sed.txt"%(obj.id)), sedcat)
        ###mock_stamp_START
        elif obj.type == 'stamp':
            sed_data = getObservedSED(
                sedCat=self.tempSed_gal[obj.sed_type],
                redshift=obj.z,
                av=obj.param["av"],
                redden=obj.param["redden"]
            )
            wave, flux = sed_data[0], sed_data[1]
        ###mock_stamp_END
        else:
            raise ValueError("Object type not known")
        speci = interpolate.interp1d(wave, flux)
        lamb = np.arange(2000, 11001+0.5, 0.5)
        y = speci(lamb)
        # erg/s/cm2/A --> photon/s/m2/A
        all_sed = y * lamb / (cons.h.value * cons.c.value) * 1e-13
        sed = Table(np.array([lamb, all_sed]).T, names=('WAVELENGTH', 'FLUX'))
        
        if obj.type == 'quasar':
            # integrate to get the magnitudes
            sed_photon = np.array([sed['WAVELENGTH'], sed['FLUX']]).T
            sed_photon = galsim.LookupTable(x=np.array(sed_photon[:, 0]), f=np.array(sed_photon[:, 1]), interpolant='nearest')
            sed_photon = galsim.SED(sed_photon, wave_type='A', flux_type='1', fast=False)
            interFlux = integrate_sed_bandpass(sed=sed_photon, bandpass=self.filt.bandpass_full)
            obj.param['mag_use_normal'] = getABMAG(interFlux, self.filt.bandpass_full)
            # if obj.param['mag_use_normal'] >= 30:
            #     print("obj ID = %d"%obj.id)
            #     print("mag_use_normal = %.3f"%obj.param['mag_use_normal'])
            #     print("integrated flux = %.7f"%(interFlux))
            #     print("app mag = %.3f"%obj.param['appMag'])
            #     np.savetxt('./AGN_SED_test/sed_objID_%d.txt'%obj.id, np.transpose([self.lamb_AGN, self.SED_AGN[int(obj.id)]]))
            # print("obj ID = %d"%obj.id)
            # print("mag_use_normal = %.3f"%obj.param['mag_use_normal'])
            # print("integrated flux = %.7f"%(interFlux))
            # print("app mag = %.3f"%obj.param['appMag'])
            # print("abs mag = %.3f"%obj.param['absMag'])
            # mag = getABMAG(interFlux, self.filt.bandpass_full)
            # print("mag diff = %.3f"%(mag - obj.param['mag_use_normal']))
        del wave
        del flux
        return sed