import { AxisOptions, Chart } from 'react-charts';
import { useSelector } from 'react-redux';
import { Usage, getUsage } from '../../slices/usage';
import { useMemo } from 'react';
import flatMap from 'lodash/flatMap';
import groupBy from 'lodash/groupBy';
import sumBy from 'lodash/sumBy';
import Card from '../common/Card';

const UsageComponent = () => {
  const usageData = useSelector(getUsage);

  const primaryAxis = useMemo<AxisOptions<Usage>>(
    () => ({
      getValue: (datum: any) => {
        const date = new Date(datum.date);
        date.setHours(17);
        return date;
      },
    }),
    []
  );

  const secondaryAxes = useMemo<AxisOptions<Usage>[]>(
    () => [
      {
        getValue: (datum: any) => datum.completion_tokens + datum.prompt_tokens,
        elementType: 'line',
      },
    ],
    []
  );

  const secondaryImageAxes = useMemo<AxisOptions<Usage>[]>(
    () => [
      {
        getValue: (datum: any) => datum.image_steps,
        elementType: 'line',
      },
    ],
    []
  );

  const secondaryAudioAxes = useMemo<AxisOptions<Usage>[]>(
    () => [
      {
        getValue: (datum: any) => datum.audio_tokens,
        elementType: 'line',
      },
    ],
    []
  );

  const groupedUsage = Object.entries(
    groupBy(usageData, (x) => x.model_name)
  ).map(([name, usage]) => ({
    id: usage?.[0].model_id,
    name,
    type: usage?.[0].model_type,
    usage,
  }));

  const textData = groupedUsage
    .filter((usageDatum: any) => usageDatum.type === 'llm')
    .filter((usageDatum: any) =>
      sumBy(
        usageDatum.usage,
        (item: any) => item.completion_tokens + item.prompt_tokens
      )
    )
    .map((modelUsage: any) => ({
      label: modelUsage.name,
      data: modelUsage.usage,
    }));
  const imageData = groupedUsage
    .filter((usageDatum: any) => usageDatum.type === 'sd')
    .filter((usageDatum: any) =>
      sumBy(usageDatum.usage, (item: any) => item.image_steps)
    )
    .map((modelUsage: any) => ({
      label: modelUsage.name,
      data: modelUsage.usage,
    }));
  const audioData = groupedUsage
    .filter((usageDatum: any) => usageDatum.type === 'audio')
    .filter((usageDatum: any) =>
      sumBy(usageDatum.usage, (item: any) => item.audio_tokens)
    )
    .map((modelUsage: any) => ({
      label: modelUsage.name,
      data: modelUsage.usage,
    }));

  const allUsage = flatMap(groupedUsage.map((usage: any) => usage.usage));

  if (groupedUsage.length === 0) return null;

  const allTextUsage = sumBy(
    allUsage,
    (datum: Usage) => datum.completion_tokens + datum.prompt_tokens
  );
  const allImageUsage = sumBy(allUsage, (datum: Usage) => datum.image_steps);
  const allAudioUsage = sumBy(allUsage, (datum: Usage) => datum.audio_tokens);

  return (
    <div className="flex flex-col mb-10 w-full">
      <h2>Usage</h2>
      <div className="flex flex-col gap-5">
        <div className="flex gap-4">
          {allTextUsage > 0 && (
            <Card noHover className="text-black px-4 py-3">
              <div className="mb-2 text-sm">Total Text Tokens</div>
              <div className="text-xl font-bold">{allTextUsage}</div>
            </Card>
          )}
          {allImageUsage > 0 && (
            <Card noHover className="text-black px-4 py-3">
              <div className="mb-2 text-sm">Total Image Steps</div>
              <div className="text-xl font-bold">{allImageUsage}</div>
            </Card>
          )}
          {allAudioUsage > 0 && (
            <Card noHover className="text-black px-4 py-3">
              <div className="mb-3">Total Audio Tokens</div>
              <div className="text-xl font-bold">{allAudioUsage}</div>
            </Card>
          )}
        </div>
        {textData.length > 0 && (
          <Card noHover className="w-full">
            <div className="text-black mb-4">Text Generation</div>
            <div className="h-40 w-full">
              <Chart
                options={{
                  data: textData,
                  primaryAxis,
                  secondaryAxes,
                }}
              />
            </div>
          </Card>
        )}
        {imageData.length > 0 && (
          <Card noHover className="w-full">
            <div className="text-black mb-4">Image Generation</div>
            <div className="h-40 w-full">
              <Chart
                options={{
                  data: imageData,
                  primaryAxis,
                  secondaryAxes: secondaryImageAxes,
                }}
              />
            </div>
          </Card>
        )}
        {audioData.length > 0 && (
          <Card noHover className="w-full">
            <div className="text-black mb-4">Audio Generation</div>
            <div className="h-40 w-full">
              <Chart
                options={{
                  data: audioData,
                  primaryAxis,
                  secondaryAxes: secondaryAudioAxes,
                }}
              />
            </div>
          </Card>
        )}
      </div>
    </div>
  );
};

export default UsageComponent;
