Compare commits

..

12 Commits

Author SHA1 Message Date
lyf
ebe617b733 fix max_completions_tokens 2024-10-16 16:24:02 +08:00
Lloyd Zhou
c139038e01 Merge pull request #5639 from code-october/fix/auth-ui
优化访问码输入框
2024-10-11 19:11:35 +08:00
code-october
4a7fd3a380 优化首页 api 输入框 2024-10-11 10:36:11 +00:00
code-october
c98dc31cdf 优化访问码输入框 2024-10-11 09:03:20 +00:00
Lloyd Zhou
c5074f0aa4 Merge pull request #5581 from ConnectAI-E/feature/gemini-functioncall
google gemini support function call
2024-10-10 21:02:36 +08:00
Lloyd Zhou
ba58018a15 Merge pull request #5211 from ConnectAI-E/feature/jest
feat: jest
2024-10-10 21:02:05 +08:00
lloydzhou
4ae34ea3ee merge main 2024-10-09 18:27:23 +08:00
Dogtiti
acf9fa36f9 Merge branch 'main' of https://github.com/ConnectAI-E/ChatGPT-Next-Web into feature/jest 2024-10-08 10:30:47 +08:00
Dogtiti
461154bb03 fix: format package 2024-10-08 10:29:42 +08:00
lloydzhou
450766a44b google gemini support function call 2024-10-03 20:28:15 +08:00
Dogtiti
1287e39cc6 feat: run test before build 2024-08-06 19:24:47 +08:00
Dogtiti
1ef2aa35e9 feat: jest 2024-08-06 18:03:27 +08:00
12 changed files with 2104 additions and 165 deletions

View File

@@ -7,21 +7,25 @@ import {
LLMUsage,
SpeechOptions,
} from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
useAccessStore,
useAppConfig,
useChatStore,
usePluginStore,
ChatMessageTool,
} from "@/app/store";
import { stream } from "@/app/utils/chat";
import { getClientConfig } from "@/app/config/client";
import { GEMINI_BASE_URL } from "@/app/constant";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import {
getMessageTextContent,
getMessageImages,
isVisionModel,
} from "@/app/utils";
import { preProcessImageContent } from "@/app/utils/chat";
import { nanoid } from "nanoid";
import { RequestPayload } from "./openai";
import { fetch } from "@/app/utils/stream";
export class GeminiProApi implements LLMApi {
@@ -178,115 +182,81 @@ export class GeminiProApi implements LLMApi {
);
if (shouldStream) {
let responseText = "";
let remainText = "";
let finished = false;
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
chatPath,
requestPayload,
getHeaders(),
// @ts-ignore
[{ functionDeclarations: tools.map((tool) => tool.function) }],
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
const chunkJson = JSON.parse(text);
const finish = () => {
if (!finished) {
finished = true;
options.onFinish(responseText + remainText);
}
};
// animate response to make it looks smooth
function animateResponseText() {
if (finished || controller.signal.aborted) {
responseText += remainText;
finish();
return;
}
if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
options.onUpdate?.(responseText, fetchText);
}
requestAnimationFrame(animateResponseText);
}
// start animaion
animateResponseText();
controller.signal.onabort = finish;
fetchEventSource(chatPath, {
fetch: fetch as any,
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[Gemini] request response content type: ",
contentType,
const functionCall = chunkJson?.candidates
?.at(0)
?.content.parts.at(0)?.functionCall;
if (functionCall) {
const { name, args } = functionCall;
runTools.push({
id: nanoid(),
type: "function",
function: {
name,
arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse
},
});
}
return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text;
},
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// @ts-ignore
requestPayload?.contents?.splice(
// @ts-ignore
requestPayload?.contents?.length,
0,
{
role: "model",
parts: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
functionCall: {
name: tool?.function?.name,
args: JSON.parse(tool?.function?.arguments as string),
},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "function",
parts: [
{
functionResponse: {
name: result.name,
response: {
name: result.name,
content: result.content, // TODO just text content...
},
},
},
],
})),
);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const delta = apiClient.extractMessage(json);
if (delta) {
remainText += delta;
}
const blockReason = json?.promptFeedback?.blockReason;
if (blockReason) {
// being blocked
console.log(`[Google] [Safety Ratings] result:`, blockReason);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
options,
);
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);

