import { ArrowBack, ContentCopy } from '@mui/icons-material'
import {
    Box,
    Button,
    Checkbox,
    ClickAwayListener,
    Divider,
    FormControl,
    FormControlLabel,
    IconButton,
    InputLabel,
    MenuItem,
    Select,
    Slider,
    Stack,
    Switch,
    TextField,
    Tooltip,
    Typography,
} from '@mui/material'
import { getDefaultRasterVizParams } from '../../api/dataset'
import React, { useEffect, useState } from 'react'
import { SketchPicker } from 'react-color'
import ReactDOM from 'react-dom'
import CacheManager from '../../context/cache'
import { useMapContext } from '../../context/map/mapContext'
import { Dataset } from '../../types/dataset'
import {
    BinnedRasterVizParams,
    CategoricalRasterVizParams,
    MultibandRasterVizParams,
    PseudocolorRasterVizParams,
    VectorVizParams,
    VizParams,
    VizType,
} from '../../types/viz'
import VisibilityButton from '../DatasetCard/VisibilityButton'
import ZoomButton from '../DatasetCard/ZoomButton'
import CopyVizParamsDialog from './CopyVizParamsDialog'
import VizParamsEditorStyle from './VizParamsEditor.module.css'

const COLOR_RAMPS = ['viridis', 'plasma', 'inferno', 'magma', 'cividis']

const CATEGORICAL_COLOR_MAPS = [
    'paired',
    'accent',
    'viridis',
    'hsv',
    'tab10',
    'tab20',
    'tab20b',
    'tab20c',
]

const FloatingBox = ({
    children,
    top,
    left,
    bottom,
}: {
    children: React.ReactNode
    top?: number
    left: number
    bottom?: number
}) => {
    return ReactDOM.createPortal(
        <Box
            sx={{
                position: 'absolute',
                ...(top !== undefined ? { top } : { bottom }),
                left,
                zIndex: 2000,
            }}
        >
            {children}
        </Box>,
        document.body
    )
}

function VectorVizParamsEditor({
    dataset,
    setHasValidationErrors,
}: {
    dataset: Dataset
    setHasValidationErrors: React.Dispatch<React.SetStateAction<boolean>>
}) {
    const { updateVizParams } = useMapContext()
    const vizParams = dataset.vizParams as VectorVizParams
    const [colorPickerVisible, setColorPickerVisible] = React.useState(false)
    const [colorPickerPosition, setColorPickerPosition] = React.useState({
        top: 0,
        left: 0,
    })

    const validateVectorParams = (params: VectorVizParams) => {
        return true
    }

    const handleChange = (field: string, value: string | number) => {
        const newParams = {
            ...vizParams,
            [field]: value,
        }
        updateVizParams(dataset, newParams, dataset.vizType)
        setHasValidationErrors(!validateVectorParams(newParams))
    }

    const handleLocalColorPickerClick = (event: React.MouseEvent) => {
        // For vector viz, position the color picker with top-left at the clicked position
        setColorPickerPosition({ top: event.clientY, left: event.clientX })
        setColorPickerVisible(true)
    }

    return (
        <Box sx={{ minWidth: 120 }}>
            <FormControl fullWidth sx={{ mt: 2, mb: 2 }}>
                <InputLabel id="vector-type-label">
                    Visualization Type
                </InputLabel>
                <Select
                    labelId="vector-type-label"
                    value={vizParams.mode}
                    label="Visualization Type"
                    sx={{
                        '& fieldset': {
                            borderColor: 'primary.main',
                        },
                    }}
                    onChange={(e) => handleChange('mode', e.target.value)}
                >
                    <MenuItem value="fill">Fill</MenuItem>
                    <MenuItem value="outline">Outline</MenuItem>
                </Select>
            </FormControl>

            <Box sx={{ mb: 2 }}>
                <Typography gutterBottom>Color</Typography>
                <Box
                    sx={{
                        display: 'flex',
                        alignItems: 'center',
                        cursor: 'pointer',
                    }}
                    onClick={(event) => handleLocalColorPickerClick(event)}
                >
                    <Box
                        sx={{
                            width: 36,
                            height: 14,
                            borderRadius: 1,
                            backgroundColor: vizParams.color,
                            marginRight: 1,
                        }}
                    />
                    <Typography>{vizParams.color}</Typography>
                </Box>
                {colorPickerVisible && (
                    <FloatingBox
                        top={colorPickerPosition.top}
                        left={colorPickerPosition.left}
                    >
                        <ClickAwayListener
                            onClickAway={() => setColorPickerVisible(false)}
                        >
                            <Box>
                                <SketchPicker
                                    color={vizParams.color}
                                    onChangeComplete={(color) => {
                                        handleChange('color', color.hex)
                                        setColorPickerVisible(false)
                                    }}
                                />
                            </Box>
                        </ClickAwayListener>
                    </FloatingBox>
                )}
            </Box>

            {vizParams.mode === 'outline' && (
                <Box sx={{ width: '100%', mb: 2 }}>
                    <Typography id="width-slider" gutterBottom>
                        Width (px)
                    </Typography>
                    <Slider
                        aria-labelledby="width-slider"
                        value={vizParams.width || 1}
                        onChange={(_, newValue: number) => {
                            return handleChange('width', newValue)
                        }}
                        valueLabelDisplay="auto"
                        step={0.5}
                        marks
                        min={0.5}
                        max={5}
                    />
                </Box>
            )}
        </Box>
    )
}

