import os
import galsim
import random
import numpy as np
import h5py as h5
import healpy as hp
import astropy.constants as cons
import traceback
from astropy.coordinates import spherical_to_cartesian
from astropy.table import Table
from scipy import interpolate
from datetime import datetime

from observation_sim.mock_objects import CatalogBase, Star, Galaxy, Quasar, Stamp
from observation_sim.mock_objects._util import tag_sed, getObservedSED, getABMAG, integrate_sed_bandpass, comoving_dist
from observation_sim.astrometry.Astrometry_util import on_orbit_obs_position

import astropy.io.fits as fitsio
from observation_sim.mock_objects._util import seds, sed_assign, extAv

# (TEST)
from astropy.cosmology import FlatLambdaCDM
from astropy import constants
from astropy import units as U

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources

NSIDE = 128


class Catalog(CatalogBase):
    def __init__(self, config, chip, pointing, chip_output, filt, **kwargs):
        super().__init__()
        self.cat_dir = config["catalog_options"]["input_path"]["cat_dir"]
        self.seed_Av = 121212  # config["catalog_options"]["seed_Av"]

        # (TEST)
        self.cosmo = FlatLambdaCDM(H0=67.66, Om0=0.3111)

        self.chip_output = chip_output
        self.filt = filt
        self.logger = chip_output.logger

        with pkg_resources.path('catalog.data', 'SLOAN_SDSS.g.fits') as filter_path:
            self.normF_star = Table.read(str(filter_path))
        with pkg_resources.path('catalog.data', 'lsst_throuput_g.fits') as filter_path:
            self.normF_galaxy = Table.read(str(filter_path))

        self.config = config
        self.chip = chip
        self.pointing = pointing

        self.max_size = 0.

        if "stamp_cat" in config["catalog_options"]["input_path"] and config["catalog_options"]["input_path"]["stamp_cat"] and config["catalog_options"]["stamp_yes"]:
            stamp_file = config["catalog_options"]["input_path"]["stamp_cat"]
            self.stamp_path = os.path.join(self.cat_dir, stamp_file)
            # self.stamp_SED_path = os.path.join(config["data_dir"], config["SED_templates_path"]["stamp_SED"]) ###shoule be stamp-SED
            # self._load_SED_lib_stamps() ###shoule be stamp-SED
            self.tempSed_gal, self.tempRed_gal = seds(
                "galaxy.list", seddir="/public/home/chengliang/CSSOSDataProductsSims/testCats/Templates/Galaxy/")  # only for test

        self._add_output_columns_header()
        self._get_healpix_list()
        self._load()

    def _add_output_columns_header(self):
        self.add_hdr = " model_tag teff logg feh"
        self.add_hdr += " bulgemass diskmass detA e1 e2 kappa g1 g2 size galType veldisp "

        self.add_fmt = " %10s %8.4f %8.4f %8.4f"
        self.add_fmt += " %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %8.4f %4d %8.4f "
        self.chip_output.update_output_header(
            additional_column_names=self.add_hdr)

    def _get_healpix_list(self):
        self.sky_coverage = self.chip.getSkyCoverageEnlarged(
            self.chip.img.wcs, margin=0.2)
        ra_min, ra_max, dec_min, dec_max = self.sky_coverage.xmin, self.sky_coverage.xmax, self.sky_coverage.ymin, self.sky_coverage.ymax
        ra = np.deg2rad(np.array([ra_min, ra_max, ra_max, ra_min]))
        dec = np.deg2rad(np.array([dec_max, dec_max, dec_min, dec_min]))
        # vertices = spherical_to_cartesian(1., dec, ra)
        self.pix_list = hp.query_polygon(
            NSIDE,
            hp.ang2vec(np.radians(90.) - dec, ra),
            inclusive=True
        )
        # self.pix_list = hp.query_polygon(NSIDE, np.array(vertices).T, inclusive=True)
        if self.logger is not None:
            msg = str(("HEALPix List: ", self.pix_list))
            self.logger.info(msg)
        else:
            print("HEALPix List: ", self.pix_list)

    def load_norm_filt(self, obj):
        if obj.type == "stamp":
            return self.normF_galaxy  # normalize_filter for stamp
        else:
            return None

    def _load_stamps(self, stamps, pix_id=None):
        print("debug:: load_stamps")
        nstamps = len(stamps['filename'])
        self.rng_sedGal = random.Random()
        # Use healpix index as the random seed
        self.rng_sedGal.seed(float(pix_id))
        self.ud = galsim.UniformDeviate(pix_id)

        for istamp in range(nstamps):
            print("debug::", istamp)
            fitsfile = os.path.join(
                self.cat_dir, "stampCats/"+stamps['filename'][istamp].decode('utf-8'))
            print("debug::", istamp, fitsfile)
            hdu = fitsio.open(fitsfile)

            param = self.initialize_param()
            param['id'] = hdu[0].header['index']  # istamp
            param['star'] = 3      # Stamp type in .cat file
            param['ra'] = hdu[0].header['ra']
            param['dec'] = hdu[0].header['dec']
            param['pixScale'] = hdu[0].header['pixScale']
            # param['srcGalaxyID'] = hdu[0].header['srcGID']
            # param['mu']= hdu[0].header['mu']
            # param['PA']= hdu[0].header['PA']
            # param['bfrac']= hdu[0].header['bfrac']
            # param['z']= hdu[0].header['z']
            # gals['mag_true_g_lsst']
            param['mag_use_normal'] = hdu[0].header['mag_g']

            # Apply astrometric modeling
            # in C3 case only aberration
            param['ra_orig'] = param['ra']
            param['dec_orig'] = param['dec']
            if self.config["obs_setting"]["enable_astrometric_model"]:
                ra_list = [param['ra']]  # ra_arr.tolist()
                dec_list = [param['dec']]  # dec_arr.tolist()
                pmra_list = np.zeros(1).tolist()
                pmdec_list = np.zeros(1).tolist()
                rv_list = np.zeros(1).tolist()
                parallax_list = [1e-9] * 1
                dt = datetime.fromtimestamp(self.pointing.timestamp)
                date_str = dt.date().isoformat()
                time_str = dt.time().isoformat()
                ra_arr, dec_arr = on_orbit_obs_position(
                    input_ra_list=ra_list,
                    input_dec_list=dec_list,
                    input_pmra_list=pmra_list,
                    input_pmdec_list=pmdec_list,
                    input_rv_list=rv_list,
                    input_parallax_list=parallax_list,
                    input_nstars=1,
                    input_x=self.pointing.sat_x,
                    input_y=self.pointing.sat_y,
                    input_z=self.pointing.sat_z,
                    input_vx=self.pointing.sat_vx,
                    input_vy=self.pointing.sat_vy,
                    input_vz=self.pointing.sat_vz,
                    input_epoch="J2015.5",
                    input_date_str=date_str,
                    input_time_str=time_str
                )
                param['ra'] = ra_arr[0]
                param['dec'] = dec_arr[0]

            # Assign each galaxy a template SED
            param['sed_type'] = sed_assign(
                phz=param['z'], btt=param['bfrac'], rng=self.rng_sedGal)
            param['redden'] = self.tempRed_gal[param['sed_type']]
            param['av'] = 0.0
            param['redden'] = 0
            param['mu'] = 1

            # param["CSSTmag"]= True
            # param["mag_r"] = 20.
            # param['']
            # more keywords for stamp#
            param['image'] = hdu[0].data
            param['image'] = param['image']/(np.sum(param['image']))
            obj = Stamp(param)
            self.objs.append(obj)

    def _load(self, **kwargs):
        self.objs = []
        self.ids = 0

        if "stamp_cat" in self.config["catalog_options"]["input_path"] and self.config["catalog_options"]["input_path"]["stamp_cat"] and self.config["catalog_options"]["stamp_yes"]:
            stamps_cat = h5.File(self.stamp_path, 'r')['Stamps']
            print("debug::", stamps_cat.keys())

            for pix in self.pix_list:
                try:
                    stamps = stamps_cat[str(pix)]
                    print("debug::", stamps.keys())
                    self._load_stamps(stamps, pix_id=pix)
                    del stamps
                except Exception as e:
                    self.logger.error(str(e))
                    print(e)

        if self.logger is not None:
            self.logger.info("maximum galaxy size: %.4f" % (self.max_size))
            self.logger.info("number of objects in catalog: %d" %
                             (len(self.objs)))
        else:
            print("number of objects in catalog: ", len(self.objs))

    def load_sed(self, obj, **kwargs):
        if obj.type == 'stamp':
            sed_data = getObservedSED(
                sedCat=self.tempSed_gal[obj.sed_type],
                redshift=obj.z,
                av=obj.param["av"],
                redden=obj.param["redden"]
            )
            wave, flux = sed_data[0], sed_data[1]
        else:
            raise ValueError("Object type not known")
        speci = interpolate.interp1d(wave, flux)
        lamb = np.arange(2000, 11001+0.5, 0.5)
        y = speci(lamb)
        # erg/s/cm2/A --> photon/s/m2/A
        all_sed = y * lamb / (cons.h.value * cons.c.value) * 1e-13
        sed = Table(np.array([lamb, all_sed]).T, names=('WAVELENGTH', 'FLUX'))

        del wave
        del flux
        return sed