Merge branch 'main' of https://github.com/ConnectAI-E/ChatGPT-Next-Web into feature/update-target-session
This commit is contained in:
commit
106461a1e7
|
@ -1,8 +1,8 @@
|
||||||
import { NextRequest, NextResponse } from "next/server";
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
import { getServerSideConfig } from "../config/server";
|
import { getServerSideConfig } from "../config/server";
|
||||||
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
|
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
|
||||||
import { isModelAvailableInServer } from "../utils/model";
|
|
||||||
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
|
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
|
||||||
|
import { getModelProvider, isModelAvailableInServer } from "../utils/model";
|
||||||
|
|
||||||
const serverConfig = getServerSideConfig();
|
const serverConfig = getServerSideConfig();
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
|
||||||
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
|
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
|
||||||
.forEach((m) => {
|
.forEach((m) => {
|
||||||
const [fullName, displayName] = m.split("=");
|
const [fullName, displayName] = m.split("=");
|
||||||
const [_, providerName] = fullName.split("@");
|
const [_, providerName] = getModelProvider(fullName);
|
||||||
if (providerName === "azure" && !displayName) {
|
if (providerName === "azure" && !displayName) {
|
||||||
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
|
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
|
||||||
"deployments/",
|
"deployments/",
|
||||||
|
|
|
@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
|
||||||
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
|
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
|
||||||
|
|
||||||
import { isEmpty } from "lodash-es";
|
import { isEmpty } from "lodash-es";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
const localStorage = safeLocalStorage();
|
const localStorage = safeLocalStorage();
|
||||||
|
|
||||||
|
@ -648,7 +649,7 @@ export function ChatActions(props: {
|
||||||
onClose={() => setShowModelSelector(false)}
|
onClose={() => setShowModelSelector(false)}
|
||||||
onSelection={(s) => {
|
onSelection={(s) => {
|
||||||
if (s.length === 0) return;
|
if (s.length === 0) return;
|
||||||
const [model, providerName] = s[0].split("@");
|
const [model, providerName] = getModelProvider(s[0]);
|
||||||
chatStore.updateTargetSession(session, (session) => {
|
chatStore.updateTargetSession(session, (session) => {
|
||||||
session.mask.modelConfig.model = model as ModelType;
|
session.mask.modelConfig.model = model as ModelType;
|
||||||
session.mask.modelConfig.providerName =
|
session.mask.modelConfig.providerName =
|
||||||
|
|
|
@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
|
||||||
import { useAllModels } from "../utils/hooks";
|
import { useAllModels } from "../utils/hooks";
|
||||||
import { groupBy } from "lodash-es";
|
import { groupBy } from "lodash-es";
|
||||||
import styles from "./model-config.module.scss";
|
import styles from "./model-config.module.scss";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
export function ModelConfigList(props: {
|
export function ModelConfigList(props: {
|
||||||
modelConfig: ModelConfig;
|
modelConfig: ModelConfig;
|
||||||
|
@ -28,7 +29,9 @@ export function ModelConfigList(props: {
|
||||||
value={value}
|
value={value}
|
||||||
align="left"
|
align="left"
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const [model, providerName] = e.currentTarget.value.split("@");
|
const [model, providerName] = getModelProvider(
|
||||||
|
e.currentTarget.value,
|
||||||
|
);
|
||||||
props.updateConfig((config) => {
|
props.updateConfig((config) => {
|
||||||
config.model = ModalConfigValidator.model(model);
|
config.model = ModalConfigValidator.model(model);
|
||||||
config.providerName = providerName as ServiceProvider;
|
config.providerName = providerName as ServiceProvider;
|
||||||
|
@ -247,7 +250,9 @@ export function ModelConfigList(props: {
|
||||||
aria-label={Locale.Settings.CompressModel.Title}
|
aria-label={Locale.Settings.CompressModel.Title}
|
||||||
value={compressModelValue}
|
value={compressModelValue}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const [model, providerName] = e.currentTarget.value.split("@");
|
const [model, providerName] = getModelProvider(
|
||||||
|
e.currentTarget.value,
|
||||||
|
);
|
||||||
props.updateConfig((config) => {
|
props.updateConfig((config) => {
|
||||||
config.compressModel = ModalConfigValidator.model(model);
|
config.compressModel = ModalConfigValidator.model(model);
|
||||||
config.compressProviderName = providerName as ServiceProvider;
|
config.compressProviderName = providerName as ServiceProvider;
|
||||||
|
|
|
@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
|
||||||
import { createPersistStore } from "../utils/store";
|
import { createPersistStore } from "../utils/store";
|
||||||
import { ensure } from "../utils/clone";
|
import { ensure } from "../utils/clone";
|
||||||
import { DEFAULT_CONFIG } from "./config";
|
import { DEFAULT_CONFIG } from "./config";
|
||||||
|
import { getModelProvider } from "../utils/model";
|
||||||
|
|
||||||
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
|
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
|
||||||
|
|
||||||
|
@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
const defaultModel = res.defaultModel ?? "";
|
const defaultModel = res.defaultModel ?? "";
|
||||||
if (defaultModel !== "") {
|
if (defaultModel !== "") {
|
||||||
const [model, providerName] = defaultModel.split("@");
|
const [model, providerName] = getModelProvider(defaultModel);
|
||||||
DEFAULT_CONFIG.modelConfig.model = model;
|
DEFAULT_CONFIG.modelConfig.model = model;
|
||||||
DEFAULT_CONFIG.modelConfig.providerName = providerName;
|
DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
|
@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get model name and provider from a formatted string,
|
||||||
|
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
|
||||||
|
* @param modelWithProvider model name with provider separated by last `@` char,
|
||||||
|
* @returns [model, provider] tuple, if no `@` char found, provider is undefined
|
||||||
|
*/
|
||||||
|
export function getModelProvider(modelWithProvider: string): [string, string?] {
|
||||||
|
const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
|
||||||
|
return [model, provider];
|
||||||
|
}
|
||||||
|
|
||||||
export function collectModelTable(
|
export function collectModelTable(
|
||||||
models: readonly LLMModel[],
|
models: readonly LLMModel[],
|
||||||
customModels: string,
|
customModels: string,
|
||||||
|
@ -79,10 +90,10 @@ export function collectModelTable(
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
// 1. find model by name, and set available value
|
// 1. find model by name, and set available value
|
||||||
const [customModelName, customProviderName] = name.split("@");
|
const [customModelName, customProviderName] = getModelProvider(name);
|
||||||
let count = 0;
|
let count = 0;
|
||||||
for (const fullName in modelTable) {
|
for (const fullName in modelTable) {
|
||||||
const [modelName, providerName] = fullName.split("@");
|
const [modelName, providerName] = getModelProvider(fullName);
|
||||||
if (
|
if (
|
||||||
customModelName == modelName &&
|
customModelName == modelName &&
|
||||||
(customProviderName === undefined ||
|
(customProviderName === undefined ||
|
||||||
|
@ -102,7 +113,7 @@ export function collectModelTable(
|
||||||
}
|
}
|
||||||
// 2. if model not exists, create new model with available value
|
// 2. if model not exists, create new model with available value
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
let [customModelName, customProviderName] = name.split("@");
|
let [customModelName, customProviderName] = getModelProvider(name);
|
||||||
const provider = customProvider(
|
const provider = customProvider(
|
||||||
customProviderName || customModelName,
|
customProviderName || customModelName,
|
||||||
);
|
);
|
||||||
|
@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
|
||||||
for (const key of Object.keys(modelTable)) {
|
for (const key of Object.keys(modelTable)) {
|
||||||
if (
|
if (
|
||||||
modelTable[key].available &&
|
modelTable[key].available &&
|
||||||
key.split("@").shift() == defaultModel
|
getModelProvider(key)[0] == defaultModel
|
||||||
) {
|
) {
|
||||||
modelTable[key].isDefault = true;
|
modelTable[key].isDefault = true;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import { getModelProvider } from "../app/utils/model";
|
||||||
|
|
||||||
|
describe("getModelProvider", () => {
|
||||||
|
test("should return model and provider when input contains '@'", () => {
|
||||||
|
const input = "model@provider";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model");
|
||||||
|
expect(provider).toBe("provider");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should return model and undefined provider when input does not contain '@'", () => {
|
||||||
|
const input = "model";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model");
|
||||||
|
expect(provider).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should handle multiple '@' characters correctly", () => {
|
||||||
|
const input = "model@provider@extra";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("model@provider");
|
||||||
|
expect(provider).toBe("extra");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should return empty strings when input is empty", () => {
|
||||||
|
const input = "";
|
||||||
|
const [model, provider] = getModelProvider(input);
|
||||||
|
expect(model).toBe("");
|
||||||
|
expect(provider).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
Loading…
Reference in New Issue