import {useEffect, useState} from 'react'
import CircularProgress from '@mui/material/CircularProgress'
import Stack from '@mui/material/Stack'
import Typography from '@mui/material/Typography'
import {makeApi} from '@equistamp/api'
import {makeSorter} from 'components/filters'
import {smartRound} from 'components/formatters'
import {allOf, filterBetween} from 'components/filters/filters'
import {modelColumns, generalColumns, Column} from './columns'
import ScoresTable, {ScoreModel} from './ScoresTable'
import {asColumn, unique} from './evaluations'
import ColumnsChooser from './ColumnsChooser'
import type {Evaluation, Model, Score, Direction} from '@equistamp/types'

type SortConfig = {col: string; direction: Direction}
const setSort =
  (col: string) =>
  (current: SortConfig): SortConfig => {
    if (col !== current.col) return {col, direction: 'Desc'}
    return {col, direction: current.direction === 'Desc' ? 'Asc' : 'Desc'}
  }

// Return a list of all the models that should be displayed, along with the scores of the
// selected evaluations
const selectModels = (allScores: Score[], models: Model[], selectedEvals: string[]) => {
  type ItemsMap = {[k: string]: Model | Evaluation}
  const itemsMap = models.reduce((acc, i) => ({...acc, [i.id]: i}), {}) as ItemsMap

  const updateModel = (acc: {[k: string]: ScoreModel}, score: Score) => {
    const id = score.evaluatee_id
    const current = acc[id] || itemsMap[id]
    if (!current) return acc
    current.scores = {
      ...(current.scores || {}),
      [score.evaluation_id]: score,
    }
    return {...acc, [id]: {...current, [score.evaluation_id]: score.score}}
  }
  return Object.values(
    allScores
      .filter((s) => selectedEvals.includes(s.evaluation_id))
      .reduce(updateModel, {} as {[k: string]: ScoreModel})
  ).map((m) => ({...m, ...m?.statistics})) as ScoreModel[]
}

const addMinMaxVals = (modelScores: ScoreModel[], cols: Column[]) =>
  unique(cols).map((col) => {
    const vals = makeSorter(col.id, 'Desc')(modelScores).map((m) => m[col.id as keyof ScoreModel])
    return {...col, minValue: smartRound(vals[vals.length - 1]), maxValue: smartRound(vals[0])}
  })

const Loading = () => (
  <>
    <CircularProgress />
    <Typography>Fetching Scores...</Typography>
  </>
)

type CompareProps = {
  evaluations: Evaluation[]
  models: Model[]
}
const Compare = ({evaluations, models}: CompareProps) => {
  const [loading, setLoading] = useState(true)
  const [allScores, setAllScores] = useState([] as Score[])
  const [currentModels, setCurrentModels] = useState<ScoreModel[]>([])
  const [currentEvals, setCurrentEvals] = useState<Column[]>([])
  const [currentSort, setCurrentSort] = useState({col: 'score', direction: 'Desc'} as SortConfig)
  const [available, setAvailable] = useState([...generalColumns, ...Object.values(modelColumns)])
  const [currentColumns, setCurrentColumns] = useState<Column[]>([
    modelColumns.cost_per_1M_input_tokens_usd,
    modelColumns.cost_per_1M_output_tokens_usd,
    modelColumns.median_latency,
  ])

  useEffect(() => {
    const fetchAll = async () => {
      const {items} = await makeApi().scores.list({perPage: 'all'})
      setAllScores(items as Score[])
      setLoading(false)
    }
    fetchAll()
  }, [setLoading, setAllScores])

  useEffect(() => {
    const evalIds = new Set(allScores.map((s) => s.evaluation_id))
    const selectedEvals = evaluations.filter((e) => evalIds.has(e.id)).map(asColumn)

    const modelScores = selectModels(
      allScores,
      models,
      selectedEvals.map((e) => e.id)
    )
    setAvailable(
      addMinMaxVals(modelScores, [
        ...generalColumns,
        ...selectedEvals,
        ...Object.values(modelColumns),
      ])
    )
    setCurrentColumns((current) => [
      ...addMinMaxVals(modelScores, [...current, ...generalColumns, ...selectedEvals.slice(0, 1)]),
    ])
    setCurrentEvals(
      selectedEvals.sort((a, b) =>
        a.name.toLocaleLowerCase().localeCompare(b.name.toLocaleLowerCase())
      )
    )
  }, [allScores, models, evaluations])

  useEffect(() => {
    const sorter = makeSorter(currentSort.col, currentSort.direction)
    const filter = allOf(
      currentColumns
        .filter((c) => c.filters?.max || c.filters?.min)
        .map((c) => filterBetween(c.filters?.min, c.filters?.max, c.id, 0))
    )

    setCurrentModels(
      sorter(
        selectModels(
          allScores,
          models,
          currentEvals.map((e) => e.id)
        ).filter(filter)
      )
    )
  }, [currentSort, currentColumns, setCurrentModels, allScores, models, currentEvals])

  const addFuncs = (c: Column) => ({
    ...c,
    onSort: () => setCurrentSort(setSort(c.id)),
    onFilter: (min?: number, max?: number) => {
      setCurrentColumns((current) =>
        current.map((i) => (i.id === c.id ? {...i, filters: {...c.filters, min: min, max}} : i))
      )
    },
    sortDir: currentSort.col === c.id ? currentSort.direction : undefined,
  })

  return (
    <Stack spacing={4} sx={{margin: 4}} justifyContent="space-between" alignItems="center">
      <Stack direction="row" justifyContent="flex-end" sx={{width: '100%', pr: 11}}>
        <ColumnsChooser
          available={available}
          current={currentColumns}
          onChange={setCurrentColumns}
        />
      </Stack>
      {loading ? (
        <Loading />
      ) : (
        <ScoresTable
          items={currentModels}
          evals={currentColumns.filter((c) => c.type === 'evaluation').map(addFuncs)}
          columns={currentColumns.filter((c) => c.type !== 'evaluation').map(addFuncs)}
        />
      )}
    </Stack>
  )
}

export default Compare
