import { create } from 'zustand';
import { BallTree, BallTreeNode } from '~/models';
import { useViewer } from '~/state/useViewer.tsx';
import { BoundingSphere, Cartesian3, Color, HeadingPitchRange, Math as CesiumMath } from 'cesium';
import { useWebSocket } from '~/state/useWebSocket.tsx';
import { usePlaces } from '~/state/usePlaces.tsx';
import { BallTreeResponse, IResponse } from '@types';

interface BallTreeViewerState {
    ballTree?: BallTree;
    selectedNode?: BallTreeNode;
}

interface BallTreeViewerActions {
    loadBallTree: (ballTree: BallTree) => void;
    selectNode: (node: BallTreeNode) => void;
    clearSelection: () => void;
    drawBallTree: () => void;
    fitBallTreeInView: (callback?: () => void) => void;
    checkTreeFitInView: () => boolean;
    makeBallTreeRequest: (callback: () => void) => void;
}

export const useBallTreeViewer = create<BallTreeViewerState & BallTreeViewerActions>((set, get) => ({
    ballTree: undefined,
    selectedNode: undefined,

    loadBallTree: (ballTree) => set({ ballTree }),
    selectNode: (node) => set({ selectedNode: node }),
    clearSelection: () => set({ selectedNode: undefined }),
    drawBallTree: () => {
        const viewer = useViewer.getState().viewer;
        if (!viewer) return;

        const drawNode = (node: BallTreeNode, depth: number = 0) => {
            if (!node) return;

            const color = node.isLeaf ? Color.DARKBLUE : Color.SANDYBROWN;
            const outlineColor = node.isLeaf ? Color.AQUAMARINE : Color.YELLOW;
            const size = node.isLeaf ? 10 : 15;

            if (node.isLeaf) {
                viewer.entities.add({
                    position: node.centroid.cartesian3,
                    ellipse: {
                        semiMinorAxis: node.radius,
                        semiMajorAxis: node.radius,
                        material: Color.RED.withAlpha(0.5),
                        outline: true,
                        outlineColor: Color.BLACK,
                        outlineWidth: 2,
                    },
                    point: {
                        pixelSize: size,
                        color: color,
                        outlineColor: outlineColor,
                        outlineWidth: 2,
                    }
                });
            }

            // Recursively draw children
            if (!node.isLeaf && node.left && node.right) {
                drawNode(node.left, depth + 1);
                drawNode(node.right, depth + 1);
            }
        };

        const tree = get().ballTree;
        if (!tree || !tree.root) return;
        drawNode(tree.root);
    },

    fitBallTreeInView: (callback) => {
        const viewer = useViewer.getState().viewer;
        const tree = get().ballTree;

        if (!viewer || !tree || !tree.root) return;

        // Function to recursively collect all node positions in the Ball Tree
        const collectPositions = (node: BallTreeNode, positions: Cartesian3[]) => {
            if (node.isLeaf) {
                positions.push(node.centroid.cartesian3);
            } else {
                if (node.left) collectPositions(node.left, positions);
                if (node.right) collectPositions(node.right, positions);
            }
        };

        // Collect all positions from the Ball Tree
        const positions: Cartesian3[] = [];
        collectPositions(tree.root, positions);

        // Calculate the bounding sphere using collected positions
        const boundingSphere = BoundingSphere.fromPoints(positions);
        if (!boundingSphere) return;

        // Fly to the calculated bounding sphere with a suitable view offset
        viewer.camera.flyToBoundingSphere(boundingSphere, {
            duration: 2.0,  // Duration in seconds
            offset: new HeadingPitchRange(0, CesiumMath.toRadians(-90), boundingSphere.radius * 4)
        });

        callback?.();
    },

    checkTreeFitInView: () => {
        const tree = get().ballTree;
        if (!tree || !tree.root) return true;

        const maxDist = 1e7;  // Define the maximum acceptable distance
        const positions: Cartesian3[] = [];

        // Function to recursively collect positions
        const collectPositions = (node: BallTreeNode) => {
            if (node.isLeaf) {
                if (!node.points) return;
                for (let i = 0; i < node.points.length; ++i) positions.push(node.points[i].cartesian3);
            } else {
                if (node.left) collectPositions(node.left);
                if (node.right) collectPositions(node.right);
            }
        };

        collectPositions(tree.root);

        if (positions.length <= 1) return true; // Single point or no points trivially fit

        const p = positions[0];
        for (let i = 1; i < positions.length; ++i) {
            if (Cartesian3.distance(p, positions[i]) > maxDist) return false;
        }
        return true;
    },

    makeBallTreeRequest: (callback) => {
        const sendData = useWebSocket.getState().sendData;
        sendData({ type: 'ballTree', places: usePlaces.getState().places }, (response: IResponse) => {
            console.log(response);
            if (isBallTreeResponse(response)) {
                get().loadBallTree(response.ballTree);
                callback?.();
            }
        })
    }
}));

const isBallTreeResponse = (response: IResponse) : response is BallTreeResponse => {
    return response.ballTree !== undefined;
}
