Feng Chike Claude Opus 4.6 (1M context) commited on
Commit
408f650
·
0 Parent(s):

Freud Zero MVP: 心理咨询AI系统(清洁部署)

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (13) hide show
  1. .gitignore +14 -0
  2. README.md +14 -0
  3. app.py +4 -0
  4. counselor.py +152 -0
  5. evaluator.py +35 -0
  6. main.py +227 -0
  7. mcts_reasoner.py +184 -0
  8. prompts.py +295 -0
  9. requirements.txt +5 -0
  10. session_logger.py +38 -0
  11. strategic_advisor.py +533 -0
  12. strategy_visualizer.py +126 -0
  13. supervisor_advisor.py +122 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ .gradio/
4
+ .DS_Store
5
+ sessions/
6
+ *.pyc
7
+ .env
8
+ test_*.py
9
+ test.py
10
+ *.output
11
+ .claude/
12
+ .worktrees/
13
+ INIT.md
14
+ freud_zero_3day_plan.md
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Freud-Zero MVP
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: "5.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Freud-Zero MVP
13
+
14
+ 精神动力学取向回应性咨询师 · V4 PUCT 战略推理版
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from main import app
2
+
3
+ if __name__ == "__main__":
4
+ app.launch(server_name="0.0.0.0", server_port=7860)
counselor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import threading
4
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
5
+ from langchain_openai import ChatOpenAI
6
+ from prompts import COUNSELOR_SYSTEM_PROMPT, STRATEGIC_GUIDANCE_TEMPLATE
7
+ from evaluator import DisclosureEvaluator
8
+ from strategic_advisor import StrategicAdvisor
9
+ from session_logger import SessionLogger
10
+ from strategy_visualizer import StrategyVisualizer
11
+
12
+
13
+ class PsychodynamicCounselor:
14
+
15
+ def __init__(self):
16
+ self.llm = ChatOpenAI(
17
+ model="qwen-plus",
18
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
19
+ api_key=os.getenv("DASHSCOPE_API_KEY"),
20
+ temperature=0.7,
21
+ )
22
+ self.evaluator = DisclosureEvaluator()
23
+ self.advisor = StrategicAdvisor()
24
+ self.logger = SessionLogger()
25
+ self.visualizer = StrategyVisualizer()
26
+ self.history = [SystemMessage(content=COUNSELOR_SYSTEM_PROMPT)]
27
+ self.turn_number = 0
28
+ self.current_guidance = None
29
+ self._last_disclosure_score = 1 # 跟踪当前揭露水平,供相对评分用
30
+ self._last_dimensions = {} # 最近一次揭露维度 A-E
31
+ self._last_reasoning = "" # 最近一次评估理由
32
+ self._disclosure_history = [] # 揭露分数历史轨迹
33
+ self._last_trace_stats = None # 最近一次推理统计
34
+ self._pending_trace = None # 后台完成的战略推理结果
35
+ self._bg_thread = None
36
+ self._lock = threading.Lock()
37
+
38
+ def _inject_guidance(self):
39
+ """后台线程调用:current_guidance 已更新,无需修改 history。
40
+ 督导指令会在 respond() 调用模型时动态插入。"""
41
+ pass
42
+
43
+ def _run_strategic_reasoning(self, history_snapshot, current_disclosure):
44
+ """后台线程:执行4层战略推理,完成后更新指导。"""
45
+ print(f"[战略推理] 第{self.turn_number}轮触发,当前揭露={current_disclosure},开始后台推理...")
46
+ try:
47
+ best, guidance, strategic_trace = self.advisor.run(history_snapshot, current_disclosure)
48
+ with self._lock:
49
+ if best and best.get("score", 0) > 0:
50
+ self.current_guidance = {
51
+ "direction": guidance.get("direction", best["seed"]),
52
+ "principles": guidance.get("principles", []),
53
+ "evidence": guidance.get("evidence", ""),
54
+ }
55
+ print(f"[战略推理] 完成! 选中: {best['id']}.{best['branch']} score={best['score']} delta={best['delta']}")
56
+ print(f" 方向: {guidance.get('direction', '?')}")
57
+ for p in guidance.get("principles", []):
58
+ print(f" 原则: {p}")
59
+ else:
60
+ print("[战略推理] 完成,但未产生有效方向建议")
61
+ self._pending_trace = strategic_trace
62
+ # 生成可视化报告
63
+ self.visualizer.render(strategic_trace, self.turn_number)
64
+ except Exception as e:
65
+ print(f"[战略推理] 后台推理失败,跳过本轮: {e}")
66
+
67
+ def respond(self, user_message):
68
+ self.turn_number += 1
69
+ self.history.append(HumanMessage(content=user_message))
70
+
71
+ # 检查是否有后台完成的战略推理结果需要记录
72
+ logged_trace = None
73
+ with self._lock:
74
+ if self._pending_trace is not None:
75
+ logged_trace = self._pending_trace
76
+ self._pending_trace = None
77
+
78
+ # 评估来访者当前发言的揭露深度
79
+ disclosure_result = self.evaluator.evaluate_disclosure(user_message)
80
+ self._last_disclosure_score = disclosure_result["score"]
81
+ self._last_dimensions = disclosure_result.get("dimensions", {})
82
+ self._last_reasoning = disclosure_result.get("reasoning", "")
83
+ self._disclosure_history.append(disclosure_result["score"])
84
+
85
+ # 同步战略推理:先推理,再用督导结果回复
86
+ print(f"[战略推理] 第{self.turn_number}轮,当前揭露={self._last_disclosure_score},同步推理中...")
87
+ try:
88
+ best, guidance, strategic_trace = self.advisor.run(
89
+ list(self.history), self._last_disclosure_score
90
+ )
91
+ if best and best.get("score", 0) > 0:
92
+ self.current_guidance = {
93
+ "direction": guidance.get("direction", best["seed"]),
94
+ "principles": guidance.get("principles", []),
95
+ "evidence": guidance.get("evidence", ""),
96
+ }
97
+ print(f"[战略推理] 完成! {best['id']}.{best['branch']} score={best['score']} delta={best['delta']}")
98
+ logged_trace = strategic_trace
99
+ self._last_trace_stats = {
100
+ "total_paths": len(strategic_trace.get("candidates", [])),
101
+ "deep_paths": len(strategic_trace.get("deep_paths", [])),
102
+ "seeds": list(strategic_trace.get("seeds", {}).keys()),
103
+ "selected": strategic_trace.get("selected", ""),
104
+ "timing": strategic_trace.get("timing", {}),
105
+ "best_score": best.get("score", 0) if best else 0,
106
+ "best_delta": best.get("delta", 0) if best else 0,
107
+ "predicted_disclosure": guidance.get("disclosure_level", best.get("score", "?")) if best else "?",
108
+ }
109
+ self.visualizer.render(strategic_trace, self.turn_number)
110
+ except Exception as e:
111
+ print(f"[战略推理] 推理失败,跳过: {e}")
112
+ logged_trace = None
113
+
114
+ # 前台模型生成回复:动态构建消息列表,督导指令插在最新用户消息之前
115
+ if self.current_guidance:
116
+ principles_text = "\n".join(f"- {p}" for p in self.current_guidance.get("principles", []))
117
+ guidance_text = STRATEGIC_GUIDANCE_TEMPLATE.replace(
118
+ "{direction}", self.current_guidance["direction"]
119
+ ).replace("{principles}", principles_text
120
+ ).replace("{evidence}", self.current_guidance.get("evidence", ""))
121
+ messages_to_send = self.history[:-1] + [SystemMessage(content=guidance_text)] + [self.history[-1]]
122
+ print(f"[前台] 第{self.turn_number}轮 | 督导方向已注入: {self.current_guidance['direction'][:40]}")
123
+ else:
124
+ messages_to_send = self.history
125
+ print(f"[前台] 第{self.turn_number}轮 | 无督导指令")
126
+ # 带重试的前台调用(防止并发限流 403)
127
+ for _retry in range(3):
128
+ try:
129
+ response = self.llm.invoke(messages_to_send)
130
+ break
131
+ except Exception as e:
132
+ if _retry < 2:
133
+ print(f"[前台] 调用失败({e}), 重试中...")
134
+ time.sleep(1)
135
+ else:
136
+ raise
137
+ self.history.append(AIMessage(content=response.content))
138
+
139
+ self.logger.log_turn(
140
+ self.turn_number,
141
+ user_message,
142
+ response.content,
143
+ disclosure_result["score"],
144
+ disclosure_result["dimensions"],
145
+ disclosure_result["reasoning"],
146
+ mcts_trace=logged_trace,
147
+ )
148
+
149
+ return response.content
150
+
151
+ def get_session_filepath(self):
152
+ return self.logger.get_filepath()
evaluator.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from langchain_openai import ChatOpenAI
4
+ from prompts import DISCLOSURE_EVAL_PROMPT
5
+
6
+
7
+ class DisclosureEvaluator:
8
+ def __init__(self):
9
+ self.llm = ChatOpenAI(
10
+ model="qwen-turbo",
11
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
12
+ api_key=os.getenv("DASHSCOPE_API_KEY"),
13
+ temperature=0.0,
14
+ max_tokens=256,
15
+ )
16
+
17
+ def evaluate_disclosure(self, user_message, max_retries=2):
18
+ prompt = DISCLOSURE_EVAL_PROMPT.replace("{user_message}", user_message)
19
+
20
+ for attempt in range(max_retries + 1):
21
+ try:
22
+ llm_message = self.llm.invoke(prompt)
23
+ content = llm_message.content.strip()
24
+ start = content.find("{")
25
+ end = content.rfind("}") + 1
26
+ result = json.loads(content[start:end])
27
+ return {
28
+ "score": max(1, min(10, int(result.get("score", 1)))),
29
+ "dimensions": result.get("dimensions", {}),
30
+ "reasoning": result.get("reasoning", ""),
31
+ }
32
+ except (json.JSONDecodeError, ValueError):
33
+ if attempt < max_retries:
34
+ continue # 再试一次
35
+ return {"score": 1, "dimensions": {}, "reasoning": "评估解析失败"}
main.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ from pathlib import Path
4
+ from counselor import PsychodynamicCounselor
5
+
6
+ counselor = None
7
+
8
+
9
+ def build_status_panel(c):
10
+ """构建会话状态 Markdown 富文本面板"""
11
+ score = c._last_disclosure_score
12
+ turn = c.turn_number
13
+
14
+ # 揭露深度进度条
15
+ filled = "█" * score + "░" * (10 - score)
16
+ # 揭露趋势
17
+ history = c._disclosure_history
18
+ if len(history) >= 2:
19
+ diff = history[-1] - history[-2]
20
+ trend = "↑" if diff > 0 else ("↓" if diff < 0 else "→")
21
+ else:
22
+ trend = "·"
23
+
24
+ # 维度指示灯
25
+ dims = c._last_dimensions
26
+ dim_labels = {
27
+ "A": "具体事件", "B": "情绪表达", "C": "具体情绪",
28
+ "D": "自我反思", "E": "回避触及",
29
+ }
30
+ dim_parts = []
31
+ for k in ["A", "B", "C", "D", "E"]:
32
+ on = dims.get(k, False)
33
+ icon = "🟢" if on else "⚫"
34
+ dim_parts.append(f"{icon} {k}:{dim_labels[k]}")
35
+ dim_line = " ".join(dim_parts)
36
+
37
+ # 历史火花线 (sparkline)
38
+ spark_chars = " ▁▂▃▄▅▆▇█"
39
+ spark = ""
40
+ for s in history[-20:]:
41
+ idx = min(s, 9)
42
+ spark += spark_chars[idx]
43
+ if not spark:
44
+ spark = "—"
45
+
46
+ lines = []
47
+ lines.append("### 🧠 SESSION MONITOR")
48
+ lines.append("")
49
+ lines.append(f"| 指标 | 值 |")
50
+ lines.append(f"|:---|:---|")
51
+ lines.append(f"| **轮次** | `{turn}` |")
52
+ lines.append(f"| **揭露深度** | `{filled}` **{score}/10** {trend} |")
53
+ lines.append(f"| **深度轨迹** | `{spark}` |")
54
+ lines.append("")
55
+ lines.append(f"**维度分析** {dim_line}")
56
+
57
+ # 督导信息
58
+ if c.current_guidance:
59
+ g = c.current_guidance
60
+ direction = g.get("direction", "—")
61
+ principles = g.get("principles", [])
62
+ evidence = g.get("evidence", "")
63
+
64
+ lines.append("")
65
+ lines.append("---")
66
+ lines.append("#### ▸ 督导指令")
67
+ lines.append(f"> **方向**: {direction}")
68
+ if principles:
69
+ lines.append(">")
70
+ for p in principles[:3]:
71
+ lines.append(f"> · {p}")
72
+ if evidence:
73
+ lines.append(f">")
74
+ lines.append(f"> **证据**: {evidence[:80]}{'...' if len(evidence) > 80 else ''}")
75
+
76
+ # 推理引擎统计
77
+ ts = c._last_trace_stats
78
+ if ts:
79
+ timing = ts.get("timing", {})
80
+ total_s = timing.get("total_seconds", 0)
81
+ total_paths = ts.get("total_paths", 0)
82
+ deep = ts.get("deep_paths", 0)
83
+ seeds = ts.get("seeds", [])
84
+ selected = ts.get("selected", "")
85
+ best_score = ts.get("best_score", 0)
86
+ best_delta = ts.get("best_delta", 0)
87
+
88
+ lines.append("")
89
+ lines.append("#### ▸ PUCT 推理引擎")
90
+ lines.append(f"| 参数 | 值 |")
91
+ lines.append(f"|:---|:---|")
92
+ lines.append(f"| **搜索树** | L1→L2→L3→L4→L5→L6 (6层) |")
93
+ lines.append(f"| **种子方向** | {', '.join(seeds)} |")
94
+ lines.append(f"| **候选路径** | {total_paths} 条 → 深探 {deep} 条 |")
95
+ lines.append(f"| **最优路径** | `{selected}` score={best_score} Δ={best_delta:+.1f} |")
96
+ lines.append(f"| **推理耗时** | {total_s}s |")
97
+ else:
98
+ lines.append("")
99
+ lines.append("---")
100
+ lines.append("#### ▸ 督导引擎")
101
+ lines.append(f"> ⏳ 每轮同步推理中…")
102
+
103
+ # 评估理由
104
+ if c._last_reasoning:
105
+ lines.append("")
106
+ lines.append(f"<details><summary>📋 评估理由</summary>\n\n{c._last_reasoning}\n\n</details>")
107
+
108
+ return "\n".join(lines)
109
+
110
+
111
+ def start_session():
112
+ global counselor
113
+ counselor = PsychodynamicCounselor()
114
+ return [], "### 🧠 SESSION MONITOR\n\n> 新会话已开始,等待来访者发言…"
115
+
116
+
117
+ def chat(user_message, chat_history):
118
+ global counselor
119
+ if counselor is None:
120
+ counselor = PsychodynamicCounselor()
121
+
122
+ if not user_message.strip():
123
+ return chat_history, "", "", ""
124
+
125
+ response = counselor.respond(user_message)
126
+
127
+ chat_history = chat_history or []
128
+ chat_history.append({"role": "user", "content": user_message})
129
+ chat_history.append({"role": "assistant", "content": response})
130
+
131
+ # 构建富文本状态面板
132
+ status = build_status_panel(counselor)
133
+
134
+ return chat_history, "", status, ""
135
+
136
+
137
+ def end_session():
138
+ global counselor
139
+ if counselor:
140
+ path = counselor.get_session_filepath()
141
+ counselor = None
142
+ return f"会话已结束。日志保存于:{path}"
143
+ return "当前无活跃会话。"
144
+
145
+
146
+ def view_sessions():
147
+ files = sorted(Path("sessions").glob("session_*.json"))
148
+ if not files:
149
+ return "暂无会话记录"
150
+
151
+ output = ""
152
+ for f in files:
153
+ with open(f, encoding="utf-8") as fp:
154
+ data = json.load(fp)
155
+
156
+ total = data.get("total_turns", len(data["turns"]))
157
+ output += f"\n{'='*50}\n"
158
+ output += f"Session: {data['session_id']} | 轮次: {total}\n"
159
+ output += f"{'='*50}\n"
160
+
161
+ for t in data["turns"]:
162
+ score = t.get("disclosure_score", "?")
163
+ output += f"\n[轮次 {t['turn_number']}] 揭露评分: {score}/5\n"
164
+ output += f"来访者: {t['user_message']}\n"
165
+ output += f"咨询师: {t['counselor_message']}\n"
166
+ dims = t.get("dimension_score", {})
167
+ reason = t.get("reason", "")
168
+ if dims:
169
+ output += f"维度: {dims}\n"
170
+ if reason:
171
+ output += f"理由: {reason}\n"
172
+
173
+ # 战略推理记录(每5轮)
174
+ trace = t.get("mcts_trace")
175
+ if trace and "selected_direction" in trace:
176
+ output += f"\n === 战略推理(第{t['turn_number']}轮触发) ===\n"
177
+ output += f" 总结: {trace.get('summary', '')[:80]}...\n"
178
+ output += f" 选中: {trace['selected']} → {trace['selected_direction'][:50]}\n"
179
+ output += f" 预测揭露: {trace.get('selected_score', '?')}/10\n"
180
+ for d in trace.get("directions", []):
181
+ marker = " ★" if d["id"] == trace["selected"] else ""
182
+ output += f" [{d['id']}]{marker} {d.get('direction', '')[:40]} → 揭露={d.get('disclosure_level', '?')}/10\n"
183
+ output += f" ========================\n"
184
+
185
+ return output
186
+
187
+
188
+ def download_all_sessions():
189
+ files = sorted(Path("sessions").glob("session_*.json"))
190
+ if not files:
191
+ return None
192
+ return [str(f) for f in files]
193
+
194
+
195
+ with gr.Blocks(title="Freud-Zero MVP") as app:
196
+ gr.Markdown("# Freud-Zero MVP")
197
+ gr.Markdown("精神动力学取向回应性咨询师 · 自我揭露深度追踪")
198
+
199
+ with gr.Row():
200
+ btn_start = gr.Button("开始新会话", variant="primary")
201
+ btn_end = gr.Button("结束会话", variant="stop")
202
+
203
+ chatbot = gr.Chatbot(label="对话", height=480)
204
+
205
+ with gr.Row():
206
+ user_input = gr.Textbox(placeholder="说你想说的……", show_label=False, scale=4)
207
+ btn_send = gr.Button("发送", scale=1)
208
+
209
+ status_output = gr.Markdown(value="### 🧠 SESSION MONITOR\n\n> 等待开始会话…")
210
+
211
+ with gr.Accordion("研究者面板", open=False):
212
+ with gr.Row():
213
+ btn_view = gr.Button("查看所有会话记录")
214
+ btn_download = gr.Button("下载日志文件")
215
+ log_display = gr.Textbox(label="会话日志", lines=20, interactive=False)
216
+ file_output = gr.File(label="日志文件")
217
+
218
+ # 绑定事件
219
+ btn_start.click(start_session, outputs=[chatbot, status_output])
220
+ btn_end.click(end_session, outputs=[status_output])
221
+ btn_send.click(chat, inputs=[user_input, chatbot], outputs=[chatbot, user_input, status_output, log_display])
222
+ user_input.submit(chat, inputs=[user_input, chatbot], outputs=[chatbot, user_input, status_output, log_display])
223
+ btn_view.click(view_sessions, outputs=[log_display])
224
+ btn_download.click(download_all_sessions, outputs=[file_output])
225
+
226
+ if __name__ == "__main__":
227
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True)
mcts_reasoner.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.messages import HumanMessage, AIMessage
6
+ from prompts import (
7
+ CANDIDATE_GENERATION_PROMPT,
8
+ CLIENT_SIMULATOR_PROMPT,
9
+ MCTS_EVALUATOR_PROMPT,
10
+ )
11
+
12
+
13
+ class MCTSReasoner:
14
+ def __init__(self):
15
+ base_kwargs = dict(
16
+ model="deepseek-chat",
17
+ base_url="https://api.deepseek.com/v1",
18
+ api_key=os.getenv("DEEPSEEK_API_KEY"),
19
+ )
20
+ self.gen_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=1024)
21
+ self.sim_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=256)
22
+ self.eval_llm = ChatOpenAI(**base_kwargs, temperature=0.0, max_tokens=64)
23
+
24
+ def _format_history(self, history):
25
+ """将 langchain 消息历史格式化为可读文本(跳过 system message)。"""
26
+ lines = []
27
+ for msg in history:
28
+ if isinstance(msg, HumanMessage):
29
+ lines.append(f"来访者:{msg.content}")
30
+ elif isinstance(msg, AIMessage):
31
+ lines.append(f"咨询师:{msg.content}")
32
+ return "\n".join(lines) if lines else "(首次对话)"
33
+
34
+ def _parse_json(self, text):
35
+ """从 LLM 输出中提取 JSON。"""
36
+ content = text.strip()
37
+ start = content.find("[") if "[" in content else content.find("{")
38
+ end = content.rfind("]") + 1 if "[" in content else content.rfind("}") + 1
39
+ if start == -1 or end == 0:
40
+ raise ValueError(f"无法解析 JSON: {content[:100]}")
41
+ return json.loads(content[start:end])
42
+
43
+ def generate_candidates(self, history, user_message):
44
+ """Step 1: 生成 5 个候选咨询师回复。"""
45
+ prompt = CANDIDATE_GENERATION_PROMPT.replace(
46
+ "{conversation_history}", self._format_history(history)
47
+ ).replace("{user_message}", user_message)
48
+
49
+ for attempt in range(3):
50
+ try:
51
+ result = self.gen_llm.invoke(prompt)
52
+ candidates = self._parse_json(result.content)
53
+ return candidates
54
+ except (json.JSONDecodeError, ValueError):
55
+ if attempt == 2:
56
+ raise
57
+
58
+ def _simulate_one(self, candidate, history_text, user_message):
59
+ """模拟单个候选回复的来访者反应。"""
60
+ prompt = CLIENT_SIMULATOR_PROMPT.replace(
61
+ "{conversation_history}", history_text
62
+ ).replace(
63
+ "{user_message}", user_message
64
+ ).replace(
65
+ "{therapist_response}", candidate["response"]
66
+ )
67
+
68
+ for attempt in range(2):
69
+ try:
70
+ result = self.sim_llm.invoke(prompt)
71
+ parsed = self._parse_json(result.content)
72
+ return {
73
+ "id": candidate["id"],
74
+ "simulated_client_response": parsed.get("simulated_response", ""),
75
+ "emotional_state": parsed.get("emotional_state", ""),
76
+ }
77
+ except (json.JSONDecodeError, ValueError):
78
+ if attempt == 1:
79
+ return {
80
+ "id": candidate["id"],
81
+ "simulated_client_response": "(模拟失败)",
82
+ "emotional_state": "未知",
83
+ }
84
+
85
+ def simulate_client_reactions(self, candidates, history, user_message):
86
+ """Step 2: 并行模拟来访者对每个候选回复的反应。"""
87
+ history_text = self._format_history(history)
88
+ simulations = []
89
+
90
+ with ThreadPoolExecutor(max_workers=5) as executor:
91
+ futures = {
92
+ executor.submit(
93
+ self._simulate_one, c, history_text, user_message
94
+ ): c["id"]
95
+ for c in candidates
96
+ }
97
+ for future in as_completed(futures):
98
+ simulations.append(future.result())
99
+
100
+ simulations.sort(key=lambda x: x["id"])
101
+ return simulations
102
+
103
+ def _evaluate_one(self, simulation):
104
+ """评估单个模拟反应的揭露深度。"""
105
+ prompt = MCTS_EVALUATOR_PROMPT.replace(
106
+ "{client_response}", simulation["simulated_client_response"]
107
+ )
108
+
109
+ for attempt in range(2):
110
+ try:
111
+ result = self.eval_llm.invoke(prompt)
112
+ parsed = self._parse_json(result.content)
113
+ return {
114
+ "id": simulation["id"],
115
+ "score": max(0, min(10, int(parsed.get("score", 0)))),
116
+ "reason": parsed.get("reason", ""),
117
+ }
118
+ except (json.JSONDecodeError, ValueError):
119
+ if attempt == 1:
120
+ return {"id": simulation["id"], "score": 0, "reason": "评估解析失败"}
121
+
122
+ def evaluate_disclosures(self, simulations):
123
+ """Step 3: 并行评估每个模拟反应的揭露深度。"""
124
+ evaluations = []
125
+
126
+ with ThreadPoolExecutor(max_workers=5) as executor:
127
+ futures = {
128
+ executor.submit(self._evaluate_one, s): s["id"] for s in simulations
129
+ }
130
+ for future in as_completed(futures):
131
+ evaluations.append(future.result())
132
+
133
+ evaluations.sort(key=lambda x: x["id"])
134
+ return evaluations
135
+
136
+ def select_best(self, candidates, simulations, evaluations):
137
+ """Step 4: 选择最高分候选。分数相同时优先选择情感深度更高的。"""
138
+ max_score = max(e["score"] for e in evaluations)
139
+ top_candidates = [e for e in evaluations if e["score"] == max_score]
140
+
141
+ if len(top_candidates) == 1:
142
+ best_id = top_candidates[0]["id"]
143
+ reason = "最高揭露深度评分"
144
+ else:
145
+ # 同分时,找模拟反应中情感状态描述最长的(粗略代理情感深度)
146
+ best_id = top_candidates[0]["id"]
147
+ max_depth = 0
148
+ for tc in top_candidates:
149
+ sim = next((s for s in simulations if s["id"] == tc["id"]), None)
150
+ if sim:
151
+ depth = len(sim.get("emotional_state", ""))
152
+ if depth > max_depth:
153
+ max_depth = depth
154
+ best_id = tc["id"]
155
+ reason = "同分中情感深度更高"
156
+
157
+ best_response = next(c["response"] for c in candidates if c["id"] == best_id)
158
+ return best_id, best_response, reason
159
+
160
+ def run(self, history, user_message):
161
+ """执行完整 MCTS 推理流程,返回最优回复和完整 trace。"""
162
+ # Step 1
163
+ candidates = self.generate_candidates(history, user_message)
164
+
165
+ # Step 2
166
+ simulations = self.simulate_client_reactions(candidates, history, user_message)
167
+
168
+ # Step 3
169
+ evaluations = self.evaluate_disclosures(simulations)
170
+
171
+ # Step 4
172
+ selected_id, best_response, selection_reason = self.select_best(
173
+ candidates, simulations, evaluations
174
+ )
175
+
176
+ mcts_trace = {
177
+ "candidates": candidates,
178
+ "simulations": simulations,
179
+ "evaluations": evaluations,
180
+ "selected": selected_id,
181
+ "selection_reason": selection_reason,
182
+ }
183
+
184
+ return best_response, mcts_trace
prompts.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COUNSELOR_SYSTEM_PROMPT = """你是一位精神动力学取向的心理咨询师,正在和来访者进行一对一的咨询。
2
+
3
+ ## 你是什么样的咨询师
4
+ 你是一个真实的人在做咨询。你说话就像平时说话一样——直接、自然、不表演。不要用括号描述动作或状态(如"(沉默)""(轻轻叹气)"),不要描写自己的表情或肢体语言。你只需要说话,或者不说话。
5
+
6
+ ## 怎么回应
7
+ - 少说。你的工作是让来访者说,不是你说。
8
+ - 大多数时候一句话就够。有时候半句话、一个词、甚至不说话。
9
+ - 绝对不要超过两句话。如果你想说第三句,删掉前两句,只留最准的那一句。
10
+ - 不要复述、不要摘要、不要解释你为什么这么问。
11
+ - 不要用"我能感受到""听起来你……"这类套话。
12
+ - 一次只做一件事:要么点一下情绪,要么问一个问题,要么就在那儿。
13
+
14
+ ## 干预方式
15
+ - 当来访者正在展开的时候,不要打断。一个"嗯"或者短暂的沉默就是最好的回应。
16
+ - 当你要说话时,试着命名来访者可能还没完全意识到的那个情绪——不是解释,是轻轻点一下。
17
+ - 提问时只问一个。问那个来访者似乎在绕着走的东西。
18
+ - 诠释要非常谨慎。用"我在想……""不知道是不是……"开头,让来访者有拒绝的空间。
19
+
20
+ ## 不要做的事
21
+ - 不诊断、不给建议、不布置作业。
22
+ - 不鼓励、不安慰、不扮演朋友。
23
+ - 不用CBT的方式说话("换个角度想""这个想法合理吗")。
24
+ - 不要用身体隐喻或躯体化意象(如"胸口像石头""吞了玻璃""卡在喉咙")。来访者自己说可以,你不要主动去猜测或命名身体感受。
25
+
26
+ ## 危机处理
27
+ 如来访者表达自杀/自伤意图,立即回应:
28
+ "我听到你说的了,这让我很担心你的安全。我希望你现在就拨打400-161-9995,这是24小时心理危机热线。你愿意现在拨打吗?"
29
+ 然后停止一切其他话题。
30
+
31
+ ## 语言
32
+ 口语化中文,像真人说话一样自然。不要书面腔。不用emoji。
33
+ """
34
+
35
+ DISCLOSURE_EVAL_PROMPT = """你是自我揭露深度评估器。先归纳维度,再对照表打分(10分制)。
36
+
37
+ ## 第一步:逐条提取(用是/否回答)
38
+ A. 是否提及了具体的个人事件或经历(而非泛泛而谈)?
39
+ B. 是否表达了情绪或情感体验(而非纯粹叙事)?
40
+ C. 情绪表达是否具体而非笼统?("我很烦"=笼统,"我对他的忽视感到愤怒"=具体)
41
+ D. 是否涉及自我认知、关系模式或内心冲突?
42
+ E. 是否触及了通常会回避的主题,或伴有犹豫/修正的表达方式?
43
+
44
+ ## 第二步:对照表评分(就低不就高)
45
+ 1分:完全回避或拒绝回应。
46
+ 2分:A否+B否。纯寒暄、闲聊,无个人信息。
47
+ 3分:A是+B否。提到个人事件但零情感。
48
+ 4分:B是+C否。有情绪但笼统("不太好""有点烦")。
49
+ 5分:C是+D否。情绪具体但没有自我反思。
50
+ 6分:C是+D是。情绪具体,且有自我模式的反思。
51
+ 7分:D是+E部分。深度自我反思,接近回避内容但未完全展开。
52
+ 8分:E是。触及核心回避主题,有犹豫但在说。
53
+ 9分:E是+情感强度高。触及回避主题并伴随强烈情感。
54
+ 10分:完全突破防御,袒露最脆弱的核心体验。
55
+
56
+ ## 输出格式(严格JSON)
57
+ {{"dimensions": {{"A": true/false, "B": true/false, "C": true/false, "D": true/false, "E": true/false}}, "score": <1-10>, "reasoning": "<一句话,不超过50字>"}}
58
+
59
+ ## 来访者本轮发言:
60
+ {user_message}
61
+ """
62
+
63
+ # ===== 战略推理相关 Prompt(每5轮触发) =====
64
+
65
+ SUMMARY_AND_SEEDS_PROMPT = """你是精神动力学督导团协调者。完成两件事:
66
+
67
+ ## 对话记录
68
+ {conversation_history}
69
+
70
+ ## 任务
71
+ 1. 用2句话总结:核心议题+防御模式
72
+ 2. 生成3个不同的临床切入点(具体到此刻状态,不涉及身体感受)
73
+
74
+ 只输出JSON(总共不超过150字):{{"summary":"2句话总结","A":"切入点A","B":"切入点B","C":"切入点C"}}"""
75
+
76
+ L2_MERGED_PROMPT = """你是来访者心理模拟器。基于对话和咨询师最新回应,直接生成2个不同的来访者回应。
77
+
78
+ ## 对话历史
79
+ {conversation_history}
80
+
81
+ ## 咨询师刚说
82
+ {therapist_reply}
83
+
84
+ ## 要求
85
+ 2个回应必须反映来访者此刻听到这句话后可能的不同内在反应。每个回应1-2句,口语化,符合来访者一贯风格。
86
+
87
+ 只输出JSON:{{"A":"来访者回应A","B":"来访者回应B"}}"""
88
+
89
+ L3_MERGED_PROMPT = """你是精神动力学咨询师。对话中发生了以下交流,现在生成3个不同方向的咨询师回应。
90
+
91
+ ## 对话记录
92
+ {conversation_history}
93
+
94
+ ## 刚才的交流(模拟)
95
+ 咨询师:{l1_therapist_reply}
96
+ 来访者:{l2_client_response}
97
+
98
+ ## 要求
99
+ 3个回应从不同临床角度出发,各一句话,口语化中文,不要套话,不要问身体感受。
100
+
101
+ 只输出JSON:{{"A":"咨询师回应A","B":"咨询师回应B","C":"咨询师回应C"}}"""
102
+
103
+ L4_MERGED_PROMPT = """你是来访者心理模拟器。基于完整交流脉络,生成2个不同的来访者回应。
104
+
105
+ ## 对话历史
106
+ {conversation_history}
107
+
108
+ ## 模拟交流
109
+ 咨询师①:{l1_therapist_reply}
110
+ 来访者①:{l2_client_response}
111
+ 咨询师②:{l3_therapist_reply}
112
+
113
+ ## 要求
114
+ 2个回应反映来访者此刻可能的不同反应。1-2句,口语化,符合来访者风格。
115
+
116
+ 只输出JSON:{{"A":"来访者回应A","B":"来访者回应B"}}"""
117
+
118
+ QUICK_EVAL_PROMPT = """你是自我揭露度快速评估器。评估来访者这句话的揭露深度(1-10分)。
119
+ 1=回避 2=寒暄 3=事件无情感 4=笼统情绪 5=具体情绪 6=情绪+反思 7=深度反思 8=触及回避主题 9=回避+强烈情感 10=完全突破
120
+ 当前揭露水平:{current_disclosure_score}分
121
+ 来访者说:{user_message}
122
+ 只输出JSON:{{"score":<1-10>}}"""
123
+
124
+ L5_L6_MERGED_PROMPT = """你是精神动力学对话模拟器。基于完整4轮交流脉络,生成第三轮:咨询师回应+来访者反应。
125
+
126
+ ## 对话历史
127
+ {conversation_history}
128
+
129
+ ## 模拟交流(前2轮)
130
+ 咨询师①:{l1_therapist_reply}
131
+ 来访者①:{l2_client_response}
132
+ 咨询师②:{l3_therapist_reply}
133
+ 来访者②:{l4_client_response}
134
+
135
+ ## 要求
136
+ 1. 咨询师第三轮:基于来访者②的回应,写一句最有推动力的咨询师回应(精神动力学取向,口语化,不要套话)
137
+ 2. 来访者第三轮:基于咨询师第三轮的回应,模拟来访者最可能的反应(口语化,符合来访者风格)
138
+
139
+ 只输出JSON:{{"l5_reply":"咨询师第三轮","l6_client":"来访者第三轮"}}"""
140
+
141
+ SESSION_SUMMARY_PROMPT = """你是精神动力学督导。总结以下咨询对话的走向。
142
+
143
+ ## 对话记录
144
+ {conversation_history}
145
+
146
+ ## 要求
147
+ 总结来访者的核心议题、情感状态、防御模式、治疗联盟状态。3-5句话。
148
+
149
+ 只输出JSON:{{"summary":"总结内容"}}"""
150
+
151
+ SEED_GENERATION_PROMPT = """你是精神动力学督导团的协调者。基于会话总结和最近对话,为5位督导各生成一个独特的观察切入点。
152
+
153
+ ## 会话总结
154
+ {summary}
155
+
156
+ ## 最近对话
157
+ {recent_history}
158
+
159
+ ## 要求
160
+ 5个切入点必须彼此不同,覆盖不同的临床维度(不限制,请自由发挥)。每个切入点一句话,要具体到这个来访者此刻的状态。注意:不要从身体感受或躯体体验入手。
161
+
162
+ 只输出JSON:{{"A":"切入点A","B":"切入点B","C":"切入点C","D":"切入点D","E":"切入点E"}}"""
163
+
164
+ THERAPIST_REPLY_PROMPT = """你是精神动力学咨询师。基于对话历史,从以下切入点出发,写一句咨询师的回应。
165
+
166
+ ## 对话记录
167
+ {conversation_history}
168
+
169
+ ## 切入点
170
+ {seed_perspective}
171
+
172
+ 要求:一句话,最多两句。口语化中文,精准,不用套话。不要问身体感受。只输出咨询师说的话。"""
173
+
174
+ CLIENT_RESPONSE_PROMPT = """你是来访者。根据对话历史和你的性格,从指定的回应方向出发,回应咨询师的最新发言。保持与之前一致的语言风格和防御水平。
175
+
176
+ ## 对话历史
177
+ {conversation_history}
178
+
179
+ ## 咨询师刚说
180
+ {therapist_reply}
181
+
182
+ ## 你的回应方向
183
+ {client_direction}
184
+
185
+ 只输出来访者说的话,不要任何其他内容。"""
186
+
187
+ L2_DIRECTION_PROMPT = """你是来访者心理模拟器。基于对话历史和咨询师刚才的话,生成3个来访者可能的回应方向。
188
+
189
+ ## 对话历史
190
+ {conversation_history}
191
+
192
+ ## 咨询师刚说
193
+ {therapist_reply}
194
+
195
+ ## 要求
196
+ 3个方向必须彼此不同,反映来访者可能的不同心理状态:
197
+ - 可能更防御/回避
198
+ - 可能开始松动/试探性回应
199
+ - 可能意外打开/情绪流露
200
+ 每个方向一句话描述来访者的内在状态和可能的反应倾向。
201
+
202
+ 只输出JSON:{{"A":"方向A","B":"方向B","C":"方向C"}}"""
203
+
204
+
205
+ L3_SEED_GENERATION_PROMPT = """你是精神动力学督导。基于以下对话和最新一轮模拟交流,为下一步探索生成5个不同的切入方向。
206
+
207
+ ## 对话记录
208
+ {conversation_history}
209
+
210
+ ## 最新交流(模拟)
211
+ 咨询师:{l1_therapist_reply}
212
+ 来访者:{l2_client_response}
213
+
214
+ ## 要求
215
+ 5个方向必须彼此不同,针对来访者刚才的回应中可以深入的不同面向。每个方向一句话。不要从身体感受入手。
216
+
217
+ 只输出JSON:{{"A":"方向A","B":"方向B","C":"方向C","D":"方向D","E":"方向E"}}"""
218
+
219
+ THERAPIST_CONTINUATION_PROMPT = """你是精神动力学咨询师。对话中发生了以下交流,现在从指定方向写你的下一句回应。
220
+
221
+ ## 对话记录
222
+ {conversation_history}
223
+
224
+ ## 刚才的交流(模拟)
225
+ 咨询师:{l1_therapist_reply}
226
+ 来访者:{l2_client_response}
227
+
228
+ ## 你的探索方向
229
+ {l3_seed}
230
+
231
+ 要求:一句话,最多两句。自然延续,不重复之前说过的。不要问身体感受。只输出咨询师说的话。"""
232
+
233
+ RELATIVE_DISCLOSURE_EVAL_PROMPT = """你是自我揭露深度评估器。先提取信息,再严格打分(10分制)。
234
+
235
+ ## 第一步:逐条提取(用是/否回答)
236
+ A. 是否提及了之前未说过的具体个人事件或经历?(重复已说过的不算)
237
+ B. 是否表达了情绪或情感体验?
238
+ C. 情绪表达是否具体而非笼统?("我很烦"=笼统,"我对他的忽视感到愤怒"=具体)
239
+ D. 是否涉及自我认知、关系模式或内心冲突的反思?
240
+ E. 是否触及了通常会回避的主题,或伴有犹豫/修正/欲言又止?
241
+
242
+ ## 第二步:对照表评分(10分制,就低不就高)
243
+ 1分:完全回避或拒绝回应("不想说""没什么")
244
+ 2分:A否+B否。纯寒暄、闲聊、泛泛而谈,无个人信息。
245
+ 3分:A是+B否。提到个人事件但零情感(纯叙事)。
246
+ 4分:B是+C否。有情绪但笼统("不太好""有点烦""挺累的")。
247
+ 5分:C是+D否。情绪具体(能说出是什么情绪、针对谁),但没有自我反思。
248
+ 6分:C是+D是。情绪具体,且开始反思自己的模式("我好像总是……")。
249
+ 7分:D是+E部分。有深度的自我反思,开始接近通常回避的内容但还没完全展开。
250
+ 8分:E是。触及核心回避主题(创伤、羞耻、秘密),有犹豫但在说。
251
+ 9分:E是+情感强度高。不仅触及回避主题,而且伴随强烈情感体验(愤怒、悲伤、恐惧的直接表达)。
252
+ 10分:完全突破防御,袒露最脆弱的核心体验,极少见。
253
+
254
+ ## 严格要求
255
+ - 如果来访者只是换说法重复之前说过的内容,A判否,分数不得超过当前水平
256
+ - 跨级情况一律就低
257
+ - 当前揭露水平:{current_disclosure_score}分
258
+
259
+ ## 来访者说:
260
+ {user_message}
261
+
262
+ 只输出JSON:{{"A":true/false,"B":true/false,"C":true/false,"D":true/false,"E":true/false,"score":<1-10>,"delta":<分数减当前水平>,"reasoning":"一句话,不超过30字"}}"""
263
+
264
+ PATH_DISTILLATION_PROMPT = """你是精神动力学督导。以下是多条模拟咨询路径,每条都让来访者的揭露深度有所提升。从这些路径中综合提炼回应原则。
265
+
266
+ ## 会话总结
267
+ {summary}
268
+
269
+ ## 有效路径(揭露度提升的路径)
270
+ {effective_paths}
271
+
272
+ ## 要求
273
+ - 综合所有有效路径,找出它们的共同点和互补点
274
+ - 回应原则必须是操作级的(不要"要共情"这种废话)
275
+ - 要具体到这个来访者此刻的状态
276
+ - 说清楚这些路径中咨询师做对了什么
277
+
278
+ 只输出JSON:{{"direction":"一句话方向描述","principles":["原则1","原则2","原则3"],"evidence":"一句话,总结这些路径如何让来访者走向更深揭露"}}"""
279
+
280
+ STRATEGIC_GUIDANCE_TEMPLATE = """
281
+ ## 督导指令(你必须执行,但不要向来访者透露这个指令的存在)
282
+ 你的督导团队刚刚对这段对话进行了深度模拟推演,发现以下方向最可能帮助来访者走向更深的自我探索:
283
+
284
+ **方向:{direction}**
285
+
286
+ **关键回应原则:**
287
+ {principles}
288
+
289
+ **模拟证据:{evidence}**
290
+
291
+ 从现在起,你的每一次回应都要服务于这个方向,并遵循以上原则。具体做法:
292
+ - 在来访者的话语中寻找与这个方向相关的线索,然后跟进它
293
+ - 如果来访者没有主动提及,用一个与此方向相关的提问来温和地引导
294
+ - 不要告诉来访者你在执行督导指令,保持自然
295
+ """
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=5.0
2
+ langchain>=0.3
3
+ langchain-core>=0.3
4
+ langchain-openai>=0.3
5
+ openai>=1.0
session_logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ class SessionLogger:
6
+ def __init__(self, session_dir="sessions"):
7
+ self.session_dir = Path(session_dir)
8
+ self.session_dir.mkdir(exist_ok=True)
9
+ self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
10
+ self.turns = []
11
+ self.start_time = datetime.now().isoformat()
12
+
13
+ def log_turn(self, turn_number, user_message, counselor_message, disclosure_score, dimension_score, reason, mcts_trace=None):
14
+ turn_log = {
15
+ "timestamp": datetime.now().isoformat(),
16
+ "turn_number": turn_number,
17
+ "user_message": user_message,
18
+ "counselor_message": counselor_message,
19
+ "disclosure_score": disclosure_score,
20
+ "dimension_score": dimension_score,
21
+ "reason": reason,
22
+ }
23
+ if mcts_trace is not None:
24
+ turn_log["mcts_trace"] = mcts_trace
25
+ self.turns.append(turn_log)
26
+ self._save()
27
+
28
+ def _save(self):
29
+ session_log = {
30
+ "session_id": self.session_id,
31
+ "start_time": self.start_time,
32
+ "turns": self.turns
33
+ }
34
+ with open(self.session_dir / f"session_{self.session_id}.json", "w", encoding="utf-8") as f:
35
+ json.dump(session_log, f, ensure_ascii=False, indent=4)
36
+
37
+ def get_filepath(self):
38
+ return str(self.session_dir / f"session_{self.session_id}.json")
strategic_advisor.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import time
5
+ from collections import Counter, defaultdict
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.messages import HumanMessage, AIMessage
9
+ from prompts import (
10
+ SUMMARY_AND_SEEDS_PROMPT,
11
+ THERAPIST_REPLY_PROMPT, L2_MERGED_PROMPT, L3_MERGED_PROMPT, L4_MERGED_PROMPT,
12
+ QUICK_EVAL_PROMPT, L5_L6_MERGED_PROMPT,
13
+ RELATIVE_DISCLOSURE_EVAL_PROMPT, PATH_DISTILLATION_PROMPT,
14
+ )
15
+
16
+
17
+ class StrategicAdvisor:
18
+ """PUCT版: UCB自适应预算分配 + 可变深度探索"""
19
+
20
+ def __init__(self, c_puct=1.5):
21
+ dashscope_base = dict(
22
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
23
+ api_key=os.getenv("DASHSCOPE_API_KEY"),
24
+ )
25
+ self.llm = ChatOpenAI(model="qwen-turbo", **dashscope_base, temperature=0.7, max_tokens=256)
26
+ self.c_puct = c_puct
27
+
28
+ # ===== 工具方法 =====
29
+
30
+ def _format_history(self, history):
31
+ lines = []
32
+ for msg in history:
33
+ if isinstance(msg, HumanMessage):
34
+ lines.append(f"来访者:{msg.content}")
35
+ elif isinstance(msg, AIMessage):
36
+ lines.append(f"咨询师:{msg.content}")
37
+ return "\n".join(lines) if lines else "(无)"
38
+
39
+ def _parse_json(self, text):
40
+ content = text.strip()
41
+ start = content.find("{")
42
+ end = content.rfind("}") + 1
43
+ if start == -1 or end == 0:
44
+ raise ValueError(f"无法解析 JSON: {content[:100]}")
45
+ return json.loads(content[start:end])
46
+
47
+ # ===== PUCT 核心 =====
48
+
49
+ def compute_ucb(self, paths):
50
+ """为每条路径计算 UCB 分数"""
51
+ seed_counts = Counter(p["id"] for p in paths)
52
+ n_total = len(paths)
53
+ for p in paths:
54
+ q = p.get("_quick_score", p.get("score", 1)) / 10.0
55
+ n_seed = seed_counts[p["id"]]
56
+ exploration = self.c_puct * math.sqrt(n_total) / (1 + n_seed)
57
+ p["ucb"] = q + exploration
58
+ return paths
59
+
60
+ def allocate_budget(self, paths, total_budget, min_per_seed=1):
61
+ """按 UCB 分配展开预算,保证每个种子至少 min_per_seed 个名额"""
62
+ self.compute_ucb(paths)
63
+ ranked = sorted(paths, key=lambda x: -x["ucb"])
64
+
65
+ # 保底:每个种子至少选一条
66
+ selected = []
67
+ seeds_seen = set()
68
+ for p in ranked:
69
+ if p["id"] not in seeds_seen:
70
+ selected.append(p)
71
+ seeds_seen.add(p["id"])
72
+ if len(selected) >= total_budget:
73
+ return selected
74
+
75
+ # 剩余预算按 UCB 排序填充
76
+ remaining = total_budget - len(selected)
77
+ for p in ranked:
78
+ if remaining <= 0:
79
+ break
80
+ if p not in selected:
81
+ selected.append(p)
82
+ remaining -= 1
83
+
84
+ return selected
85
+
86
+ # ===== 快速评分 =====
87
+
88
+ def _quick_score_one(self, text, current_disclosure):
89
+ prompt = QUICK_EVAL_PROMPT.replace(
90
+ "{current_disclosure_score}", str(current_disclosure)
91
+ ).replace("{user_message}", text)
92
+ try:
93
+ result = self.llm.invoke(prompt)
94
+ parsed = self._parse_json(result.content)
95
+ return max(1, min(10, int(parsed.get("score", 1))))
96
+ except Exception:
97
+ return 1
98
+
99
+ def quick_score_paths(self, paths, text_key, current_disclosure):
100
+ """对路径列表并行快速评分,结果写入 _quick_score 字段"""
101
+ with ThreadPoolExecutor(max_workers=len(paths)) as executor:
102
+ futures = {executor.submit(self._quick_score_one, p[text_key], current_disclosure): i
103
+ for i, p in enumerate(paths)}
104
+ for future in as_completed(futures):
105
+ idx = futures[future]
106
+ paths[idx]["_quick_score"] = future.result()
107
+ return paths
108
+
109
+ # ===== Step 1: 总结 + 3个种子 =====
110
+
111
+ def summarize_and_seeds(self, history):
112
+ prompt = SUMMARY_AND_SEEDS_PROMPT.replace(
113
+ "{conversation_history}", self._format_history(history)
114
+ )
115
+ for attempt in range(3):
116
+ try:
117
+ result = self.llm.invoke(prompt)
118
+ parsed = self._parse_json(result.content)
119
+ summary = parsed.pop("summary", "总结失败")
120
+ seeds = {k: v for k, v in parsed.items() if k in "ABC"}
121
+ for k in "ABC":
122
+ seeds.setdefault(k, "从你独特的临床视角出发")
123
+ return summary, seeds
124
+ except (json.JSONDecodeError, ValueError):
125
+ if attempt == 2:
126
+ return "总结失败", {k: "从你独特的临床视角出发" for k in "ABC"}
127
+
128
+ # ===== Step 2 / L1: 3×咨询师回复 =====
129
+
130
+ def _gen_therapist_reply(self, seed_id, seed, history_text):
131
+ prompt = THERAPIST_REPLY_PROMPT.replace(
132
+ "{conversation_history}", history_text
133
+ ).replace("{seed_perspective}", seed)
134
+ try:
135
+ result = self.llm.invoke(prompt)
136
+ return {"id": seed_id, "seed": seed, "reply": result.content.strip()}
137
+ except Exception as e:
138
+ return {"id": seed_id, "seed": seed, "reply": f"(生成失败: {e})"}
139
+
140
+ def generate_l1(self, seeds, history):
141
+ history_text = self._format_history(history)
142
+ results = []
143
+ with ThreadPoolExecutor(max_workers=3) as executor:
144
+ futures = {executor.submit(self._gen_therapist_reply, sid, seeds[sid], history_text): sid for sid in seeds}
145
+ for future in as_completed(futures):
146
+ results.append(future.result())
147
+ results.sort(key=lambda x: x["id"])
148
+ return results
149
+
150
+ # ===== Step 3 / L2: 合并方向+来访者回应 =====
151
+
152
+ def _gen_l2_merged(self, l1_item, history_text):
153
+ prompt = L2_MERGED_PROMPT.replace(
154
+ "{conversation_history}", history_text
155
+ ).replace("{therapist_reply}", l1_item["reply"])
156
+ try:
157
+ result = self.llm.invoke(prompt)
158
+ parsed = self._parse_json(result.content)
159
+ return [{**l1_item, "l2_dir": did, "client_response": parsed.get(did, "(模拟失败)")}
160
+ for did in ["A", "B"]]
161
+ except Exception:
162
+ return [{**l1_item, "l2_dir": d, "client_response": "(模拟失败)"} for d in ["A", "B"]]
163
+
164
+ def generate_l2(self, l1_results, history):
165
+ history_text = self._format_history(history)
166
+ results = []
167
+ with ThreadPoolExecutor(max_workers=3) as executor:
168
+ futures = {executor.submit(self._gen_l2_merged, item, history_text): item["id"] for item in l1_results}
169
+ for future in as_completed(futures):
170
+ results.extend(future.result())
171
+ results.sort(key=lambda x: (x["id"], x["l2_dir"]))
172
+ return results
173
+
174
+ # ===== Step 4 / L3: 合并种子+咨询师延续 =====
175
+
176
+ def _gen_l3_merged(self, l2_item, history_text):
177
+ prompt = L3_MERGED_PROMPT.replace(
178
+ "{conversation_history}", history_text
179
+ ).replace("{l1_therapist_reply}", l2_item["reply"]
180
+ ).replace("{l2_client_response}", l2_item["client_response"])
181
+ try:
182
+ result = self.llm.invoke(prompt)
183
+ parsed = self._parse_json(result.content)
184
+ return [{
185
+ "id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": bid,
186
+ "seed": l2_item["seed"], "l1_reply": l2_item["reply"],
187
+ "l2_client": l2_item["client_response"], "l3_reply": parsed.get(bid, "(生成失败)"),
188
+ } for bid in ["A", "B", "C"]]
189
+ except Exception:
190
+ return [{
191
+ "id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": b,
192
+ "seed": l2_item["seed"], "l1_reply": l2_item["reply"],
193
+ "l2_client": l2_item["client_response"], "l3_reply": "(生成失败)",
194
+ } for b in ["A", "B", "C"]]
195
+
196
+ def generate_l3(self, l2_selected, history):
197
+ history_text = self._format_history(history)
198
+ results = []
199
+ with ThreadPoolExecutor(max_workers=len(l2_selected)) as executor:
200
+ futures = {executor.submit(self._gen_l3_merged, item, history_text): (item["id"], item["l2_dir"])
201
+ for item in l2_selected}
202
+ for future in as_completed(futures):
203
+ results.extend(future.result())
204
+ results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"]))
205
+ return results
206
+
207
+ # ===== Step 5 / L4: 合并方向+来访者回应 =====
208
+
209
+ def _gen_l4_merged(self, l3_item, history_text):
210
+ prompt = L4_MERGED_PROMPT.replace(
211
+ "{conversation_history}", history_text
212
+ ).replace("{l1_therapist_reply}", l3_item["l1_reply"]
213
+ ).replace("{l2_client_response}", l3_item["l2_client"]
214
+ ).replace("{l3_therapist_reply}", l3_item["l3_reply"])
215
+ try:
216
+ result = self.llm.invoke(prompt)
217
+ parsed = self._parse_json(result.content)
218
+ return [{**l3_item, "l4_dir": did, "l4_client": parsed.get(did, "(模拟失败)")}
219
+ for did in ["A", "B"]]
220
+ except Exception:
221
+ return [{**l3_item, "l4_dir": d, "l4_client": "(模拟失败)"} for d in ["A", "B"]]
222
+
223
+ def generate_l4(self, l3_selected, history):
224
+ history_text = self._format_history(history)
225
+ results = []
226
+ with ThreadPoolExecutor(max_workers=len(l3_selected)) as executor:
227
+ futures = {executor.submit(self._gen_l4_merged, item, history_text): (item["id"], item["l2_dir"], item["branch"])
228
+ for item in l3_selected}
229
+ for future in as_completed(futures):
230
+ results.extend(future.result())
231
+ results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"], x.get("l4_dir", "")))
232
+ return results
233
+
234
+ # ===== Step 5.5: 终评 =====
235
+
236
+ def _score_relative(self, item, current_disclosure):
237
+ prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace(
238
+ "{current_disclosure_score}", str(current_disclosure)
239
+ ).replace("{user_message}", item["l4_client"])
240
+ try:
241
+ result = self.llm.invoke(prompt)
242
+ parsed = self._parse_json(result.content)
243
+ score = max(1, min(10, int(parsed.get("score", 1))))
244
+ dims = {k: parsed.get(k, False) for k in "ABCDE"}
245
+ return {**item, "score": score, "delta": score - current_disclosure,
246
+ "dims": dims, "reason": parsed.get("reasoning", "")}
247
+ except Exception:
248
+ return {**item, "score": 1, "delta": 1 - current_disclosure,
249
+ "dims": {}, "reason": "评分失败"}
250
+
251
+ def score_all(self, l4_results, current_disclosure=1):
252
+ results = []
253
+ with ThreadPoolExecutor(max_workers=len(l4_results)) as executor:
254
+ futures = {executor.submit(self._score_relative, item, current_disclosure): i
255
+ for i, item in enumerate(l4_results)}
256
+ for future in as_completed(futures):
257
+ results.append(future.result())
258
+ results.sort(key=lambda x: (x["id"], x.get("l2_dir", ""), x["branch"]))
259
+ return results
260
+
261
+ # ===== Step 6: 高UCB路径深度探索 (L5+L6) =====
262
+
263
+ def _gen_l5_l6(self, item, history_text):
264
+ prompt = L5_L6_MERGED_PROMPT.replace(
265
+ "{conversation_history}", history_text
266
+ ).replace("{l1_therapist_reply}", item["l1_reply"]
267
+ ).replace("{l2_client_response}", item["l2_client"]
268
+ ).replace("{l3_therapist_reply}", item["l3_reply"]
269
+ ).replace("{l4_client_response}", item["l4_client"])
270
+ try:
271
+ result = self.llm.invoke(prompt)
272
+ parsed = self._parse_json(result.content)
273
+ return {**item,
274
+ "l5_reply": parsed.get("l5_reply", "(生成失败)"),
275
+ "l6_client": parsed.get("l6_client", "(模拟失败)"),
276
+ "depth": 6}
277
+ except Exception:
278
+ return {**item, "l5_reply": "(生成失败)", "l6_client": "(模拟失败)", "depth": 6}
279
+
280
+ def _score_l6(self, item, current_disclosure):
281
+ prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace(
282
+ "{current_disclosure_score}", str(current_disclosure)
283
+ ).replace("{user_message}", item["l6_client"])
284
+ try:
285
+ result = self.llm.invoke(prompt)
286
+ parsed = self._parse_json(result.content)
287
+ score = max(1, min(10, int(parsed.get("score", 1))))
288
+ return {**item, "l6_score": score, "l6_delta": score - current_disclosure,
289
+ "reason": parsed.get("reasoning", item.get("reason", ""))}
290
+ except Exception:
291
+ return {**item, "l6_score": item.get("score", 1), "l6_delta": 0}
292
+
293
+ def deep_explore(self, top_paths, history, current_disclosure):
294
+ """对 top UCB 路径进行 L5+L6 深度探索"""
295
+ history_text = self._format_history(history)
296
+ # 并行生成 L5+L6
297
+ deep_results = []
298
+ with ThreadPoolExecutor(max_workers=len(top_paths)) as executor:
299
+ futures = {executor.submit(self._gen_l5_l6, item, history_text): i
300
+ for i, item in enumerate(top_paths)}
301
+ for future in as_completed(futures):
302
+ deep_results.append(future.result())
303
+ # 并行评分 L6
304
+ scored = []
305
+ with ThreadPoolExecutor(max_workers=len(deep_results)) as executor:
306
+ futures = {executor.submit(self._score_l6, item, current_disclosure): i
307
+ for i, item in enumerate(deep_results)}
308
+ for future in as_completed(futures):
309
+ scored.append(future.result())
310
+ return scored
311
+
312
+ # ===== Step 7: 蒸馏(UCB加权,深度路径×2) =====
313
+
314
+ def distill_paths(self, scored_4layer, deep_paths, summary):
315
+ """合并4层和6层路径,按UCB加权选择蒸馏输入"""
316
+ # 4层有效路径
317
+ effective_4 = [item for item in scored_4layer if item.get("delta", 0) > 0]
318
+ # 6层路径(权重×2,复制一份进入排名)
319
+ effective_6 = []
320
+ for item in deep_paths:
321
+ item["_distill_weight"] = 2
322
+ effective_6.append(item)
323
+
324
+ all_effective = effective_4 + effective_6
325
+ if not all_effective:
326
+ # 退化:取4层最高分
327
+ ranked = sorted(scored_4layer, key=lambda x: x["score"], reverse=True)
328
+ all_effective = [ranked[0]] if ranked else []
329
+
330
+ # 按加权分数降序:6层路径分数×1.5(深度奖励)
331
+ def sort_key(x):
332
+ base = x.get("l6_score", x.get("score", 0))
333
+ depth_bonus = 1.5 if x.get("depth") == 6 else 1.0
334
+ return base * depth_bonus
335
+ ranked = sorted(all_effective, key=sort_key, reverse=True)
336
+ top = ranked[:5]
337
+
338
+ # 格式化
339
+ path_texts = []
340
+ for i, item in enumerate(top, 1):
341
+ depth = item.get("depth", 4)
342
+ if depth == 6:
343
+ path_texts.append(
344
+ f"路径{i}(种子{item['id']}.{item['branch']},深度=6轮,揭露度+{item.get('l6_delta', 0)}):\n"
345
+ f" 咨询师①:{item['l1_reply']}\n"
346
+ f" 来访者①:{item['l2_client']}\n"
347
+ f" 咨询师②:{item['l3_reply']}\n"
348
+ f" 来访者②:{item['l4_client']}\n"
349
+ f" 咨询师③:{item['l5_reply']}\n"
350
+ f" 来访者③:{item['l6_client']}"
351
+ )
352
+ else:
353
+ path_texts.append(
354
+ f"路径{i}(种子{item['id']}.{item['branch']},深度=4轮,揭露度+{item.get('delta', 0)}):\n"
355
+ f" 咨询师①:{item['l1_reply']}\n"
356
+ f" 来访者①:{item['l2_client']}\n"
357
+ f" 咨询师②:{item['l3_reply']}\n"
358
+ f" 来访者②:{item['l4_client']}"
359
+ )
360
+ effective_paths_text = "\n\n".join(path_texts)
361
+
362
+ prompt = PATH_DISTILLATION_PROMPT.replace(
363
+ "{summary}", summary
364
+ ).replace("{effective_paths}", effective_paths_text)
365
+
366
+ n_deep = sum(1 for t in top if t.get("depth") == 6)
367
+ seeds_in = set(t["id"] for t in top)
368
+ print(f"[PUCT] 蒸馏输入: {len(top)}条路径({n_deep}条6轮深度, {len(seeds_in)}个种子覆盖)")
369
+ try:
370
+ result = self.llm.invoke(prompt)
371
+ parsed = self._parse_json(result.content)
372
+ parsed["_distill_count"] = len(top)
373
+ parsed["_deep_count"] = n_deep
374
+ parsed["_distill_ids"] = [f"{i['id']}.{i['branch']}" for i in top]
375
+ return parsed
376
+ except Exception:
377
+ best = top[0] if top else scored_4layer[0]
378
+ return {
379
+ "direction": best.get("seed", ""),
380
+ "principles": [f"沿着「{best.get('seed', '')}」的方向继续探索"],
381
+ "evidence": f"模拟显示{len(top)}条路径有效",
382
+ "_distill_count": len(top), "_deep_count": 0,
383
+ "_distill_ids": [f"{best['id']}.{best['branch']}"],
384
+ }
385
+
386
+ # ===== 完整 PUCT 流程 =====
387
+
388
+ def run(self, history, current_disclosure=1):
389
+ total_start = time.time()
390
+
391
+ # Step 1: 总结 + 3种子
392
+ t = time.time()
393
+ summary, seeds = self.summarize_and_seeds(history)
394
+ t1 = time.time() - t
395
+ print(f"[PUCT] Step1 总结+种子: {t1:.1f}s | {summary[:60]}")
396
+ for sid, seed in seeds.items():
397
+ print(f" {sid}: {seed[:50]}")
398
+
399
+ # Step 2 / L1: 3×咨询师
400
+ t = time.time()
401
+ l1 = self.generate_l1(seeds, history)
402
+ t2 = time.time() - t
403
+ print(f"[PUCT] L1 {len(l1)}×咨询师: {t2:.1f}s")
404
+
405
+ # Step 3 / L2: 6×来访者
406
+ t = time.time()
407
+ l2 = self.generate_l2(l1, history)
408
+ t3 = time.time() - t
409
+ print(f"[PUCT] L2 {len(l2)}×来访者: {t3:.1f}s")
410
+
411
+ # Step 3.5: L2 快速评分
412
+ t = time.time()
413
+ l2 = self.quick_score_paths(l2, "client_response", current_disclosure)
414
+ t3_5 = time.time() - t
415
+ print(f"[PUCT] L2快评: {t3_5:.1f}s")
416
+ for item in l2:
417
+ print(f" L2-{item['id']}.{item['l2_dir']}: qs={item['_quick_score']} | {item['client_response'][:30]}")
418
+
419
+ # Step 4: UCB选择 → L3 (预算≤6条L2进入L3)
420
+ l2_budget = min(6, len(l2)) # 最多全选
421
+ l2_selected = self.allocate_budget(l2, l2_budget)
422
+ print(f"[PUCT] UCB选择L2→L3: {len(l2_selected)}条 (from {len(l2)})")
423
+ for item in l2_selected:
424
+ print(f" 选中 {item['id']}.{item['l2_dir']}: ucb={item['ucb']:.2f} qs={item['_quick_score']}")
425
+
426
+ t = time.time()
427
+ l3 = self.generate_l3(l2_selected, history)
428
+ t4 = time.time() - t
429
+ print(f"[PUCT] L3 {len(l3)}×咨询师: {t4:.1f}s")
430
+
431
+ # Step 4.5: L3 快速评分(评估咨询师回应的推动效果)
432
+ t = time.time()
433
+ l3 = self.quick_score_paths(l3, "l3_reply", current_disclosure)
434
+ t4_5 = time.time() - t
435
+ print(f"[PUCT] L3快评: {t4_5:.1f}s")
436
+
437
+ # Step 5: UCB选择 → L4 (预算≤12条L3进入L4)
438
+ l3_budget = min(12, len(l3))
439
+ l3_selected = self.allocate_budget(l3, l3_budget)
440
+ print(f"[PUCT] UCB选择L3→L4: {len(l3_selected)}条 (from {len(l3)})")
441
+
442
+ t = time.time()
443
+ l4 = self.generate_l4(l3_selected, history)
444
+ t5 = time.time() - t
445
+ print(f"[PUCT] L4 {len(l4)}×来访者: {t5:.1f}s")
446
+
447
+ # Step 5.5: L4 终评
448
+ t = time.time()
449
+ scored = self.score_all(l4, current_disclosure)
450
+ t5_5 = time.time() - t
451
+ print(f"[PUCT] L4终评({len(scored)}条): {t5_5:.1f}s")
452
+ for item in sorted(scored, key=lambda x: -x["score"])[:5]:
453
+ print(f" {item['id']}.{item.get('l2_dir','')}.{item['branch']}: score={item['score']} delta={item['delta']}")
454
+
455
+ # 选当前最优
456
+ groups = defaultdict(list)
457
+ for item in scored:
458
+ groups[item["id"]].append(item)
459
+ seed_best = {sid: max(items, key=lambda x: x["score"]) for sid, items in groups.items()}
460
+ best = max(seed_best.values(), key=lambda x: x["score"])
461
+ print(f"[PUCT] 4层最优: {best['id']}.{best.get('l2_dir','')}.{best['branch']} score={best['score']} delta={best['delta']}")
462
+
463
+ # Step 6: 高UCB路径深度探索 (L5+L6)
464
+ # 从终评结果中选 top-3 by UCB
465
+ self.compute_ucb(scored)
466
+ top3 = sorted(scored, key=lambda x: -x["ucb"])[:3]
467
+ top3_desc = [f"{p['id']}.{p.get('l2_dir','')}.{p['branch']}(ucb={p['ucb']:.2f})" for p in top3]
468
+ print(f"[PUCT] 深度探索 top-3: {top3_desc}")
469
+
470
+ t = time.time()
471
+ deep_paths = self.deep_explore(top3, history, current_disclosure)
472
+ t6 = time.time() - t
473
+ print(f"[PUCT] L5+L6深探: {t6:.1f}s")
474
+ for dp in deep_paths:
475
+ print(f" 深探 {dp['id']}.{dp['branch']}: L5={dp['l5_reply'][:30]} → L6 score={dp.get('l6_score','?')} delta={dp.get('l6_delta','?')}")
476
+
477
+ # 更新 best(如果深度路径更好)
478
+ for dp in deep_paths:
479
+ if dp.get("l6_score", 0) > best.get("score", 0):
480
+ best = dp
481
+ print(f"[PUCT] 深探更优: {dp['id']}.{dp['branch']} l6_score={dp['l6_score']}")
482
+
483
+ # Step 7: 蒸馏
484
+ t = time.time()
485
+ guidance = self.distill_paths(scored, deep_paths, summary)
486
+ t7 = time.time() - t
487
+ print(f"[PUCT] 蒸馏: {t7:.1f}s")
488
+ print(f" 方向: {guidance.get('direction', '?')}")
489
+ for p in guidance.get("principles", []):
490
+ print(f" 原则: {p}")
491
+
492
+ total_cost = time.time() - total_start
493
+ print(f"[PUCT] 总耗时: {total_cost:.1f}s")
494
+
495
+ strategic_trace = {
496
+ "summary": summary,
497
+ "seeds": seeds,
498
+ "candidates": [
499
+ {
500
+ "id": item["id"], "branch": item["branch"],
501
+ "l1_reply": item["l1_reply"], "l2_client": item["l2_client"],
502
+ "l3_reply": item["l3_reply"], "l4_client": item["l4_client"],
503
+ "score": item["score"], "delta": item["delta"], "reason": item.get("reason", ""),
504
+ }
505
+ for item in scored
506
+ ],
507
+ "deep_paths": [
508
+ {
509
+ "id": dp["id"], "branch": dp["branch"],
510
+ "l5_reply": dp.get("l5_reply", ""), "l6_client": dp.get("l6_client", ""),
511
+ "l6_score": dp.get("l6_score", 0), "l6_delta": dp.get("l6_delta", 0),
512
+ }
513
+ for dp in deep_paths
514
+ ],
515
+ "selected": f"{best['id']}.{best.get('l2_dir','')}.{best['branch']}",
516
+ "guidance": guidance,
517
+ "current_disclosure": current_disclosure,
518
+ "timing": {
519
+ "total_seconds": round(total_cost, 1),
520
+ "step1_summary_seeds": round(t1, 1),
521
+ "L1_therapist": round(t2, 1),
522
+ "L2_merged": round(t3, 1),
523
+ "L2_quick_score": round(t3_5, 1),
524
+ "L3_merged": round(t4, 1),
525
+ "L3_quick_score": round(t4_5, 1),
526
+ "L4_merged": round(t5, 1),
527
+ "L4_final_score": round(t5_5, 1),
528
+ "L5_L6_deep": round(t6, 1),
529
+ "distillation": round(t7, 1),
530
+ },
531
+ }
532
+
533
+ return best, guidance, strategic_trace
strategy_visualizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """战略推理可视化:将每次战略决策的完整路径渲染为可读的文本树。"""
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+
6
+ class StrategyVisualizer:
7
+ """生成战略推理的可视化报告,保存到 sessions/strategy_vis/ 目录。"""
8
+
9
+ def __init__(self, session_dir="sessions/strategy_vis"):
10
+ self.dir = Path(session_dir)
11
+ self.dir.mkdir(parents=True, exist_ok=True)
12
+ self.report_count = 0
13
+
14
+ def render(self, trace, turn_number=0):
15
+ """将一次战略推理 trace 渲染为可视化文本并保存。"""
16
+ if not trace:
17
+ return
18
+
19
+ self.report_count += 1
20
+ lines = []
21
+ ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
22
+ lines.append(f"{'=' * 70}")
23
+ lines.append(f"战略推理报告 | 第{turn_number}轮 | {ts}")
24
+ lines.append(f"{'=' * 70}")
25
+
26
+ # 总结
27
+ lines.append(f"\n📋 会话总结:")
28
+ lines.append(f" {trace.get('summary', '?')}")
29
+
30
+ # 当前揭露度
31
+ current = trace.get("current_disclosure", "?")
32
+ lines.append(f"\n📊 当前揭露水平: {current}/10")
33
+
34
+ # 种子
35
+ seeds = trace.get("seeds", {})
36
+ lines.append(f"\n🌱 种子视角:")
37
+ for sid, seed in seeds.items():
38
+ lines.append(f" {sid}: {seed}")
39
+
40
+ # 路径树
41
+ candidates = trace.get("candidates", [])
42
+ if not candidates:
43
+ lines.append("\n⚠️ 无候选路径")
44
+ else:
45
+ # 按种子分组
46
+ from collections import defaultdict
47
+ groups = defaultdict(list)
48
+ for c in candidates:
49
+ groups[c["id"]].append(c)
50
+
51
+ # 标记有效路径(delta>0)
52
+ effective_ids = set()
53
+ for c in candidates:
54
+ if c.get("delta", 0) > 0:
55
+ effective_ids.add((c["id"], c["branch"]))
56
+
57
+ lines.append(f"\n🌳 路径树 (共{len(candidates)}条,{len(effective_ids)}条有效):")
58
+ lines.append(f" {'─' * 66}")
59
+
60
+ for sid in sorted(groups.keys()):
61
+ items = groups[sid]
62
+ seed_text = seeds.get(sid, "?")
63
+ max_score = max(i["score"] for i in items)
64
+ lines.append(f"")
65
+ lines.append(f" ┌─ 种子{sid}: {seed_text[:50]}")
66
+ lines.append(f" │ L1咨询师: {items[0].get('l1_reply', '?')[:55]}")
67
+ lines.append(f" │ L2来访者: {items[0].get('l2_client', '?')[:55]}")
68
+ lines.append(f" │")
69
+
70
+ for item in sorted(items, key=lambda x: x["branch"]):
71
+ bid = item["branch"]
72
+ score = item["score"]
73
+ delta = item.get("delta", 0)
74
+ is_effective = (sid, bid) in effective_ids
75
+ is_selected = f"{sid}.{bid}" == trace.get("selected", "")
76
+
77
+ # 标记符号
78
+ if is_selected:
79
+ marker = "★"
80
+ elif is_effective:
81
+ marker = "✓"
82
+ else:
83
+ marker = "·"
84
+
85
+ delta_str = f"+{delta}" if delta > 0 else str(delta)
86
+ lines.append(f" │ {marker} 分叉{bid}: score={score}/10 (Δ{delta_str}) | {item.get('reason', '')[:30]}")
87
+ lines.append(f" │ L3咨询师: {item.get('l3_reply', '?')[:50]}")
88
+ lines.append(f" │ L4来访者: {item.get('l4_client', '?')[:50]}")
89
+
90
+ lines.append(f" │ ── 种子{sid}最高分: {max_score}/10")
91
+ lines.append(f" └{'─' * 65}")
92
+
93
+ # 蒸馏结果
94
+ guidance = trace.get("guidance", {})
95
+ distill_count = guidance.get("_distill_count", len(effective_ids))
96
+ distill_ids = guidance.get("_distill_ids", [])
97
+ distill_label = f"{distill_count}条路径(前30%)" + (f" [{', '.join(distill_ids)}]" if distill_ids else "")
98
+ lines.append(f"\n🎯 蒸馏结果 (从{distill_label}):")
99
+ lines.append(f" 方向: {guidance.get('direction', '?')}")
100
+ for p in guidance.get("principles", []):
101
+ lines.append(f" • {p}")
102
+ lines.append(f" 证据: {guidance.get('evidence', '?')}")
103
+
104
+ # 计时
105
+ timing = trace.get("timing", {})
106
+ if timing:
107
+ lines.append(f"\n⏱ 计时:")
108
+ for k, v in timing.items():
109
+ lines.append(f" {k}: {v}s")
110
+
111
+ # 图例
112
+ lines.append(f"\n图例: ★=最终选中 ✓=有效路径(Δ>0) ·=未提升 | 蒸馏仅用前30%有效路径")
113
+ lines.append(f"{'=' * 70}\n")
114
+
115
+ report_text = "\n".join(lines)
116
+
117
+ # 保存
118
+ filename = f"turn{turn_number:03d}_{datetime.now().strftime('%H%M%S')}.txt"
119
+ filepath = self.dir / filename
120
+ with open(filepath, "w", encoding="utf-8") as f:
121
+ f.write(report_text)
122
+
123
+ # 同时打���到终端
124
+ print(report_text)
125
+
126
+ return str(filepath)
supervisor_advisor.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ v5 SupervisorAdvisor — 单次督导调用版
3
+
4
+ 架构:每轮回应前,把对话历史发给「督导」做一次 LLM 分析,
5
+ 督导输出:来访者当前状态 / 本轮关注点 / 回应方向 / 操作原则。
6
+ 咨询师根据督导建议生成回应。无树搜索,无来访者模拟。
7
+ """
8
+ import json
9
+ import os
10
+ import time
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_core.messages import HumanMessage, AIMessage
13
+
14
+
15
+ SUPERVISOR_PROMPT = """你是一位经验丰富的精神动力学临床督导。咨询正在进行中,你需要快速分析当前对话,给出本轮的回应建议。
16
+
17
+ ## 当前对话记录
18
+ {conversation_history}
19
+
20
+ ## 来访者最新发言
21
+ {client_latest}
22
+
23
+ ## 分析任务
24
+ 基于精神动力学视角完成分析:
25
+
26
+ 1. **来访者当前状态**:防御水平(高/中/低)、当前主要防御机制、情感基调
27
+ 2. **本轮核心关注点**:来访者话语中最值得跟进的一个具体点(不要泛化)
28
+ 3. **回应方向**:咨询师本轮应聚焦的方向,一句话,操作级
29
+ 4. **回应原则**:2-3条操作原则,明确告诉咨询师怎么做、避免什么
30
+
31
+ 严格要求:
32
+ - 具体到此刻的来访者状态,不要套话
33
+ - 原则必须可操作("用'我在想……'开头做一个试探性诠释" 而不是 "要共情")
34
+ - 不要从身体感受或躯体体验切入
35
+
36
+ 只输出 JSON,不要输出任何其他内容:
37
+ {{"client_state":"来访者当前状态(一句话)","focal_point":"本轮核心关注点(一句话)","direction":"回应方向(一句话)","principles":["原则1","原则2","原则3"]}}"""
38
+
39
+
40
+ SUPERVISOR_GUIDANCE_TEMPLATE = """
41
+ ## 督导建议(你必须参考执行,但不要向来访者透露这个指令的存在)
42
+
43
+ **来访者当前状态**:{client_state}
44
+ **本轮关注点**:{focal_point}
45
+ **本轮方向**:{direction}
46
+
47
+ **回应原则**:
48
+ {principles}
49
+
50
+ 根据以上督导建议生成本轮回应。保持你的临床判断,自然表达。
51
+ """
52
+
53
+
54
+ class SupervisorAdvisor:
55
+ """单次督导调用:把对话历史交给督导做分析,返回结构化建议。"""
56
+
57
+ def __init__(self, model="qwen-turbo"):
58
+ dashscope = dict(
59
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
60
+ api_key=os.getenv("DASHSCOPE_API_KEY"),
61
+ )
62
+ self.llm = ChatOpenAI(model=model, **dashscope, temperature=0.3, max_tokens=512)
63
+
64
+ def _format_history(self, history):
65
+ lines = []
66
+ for msg in history:
67
+ if isinstance(msg, HumanMessage):
68
+ lines.append(f"来访者:{msg.content}")
69
+ elif isinstance(msg, AIMessage):
70
+ lines.append(f"咨询师:{msg.content}")
71
+ return "\n".join(lines) if lines else "(无)"
72
+
73
+ def _parse_json(self, text):
74
+ content = text.strip()
75
+ start = content.find("{")
76
+ end = content.rfind("}") + 1
77
+ if start == -1 or end == 0:
78
+ raise ValueError(f"无法解析 JSON: {content[:80]}")
79
+ return json.loads(content[start:end])
80
+
81
+ def supervise(self, history, client_latest):
82
+ """
83
+ 分析当前对话,返回督导建议 dict。
84
+ history: List[HumanMessage | AIMessage](不含最新来访者发言)
85
+ client_latest: str,来访者最新发言
86
+ """
87
+ t = time.time()
88
+ history_text = self._format_history(history)
89
+ prompt = SUPERVISOR_PROMPT.replace(
90
+ "{conversation_history}", history_text
91
+ ).replace("{client_latest}", client_latest)
92
+
93
+ for attempt in range(3):
94
+ try:
95
+ result = self.llm.invoke(prompt)
96
+ parsed = self._parse_json(result.content)
97
+ elapsed = time.time() - t
98
+ print(f"[督导] {elapsed:.1f}s | 状态: {parsed.get('client_state','?')[:40]}")
99
+ print(f"[督导] 关注点: {parsed.get('focal_point','?')[:50]}")
100
+ print(f"[督导] 方向: {parsed.get('direction','?')[:50]}")
101
+ return parsed
102
+ except Exception as e:
103
+ if attempt == 2:
104
+ print(f"[督导] 分析失败,返回空建议: {e}")
105
+ return None
106
+
107
+ def format_guidance(self, supervision):
108
+ """把督导建议格式化为注入到咨询师 prompt 的文本。"""
109
+ if not supervision:
110
+ return None
111
+ principles_text = "\n".join(
112
+ f"- {p}" for p in supervision.get("principles", [])
113
+ )
114
+ return SUPERVISOR_GUIDANCE_TEMPLATE.replace(
115
+ "{client_state}", supervision.get("client_state", "")
116
+ ).replace(
117
+ "{focal_point}", supervision.get("focal_point", "")
118
+ ).replace(
119
+ "{direction}", supervision.get("direction", "")
120
+ ).replace(
121
+ "{principles}", principles_text
122
+ )