from ObservationSim.MockObject.SpecDisperser import SpecDisperser
from ObservationSim.MockObject.SpecDisperser import rotate90

import galsim
import numpy as np
from astropy.table import Table
from scipy import interpolate

import galsim

import os

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 'importlib_resources'
    import importlib_resources as pkg_resources


###calculate sky map by sky SED

def calculateSkyMap_split_g(skyMap=None, blueLimit=4200, redLimit=6500, skyfn='sky_emiss_hubble_50_50_A.dat', conf=[''], pixelSize=0.074, isAlongY=0,
                            split_pos=3685):
    # skyMap = np.ones([yLen, xLen], dtype='float32')
    #
    # if isAlongY == 1:
    #     skyMap = np.ones([xLen, yLen], dtype='float32')

    # for i in range(len(conf)):
    #     conf[i] = os.path.join(SLSSIM_PATH, conf[i])
    conf1 = conf[0]
    conf2 = conf[0]
    if np.size(conf) == 2:
        conf2 = conf[1]

    skyImg = galsim.Image(skyMap, xmin=0, ymin=0)

    tbstart = blueLimit
    tbend = redLimit

    fimg = np.zeros_like(skyMap)

    fImg = galsim.Image(fimg)
    # skyfn = os.path.join(SLSSIM_PATH, skyfn)
    with pkg_resources.path('ObservationSim.MockObject.data', skyfn) as data_path:
        skySpec = np.loadtxt(data_path)
    # skySpec = np.loadtxt(skyfn)
    spec = Table(np.array([skySpec[:, 0], skySpec[:, 1]]).T, names=('WAVELENGTH', 'FLUX'))

    if isAlongY == 0:
        directParm = 0
    if isAlongY ==1:
        directParm = 1

    if split_pos >= skyImg.array.shape[directParm]:
        skyImg1 = galsim.Image(skyImg.array)
        origin1 = [0, 0]
        # sdp = specDisperser.specDisperser(orig_img=skyImg1, xcenter=skyImg1.center.x, ycenter=skyImg1.center.y,
        #                                   full_img=fimg, tar_spec=spec, band_start=tbstart, band_end=tbend,
        #                                   origin=origin1,
        #                                   conf=conf1)
        # sdp.compute_spec_orders()

        
        sdp = SpecDisperser(orig_img=skyImg1, xcenter=skyImg1.center.x, ycenter=skyImg1.center.y, origin=origin1,
                        tar_spec=spec,
                        band_start=tbstart, band_end=tbend,
                        conf=conf2)

        spec_orders = sdp.compute_spec_orders()

        for k, v in spec_orders.items():
            img_s = v[0]
            origin_order_x = v[1]
            origin_order_y = v[2]
            ssImg = galsim.ImageF(img_s)
            ssImg.setOrigin(origin_order_x, origin_order_y)
            bounds = ssImg.bounds & fImg.bounds
            if bounds.area() == 0:
                continue
            fImg[bounds] = fImg[bounds] + ssImg[bounds]
        

        
    else:

        skyImg1 = galsim.Image(skyImg.array[:, 0:split_pos])
        origin1 = [0, 0]
        skyImg2 = galsim.Image(skyImg.array[:, split_pos:-1])
        origin2 = [0, split_pos]

        # sdp = specDisperser.specDisperser(orig_img=skyImg1, xcenter=skyImg1.center.x, ycenter=skyImg1.center.y,
        #                                   full_img=fimg, tar_spec=spec, band_start=tbstart, band_end=tbend,
        #                                   origin=origin1,
        #                                   conf=conf1)

        # sdp.compute_spec_orders()

        sdp = SpecDisperser(orig_img=skyImg1, xcenter=skyImg1.center.x, ycenter=skyImg1.center.y, origin=origin1,
                        tar_spec=spec,
                        band_start=tbstart, band_end=tbend,
                        conf=conf1)

        spec_orders = sdp.compute_spec_orders()

        for k, v in spec_orders.items():
            img_s = v[0]
            origin_order_x = v[1]
            origin_order_y = v[2]
            ssImg = galsim.ImageF(img_s)
            ssImg.setOrigin(origin_order_x, origin_order_y)
            bounds = ssImg.bounds & fImg.bounds
            if bounds.area() == 0:
                continue
            fImg[bounds] = fImg[bounds] + ssImg[bounds]

        
        sdp = SpecDisperser(orig_img=skyImg2, xcenter=skyImg2.center.x, ycenter=skyImg2.center.y, origin=origin2,
                        tar_spec=spec,
                        band_start=tbstart, band_end=tbend,
                        conf=conf2)

        spec_orders = sdp.compute_spec_orders()

        for k, v in spec_orders.items():
            img_s = v[0]
            origin_order_x = v[1]
            origin_order_y = v[2]
            ssImg = galsim.ImageF(img_s)
            ssImg.setOrigin(origin_order_x, origin_order_y)
            bounds = ssImg.bounds & fImg.bounds
            fImg[bounds] = fImg[bounds] + ssImg[bounds]

    if isAlongY == 1:
        fimg, tmx, tmy = rotate90(array_orig=fImg.array, xc=0, yc=0, isClockwise=0)
    else:
        fimg = fImg.array

    fimg = fimg * pixelSize * pixelSize

    return fimg
    
def calculateSkyMap(xLen=9232, yLen=9126, blueLimit=4200, redLimit=6500,
                    skyfn='sky_emiss_hubble_50_50_A.dat', conf='', pixelSize=0.074, isAlongY=0):
    skyMap = np.ones([yLen, xLen], dtype='float32')

    if isAlongY == 1:
        skyMap = np.ones([xLen, yLen], dtype='float32')

    skyImg = galsim.Image(skyMap)

    tbstart = blueLimit
    tbend = redLimit

    fimg = np.zeros_like(skyMap)
    fImg = galsim.Image(fimg)
    with pkg_resources.path('ObservationSim.MockObject.data', skyfn) as data_path:
        skySpec = np.loadtxt(data_path)
    # skySpec = np.loadtxt(skyfn)
    
    spec = Table(np.array([skySpec[:, 0], skySpec[:, 1]]).T, names=('WAVELENGTH', 'FLUX'))

    sdp = SpecDisperser(orig_img=skyImg, xcenter=skyImg.center.x, ycenter=skyImg.center.y, origin=[1, 1],
                        tar_spec=spec,
                        band_start=tbstart, band_end=tbend,
                        conf=conf)

    spec_orders = sdp.compute_spec_orders()

    for k, v in spec_orders.items():
        img_s = v[0]
        origin_order_x = v[1]
        origin_order_y = v[2]
        ssImg = galsim.ImageF(img_s)
        ssImg.setOrigin(origin_order_x, origin_order_y)
        bounds = ssImg.bounds & fImg.bounds
        fImg[bounds] = fImg[bounds] + ssImg[bounds]

    if isAlongY == 1:
        fimg, tmx, tmy = rotate90(array_orig=fImg.array, xc=0, yc=0, isClockwise=0)
    else:
        fimg = fImg.array
        
    fimg = fimg * pixelSize * pixelSize

    return fimg