View File

@@ -63,7 +63,7 @@ export interface RequestPayload {
presence_penalty: number;
frequency_penalty: number;
top_p: number;
max_tokens?: number;
max_completions_tokens?: number;
}
export interface DalleRequestPayload {
@@ -228,13 +228,16 @@ export class ChatGPTApi implements LLMApi {
presence_penalty: !isO1 ? modelConfig.presence_penalty : 0,
frequency_penalty: !isO1 ? modelConfig.frequency_penalty : 0,
top_p: !isO1 ? modelConfig.top_p : 1,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
// max_completions_tokens: Math.max(modelConfig.max_completions_tokens, 1024),
// Please do not ask me why not send max_completions_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
// add max_tokens to vision model
// add max_completions_tokens to vision model
if (visionModel) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
requestPayload["max_completions_tokens"] = Math.max(
modelConfig.max_completions_tokens,
4000,
);
}
}

View File

@@ -11,6 +11,7 @@ import Logo from "../icons/logo.svg";
import { useMobileScreen } from "@/app/utils";
import BotIcon from "../icons/bot.svg";
import { getClientConfig } from "../config/client";
import { PasswordInput } from "./ui-lib";
import LeftIcon from "@/app/icons/left.svg";
import { safeLocalStorage } from "@/app/utils";
import {
@@ -60,36 +61,43 @@ export function AuthPage() {
<div className={styles["auth-title"]}>{Locale.Auth.Title}</div>
<div className={styles["auth-tips"]}>{Locale.Auth.Tips}</div>
<input
className={styles["auth-input"]}
type="password"
placeholder={Locale.Auth.Input}
<PasswordInput
style={{ marginTop: "3vh", marginBottom: "3vh" }}
aria={Locale.Settings.ShowPassword}
aria-label={Locale.Auth.Input}
value={accessStore.accessCode}
type="text"
placeholder={Locale.Auth.Input}
onChange={(e) => {
accessStore.update(
(access) => (access.accessCode = e.currentTarget.value),
);
}}
/>
{!accessStore.hideUserApiKey ? (
<>
<div className={styles["auth-tips"]}>{Locale.Auth.SubTips}</div>
<input
className={styles["auth-input"]}
type="password"
placeholder={Locale.Settings.Access.OpenAI.ApiKey.Placeholder}
<PasswordInput
style={{ marginTop: "3vh", marginBottom: "3vh" }}
aria={Locale.Settings.ShowPassword}
aria-label={Locale.Settings.Access.OpenAI.ApiKey.Placeholder}
value={accessStore.openaiApiKey}
type="text"
placeholder={Locale.Settings.Access.OpenAI.ApiKey.Placeholder}
onChange={(e) => {
accessStore.update(
(access) => (access.openaiApiKey = e.currentTarget.value),
);
}}
/>
<input
className={styles["auth-input-second"]}
type="password"
placeholder={Locale.Settings.Access.Google.ApiKey.Placeholder}
<PasswordInput
style={{ marginTop: "3vh", marginBottom: "3vh" }}
aria={Locale.Settings.ShowPassword}
aria-label={Locale.Settings.Access.Google.ApiKey.Placeholder}
value={accessStore.googleApiKey}
type="text"
placeholder={Locale.Settings.Access.Google.ApiKey.Placeholder}
onChange={(e) => {
accessStore.update(
(access) => (access.googleApiKey = e.currentTarget.value),

View File

@@ -65,7 +65,7 @@ export const DEFAULT_CONFIG = {
providerName: "OpenAI" as ServiceProvider,
temperature: 0.5,
top_p: 1,
max_tokens: 4000,
max_completions_tokens: 4000,
presence_penalty: 0,
frequency_penalty: 0,
sendMemory: true,
@@ -127,7 +127,7 @@ export const ModalConfigValidator = {
model(x: string) {
return x as ModelType;
},
max_tokens(x: number) {
max_completions_tokens(x: number) {
return limitNumber(x, 0, 512000, 1024);
},
presence_penalty(x: number) {

View File

@@ -151,7 +151,7 @@ export const usePromptStore = createPersistStore(
if (typeof window === "undefined") {
return;
}
const PROMPT_URL = "./prompts.json";
type PromptList = Array<[string, string]>;

View File

@@ -285,6 +285,9 @@ export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true;
}
if (provider == ServiceProvider.Google && !model.includes("vision")) {
return true;
}
return false;
}

View File

@@ -250,6 +250,7 @@ export function stream(
return e.toString();
})
.then((content) => ({
name: tool.function.name,
role: "tool",
content,
tool_call_id: tool.id,

21
jest.config.ts Normal file
View File

@@ -0,0 +1,21 @@
import type { Config } from "jest";
import nextJest from "next/jest.js";
const createJestConfig = nextJest({
// Provide the path to your Next.js app to load next.config.js and .env files in your test environment
dir: "./",
});
// Add any custom config to be passed to Jest
const config: Config = {
coverageProvider: "v8",
testEnvironment: "jsdom",
testMatch: ["**/*.test.js", "**/*.test.ts", "**/*.test.jsx", "**/*.test.tsx"],
setupFilesAfterEnv: ["<rootDir>/jest.setup.ts"],
moduleNameMapper: {
"^@/(.*)$": "<rootDir>/$1",
},
};
// createJestConfig is exported this way to ensure that next/jest can load the Next.js config which is async
export default createJestConfig(config);

2
jest.setup.ts Normal file
View File

@@ -0,0 +1,2 @@
// Learn more: https://github.com/testing-library/jest-dom
import "@testing-library/jest-dom";

View File

@@ -6,16 +6,18 @@
"mask": "npx tsx app/masks/build.ts",
"mask:watch": "npx watch \"yarn mask\" app/masks",
"dev": "concurrently -r \"yarn run mask:watch\" \"next dev\"",
"build": "yarn mask && cross-env BUILD_MODE=standalone next build",
"build": "yarn test:ci && yarn mask && cross-env BUILD_MODE=standalone next build",
"start": "next start",
"lint": "next lint",
"export": "yarn mask && cross-env BUILD_MODE=export BUILD_APP=1 next build",
"export": "yarn test:ci && yarn mask && cross-env BUILD_MODE=export BUILD_APP=1 next build",
"export:dev": "concurrently -r \"yarn mask:watch\" \"cross-env BUILD_MODE=export BUILD_APP=1 next dev\"",
"app:dev": "concurrently -r \"yarn mask:watch\" \"yarn tauri dev\"",
"app:build": "yarn mask && yarn tauri build",
"app:build": "yarn test:ci && yarn mask && yarn tauri build",
"prompts": "node ./scripts/fetch-prompts.mjs",
"prepare": "husky install",
"proxy-dev": "sh ./scripts/init-proxy.sh && proxychains -f ./scripts/proxychains.conf yarn dev"
"proxy-dev": "sh ./scripts/init-proxy.sh && proxychains -f ./scripts/proxychains.conf yarn dev",
"test": "jest --watch",
"test:ci": "jest --ci"
},
"dependencies": {
"@fortaine/fetch-event-source": "^3.0.6",
@@ -54,6 +56,9 @@
"devDependencies": {
"@tauri-apps/api": "^1.6.0",
"@tauri-apps/cli": "1.5.11",
"@testing-library/jest-dom": "^6.4.8",
"@testing-library/react": "^16.0.0",
"@types/jest": "^29.5.12",
"@types/js-yaml": "4.0.9",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.11.30",
@@ -69,8 +74,11 @@
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-unused-imports": "^3.2.0",
"husky": "^8.0.0",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0",
"lint-staged": "^13.2.2",
"prettier": "^3.0.2",
"ts-node": "^10.9.2",
"tsx": "^4.16.0",
"typescript": "5.2.2",
"watch": "^1.0.2",

9
test/sum-module.test.ts Normal file
View File

@@ -0,0 +1,9 @@
function sum(a: number, b: number) {
return a + b;
}
describe("sum module", () => {
test("adds 1 + 2 to equal 3", () => {
expect(sum(1, 2)).toBe(3);
});
});

1970
yarn.lock

File diff suppressed because it is too large Load Diff