Merge pull request #4480 from ChatGPTNextWeb/chore-fix

feat: Solve the problem of using openai interface protocol for user-d…
This commit is contained in:
DeanYao 2024-04-10 09:26:21 +08:00 committed by GitHub
commit 67acc38a1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 49 additions and 7 deletions

View File

@ -40,6 +40,7 @@ import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api"; import { ClientApi } from "../client/api";
import { getMessageTextContent } from "../utils"; import { getMessageTextContent } from "../utils";
import { identifyDefaultClaudeModel } from "../utils/checkers";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />, loading: () => <LoadingIcon />,
@ -315,7 +316,7 @@ export function PreviewActions(props: {
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) { if (config.modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.model.startsWith("claude")) { } else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

View File

@ -29,6 +29,7 @@ import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api"; import { ClientApi } from "../client/api";
import { useAccessStore } from "../store"; import { useAccessStore } from "../store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
export function Loading(props: { noLogo?: boolean }) { export function Loading(props: { noLogo?: boolean }) {
return ( return (
@ -173,7 +174,7 @@ export function useLoadData() {
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) { if (config.modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.model.startsWith("claude")) { } else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

View File

@ -367,4 +367,14 @@ export const DEFAULT_MODELS = [
export const CHAT_PAGE_SIZE = 15; export const CHAT_PAGE_SIZE = 15;
export const MAX_RENDER_MSG_COUNT = 45; export const MAX_RENDER_MSG_COUNT = 45;
export const internalWhiteWebDavEndpoints = ["https://dav.jianguoyun.com"]; // some famous webdav endpoints
export const internalWhiteWebDavEndpoints = [
"https://dav.jianguoyun.com/dav/",
"https://dav.dropdav.com/",
"https://dav.box.com/dav",
"https://nanao.teracloud.jp/dav/",
"https://webdav.4shared.com/",
"https://dav.idrivesync.com",
"https://webdav.yandex.com",
"https://app.koofr.net/dav/Koofr",
];

View File

@ -20,6 +20,7 @@ import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token"; import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
export type ChatMessage = RequestMessage & { export type ChatMessage = RequestMessage & {
date: string; date: string;
@ -353,7 +354,7 @@ export const useChatStore = createPersistStore(
var api: ClientApi; var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) { if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.model.startsWith("claude")) { } else if (identifyDefaultClaudeModel(modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
@ -539,7 +540,7 @@ export const useChatStore = createPersistStore(
var api: ClientApi; var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) { if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.model.startsWith("claude")) { } else if (identifyDefaultClaudeModel(modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

21
app/utils/checkers.ts Normal file
View File

@ -0,0 +1,21 @@
import { useAccessStore } from "../store/access";
import { useAppConfig } from "../store/config";
import { collectModels } from "./model";
export function identifyDefaultClaudeModel(modelName: string) {
const accessStore = useAccessStore.getState();
const configStore = useAppConfig.getState();
const allModals = collectModels(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
);
const modelMeta = allModals.find((m) => m.name === modelName);
return (
modelName.startsWith("claude") &&
modelMeta &&
modelMeta.provider?.providerType === "anthropic"
);
}

View File

@ -22,6 +22,12 @@ export function collectModelTable(
}; };
}); });
const customProvider = (modelName: string) => ({
id: modelName,
providerName: "",
providerType: "custom",
});
// server custom models // server custom models
customModels customModels
.split(",") .split(",")
@ -34,13 +40,15 @@ export function collectModelTable(
// enable or disable all models // enable or disable all models
if (name === "all") { if (name === "all") {
Object.values(modelTable).forEach((model) => (model.available = available)); Object.values(modelTable).forEach(
(model) => (model.available = available),
);
} else { } else {
modelTable[name] = { modelTable[name] = {
name, name,
displayName: displayName || name, displayName: displayName || name,
available, available,
provider: modelTable[name]?.provider, // Use optional chaining provider: modelTable[name]?.provider ?? customProvider(name), // Use optional chaining
}; };
} }
}); });