import { groupBy, last } from 'ramda'
import { resolveObjectKey } from 'utils'
import { formatter } from 'utils/string'
import { useMemo, useState } from 'react'
import GridTable from './GridTable'
import Checkbox from 'components/Inputs/Checkbox'
import classNames from 'classnames'

const EPSILON = 0.001
const DEFAULT_VALUE_BG = 'bg-grey-100'
const HEATMAP_HEADER_BG = 'bg-blue-100'

const HEATMAP_CLASSES = [
  'bg-blue-300',
  'bg-blue-200',
  'bg-blue-100',
  DEFAULT_VALUE_BG,
  'bg-error-100',
  'bg-error-200',
  'bg-error-300'
]

const LEGEND_CLASSES = [...HEATMAP_CLASSES]
LEGEND_CLASSES.reverse()

const COLORS = ['bg-blue-100', 'bg-blue-200', 'bg-blue-300', 'bg-blue-400']

const calculateHorizontalSum = (data, comparedProps, refProp) => {
  return Object.keys(data).reduce(
    (acc, date) => {
      for (const comparedProp of comparedProps) {
        acc[comparedProp] += data[date][comparedProp] || 0;
      }
      if (data[date][refProp]) {
        acc[refProp] += data[date][refProp];
      }
      return acc;
    },
    comparedProps.reduce((obj, prop) => ({ ...obj, [prop]: 0 }), { [refProp]: 0 })
  );
};

const getHorizontalSum = (data, comparedPropsObj, refPropObj) => {
  const comparedProps = comparedPropsObj?.map(p => p.objKey);
  const refProp = refPropObj?.objKey;

  const sumCompared = calculateHorizontalSum(data.data, comparedProps, refProp);

  const subRowsSums = data.subRows?.map(subRow => {
    return calculateHorizontalSum(subRow.data, comparedProps, refProp);
  }) || [];

  return {
    ...sumCompared,
    subRows: subRowsSums,
  };
};

function getVerticalSum(array, comparedProps, refProp) {
  const total = {
    title: 'Total',
    data: {},
    className: 'border-t-2 border-[#5B6673] !font-bold'
  };

  array.forEach(item => {
    if (item.subRows) {
      for (const [date, values] of Object.entries(item.data)) {
        if (!total.data[date]) {
          total.data[date] = {};
          comparedProps.forEach(prop => {
            total.data[date][prop.objKey] = 0;
          });
          total.data[date][refProp.objKey] = 0;
        }

        comparedProps.forEach(prop => {
          total.data[date][prop.objKey] += values[prop.objKey];
        });

        total.data[date][refProp.objKey] += values[refProp.objKey];
      }
    }
  });

  return [...array, total];
}

const findLargestDiviationsInternal = (refPropPath, propPathList, items) => {
  return (
    items?.reduce((result, item) => {
      return [
        ...result,
        ...findLargestDiviationsInternal(
          refPropPath,
          propPathList,
          item.subRows
        ),
        ...Object.values(item.data).reduce((acc, cur) => {
          const refVal = resolveObjectKey(cur, refPropPath)
          return [
            ...acc,
            ...propPathList.map(
              (propKey) => resolveObjectKey(cur, propKey) - refVal
            )
          ]
        }, [])
      ]
    }, []) ?? []
  )
}

const findLargestDiviations = (refPropPath, propPathList, items) => {
  const distances = findLargestDiviationsInternal(
    refPropPath,
    propPathList,
    items
  )

  const { positive, negative } = groupBy(
    (x) => (x >= 0 ? 'positive' : 'negative'),
    distances.filter((x) => !!x)
  )

  return {
    maxPositive: positive?.reduce((acc, cur) => Math.max(acc, cur), 1) ?? 1,
    maxNegative: negative?.reduce((acc, cur) => Math.min(acc, cur), -1) ?? -1
  }
}

const findRangeMappings = (distances) => {
  const mappings = []

  const itemCount = (HEATMAP_CLASSES.length - 1) / 2
  for (let i = 0; i < itemCount; ++i) {
    mappings.push({
      start: distances.maxPositive * (i / itemCount),
      end: distances.maxPositive * ((i + 1) / itemCount) + EPSILON,
      idx: itemCount - i - 1
    })

    mappings.push({
      start: distances.maxNegative * ((i + 1) / itemCount) - EPSILON,
      end: distances.maxNegative * (i / itemCount),
      idx: itemCount + i + 1
    })
  }

  return mappings
}

const rowContextFactory = (item) => ({
  expandable: Array.isArray(item.subRows)
})

/*
type PropInfo = {
  objKey: string;
  title?: string;
  legend?: string;
}

type RowData = {
  id: string | number;
  title?: string;
  data: {
    [k: string]: number | string;
  }
  subRows?: RowData[];
}
*/

