import moment from "moment-timezone";
import { Array, String } from "runtypes";

import { nullthrows } from "../../assertions.js";
import {
  BinaryColumnPredicate,
  BinaryColumnPredicateOp,
  ColumnPredicate,
  CompoundColumnPredicate,
  CompoundColumnPredicateOp,
  ListBinaryColumnPredicate,
  ListBinaryColumnPredicateOp,
  UnaryColumnPredicate,
  UnaryColumnPredicateOp,
} from "../../display-table/columnPredicateTypes.js";
import { ColumnFilter } from "../../display-table/filterTypes.js";
import { FilledDynamicValueTableColumnType } from "../../DynamicValue.js";
import { assertNever } from "../../errors.js";
import { getColumnFromColumnNameWithQueryPath } from "../../explore/exploreChartSemanticUtils.js";
import { SemanticAwareFieldGroup } from "../../explore/semanticTypes.js";
import { FilterType } from "../../filter/filterTypes.js";
import { DisplayTableColumnId } from "../../idTypeBrands.js";
import { notEmpty } from "../../notEmpty.js";
import {
  typedObjectEntries,
  typedObjectKeys,
  typedObjectValues,
} from "../../utils/typedObjects.js";
import { unescapeVegaString } from "../../vega-chart-cell/vegaEscape.js";
import { isMondayStartWeek } from "../chartTimeUnitUtils.js";
import {
  ChartSelection,
  SelectionDerived,
  TimeunitBinnedField,
  VegaIntervalDatetimeRange,
  VegaIntervalNumericRange,
  VegaIntervalSelection,
  VegaPointSelection,
} from "../types.js";
import { TimezoneName } from "../../dateTypes.js";

export function getChartSelectionFilters({
  columnTypes = {},
  derived = [],
  displayTimezone,
  fieldGroups,
  filterType = FilterType.KEEP,
  selection,
  timestampFilteringEnabled = false,
}: {
  columnTypes?: Record<string, FilledDynamicValueTableColumnType>;
  selection: ChartSelection;
  filterType?: FilterType;
  derived?: SelectionDerived;
  timestampFilteringEnabled?: boolean;
  fieldGroups?: SemanticAwareFieldGroup[];
  displayTimezone: TimezoneName;
}): ColumnFilter[] {
  const filters: ColumnFilter[] = [];
  typedObjectValues(selection).forEach((value) => {
    if (VegaIntervalSelection.guard(value)) {
      filters.push(
        ...intervalSelectionFilters({
          selection: value,
          filterType,
          derived,
          columnTypes,
          timestampFilteringEnabled,
          fieldGroups,
          displayTimezone,
        }),
      );
    } else if (VegaPointSelection.guard(value)) {
      filters.push(
        ...pointSelectionFilters({
          pointSelection: value,
          filterType,
          derived,
          columnTypes,
          timestampFilteringEnabled,
          displayTimezone,
        }),
      );
    }
  });
  return consolidateFilters(filters, filterType, fieldGroups);
}