function CategoricalVizParamsEditor({ dataset, setHasValidationErrors }) {
    const { updateVizParams } = useMapContext()
    const vizParams: CategoricalRasterVizParams =
        dataset.vizParams as CategoricalRasterVizParams

    const validateCategoricalParams = (params: CategoricalRasterVizParams) => {
        return Boolean(
            params.colorMap && CATEGORICAL_COLOR_MAPS.includes(params.colorMap)
        )
    }

    const updateParam = (
        param: keyof CategoricalRasterVizParams,
        value: any
    ) => {
        const newParams = {
            ...dataset.vizParams,
            [param]: value,
        } as CategoricalRasterVizParams

        setHasValidationErrors(!validateCategoricalParams(newParams))
        updateVizParams(dataset, newParams, dataset.vizType)
    }

    return (
        <Box>
            <FormControl fullWidth sx={{ mb: 2, mt: '4vh' }}>
                <InputLabel id="color-map-label ">Color Map</InputLabel>
                <Select
                    labelId="color-ramp-label"
                    id="color-ramp"
                    value={vizParams.colorMap}
                    label="Color Ramp"
                    onChange={(event) => {
                        updateParam('colorMap', event.target.value)
                    }}
                    sx={{
                        '& fieldset': {
                            borderColor: 'primary.main',
                        },
                    }}
                >
                    {CATEGORICAL_COLOR_MAPS.map((colorMap) => (
                        <MenuItem key={colorMap} value={colorMap}>
                            {colorMap}
                        </MenuItem>
                    ))}
                </Select>
            </FormControl>
            <FormControlLabel
                control={
                    <Checkbox
                        checked={vizParams?.reverseColormap || false}
                        onChange={(event) => {
                            updateParam('reverseColormap', event.target.checked)
                        }}
                    />
                }
                label="Reverse Colormap"
                sx={{ justifyContent: 'center', width: '100%' }}
            />
        </Box>
    )
}

function TransparencyFillLayerVizParamsEditor<T extends VizParams>({
    localVizParams,
    setLocalVizParams,
    setHasValidationErrors,
}: {
    localVizParams: T
    setLocalVizParams: (vizParams: T | ((prev: T) => T)) => void
    setHasValidationErrors: React.Dispatch<React.SetStateAction<boolean>>
}) {
    return (
        <>
            <Stack
                direction="row"
                alignItems="center"
                spacing={1}
                sx={{ marginRight: '25px', marginTop: '8px' }}
            >
                <Typography variant="body2">Transparency Fill Layer</Typography>
                <Switch
                    checked={localVizParams.transparencyFillLayerEnabled}
                    onChange={(event) => {
                        setLocalVizParams((prev: T) => ({
                            ...prev,
                            transparencyFillLayerEnabled: event.target.checked,
                        }))
                        setHasValidationErrors(false)
                    }}
                    size="small"
                />
            </Stack>
            {localVizParams.transparencyFillLayerEnabled && (
                <Stack
                    direction="row"
                    alignItems="center"
                    spacing={1}
                    sx={{ marginRight: '25px', marginTop: '8px' }}
                >
                    <Typography variant="body2">
                        Transparency Fill Layer Opacity
                    </Typography>
                    <Slider
                        value={
                            localVizParams.transparencyFillLayerOpacity * 100
                        }
                        onChange={(e, newValue) => {
                            const newParams = {
                                ...localVizParams,
                                transparencyFillLayerOpacity: Math.min(
                                    1,
                                    Math.max(0, (newValue as number) / 100)
                                ),
                            }
                            setLocalVizParams(newParams)
                            setHasValidationErrors(false)
                        }}
                        aria-labelledby="fill-layer-opacity-slider"
                        valueLabelDisplay="auto"
                        size={'small'}
                        step={1}
                        min={0}
                        max={100}
                    />
                </Stack>
            )}
        </>
    )
}

