feat: 更新AI教练控制器,增加用户聊天次数管理功能

- 在AI教练控制器中引入用户聊天次数的检查,确保用户在进行对话前有足够的聊天次数。
- 新增用户服务方法以获取和扣减用户的聊天次数,优化用户体验。
- 调整默认免费聊天次数为5次,提升系统的使用限制管理。
This commit is contained in:
richarjiang
2025-08-18 19:20:01 +08:00
parent ede5730647
commit a56d1d5255
3 changed files with 62 additions and 100 deletions

View File

@@ -1,4 +1,4 @@
import { Body, Controller, Delete, Get, Param, Post, Query, Res, StreamableFile, UseGuards } from '@nestjs/common'; import { Body, Controller, Delete, Get, HttpException, HttpStatus, Logger, Param, Post, Query, Res, StreamableFile, UseGuards } from '@nestjs/common';
import { ApiTags, ApiOperation, ApiBody, ApiQuery, ApiParam } from '@nestjs/swagger'; import { ApiTags, ApiOperation, ApiBody, ApiQuery, ApiParam } from '@nestjs/swagger';
import { Response } from 'express'; import { Response } from 'express';
import { JwtAuthGuard } from '../common/guards/jwt-auth.guard'; import { JwtAuthGuard } from '../common/guards/jwt-auth.guard';
@@ -7,12 +7,17 @@ import { AccessTokenPayload } from '../users/services/apple-auth.service';
import { AiCoachService } from './ai-coach.service'; import { AiCoachService } from './ai-coach.service';
import { AiChatRequestDto, AiChatResponseDto, AiResponseDataDto } from './dto/ai-chat.dto'; import { AiChatRequestDto, AiChatResponseDto, AiResponseDataDto } from './dto/ai-chat.dto';
import { PostureAssessmentRequestDto, PostureAssessmentResponseDto } from './dto/posture-assessment.dto'; import { PostureAssessmentRequestDto, PostureAssessmentResponseDto } from './dto/posture-assessment.dto';
import { UsersService } from '../users/users.service';
@ApiTags('ai-coach') @ApiTags('ai-coach')
@Controller('ai-coach') @Controller('ai-coach')
@UseGuards(JwtAuthGuard) @UseGuards(JwtAuthGuard)
export class AiCoachController { export class AiCoachController {
constructor(private readonly aiCoachService: AiCoachService) { } private readonly logger = new Logger(AiCoachController.name);
constructor(
private readonly aiCoachService: AiCoachService,
private readonly usersService: UsersService,
) { }
@Post('chat') @Post('chat')
@ApiOperation({ summary: '流式大模型对话(普拉提教练)' }) @ApiOperation({ summary: '流式大模型对话(普拉提教练)' })
@@ -26,6 +31,13 @@ export class AiCoachController {
const stream = body.stream !== false; // 默认流式 const stream = body.stream !== false; // 默认流式
const userContent = body.messages?.[body.messages.length - 1]?.content || ''; const userContent = body.messages?.[body.messages.length - 1]?.content || '';
// 判断用户是否有聊天次数
const usageCount = await this.usersService.getUserUsageCount(userId);
if (usageCount <= 0) {
this.logger.error(`chat: ${userId} has no usage count`);
throw new HttpException('用户没有聊天次数', HttpStatus.FORBIDDEN);
}
// 创建或沿用会话ID并保存用户消息 // 创建或沿用会话ID并保存用户消息
const { conversationId } = await this.aiCoachService.createOrAppendMessages({ const { conversationId } = await this.aiCoachService.createOrAppendMessages({
userId, userId,
@@ -45,6 +57,8 @@ export class AiCoachController {
confirmationData: body.confirmationData, confirmationData: body.confirmationData,
}); });
await this.usersService.deductUserUsageCount(userId);
// 检查是否返回结构化数据(如确认选项) // 检查是否返回结构化数据(如确认选项)
// 结构化数据必须使用非流式模式返回 // 结构化数据必须使用非流式模式返回
if (typeof result === 'object' && 'type' in result) { if (typeof result === 'object' && 'type' in result) {

View File

@@ -317,99 +317,4 @@ export class UsersController {
} }
} }
/**
* 获取营养汇总分析
*/
@UseGuards(JwtAuthGuard)
@Get('nutrition-summary')
@HttpCode(HttpStatus.OK)
@ApiOperation({ summary: '获取最近饮食的营养汇总分析' })
@ApiQuery({ name: 'mealCount', required: false, description: '分析最近几顿饮食默认10顿' })
@ApiResponse({ status: 200, description: '成功获取营养汇总', type: DietAnalysisResponseDto })
async getNutritionSummary(
@Query('mealCount') mealCount: string = '10',
@CurrentUser() user: AccessTokenPayload,
): Promise<DietAnalysisResponseDto> {
this.logger.log(`获取营养汇总 - 用户ID: ${user.sub}, 分析${mealCount}顿饮食`);
const count = Math.min(20, Math.max(1, parseInt(mealCount) || 10));
const nutritionSummary = await this.usersService.getRecentNutritionSummary(user.sub, count);
// 获取最近的饮食记录用于分析
const recentRecords = await this.usersService.getDietHistory(user.sub, { limit: count });
// 简单的营养评分算法(可以后续优化)
const nutritionScore = this.calculateNutritionScore(nutritionSummary);
// 生成基础建议后续可以接入AI分析
const recommendations = this.generateBasicRecommendations(nutritionSummary);
return {
nutritionSummary,
recentRecords: recentRecords.records,
healthAnalysis: '基于您最近的饮食记录,我将为您提供个性化的营养分析和健康建议。',
nutritionScore,
recommendations,
};
}
/**
* 简单的营养评分算法
*/
private calculateNutritionScore(summary: any): number {
let score = 50; // 基础分数
// 基于热量是否合理调整分数
const dailyCalories = summary.totalCalories / (summary.recordCount / 3); // 假设一天3餐
if (dailyCalories >= 1500 && dailyCalories <= 2500) score += 20;
else if (dailyCalories < 1200 || dailyCalories > 3000) score -= 20;
// 基于蛋白质摄入调整分数
const dailyProtein = summary.totalProtein / (summary.recordCount / 3);
if (dailyProtein >= 50 && dailyProtein <= 150) score += 15;
else if (dailyProtein < 30) score -= 15;
// 基于膳食纤维调整分数
const dailyFiber = summary.totalFiber / (summary.recordCount / 3);
if (dailyFiber >= 25) score += 15;
else if (dailyFiber < 10) score -= 10;
return Math.max(0, Math.min(100, score));
}
/**
* 生成基础营养建议
*/
private generateBasicRecommendations(summary: any): string[] {
const recommendations: string[] = [];
const dailyCalories = summary.totalCalories / (summary.recordCount / 3);
const dailyProtein = summary.totalProtein / (summary.recordCount / 3);
const dailyFiber = summary.totalFiber / (summary.recordCount / 3);
const dailySodium = summary.totalSodium / (summary.recordCount / 3);
if (dailyCalories < 1200) {
recommendations.push('您的日均热量摄入偏低,建议适当增加营养密度高的食物。');
} else if (dailyCalories > 2500) {
recommendations.push('您的日均热量摄入偏高建议控制portion size或选择低热量食物。');
}
if (dailyProtein < 50) {
recommendations.push('建议增加优质蛋白质摄入,如鸡胸肉、鱼类、豆制品等。');
}
if (dailyFiber < 25) {
recommendations.push('建议增加膳食纤维摄入,多吃蔬菜、水果和全谷物。');
}
if (dailySodium > 2000) {
recommendations.push('钠摄入偏高,建议减少加工食品和调味料的使用。');
}
if (recommendations.length === 0) {
recommendations.push('您的饮食结构相对均衡,继续保持良好的饮食习惯!');
}
return recommendations;
}
} }

View File

@@ -36,7 +36,7 @@ import { ActivityLogsService } from '../activity-logs/activity-logs.service';
import { ActivityActionType, ActivityEntityType } from '../activity-logs/models/activity-log.model'; import { ActivityActionType, ActivityEntityType } from '../activity-logs/models/activity-log.model';
import { CreateDietRecordDto, UpdateDietRecordDto, GetDietHistoryQueryDto, DietRecordResponseDto, DietHistoryResponseDto, NutritionSummaryDto } from './dto/diet-record.dto'; import { CreateDietRecordDto, UpdateDietRecordDto, GetDietHistoryQueryDto, DietRecordResponseDto, DietHistoryResponseDto, NutritionSummaryDto } from './dto/diet-record.dto';
const DEFAULT_FREE_USAGE_COUNT = 10; const DEFAULT_FREE_USAGE_COUNT = 5;
@Injectable() @Injectable()
export class UsersService { export class UsersService {
@@ -126,6 +126,49 @@ export class UsersService {
} }
} }
/**
* @desc 获取用户剩余的聊天次数
*/
async getUserUsageCount(userId: string): Promise<number> {
try {
const user = await this.userModel.findOne({ where: { id: userId } });
if (!user) {
this.logger.log(`getUserUsageCount: ${userId} not found, return 0`);
return 0
}
if (user.isVip) {
// 会员用户无限次
this.logger.log(`getUserUsageCount: ${userId} is vip, return 999`);
return 999
}
this.logger.log(`getUserUsageCount: ${userId} freeUsageCount: ${user.freeUsageCount}`);
return user.freeUsageCount || 0;
} catch (error) {
this.logger.error(`getUserUsageCount error: ${error instanceof Error ? error.message : String(error)}`);
return 0
}
}
// 扣减用户免费次数
async deductUserUsageCount(userId: string, count: number = 1): Promise<void> {
try {
this.logger.log(`deductUserUsageCount: ${userId} deduct ${count} times`);
const user = await this.userModel.findOne({ where: { id: userId } });
if (!user) {
throw new NotFoundException(`ID为${userId}的用户不存在`);
}
user.freeUsageCount -= count;
await user.save();
} catch (error) {
this.logger.error(`deductUserUsageCount error: ${error instanceof Error ? error.message : String(error)}`);
throw error;
}
}
// 更新用户昵称、头像 // 更新用户昵称、头像
async updateUser(updateUserDto: UpdateUserDto): Promise<UpdateUserResponseDto> { async updateUser(updateUserDto: UpdateUserDto): Promise<UpdateUserResponseDto> {
@@ -453,8 +496,8 @@ export class UsersService {
} }
/** /**
* 获取最近N顿饮食的营养汇总 * 获取最近N顿饮食的营养汇总
*/ */
async getRecentNutritionSummary(userId: string, mealCount: number = 10): Promise<NutritionSummaryDto> { async getRecentNutritionSummary(userId: string, mealCount: number = 10): Promise<NutritionSummaryDto> {
const records = await this.userDietHistoryModel.findAll({ const records = await this.userDietHistoryModel.findAll({
where: { userId, deleted: false }, where: { userId, deleted: false },