mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-28 08:41:21 +08:00
fix: pass channel_type correctly in multi-channel mode
This commit is contained in:
@@ -451,8 +451,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
|
||||
4
app.py
4
app.py
@@ -211,14 +211,14 @@ def run():
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "wx")
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["wx"]
|
||||
channel_names = ["web"]
|
||||
|
||||
if "wxy" in channel_names:
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
@@ -135,7 +135,7 @@ class AgentLLMModel(LLMModel):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
@@ -143,15 +143,20 @@ class AgentLLMModel(LLMModel):
|
||||
'stream': True,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
|
||||
|
||||
# Only pass max_tokens if explicitly set, let the bot use its default
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass channel_type for linkai tracking
|
||||
channel_type = getattr(self, 'channel_type', None)
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
@@ -325,6 +330,10 @@ class AgentBridge:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass channel_type to model so linkai requests carry it
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
|
||||
# Record message count before execution so we can diff new messages
|
||||
with agent.messages_lock:
|
||||
pre_run_len = len(agent.messages)
|
||||
|
||||
@@ -322,7 +322,14 @@ class AgentInitializer:
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = conf().get("channel_type", "unknown")
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
@@ -369,7 +376,7 @@ class AgentInitializer:
|
||||
return {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"workspace": workspace_root,
|
||||
"channel": conf().get("channel_type", "unknown"),
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,8 @@ class ChatChannel(Channel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
|
||||
@@ -698,6 +698,8 @@ class FeiShuChanel(ChatChannel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class LinkAIBot(Bot, OpenAICompatibleBot):
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"sender_id": session_id,
|
||||
"channel_type": conf().get("channel_type", "wx")
|
||||
"channel_type": context.get("channel_type") or conf().get("channel_type", "web")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
@@ -526,6 +526,14 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
logger.debug(f"[LinkAI] messages: {len(messages)}, tools: {len(tools) if tools else 0}, stream: {stream}")
|
||||
|
||||
# Build request parameters (LinkAI uses OpenAI-compatible format)
|
||||
raw_ct = conf().get("channel_type", "web")
|
||||
if isinstance(raw_ct, list):
|
||||
channel_type = raw_ct[0] if raw_ct else "web"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
channel_type = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
channel_type = raw_ct
|
||||
|
||||
body = {
|
||||
"messages": messages,
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
@@ -533,7 +541,8 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"top_p": kwargs.get("top_p", conf().get("top_p", 1)),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||
"stream": stream
|
||||
"stream": stream,
|
||||
"channel_type": kwargs.get("channel_type", channel_type),
|
||||
}
|
||||
|
||||
if tools:
|
||||
|
||||
@@ -140,7 +140,9 @@ def get_help_text(isadmin, isgroup):
|
||||
for cmd, info in COMMANDS.items():
|
||||
if cmd in ["auth", "set_openai_api_key", "reset_openai_api_key", "set_gpt_model", "reset_gpt_model", "gpt_model"]: # 不显示帮助指令
|
||||
continue
|
||||
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
|
||||
raw_ct = conf().get("channel_type", "web")
|
||||
active_channels = raw_ct if isinstance(raw_ct, list) else [c.strip() for c in str(raw_ct).split(",")]
|
||||
if cmd == "id" and not any(c in ["wxy", "wechatmp"] for c in active_channels):
|
||||
continue
|
||||
alias = ["#" + a for a in info["alias"][:1]]
|
||||
help_text += f"{','.join(alias)} "
|
||||
|
||||
Reference in New Issue
Block a user