import os
import galsim
import random
import numpy as np
import h5py as h5
import healpy as hp
import astropy.constants as cons
from astropy.coordinates import spherical_to_cartesian
from astropy.table import Table
from scipy import interpolate
from datetime import datetime

from ObservationSim.MockObject import CatalogBase, Star, Galaxy, Quasar
from ObservationSim.MockObject._util import seds, sed_assign, extAv, tag_sed, getObservedSED
from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources

NSIDE = 128

class Catalog(CatalogBase):
    def __init__(self, config, chip, pointing, chip_output, **kwargs):
        super().__init__()
        self.cat_dir = os.path.join(config["data_dir"], config["catalog_options"]["input_path"]["cat_dir"])
        self.seed_Av = config["catalog_options"]["seed_Av"]

        self.chip_output = chip_output
        self.logger = chip_output.logger

        with pkg_resources.path('Catalog.data', 'SLOAN_SDSS.g.fits') as filter_path:
            self.normF_star = Table.read(str(filter_path))
        with pkg_resources.path('Catalog.data', 'lsst_throuput_g.fits') as filter_path:
            self.normF_galaxy = Table.read(str(filter_path))
        
        self.config = config
        self.chip = chip
        self.pointing = pointing

        if "star_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["star_cat"] and not config["catalog_options"]["galaxy_only"]:
            star_file = config["catalog_options"]["input_path"]["star_cat"]
            star_SED_file = config["catalog_options"]["SED_templates_path"]["star_SED"]
            self.star_path = os.path.join(self.cat_dir, star_file)
            self.star_SED_path = os.path.join(config["data_dir"], star_SED_file)
            self._load_SED_lib_star()
        if "galaxy_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["galaxy_cat"] and not config["catalog_options"]["star_only"]:
            galaxy_file = config["catalog_options"]["input_path"]["galaxy_cat"]
            self.galaxy_path = os.path.join(self.cat_dir, galaxy_file)
            self.galaxy_SED_path = os.path.join(config["data_dir"], config["catalog_options"]["SED_templates_path"]["galaxy_SED"])
            self._load_SED_lib_gals()
        if "rotateEll" in config["catalog_options"]:
            self.rotation = float(int(config["catalog_options"]["rotateEll"]/45.))
        else:
            self.rotation = 0.

        # Update output .cat header with catalog specific output columns
        self._add_output_columns_header()

        self._get_healpix_list()
        self._load()
    
    def _add_output_columns_header(self):
        self.add_hdr = " model_tag teff logg feh"
        self.add_fmt = " %10s %8.4f %8.4f %8.4f"
        self.chip_output.update_ouptut_header(additional_column_names=self.add_hdr)

    def _get_healpix_list(self):
        self.sky_coverage = self.chip.getSkyCoverageEnlarged(self.chip.img.wcs, margin=0.2)
        ra_min, ra_max, dec_min, dec_max = self.sky_coverage.xmin, self.sky_coverage.xmax, self.sky_coverage.ymin, self.sky_coverage.ymax
        ra = np.deg2rad(np.array([ra_min, ra_max, ra_max, ra_min]))
        dec = np.deg2rad(np.array([dec_max, dec_max, dec_min, dec_min]))
        vertices = spherical_to_cartesian(1., dec, ra)
        self.pix_list = hp.query_polygon(NSIDE, np.array(vertices).T, inclusive=True)
        if self.logger is not None:
            msg = str(("HEALPix List: ", self.pix_list))
            self.logger.info(msg)
        else:
            print("HEALPix List: ", self.pix_list)

    def load_norm_filt(self, obj):
        if obj.type == "star":
            return self.normF_star
        elif obj.type == "galaxy" or obj.type == "quasar":
            return self.normF_galaxy
        else:
            return None

    def _load_SED_lib_star(self):
        self.tempSED_star = h5.File(self.star_SED_path,'r')

    def _load_SED_lib_gals(self):
        self.tempSed_gal, self.tempRed_gal = seds("galaxy.list", seddir=self.galaxy_SED_path)

    def _load_gals(self, gals, pix_id=None):
        ngals = len(gals['galaxyID'])
        self.rng_sedGal = random.Random()
        self.rng_sedGal.seed(pix_id) # Use healpix index as the random seed
        self.ud = galsim.UniformDeviate(pix_id)

        # Apply astrometric modeling
        # in C3 case only aberration
        ra_arr = gals['ra_true'][:]
        dec_arr = gals['dec_true'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = np.zeros(ngals).tolist()
            pmdec_list = np.zeros(ngals).tolist()
            rv_list = np.zeros(ngals).tolist()
            parallax_list = [1e-9] * ngals
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=ngals,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2015.5",
                input_date_str=date_str,
                input_time_str=time_str
            )

        for igals in range(ngals):
            param = self.initialize_param()
            param['ra'] = ra_arr[igals]
            param['dec'] = dec_arr[igals]
            param['ra_orig'] = gals['ra_true'][igals]
            param['dec_orig'] = gals['dec_true'][igals]
            param['mag_use_normal'] = gals['mag_true_g_lsst'][igals]
            # if param['mag_use_normal'] >= 26.5:
            #     continue
            param['z'] = gals['redshift_true'][igals]
            param['model_tag'] = 'None'
            param['g1'] = 0
            param['g2'] = 0
            param['kappa'] = 0
            param['delta_ra'] = 0
            param['delta_dec'] = 0
            # sersicB = gals['sersic_bulge'][igals]
            hlrMajB = gals['size_bulge_true'][igals]
            hlrMinB = gals['size_minor_bulge_true'][igals]
            # sersicD = gals['sersic_disk'][igals]
            hlrMajD = gals['size_disk_true'][igals]
            hlrMinD = gals['size_minor_disk_true'][igals]
            aGal = gals['size_true'][igals]
            bGal = gals['size_minor_true'][igals]
            param['bfrac'] = gals['bulge_to_total_ratio_i'][igals]
            param['theta'] = gals['position_angle_true'][igals]
            param['hlr_bulge'] = np.sqrt(hlrMajB * hlrMinB)
            param['hlr_disk'] = np.sqrt(hlrMajD * hlrMinD)
            param['ell_bulge'] = (hlrMajB - hlrMinB)/(hlrMajB + hlrMinB)
            param['ell_disk'] = (hlrMajD - hlrMinD)/(hlrMajD + hlrMinD)
            param['ell_tot'] = (aGal - bGal) / (aGal + bGal)

            # Assign each galaxy a template SED
            param['sed_type'] = sed_assign(phz=param['z'], btt=param['bfrac'], rng=self.rng_sedGal)
            # param['redden'] = self.tempRed_gal[param['sed_type']]
            # param['av'] = self.avGal[int(self.ud()*self.nav)]
            
            # TEST no redening and no extinction
            param['av'] = 0.0
            param['redden'] = 0

            if param['sed_type'] <= 5:
                param['av'] = 0.0
                param['redden'] = 0
            param['star'] = 0   # Galaxy
            if param['sed_type'] >= 29:
                param['av'] = 0.6 * param['av'] / 3.0 # for quasar, av=[0, 0.2], 3.0=av.max-av.im
                param['star'] = 2 # Quasar

            # NOTE: this cut cannot be put before the SED type has been assigned
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue

            self.ids += 1
            # param['id'] = self.ids
            param['id'] = gals['galaxyID'][igals]
            
            if param['star'] == 0:
                obj = Galaxy(param, self.rotation, logger=self.logger)
            if param['star'] == 2:
                obj = Quasar(param, logger=self.logger)
            
            # Need to deal with additional output columns
            obj.additional_output_str = self.add_fmt%("n", 0., 0., 0.)
            
            self.objs.append(obj)

    def _load_stars(self, stars, pix_id=None):
        nstars = len(stars['sourceID'])
        # Apply astrometric modeling
        ra_arr = stars["RA"][:]
        dec_arr = stars["Dec"][:]
        pmra_arr = stars['pmra'][:]
        pmdec_arr = stars['pmdec'][:]
        rv_arr = stars['RV'][:]
        parallax_arr = stars['parallax'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = pmra_arr.tolist()
            pmdec_list = pmdec_arr.tolist()
            rv_list = rv_arr.tolist()
            parallax_list = parallax_arr.tolist()
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=nstars,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2015.5",
                input_date_str=date_str,
                input_time_str=time_str
            )
        for istars in range(nstars):
            param = self.initialize_param()
            param['ra'] = ra_arr[istars]
            param['dec'] = dec_arr[istars]
            param['ra_orig'] = stars["RA"][istars]
            param['dec_orig'] = stars["Dec"][istars]
            param['pmra'] = pmra_arr[istars]
            param['pmdec'] = pmdec_arr[istars]
            param['rv'] = rv_arr[istars]
            param['parallax'] = parallax_arr[istars]
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue
            param['mag_use_normal'] = stars['app_sdss_g'][istars]
            # if param['mag_use_normal'] >= 26.5:
            #     continue
            self.ids += 1
            # param['id'] = self.ids
            param['id'] = stars['sourceID'][istars]
            param['sed_type'] = stars['sourceID'][istars]
            param['model_tag'] = stars['model_tag'][istars]
            param['teff'] = stars['teff'][istars]
            param['logg'] = stars['grav'][istars]
            param['feh'] = stars['feh'][istars]
            param['z'] = 0.0
            param['star'] = 1   # Star
            obj = Star(param, logger=self.logger)

            # Append additional output columns to the .cat file
            obj.additional_output_str = self.add_fmt%(param["model_tag"], param['teff'], param['logg'], param['feh'])

            self.objs.append(obj)

    def _load(self, **kwargs):
        self.nav = 15005
        self.avGal = extAv(self.nav, seed=self.seed_Av)
        self.objs = []
        self.ids = 0
        if "star_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["star_cat"] and not self.config["catalog_options"]["galaxy_only"]:
            star_cat = h5.File(self.star_path, 'r')['catalog']
            for pix in self.pix_list:
                try:
                    stars = star_cat[str(pix)]
                    self._load_stars(stars, pix_id=pix)
                    del stars
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)
        if "galaxy_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["galaxy_cat"] and not self.config["catalog_options"]["star_only"]:
            gals_cat = h5.File(self.galaxy_path, 'r')['galaxies']
            for pix in self.pix_list:
                try:
                    gals = gals_cat[str(pix)]
                    self._load_gals(gals, pix_id=pix)
                    del gals
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)
        if self.logger is not None:
            self.logger.info("number of objects in catalog: %d"%(len(self.objs)))
        else:
            print("number of objects in catalog: ", len(self.objs))
        del self.avGal


    def load_sed(self, obj, **kwargs):
        if obj.type == 'star':
            _, wave, flux = tag_sed(
                h5file=self.tempSED_star,
                model_tag=obj.param['model_tag'],
                teff=obj.param['teff'],
                logg=obj.param['logg'],
                feh=obj.param['feh']
            )
        elif obj.type == 'galaxy' or obj.type == 'quasar':
            sed_data = getObservedSED(
                sedCat=self.tempSed_gal[obj.sed_type],
                redshift=obj.z,
                av=obj.param["av"],
                redden=obj.param["redden"]
            )
            wave, flux = sed_data[0], sed_data[1]
        else:
            raise ValueError("Object type not known")
        speci = interpolate.interp1d(wave, flux)
        lamb = np.arange(2000, 18001 + 0.5, 0.5)
        y = speci(lamb)
        # erg/s/cm2/A --> photo/s/m2/A
        all_sed = y * lamb / (cons.h.value * cons.c.value) * 1e-13
        sed = Table(np.array([lamb, all_sed]).T, names=('WAVELENGTH', 'FLUX'))
        del wave
        del flux
        return sed