function intervalSelectionFilters({
  columnTypes,
  derived,
  displayTimezone,
  fieldGroups,
  filterType,
  selection,
  timestampFilteringEnabled,
}: {
  displayTimezone: TimezoneName;
  columnTypes: Record<string, FilledDynamicValueTableColumnType>;
  derived: SelectionDerived;
  filterType: FilterType;
  selection: VegaIntervalSelection;
  fieldGroups?: SemanticAwareFieldGroup[];
  timestampFilteringEnabled: boolean;
}): ColumnFilter[] {
  const filters: ColumnFilter[] = [];
  for (const [rawField, interval] of typedObjectEntries(selection)) {
    if (interval.length < 1) {
      continue;
    }
    // until we can restore the "bin" entries in `derived`, post-pushdown
    // pattern-match to filter on the underlying field
    const field =
      rawField.includes("col_hist_") && rawField.endsWith("_bin_start")
        ? rawField.slice(0, -10)
        : rawField;
    const start = interval[0];
    const end = interval[interval.length - 1];

    const startIsDate =
      typeof start === "string" ? isIsoDateOrTimestamp(start) : false;
    const endIsDate =
      typeof end === "string" ? isIsoDateOrTimestamp(end) : false;

    const sourceField = maybeRenamed(field, derived);
    let columnType;
    // `columnTypes` only contains the types of the columns on the "base"
    // dataframe, so if `fieldGroups` is present (i.e. in explore
    // context), then we should get the type from field groups to handle
    // joined fields.
    if (fieldGroups) {
      const column = getColumnFromColumnNameWithQueryPath(
        sourceField,
        fieldGroups,
      );
      columnType = column?.columnType;
    } else {
      columnType = columnTypes[sourceField];
    }

    if (typeof start === "number" && typeof end === "number") {
      // Numeric interval selection
      if (start === end) {
        continue;
      }

      const { interval: sourceInterval, ops } = maybeBinNumericInterval(
        field,
        [start, end],
        filterType,
        derived,
      );

      filters.push({
        column: sourceField as DisplayTableColumnId,
        columnType,
        predicate: {
          op:
            filterType === FilterType.KEEP
              ? CompoundColumnPredicateOp.AND
              : CompoundColumnPredicateOp.OR,
          args: [
            {
              op: ops[0],
              arg: `${sourceInterval[0]}`,
            },
            {
              op: ops[1],
              arg: `${sourceInterval[1]}`,
            },
          ],
        },
      });
    } else if (
      typeof start === "string" &&
      typeof end === "string" &&
      startIsDate &&
      endIsDate &&
      (columnType === "DATE" ||
        columnType === "DATETIME" ||
        columnType === "DATETIMETZ")
    ) {
      filters.push(
        getDateFilter(
          field,
          timezoneAdjust(start, displayTimezone),
          timezoneAdjust(end, displayTimezone),
          filterType,
          columnType,
          derived,
          timestampFilteringEnabled,
        ),
      );
    } else if (Array(String).guard(interval)) {
      filters.push(
        getCategoricalFilter(interval, filterType, columnType, sourceField),
      );
    }
  }
  return filters;
}

function pointSelectionFilters({
  columnTypes,
  derived,
  displayTimezone,
  filterType,
  pointSelection,
  timestampFilteringEnabled,
}: {
  displayTimezone: TimezoneName;
  columnTypes: Record<string, FilledDynamicValueTableColumnType>;
  derived: SelectionDerived;
  filterType: FilterType;
  pointSelection: VegaPointSelection;
  timestampFilteringEnabled: boolean;
}): ColumnFilter[] {
  const filters: ColumnFilter[] = [];
  const firstSelection = pointSelection[0];
  if (firstSelection == null) {
    return [];
  }

  for (const field of typedObjectKeys(firstSelection)) {
    const sourceField = maybeRenamed(field, derived);
    const columnType = columnTypes[sourceField];
    const selectedValues = pointSelection
      .map((selection) => selection[field])
      .filter(notEmpty); // shouldnt be necessary as all selections should have the same fields but just in case

    if (
      Array(String).guard(selectedValues) &&
      columnType !== "DATE" &&
      columnType !== "DATETIME" &&
      columnType !== "DATETIMETZ"
    ) {
      filters.push(
        getCategoricalFilter(
          selectedValues,
          filterType,
          columnType,
          sourceField,
        ),
      );
    } else {
      for (const selection of pointSelection) {
        const value = selection[field];
        if (typeof value === "number") {
          filters.push({
            column: sourceField as DisplayTableColumnId,
            columnType: FilledDynamicValueTableColumnType.NUMBER,
            predicate: maybeBinNumericPointSelection(
              field,
              value,
              filterType,
              derived,
            ),
          });
        } else if (
          typeof value === "string" &&
          isIsoDateOrTimestamp(value) &&
          (columnType === "DATE" ||
            columnType === "DATETIME" ||
            columnType === "DATETIMETZ")
        ) {
          const tzAdjustedValue = timezoneAdjust(value, displayTimezone);
          filters.push(
            getDateFilter(
              field,
              tzAdjustedValue,
              tzAdjustedValue,
              filterType,
              columnType,
              derived,
              timestampFilteringEnabled,
            ),
          );
        }
      }
    }
  }
  return filters;
}

