merge main
This commit is contained in:
commit
6667ee1c7f
|
@ -301,6 +301,14 @@ iflytek Api Key.
|
||||||
|
|
||||||
iflytek Api Secret.
|
iflytek Api Secret.
|
||||||
|
|
||||||
|
### `CHATGLM_API_KEY` (optional)
|
||||||
|
|
||||||
|
ChatGLM Api Key.
|
||||||
|
|
||||||
|
### `CHATGLM_URL` (optional)
|
||||||
|
|
||||||
|
ChatGLM Api Url.
|
||||||
|
|
||||||
### `HIDE_USER_API_KEY` (optional)
|
### `HIDE_USER_API_KEY` (optional)
|
||||||
|
|
||||||
> Default: Empty
|
> Default: Empty
|
||||||
|
|
|
@ -184,6 +184,13 @@ ByteDance Api Url.
|
||||||
|
|
||||||
讯飞星火Api Secret.
|
讯飞星火Api Secret.
|
||||||
|
|
||||||
|
### `CHATGLM_API_KEY` (可选)
|
||||||
|
|
||||||
|
ChatGLM Api Key.
|
||||||
|
|
||||||
|
### `CHATGLM_URL` (可选)
|
||||||
|
|
||||||
|
ChatGLM Api Url.
|
||||||
|
|
||||||
|
|
||||||
### `HIDE_USER_API_KEY` (可选)
|
### `HIDE_USER_API_KEY` (可选)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import { NextRequest, NextResponse } from "next/server";
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
import { getServerSideConfig } from "../config/server";
|
import { getServerSideConfig } from "../config/server";
|
||||||
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
|
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
|
||||||
import { isModelAvailableInServer } from "../utils/model";
|
|
||||||
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
|
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
|
||||||
|
import { getModelProvider, isModelAvailableInServer } from "../utils/model";
|
||||||
|
|
||||||
const serverConfig = getServerSideConfig();
|
const serverConfig = getServerSideConfig();
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
|
||||||
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
|
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
|
||||||
.forEach((m) => {
|
.forEach((m) => {
|
||||||
const [fullName, displayName] = m.split("=");
|
const [fullName, displayName] = m.split("=");
|
||||||
const [_, providerName] = fullName.split("@");
|
const [_, providerName] = getModelProvider(fullName);
|
||||||
if (providerName === "azure" && !displayName) {
|
if (providerName === "azure" && !displayName) {
|
||||||
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
|
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
|
||||||
"deployments/",
|
"deployments/",
|
||||||
|
|
|
@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
|
||||||
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
|
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
|
||||||
|
|
||||||
import { isEmpty } from "lodash-es";
|
import { isEmpty } from "lodash-es";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
const localStorage = safeLocalStorage();
|
const localStorage = safeLocalStorage();
|
||||||
|
|
||||||
|
@ -148,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) {
|
||||||
text={Locale.Chat.Config.Reset}
|
text={Locale.Chat.Config.Reset}
|
||||||
onClick={async () => {
|
onClick={async () => {
|
||||||
if (await showConfirm(Locale.Memory.ResetConfirm)) {
|
if (await showConfirm(Locale.Memory.ResetConfirm)) {
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) => (session.memoryPrompt = ""),
|
(session) => (session.memoryPrompt = ""),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -173,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) {
|
||||||
updateMask={(updater) => {
|
updateMask={(updater) => {
|
||||||
const mask = { ...session.mask };
|
const mask = { ...session.mask };
|
||||||
updater(mask);
|
updater(mask);
|
||||||
chatStore.updateCurrentSession((session) => (session.mask = mask));
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
|
(session) => (session.mask = mask),
|
||||||
|
);
|
||||||
}}
|
}}
|
||||||
shouldSyncFromGlobal
|
shouldSyncFromGlobal
|
||||||
extraListItems={
|
extraListItems={
|
||||||
|
@ -345,12 +350,14 @@ export function PromptHints(props: {
|
||||||
|
|
||||||
function ClearContextDivider() {
|
function ClearContextDivider() {
|
||||||
const chatStore = useChatStore();
|
const chatStore = useChatStore();
|
||||||
|
const session = chatStore.currentSession();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={styles["clear-context"]}
|
className={styles["clear-context"]}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) => (session.clearContextIndex = undefined),
|
(session) => (session.clearContextIndex = undefined),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -460,6 +467,7 @@ export function ChatActions(props: {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const chatStore = useChatStore();
|
const chatStore = useChatStore();
|
||||||
const pluginStore = usePluginStore();
|
const pluginStore = usePluginStore();
|
||||||
|
const session = chatStore.currentSession();
|
||||||
|
|
||||||
// switch themes
|
// switch themes
|
||||||
const theme = config.theme;
|
const theme = config.theme;
|
||||||
|
@ -476,10 +484,9 @@ export function ChatActions(props: {
|
||||||
const stopAll = () => ChatControllerPool.stopAll();
|
const stopAll = () => ChatControllerPool.stopAll();
|
||||||
|
|
||||||
// switch model
|
// switch model
|
||||||
const currentModel = chatStore.currentSession().mask.modelConfig.model;
|
const currentModel = session.mask.modelConfig.model;
|
||||||
const currentProviderName =
|
const currentProviderName =
|
||||||
chatStore.currentSession().mask.modelConfig?.providerName ||
|
session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
|
||||||
ServiceProvider.OpenAI;
|
|
||||||
const allModels = useAllModels();
|
const allModels = useAllModels();
|
||||||
const models = useMemo(() => {
|
const models = useMemo(() => {
|
||||||
const filteredModels = allModels.filter((m) => m.available);
|
const filteredModels = allModels.filter((m) => m.available);
|
||||||
|
@ -513,12 +520,9 @@ export function ChatActions(props: {
|
||||||
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
|
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
|
||||||
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
|
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
|
||||||
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
|
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
|
||||||
const currentSize =
|
const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
|
||||||
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";
|
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
|
||||||
const currentQuality =
|
const currentStyle = session.mask.modelConfig?.style ?? "vivid";
|
||||||
chatStore.currentSession().mask.modelConfig?.quality ?? "standard";
|
|
||||||
const currentStyle =
|
|
||||||
chatStore.currentSession().mask.modelConfig?.style ?? "vivid";
|
|
||||||
|
|
||||||
const isMobileScreen = useMobileScreen();
|
const isMobileScreen = useMobileScreen();
|
||||||
|
|
||||||
|
@ -536,7 +540,7 @@ export function ChatActions(props: {
|
||||||
if (isUnavailableModel && models.length > 0) {
|
if (isUnavailableModel && models.length > 0) {
|
||||||
// show next model to default model if exist
|
// show next model to default model if exist
|
||||||
let nextModel = models.find((model) => model.isDefault) || models[0];
|
let nextModel = models.find((model) => model.isDefault) || models[0];
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.model = nextModel.name;
|
session.mask.modelConfig.model = nextModel.name;
|
||||||
session.mask.modelConfig.providerName = nextModel?.provider
|
session.mask.modelConfig.providerName = nextModel?.provider
|
||||||
?.providerName as ServiceProvider;
|
?.providerName as ServiceProvider;
|
||||||
|
@ -547,7 +551,7 @@ export function ChatActions(props: {
|
||||||
: nextModel.name,
|
: nextModel.name,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}, [chatStore, currentModel, models]);
|
}, [chatStore, currentModel, models, session]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={styles["chat-input-actions"]}>
|
<div className={styles["chat-input-actions"]}>
|
||||||
|
@ -614,7 +618,7 @@ export function ChatActions(props: {
|
||||||
text={Locale.Chat.InputActions.Clear}
|
text={Locale.Chat.InputActions.Clear}
|
||||||
icon={<BreakIcon />}
|
icon={<BreakIcon />}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
if (session.clearContextIndex === session.messages.length) {
|
if (session.clearContextIndex === session.messages.length) {
|
||||||
session.clearContextIndex = undefined;
|
session.clearContextIndex = undefined;
|
||||||
} else {
|
} else {
|
||||||
|
@ -645,8 +649,8 @@ export function ChatActions(props: {
|
||||||
onClose={() => setShowModelSelector(false)}
|
onClose={() => setShowModelSelector(false)}
|
||||||
onSelection={(s) => {
|
onSelection={(s) => {
|
||||||
if (s.length === 0) return;
|
if (s.length === 0) return;
|
||||||
const [model, providerName] = s[0].split("@");
|
const [model, providerName] = getModelProvider(s[0]);
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.model = model as ModelType;
|
session.mask.modelConfig.model = model as ModelType;
|
||||||
session.mask.modelConfig.providerName =
|
session.mask.modelConfig.providerName =
|
||||||
providerName as ServiceProvider;
|
providerName as ServiceProvider;
|
||||||
|
@ -684,7 +688,7 @@ export function ChatActions(props: {
|
||||||
onSelection={(s) => {
|
onSelection={(s) => {
|
||||||
if (s.length === 0) return;
|
if (s.length === 0) return;
|
||||||
const size = s[0];
|
const size = s[0];
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.size = size;
|
session.mask.modelConfig.size = size;
|
||||||
});
|
});
|
||||||
showToast(size);
|
showToast(size);
|
||||||
|
@ -711,7 +715,7 @@ export function ChatActions(props: {
|
||||||
onSelection={(q) => {
|
onSelection={(q) => {
|
||||||
if (q.length === 0) return;
|
if (q.length === 0) return;
|
||||||
const quality = q[0];
|
const quality = q[0];
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.quality = quality;
|
session.mask.modelConfig.quality = quality;
|
||||||
});
|
});
|
||||||
showToast(quality);
|
showToast(quality);
|
||||||
|
@ -738,7 +742,7 @@ export function ChatActions(props: {
|
||||||
onSelection={(s) => {
|
onSelection={(s) => {
|
||||||
if (s.length === 0) return;
|
if (s.length === 0) return;
|
||||||
const style = s[0];
|
const style = s[0];
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.style = style;
|
session.mask.modelConfig.style = style;
|
||||||
});
|
});
|
||||||
showToast(style);
|
showToast(style);
|
||||||
|
@ -769,7 +773,7 @@ export function ChatActions(props: {
|
||||||
}))}
|
}))}
|
||||||
onClose={() => setShowPluginSelector(false)}
|
onClose={() => setShowPluginSelector(false)}
|
||||||
onSelection={(s) => {
|
onSelection={(s) => {
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.plugin = s as string[];
|
session.mask.plugin = s as string[];
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
|
@ -812,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
|
||||||
icon={<ConfirmIcon />}
|
icon={<ConfirmIcon />}
|
||||||
key="ok"
|
key="ok"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) => (session.messages = messages),
|
(session) => (session.messages = messages),
|
||||||
);
|
);
|
||||||
props.onClose();
|
props.onClose();
|
||||||
|
@ -829,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
|
||||||
type="text"
|
type="text"
|
||||||
value={session.topic}
|
value={session.topic}
|
||||||
onInput={(e) =>
|
onInput={(e) =>
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) => (session.topic = e.currentTarget.value),
|
(session) => (session.topic = e.currentTarget.value),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -990,7 +996,8 @@ function _Chat() {
|
||||||
prev: () => chatStore.nextSession(-1),
|
prev: () => chatStore.nextSession(-1),
|
||||||
next: () => chatStore.nextSession(1),
|
next: () => chatStore.nextSession(1),
|
||||||
clear: () =>
|
clear: () =>
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) => (session.clearContextIndex = session.messages.length),
|
(session) => (session.clearContextIndex = session.messages.length),
|
||||||
),
|
),
|
||||||
fork: () => chatStore.forkSession(),
|
fork: () => chatStore.forkSession(),
|
||||||
|
@ -1061,7 +1068,7 @@ function _Chat() {
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
|
const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
|
||||||
session.messages.forEach((m) => {
|
session.messages.forEach((m) => {
|
||||||
// check if should stop all stale messages
|
// check if should stop all stale messages
|
||||||
|
@ -1087,7 +1094,7 @@ function _Chat() {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, []);
|
}, [session]);
|
||||||
|
|
||||||
// check if should send message
|
// check if should send message
|
||||||
const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
@ -1118,7 +1125,8 @@ function _Chat() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteMessage = (msgId?: string) => {
|
const deleteMessage = (msgId?: string) => {
|
||||||
chatStore.updateCurrentSession(
|
chatStore.updateTargetSession(
|
||||||
|
session,
|
||||||
(session) =>
|
(session) =>
|
||||||
(session.messages = session.messages.filter((m) => m.id !== msgId)),
|
(session.messages = session.messages.filter((m) => m.id !== msgId)),
|
||||||
);
|
);
|
||||||
|
@ -1185,7 +1193,7 @@ function _Chat() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const onPinMessage = (message: ChatMessage) => {
|
const onPinMessage = (message: ChatMessage) => {
|
||||||
chatStore.updateCurrentSession((session) =>
|
chatStore.updateTargetSession(session, (session) =>
|
||||||
session.mask.context.push(message),
|
session.mask.context.push(message),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1607,7 +1615,7 @@ function _Chat() {
|
||||||
title={Locale.Chat.Actions.RefreshTitle}
|
title={Locale.Chat.Actions.RefreshTitle}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
showToast(Locale.Chat.Actions.RefreshToast);
|
showToast(Locale.Chat.Actions.RefreshToast);
|
||||||
chatStore.summarizeSession(true);
|
chatStore.summarizeSession(true, session);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
@ -1711,14 +1719,17 @@ function _Chat() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
chatStore.updateCurrentSession((session) => {
|
chatStore.updateTargetSession(
|
||||||
const m = session.mask.context
|
session,
|
||||||
.concat(session.messages)
|
(session) => {
|
||||||
.find((m) => m.id === message.id);
|
const m = session.mask.context
|
||||||
if (m) {
|
.concat(session.messages)
|
||||||
m.content = newContent;
|
.find((m) => m.id === message.id);
|
||||||
}
|
if (m) {
|
||||||
});
|
m.content = newContent;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
}}
|
}}
|
||||||
></IconButton>
|
></IconButton>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
|
||||||
import { useAllModels } from "../utils/hooks";
|
import { useAllModels } from "../utils/hooks";
|
||||||
import { groupBy } from "lodash-es";
|
import { groupBy } from "lodash-es";
|
||||||
import styles from "./model-config.module.scss";
|
import styles from "./model-config.module.scss";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
export function ModelConfigList(props: {
|
export function ModelConfigList(props: {
|
||||||
modelConfig: ModelConfig;
|
modelConfig: ModelConfig;
|
||||||
|
@ -28,7 +29,9 @@ export function ModelConfigList(props: {
|
||||||
value={value}
|
value={value}
|
||||||
align="left"
|
align="left"
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const [model, providerName] = e.currentTarget.value.split("@");
|
const [model, providerName] = getModelProvider(
|
||||||
|
e.currentTarget.value,
|
||||||
|
);
|
||||||
props.updateConfig((config) => {
|
props.updateConfig((config) => {
|
||||||
config.model = ModalConfigValidator.model(model);
|
config.model = ModalConfigValidator.model(model);
|
||||||
config.providerName = providerName as ServiceProvider;
|
config.providerName = providerName as ServiceProvider;
|
||||||
|
@ -247,7 +250,9 @@ export function ModelConfigList(props: {
|
||||||
aria-label={Locale.Settings.CompressModel.Title}
|
aria-label={Locale.Settings.CompressModel.Title}
|
||||||
value={compressModelValue}
|
value={compressModelValue}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const [model, providerName] = e.currentTarget.value.split("@");
|
const [model, providerName] = getModelProvider(
|
||||||
|
e.currentTarget.value,
|
||||||
|
);
|
||||||
props.updateConfig((config) => {
|
props.updateConfig((config) => {
|
||||||
config.compressModel = ModalConfigValidator.model(model);
|
config.compressModel = ModalConfigValidator.model(model);
|
||||||
config.compressProviderName = providerName as ServiceProvider;
|
config.compressProviderName = providerName as ServiceProvider;
|
||||||
|
|
|
@ -232,7 +232,7 @@ export const XAI = {
|
||||||
|
|
||||||
export const ChatGLM = {
|
export const ChatGLM = {
|
||||||
ExampleEndpoint: CHATGLM_BASE_URL,
|
ExampleEndpoint: CHATGLM_BASE_URL,
|
||||||
ChatPath: "/api/paas/v4/chat/completions",
|
ChatPath: "api/paas/v4/chat/completions",
|
||||||
};
|
};
|
||||||
|
|
||||||
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
|
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
|
||||||
|
@ -327,12 +327,13 @@ const anthropicModels = [
|
||||||
"claude-2.1",
|
"claude-2.1",
|
||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-opus-latest",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
"claude-3-5-haiku-20241022",
|
"claude-3-5-haiku-20241022",
|
||||||
"claude-3-5-sonnet-20240620",
|
"claude-3-5-sonnet-20240620",
|
||||||
"claude-3-5-sonnet-20241022",
|
"claude-3-5-sonnet-20241022",
|
||||||
"claude-3-5-sonnet-latest",
|
"claude-3-5-sonnet-latest",
|
||||||
"claude-3-opus-latest",
|
"claude-3-5-haiku-latest",
|
||||||
];
|
];
|
||||||
|
|
||||||
const baiduModels = [
|
const baiduModels = [
|
||||||
|
|
|
@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
|
||||||
import { createPersistStore } from "../utils/store";
|
import { createPersistStore } from "../utils/store";
|
||||||
import { ensure } from "../utils/clone";
|
import { ensure } from "../utils/clone";
|
||||||
import { DEFAULT_CONFIG } from "./config";
|
import { DEFAULT_CONFIG } from "./config";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
|
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
|
||||||
|
|
||||||
|
@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
const defaultModel = res.defaultModel ?? "";
|
const defaultModel = res.defaultModel ?? "";
|
||||||
if (defaultModel !== "") {
|
if (defaultModel !== "") {
|
||||||
const [model, providerName] = defaultModel.split("@");
|
const [model, providerName] = getModelProvider(defaultModel);
|
||||||
DEFAULT_CONFIG.modelConfig.model = model;
|
DEFAULT_CONFIG.modelConfig.model = model;
|
||||||
DEFAULT_CONFIG.modelConfig.providerName = providerName;
|
DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
|
@ -352,13 +352,13 @@ export const useChatStore = createPersistStore(
|
||||||
return session;
|
return session;
|
||||||
},
|
},
|
||||||
|
|
||||||
onNewMessage(message: ChatMessage) {
|
onNewMessage(message: ChatMessage, targetSession: ChatSession) {
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(targetSession, (session) => {
|
||||||
session.messages = session.messages.concat();
|
session.messages = session.messages.concat();
|
||||||
session.lastUpdate = Date.now();
|
session.lastUpdate = Date.now();
|
||||||
});
|
});
|
||||||
get().updateStat(message);
|
get().updateStat(message, targetSession);
|
||||||
get().summarizeSession();
|
get().summarizeSession(false, targetSession);
|
||||||
},
|
},
|
||||||
|
|
||||||
async onUserInput(content: string, attachImages?: string[]) {
|
async onUserInput(content: string, attachImages?: string[]) {
|
||||||
|
@ -396,10 +396,10 @@ export const useChatStore = createPersistStore(
|
||||||
// get recent messages
|
// get recent messages
|
||||||
const recentMessages = get().getMessagesWithMemory();
|
const recentMessages = get().getMessagesWithMemory();
|
||||||
const sendMessages = recentMessages.concat(userMessage);
|
const sendMessages = recentMessages.concat(userMessage);
|
||||||
const messageIndex = get().currentSession().messages.length + 1;
|
const messageIndex = session.messages.length + 1;
|
||||||
|
|
||||||
// save user's and bot's message
|
// save user's and bot's message
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
const savedUserMessage = {
|
const savedUserMessage = {
|
||||||
...userMessage,
|
...userMessage,
|
||||||
content: mContent,
|
content: mContent,
|
||||||
|
@ -420,7 +420,7 @@ export const useChatStore = createPersistStore(
|
||||||
if (message) {
|
if (message) {
|
||||||
botMessage.content = message;
|
botMessage.content = message;
|
||||||
}
|
}
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.messages = session.messages.concat();
|
session.messages = session.messages.concat();
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
@ -428,13 +428,14 @@ export const useChatStore = createPersistStore(
|
||||||
botMessage.streaming = false;
|
botMessage.streaming = false;
|
||||||
if (message) {
|
if (message) {
|
||||||
botMessage.content = message;
|
botMessage.content = message;
|
||||||
get().onNewMessage(botMessage);
|
botMessage.date = new Date().toLocaleString();
|
||||||
|
get().onNewMessage(botMessage, session);
|
||||||
}
|
}
|
||||||
ChatControllerPool.remove(session.id, botMessage.id);
|
ChatControllerPool.remove(session.id, botMessage.id);
|
||||||
},
|
},
|
||||||
onBeforeTool(tool: ChatMessageTool) {
|
onBeforeTool(tool: ChatMessageTool) {
|
||||||
(botMessage.tools = botMessage?.tools || []).push(tool);
|
(botMessage.tools = botMessage?.tools || []).push(tool);
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.messages = session.messages.concat();
|
session.messages = session.messages.concat();
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
@ -444,7 +445,7 @@ export const useChatStore = createPersistStore(
|
||||||
tools[i] = { ...tool };
|
tools[i] = { ...tool };
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.messages = session.messages.concat();
|
session.messages = session.messages.concat();
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
@ -459,7 +460,7 @@ export const useChatStore = createPersistStore(
|
||||||
botMessage.streaming = false;
|
botMessage.streaming = false;
|
||||||
userMessage.isError = !isAborted;
|
userMessage.isError = !isAborted;
|
||||||
botMessage.isError = !isAborted;
|
botMessage.isError = !isAborted;
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.messages = session.messages.concat();
|
session.messages = session.messages.concat();
|
||||||
});
|
});
|
||||||
ChatControllerPool.remove(
|
ChatControllerPool.remove(
|
||||||
|
@ -591,16 +592,19 @@ export const useChatStore = createPersistStore(
|
||||||
set(() => ({ sessions }));
|
set(() => ({ sessions }));
|
||||||
},
|
},
|
||||||
|
|
||||||
resetSession() {
|
resetSession(session: ChatSession) {
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.messages = [];
|
session.messages = [];
|
||||||
session.memoryPrompt = "";
|
session.memoryPrompt = "";
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
summarizeSession(refreshTitle: boolean = false) {
|
summarizeSession(
|
||||||
|
refreshTitle: boolean = false,
|
||||||
|
targetSession: ChatSession,
|
||||||
|
) {
|
||||||
const config = useAppConfig.getState();
|
const config = useAppConfig.getState();
|
||||||
const session = get().currentSession();
|
const session = targetSession;
|
||||||
const modelConfig = session.mask.modelConfig;
|
const modelConfig = session.mask.modelConfig;
|
||||||
// skip summarize when using dalle3?
|
// skip summarize when using dalle3?
|
||||||
if (isDalle3(modelConfig.model)) {
|
if (isDalle3(modelConfig.model)) {
|
||||||
|
@ -651,7 +655,8 @@ export const useChatStore = createPersistStore(
|
||||||
},
|
},
|
||||||
onFinish(message, responseRes) {
|
onFinish(message, responseRes) {
|
||||||
if (responseRes?.status === 200) {
|
if (responseRes?.status === 200) {
|
||||||
get().updateCurrentSession(
|
get().updateTargetSession(
|
||||||
|
session,
|
||||||
(session) =>
|
(session) =>
|
||||||
(session.topic =
|
(session.topic =
|
||||||
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
|
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
|
||||||
|
@ -719,7 +724,7 @@ export const useChatStore = createPersistStore(
|
||||||
onFinish(message, responseRes) {
|
onFinish(message, responseRes) {
|
||||||
if (responseRes?.status === 200) {
|
if (responseRes?.status === 200) {
|
||||||
console.log("[Memory] ", message);
|
console.log("[Memory] ", message);
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.lastSummarizeIndex = lastSummarizeIndex;
|
session.lastSummarizeIndex = lastSummarizeIndex;
|
||||||
session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
|
session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
|
||||||
});
|
});
|
||||||
|
@ -732,20 +737,22 @@ export const useChatStore = createPersistStore(
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
updateStat(message: ChatMessage) {
|
updateStat(message: ChatMessage, session: ChatSession) {
|
||||||
get().updateCurrentSession((session) => {
|
get().updateTargetSession(session, (session) => {
|
||||||
session.stat.charCount += message.content.length;
|
session.stat.charCount += message.content.length;
|
||||||
// TODO: should update chat count and word count
|
// TODO: should update chat count and word count
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
updateTargetSession(
|
||||||
updateCurrentSession(updater: (session: ChatSession) => void) {
|
targetSession: ChatSession,
|
||||||
|
updater: (session: ChatSession) => void,
|
||||||
|
) {
|
||||||
const sessions = get().sessions;
|
const sessions = get().sessions;
|
||||||
const index = get().currentSessionIndex;
|
const index = sessions.findIndex((s) => s.id === targetSession.id);
|
||||||
|
if (index < 0) return;
|
||||||
updater(sessions[index]);
|
updater(sessions[index]);
|
||||||
set(() => ({ sessions }));
|
set(() => ({ sessions }));
|
||||||
},
|
},
|
||||||
|
|
||||||
async clearAllData() {
|
async clearAllData() {
|
||||||
await indexedDBStorage.clear();
|
await indexedDBStorage.clear();
|
||||||
localStorage.clear();
|
localStorage.clear();
|
||||||
|
|
|
@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get model name and provider from a formatted string,
|
||||||
|
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
|
||||||
|
* @param modelWithProvider model name with provider separated by last `@` char,
|
||||||
|
* @returns [model, provider] tuple, if no `@` char found, provider is undefined
|
||||||
|
*/
|
||||||
|
export function getModelProvider(modelWithProvider: string): [string, string?] {
|
||||||
|
const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
|
||||||
|
return [model, provider];
|
||||||
|
}
|
||||||
|
|
||||||
export function collectModelTable(
|
export function collectModelTable(
|
||||||
models: readonly LLMModel[],
|
models: readonly LLMModel[],
|
||||||
customModels: string,
|
customModels: string,
|
||||||
|
@ -79,10 +90,10 @@ export function collectModelTable(
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
// 1. find model by name, and set available value
|
// 1. find model by name, and set available value
|
||||||
const [customModelName, customProviderName] = name.split("@");
|
const [customModelName, customProviderName] = getModelProvider(name);
|
||||||
let count = 0;
|
let count = 0;
|
||||||
for (const fullName in modelTable) {
|
for (const fullName in modelTable) {
|
||||||
const [modelName, providerName] = fullName.split("@");
|
const [modelName, providerName] = getModelProvider(fullName);
|
||||||
if (
|
if (
|
||||||
customModelName == modelName &&
|
customModelName == modelName &&
|
||||||
(customProviderName === undefined ||
|
(customProviderName === undefined ||
|
||||||
|
@ -102,7 +113,7 @@ export function collectModelTable(
|
||||||
}
|
}
|
||||||
// 2. if model not exists, create new model with available value
|
// 2. if model not exists, create new model with available value
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
let [customModelName, customProviderName] = name.split("@");
|
let [customModelName, customProviderName] = getModelProvider(name);
|
||||||
const provider = customProvider(
|
const provider = customProvider(
|
||||||
customProviderName || customModelName,
|
customProviderName || customModelName,
|
||||||
);
|
);
|
||||||
|
@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
|
||||||
for (const key of Object.keys(modelTable)) {
|
for (const key of Object.keys(modelTable)) {
|
||||||
if (
|
if (
|
||||||
modelTable[key].available &&
|
modelTable[key].available &&
|
||||||
key.split("@").shift() == defaultModel
|
getModelProvider(key)[0] == defaultModel
|
||||||
) {
|
) {
|
||||||
modelTable[key].isDefault = true;
|
modelTable[key].isDefault = true;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
},
|
},
|
||||||
"package": {
|
"package": {
|
||||||
"productName": "NextChat",
|
"productName": "NextChat",
|
||||||
"version": "2.15.6"
|
"version": "2.15.7"
|
||||||
},
|
},
|
||||||
"tauri": {
|
"tauri": {
|
||||||
"allowlist": {
|
"allowlist": {
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import { getModelProvider } from "../app/utils/model";
|
||||||
|
|
||||||
|
describe("getModelProvider", () => {
|
||||||
|
test("should return model and provider when input contains '@'", () => {
|
||||||
|
const input = "model@provider";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model");
|
||||||
|
expect(provider).toBe("provider");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should return model and undefined provider when input does not contain '@'", () => {
|
||||||
|
const input = "model";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model");
|
||||||
|
expect(provider).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should handle multiple '@' characters correctly", () => {
|
||||||
|
const input = "model@provider@extra";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model@provider");
|
||||||
|
expect(provider).toBe("extra");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should return empty strings when input is empty", () => {
|
||||||
|
const input = "";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("");
|
||||||
|
expect(provider).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
Loading…
Reference in New Issue