diff --git a/api/v1/plant/ai_chat_api.go b/api/v1/plant/ai_chat_api.go index 777a022..96e586a 100644 --- a/api/v1/plant/ai_chat_api.go +++ b/api/v1/plant/ai_chat_api.go @@ -1,6 +1,7 @@ package plant import ( + "fmt" "sundynix-go/global" "sundynix-go/model/commom/response" "sundynix-go/utils/auth" @@ -24,7 +25,19 @@ func (a *AiChatApi) ChatStreamPlant(c *gin.Context) { return } - // SSE Headers(微信小程序通过 enableChunked: true 配合实现打字机效果) + userId := auth.GetUserId(c) + + // ── 每日用量检查 ── + cfg, cfgErr := sysAiConfigService.GetActiveAiConfig() + if cfgErr == nil && cfg.DailyQueryLimit > 0 { + todayCount, _ := chatHistoryService.GetTodayCount(userId) + if todayCount >= int64(cfg.DailyQueryLimit) { + response.FailWithMsg(fmt.Sprintf("今日问答次数已达上限(%d次),明天再来吧", cfg.DailyQueryLimit), c) + return + } + } + + // SSE Headers w := c.Writer header := w.Header() header.Set("Content-Type", "text/event-stream") @@ -46,14 +59,12 @@ func (a *AiChatApi) ChatStreamPlant(c *gin.Context) { _, _ = w.WriteString("data: [ERROR]" + err.Error() + "\n\n") w.Flush() } else { - // 流结束标志 _, _ = w.WriteString("data: [DONE]\n\n") w.Flush() } // 异步保存问答历史 if fullAnswer != "" { - userId := auth.GetUserId(c) go func() { if saveErr := chatHistoryService.SaveHistory(userId, query, fullAnswer); saveErr != nil { global.Logger.Error("Save chat history failed", zap.Error(saveErr)) @@ -138,6 +149,34 @@ func (a *AiChatApi) ClearChatHistory(c *gin.Context) { response.OkWithMsg("已清空", c) } +// GetChatQuota 获取当前用户今日剩余问答额度 +// @Tags Plant-AiChat +// @Router /plant/chat/quota [get] +func (a *AiChatApi) GetChatQuota(c *gin.Context) { + userId := auth.GetUserId(c) + todayCount, _ := chatHistoryService.GetTodayCount(userId) + + limit := 0 // 0 = unlimited + cfg, err := sysAiConfigService.GetActiveAiConfig() + if err == nil && cfg.DailyQueryLimit > 0 { + limit = cfg.DailyQueryLimit + } + + remaining := -1 // -1 means unlimited + if limit > 0 { + remaining = limit - int(todayCount) + if remaining < 0 { + remaining = 0 + } + } + + response.OkWithData(map[string]interface{}{ + "used": todayCount, + "limit": limit, + "remaining": remaining, + }, c) +} + func parseInt(s string) int { n := 0 for _, c := range s { diff --git a/api/v1/plant/enter.go b/api/v1/plant/enter.go index 2224a84..28ca710 100644 --- a/api/v1/plant/enter.go +++ b/api/v1/plant/enter.go @@ -31,4 +31,5 @@ var ( exchangeService = service.GroupApp.PlantServiceGroup.ExchangeService aiRagService = service.GroupApp.PlantServiceGroup.AiRagService chatHistoryService = service.GroupApp.PlantServiceGroup.AiChatHistoryService + sysAiConfigService = service.GroupApp.SystemServiceGroup.SysAiConfigService ) diff --git a/model/system/sys_ai_config.go b/model/system/sys_ai_config.go index c27a02b..4fa2ab9 100644 --- a/model/system/sys_ai_config.go +++ b/model/system/sys_ai_config.go @@ -20,4 +20,6 @@ type SysAiConfig struct { EmbeddingApiUrl string `gorm:"column:embedding_api_url;type:varchar(255);comment:Embedding模型接口地址" json:"embeddingApiUrl" form:"embeddingApiUrl"` EmbeddingApiKey string `gorm:"column:embedding_api_key;type:varchar(255);comment:Embedding模型ApiKey" json:"embeddingApiKey" form:"embeddingApiKey"` EmbeddingModelName string `gorm:"column:embedding_model_name;type:varchar(100);comment:Embedding模型名称" json:"embeddingModelName" form:"embeddingModelName"` + // 用量限制 + DailyQueryLimit int `gorm:"column:daily_query_limit;type:int;default:20;comment:每用户每日问答上限(0=不限)" json:"dailyQueryLimit" form:"dailyQueryLimit"` } diff --git a/router/plant/ai_chat_router.go b/router/plant/ai_chat_router.go index 1554c01..dc8647e 100644 --- a/router/plant/ai_chat_router.go +++ b/router/plant/ai_chat_router.go @@ -13,6 +13,7 @@ func (s *AiChatRouter) InitAiChatRouter(Router *gin.RouterGroup) (R gin.IRoutes) aiChatRouter.GET("history", aiChatApi.GetChatHistory) // 问答历史列表 aiChatRouter.POST("history/delete", aiChatApi.DeleteChatHistory) // 删除单条 aiChatRouter.POST("history/clear", aiChatApi.ClearChatHistory) // 清空历史 + aiChatRouter.GET("quota", aiChatApi.GetChatQuota) // 今日剩余额度 } return aiChatRouter } diff --git a/router/system/sys_ai_config_router.go b/router/system/sys_ai_config_router.go index b403165..bc6b7fe 100644 --- a/router/system/sys_ai_config_router.go +++ b/router/system/sys_ai_config_router.go @@ -9,7 +9,7 @@ func (s *SysAiConfigRouter) InitSysAiConfigRouter(Router *gin.RouterGroup) (R gi sysAiConfigApi := sysAiConfigApi { sysAiConfigRouter.POST("create", sysAiConfigApi.CreateAiConfig) // 创建配置 - sysAiConfigRouter.PUT("update", sysAiConfigApi.UpdateAiConfig) // 更新配置 + sysAiConfigRouter.POST("update", sysAiConfigApi.UpdateAiConfig) // 更新配置 sysAiConfigRouter.POST("setActive", sysAiConfigApi.SetActive) // 设置激活状态 sysAiConfigRouter.GET("list", sysAiConfigApi.GetList) // 获取列表 } diff --git a/service/plant/ai_chat_history_service.go b/service/plant/ai_chat_history_service.go index ed1360f..809dac2 100644 --- a/service/plant/ai_chat_history_service.go +++ b/service/plant/ai_chat_history_service.go @@ -46,3 +46,12 @@ func (s *AiChatHistoryService) DeleteHistory(userId, id string) error { func (s *AiChatHistoryService) ClearHistory(userId string) error { return global.DB.Where("user_id = ?", userId).Delete(&plantModel.AiChatHistory{}).Error } + +// GetTodayCount 获取用户今日问答数量 +func (s *AiChatHistoryService) GetTodayCount(userId string) (int64, error) { + var count int64 + err := global.DB.Model(&plantModel.AiChatHistory{}). + Where("user_id = ? AND DATE(created_at) = CURDATE()", userId). + Count(&count).Error + return count, err +}