function getCategoricalFilter(
  selections: string[],
  filterType: FilterType,
  columnType: FilledDynamicValueTableColumnType | undefined,
  sourceField: string,
): ColumnFilter {
  // Categorical interval selection
  const predicates: ColumnPredicate[] = [];
  const args: string[] = [];

  const seenValue = new Set();
  for (const selection of selections) {
    if (!seenValue.has(selection)) {
      if (selection === "null") {
        // If value is the string "null", always check IS NULL / IS NOT NULL
        predicates.push({
          op:
            filterType === FilterType.KEEP
              ? UnaryColumnPredicateOp.IS_NULL
              : UnaryColumnPredicateOp.NOT_NULL,
        });
        // If the column type is STRING, also check for the string "null"
        if (columnType === "STRING") {
          args.push(selection.toString());
          seenValue.add(selection);
        }
      } else {
        args.push(selection.toString());
        seenValue.add(selection);
      }
    }
  }

  if (args.length > 0) {
    predicates.push({
      op:
        filterType === FilterType.KEEP
          ? ListBinaryColumnPredicateOp.IS_ONE_OF
          : ListBinaryColumnPredicateOp.NOT_ONE_OF,
      arg: args,
    });
  }

  return predicates.length === 1
    ? {
        column: sourceField as DisplayTableColumnId,
        columnType,
        predicate: nullthrows(predicates[0]),
      }
    : {
        column: sourceField as DisplayTableColumnId,
        columnType,
        predicate: {
          op:
            filterType === FilterType.KEEP
              ? CompoundColumnPredicateOp.OR
              : CompoundColumnPredicateOp.AND,
          args: predicates,
        },
      };
}

function getDateFilter(
  field: string,
  start: string | number | boolean | undefined,
  end: string | number | boolean | undefined | undefined,
  filterType: FilterType,
  columnType: FilledDynamicValueTableColumnType,
  derived: SelectionDerived,
  timestampFilteringEnabled: boolean,
): ColumnFilter {
  const predicate = getDateFilterPredicate(
    field,
    [start as string, end as string],
    filterType,
    derived,
    columnType,
    timestampFilteringEnabled,
  );

  const sourceField = maybeRenamed(field, derived);
  return {
    column: sourceField as DisplayTableColumnId,
    columnType,
    predicate,
  };
}

function maybeBinNumericPointSelection(
  field: string,
  value: number,
  filterType: FilterType,
  derived: SelectionDerived,
): ColumnPredicate {
  for (const d of derived) {
    if (d.type === "bin" && d.derivedName === field && d.binsConfig != null) {
      const { start, step, stop } = d.binsConfig;
      const numBins = Math.round((stop - start) / step);

      // Bin start
      if (start <= value && value <= stop) {
        const binIndex = Math.floor((value - start) / step);
        const lowerEdge = start + step * binIndex;
        const upperEdge = start + step * (binIndex + 1);
        const includeLastEdge = numBins === binIndex + 1;
        if (filterType === FilterType.KEEP) {
          return {
            op: "AND",
            args: [
              {
                op: BinaryColumnPredicateOp.GTE,
                arg: `${lowerEdge}`,
              },
              {
                op: includeLastEdge
                  ? BinaryColumnPredicateOp.LTE
                  : BinaryColumnPredicateOp.LT,
                arg: `${upperEdge}`,
              },
            ],
          };
        } else {
          return {
            op: "OR",
            args: [
              {
                op: BinaryColumnPredicateOp.LT,
                arg: `${lowerEdge}`,
              },
              {
                op: includeLastEdge
                  ? BinaryColumnPredicateOp.GT
                  : BinaryColumnPredicateOp.GTE,
                arg: `${upperEdge}`,
              },
            ],
          };
        }
      }
    }
  }

  // Not binned
  const predicateOp =
    filterType === FilterType.KEEP
      ? BinaryColumnPredicateOp.EQ
      : BinaryColumnPredicateOp.NEQ;

  return {
    op: predicateOp,
    arg: `${value}`,
  };
}

