import { HierarchyPointNode, stratify, tree } from 'd3-hierarchy';
import { timer } from 'd3-timer';
import { useEffect, useRef } from 'react';
import {
  Edge,
  Node,
  ReactFlowState,
  useReactFlow,
  useStore,
  useUpdateNodeInternals,
} from 'reactflow';
import { StepsConfig } from 'pages/app-pages/chatbot-builder-page/nodes/config';
import { StepCategory } from 'pages/app-pages/chatbot-builder-page/nodes/types';
import {
  LayoutDirection,
  LayoutDirectionStore,
  useBuilderStore,
  useLayoutDirection,
} from 'pages/app-pages/chatbot-builder-page/store';
import { BaseNodeProps, Mention } from 'types';
import { defaultMentions } from 'utils';

// initialize the tree layout (see https://observablehq.com/@d3/tree for examples)

function getAllMentionIdsFromNodes(nodes: Node<BaseNodeProps>[]): string[] {
  const allMentionIds: string[] = [];

  nodes.forEach((node) => {
    // Extract mentions from node.data.mentions
    if (node.data?.mentions) {
      const mentionIds = node.data.mentions.map((mention) => mention.id);
      allMentionIds.push(...mentionIds);
    }

    // Extract mention from node.data.formData.variable
    if (node.data?.formData?.variable) {
      allMentionIds.push(node.data.formData.variable);
    }

    // NEW: Extract mentions from StepsConfig if a selectedStep is present
    if (node.data?.selectedStep) {
      const stepMentions = StepsConfig[node.data.selectedStep].mentions;
      allMentionIds.push(...stepMentions.map((mention) => mention.id));
    }
  });

  // Filter out duplicates and ensure the mentions/variables exist within the nodes.
  return [...new Set(allMentionIds)].filter((mentionId) =>
    nodes.some(
      (node) =>
        (node.data?.mentions && node.data.mentions.some((mention) => mention.id === mentionId)) ||
        node.data?.formData?.variable === mentionId ||
        (node.data?.selectedStep &&
          StepsConfig[node.data.selectedStep].mentions.some((mention) => mention.id === mentionId)),
    ),
  );
}

const layout = (direction: LayoutDirection) => {
  const nodeSize: [number, number] =
    direction === LayoutDirection.Horizontal ? [200, 400] : [300, 250];
  return (
    tree<Node<BaseNodeProps>>()
      // the node size configures the spacing between the nodes ([width, height])
      .nodeSize(nodeSize)
      // this is needed for creating equal space between all nodes
      .separation(() => 1)
  );
};

const options = { duration: 300 };

// the layouting function
// accepts current nodes and edges and returns the layouted nodes with their updated positions
function layoutNodes(
  nodes: Node<BaseNodeProps>[],
  edges: Edge[],
  direction: LayoutDirection = LayoutDirection.Horizontal,
): Node<BaseNodeProps>[] {
  if (nodes.length === 0) {
    return [];
  }

  const hierarchy = stratify<Node<BaseNodeProps>>()
    .id((d) => d.id)
    .parentId((d) => edges.find((e: Edge) => e.target === d.id)?.source)(nodes);

  const root = layout(direction)(hierarchy);

  const traverseAndBuild = (
    node: HierarchyPointNode<Node<BaseNodeProps>>,
    accumulatedMentions: Mention[] = [],
    prevNodeCategory: StepCategory | null = null, // New parameter to keep track of the previous node's category
  ) => {
    const mentionsForThisNode = [...accumulatedMentions];

    let newMentionsForChildren: Mention[] = [];

    // Extract mentions based on selectedStep
    if (node.data?.data.selectedStep) {
      newMentionsForChildren.push(...StepsConfig[node.data.data.selectedStep].mentions);
    }

    // Extract mentions from formData.variable
    if (node.data?.data.formData?.variable) {
      newMentionsForChildren.push({
        id: node.data.data.formData.variable,
        value: node.data.data.formData.variable,
      });
    }

    const currentCategory = node.data?.data.selectedStep
      ? StepsConfig[node.data.data.selectedStep].category
      : null;

    if (
      currentCategory === StepCategory.MultiNodeInput &&
      prevNodeCategory === StepCategory.MultiNodeInput
    ) {
      // Modify the node data to include backButton: true if the conditions are met
      node.data.data = { ...node.data.data, backButton: true };
    }

    let result = [
      {
        ...node.data,
        data: {
          ...node.data.data,
          step: node.depth,
          label: node.depth === 0 ? node.data.data.label : `Step ${node.depth}`,
          mentions: mentionsForThisNode,
        },
        position:
          direction === LayoutDirection.Horizontal
            ? { x: node.y, y: node.x }
            : { x: node.x, y: node.y },
      },
    ];

    if (node.children) {
      node.children.forEach((child) => {
        const allExistingMentionIds = getAllMentionIdsFromNodes(nodes);
        newMentionsForChildren = newMentionsForChildren.filter((mention) =>
          allExistingMentionIds.includes(mention.id),
        );

        const mentionsForChild = [...mentionsForThisNode, ...newMentionsForChildren].filter(
          (mention, index, self) => index === self.findIndex((m) => m.id === mention.id),
        );
        result = result.concat(traverseAndBuild(child, mentionsForChild, currentCategory));
      });
    }

    return result;
  };

  return traverseAndBuild(root, defaultMentions);
}