function BinnedVizParamsEditor({
    dataset,
    localVizParams,
    setLocalVizParams,
    setHasValidationErrors,
}: {
    dataset: Dataset
    localVizParams: BinnedRasterVizParams
    setLocalVizParams: (
        vizParams:
            | BinnedRasterVizParams
            | ((prev: BinnedRasterVizParams) => BinnedRasterVizParams)
    ) => void
    setHasValidationErrors: React.Dispatch<React.SetStateAction<boolean>>
}) {
    const inputRefs = React.useRef<(HTMLInputElement | null)[]>([])
    const [editingIndex, setEditingIndex] = useState<number | null>(null)
    const [localValues, setLocalValues] = useState<{ [key: number]: string }>(
        {}
    )
    const [colorPickerVisible, setColorPickerVisible] = React.useState(false)
    const [colorPickerPosition, setColorPickerPosition] = React.useState({
        top: 0,
        left: 0,
    })

    // Memoize the bins to avoid unnecessary re-renders
    const bins = React.useMemo(
        () => localVizParams.bins || [],
        [localVizParams.bins]
    )

    // Initialize and update localValues whenever bins change
    React.useEffect(() => {
        const newLocalValues = {}
        bins.forEach(([value], index) => {
            // Handle null or undefined values by using a default of 0
            newLocalValues[index] =
                value !== null && value !== undefined ? value.toString() : '0'
        })
        setLocalValues(newLocalValues)
    }, [bins])

    // Update inputRefs when bins change
    React.useEffect(() => {
        inputRefs.current = inputRefs.current.slice(0, bins.length)
    }, [bins.length])

    const validateBinnedParams = React.useCallback(
        (params: BinnedRasterVizParams) => {
            if (!params.band) return false
            if (
                !params.bins ||
                !Array.isArray(params.bins) ||
                params.bins.length === 0
            )
                return false

            // Check that bins are sorted and have valid values
            for (let i = 0; i < params.bins.length; i++) {
                const bin = params.bins[i]

                // Check if bin is a valid array with at least 2 elements
                if (!bin || !Array.isArray(bin) || bin.length < 2) return false

                // Check if the value is a valid number
                if (
                    bin[0] === null ||
                    bin[0] === undefined ||
                    typeof bin[0] !== 'number'
                )
                    return false

                // Check if the color is a valid string
                if (
                    bin[1] === null ||
                    bin[1] === undefined ||
                    typeof bin[1] !== 'string'
                )
                    return false

                // Check that bins are sorted (except for the first one)
                if (i > 0) {
                    const prevBin = params.bins[i - 1]
                    if (prevBin[0] >= bin[0]) return false
                }
            }

            return true
        },
        []
    )

    const isValidNumber = React.useCallback((value: string): boolean => {
        const num = Number(value)
        return !isNaN(num) && isFinite(num)
    }, [])

    // Define all handler functions with useCallback to prevent unnecessary re-renders

    const handleValueChange = React.useCallback(
        (index: number, value: string) => {
            // Just update the local value - we'll apply it on blur or when tabbing/pressing enter
            setLocalValues((prev) => ({ ...prev, [index]: value }))
        },
        []
    )

    const applyValueChange = React.useCallback(
        (index: number, value: string) => {
            // Skip if the value isn't valid
            if (!isValidNumber(value)) return

            const numValue = parseFloat(value)

            // Skip if the value is the same as the current bin value
            // Handle null or undefined bin values
            const currentBinValue = bins[index] && bins[index][0]
            if (
                currentBinValue !== null &&
                currentBinValue !== undefined &&
                numValue === currentBinValue
            )
                return

            // Create a new array of bins with the updated value
            const newBins = [...bins]

            // Ensure the bin exists and has a valid structure
            if (
                !newBins[index] ||
                !Array.isArray(newBins[index]) ||
                newBins[index].length < 2
            ) {
                // Create a default bin if it doesn't exist or is invalid
                newBins[index] = [numValue, '#000000']
            } else {
                newBins[index] = [numValue, newBins[index][1] || '#000000']
            }

            // Sort bins by value
            newBins.sort((a, b) => {
                // Handle null or undefined values in sorting
                const aVal = a[0] !== null && a[0] !== undefined ? a[0] : 0
                const bVal = b[0] !== null && b[0] !== undefined ? b[0] : 0
                return aVal - bVal
            })

            // Update the params
            const newParams = {
                ...localVizParams,
                bins: newBins,
            }

            // Update the params first
            setLocalVizParams(newParams)
            setHasValidationErrors(!validateBinnedParams(newParams))

            // The useEffect will handle updating localValues when bins change
        },
        [
            bins,
            isValidNumber,
            localVizParams,
            setHasValidationErrors,
            setLocalVizParams,
            validateBinnedParams,
        ]
    )

    const handleAddBin = React.useCallback(() => {
        // Get the last value and color, defaulting to 0 and black if they're null or undefined
        let lastValue = 0
        let lastColor = '#000000'

        if (bins.length > 0) {
            const lastBin = bins[bins.length - 1]
            if (lastBin && Array.isArray(lastBin) && lastBin.length >= 2) {
                lastValue =
                    lastBin[0] !== null && lastBin[0] !== undefined
                        ? lastBin[0]
                        : 0
                lastColor =
                    lastBin[1] !== null && lastBin[1] !== undefined
                        ? lastBin[1]
                        : '#000000'
            }
        }

        const newBin: [number, string] = [lastValue + 1, lastColor]

        const newParams = {
            ...localVizParams,
            bins: [...bins, newBin],
        }
        setLocalVizParams(newParams)
        setHasValidationErrors(!validateBinnedParams(newParams))
    }, [
        bins,
        localVizParams,
        setHasValidationErrors,
        setLocalVizParams,
        validateBinnedParams,
    ])

    const handleRemoveBin = React.useCallback(
        (index: number) => {
            // Filter out the bin at the specified index, ensuring we only keep valid bins
            const newBins = bins
                .filter((bin, i) => i !== index)
                .filter((bin) => bin && Array.isArray(bin) && bin.length >= 2)

            const newParams = {
                ...localVizParams,
                bins: newBins,
            }
            setLocalVizParams(newParams)
            setHasValidationErrors(!validateBinnedParams(newParams))
        },
        [
            bins,
            localVizParams,
            setHasValidationErrors,
            setLocalVizParams,
            validateBinnedParams,
        ]
    )

    const handleColorChange = React.useCallback(
        (color: string) => {
            if (editingIndex === null) return

            const newBins = [...bins]
            newBins[editingIndex] = [newBins[editingIndex][0], color]

            const newParams = {
                ...localVizParams,
                bins: newBins,
            }
            setLocalVizParams(newParams)
            setHasValidationErrors(!validateBinnedParams(newParams))
            setColorPickerVisible(false)
        },
        [
            bins,
            editingIndex,
            localVizParams,
            setHasValidationErrors,
            setLocalVizParams,
            validateBinnedParams,
        ]
    )

    const handleLocalColorPickerClick = React.useCallback(
        (event: React.MouseEvent, index: number) => {
            // For binned viz, position the color picker with bottom-left at the clicked position
            // We're using the 'bottom' property of FloatingBox, so we need to set the bottom position
            setColorPickerPosition({
                top: window.innerHeight - event.clientY, // This will be used as 'bottom'
                left: event.clientX,
            })
            setColorPickerVisible(true)
            setEditingIndex(index)
        },
        []
    )

    const handleKeyDown = React.useCallback(
        (event: React.KeyboardEvent, index: number) => {
            const totalInputs = bins.length

            // Helper to focus an input at a specific index
            const focusInput = (targetIndex: number) => {
                event.preventDefault()
                // Always apply the current value before changing focus
                const value = localValues[index]
                if (value !== undefined) {
                    applyValueChange(index, value)
                }

                const input = inputRefs.current[targetIndex]
                if (input) {
                    input.focus()
                }
            }

            // Helper to get the next index with wrapping
            const getNextIndex = (current: number, delta: number) => {
                return (current + delta + totalInputs) % totalInputs
            }

            switch (event.key) {
                case 'ArrowUp':
                    focusInput(getNextIndex(index, -1))
                    break
                case 'ArrowDown':
                    focusInput(getNextIndex(index, 1))
                    break
                case 'Tab':
                    // Apply the current value before moving focus
                    const value = localValues[index]
                    if (value !== undefined) {
                        applyValueChange(index, value)
                    }

                    if (event.shiftKey) {
                        // Shift+Tab moves up
                        focusInput(getNextIndex(index, -1))
                    } else {
                        // Tab moves down or adds new row if at last input
                        if (index === totalInputs - 1) {
                            // Apply value before adding a new bin
                            handleAddBin()
                            // Focus the new bin after it's added
                            setTimeout(() => {
                                const newIndex = inputRefs.current.length - 1
                                if (newIndex >= 0) {
                                    const input = inputRefs.current[newIndex]
                                    if (input) input.focus()
                                }
                            }, 0)
                        } else {
                            focusInput(getNextIndex(index, 1))
                        }
                    }
                    break
                case 'Enter':
                    event.preventDefault()
                    if (localValues[index] !== undefined) {
                        applyValueChange(index, localValues[index])
                    }
                    break
                case 'Escape':
                    event.preventDefault()
                    // Reset to the original value
                    // Get the original value, defaulting to 0 if it's null or undefined
                    const binValue = bins[index] && bins[index][0]
                    const originalValue =
                        binValue !== null && binValue !== undefined
                            ? binValue.toString()
                            : '0'

                    setLocalValues((prev) => ({
                        ...prev,
                        [index]: originalValue,
                    }))
                    break
            }
        },
        [applyValueChange, bins, handleAddBin, localValues]
    )

    return (
        <Box sx={{ minWidth: 120 }}>
            <FormControl fullWidth sx={{ mb: 2 }}>
                <InputLabel id="band-selector-label">Band</InputLabel>
                <Select
                    required
                    labelId="band-selector-label"
                    value={localVizParams.band ?? ''}
                    label="Band"
                    onChange={(event) => {
                        const newParams = {
                            ...localVizParams,
                            band: event.target.value,
                        }
                        setLocalVizParams(newParams)
                        setHasValidationErrors(!validateBinnedParams(newParams))
                    }}
                    sx={{
                        '& fieldset': {
                            borderColor: 'primary.main',
                        },
                    }}
                >
                    {dataset.bands.map((band) => (
                        <MenuItem key={band} value={band}>
                            {band}
                        </MenuItem>
                    ))}
                </Select>
            </FormControl>

            <Typography variant="subtitle1" gutterBottom>
                Color Thresholds
            </Typography>

            <Box sx={{ mb: 2 }}>
                {bins.map(([value, color], index) => (
                    <Box
                        key={`threshold-container-${index}`}
                        sx={{
                            display: 'flex',
                            alignItems: 'center',
                            mb: 1,
                            gap: 1,
                        }}
                    >
                        <Typography variant="body1">&lt;</Typography>
                        <TextField
                            type="number"
                            sx={{
                                flexGrow: 1,
                                '& input::-webkit-outer-spin-button, & input::-webkit-inner-spin-button':
                                    {
                                        display: 'none',
                                    },
                                '& input[type=number]': {
                                    MozAppearance: 'textfield',
                                },
                            }}
                            inputRef={(el) => (inputRefs.current[index] = el)}
                            value={
                                localValues[index] !== undefined
                                    ? localValues[index]
                                    : value !== null && value !== undefined
                                      ? value.toString()
                                      : '0'
                            }
                            key={`threshold-${index}`}
                            onChange={(e) =>
                                handleValueChange(index, e.target.value)
                            }
                            onBlur={() => {
                                // Always try to apply the value on blur
                                if (localValues[index] != null) {
                                    applyValueChange(index, localValues[index])
                                }

                                // If the value is invalid, reset to the original value
                                if (
                                    !localValues[index] ||
                                    !isValidNumber(localValues[index])
                                ) {
                                    // Get the original value, defaulting to 0 if it's null or undefined
                                    const binValue =
                                        bins[index] && bins[index][0]
                                    const originalValue =
                                        binValue !== null &&
                                        binValue !== undefined
                                            ? binValue.toString()
                                            : '0'

                                    setLocalValues((prev) => ({
                                        ...prev,
                                        [index]: originalValue,
                                    }))
                                }
                            }}
                            onKeyDown={(e) => {
                                // Let the handleKeyDown function handle all key events
                                handleKeyDown(e, index)
                            }}
                            error={
                                localValues[index] === undefined ||
                                !isValidNumber(localValues[index])
                            }
                            size="small"
                        />
                        <Typography variant="body1">→</Typography>
                        <Box
                            sx={{
                                width: 36,
                                height: 36,
                                minWidth: 36,
                                minHeight: 36,
                                borderRadius: 1,
                                backgroundColor: color,
                                cursor: 'pointer',
                                border: '1px solid #ccc',
                                flexShrink: 0,
                            }}
                            onClick={(e) =>
                                handleLocalColorPickerClick(e, index)
                            }
                            // Make color box not focusable
                            tabIndex={-1}
                        />
                        <IconButton
                            size="small"
                            onClick={() => handleRemoveBin(index)}
                            sx={{
                                ml: 1,
                                width: 36,
                                height: 36,
                                minWidth: 36,
                                minHeight: 36,
                                flexShrink: 0,
                            }}
                            // Make remove button not focusable
                            tabIndex={-1}
                        >
                            <Typography>×</Typography>
                        </IconButton>
                    </Box>
                ))}

                <Button
                    variant="outlined"
                    onClick={handleAddBin}
                    sx={{ mt: 1 }}
                    size="small"
                    // Make button not focusable
                    tabIndex={-1}
                >
                    Add Threshold
                </Button>
            </Box>

            {colorPickerVisible && (
                <FloatingBox
                    bottom={colorPickerPosition.top}
                    left={colorPickerPosition.left}
                >
                    <ClickAwayListener
                        onClickAway={() => setColorPickerVisible(false)}
                    >
                        <Box>
                            <SketchPicker
                                color={
                                    editingIndex !== null
                                        ? localVizParams.bins[editingIndex][1]
                                        : '#000000'
                                }
                                onChangeComplete={(color) =>
                                    handleColorChange(color.hex)
                                }
                            />
                        </Box>
                    </ClickAwayListener>
                </FloatingBox>
            )}
        </Box>
    )
}