function consolidateFilters(
  filters: ColumnFilter[],
  filterType: FilterType,
  fieldGroups?: SemanticAwareFieldGroup[],
): ColumnFilter[] {
  const predicatesByColumn: Record<DisplayTableColumnId, ColumnPredicate[]> =
    {};
  const columnTypesByColumn: Record<
    DisplayTableColumnId,
    FilledDynamicValueTableColumnType | undefined
  > = {};
  for (const filter of filters) {
    if (columnTypesByColumn[filter.column] == null) {
      columnTypesByColumn[filter.column] = filter.columnType;
    }
    if (predicatesByColumn[filter.column] == null) {
      predicatesByColumn[filter.column] = [];
    }
    predicatesByColumn[filter.column]?.push(filter.predicate);
  }

  return typedObjectEntries(predicatesByColumn).map(([column, predicates]) => {
    const dedupedPredicates: ColumnPredicate[] = [];
    const seenPredicates = new Set<string>();
    for (const p of predicates) {
      const predicateStr = stringifyPredicate(p);
      if (!seenPredicates.has(predicateStr)) {
        seenPredicates.add(predicateStr);
        dedupedPredicates.push(p);
      }
    }

    const columnFilter: ColumnFilter = {
      column,
      columnType: columnTypesByColumn[column],
      predicate: {
        op:
          filterType === FilterType.KEEP
            ? CompoundColumnPredicateOp.OR
            : CompoundColumnPredicateOp.AND,
        args: dedupedPredicates,
      },
    };
    if (dedupedPredicates.length === 1) {
      columnFilter.predicate = nullthrows(dedupedPredicates[0]);
    }

    if (fieldGroups) {
      // add semantic properties if the consuming chart needs
      // to be semantically aware.
      const semanticColumn = getColumnFromColumnNameWithQueryPath(
        column,
        fieldGroups,
      );

      const semanticColumnId = semanticColumn?.columnId as DisplayTableColumnId;
      columnFilter.column = semanticColumnId ?? column;
      columnFilter.fieldType = semanticColumn?.fieldType;
      columnFilter.queryPath = semanticColumn?.queryPath;
      // Use explore's column type if available
      columnFilter.columnType =
        semanticColumn?.columnType ?? columnFilter.columnType;
    }

    return columnFilter;
  });
}

function stringifyPredicate(predicate: ColumnPredicate): string {
  if (UnaryColumnPredicate.guard(predicate)) {
    return `${predicate.op}`;
  } else if (
    BinaryColumnPredicate.guard(predicate) ||
    ListBinaryColumnPredicate.guard(predicate)
  ) {
    return `${predicate.op}//${predicate.arg}`;
  } else if (CompoundColumnPredicate.guard(predicate)) {
    return `${predicate.op}//${predicate.args
      .map(stringifyPredicate)
      .join("//")}`;
  } else {
    assertNever(predicate, predicate);
  }
}

function maybeRenamed(field: string, derived: SelectionDerived): string {
  for (const d of derived) {
    if (d.derivedName === field) {
      // the source name has already been unescaped where needed so
      // dont need to do it again here
      return d.sourceName;
    }
  }
  // No derived field found
  return unescapeVegaString(field);
}

function maybeBinNumericInterval(
  field: string,
  interval: VegaIntervalNumericRange,
  filterType: FilterType,
  derived: SelectionDerived,
): {
  interval: VegaIntervalNumericRange;
  ops: [BinaryColumnPredicateOp, BinaryColumnPredicateOp];
} {
  // We have a numeric interval that is eligible to be binned
  for (const d of derived) {
    if (d.type === "bin" && d.derivedName === field && d.binsConfig != null) {
      const { start, step, stop } = d.binsConfig;
      const numBins = Math.round((stop - start) / step);
      const [inputStart, inputStop] = interval;
      let [outputStart, outputStop] = interval;

      // Bin start
      if (start <= inputStart && inputStart <= stop) {
        const binIndex = Math.floor((inputStart - start) / step);
        outputStart = start + step * binIndex;
      }

      // Bin stop
      let includeLastEdge;
      if (start <= inputStop && inputStop <= stop) {
        const binIndex = Math.ceil((inputStop - start) / step);
        outputStop = start + step * binIndex;
        includeLastEdge = numBins === binIndex;
      } else {
        includeLastEdge = inputStop >= stop;
      }

      const outputInterval: [number, number] = [outputStart, outputStop];

      // Binned interval is only inclusive on the right side if
      // stop falls in the last bin. This matches Vega.
      if (filterType === FilterType.KEEP) {
        return {
          interval: outputInterval,
          ops: [
            BinaryColumnPredicateOp.GTE,
            includeLastEdge
              ? BinaryColumnPredicateOp.LTE
              : BinaryColumnPredicateOp.LT,
          ],
        };
      } else {
        return {
          interval: outputInterval,
          ops: [
            BinaryColumnPredicateOp.LT,
            includeLastEdge
              ? BinaryColumnPredicateOp.GT
              : BinaryColumnPredicateOp.GTE,
          ],
        };
      }
    }
  }

  // Un-binned interval is inclusive on both sides
  if (filterType === FilterType.KEEP) {
    return {
      interval: roundNumericInterval(interval),
      ops: [BinaryColumnPredicateOp.GTE, BinaryColumnPredicateOp.LTE],
    };
  } else {
    return {
      interval: roundNumericInterval(interval),
      ops: [BinaryColumnPredicateOp.LT, BinaryColumnPredicateOp.GT],
    };
  }
}

