import galsim
import os
import sys
import numpy as np
import time
import math
import astropy.constants as cons
from astropy.io import fits
from scipy.interpolate import griddata
from astropy.table import Table
from observation_sim.mock_objects.SpecDisperser import SpecDisperser
from scipy import interpolate
import gc

from observation_sim.mock_objects.MockObject import MockObject

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources


# flatDir = '/Volumes/EAGET/LED_FLAT/'
LED_name = ['LED1', 'LED2', 'LED3', 'LED4', 'LED5', 'LED6', 'LED7', 'LED8', 'LED9', 'LED10', 'LED11', 'LED12', 'LED13',
            'LED14']
cwaves_name = {'LED1': '275', 'LED2': '310', 'LED3': '430', 'LED4': '505', 'LED5': '545', 'LED6': '590', 'LED7': '670',
               'LED8': '760', 'LED9': '940', 'LED10': '940', 'LED11': '1050', 'LED12': '1050',
               'LED13': '340', 'LED14': '365'}

cwaves = {'LED1': 2750, 'LED2': 3100, 'LED3': 4300, 'LED4': 5050, 'LED5': 5250, 'LED6': 5900, 'LED7': 6700,
          'LED8': 7600, 'LED9': 8800, 'LED10': 9400, 'LED11': 10500, 'LED12': 15500, 'LED13': 3400, 'LED14': 3650}
cwaves_fwhm = {'LED1': 110, 'LED2': 120, 'LED3': 200, 'LED4': 300, 'LED5': 300, 'LED6': 130, 'LED7': 210,
               'LED8': 260, 'LED9': 400, 'LED10': 370, 'LED11': 500, 'LED12': 1400, 'LED13': 90, 'LED14': 100}
# LED_QE = {'LED1': 0.3, 'LED2': 0.4, 'LED13': 0.5, 'LED14': 0.5, 'LED10': 0.4}
# e-/ms
# fluxLED = {'LED1': 0.16478729, 'LED2': 0.084220931, 'LED3': 2.263360617, 'LED4': 2.190623489, 'LED5': 0.703504768,
#            'LED6': 0.446117963, 'LED7': 0.647122098, 'LED8': 0.922313442,
#            'LED9': 0.987278143, 'LED10': 2.043989167, 'LED11': 0.612571429, 'LED12': 1.228915663, 'LED13': 0.17029384,
#            'LED14': 0.27842925}

# e-/ms
fluxLED = {'LED1': 15, 'LED2': 15, 'LED3': 12.5, 'LED4': 9, 'LED5': 9,
           'LED6': 9, 'LED7': 9, 'LED8': 9, 'LED9': 9, 'LED10': 12.5, 'LED11': 15, 'LED12': 15, 'LED13': 12.5,
           'LED14': 12.5}
# fluxLEDL = {'LED1': 10, 'LED2': 10, 'LED3': 10, 'LED4': 10, 'LED5': 10,
#            'LED6': 10, 'LED7': 10, 'LED8': 10, 'LED9': 10, 'LED10': 10, 'LED11': 10, 'LED12':10, 'LED13': 10,
#            'LED14': 10}

mirro_eff = {'GU': 0.61, 'GV': 0.8, 'GI': 0.8}

bandtoLed = {'NUV': ['LED1', 'LED2'], 'u': ['LED13', 'LED14'], 'g': ['LED3', 'LED4', 'LED5'], 'r': ['LED6', 'LED7'], 'i': ['LED8'], 'z': ['LED9', 'LED10'], 'y': ['LED10'], 'GU': ['LED1', 'LED2', 'LED13', 'LED14'], 'GV': ['LED3', 'LED4', 'LED5', 'LED6'], 'GI': ['LED7', 'LED8', 'LED9', 'LED10']}
# mirro_eff = {'GU':1, 'GV':1, 'GI':1}