function SingleBandVizParamsEditor({
    dataset,
    localVizParams,
    setLocalVizParams,
    setHasValidationErrors,
}: {
    dataset: Dataset
    localVizParams: PseudocolorRasterVizParams
    setLocalVizParams: (
        vizParams:
            | PseudocolorRasterVizParams
            | ((prev: PseudocolorRasterVizParams) => PseudocolorRasterVizParams)
    ) => void
    setHasValidationErrors: React.Dispatch<React.SetStateAction<boolean>>
}) {

    const validateMinMax = (min: number | null, max: number | null) => {
        if (min === null || max === null) return false
        if (min >= max) return false
        return true
    }

    const updateParam = (
        param: keyof PseudocolorRasterVizParams,
        value: any
    ) => {
        setLocalVizParams({
            ...localVizParams,
            [param]: value,
        })
    }

    const updateMinOrMax = (
        band: string,
        minOrMax: 'min' | 'max',
        value: number | null
    ) => {
        const currentMinMax = localVizParams.minMaxesPerBand[band]
        const newMin = minOrMax === 'min' ? value : currentMinMax[0]
        const newMax = minOrMax === 'max' ? value : currentMinMax[1]

        const newParams = {
            ...localVizParams,
            minMaxesPerBand: {
                ...localVizParams.minMaxesPerBand,
                [band]: [newMin, newMax] as [number, number],
            },
        }
        setLocalVizParams(newParams)
        setHasValidationErrors(!validateMinMax(newMin, newMax))
    }

    const currentBand = localVizParams.band
    const currentMin = localVizParams.minMaxesPerBand[currentBand][0]
    const currentMax = localVizParams.minMaxesPerBand[currentBand][1]

    return (
        <Box sx={{ minWidth: 120 }}>
            <FormControl fullWidth sx={{ mb: 2 }}>
                <InputLabel id="band-selector-label">Band</InputLabel>
                <Select
                    required
                    labelId="band-selector-label"
                    id="band-selector"
                    value={localVizParams.band ?? ''}
                    label="Band"
                    onChange={(event) => {
                        updateParam('band', event.target.value)
                    }}
                    sx={{
                        '& fieldset': {
                            borderColor: 'primary.main',
                        },
                    }}
                    error={localVizParams.band === undefined}
                >
                    {dataset.bands.map((band) => (
                        <MenuItem key={band} value={band}>
                            {band}
                        </MenuItem>
                    ))}
                </Select>
            </FormControl>
            <FormControl fullWidth sx={{ mb: 2 }}>
                <InputLabel id="color-ramp-label ">Color Ramp</InputLabel>
                <Select
                    labelId="color-ramp-label"
                    id="color-ramp"
                    value={localVizParams.colorRamp}
                    label="Color Ramp"
                    onChange={(event) => {
                        updateParam('colorRamp', event.target.value)
                    }}
                    sx={{
                        '& fieldset': {
                            borderColor: 'primary.main',
                        },
                    }}
                >
                    {COLOR_RAMPS.map((colorRamp) => (
                        <MenuItem key={colorRamp} value={colorRamp}>
                            {colorRamp}
                        </MenuItem>
                    ))}
                </Select>
            </FormControl>
            <Box
                sx={{ display: 'flex', justifyContent: 'space-between', mb: 2 }}
            >
                <FormControl sx={{ width: '48%' }}>
                    <TextField
                        required
                        id="min-value"
                        label="Min"
                        type="number"
                        InputLabelProps={{
                            shrink: true,
                        }}
                        sx={{
                            '& fieldset': {
                                borderColor: 'primary.main',
                            },
                        }}
                        value={currentMin ?? ''}
                        onChange={(event) => {
                            const value = event.target.value
                            if (value === '' || !isNaN(parseFloat(value))) {
                                updateMinOrMax(
                                    currentBand,
                                    'min',
                                    value === '' ? null : parseFloat(value)
                                )
                            }
                        }}
                        error={currentMin === null}
                        helperText={currentMin === null ? 'Required' : ''}
                    />
                </FormControl>
                <FormControl sx={{ width: '48%' }}>
                    <TextField
                        required
                        id="max-value"
                        label="Max"
                        type="number"
                        InputLabelProps={{
                            shrink: true,
                        }}
                        sx={{
                            '& fieldset': {
                                borderColor: 'primary.main',
                            },
                        }}
                        value={currentMax ?? ''}
                        onChange={(event) => {
                            const value = event.target.value
                            if (value === '' || !isNaN(parseFloat(value))) {
                                updateMinOrMax(
                                    currentBand,
                                    'max',
                                    value === '' ? null : parseFloat(value)
                                )
                            }
                        }}
                        error={currentMax === null}
                        helperText={currentMax === null ? 'Required' : ''}
                    />
                </FormControl>
            </Box>
            <Box>
                {' '}
                {/* Wrapper needed for disabled Tooltip */}
                <FormControlLabel
                    control={
                        <Checkbox
                            checked={localVizParams?.reverseColormap || false}
                            onChange={(event) => {
                                updateParam(
                                    'reverseColormap',
                                    event.target.checked
                                )
                            }}
                        />
                    }
                    label={<Typography>Reverse Color Ramp</Typography>}
                    sx={{ justifyContent: 'center', width: '100%' }}
                />
            </Box>
        </Box>
    )
}