function getDateFilterPredicate(
  field: string,
  interval: VegaIntervalDatetimeRange,
  filterType: FilterType,
  derived: SelectionDerived,
  columnType: FilledDynamicValueTableColumnType,
  timestampFilteringEnabled: boolean,
): ColumnPredicate {
  let outputStart = moment(interval[0]);
  let outputStop = moment(interval[1]);
  let isDayBinOrGreater = false;

  const timeUnitSelection: TimeunitBinnedField | undefined = derived.find(
    (d): d is TimeunitBinnedField =>
      d.derivedName === field && d.type === "timeunit",
  );

  if (timeUnitSelection) {
    switch (timeUnitSelection.timeUnit) {
      case "year": {
        outputStart = moment(outputStart).startOf("year");
        outputStop = moment(outputStop).endOf("year");
        isDayBinOrGreater = true;
        break;
      }
      case "yearquarter": {
        outputStart = moment(outputStart).startOf("quarter");
        outputStop = moment(outputStop).endOf("quarter");
        isDayBinOrGreater = true;
        break;
      }
      case "yearmonth": {
        outputStart = moment(outputStart).startOf("month");
        outputStop = moment(outputStop).endOf("month");
        isDayBinOrGreater = true;
        break;
      }
      case "yearweek": {
        if (
          isMondayStartWeek(
            timeUnitSelection.timeUnit,
            timeUnitSelection.timeUnitConfig,
          )
        ) {
          outputStart = moment(outputStart).startOf("isoWeek");
          outputStop = moment(outputStop).endOf("isoWeek");
        } else {
          outputStart = moment(outputStart).startOf("week");
          outputStop = moment(outputStop).endOf("week");
        }
        isDayBinOrGreater = true;
        break;
      }
      case "yearmonthdate": {
        outputStart = moment(outputStart).startOf("date");
        outputStop = moment(outputStop).endOf("date");
        isDayBinOrGreater = true;
        break;
      }
      case "yearmonthdatehours": {
        outputStart = moment(outputStart).startOf("hour");
        outputStop = moment(outputStop).endOf("hour");
        break;
      }
      case "yearmonthdatehoursminutes": {
        outputStart = moment(outputStart).startOf("minute");
        outputStop = moment(outputStop).endOf("minute");
        break;
      }
      case "yearmonthdatehoursminutesseconds": {
        outputStart = moment(outputStart).startOf("second");
        outputStop = moment(outputStop).endOf("second");
        break;
      }
    }
  }

  const isDatetimeColumn =
    columnType === FilledDynamicValueTableColumnType.DATETIME ||
    columnType === FilledDynamicValueTableColumnType.DATETIMETZ;
  const isTimestampPointFilter = outputStart.isSame(outputStop);

  // We want to preserve the timestamp granularity for point filters on
  // timestamp columns, but otherwise we shorten to date since we don't
  // support timestamp filtering on date operations.
  const preserveTimestamp =
    !isDayBinOrGreater &&
    (timestampFilteringEnabled || isTimestampPointFilter) &&
    isDatetimeColumn;

  // shorten if not preserving timestamp or if both ends of the interval are date-only (not ISO timestamps)
  const shorten =
    !preserveTimestamp ||
    (!isIsoTimestamp(interval[0]) && !isIsoTimestamp(interval[1]));
  const outputInterval: [string, string] = [
    maybeShortenIsoDate(toLocalTimestamp(outputStart), shorten),
    maybeShortenIsoDate(toLocalTimestamp(outputStop), shorten),
  ];
  const isSingleDayRange =
    outputStart.diff(outputStop, "days") === 0 && isDayBinOrGreater;
  if (outputInterval[0] === outputInterval[1] || isSingleDayRange) {
    // If the timestamp start/end values are the same, then we need to special
    // case filtering on timetsamp columns and use the `EQ` operator instead so
    // that we preserve legacy behavior of filtering on timestamp values in
    // dialects that support timestamp + string comparison. Once `DATE_EQUAL`
    // supports filtering on timestamps, we can remove this.
    const eq =
      preserveTimestamp && !timestampFilteringEnabled
        ? BinaryColumnPredicateOp.EQ
        : BinaryColumnPredicateOp.DATE_EQUAL;
    const neq =
      preserveTimestamp && !timestampFilteringEnabled
        ? BinaryColumnPredicateOp.NEQ
        : BinaryColumnPredicateOp.DATE_NOT_EQUAL;

    return {
      op: filterType === FilterType.KEEP ? eq : neq,
      // When comparing specific timestamp values, we want to filter based on the exact raw vvalue
      arg: preserveTimestamp ? interval[0] : outputInterval[0],
    };
  }

  if (filterType === FilterType.KEEP) {
    return {
      op: ListBinaryColumnPredicateOp.DATE_BETWEEN,
      arg: outputInterval,
    };
  } else {
    return {
      op: CompoundColumnPredicateOp.OR,
      args: [
        {
          op: BinaryColumnPredicateOp.DATE_BEFORE,
          arg: outputInterval[0],
        },
        {
          op: BinaryColumnPredicateOp.DATE_AFTER,
          arg: outputInterval[1],
        },
      ],
    };
  }
}