class FlatLED(MockObject):
    def __init__(self, chip, filt, flatDir=None, logger=None):
        # self.led_type_list = led_type_list
        self.filt = filt
        self.chip = chip
        self.logger = logger
        if flatDir is not None:
            self.flatDir = flatDir
        else:
            try:
                with pkg_resources.files('observation_sim.mock_objects.data.led').joinpath("") as ledDir:
                    self.flatDir = ledDir.as_posix()
            except AttributeError:
                with pkg_resources.path('observation_sim.mock_objects.data.led', "") as ledDir:
                    self.flatDir = ledDir.as_posix()

    def getInnerFlat(self):
        ledflats = bandtoLed[self.chip.filter_type]
        iFlat = np.zeros([self.chip.npix_y, self.chip.npix_x])
        for nled in ledflats:
            iFlat = iFlat + self.getLEDImage1(led_type=nled, LED_Img_flag=False)
        iFlat = iFlat/len(ledflats)
        return iFlat

    ###
    # return LED flat, e/s
    ###
    def getLEDImage(self, led_type='LED1', LED_Img_flag=True):
        # cwave = cwaves[led_type]
        flat = fits.open(os.path.join(self.flatDir, 'model_' +
                         cwaves_name[led_type] + 'nm.fits'))
        xlen = flat[0].header['NAXIS1']
        ylen = 601
        x = np.linspace(0, self.chip.npix_x * 6, xlen)
        y = np.linspace(0, self.chip.npix_y * 5, ylen)
        xx, yy = np.meshgrid(x, y)

        a1 = flat[0].data[0:ylen, 0:xlen]
        # z = np.sin((xx+yy+xx**2+yy**2))
        # fInterp = interp2d(xx, yy, z, kind='linear')

        X_ = np.hstack((xx.flatten()[:, None], yy.flatten()[:, None]))
        Z_ = a1.flatten()

        n_x = np.arange(0, self.chip.npix_x * 6, 1)
        n_y = np.arange(0, self.chip.npix_y * 5, 1)

        M, N = np.meshgrid(n_x, n_y)

        i = self.chip.rowID - 1
        j = self.chip.colID - 1
        U = griddata(X_, Z_, (
            M[self.chip.npix_y * i:self.chip.npix_y *
                (i + 1), self.chip.npix_x * j:self.chip.npix_x * (j + 1)],
            N[self.chip.npix_y * i:self.chip.npix_y * (i + 1), self.chip.npix_x * j:self.chip.npix_x * (j + 1)]),
            method='linear')
        U = U/np.mean(U)

        flatImage = U
        if LED_Img_flag:
            flatImage = flatImage*fluxLED[led_type]*1000
        gc.collect()
        return flatImage

        ###
    # return LED flat, e/s
    ###

    def getLEDImage1(self, led_type='LED1', LED_Img_flag=True):
        # cwave = cwaves[led_type]
        flat = fits.open(os.path.join(self.flatDir, 'model_' +
                         cwaves_name[led_type] + 'nm.fits'))
        xlen = flat[0].header['NAXIS1']
        ylen = 601
        i = self.chip.rowID - 1
        j = self.chip.colID - 1
        x = np.linspace(0, self.chip.npix_x, int(xlen/6.))
        y = np.linspace(0, self.chip.npix_y, int(ylen/5.))
        xx, yy = np.meshgrid(x, y)
        a1 = flat[0].data[int(ylen*i/5.):int(ylen*i/5.)+int(ylen/5.),
                          int(xlen*j/6.):int(xlen*j/6.)+int(xlen/6.)]
        # z = np.sin((xx+yy+xx**2+yy**2))
        # fInterp = interp2d(xx, yy, z, kind='linear')
        X_ = np.hstack((xx.flatten()[:, None], yy.flatten()[:, None]))
        Z_ = a1.flatten()
        n_x = np.arange(0, self.chip.npix_x, 1)
        n_y = np.arange(0, self.chip.npix_y, 1)
        M, N = np.meshgrid(n_x, n_y)
        x_seg_len = 4
        y_seg_len = 8
        x_seg = int(self.chip.npix_x/x_seg_len)
        y_seg = int(self.chip.npix_y/y_seg_len)
        U = np.zeros([self.chip.npix_y, self.chip.npix_x], dtype=np.float32)
        for y_seg_i in np.arange(y_seg_len):
            for x_seg_i in np.arange(x_seg_len):
                U[y_seg_i*y_seg:(y_seg_i+1)*y_seg, x_seg_i*x_seg:(x_seg_i+1)*x_seg] = griddata(X_, Z_, (M[y_seg_i*y_seg:(y_seg_i+1)*y_seg, x_seg_i*x_seg:(x_seg_i+1)*x_seg], N[y_seg_i*y_seg:(y_seg_i+1)*y_seg, x_seg_i*x_seg:(x_seg_i+1)*x_seg]), method='linear')
        # U = griddata(X_, Z_, (
        #     M[0:self.chip.npix_y, 0:self.chip.npix_x],
        #     N[0:self.chip.npix_y, 0:self.chip.npix_x]),
        #     method='nearest').astype(np.float32)
        U = U/np.mean(U)
        flatImage = U
        if LED_Img_flag:
            flatImage = U*fluxLED[led_type]*1000
        gc.collect()
        return flatImage

    def drawObj_LEDFlat_img(self, led_type_list=['LED1'], exp_t_list=[0.1]):
        if len(led_type_list) > len(exp_t_list):
            return np.ones([self.chip.npix_y, self.chip.npix_x])

        ledFlat = np.zeros([self.chip.npix_y, self.chip.npix_x])

        ledStat = '00000000000000'
        ledTimes = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

        nledStat = '2'
        for i in np.arange(len(led_type_list)):
            led_type = led_type_list[i]
            exp_t = exp_t_list[i]
            # unitFlatImg = self.getLEDImage(led_type=led_type)
            unitFlatImg = self.getLEDImage1(led_type=led_type)
            # print("---------------TEST mem:",np.mean(unitFlatImg))
            led_wave = cwaves[led_type]
            led_fwhm = cwaves_fwhm[led_type]
            led_spec = self.gaussian1d_profile_led(led_wave, led_fwhm)
            speci = interpolate.interp1d(
                led_spec['WAVELENGTH'], led_spec['FLUX'])
            w_list = np.arange(self.filt.blue_limit,
                               self.filt.red_limit, 0.5)  # A

            f_spec = speci(w_list)
            ccd_bp = self.chip._getChipEffCurve(self.chip.filter_type)
            ccd_eff = ccd_bp.__call__(w_list / 10.)
            filt_bp = self.filt.filter_bandpass
            fil_eff = filt_bp.__call__(w_list / 10.)
            t_spec = np.trapz(f_spec*ccd_eff*fil_eff, w_list)
            # print(i, np.mean(unitFlatImg), t_spec, exp_t)
            unitFlatImg = unitFlatImg * t_spec
            # print("DEBUG1:---------------",np.mean(unitFlatImg))
            ledFlat = ledFlat+unitFlatImg*exp_t

            ledStat = ledStat[0:int(led_type[3:])-1] + \
                nledStat+ledStat[int(led_type[3:]):]
            ledTimes[int(led_type[3:])-1] = exp_t * 1000
            gc.collect()
        return ledFlat, ledStat, ledTimes

    def drawObj_LEDFlat_slitless(self, led_type_list=['LED1'], exp_t_list=[0.1]):
        if len(led_type_list) != len(exp_t_list):
            return np.ones([self.chip.npix_y, self.chip.npix_x])

        ledFlat = np.zeros([self.chip.npix_y, self.chip.npix_x])

        ledStat = '00000000000000'
        ledTimes = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

        nledStat = '2'

        for i in np.arange(len(led_type_list)):
            led_type = led_type_list[i]
            exp_t = exp_t_list[i]
            # unitFlatImg = self.getLEDImage(led_type=led_type)
            unitFlatImg = self.getLEDImage1(led_type=led_type)
            # print("---------------TEST mem:",np.mean(unitFlatImg))
            ledFlat_ = unitFlatImg*exp_t
            ledFlat_ = ledFlat_ / mirro_eff[self.filt.filter_type]
            ledFlat_.astype(np.float32)
            led_wave = cwaves[led_type]
            led_fwhm = cwaves_fwhm[led_type]
            led_spec = self.gaussian1d_profile_led(led_wave, led_fwhm)
            # print("DEBUG1:---------------",np.mean(ledFlat_))
            ledspec_map = self.calculateLEDSpec(
                skyMap=ledFlat_,
                blueLimit=self.filt.blue_limit,
                redLimit=self.filt.red_limit,
                conf=self.chip.sls_conf,
                pixelSize=self.chip.pix_scale,
                isAlongY=0,
                flat_cube=self.chip.flat_cube, led_spec=led_spec)

            ledFlat = ledFlat + ledspec_map
            ledStat = ledStat[0:int(led_type[3:])-1] + \
                nledStat+ledStat[int(led_type[3:]):]
            ledTimes[int(led_type[3:])-1] = exp_t * 1000
        return ledFlat, ledStat, ledTimes

    def drawObj_LEDFlat(self, led_type_list=['LED1'], exp_t_list=[0.1]):
        if self.chip.survey_type == "photometric":
            return self.drawObj_LEDFlat_img(led_type_list=led_type_list, exp_t_list=exp_t_list)
        elif self.chip.survey_type == "spectroscopic":
            return self.drawObj_LEDFlat_slitless(led_type_list=led_type_list, exp_t_list=exp_t_list)

    def gaussian1d_profile_led(self, xc=5050, fwhm=300):
        sigma = fwhm/2.355
        x_radii = int(5*sigma + 1)
        xlist = np.arange(xc-x_radii, xc+x_radii, 0.5)
        xlist_ = np.zeros(len(xlist) + 2)
        xlist_[1:-1] = xlist
        xlist_[0] = 2000
        xlist_[-1] = 18000
        ids1 = xlist > xc-fwhm
        ids2 = xlist[ids1] < xc+fwhm
        data = np.exp((-(xlist-xc)*(xlist-xc))/(2*sigma*sigma)) / \
            (np.sqrt(2*math.pi)*sigma)
        scale = 1/np.trapz(data[ids1][ids2], xlist[ids1][ids2])
        data_ = np.zeros(len(xlist) + 2)
        data_[1:-1] = data*scale
        # print("DEBUG:-------------------------------",np.sum(data_), scale)
        return Table(np.array([xlist_.astype(np.float32), data_.astype(np.float32)]).T, names=('WAVELENGTH', 'FLUX'))

    def calculateLEDSpec(self, skyMap=None, blueLimit=4200, redLimit=6500,
                         conf=[''], pixelSize=0.074, isAlongY=0,
                         split_pos=3685, flat_cube=None, led_spec=None):

        conf1 = conf[0]
        conf2 = conf[0]
        if np.size(conf) == 2:
            conf2 = conf[1]

        skyImg = galsim.Image(skyMap, xmin=0, ymin=0)

        tbstart = blueLimit
        tbend = redLimit

        fimg = np.zeros_like(skyMap)

        fImg = galsim.Image(fimg)

        spec = led_spec
        if isAlongY == 0:
            directParm = 0
        if isAlongY == 1:
            directParm = 1

        if split_pos >= skyImg.array.shape[directParm]:
            skyImg1 = galsim.Image(skyImg.array)
            origin1 = [0, 0]
            # sdp = specDisperser.specDisperser(orig_img=skyImg1, xcenter=skyImg1.center.x, ycenter=skyImg1.center.y,
            #                                   full_img=fimg, tar_spec=spec, band_start=tbstart, band_end=tbend,
            #                                   origin=origin1,
            #                                   conf=conf1)
            # sdp.compute_spec_orders()

            y_len = skyMap.shape[0]
            x_len = skyMap.shape[1]
            delt_x = 100
            delt_y = 100

            sub_y_start_arr = np.arange(0, y_len, delt_y)
            sub_y_end_arr = sub_y_start_arr + delt_y
            sub_y_end_arr[-1] = min(sub_y_end_arr[-1], y_len)

            sub_x_start_arr = np.arange(0, x_len, delt_x)
            sub_x_end_arr = sub_x_start_arr + delt_x
            sub_x_end_arr[-1] = min(sub_x_end_arr[-1], x_len)

            for i, k1 in enumerate(sub_y_start_arr):
                sub_y_s = k1
                sub_y_e = sub_y_end_arr[i]

                sub_y_center = (sub_y_s + sub_y_e) / 2.

                for j, k2 in enumerate(sub_x_start_arr):
                    sub_x_s = k2
                    sub_x_e = sub_x_end_arr[j]

                    skyImg_sub = galsim.Image(
                        skyImg.array[sub_y_s:sub_y_e, sub_x_s:sub_x_e])
                    origin_sub = [sub_y_s, sub_x_s]
                    sub_x_center = (sub_x_s + sub_x_e) / 2.

                    sdp = SpecDisperser(orig_img=skyImg_sub, xcenter=sub_x_center, ycenter=sub_y_center,
                                        origin=origin_sub,
                                        tar_spec=spec,
                                        band_start=tbstart, band_end=tbend,
                                        conf=conf2,
                                        flat_cube=flat_cube)

                    spec_orders = sdp.compute_spec_orders()

                    for k, v in spec_orders.items():
                        img_s = v[0]
                        origin_order_x = v[1]
                        origin_order_y = v[2]
                        ssImg = galsim.ImageF(img_s)
                        ssImg.setOrigin(origin_order_x, origin_order_y)
                        bounds = ssImg.bounds & fImg.bounds
                        if bounds.area() == 0:
                            continue
                        fImg[bounds] = fImg[bounds] + ssImg[bounds]

        else:

            # sdp.compute_spec_orders()
            y_len = skyMap.shape[0]
            x_len = skyMap.shape[1]
            delt_x = 500
            delt_y = y_len

            sub_y_start_arr = np.arange(0, y_len, delt_y)
            sub_y_end_arr = sub_y_start_arr + delt_y
            sub_y_end_arr[-1] = min(sub_y_end_arr[-1], y_len)

            delt_x = split_pos - 0
            sub_x_start_arr = np.arange(0, split_pos, delt_x)
            sub_x_end_arr = sub_x_start_arr + delt_x
            sub_x_end_arr[-1] = min(sub_x_end_arr[-1], split_pos)

            for i, k1 in enumerate(sub_y_start_arr):
                sub_y_s = k1
                sub_y_e = sub_y_end_arr[i]

                sub_y_center = (sub_y_s + sub_y_e) / 2.

                for j, k2 in enumerate(sub_x_start_arr):
                    sub_x_s = k2
                    sub_x_e = sub_x_end_arr[j]
                    # print(i,j,sub_y_s, sub_y_e,sub_x_s,sub_x_e)
                    T1 = time.time()
                    skyImg_sub = galsim.Image(
                        skyImg.array[sub_y_s:sub_y_e, sub_x_s:sub_x_e])
                    origin_sub = [sub_y_s, sub_x_s]
                    sub_x_center = (sub_x_s + sub_x_e) / 2.

                    sdp = SpecDisperser(orig_img=skyImg_sub, xcenter=sub_x_center, ycenter=sub_y_center,
                                        origin=origin_sub,
                                        tar_spec=spec,
                                        band_start=tbstart, band_end=tbend,
                                        conf=conf1,
                                        flat_cube=flat_cube)

                    spec_orders = sdp.compute_spec_orders()

                    for k, v in spec_orders.items():
                        img_s = v[0]
                        origin_order_x = v[1]
                        origin_order_y = v[2]
                        ssImg = galsim.ImageF(img_s)
                        ssImg.setOrigin(origin_order_x, origin_order_y)
                        bounds = ssImg.bounds & fImg.bounds
                        if bounds.area() == 0:
                            continue
                        fImg[bounds] = fImg[bounds] + ssImg[bounds]

                    T2 = time.time()

                    print('time: %s ms' % ((T2 - T1) * 1000))

            delt_x = x_len - split_pos
            sub_x_start_arr = np.arange(split_pos, x_len, delt_x)
            sub_x_end_arr = sub_x_start_arr + delt_x
            sub_x_end_arr[-1] = min(sub_x_end_arr[-1], x_len)

            for i, k1 in enumerate(sub_y_start_arr):
                sub_y_s = k1
                sub_y_e = sub_y_end_arr[i]

                sub_y_center = (sub_y_s + sub_y_e) / 2.

                for j, k2 in enumerate(sub_x_start_arr):
                    sub_x_s = k2
                    sub_x_e = sub_x_end_arr[j]
                    # print(i,j,sub_y_s, sub_y_e,sub_x_s,sub_x_e)

                    T1 = time.time()

                    skyImg_sub = galsim.Image(
                        skyImg.array[sub_y_s:sub_y_e, sub_x_s:sub_x_e])
                    origin_sub = [sub_y_s, sub_x_s]
                    sub_x_center = (sub_x_s + sub_x_e) / 2.

                    sdp = SpecDisperser(orig_img=skyImg_sub, xcenter=sub_x_center, ycenter=sub_y_center,
                                        origin=origin_sub,
                                        tar_spec=spec,
                                        band_start=tbstart, band_end=tbend,
                                        conf=conf2,
                                        flat_cube=flat_cube)

                    spec_orders = sdp.compute_spec_orders()

                    for k, v in spec_orders.items():
                        img_s = v[0]
                        origin_order_x = v[1]
                        origin_order_y = v[2]
                        ssImg = galsim.ImageF(img_s)
                        ssImg.setOrigin(origin_order_x, origin_order_y)
                        bounds = ssImg.bounds & fImg.bounds
                        if bounds.area() == 0:
                            continue
                        fImg[bounds] = fImg[bounds] + ssImg[bounds]
                    T2 = time.time()

                    print('time: %s ms' % ((T2 - T1) * 1000))

        if isAlongY == 1:
            fimg, tmx, tmy = SpecDisperser.rotate90(
                array_orig=fImg.array, xc=0, yc=0, isClockwise=0)
        else:
            fimg = fImg.array

        # fimg = fimg * pixelSize * pixelSize

        return fimg