function MultibandVizParamsEditor({
    dataset,
    localVizParams,
    setLocalVizParams,
    setHasValidationErrors,
}: {
    dataset: Dataset
    localVizParams: MultibandRasterVizParams
    setLocalVizParams: (
        vizParams:
            | MultibandRasterVizParams
            | ((prev: MultibandRasterVizParams) => MultibandRasterVizParams)
    ) => void
    setHasValidationErrors: React.Dispatch<React.SetStateAction<boolean>>
}) {
    const validateMultibandParams = (params: MultibandRasterVizParams) => {
        // Check if all selected bands have valid min/max values
        const selectedBands = [params.red, params.green, params.blue].filter(
            Boolean
        )
        return selectedBands.every((band) => {
            const [min, max] = params.minMaxesPerBand[band]
            return min !== null && max !== null && min < max
        })
    }

    const updateParam = (param: keyof MultibandRasterVizParams, value: any) => {
        const newParams = {
            ...localVizParams,
            [param]: value,
        }
        setLocalVizParams(newParams)
        setHasValidationErrors(!validateMultibandParams(newParams))
    }

    const updateMinOrMax = (
        band: string,
        minOrMax: 'min' | 'max',
        value: number | null
    ) => {
        const newParams = {
            ...localVizParams,
            minMaxesPerBand: {
                ...localVizParams.minMaxesPerBand,
                [band]: [
                    minOrMax === 'min'
                        ? value
                        : localVizParams.minMaxesPerBand[band][0],
                    minOrMax === 'max'
                        ? value
                        : localVizParams.minMaxesPerBand[band][1],
                ] as [number, number],
            },
        }
        setLocalVizParams(newParams)
        setHasValidationErrors(!validateMultibandParams(newParams))
    }

    return (
        <Box sx={{ minWidth: 120 }}>
            {['red', 'green', 'blue'].map((color) => (
                <React.Fragment key={color}>
                    <Divider>
                        <Typography variant="body2">
                            {color.charAt(0).toUpperCase() + color.slice(1)}{' '}
                            Band
                        </Typography>
                    </Divider>
                    <FormControl
                        fullWidth
                        key={color}
                        sx={{ mt: '20px', mb: '10px' }}
                    >
                        <InputLabel id={`${color}-band-label`}>
                            {color.charAt(0).toUpperCase() + color.slice(1)}{' '}
                            Band
                        </InputLabel>
                        <Select
                            required
                            labelId={`${color}-band-label`}
                            value={localVizParams[color] ?? ''}
                            sx={{
                                '& fieldset': {
                                    borderColor: 'primary.main',
                                },
                            }}
                            label={`${
                                color.charAt(0).toUpperCase() + color.slice(1)
                            } Band`}
                            onChange={(event) => {
                                const value = event.target.value
                                updateParam(
                                    color as keyof MultibandRasterVizParams,
                                    value === '' ? undefined : value
                                )
                            }}
                            error={localVizParams[color] === undefined}
                        >
                            {dataset.bands.map((band) => (
                                <MenuItem key={band} value={band}>
                                    {band}
                                </MenuItem>
                            ))}
                        </Select>
                        {localVizParams[color] !== undefined && (
                            <Box
                                sx={{
                                    display: 'flex',
                                    justifyContent: 'space-between',
                                    mt: 2,
                                }}
                            >
                                <TextField
                                    label="Min"
                                    type="number"
                                    value={
                                        localVizParams.minMaxesPerBand[
                                            localVizParams[color]
                                        ][0] ?? ''
                                    }
                                    onChange={(e) =>
                                        updateMinOrMax(
                                            localVizParams[color],
                                            'min',
                                            e.target.value === ''
                                                ? null
                                                : Number(e.target.value)
                                        )
                                    }
                                    sx={{
                                        width: '48%',
                                        '& fieldset': {
                                            borderColor: 'primary.main',
                                        },
                                    }}
                                />
                                <TextField
                                    label="Max"
                                    type="number"
                                    value={
                                        localVizParams.minMaxesPerBand[
                                            localVizParams[color]
                                        ][1] ?? ''
                                    }
                                    onChange={(e) =>
                                        updateMinOrMax(
                                            localVizParams[color],
                                            'max',
                                            e.target.value === ''
                                                ? null
                                                : Number(e.target.value)
                                        )
                                    }
                                    sx={{
                                        width: '48%',
                                        '& fieldset': {
                                            borderColor: 'primary.main',
                                        },
                                    }}
                                />
                            </Box>
                        )}
                    </FormControl>
                </React.Fragment>
            ))}
        </Box>
    )
}