const HeatMapTable = ({
  refProp, // PropInfo
  comparedProps, // PropInfo
  groups,
  data: dataProp, // RowData[] where k must be the same as the group names
  showControls = true,
  showLegend = true,
  showTotal = false,
  defaultColoringMode = 'heatmap', // 'heatmap' | 'keys'
  defaultShowReference = true,
  heatmapLegend = true,
  rowTitles = true,
  expandable = true,
  rowTitleFormatter,
  groupLabelFormatter,
  columnValueFormatter = (v) => v,
  columnSizes
}) => {
  const [showReference, setShowReference] = useState(defaultShowReference)
  const [coloringMode, setColoringMode] = useState(defaultColoringMode)

  const data = useMemo(
    () => showTotal ? getVerticalSum(dataProp, comparedProps, refProp) : [...dataProp],
    [dataProp, showTotal, comparedProps, refProp]
  );

  const columnHeaders = useMemo(
    () => [
      ...(rowTitles ? [undefined] : []),
      ...groups.reduce(
        (acc) => [
          ...acc,
          ...comparedProps.map(({ objKey, title }, idx) => (
            <div
              key={`${title}-${idx}`}
              className={classNames(
                coloringMode === 'heatmap'
                  ? HEATMAP_HEADER_BG
                  : COLORS[idx % COLORS.length],
                'w-full h-full p-2'
              )}
            >
              {title || formatter(last(objKey.split('.')))}
            </div>
          )),
          ...(showReference
            ? [
                <div
                  key={`${refProp.objKey}-${refProp.title}`}
                  className={classNames(DEFAULT_VALUE_BG, 'w-full h-full p-2')}
                >
                  {refProp.title || formatter(last(refProp.objKey.split('.')))}
                </div>
              ]
            : []),
        ],
        []
      ),
      ...(showTotal ? [...comparedProps.map(({ objKey, title }, idx) => (
        <div key={`${title}-${idx}-total`} className='border-l-2 border-[#5B6673] w-full h-full justify-center bg-[#E7E5E5] text-[#5B6673] font-bold flex items-center'>
          {title || formatter(last(objKey.split('.')))}
        </div>
      ))] : []),
      ...(showTotal && showReference
        ? [
            <div
              key={`${refProp.objKey}-${refProp.title}`}
              className='w-full h-full justify-center bg-[#E7E5E5] text-[#5B6673] font-bold flex items-center'
            >
              {refProp.title || formatter(last(refProp.objKey.split('.')))}
            </div>
          ]
        : []),
    ],
    [rowTitles, groups, showTotal, comparedProps, showReference, refProp.objKey, refProp.title, coloringMode]
  )

  const columnHeaderGroups = useMemo(
    () => {
      const basicGroups = groups.map((m, i) => ({
        start: i * (comparedProps.length + !!showReference) + 1 + !!rowTitles,
        span: comparedProps.length + !!showReference,
        title: groupLabelFormatter ? groupLabelFormatter(m) : m
      }));

      if (showTotal) {
        const totalStart = basicGroups.length * (comparedProps.length + !!showReference) + 1 + !!rowTitles;
        basicGroups.push({
          start: totalStart,
          span: comparedProps.length + !!showReference,
          title: 'Total',
          className: 'border-l-2 border-[#5B6673] !bg-[#E7E5E5] !text-[#5B6673] !font-bold',
        });
      }

      return basicGroups;
    },
    [groups, comparedProps.length, showReference, rowTitles, groupLabelFormatter, showTotal]
  );


  const ranges = findRangeMappings(
    findLargestDiviations(
      refProp.objKey,
      comparedProps.map((x) => x.objKey),
      data
    )
  )

  const heatmapClassName = (val, ref) => {
    const dist = val - ref

    if (Math.abs(dist) > EPSILON) {
      for (const mapping of ranges) {
        if (dist >= mapping.start && dist < mapping.end) {
          return HEATMAP_CLASSES[mapping.idx]
        }
      }
    }

    return HEATMAP_CLASSES[(HEATMAP_CLASSES.length - 1) / 2]
  }

  const horizontalSum = useMemo(() => data.map(item => getHorizontalSum(item, comparedProps, refProp)), [comparedProps, data, refProp])

  const columnMappers = [
    ...(rowTitles
      ? [
          {
            className: 'bg-white',
            content: ({ item }) => (
              <div>
                {rowTitleFormatter ? rowTitleFormatter(item.title) : item.title}
              </div>
            )
          }
        ]
      : []),
    ...groups.reduce(
      (acc, groupName) => [
        ...acc,
        ...comparedProps.map(({ objKey }, propIdx) => ({
          className: (item) =>
            classNames(coloringMode === 'heatmap'
              ? heatmapClassName(
                  resolveObjectKey(item.data[groupName], objKey),
                  resolveObjectKey(item.data[groupName], refProp.objKey)
                )
              : COLORS[propIdx % COLORS.length], 'justify-center'),
          content: ({ item }) => {
            const value = resolveObjectKey(item.data[groupName], objKey);
            return (
              <>
                {columnValueFormatter(value)}
              </>
            );
          }
        })),
        ...(showReference
          ? [
              {
                className: classNames(DEFAULT_VALUE_BG, 'justify-center'),
                content: ({ item }) => {
                  return (
                    <>
                      {columnValueFormatter(
                        resolveObjectKey(item.data[groupName], refProp.objKey)
                      )}
                    </>
                )}
              }
            ]
          : []),
      ],
      []
    ),
    ...(showTotal ? [...comparedProps.map(({ objKey }) => {
      return {
        className: 'border-l-2 border-[#5B6673] justify-center bg-[#E7E5E5] text-[#5B6673] font-bold',
        content: ({ rowKey, ...rest }) => {
          const isSubItem = String(rowKey).includes('-');
          let value;

          if (isSubItem) {
            const [parentIdx, subIdx] = rowKey.split('-').map(Number);
            value = horizontalSum[parentIdx].subRows[subIdx][objKey];
          } else {
            value = horizontalSum[rowKey][objKey];
          }

          return (
            <b> {columnValueFormatter(value)} </b>
          );
        }
      };
    })] : []),
    ...(showTotal && showReference
      ? [
          {
            className: 'justify-center bg-[#E7E5E5] text-[#5B6673] font-bold',
            content: ({ rowKey }) => {
              const isSubItem = String(rowKey).includes('-');
              let value;

              if (isSubItem) {
                const [parentIdx, subIdx] = rowKey.split('-').map(Number);
                value = horizontalSum[parentIdx].subRows[subIdx][refProp.objKey];
              } else {
                value = horizontalSum[rowKey][refProp.objKey];
              }

              return (
                <b>{columnValueFormatter(value)}</b>
              );
            }
          }
        ]
      : []),
  ]

  return (
    <div>
      {showControls && (
        <div className='flex flex-row gap-4 mb-8'>
          <span>Table view:</span>
          <Checkbox
            label={`${
              refProp.title || formatter(last(refProp.objKey.split('.')))
            } Comparison`}
            isChecked={showReference}
            onChange={() => setShowReference((prev) => !prev)}
          />
          <Checkbox
            label='Variation heatmap'
            isChecked={coloringMode === 'heatmap'}
            onChange={() =>
              setColoringMode((prev) =>
                prev === 'heatmap' ? 'keys' : 'heatmap'
              )
            }
          />
        </div>
      )}
      <div className='w-full overflow-x-auto'>
        <GridTable
          data={data}
          keyMapper={(row) => row.id}
          rowContext={rowContextFactory}
          columnMappers={columnMappers}
          columnHeaders={columnHeaders}
          columnHeaderGroups={columnHeaderGroups}
          renderExpandedRow={expandable && (({ subRows }) => subRows)}
          expandArrowPosition={'left'}
          cellClassNames={
            'py-1 pr-2 p-2 mx-0.5 mb-1 text-sm font-medium text-gray-900 flex items-center'
          }
          headerClassNames={
            'flex items-center justify-center text-sm m-0.5 mb-2 rounded-sm'
          }
          groupHeaderClassNames={
            'flex items-center justify-center text-sm p-2 m-0.5 bg-blue-200 rounded-sm'
          }
          columnSizes={columnSizes}
          stickyFirstColumns={1}
        />
      </div>
      {coloringMode === 'keys' && showLegend && (
        <div className='flex flex-row gap-4 mt-4'>
          {comparedProps.map(({ objKey, title, legend }, idx) => (
            <LegendItem
              key={idx}
              circleClassName={COLORS[idx % COLORS.length]}
              label={legend || title || formatter(last(objKey.split('.')))}
            />
          ))}
          {showReference && (
            <LegendItem
              circleClassName={DEFAULT_VALUE_BG}
              label={
                refProp.legend ||
                refProp.title ||
                formatter(last(refProp.objKey.split('.')))
              }
            />
          )}
        </div>
      )}
      {coloringMode === 'heatmap' && heatmapLegend && showLegend && (
        <div className='flex flex-row gap-2 mt-4 items-center'>
          <span className='text-sm'>
            {comparedProps
              .map(
                ({ title, objKey }) =>
                  title || formatter(last(objKey.split('.')))
              )
              .join(', ')}{' '}
            - {refProp.title || formatter(last(refProp.objKey.split('.')))}{' '}
            Variation:
          </span>
          <span className='text-xs'>Negative</span>
          <div className='flex flex-row'>
            {LEGEND_CLASSES.map((className, idx) => (
              <div key={idx} className={classNames(className, 'h-5 w-6')} />
            ))}
          </div>
          <span className='text-xs'>Positive</span>
        </div>
      )}
    </div>
  )
}

const LegendItem = ({ circleClassName, label }) => (
  <div className='flex flex-row items-center text-sm'>
    <div
      className={classNames('h-3 w-3 rounded-full mr-1.5', circleClassName)}
    ></div>
    {label}
  </div>
)

export default HeatMapTable
