import os
import galsim
import random
import numpy as np
import h5py as h5
import healpy as hp
import astropy.constants as cons
import traceback
from astropy.coordinates import spherical_to_cartesian
from astropy.table import Table
from scipy import interpolate
from datetime import datetime

from ObservationSim.MockObject import CatalogBase, Star, Galaxy, Quasar
from ObservationSim.MockObject._util import tag_sed, getObservedSED, getABMAG, integrate_sed_bandpass, comoving_dist
from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position

# (TEST)
from astropy.cosmology import FlatLambdaCDM
from astropy import constants
from astropy import units as U
from astropy.coordinates import SkyCoord
from astropy.io import fits

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources

NSIDE = 128

bundle_file_list = ['galaxies_C6_bundle000199.h5','galaxies_C6_bundle000200.h5','galaxies_C6_bundle000241.h5','galaxies_C6_bundle000242.h5','galaxies_C6_bundle000287.h5','galaxies_C6_bundle000288.h5','galaxies_C6_bundle000714.h5','galaxies_C6_bundle000715.h5','galaxies_C6_bundle000778.h5','galaxies_C6_bundle000779.h5','galaxies_C6_bundle000842.h5','galaxies_C6_bundle000843.h5','galaxies_C6_bundle002046.h5','galaxies_C6_bundle002110.h5','galaxies_C6_bundle002111.h5','galaxies_C6_bundle002173.h5','galaxies_C6_bundle002174.h5','galaxies_C6_bundle002238.h5','galaxies_C6_bundle002596.h5','galaxies_C6_bundle002597.h5','galaxies_C6_bundle002656.h5','galaxies_C6_bundle002657.h5','galaxies_C6_bundle002711.h5','galaxies_C6_bundle002712.h5','galaxies_C6_bundle002844.h5','galaxies_C6_bundle002845.h5','galaxies_C6_bundle002884.h5','galaxies_C6_bundle002885.h5','galaxies_C6_bundle002921.h5','galaxies_C6_bundle002922.h5']

qsosed_file_list = ['quickspeclib_interp1d_run1.fits','quickspeclib_interp1d_run2.fits','quickspeclib_interp1d_run3.fits','quickspeclib_interp1d_run4.fits','quickspeclib_interp1d_run5.fits','quickspeclib_interp1d_run6.fits','quickspeclib_interp1d_run7.fits','quickspeclib_interp1d_run8.fits','quickspeclib_interp1d_run9.fits','quickspeclib_interp1d_run10.fits','quickspeclib_interp1d_run11.fits','quickspeclib_interp1d_run12.fits','quickspeclib_interp1d_run13.fits','quickspeclib_interp1d_run14.fits','quickspeclib_interp1d_run15.fits','quickspeclib_interp1d_run16.fits','quickspeclib_interp1d_run17.fits','quickspeclib_interp1d_run18.fits','quickspeclib_interp1d_run19.fits','quickspeclib_interp1d_run20.fits','quickspeclib_interp1d_run21.fits','quickspeclib_interp1d_run22.fits','quickspeclib_interp1d_run23.fits','quickspeclib_interp1d_run24.fits','quickspeclib_interp1d_run25.fits','quickspeclib_interp1d_run26.fits','quickspeclib_interp1d_run27.fits','quickspeclib_interp1d_run28.fits','quickspeclib_interp1d_run29.fits','quickspeclib_interp1d_run30.fits']

star_file_list = ['C7_Gaia_Galaxia_RA170DECm23_healpix.hdf5', 'C7_Gaia_Galaxia_RA180DECp60_healpix.hdf5', 'C7_Gaia_Galaxia_RA240DECp30_healpix.hdf5', 'C7_Gaia_Galaxia_RA300DECm60_healpix.hdf5', 'C7_Gaia_Galaxia_RA30DECm48_healpix.hdf5']
star_center_list = [(170., -23.), (180., 60.), (240., 30.), (300., -60.), (30., -48.)]

def get_bundleIndex(healpixID_ring, bundleOrder=4, healpixOrder=7):
    assert NSIDE == 2**healpixOrder
    shift = healpixOrder - bundleOrder
    shift = 2*shift

    nside_bundle = 2**bundleOrder
    nside_healpix= 2**healpixOrder

    healpixID_nest= hp.ring2nest(nside_healpix, healpixID_ring)
    bundleID_nest = (healpixID_nest >> shift)
    bundleID_ring = hp.nest2ring(nside_bundle, bundleID_nest)

    return bundleID_ring

