import { Edge, Node } from 'reactflow';
import * as d3 from 'd3-hierarchy';

export function autoLayout(nodes: Node[], edges: Edge[]) {
  const root = d3
    .stratify<Node>()
    .id((d) => d.id)
    .parentId((d) => edges.find((e) => e.target === d.id)?.source)(nodes);

  // Calculate the depth of the tree

  const startingNode = nodes.find((node) => node.type === 'trigger') || null;

  const width = maxHeight(startingNode, edges, nodes);

  const leafsCount = countLeafs(startingNode, edges, nodes);

  // Define node size
  const nodeHeight = 150;

  // Calculate the size of the tree for horizontal layout
  const height = leafsCount * nodeHeight;

  const treeLayout = d3.tree().size([height, (width - 1) * 500]);

  // Apply the tree layout to the root
  //@ts-ignore
  const layout = treeLayout(root);

  // Convert the layout to the format expected by React Flow
  const positionedNodes = nodes.map((node) => {
    const d3Node = layout.descendants().find((n) => n.id === node.id);
    return {
      ...node,
      //@ts-ignore
      position: { x: d3Node.y, y: d3Node.x },
    };
  });

  return positionedNodes;
}

function countLeafs(
  startNode: Node | null,
  edges: Edge[],
  nodes: Node[],
  count = 0
) {
  if (!startNode) {
    return count;
  }

  const nexts = edges.filter((edge) => edge.source === startNode.id);

  if (nexts.length === 0) {
    return count + 1;
  }

  for (const next of nexts) {
    const nextNode = nodes.find((node) => node.id === next.target) || null;
    count = countLeafs(nextNode, edges, nodes, count);
  }

  return count;
}

function maxHeight(
  startNode: Node | null,
  edges: Edge[],
  nodes: Node[]
): number {
  if (!startNode) {
    return 0;
  }

  const nexts = edges.filter((edge) => edge.source === startNode.id);

  if (nexts.length === 0) {
    return 1; // Leaf node, height is 1
  }

  const heights = [];
  for (const next of nexts) {
    const nextNode = nodes.find((node) => node.id === next.target) || null;
    heights.push(maxHeight(nextNode, edges, nodes));
  }

  return Math.max(...heights) + 1; // Add 1 for the current node
}
