import {useMemo, useReducer} from "react";
import {createSlice, PayloadAction} from "@reduxjs/toolkit";
import {enableMapSet, castDraft} from "immer";

enableMapSet();

interface SelectionHandlerArgs<T> {
    allItems: T[]
    selectedItems?: T[]
    disabledItems?: T[]
}

interface SelectionStateSets<T> {
    allItemsSet: Set<T>
    selectedSet: Set<T>
    disabledSet: Set<T>
}

export interface UseSelectionState<T> extends SelectionStateSets<T> {
    isAllSelected: boolean
    isAnySelected: boolean
    isAnyEnabled: boolean
}

const selectedFlags = <T>({allItemsSet, selectedSet, disabledSet}: SelectionStateSets<T>) => {
    const allEnabledItems = enabledItems({allItemsSet, disabledSet});

    return ({
        isAllSelected: allEnabledItems.length > 0 && allEnabledItems.every((item) => selectedSet.has(item)),
        isAnySelected: allEnabledItems.some((item) => selectedSet.has(item)),
        isAnyEnabled: allEnabledItems.length > 0
    });
};

const createInitialState = <T>({allItems, selectedItems = [], disabledItems = []}: SelectionHandlerArgs<T>): UseSelectionState<T> => {
    const allItemsSet = new Set(allItems);
    const selectedSet = new Set(selectedItems);
    const disabledSet = new Set(disabledItems);

    return ({
        allItemsSet,
        selectedSet,
        disabledSet,
        ...selectedFlags({allItemsSet, selectedSet, disabledSet})
    });
};

const enabledItems = <T>({allItemsSet, disabledSet}: {allItemsSet: Set<T>, disabledSet: Set<T>}) => (
    [...allItemsSet].filter((item) => !disabledSet.has(item))
);

const isSetEqual = <T>(set1: Set<T>, set2: Set<T>) => (
    (set1.size === set2.size) && [...set1].every((item) => set2.has(item))
);

const createSelectionSlice = <T>(initialState: UseSelectionState<T>) => {
    return createSlice({
        name: "selection",
        initialState,
        reducers: {
            setAllItems(state, action: PayloadAction<{items: T[]}>) {
                const newAllItemsSet = castDraft(new Set(action.payload.items));

                if (!isSetEqual(state.allItemsSet, newAllItemsSet)) {
                    state.allItemsSet.clear();
                    action.payload.items.forEach((item) => state.allItemsSet.add(castDraft(item)));
                }
            },
            setSelected(state, action: PayloadAction<{item: T, isSelected: boolean}>) {
                const item = castDraft(action.payload.item);

                if (action.payload.isSelected) {
                    state.selectedSet.add(item);
                } else {
                    state.selectedSet.delete(item);
                }
            },
            setSelectedItems(state, action: PayloadAction<{items: T[]}>) {
                const newSelectedSet = castDraft(new Set(action.payload.items));

                if (!isSetEqual(state.selectedSet, newSelectedSet)) {
                    state.selectedSet.clear();
                    action.payload.items.forEach((item) => state.selectedSet.add(castDraft(item)));
                }
            },
            setDisabled(state, action: PayloadAction<{item: T, isDisabled: boolean}>) {
                const item = castDraft(action.payload.item);

                if (action.payload.isDisabled) {
                    state.disabledSet.add(item);
                    if (state.selectedSet.has(item)) {
                        state.selectedSet.delete(item);
                        state.selectedSet = new Set(state.selectedSet);
                    }
                } else {
                    state.disabledSet.delete(item);
                }
            },
            setDisabledItems(state, action: PayloadAction<{items: T[]}>) {
                const newDisabledSet = castDraft(new Set(action.payload.items));

                if (!isSetEqual(state.disabledSet, newDisabledSet)) {
                    state.disabledSet = newDisabledSet;
                    state.disabledSet.forEach((item) => state.selectedSet.delete(item));
                }
            },
            toggleAllSelected(state) {
                state.selectedSet = new Set((state.isAllSelected) ? [] : enabledItems(state));
            },
        },
        extraReducers: (builder) => {
            // after every action, update the selected flags (isAllSelected, etc)
            builder
                .addMatcher(() => true, (state) => {
                    Object.assign(state, selectedFlags(state));
                })
        }
    })
};


type AnyFunction = (...args: any[]) => any;

function useSelection<T>(args: SelectionHandlerArgs<T>) {
    const initialState = useMemo(() => createInitialState(args), [args]);
    const selectionSlice = useMemo(() => createSelectionSlice<T>(initialState), [initialState]);
    const [state, dispatch] = useReducer(selectionSlice.reducer, initialState);
    // wrap each of the slice actions in a call to dispatch so that the caller doesn't
    // need to do that manually
    // XXX While this works I think it's a bit clumsy and heavy-handed with types.
    const actions = useMemo(() => {
            const wrapInDispatch = <Func extends AnyFunction>(fn: Func,): ((...args: Parameters<Func>) => void) => {
                const wrappedFn = (...args: Parameters<Func>): void => {
                    dispatch(fn(...args));
                };
                return wrappedFn;
            };
            return Object.fromEntries(
                Object.entries(selectionSlice.actions)
                    .map(([key, actionFn]) => [key, wrapInDispatch(actionFn)])
            ) as typeof selectionSlice.actions
        },
        [selectionSlice, dispatch]
    );

    return {...state, ...actions} as const;
}

export type UseSelection<T> = ReturnType<typeof useSelection<T>>;

export default useSelection;