import os
import galsim
import random
import numpy as np
import h5py as h5
import healpy as hp
import astropy.constants as cons
import traceback
from astropy.coordinates import spherical_to_cartesian
from astropy.table import Table
from scipy import interpolate
from datetime import datetime

from ObservationSim.MockObject import CatalogBase, Star, Galaxy, Quasar, Stamp
from ObservationSim.MockObject._util import tag_sed, getObservedSED, getABMAG, integrate_sed_bandpass, comoving_dist
from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position

import astropy.io.fits as fitsio
from ObservationSim.MockObject._util import seds, sed_assign, extAv

# (TEST)
from astropy.cosmology import FlatLambdaCDM
from astropy import constants
from astropy import units as U

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources

NSIDE = 128

class Catalog(CatalogBase):
    def __init__(self, config, chip, pointing, chip_output, filt, **kwargs):
        super().__init__()
        self.cat_dir = os.path.join(config["data_dir"], config["catalog_options"]["input_path"]["cat_dir"])
        self.seed_Av = config["catalog_options"]["seed_Av"]

        # (TEST)
        self.cosmo = FlatLambdaCDM(H0=67.66, Om0=0.3111)

        self.chip_output = chip_output
        self.filt = filt
        self.logger = chip_output.logger

        with pkg_resources.path('Catalog.data', 'SLOAN_SDSS.g.fits') as filter_path:
            self.normF_star = Table.read(str(filter_path))
        
        self.config = config
        self.chip = chip
        self.pointing = pointing

        self.max_size = 0.

        if "stamp_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["stamp_cat"] and config["catalog_options"]["stamp_yes"]:
            stamp_file = config["catalog_options"]["input_path"]["stamp_cat"]
            self.stamp_path = os.path.join(self.cat_dir, stamp_file)
            #self.stamp_SED_path = os.path.join(config["data_dir"], config["SED_templates_path"]["stamp_SED"]) ###shoule be stamp-SED
            #self._load_SED_lib_stamps() ###shoule be stamp-SED
            self.tempSed_gal, self.tempRed_gal = seds("galaxy.list", seddir="/share/simudata/CSSOSDataProductsSims/data/Templates/Galaxy/") #only for test
        
        if "rotateEll" in config["catalog_options"]:
            self.rotation = float(int(config["catalog_options"]["rotateEll"]/45.))
        else:
            self.rotation = 0.

        # Update output .cat header with catalog specific output columns
        self._add_output_columns_header()

        self._get_healpix_list()
        self._load()
    
    def _add_output_columns_header(self):
        self.add_hdr = " model_tag teff logg feh"
        self.add_hdr += " bulgemass diskmass detA e1 e2 kappa g1 g2 size galType veldisp "

        self.add_fmt = " %10s %8.4f %8.4f %8.4f"
        self.add_fmt += " %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %4d %8.4f "
        self.chip_output.update_output_header(additional_column_names=self.add_hdr)

    def _get_healpix_list(self):
        self.sky_coverage = self.chip.getSkyCoverageEnlarged(self.chip.img.wcs, margin=0.2)
        ra_min, ra_max, dec_min, dec_max = self.sky_coverage.xmin, self.sky_coverage.xmax, self.sky_coverage.ymin, self.sky_coverage.ymax
        ra = np.deg2rad(np.array([ra_min, ra_max, ra_max, ra_min]))
        dec = np.deg2rad(np.array([dec_max, dec_max, dec_min, dec_min]))
        # vertices = spherical_to_cartesian(1., dec, ra)
        self.pix_list = hp.query_polygon(
            NSIDE,
            hp.ang2vec(np.radians(90.) - dec, ra),
            inclusive=True
        )
        # self.pix_list = hp.query_polygon(NSIDE, np.array(vertices).T, inclusive=True)
        if self.logger is not None:
            msg = str(("HEALPix List: ", self.pix_list))
            self.logger.info(msg)
        else:
            print("HEALPix List: ", self.pix_list)

    def load_norm_filt(self, obj):
        if obj.type == "stamp":
            #return self.normF_galaxy  ###normalize_filter for stamp
            return None
        else:
            return None

    def _load_stamps(self, stamps, pix_id=None):
        nstamps = len(stamps['filename'])
        self.rng_sedGal = random.Random()
        self.rng_sedGal.seed(pix_id) # Use healpix index as the random seed
        self.ud = galsim.UniformDeviate(pix_id)

        for istamp in range(nstamps):
            fitsfile = os.path.join(self.cat_dir, "stampCats/"+stamps['filename'][istamp].decode('utf-8'))
            hdu=fitsio.open(fitsfile)

            param = self.initialize_param()
            param['id']   = hdu[0].header['index'] #istamp
            param['star'] = 3      # Stamp type in .cat file
            param['ra'] = hdu[0].header['ra']
            param['dec']= hdu[0].header['dec']
            param['pixScale']= hdu[0].header['pixScale']
            #param['srcGalaxyID'] = hdu[0].header['srcGID']
            #param['mu']= hdu[0].header['mu']
            #param['PA']= hdu[0].header['PA']
            #param['bfrac']= hdu[0].header['bfrac']
            #param['z']= hdu[0].header['z']
            param['mag_use_normal'] = hdu[0].header['mag_g'] #gals['mag_true_g_lsst']

            # Apply astrometric modeling
            # in C3 case only aberration
            param['ra_orig'] = param['ra']
            param['dec_orig']= param['dec']
            if self.config["obs_setting"]["enable_astrometric_model"]:
                ra_list = [param['ra']] #ra_arr.tolist()
                dec_list= [param['dec']] #dec_arr.tolist()
                pmra_list = np.zeros(1).tolist()
                pmdec_list = np.zeros(1).tolist()
                rv_list = np.zeros(1).tolist()
                parallax_list = [1e-9] * 1
                dt = datetime.fromtimestamp(self.pointing.timestamp)
                date_str = dt.date().isoformat()
                time_str = dt.time().isoformat()
                ra_arr, dec_arr = on_orbit_obs_position(
                    input_ra_list=ra_list,
                    input_dec_list=dec_list,
                    input_pmra_list=pmra_list,
                    input_pmdec_list=pmdec_list,
                    input_rv_list=rv_list,
                    input_parallax_list=parallax_list,
                    input_nstars=1,
                    input_x=self.pointing.sat_x,
                    input_y=self.pointing.sat_y,
                    input_z=self.pointing.sat_z,
                    input_vx=self.pointing.sat_vx,
                    input_vy=self.pointing.sat_vy,
                    input_vz=self.pointing.sat_vz,
                    input_epoch="J2015.5",
                    input_date_str=date_str,
                    input_time_str=time_str
                )
                param['ra'] = ra_arr[0]
                param['dec']= dec_arr[0]

            # Assign each galaxy a template SED
            param['sed_type'] = sed_assign(phz=param['z'], btt=param['bfrac'], rng=self.rng_sedGal)
            param['redden'] = self.tempRed_gal[param['sed_type']]
            param['av'] = 0.0
            param['redden'] = 0

            #param["CSSTmag"]= True
            #param["mag_r"] = 20.
            #param['']
            ###more keywords for stamp###
            param['image'] = hdu[0].data
            param['image'] = param['image']/(np.sum(param['image']))
            obj = Stamp(param)
            self.objs.append(obj)

    def _load(self, **kwargs):
        self.objs = []
        self.ids = 0

        if "stamp_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["stamp_cat"] and self.config["catalog_options"]["stamp_yes"]:
            stamps_cat = h5.File(self.stamp_path, 'r')['Stamps']
            for pix in self.pix_list:
                try:
                    stamps = stamps_cat[str(pix)]
                    self._load_stamps(stamps, pix_id=pix)
                    del stamps
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)

        if self.logger is not None:
            self.logger.info("maximum galaxy size: %.4f"%(self.max_size))
            self.logger.info("number of objects in catalog: %d"%(len(self.objs)))
        else:
            print("number of objects in catalog: ", len(self.objs))


    def load_sed(self, obj, **kwargs):
        if obj.type == 'stamp':
            sed_data = getObservedSED(
                sedCat=self.tempSed_gal[obj.sed_type],
                redshift=obj.z,
                av=obj.param["av"],
                redden=obj.param["redden"]
            )
            wave, flux = sed_data[0], sed_data[1]
        else:
            raise ValueError("Object type not known")
        speci = interpolate.interp1d(wave, flux)
        lamb = np.arange(2000, 11001+0.5, 0.5)
        y = speci(lamb)
        # erg/s/cm2/A --> photon/s/m2/A
        all_sed = y * lamb / (cons.h.value * cons.c.value) * 1e-13
        sed = Table(np.array([lamb, all_sed]).T, names=('WAVELENGTH', 'FLUX'))
        
        del wave
        del flux
        return sed