import React, { useState, useEffect } from 'react';
import { useAuth } from '../authcontext';
import moment from 'moment';

function CustomPivotTable({ selectedFields, filteredData, dateField, sumField, topNValues, setTableDataCallback }) {
  const [tableData, setTableData] = useState([]);
  const [visiblePeriods, setVisiblePeriods] = useState([]);
  const [expandedRows, setExpandedRows] = useState(new Set());
  const { currentModel } = useAuth();
  const [histPeriods, setHistPeriods] = useState([]);
  const [forePeriods, setForePeriods] = useState([]);

  useEffect(() => {
    // Fetch historical and forecast headers from currentModel based on basis
    if (!currentModel) return;

    let histHeaders = [];
    let foreHeaders = [];

    // Safely parse headers from JSON strings if necessary
    try {
      switch (currentModel.basis) {
        case 'Yearly':
          histHeaders = currentModel.annualHistHeaders ? JSON.parse(currentModel.annualHistHeaders) : [];
          foreHeaders = currentModel.annualForeHeaders ? JSON.parse(currentModel.annualForeHeaders) : [];
          break;
        case 'Quarterly':
          histHeaders = currentModel.quarterlyHistHeaders ? JSON.parse(currentModel.quarterlyHistHeaders) : [];
          foreHeaders = currentModel.quarterlyForeHeaders ? JSON.parse(currentModel.quarterlyForeHeaders) : [];
          break;
        case 'Monthly':
          histHeaders = currentModel.monthlyHistHeaders ? JSON.parse(currentModel.monthlyHistHeaders) : [];
          foreHeaders = currentModel.monthlyForeHeaders ? JSON.parse(currentModel.monthlyForeHeaders) : [];
          break;
        default:
          console.error('Invalid basis:', currentModel.basis);
          return;
      }
    } catch (error) {
      console.error('Error parsing headers from JSON:', error);
      return;
    }

    // Extract the last X historical periods and first Y forecast periods
    const extractedHistPeriods = histHeaders.slice(-currentModel.histPeriods).map(header => ({
      label: header[0], // The label, e.g., "2015" or "Q1 2015"
      fromDate: header[1], // The from date, e.g., "20/09/2014"
      toDate: header[2], // The to date, e.g., "19/09/2015"
    }));
    const extractedForePeriods = foreHeaders.slice(0, currentModel.forePeriods).map(header => ({
      label: header[0], // The forecast label, e.g., "2025"
    }));

    setHistPeriods(extractedHistPeriods);
    setForePeriods(extractedForePeriods);
  }, [currentModel]);

  useEffect(() => {
    if (!filteredData || !filteredData.rows || !filteredData.headers || !dateField || !sumField || !histPeriods.length) {
      setTableData([]);
      setVisiblePeriods([]);
      return;
    }

    const { headers, rows } = filteredData;
    const { aggregatedData, uniquePeriods } = aggregateData(headers, rows, selectedFields, dateField, sumField, histPeriods, forePeriods);

    const restrictedData = applyTopN(aggregatedData, selectedFields, topNValues);

    // Only update if the aggregatedData has actually changed
    if (JSON.stringify(aggregatedData) !== JSON.stringify(tableData)) {
      setTableData(aggregatedData);
      setVisiblePeriods(uniquePeriods);

      // Pass both restrictedData and visiblePeriods back to the parent
      if (setTableDataCallback) {
        setTableDataCallback({ restrictedData, visiblePeriods: uniquePeriods });
      }
    }

  }, [filteredData, selectedFields, dateField, sumField, topNValues, histPeriods, forePeriods, setTableDataCallback]);

  const toggleRowExpansion = (rowKey) => {
    setExpandedRows((prev) => {
      const newExpandedRows = new Set(prev);
      if (newExpandedRows.has(rowKey)) {
        newExpandedRows.delete(rowKey);
      } else {
        newExpandedRows.add(rowKey);
      }
      return newExpandedRows;
    });
  };

  const renderRows = (data, depth = 0) => {
    // Sort the keys based on the total sum in descending order, but ensure "Other" is always last
    const sortedKeys = Object.keys(data).sort((a, b) => {
      if (a === 'Other') return 1;
      if (b === 'Other') return -1;
  
      const sumA = Object.values(data[a].total).reduce((acc, period) => acc + period.sum, 0);
      const sumB = Object.values(data[b].total).reduce((acc, period) => acc + period.sum, 0);
      return sumB - sumA;
    });
  
    return sortedKeys.map((key) => {
      const row = data[key];
      const isBottomLevel = depth === selectedFields.length - 1;
      const isExpanded = expandedRows.has(key);
  
      return (
        <React.Fragment key={key}>
          <tr>
            <td style={{ paddingLeft: `${depth * 20}px` }} onClick={() => !isBottomLevel && toggleRowExpansion(key)}>
              {!isBottomLevel ? (isExpanded ? '▼' : '►') : ''} {key}
            </td>
            {visiblePeriods.map((period) => (
              <td key={period}>
                {row.total && row.total[period]
                  ? `${row.total[period].sum.toFixed(2)}`
                  : '0.00'}
              </td>
            ))}
          </tr>
          {!isBottomLevel && isExpanded && renderRows(row.children, depth + 1)}
        </React.Fragment>
      );
    });
  };

  return (
    <div className="custom-pivot-table">
      <table className="data-table">
        <thead>
          <tr>
            <th>Category</th>
            {visiblePeriods.map((period, index) => (
              <th key={index}>{period}</th>
            ))}
          </tr>
        </thead>
        <tbody>{renderRows(tableData)}</tbody>
      </table>
    </div>
  );
}

