import { Dataset } from '../../../types/dataset'
import mapboxgl, { CustomLayerInterface } from 'mapbox-gl'
import { boundsStringToArray } from '../../../utils'
import { getLatLngForTile } from '../TileLoadingLayer'
import { sourceId, dimensionsToString } from '../utils'
import { Layer, MapSourceDataEvent } from 'react-map-gl'
import vertexShaderSource from './RasterFillLayer.vert.glsl?raw'
import fragmentShaderSource from './RasterFillLayer.frag.glsl?raw'
import React from 'react'
import { CanonicalCoordinate } from 'mapbox-gl'

function doBoundsIntersect(boundsA, boundsB) {
    const [westA, southA, eastA, northA] = boundsA
    const [westB, southB, eastB, northB] = boundsB

    // Check if one rectangle is completely to the left or right of the other
    if (eastA <= westB || eastB <= westA) return false

    // Check if one rectangle is completely above or below the other
    if (northA <= southB || northB <= southA) return false

    return true // Otherwise, they intersect
}

function getIntersectionBounds(boundsA, boundsB) {
    const [westA, southA, eastA, northA] = boundsA
    const [westB, southB, eastB, northB] = boundsB

    // First, check if they intersect
    if (!doBoundsIntersect(boundsA, boundsB)) {
        return null // No intersection
    }

    return [
        Math.max(westA, westB),
        Math.max(southA, southB),
        Math.min(eastA, eastB),
        Math.min(northA, northB),
    ]
}

type RasterFillLayerProps = {
    id: string
    dataset: Dataset
}