function roundNumericInterval(interval: [number, number]): [number, number] {
  const [start, stop] = interval;
  return [Number(start.toPrecision(3)), Number(stop.toPrecision(3))];
}

export function toLocalTimestamp(date: moment.Moment): string {
  const isoDateString = date.toISOString(true);
  // need to account for + and - timezone offsets
  const substringIndex = isoDateString.includes("+")
    ? isoDateString.lastIndexOf("+")
    : isoDateString.lastIndexOf("-");
  return isoDateString.substring(0, substringIndex);
}

export function isIsoTimestamp(s: string): boolean {
  // Regex to match ISO-8601 dates, with optional milliseconds and no timezone
  const isoTimestampRe =
    /^\d{4}-[01]\d-[0-3]\d[T ][0-2]\d:[0-5]\d:[0-5]\d(\.\d+)?$/;
  return isoTimestampRe.test(s);
}

export function isIsoDate(s: string): boolean {
  // Regex to match ISO-8601 dates, with no time
  const isoDataRe = /^\d{4}-[01]\d-[0-3]\d$/;
  return isoDataRe.test(s);
}

function isIsoDateOrTimestamp(s: string): boolean {
  return isIsoTimestamp(s) || isIsoDate(s);
}

function maybeShortenIsoDate(s: string, shorten: boolean): string {
  const m = moment(s);
  if (!m.isValid()) {
    return s;
  }

  if (shorten) {
    return `${m.format("YYYY-MM-DD")}`;
  }
  if (m.millisecond() === 0) {
    return `${m.format("YYYY-MM-DDTHH:mm:ss")}`;
  }

  return s;
}

// Time-based chart selection value will be in the timezone of the chart
// Need to convert to UTC for filtering
function timezoneAdjust(
  timeValue: string,
  displayTimezone: TimezoneName,
): string {
  return isIsoTimestamp(timeValue)
    ? moment.tz(timeValue, displayTimezone).utc().format("YYYY-MM-DDTHH:mm:ss")
    : timeValue;
}