const aggregateData = (headers, rows, selectedFields, dateField, sumField, histPeriods, forePeriods) => {
  const periodIndex = headers.indexOf(dateField);
  const sumFieldIndex = headers.indexOf(sumField);
  const categoryIndices = selectedFields.map((field) => headers.indexOf(field));

  const aggregatedData = {};
  const uniquePeriods = new Set();

  // Group by historical periods using from/to dates
  histPeriods.forEach((period) => {
    const { label: periodLabel, fromDate, toDate } = period;
    const fromMoment = moment(fromDate, "DD/MM/YYYY");
    const toMoment = moment(toDate, "DD/MM/YYYY");

    rows.forEach((row) => {
      const rowDate = moment(row[periodIndex], "DD/MM/YYYY", true);

      if (rowDate.isBetween(fromMoment, toMoment, null, '[]')) {
        const sumValue = parseFloat(row[sumFieldIndex]) || 0;
        uniquePeriods.add(periodLabel);

        let currentLevel = aggregatedData;
        categoryIndices.forEach((index, depth) => {
          const category = row[index];

          // Skip if category matches the header name (e.g., "Vendor" or "Industry")
          const headerName = headers[index];
          if (!category || category === headerName) {
            return; // Skip this row if the category is invalid or matches the header name
          }

          if (!currentLevel[category]) {
            currentLevel[category] = { total: {}, children: {} };
          }
          if (!currentLevel[category].total[periodLabel]) {
            currentLevel[category].total[periodLabel] = { sum: 0 };
          }
          if (depth === categoryIndices.length - 1) {
            currentLevel[category].total[periodLabel].sum += sumValue;
          }
          currentLevel = currentLevel[category].children;
        });
      }
    });
  });

  // Process forecast periods (set to zero)
  forePeriods.forEach((period) => {
    const { label: periodLabel } = period;
    uniquePeriods.add(periodLabel);

    rows.forEach((row) => {
      let currentLevel = aggregatedData;
      categoryIndices.forEach((index, depth) => {
        const category = row[index];

        // Skip if category matches the header name
        const headerName = headers[index];
        if (!category || category === headerName) {
          return; // Skip this row if the category is invalid or matches the header name
        }

        if (!currentLevel[category]) {
          currentLevel[category] = { total: {}, children: {} };
        }
        if (!currentLevel[category].total[periodLabel]) {
          currentLevel[category].total[periodLabel] = { sum: 0 }; // Forecast periods have a sum of 0
        }
        currentLevel = currentLevel[category].children;
      });
    });
  });

  const sortedPeriods = [...histPeriods.map(p => p.label), ...forePeriods.map(p => p.label)];

  propagateTotals(aggregatedData, sortedPeriods);

  return { aggregatedData, uniquePeriods: sortedPeriods };
};

const propagateTotals = (data, periods) => {
  const totals = {};

  for (const key in data) {
    if (data[key].children && Object.keys(data[key].children).length > 0) {
      const subTotals = propagateTotals(data[key].children, periods);

      for (const period in subTotals) {
        if (!data[key].total[period]) {
          data[key].total[period] = { sum: 0 };
        }
        data[key].total[period].sum += subTotals[period].sum;
      }
    }
  }

  for (const key in data) {
    if (data[key].total) {
      for (const period of periods) {
        if (data[key].total[period]) {
          if (!totals[period]) {
            totals[period] = { sum: 0 };
          }
          totals[period].sum += data[key].total[period].sum;
        }
      }
    }
  }

  return totals;
};

const applyTopN = (data, selectedFields, topNValues, depth = 0) => {
  if (depth >= selectedFields.length) return data;

  const topN = parseInt(topNValues[depth], 10);
  if (isNaN(topN) || topN < 1) {
    return data; // No top N restriction
  }

  const sortedKeys = Object.keys(data).sort((a, b) => {
    const sumA = Object.values(data[a].total).reduce((acc, period) => acc + period.sum, 0);
    const sumB = Object.values(data[b].total).reduce((acc, period) => acc + period.sum, 0);
    return sumB - sumA;
  });

  const limitedData = {};
  sortedKeys.slice(0, topN).forEach(key => {
    limitedData[key] = data[key];
  });

  return limitedData;
};

export default CustomPivotTable;
