import { useEffect, useMemo, useRef, useState } from "react";
import * as d3 from "d3";
import React from "react";

const MARGIN = { top: 30, right: 30, bottom: 50, left: 50 };

type Group = {
  x: string;
} & { [key: string]: number };

type StackedBarplotProps = {
  width: number;
  height: any;
  data: any;
  columns: any
  colors: any,
  onclick: any
};

// Function to find the maximum sum among keys
function sumObjectKeys(obj: any) {
  let sum = 0;
  for (const key in obj) {
    if (obj.hasOwnProperty(key) && typeof obj[key] === 'number') {
      sum += obj[key];
    }
  }
  return sum;
}

export const StackedBarplot = ({
  width,
  height,
  data,
  columns,
  colors,
  onclick
}: StackedBarplotProps) => {
  const [tooltip, setTooltip] = useState({ display: false, data: {} as any, x: 0, y: 0 });

  // bounds = area inside the graph axis = calculated by subtracting the margins
  const axesRef = useRef(null);
  const boundsWidth = width - MARGIN.right - MARGIN.left;
  const boundsHeight = height - MARGIN.top - MARGIN.bottom;

  const allGroups = data.map((d: any) => String(d.x));
  const allSubgroups = columns;
  
  // Data Wrangling: stack the data
  const stackSeries = d3.stack().keys(allSubgroups).order(d3.stackOrderAscending);
  let series = stackSeries(data);

  let max = -Infinity;
  for (const obj of data) {
    const sum = sumObjectKeys(obj);
    if (sum > max) {
      max = sum;
    }
  }
  // Y axis
  const yScale = useMemo(() => {
    return d3
      .scaleLinear()
      .domain([0, max || 0])
      .range([boundsHeight, 0]);
  }, [data, height]);

  // X axis
  const xScale = useMemo(() => {
    return d3
      .scaleBand<string>()
      .domain(allGroups)
      .range([0, boundsWidth])
      .padding(0.05);
      
  }, [data, width]);

  const svgRef = useRef<SVGSVGElement | null>(null);
  const svg = d3.select(svgRef.current);
    svg.selectAll("*").remove();

  svg.append("g")
  .attr("transform", "translate(0," + height + ")")
  .call(d3.axisBottom(xScale))
  .selectAll("text")
    .attr("transform", "translate(-10,0)rotate(-45)")
    .style("text-anchor", "end");

  // Color Scale
  var colorScale = d3
    .scaleOrdinal<string>()
    .domain(allGroups)
    .range(colors);

  // Show tooltip
  const showTooltip = (event: any, d: any) => {
    const [x, y] = d3.pointer(event);
    setTooltip({ display: true, data: d, x, y });
  };

  // Hide tooltip
  const hideTooltip = () => {
    setTooltip({ display: false, data: {}, x: 0, y: 0 });
  };

  // Render the X and Y axis using d3.js, not react
  useEffect(() => {
    const svgElement = d3.select(axesRef.current);
    svgElement.selectAll("*").remove();
    const xAxisGenerator = d3.axisBottom(xScale);
    svgElement
    .append("g")
    .attr("transform", `translate(0, ${boundsHeight})`)
    .call(xAxisGenerator)
    .selectAll("text")
    .attr("transform", "rotate(-45)")
    .style("text-anchor", "end")
    .attr("dx", "-0.8em")
    .attr("dy", "0.15em");

    const yAxisGenerator = d3.axisLeft(yScale);
    svgElement.append("g").call(yAxisGenerator);
  }, [xScale, yScale, boundsHeight]);

  

  const rectangles = series.map((subgroup, i) => {
    return (
      <g key={i}>
        {subgroup.map((group, j) => {
          return (
            <rect
              key={j}
              x={xScale(group.data.x.toString())}
              y={yScale(group[1])}
              height={isNaN(yScale(group[0]) - yScale(group[1])) ? 0 : yScale(group[0]) - yScale(group[1])}
              width={xScale.bandwidth()}
              fill={colorScale(subgroup.key)}
              opacity={0.9}
              onMouseOver={(event) => {
                tooltip.display = true;
                showTooltip(event, { x: group.data.x, subgroup: subgroup.key, value: group.data[subgroup.key] })
              }}
              onMouseOut={hideTooltip}
              onClick={(event) => {
                onclick({ x: group.data.x, subgroup: subgroup.key, value: group.data[subgroup.key] })
              }}
            ></rect>
          );
        })}
      </g>
    );
  });

  return (
    <div style={{ position: "relative" }}>
      <svg width={width} height={height}>
        <g
          width={boundsWidth}
          height={boundsHeight}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(",")})`}
        >
          {rectangles}
        </g>
        <g
          width={boundsWidth}
          height={boundsHeight}
          ref={axesRef}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(",")})`}
        />
      </svg>
      {tooltip.display && (
        <div
          style={{
            left: tooltip.x + "px",
            top: tooltip.y + "px",
            width: "300px",
            backgroundColor: "white",
            padding: "5px",
            border: "1px solid gray",
            borderRadius: "3px",
          }}
        >
          <p>
            {tooltip.data?.subgroup}
            <br></br>
            {tooltip?.data?.x}
            <br></br>{tooltip.data.value!}
          </p>
        </div>
      )}
    </div>
  );
};
