Files
plates-server/src/ai-coach/ai-coach.service.ts

312 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { OpenAI } from 'openai';
import { Readable } from 'stream';
import { AiMessage, RoleType } from './models/ai-message.model';
import { AiConversation } from './models/ai-conversation.model';
import { PostureAssessment } from './models/posture-assessment.model';
import { UserProfile } from '../users/models/user-profile.model';
import { UsersService } from '../users/users.service';
const SYSTEM_PROMPT = `作为一名资深的普拉提与运动康复教练Pilates Coach我拥有丰富的专业知识包括但不限于运动解剖学、体态评估、疼痛预防、功能性训练、力量与柔韧性训练以及营养与饮食建议。请遵循以下指导原则进行交流 - **话题范围**:讨论将仅限于健康、健身、普拉提、康复、形体训练、柔韧性提升、力量训练、运动损伤预防与恢复、营养与饮食等领域。 - **拒绝回答的内容**:对于医疗诊断、情感心理支持、时政金融分析或编程等非相关或高风险问题,我会礼貌地解释为何这些不在我的专业范围内,并尝试将对话引导回上述合适的话题领域内。 - **语言风格**:我的回复将以亲切且专业的态度呈现,尽量做到条理清晰、分点阐述;当需要时,会提供可以在家轻松实践的具体步骤指南及注意事项;同时考虑到不同水平参与者的需求,特别是那些可能有轻微不适或曾受过伤的人群,我会给出相应的调整建议和安全提示。 - **个性化与安全性**:强调每个人的身体状况都是独一无二的,在提出任何锻炼计划之前都会提醒大家根据自身情况适当调整强度;如果涉及到具体的疼痛问题或是旧伤复发的情况,则强烈建议先咨询医生的意见再开始新的训练项目。 - **设备要求**:所有推荐的练习都假设参与者只有基础的家庭健身器材可用,比如瑜伽垫、弹力带或者泡沫轴等;此外还会对每项活动的大致持续时间和频率做出估计,并分享一些自我监测进步的方法。 请告诉我您具体想了解哪方面的信息,以便我能更好地为您提供帮助。`;
@Injectable()
export class AiCoachService {
private readonly logger = new Logger(AiCoachService.name);
private readonly client: OpenAI;
private readonly model: string;
private readonly visionModel: string;
constructor(private readonly configService: ConfigService, private readonly usersService: UsersService) {
const dashScopeApiKey = this.configService.get<string>('DASHSCOPE_API_KEY') || 'sk-e3ff4494c2f1463a8910d5b3d05d3143';
const baseURL = this.configService.get<string>('DASHSCOPE_BASE_URL') || 'https://dashscope.aliyuncs.com/compatible-mode/v1';
this.client = new OpenAI({
apiKey: dashScopeApiKey,
baseURL,
});
// 默认选择通义千问对话模型OpenAI兼容名可通过环境覆盖
this.model = this.configService.get<string>('DASHSCOPE_MODEL') || 'qwen-flash';
this.visionModel = this.configService.get<string>('DASHSCOPE_VISION_MODEL') || 'qwen-vl-plus';
}
async createOrAppendMessages(params: {
userId: string;
conversationId?: string;
userContent: string;
}): Promise<{ conversationId: string }> {
let conversationId = params.conversationId;
if (!conversationId) {
conversationId = `${params.userId}-${Date.now()}`;
await AiConversation.create({ id: conversationId, userId: params.userId, title: null, lastMessageAt: new Date() });
} else {
await AiConversation.upsert({ id: conversationId, userId: params.userId, lastMessageAt: new Date() });
}
await AiMessage.create({
conversationId,
userId: params.userId,
role: RoleType.User,
content: params.userContent,
metadata: null,
});
return { conversationId };
}
buildChatHistory = async (userId: string, conversationId: string) => {
const history = await AiMessage.findAll({
where: { userId, conversationId },
order: [['created_at', 'ASC']],
});
const messages = [
{ role: 'system' as const, content: SYSTEM_PROMPT },
...history.map((m) => ({ role: m.role as 'user' | 'assistant' | 'system', content: m.content })),
];
return messages;
};
async streamChat(params: {
userId: string;
conversationId: string;
userContent: string;
systemNotice?: string;
}): Promise<Readable> {
// 上下文:系统提示 + 历史 + 当前用户消息
const messages = await this.buildChatHistory(params.userId, params.conversationId);
if (params.systemNotice) {
messages.unshift({ role: 'system', content: params.systemNotice });
}
const stream = await this.client.chat.completions.create({
model: this.model,
messages,
stream: true,
temperature: 0.7,
max_tokens: 1024,
});
const readable = new Readable({ read() { } });
let assistantContent = '';
(async () => {
try {
for await (const chunk of stream) {
const delta = chunk.choices?.[0]?.delta?.content || '';
if (delta) {
assistantContent += delta;
readable.push(delta);
}
}
// 结束将assistant消息入库
await AiMessage.create({
conversationId: params.conversationId,
userId: params.userId,
role: RoleType.Assistant,
content: assistantContent,
metadata: { model: this.model },
});
await AiConversation.update({ lastMessageAt: new Date(), title: this.deriveTitleIfEmpty(assistantContent) }, { where: { id: params.conversationId, userId: params.userId } });
} catch (error) {
this.logger.error(`stream error: ${error?.message || error}`);
readable.push('\n[对话发生错误,请稍后重试]');
} finally {
readable.push(null);
}
})();
return readable;
}
private deriveTitleIfEmpty(assistantReply: string): string | null {
if (!assistantReply) return null;
const firstLine = assistantReply.split(/\r?\n/).find(Boolean) || '';
return firstLine.slice(0, 50) || null;
}
async listConversations(userId: string, params: { page?: number; pageSize?: number }) {
const page = Math.max(1, params.page || 1);
const pageSize = Math.min(50, Math.max(1, params.pageSize || 20));
const offset = (page - 1) * pageSize;
const { rows, count } = await AiConversation.findAndCountAll({
where: { userId },
order: [['last_message_at', 'DESC']],
offset,
limit: pageSize,
});
return {
page,
pageSize,
total: count,
items: rows.map((c) => ({
conversationId: c.id,
title: c.title,
lastMessageAt: c.lastMessageAt,
createdAt: c.createdAt,
})),
};
}
async getConversationDetail(userId: string, conversationId: string) {
const conv = await AiConversation.findOne({ where: { id: conversationId, userId } });
if (!conv) return null;
const messages = await AiMessage.findAll({
where: { userId, conversationId },
order: [['created_at', 'ASC']],
});
return {
conversationId: conv.id,
title: conv.title,
lastMessageAt: conv.lastMessageAt,
createdAt: conv.createdAt,
messages: messages.map((m) => ({ role: m.role, content: m.content, createdAt: m.createdAt })),
};
}
async deleteConversation(userId: string, conversationId: string): Promise<boolean> {
const conv = await AiConversation.findOne({ where: { id: conversationId, userId } });
if (!conv) return false;
await AiMessage.destroy({ where: { userId, conversationId } });
await AiConversation.destroy({ where: { id: conversationId, userId } });
return true;
}
/**
* AI体态评估
* - 汇总用户身高体重
* - 使用视觉模型读取三张图片(正/侧/背)
* - 通过强约束的 JSON Schema 产出结构化结果
* - 存储评估记录并返回
*/
async assessPosture(params: {
userId: string;
frontImageUrl: string;
sideImageUrl: string;
backImageUrl: string;
heightCm?: number;
weightKg?: number;
}) {
// 获取默认身高体重
let heightCm: number | undefined = params.heightCm;
let weightKg: number | undefined = params.weightKg;
if (heightCm == null || weightKg == null) {
const profile = await UserProfile.findOne({ where: { userId: params.userId } });
if (heightCm == null) heightCm = profile?.height ?? undefined;
if (weightKg == null) weightKg = profile?.weight ?? undefined;
}
const schemaInstruction = `请以严格合法的JSON返回体态评估结果键名与类型必须匹配以下Schema不要输出多余文本
{
"overallScore": number(0-5),
"radar": {
"骨盆中立": number(0-5),
"肩带稳": number(0-5),
"胸廓控": number(0-5),
"主排列": number(0-5),
"柔对线": number(0-5),
"核心": number(0-5)
},
"frontView": {
"描述": string,
"问题要点": string[],
"建议动作": string[]
},
"sideView": {
"描述": string,
"问题要点": string[],
"建议动作": string[]
},
"backView": {
"描述": string,
"问题要点": string[],
"建议动作": string[]
}
}`;
const persona = `你是一名资深体态评估与普拉提康复教练。结合用户提供的三张照片(正面/侧面/背面)进行体态评估。严格限制话题在健康、姿势、普拉提与训练建议范围内。用词亲切但专业,强调安全、循序渐进与个体差异。用户资料:身高${heightCm ?? '未知'}cm体重${weightKg ?? '未知'}kg。`;
const completion = await this.client.chat.completions.create({
model: this.visionModel,
messages: [
{ role: 'system', content: persona },
{
role: 'user',
content: [
{ type: 'text', text: schemaInstruction },
{ type: 'text', text: '这三张图分别是正面、侧面、背面:' },
{ type: 'image_url', image_url: { url: params.frontImageUrl } as any },
{ type: 'image_url', image_url: { url: params.sideImageUrl } as any },
{ type: 'image_url', image_url: { url: params.backImageUrl } as any },
] as any,
},
],
temperature: 0,
response_format: { type: 'json_object' } as any,
});
const raw = completion.choices?.[0]?.message?.content || '{}';
let result: any = {};
try { result = JSON.parse(raw); } catch { }
const overallScore = typeof result.overallScore === 'number' ? result.overallScore : null;
const rec = await PostureAssessment.create({
userId: params.userId,
frontImageUrl: params.frontImageUrl,
sideImageUrl: params.sideImageUrl,
backImageUrl: params.backImageUrl,
heightCm: heightCm != null ? heightCm : null,
weightKg: weightKg != null ? weightKg : null,
overallScore,
result,
});
return { id: rec.id, overallScore, result };
}
private isLikelyWeightLogIntent(text: string | undefined): boolean {
if (!text) return false;
const t = text.toLowerCase();
return /体重|称重|秤|kg|公斤|weigh|weight/.test(t);
}
async maybeExtractAndUpdateWeight(userId: string, imageUrl?: string, userText?: string): Promise<{ weightKg?: number }> {
if (!imageUrl || !this.isLikelyWeightLogIntent(userText)) return {};
try {
const sys = '从照片中读取电子秤的数字单位通常为kg。仅返回JSON例如 {"weightKg": 65.2},若无法识别,返回 {"weightKg": null}。不要添加其他文本。';
const completion = await this.client.chat.completions.create({
model: this.visionModel,
messages: [
{ role: 'system', content: sys },
{
role: 'user',
content: [
{ type: 'text', text: '请从图片中提取体重kg。若图中单位为斤或lb请换算为kg。' },
{ type: 'image_url', image_url: { url: imageUrl } as any },
] as any,
},
],
temperature: 0,
response_format: { type: 'json_object' } as any,
});
const raw = completion.choices?.[0]?.message?.content || '';
let weightKg: number | undefined;
try {
const obj = JSON.parse(raw);
weightKg = typeof obj.weightKg === 'number' ? obj.weightKg : undefined;
} catch {
const m = raw.match(/\d+(?:\.\d+)?/);
weightKg = m ? parseFloat(m[0]) : undefined;
}
if (weightKg && isFinite(weightKg) && weightKg > 0 && weightKg < 400) {
await this.usersService.addWeightByVision(userId, weightKg);
return { weightKg };
}
return {};
} catch (err) {
this.logger.error(`maybeExtractAndUpdateWeight error: ${err instanceof Error ? err.message : String(err)}`);
return {};
}
}
}