feat: add model config to settings

This commit is contained in:
Yifei Zhang
2023-03-21 16:20:32 +00:00
parent 4af8c26d02
commit 2f112ecc54
7 changed files with 517 additions and 258 deletions

View File

@@ -5,7 +5,7 @@ import { type ChatCompletionResponseMessage } from "openai";
import { requestChatStream, requestWithPrompt } from "./requests";
import { trimTopic } from "./utils";
import Locale from './locales'
import Locale from "./locales";
export type Message = ChatCompletionResponseMessage & {
date: string;
@@ -26,7 +26,7 @@ export enum Theme {
}
export interface ChatConfig {
maxToken?: number
maxToken?: number;
historyMessageCount: number; // -1 means all
compressMessageLengthThreshold: number;
sendBotMessages: boolean; // send bot's message or not
@@ -34,6 +34,78 @@ export interface ChatConfig {
avatar: string;
theme: Theme;
tightBorder: boolean;
modelConfig: {
model: string;
temperature: number;
max_tokens: number;
presence_penalty: number;
};
}
export type ModelConfig = ChatConfig["modelConfig"];
export const ALL_MODELS = [
{
name: "gpt-4",
available: false,
},
{
name: "gpt-4-0314",
available: false,
},
{
name: "gpt-4-32k",
available: false,
},
{
name: "gpt-4-32k-0314",
available: false,
},
{
name: "gpt-3.5-turbo",
available: true,
},
{
name: "gpt-3.5-turbo-0301",
available: true,
},
];
export function isValidModel(name: string) {
return ALL_MODELS.some((m) => m.name === name && m.available);
}
export function isValidNumber(x: number, min: number, max: number) {
return typeof x === "number" && x <= max && x >= min;
}
export function filterConfig(config: ModelConfig): Partial<ModelConfig> {
const validator: {
[k in keyof ModelConfig]: (x: ModelConfig[keyof ModelConfig]) => boolean;
} = {
model(x) {
return isValidModel(x as string);
},
max_tokens(x) {
return isValidNumber(x as number, 100, 4000);
},
presence_penalty(x) {
return isValidNumber(x as number, -2, 2);
},
temperature(x) {
return isValidNumber(x as number, 0, 1);
},
};
Object.keys(validator).forEach((k) => {
const key = k as keyof ModelConfig;
if (!validator[key](config[key])) {
delete config[key];
}
});
return config;
}
const DEFAULT_CONFIG: ChatConfig = {
@@ -44,6 +116,13 @@ const DEFAULT_CONFIG: ChatConfig = {
avatar: "1f603",
theme: Theme.Auto as Theme,
tightBorder: false,
modelConfig: {
model: "gpt-3.5-turbo",
temperature: 1,
max_tokens: 2000,
presence_penalty: 0,
},
};
export interface ChatStat {
@@ -107,7 +186,7 @@ interface ChatStore {
updater: (message?: Message) => void
) => void;
getMessagesWithMemory: () => Message[];
getMemoryPrompt: () => Message,
getMemoryPrompt: () => Message;
getConfig: () => ChatConfig;
resetConfig: () => void;
@@ -193,9 +272,9 @@ export const useChatStore = create<ChatStore>()(
},
onNewMessage(message) {
get().updateCurrentSession(session => {
session.lastUpdate = new Date().toLocaleString()
})
get().updateCurrentSession((session) => {
session.lastUpdate = new Date().toLocaleString();
});
get().updateStat(message);
get().summarizeSession();
},
@@ -214,9 +293,9 @@ export const useChatStore = create<ChatStore>()(
streaming: true,
};
// get recent messages
const recentMessages = get().getMessagesWithMemory()
const sendMessages = recentMessages.concat(userMessage)
// get recent messages
const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage);
// save user's and bot's message
get().updateCurrentSession((session) => {
@@ -224,12 +303,12 @@ export const useChatStore = create<ChatStore>()(
session.messages.push(botMessage);
});
console.log('[User Input] ', sendMessages)
console.log("[User Input] ", sendMessages);
requestChatStream(sendMessages, {
onMessage(content, done) {
if (done) {
botMessage.streaming = false;
get().onNewMessage(botMessage)
get().onNewMessage(botMessage);
} else {
botMessage.content = content;
set(() => ({}));
@@ -241,32 +320,35 @@ export const useChatStore = create<ChatStore>()(
set(() => ({}));
},
filterBot: !get().config.sendBotMessages,
modelConfig: get().config.modelConfig,
});
},
getMemoryPrompt() {
const session = get().currentSession()
const session = get().currentSession();
return {
role: 'system',
role: "system",
content: Locale.Store.Prompt.History(session.memoryPrompt),
date: ''
} as Message
date: "",
} as Message;
},
getMessagesWithMemory() {
const session = get().currentSession()
const config = get().config
const n = session.messages.length
const recentMessages = session.messages.slice(n - config.historyMessageCount);
const session = get().currentSession();
const config = get().config;
const n = session.messages.length;
const recentMessages = session.messages.slice(
n - config.historyMessageCount
);
const memoryPrompt = get().getMemoryPrompt()
const memoryPrompt = get().getMemoryPrompt();
if (session.memoryPrompt) {
recentMessages.unshift(memoryPrompt)
recentMessages.unshift(memoryPrompt);
}
return recentMessages
return recentMessages;
},
updateMessage(
@@ -286,49 +368,63 @@ export const useChatStore = create<ChatStore>()(
if (session.topic === DEFAULT_TOPIC && session.messages.length >= 3) {
// should summarize topic
requestWithPrompt(
session.messages,
Locale.Store.Prompt.Topic
).then((res) => {
get().updateCurrentSession(
(session) => (session.topic = trimTopic(res))
);
});
requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
(res) => {
get().updateCurrentSession(
(session) => (session.topic = trimTopic(res))
);
}
);
}
const config = get().config
let toBeSummarizedMsgs = session.messages.slice(session.lastSummarizeIndex)
const historyMsgLength = toBeSummarizedMsgs.reduce((pre, cur) => pre + cur.content.length, 0)
const config = get().config;
let toBeSummarizedMsgs = session.messages.slice(
session.lastSummarizeIndex
);
const historyMsgLength = toBeSummarizedMsgs.reduce(
(pre, cur) => pre + cur.content.length,
0
);
if (historyMsgLength > 4000) {
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(-config.historyMessageCount)
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
-config.historyMessageCount
);
}
// add memory prompt
toBeSummarizedMsgs.unshift(get().getMemoryPrompt())
toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
const lastSummarizeIndex = session.messages.length
const lastSummarizeIndex = session.messages.length;
console.log('[Chat History] ', toBeSummarizedMsgs, historyMsgLength, config.compressMessageLengthThreshold)
console.log(
"[Chat History] ",
toBeSummarizedMsgs,
historyMsgLength,
config.compressMessageLengthThreshold
);
if (historyMsgLength > config.compressMessageLengthThreshold) {
requestChatStream(toBeSummarizedMsgs.concat({
role: 'system',
content: Locale.Store.Prompt.Summarize,
date: ''
}), {
filterBot: false,
onMessage(message, done) {
session.memoryPrompt = message
if (done) {
console.log('[Memory] ', session.memoryPrompt)
session.lastSummarizeIndex = lastSummarizeIndex
}
},
onError(error) {
console.error('[Summarize] ', error)
},
})
requestChatStream(
toBeSummarizedMsgs.concat({
role: "system",
content: Locale.Store.Prompt.Summarize,
date: "",
}),
{
filterBot: false,
onMessage(message, done) {
session.memoryPrompt = message;
if (done) {
console.log("[Memory] ", session.memoryPrompt);
session.lastSummarizeIndex = lastSummarizeIndex;
}
},
onError(error) {
console.error("[Summarize] ", error);
},
}
);
}
},
@@ -348,8 +444,8 @@ export const useChatStore = create<ChatStore>()(
clearAllData() {
if (confirm(Locale.Store.ConfirmClearAll)) {
localStorage.clear()
location.reload()
localStorage.clear();
location.reload();
}
},
}),