// this is the store selector that is used for triggering the layout, this returns the number of nodes once they change
const nodeCountSelector = (state: ReactFlowState) => state.nodeInternals.size;
const layoutDirectionSelector = (state: LayoutDirectionStore) => state.layoutDirection;

function useLayout() {
  // we are using nodeCount as the trigger for the re-layouting
  // whenever the nodes length changes, we calculate the new layout
  const nodeCount = useStore(nodeCountSelector);
  const layoutDirection = useLayoutDirection(layoutDirectionSelector);
  const updateNodeInternals = useUpdateNodeInternals();
  const { actionCount } = useBuilderStore((state) => ({
    actionCount: state.actionCount,
  }));

  const prevLayoutDirection = useRef(LayoutDirection.Horizontal);

  const { getNodes, getNode, setNodes, setEdges, getEdges, setCenter } = useReactFlow();

  useEffect(() => {
    // get the current nodes and edges
    const nodes = getNodes();
    const edges = getEdges();

    // Identify empty nodes
    const emptyNodeIds = nodes
      .filter((node) => node.data.selectedStep === null)
      .map((node) => node.id);

    // Update edge types accordingly
    const updatedEdges = edges.map((edge) => {
      if (emptyNodeIds.includes(edge.target)) {
        // For edges connected to an empty node, set type to 'placeholder'
        return { ...edge, type: 'default' };
      }
      // Ensure all other edges are of type 'workflow'
      return { ...edge, type: 'workflow' };
    });

    // run the layout and get back the nodes with their updated positions
    const targetNodes = layoutNodes(nodes, updatedEdges, layoutDirection);

    // if you do not want to animate the nodes, you can uncomment the following line

    if (layoutDirection !== prevLayoutDirection.current) {
      prevLayoutDirection.current = layoutDirection;
      updateNodeInternals(targetNodes.map((node) => node.id));

      const xCenter = layoutDirection === LayoutDirection.Horizontal ? 200 : 0;
      const yCenter = layoutDirection === LayoutDirection.Horizontal ? 0 : 125;
      setCenter(xCenter, yCenter, { zoom: 1, duration: 800 });

      return setNodes(targetNodes);
    }

    // to interpolate and animate the new positions, we create objects that contain the current and target position of each node
    const transitions = targetNodes.map((node) => ({
      id: node.id,
      // this is where the node currently is placed
      from: getNode(node.id)?.position || node.position,
      // this is where we want the node to be placed
      to: node.position,
      node,
    }));

    // create a timer to animate the nodes to their new positions
    const t = timer(async (elapsed: number) => {
      const s = elapsed / options.duration;

      const currNodes = transitions.map(({ node, from, to }) => ({
        id: node.id,
        position: {
          // simple linear interpolation
          x: from.x + (to.x - from.x) * s,
          y: from.y + (to.y - from.y) * s,
        },
        data: { ...node.data },
        type: node.type,
        selected: node.selected,
      }));

      setNodes(currNodes);
      setEdges(updatedEdges); // This line applies the updated edge types

      // change target and source position of the edges

      // this is the final step of the animation
      if (elapsed > options.duration) {
        // we are moving the nodes to their destination
        // this needs to happen to avoid glitches
        const finalNodes = transitions.map(({ node, to }) => ({
          id: node.id,
          position: {
            x: to.x,
            y: to.y,
          },
          data: { ...node.data },
          type: node.type,
          selected: node.selected,
        }));
        setNodes(finalNodes);
        // stop the animation
        t.stop();
      }
    });

    return () => {
      t.stop();
    };
  }, [
    actionCount,
    nodeCount,
    layoutDirection,
    getEdges,
    getNodes,
    getNode,
    setNodes,
    setEdges,
    setCenter,
    updateNodeInternals,
  ]);
}

export default useLayout;