function VizParamsEditor({
    dataset,
    flyToDatasetBounds,
    hideControls,
}: {
    dataset: Dataset
    flyToDatasetBounds: (dataset: Dataset) => void
    hideControls?: boolean
}) {

    const { dispatch, updateVizParams } = useMapContext()
    // locally cache the viz params by viz type, so we can switch between viz types without losing the color params
    const [vizParamsForType, setVizParamsForType] = React.useState<{
        [vizType: string]: VizParams
    }>({})
    const [localVizType, setLocalVizType] = useState<VizType>(dataset.vizType)
    const [localVizParams, setLocalVizParams] = useState<VizParams>(
        dataset.vizParams
    )
    const [hasValidationErrors, setHasValidationErrors] = useState(false)
    const [copyDialogOpen, setCopyDialogOpen] = useState(false)

    // Modify the useEffect to only update if validation passes
    useEffect(() => {
        if (localVizParams == dataset.vizParams) {
            return
        }
        if (!hasValidationErrors) {
            const timer = setTimeout(() => {
                updateVizParams(dataset, localVizParams, localVizType)
            }, 500)
            return () => clearTimeout(timer)
        }
    }, [localVizParams, hasValidationErrors])

    const vizParamsCacheKey = `viz_params_${dataset.id}`
    const vizParamsCache = CacheManager.getItem(vizParamsCacheKey)
    const cachedVizParams: Record<VizType, VizParams> = vizParamsCache
        ? JSON.parse(vizParamsCache)
        : {}

    const updateLocalVizParamsByType = (
        vizType: VizType,
        vizParams: VizParams
    ) => {
        setVizParamsForType({ ...vizParamsForType, [vizType]: vizParams })
        setLocalVizParams(vizParams)
        setLocalVizType(vizType)
        CacheManager.setItem(
            vizParamsCacheKey,
            JSON.stringify(vizParamsForType)
        )
    }

    const handleVizTypeChange = (e) => {
        const newVizType = e.target.value as VizType

        // first cache the current viz params for the current viz type
        updateLocalVizParamsByType(dataset.vizType, dataset.vizParams)

        // see if we have cached viz params for the new viz type
        let newVizParams = cachedVizParams[newVizType]
        if (!newVizParams) {
            newVizParams = getDefaultRasterVizParams(dataset, newVizType)
        }

        // set the viz params for the new viz type
        updateLocalVizParamsByType(newVizType, newVizParams)
        updateVizParams(dataset, newVizParams, newVizType)
    }

    const updateDatasetOpacity = (datasetId, opacity) => {
        dispatch({
            type: 'UPDATE_DATASET_OPACITY',
            datasetVersionId: datasetId,
            opacity: opacity,
        })
    }

    let vizTypeText = ''
    let vizElement = null

    switch (localVizType) {
        case 'vector':
            vizTypeText = 'Vector'
            vizElement = (
                <VectorVizParamsEditor
                    dataset={dataset}
                    setHasValidationErrors={setHasValidationErrors}
                />
            )
            break
        case 'continuous_singleband_raster':
            vizTypeText = 'Raster'
            vizElement = (
                <SingleBandVizParamsEditor
                    dataset={dataset}
                    localVizParams={localVizParams as PseudocolorRasterVizParams}
                    setLocalVizParams={setLocalVizParams}
                    setHasValidationErrors={setHasValidationErrors}
                />
            )
            break
        case 'continuous_multiband_raster':
            vizTypeText = 'Raster'
            vizElement = (
                <MultibandVizParamsEditor
                    dataset={dataset}
                    localVizParams={localVizParams as MultibandRasterVizParams}
                    setLocalVizParams={setLocalVizParams}
                    setHasValidationErrors={setHasValidationErrors}
                />
            )
            break
        case 'categorical_raster':
            vizTypeText = 'Raster (categorical)'
            vizElement = (
                <CategoricalVizParamsEditor
                    dataset={dataset}
                    setHasValidationErrors={setHasValidationErrors}
                />
            )
            break
        case 'binned_raster':
            vizTypeText = 'Raster (binned)'
            vizElement = (
                <BinnedVizParamsEditor
                    dataset={dataset}
                    localVizParams={localVizParams as BinnedRasterVizParams}
                    setLocalVizParams={setLocalVizParams}
                    setHasValidationErrors={setHasValidationErrors}
                />
            )
            break
    }

    const isRaster: boolean =
        dataset.vizType === 'continuous_singleband_raster' ||
        dataset.vizType === 'continuous_multiband_raster' ||
        dataset.vizType === 'categorical_raster' ||
        dataset.vizType === 'binned_raster'

    return (
        <Box
            sx={{
                display: 'flex',
                alignItems: 'center',
                width: '90%',
            }}
        >
            {!hideControls && (
                <Box sx={{ alignSelf: 'flex-start' }}>
                    <Tooltip title="Go back (Esc)" placement="bottom">
                        <IconButton
                            onClick={() => {
                                dispatch({
                                    type: 'TOGGLE_DATASET_EDITING',
                                    datasetVersionId: dataset.id,
                                })
                            }}
                            aria-label="back"
                        >
                            <ArrowBack
                                fontSize="inherit"
                                className={VizParamsEditorStyle.icon}
                            />
                        </IconButton>
                    </Tooltip>
                </Box>
            )}
            <Box width={'100%'}>
                {/* Header line */}
                <Box
                    sx={{
                        display: 'flex',
                        justifyContent: 'space-between',
                        alignItems: 'center',
                        width: '100%',
                        marginBottom: '10px',
                    }}
                >
                    <Typography
                        variant="h6"
                        textAlign="center"
                        paddingTop={'2px'}
                    >
                        Editing Visualization
                    </Typography>
                    {!hideControls && (
                        <Box
                            sx={{
                                display: 'flex',
                                alignItems: 'center',
                            }}
                        >
                            <ZoomButton
                                dataset={dataset}
                                flyToDatasetBounds={flyToDatasetBounds}
                            />
                            <Tooltip title="Copy visualization style">
                                <IconButton
                                    onClick={() => setCopyDialogOpen(true)}
                                >
                                    <ContentCopy
                                        className={VizParamsEditorStyle.icon}
                                    />
                                </IconButton>
                            </Tooltip>
                            <VisibilityButton dataset={dataset} />
                        </Box>
                    )}
                </Box>
                {/* Dataset description */}
                <Box
                    sx={{
                        width: '100%',
                        display: 'flex',
                        flexDirection: 'column',
                    }}
                >
                    <Typography variant="body1">{dataset.name}</Typography>
                    <Typography variant="body2">{vizTypeText}</Typography>
                </Box>
                <Divider sx={{ marginBottom: '2px', marginTop: '2px' }} />
                <Stack
                    direction="row"
                    alignItems="center"
                    spacing={1}
                    sx={{ marginRight: '25px' }}
                >
                    <Typography variant="body2">Opacity</Typography>
                    <Slider
                        value={dataset.opacity}
                        onChange={(e, newValue) => {
                            updateDatasetOpacity(dataset.id, newValue as number)
                        }}
                        aria-labelledby="opacity-slider"
                        valueLabelDisplay="auto"
                        size={'small'}
                        step={1}
                        min={0}
                        max={100}
                    />
                </Stack>
                {isRaster && (
                    <TransparencyFillLayerVizParamsEditor
                        localVizParams={localVizParams}
                        setLocalVizParams={setLocalVizParams}
                        setHasValidationErrors={setHasValidationErrors}
                    />
                )}
                {/* Actual viz param settings */}
                {isRaster && (
                    <FormControl fullWidth sx={{ mt: 2, mb: 2 }}>
                        <InputLabel id="viz-type-label">
                            Visualization Type
                        </InputLabel>
                        <Select
                            labelId="viz-type-label"
                            value={dataset.vizType}
                            label="Visualization Type"
                            sx={{
                                '& fieldset': {
                                    borderColor: 'primary.main',
                                },
                            }}
                            onChange={handleVizTypeChange}
                        >
                            <MenuItem
                                value="continuous_singleband_raster"
                                sx={{ fontSize: '0.85em' }}
                            >
                                Single Band
                            </MenuItem>
                            <MenuItem
                                value="continuous_multiband_raster"
                                sx={{ fontSize: '0.85em' }}
                            >
                                Multiband
                            </MenuItem>
                            <MenuItem
                                value="binned_raster"
                                sx={{ fontSize: '0.85em' }}
                            >
                                Binned
                            </MenuItem>
                        </Select>
                    </FormControl>
                )}
                {vizElement}
            </Box>
            <CopyVizParamsDialog
                open={copyDialogOpen}
                onClose={() => setCopyDialogOpen(false)}
                sourceDataset={dataset}
            />
        </Box>
    )
}

export default VizParamsEditor
