/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import { PayloadAction, createAsyncThunk, createSlice } from '@reduxjs/toolkit';
import {
  InferenceApi,
  TranslateRequestDtoTargetLanguageEnum,
  ParaphraseRequestDtoTaskEnum,
  ParaphraseCustomRequestDtoTaskEnum,
} from '@ink-ai/insight-service-sdk';
import { RootState } from '.';
import { getApi } from '../common/requestHelper';
import { ChatStreamResponseDto } from './chat';

export const COMPOSE_TASKS = [
  'generate article',
  'paraphrase',
  'translate',
  'summarize',
  'continue writing',
  'custom',
] as const;

export type ComposeTasksType = (typeof COMPOSE_TASKS)[number];

export const paraphraseComposeTasks = [
  'paraphrase',
  'continue writing',
  'summarize',
] as const;

export const REWRITE_COMPOSE_TASKS = [
  'paraphrase',
  'translate',
  'summarize',
  'continue writing',
  'custom',
] as const;

const paraphraseTaskMap: Record<
  (typeof paraphraseComposeTasks)[number],
  ParaphraseRequestDtoTaskEnum
> = {
  paraphrase: ParaphraseRequestDtoTaskEnum.Revise,
  ['continue writing']: ParaphraseRequestDtoTaskEnum.ContinueWriting,
  summarize: ParaphraseRequestDtoTaskEnum.Summary,
};

type TaskParams = {
  task: ComposeTasksType;
  text: string;
  targetLanguage?: TranslateRequestDtoTargetLanguageEnum;
  sourceLanguage?: TranslateRequestDtoTargetLanguageEnum;
  userPrompt?: string;
  referenceIdList?: string[];
  glossaryIdList?: string[];
  referenceStoreIdList?: string[];
};

const initialState = {
  inputtedText: '',
  textEdited: false,
  task: 'generate article' as ComposeTasksType,
  sourceLanguage:
    TranslateRequestDtoTargetLanguageEnum.ZhCn as TranslateRequestDtoTargetLanguageEnum,
  targetLanguage:
    TranslateRequestDtoTargetLanguageEnum.EnUs as TranslateRequestDtoTargetLanguageEnum,
  selectedGlossaries: [] as string[],
  selectedReferenceStores: [] as string[],
  generatedText: '',
  isGenerating: false,
  taskId: '',
  taskParams: undefined as TaskParams | undefined,
};

export type ComposeState = typeof initialState;

/**
 * Asynchronous thunk function for starting a compose task.
 *
 * @param params - The parameters for the compose task.
 * @param getState - The function to get the current state.
 * @param dispatch - The function to dispatch actions.
 * @returns A Promise that resolves to the ID of the compose task.
 */
export const startCompose = createAsyncThunk(
  'compose/StartCompose',
  async (params: TaskParams | undefined, { getState, dispatch }) => {
    const state = getState() as RootState;
    const inferenceApi = await getApi(InferenceApi);
    const task = params?.task ?? state.compose.task;
    // no params, running for a new task, save param for rerun.
    if (!params) {
      dispatch(
        compose.actions.updateTaskParam({
          task: state.compose.task,
          text: state.compose.inputtedText,
          targetLanguage: state.compose.targetLanguage,
        }),
      );
    }
    if (task === 'translate') {
      const res = await inferenceApi.translate({
        text: params?.text ?? state.compose.inputtedText,
        targetLanguage: params?.targetLanguage ?? state.compose.targetLanguage,
        sourceLanguage: params?.sourceLanguage ?? state.compose.sourceLanguage,
        instanceId: state.auth.instanceId,
        glossaries:
          params?.glossaryIdList ?? state.compose.selectedGlossaries ?? [],
        referenceStores:
          params?.referenceStoreIdList ??
          state.compose.selectedReferenceStores ??
          [],
      });
      return res.data.id;
    } else if (task === 'custom') {
      const res = await inferenceApi.paraphraseCustom({
        text: params?.text ?? state.compose.inputtedText,
        userPrompt: params?.userPrompt ?? '',
        referenceIdList: params?.referenceIdList ?? [],
        instanceId: state.auth.instanceId,
        task: ParaphraseCustomRequestDtoTaskEnum.Revise,
      });
      return res.data.id;
    } else if (paraphraseComposeTasks.includes(task as any)) {
      const res = await inferenceApi.paraphrase({
        text: params?.text ?? state.compose.inputtedText,
        task: paraphraseTaskMap[task],
        instanceId: state.auth.instanceId,
      });
      return res.data.id;
    }
    return null;
  },
);

export const compose = createSlice({
  name: 'compose',
  initialState,
  reducers: {
    updateGeneratedText: (
      state,
      { payload }: PayloadAction<ChatStreamResponseDto>,
    ) => {
      if (payload.id !== state.taskId) {
        console.warn('Id mismatch, ignore text generation result');
        return;
      }
      if (payload.isFinal) {
        if (payload.text !== 'Stream finished') {
          state.generatedText = payload.text;
        }
        state.isGenerating = false;
      } else {
        state.generatedText += payload.text;
      }
    },
    updateTaskParam: (state, { payload }: PayloadAction<TaskParams>) => {
      state.taskParams = payload;
    },
    inputText: (state, { payload }: PayloadAction<string>) => {
      state.inputtedText = payload;
      state.textEdited = payload !== '';
    },
    stopGenerating: (state) => {
      state.isGenerating = false;
      state.taskId = '';
    },
    clearText: (state) => {
      state.inputtedText = '';
      state.textEdited = false;
    },
    clearAll: () => {
      return initialState;
    },
    syncSelectedText: (state, { payload }: PayloadAction<string>) => {
      state.inputtedText = payload;
      state.textEdited = false;
    },
    setTargetLanguage: (
      state,
      { payload }: PayloadAction<TranslateRequestDtoTargetLanguageEnum>,
    ) => {
      state.targetLanguage = payload;
    },
    setSourceLanguage: (
      state,
      { payload }: PayloadAction<TranslateRequestDtoTargetLanguageEnum>,
    ) => {
      state.sourceLanguage = payload;
    },
    setSelectedGlossaries(state, action: PayloadAction<string[]>) {
      state.selectedGlossaries = action.payload;
    },
    setSelectedReferenceStores(state, action: PayloadAction<string[]>) {
      state.selectedReferenceStores = action.payload;
    },
    setTask: (state, { payload }: PayloadAction<ComposeTasksType>) => {
      state.task = payload;
    },
  },
  extraReducers: (builder) => {
    builder.addCase(startCompose.pending, (state) => {
      state.isGenerating = true;
      state.generatedText = '';
    });
    builder.addCase(startCompose.fulfilled, (state, { payload }) => {
      state.taskId = payload;
    });
    builder.addCase(startCompose.rejected, (state) => {
      state.isGenerating = false;
    });
  },
});