def get_agnsed_file(bundle_file_name):
    return qsosed_file_list[bundle_file_list.index(bundle_file_name)]

def get_star_cat(ra_pointing, dec_pointing):
    pointing_c = SkyCoord(ra=ra_pointing*U.deg, dec=dec_pointing*U.deg)
    max_dist = 10
    return_star_path = None
    for star_file, center in zip(star_file_list, star_center_list):
        center_c = SkyCoord(ra=center[0]*U.deg, dec=center[1]*U.deg)
        dist = pointing_c.separation(center_c).to(U.deg).value
        if dist < max_dist:
            return_star_path = star_file
            max_dist = dist
    return return_star_path

class Catalog(CatalogBase):
    def __init__(self, config, chip, pointing, chip_output, filt, **kwargs):
        super().__init__()
        self.cat_dir = os.path.join(config["data_dir"], config["catalog_options"]["input_path"]["cat_dir"])
        self.seed_Av = config["catalog_options"]["seed_Av"]

        self.cosmo = FlatLambdaCDM(H0=67.66, Om0=0.3111)

        self.chip_output = chip_output
        self.filt = filt
        self.logger = chip_output.logger

        with pkg_resources.path('Catalog.data', 'SLOAN_SDSS.g.fits') as filter_path:
            self.normF_star = Table.read(str(filter_path))
        
        self.config = config
        self.chip = chip
        self.pointing = pointing

        self.max_size = 0.

        if "star_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["star_cat"] and not config["catalog_options"]["galaxy_only"]:
            # Get the cloest star catalog file
            star_file_name = get_star_cat(ra_pointing=self.pointing.ra, dec_pointing=self.pointing.dec)
            star_path = os.path.join(config["catalog_options"]["input_path"]["star_cat"], star_file_name)
            star_SED_file = config["catalog_options"]["SED_templates_path"]["star_SED"]
            self.star_path = os.path.join(self.cat_dir, star_path)
            self.star_SED_path = os.path.join(config["data_dir"], star_SED_file)
            self._load_SED_lib_star()
        
        if "galaxy_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["galaxy_cat"] and not config["catalog_options"]["star_only"]:
            galaxy_dir = config["catalog_options"]["input_path"]["galaxy_cat"]
            self.galaxy_path = os.path.join(self.cat_dir, galaxy_dir)
            self.galaxy_SED_path = os.path.join(config["data_dir"], config["catalog_options"]["SED_templates_path"]["galaxy_SED"])
            self._load_SED_lib_gals()
            self.agn_seds = {}

        if "AGN_SED" in config["catalog_options"]["SED_templates_path"] and not config["catalog_options"]["star_only"]:
            self.AGN_SED_path = os.path.join(config["data_dir"], config["catalog_options"]["SED_templates_path"]["AGN_SED"])

        if "rotateEll" in config["catalog_options"]:
            self.rotation = np.radians(float(config["catalog_options"]["rotateEll"]))
        else:
            self.rotation = 0.

        # Update output .cat header with catalog specific output columns
        self._add_output_columns_header()

        self._get_healpix_list()
        self._load()
    
    def _add_output_columns_header(self):
        self.add_hdr = " model_tag teff logg feh"
        self.add_hdr += " bulgemass diskmass detA e1 e2 kappa g1 g2 size galType veldisp "

        self.add_fmt = " %10s %8.4f %8.4f %8.4f"
        self.add_fmt += " %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %4d %8.4f "
        self.chip_output.update_ouptut_header(additional_column_names=self.add_hdr)

    def _get_healpix_list(self):
        self.sky_coverage = self.chip.getSkyCoverageEnlarged(self.chip.img.wcs, margin=0.2)
        ra_min, ra_max, dec_min, dec_max = self.sky_coverage.xmin, self.sky_coverage.xmax, self.sky_coverage.ymin, self.sky_coverage.ymax
        ra = np.deg2rad(np.array([ra_min, ra_max, ra_max, ra_min]))
        dec = np.deg2rad(np.array([dec_max, dec_max, dec_min, dec_min]))
        self.pix_list = hp.query_polygon(
            NSIDE,
            hp.ang2vec(np.radians(90.) - dec, ra),
            inclusive=True
        )
        if self.logger is not None:
            msg = str(("HEALPix List: ", self.pix_list))
            self.logger.info(msg)
        else:
            print("HEALPix List: ", self.pix_list)

    def load_norm_filt(self, obj):
        if obj.type == "star":
            return self.normF_star
        elif obj.type == "galaxy" or obj.type == "quasar":
            # return self.normF_galaxy
            return None
        else:
            return None

    def _load_SED_lib_star(self):
        self.tempSED_star = h5.File(self.star_SED_path,'r')

    def _load_SED_lib_gals(self):
        pcs = h5.File(os.path.join(self.galaxy_SED_path, "pcs.h5"), "r")
        lamb = h5.File(os.path.join(self.galaxy_SED_path, "lamb.h5"), "r")
        self.lamb_gal = lamb['lamb'][()]
        self.pcs = pcs['pcs'][()]

    def _load_gals(self, gals, pix_id=None, cat_id=0, agnsed_file=""):
        ngals = len(gals['ra'])

        # Apply astrometric modeling
        ra_arr = gals['ra'][:]
        dec_arr = gals['dec'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = np.zeros(ngals).tolist()
            pmdec_list = np.zeros(ngals).tolist()
            rv_list = np.zeros(ngals).tolist()
            parallax_list = [1e-9] * ngals
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=ngals,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )

        for igals in range(ngals):
            # # (TEST)
            # if igals > 100:
            #     break
            
            param = self.initialize_param()
            param['ra'] = ra_arr[igals]
            param['dec'] = dec_arr[igals]
            param['ra_orig'] = gals['ra'][igals]
            param['dec_orig'] = gals['dec'][igals]
            # param['mag_use_normal'] = gals['mag_csst_%s'%(self.filt.filter_type)][igals]
            if self.filt.filter_type == 'NUV':
                param['mag_use_normal'] = gals['mag_csst_nuv'][igals]
            else:
                param['mag_use_normal'] = gals['mag_csst_%s'%(self.filt.filter_type)][igals]
            if self.filt.is_too_dim(mag=param['mag_use_normal'], margin=self.config["obs_setting"]["mag_lim_margin"]):
                continue

            param['z'] = gals['redshift'][igals]
            param['model_tag'] = 'None'
            param['g1'] = gals['shear'][igals][0]
            param['g2'] = gals['shear'][igals][1]
            param['kappa'] = gals['kappa'][igals]
            param['e1'] = gals['ellipticity_true'][igals][0]
            param['e2'] = gals['ellipticity_true'][igals][1]
            
            # For shape calculation
            param['e1'], param['e2'], param['ell_total'] = self.rotate_ellipticity(
                                                                    e1=gals['ellipticity_true'][igals][0],
                                                                    e2=gals['ellipticity_true'][igals][1],
                                                                    rotation=self.rotation,
                                                                    unit='radians')
            # param['ell_total'] = np.sqrt(param['e1']**2 + param['e2']**2)
            if param['ell_total'] > 0.9:
                continue
            # phi_e = cmath.phase(complex(param['e1'], param['e2']))
            # param['e1'] = param['ell_total'] * np.cos(phi_e + 2*self.rotation)
            # param['e2'] = param['ell_total'] * np.sin(phi_e + 2*self.rotation)
            
            param['e1_disk'] = param['e1']
            param['e2_disk'] = param['e2']
            param['e1_bulge'] = param['e1']
            param['e2_bulge'] = param['e2']


            param['delta_ra'] = 0
            param['delta_dec'] = 0

            # Masses
            param['bulgemass'] = gals['bulgemass'][igals]
            param['diskmass'] = gals['diskmass'][igals]

            param['size'] = gals['size'][igals]
            if param['size'] > self.max_size:
                self.max_size = param['size']

            # Sersic index
            param['disk_sersic_idx'] = 1.
            param['bulge_sersic_idx'] = 4.

            # Sizes
            param['bfrac'] = param['bulgemass']/(param['bulgemass'] + param['diskmass'])
            if param['bfrac'] >= 0.6:
                param['hlr_bulge'] = param['size']
                param['hlr_disk'] = param['size'] * (1. - param['bfrac'])
            else:
                param['hlr_disk'] = param['size']
                param['hlr_bulge'] = param['size'] * param['bfrac']

            # SED coefficients
            param['coeff'] = gals['coeff'][igals]
            param['detA'] = gals['detA'][igals]

            # Others
            param['galType'] = gals['type'][igals]
            param['veldisp'] = gals['veldisp'][igals]
            
            # TEST no redening and no extinction
            param['av'] = 0.0
            param['redden'] = 0

            # Is this an Quasar?
            param['qsoindex'] = gals['qsoindex'][igals]
            if param['qsoindex'] == -1:
                param['star'] = 0   # Galaxy
                param['agnsed_file'] = ""
            else:
                param['star'] = 2   # Quasar
                param['agnsed_file'] = agnsed_file

            # NOTE: this cut cannot be put before the SED type has been assigned
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue

            # TEMP
            self.ids += 1
            param['id'] = '%06d'%(int(pix_id)) + '%06d'%(cat_id) + '%08d'%(igals)
            
            if param['star'] == 0:
                obj = Galaxy(param, logger=self.logger)
            elif param['star'] == 2:
                obj = Quasar(param, logger=self.logger)
            
            # Need to deal with additional output columns
            obj.additional_output_str = self.add_fmt%("n", 0., 0., 0.,
                                                    param['bulgemass'], param['diskmass'], param['detA'],
                                                    param['e1'], param['e2'], param['kappa'], param['g1'], param['g2'], param['size'],
                                                    param['galType'], param['veldisp'])
            
            self.objs.append(obj)

    def _load_stars(self, stars, pix_id=None):
        nstars = len(stars['sourceID'])
        # Apply astrometric modeling
        ra_arr = stars["RA"][:]
        dec_arr = stars["Dec"][:]
        pmra_arr = stars['pmra'][:]
        pmdec_arr = stars['pmdec'][:]
        rv_arr = stars['RV'][:]
        parallax_arr = stars['parallax'][:]
        if self.config["obs_setting"]["enable_astrometric_model"]:
            ra_list = ra_arr.tolist()
            dec_list = dec_arr.tolist()
            pmra_list = pmra_arr.tolist()
            pmdec_list = pmdec_arr.tolist()
            rv_list = rv_arr.tolist()
            parallax_list = parallax_arr.tolist()
            dt = datetime.utcfromtimestamp(self.pointing.timestamp)
            date_str = dt.date().isoformat()
            time_str = dt.time().isoformat()
            ra_arr, dec_arr = on_orbit_obs_position(
                input_ra_list=ra_list,
                input_dec_list=dec_list,
                input_pmra_list=pmra_list,
                input_pmdec_list=pmdec_list,
                input_rv_list=rv_list,
                input_parallax_list=parallax_list,
                input_nstars=nstars,
                input_x=self.pointing.sat_x,
                input_y=self.pointing.sat_y,
                input_z=self.pointing.sat_z,
                input_vx=self.pointing.sat_vx,
                input_vy=self.pointing.sat_vy,
                input_vz=self.pointing.sat_vz,
                input_epoch="J2000",
                input_date_str=date_str,
                input_time_str=time_str
            )
        for istars in range(nstars):
            # (TEST)
            # if istars > 100:
            #     break

            param = self.initialize_param()
            param['ra'] = ra_arr[istars]
            param['dec'] = dec_arr[istars]
            param['ra_orig'] = stars["RA"][istars]
            param['dec_orig'] = stars["Dec"][istars]
            param['pmra'] = pmra_arr[istars]
            param['pmdec'] = pmdec_arr[istars]
            param['rv'] = rv_arr[istars]
            param['parallax'] = parallax_arr[istars]
            if not self.chip.isContainObj(ra_obj=param['ra'], dec_obj=param['dec'], margin=200):
                continue
            param['mag_use_normal'] = stars['app_sdss_g'][istars]
            self.ids += 1
            param['id'] = stars['sourceID'][istars]
            param['sed_type'] = stars['sourceID'][istars]
            param['model_tag'] = stars['model_tag'][istars]
            param['teff'] = stars['teff'][istars]
            param['logg'] = stars['grav'][istars]
            param['feh'] = stars['feh'][istars]
            param['z'] = 0.0
            param['star'] = 1   # Star
            obj = Star(param, logger=self.logger)

            # Append additional output columns to the .cat file
            obj.additional_output_str = self.add_fmt%(param["model_tag"], param['teff'], param['logg'], param['feh'],
                                                    0., 0., 0., 0., 0., 0., 0., 0., 0., -1, 0.)

            self.objs.append(obj)

    def _load(self, **kwargs):
        self.objs = []
        self.ids = 0
        
        if "star_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["star_cat"] and not self.config["catalog_options"]["galaxy_only"]:
            star_cat = h5.File(self.star_path, 'r')['catalog']
            for pix in self.pix_list:
                try:
                    stars = star_cat[str(pix)]
                    self._load_stars(stars, pix_id=pix)
                    del stars
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)
        
        if "galaxy_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["galaxy_cat"] and not self.config["catalog_options"]["star_only"]:
            for pix in self.pix_list:
                try:
                    bundleID  = get_bundleIndex(pix)
                    bundle_file = "galaxies_C6_bundle{:06}.h5".format(bundleID)
                    file_path = os.path.join(self.galaxy_path, bundle_file)
                    gals_cat = h5.File(file_path, 'r')['galaxies']
                    gals = gals_cat[str(pix)]

                    # Get corresponding AGN SED file
                    agnsed_file = get_agnsed_file(bundle_file)
                    agnsed_path = os.path.join(self.AGN_SED_path, agnsed_file)
                    self.agn_seds[agnsed_file] = fits.open(agnsed_path)[0].data

                    self._load_gals(gals, pix_id=pix, cat_id=bundleID, agnsed_file=agnsed_file)

                    del gals
                except Exception as e:
                    traceback.print_exc()
                    self.logger.error(str(e))
                    print(e)

        if self.logger is not None:
            self.logger.info("maximum galaxy size: %.4f"%(self.max_size))
            self.logger.info("number of objects in catalog: %d"%(len(self.objs)))
        else:
            print("number of objects in catalog: ", len(self.objs))

    def load_sed(self, obj, **kwargs):
        if obj.type == 'star':
            _, wave, flux = tag_sed(
                h5file=self.tempSED_star,
                model_tag=obj.param['model_tag'],
                teff=obj.param['teff'],
                logg=obj.param['logg'],
                feh=obj.param['feh']
            )
        elif obj.type == 'galaxy' or obj.type == 'quasar':
            factor = 10**(-.4 * self.cosmo.distmod(obj.z).value)
            if obj.type == 'galaxy':
                flux = np.matmul(self.pcs, obj.coeff) * factor
                #  if np.any(flux < 0):
                #     raise ValueError("Glaxy %s: negative SED fluxes"%obj.id)
                flux[flux < 0] = 0.
                sedcat = np.vstack((self.lamb_gal, flux)).T
                sed_data = getObservedSED(
                    sedCat=sedcat,
                    redshift=obj.z,
                    av=obj.param["av"],
                    redden=obj.param["redden"]
                )
                wave, flux = sed_data[0], sed_data[1]
            elif obj.type == 'quasar':
                flux = self.agn_seds[obj.agnsed_file][int(obj.qsoindex)] * 1e-17
                flux[flux < 0] = 0.
                wave = self.lamb_gal
        else:
            raise ValueError("Object type not known")
        speci = interpolate.interp1d(wave, flux)
        lamb = np.arange(2000, 11001+0.5, 0.5)
        y = speci(lamb)
        # erg/s/cm2/A --> photon/s/m2/A
        all_sed = y * lamb / (cons.h.value * cons.c.value) * 1e-13
        sed = Table(np.array([lamb, all_sed]).T, names=('WAVELENGTH', 'FLUX'))
        
        if obj.type == 'quasar':
            # integrate to get the magnitudes
            sed_photon = np.array([sed['WAVELENGTH'], sed['FLUX']]).T
            sed_photon = galsim.LookupTable(x=np.array(sed_photon[:, 0]), f=np.array(sed_photon[:, 1]), interpolant='nearest')
            sed_photon = galsim.SED(sed_photon, wave_type='A', flux_type='1', fast=False)
            interFlux = integrate_sed_bandpass(sed=sed_photon, bandpass=self.filt.bandpass_full)
            obj.param['mag_use_normal'] = getABMAG(interFlux, self.filt.bandpass_full)
            # mag = getABMAG(interFlux, self.filt.bandpass_full)
            # print("mag diff = %.3f"%(mag - obj.param['mag_use_normal']))
        del wave
        del flux
        return sed