export function RasterFillLayer({ id, dataset }: RasterFillLayerProps) {
    const datasetRef = React.useRef(dataset)

    React.useEffect(() => {
        datasetRef.current = dataset
    }, [dataset])

    const customLayer: CustomLayerInterface = {
        id,
        type: 'custom',
        renderingMode: '2d',

        onAdd: function (map: mapboxgl.Map, gl: WebGLRenderingContext) {
            this.map = map
            this.zoom = Math.round(map.getZoom())
            this.timeouts = new Set();

            const vertexShader = gl.createShader(gl.VERTEX_SHADER)
            gl.shaderSource(vertexShader, vertexShaderSource)
            gl.compileShader(vertexShader)

            const fragmentShader = gl.createShader(gl.FRAGMENT_SHADER)
            gl.shaderSource(fragmentShader, fragmentShaderSource)
            gl.compileShader(fragmentShader)

            this.program = gl.createProgram()
            gl.attachShader(this.program, vertexShader)
            gl.attachShader(this.program, fragmentShader)
            gl.linkProgram(this.program)

            this.aPos = gl.getAttribLocation(this.program, 'a_pos')
            this.uZoom = gl.getUniformLocation(this.program, 'u_zoom')
            this.uOpacity = gl.getUniformLocation(this.program, 'u_opacity')
            this.uMatrix = gl.getUniformLocation(this.program, 'u_matrix')

            this.vertexBuffer = gl.createBuffer()
            this.indexBuffer = gl.createBuffer()

            // dict of zoom level -> tile id key -> tile id value
            this.loadedTiles = new Map()
            this.vertices = []
            this.indices = []

            this.regenerateBuffers = () => {
                const datasetBounds = boundsStringToArray(
                    datasetRef.current.bounds
                )
                this.vertices = []
                this.indices = []
                let vertexOffset = 0
                let zoomTiles = this.loadedTiles.get(this.zoom)
                // Don't render on globe for low zoom-level since the assumption
                // of web-mercator breaks down leading to weird effects
                if (this.map.getProjection().name === 'globe' && this.zoom < 8)
                    return
                if (zoomTiles !== undefined) {
                    for (const tileID of zoomTiles.values()) {
                        const [north, west] = getLatLngForTile(
                            tileID.x,
                            tileID.y,
                            tileID.z
                        )
                        const [south, east] = getLatLngForTile(
                            tileID.x + 1,
                            tileID.y + 1,
                            tileID.z
                        )
                        const tileBounds = [west, south, east, north]

                        // compute tile-bounds clipped to dataset
                        const intersectionBounds = getIntersectionBounds(
                            tileBounds,
                            datasetBounds
                        )
                        // skip tile if it doesn't intersect with bounds
                        if (!intersectionBounds) continue

                        const [
                            clippedWest,
                            clippedSouth,
                            clippedEast,
                            clippedNorth,
                        ] = intersectionBounds

                        // Standard Web Mercator conversion
                        const mapCoords = [
                            mapboxgl.MercatorCoordinate.fromLngLat([
                                clippedWest,
                                clippedNorth,
                            ]),
                            mapboxgl.MercatorCoordinate.fromLngLat([
                                clippedWest,
                                clippedSouth,
                            ]),
                            mapboxgl.MercatorCoordinate.fromLngLat([
                                clippedEast,
                                clippedSouth,
                            ]),
                            mapboxgl.MercatorCoordinate.fromLngLat([
                                clippedEast,
                                clippedNorth,
                            ]),
                        ]

                        const localVertices = [
                            mapCoords[0].x,
                            mapCoords[0].y,
                            mapCoords[1].x,
                            mapCoords[1].y,
                            mapCoords[2].x,
                            mapCoords[2].y,
                            mapCoords[3].x,
                            mapCoords[3].y,
                        ]
                        this.vertices.push(...localVertices)
                        // noinspection PointlessArithmeticExpressionJS
                        this.indices.push(
                            vertexOffset + 0,
                            vertexOffset + 1,
                            vertexOffset + 2,
                            vertexOffset + 0,
                            vertexOffset + 2,
                            vertexOffset + 3
                        )
                        vertexOffset += 4
                    }
                }
                gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer)
                gl.bufferData(
                    gl.ARRAY_BUFFER,
                    new Float32Array(this.vertices),
                    gl.DYNAMIC_DRAW
                )
                gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer)
                gl.bufferData(
                    gl.ELEMENT_ARRAY_BUFFER,
                    new Uint16Array(this.indices),
                    gl.DYNAMIC_DRAW
                )
            }

            this.addTile = (tileIDCanonical: CanonicalCoordinate) => {
                let zoomTiles = this.loadedTiles.get(tileIDCanonical.z)
                if (zoomTiles === undefined) {
                    zoomTiles = new Map()
                    this.loadedTiles.set(tileIDCanonical.z, zoomTiles)
                }
                zoomTiles.set(tileIDCanonical.key, tileIDCanonical)
            }

            this.deleteTile = (tileIDCanonical: CanonicalCoordinate) => {
                const zoomTiles = this.loadedTiles.get(tileIDCanonical.z)
                if (zoomTiles === undefined) {
                    return
                }
                zoomTiles.delete(tileIDCanonical.key)
            }

            this.sourceDataHandler = (e: MapSourceDataEvent) => {
                if (e.sourceId !== sourceId(dataset, dimensionsToString(dataset.selectedDimensions))) {
                    return
                }
                const tile = e.tile
                if (!tile) return
                if (tile.state === 'loaded') {
                    const timeout = setTimeout(() => {
                        this.addTile(tile.tileID.canonical)
                        this.regenerateBuffers()
                    }, 500) // empirically tuned timeout to avoid flickering
                    this.timeouts.add(timeout)
                    return
                } else if (tile.state === 'error') {
                    this.deleteTile(tile.tileID.canonical)
                } else {
                    return
                }
                this.regenerateBuffers()
            }
            this.zoomHandler = () => {
                let zoom = Math.round(map.getZoom())
                if (zoom == null) return
                const datasetMaxZoom = datasetRef.current.maxZoom
                if (datasetMaxZoom != null && zoom > datasetMaxZoom) {
                    zoom = datasetMaxZoom
                }
                if (zoom !== this.zoom) {
                    this.zoom = zoom
                    this.regenerateBuffers()
                }
            }

            map.on('sourcedata', this.sourceDataHandler)
            map.on('zoom', this.zoomHandler)
        },

        render: function (gl: WebGLRenderingContext, matrix: number[]) {
            if (this.indices.length === 0) return
            // Don't render on globe for low zoom-level since the assumption
            // of web-mercator breaks down leading to weird effects
            if (this.map.getProjection().name === 'globe' && this.zoom < 8)
                return
            if (
                datasetRef.current.vizParams?.transparencyFillLayerEnabled !==
                true
            ) {
                return
            }

            const opacity =
                datasetRef.current.vizParams?.transparencyFillLayerOpacity
            if (opacity == null) {
                return
            }

            gl.useProgram(this.program)
            gl.uniformMatrix4fv(this.uMatrix, false, matrix)
            gl.uniform1f(this.uZoom, this.zoom)
            gl.uniform1f(this.uOpacity, Math.min(1, Math.max(0, opacity)))

            gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer)
            gl.enableVertexAttribArray(this.aPos)
            gl.vertexAttribPointer(this.aPos, 2, gl.FLOAT, false, 0, 0)

            gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer)
            gl.drawElements(
                gl.TRIANGLES,
                this.indices.length,
                gl.UNSIGNED_SHORT,
                0
            )
        },

        onRemove: function () {
            if (this.map) {
                this.map.off('sourcedata', this.sourceDataHandler)
                this.map.off('zoom', this.zoomHandler)
            }
            if (this.timeouts) {
                this.timeouts.forEach(timeout => clearTimeout(timeout));
                this.timeouts.clear();
            }
        },

        prerender: function () {},
    }

    // @ts-ignore This is not directly exposed in react-map-gl, but it works
    return <Layer {...customLayer} />
}
