diff --git a/app/gateway/internal/handler/proxy.go b/app/gateway/internal/handler/proxy.go index 3e1546c..4d42752 100644 --- a/app/gateway/internal/handler/proxy.go +++ b/app/gateway/internal/handler/proxy.go @@ -39,6 +39,7 @@ func NewProxyRouter(upstreams []config.Upstream) *ProxyRouter { target := targetURL // 显式捕获循环变量 proxy := &httputil.ReverseProxy{ + FlushInterval: -1, // 立即将数据刷新到客户端,确保 SSE 等流式接口能够实时响应 Rewrite: func(pr *httputil.ProxyRequest) { pr.SetXForwarded() pr.Out.URL.Scheme = target.Scheme diff --git a/app/plant/api/etc/plant-api.yaml b/app/plant/api/etc/plant-api.yaml index f00e684..379f727 100644 --- a/app/plant/api/etc/plant-api.yaml +++ b/app/plant/api/etc/plant-api.yaml @@ -34,17 +34,3 @@ DB: BaiduOcr: ApiKey: hpBfjwy8ifv3qswYGYjUCNKN SecretKey: i5aXZdM4XZVuDroBslL0f3uIuwbAyXFS - -# OpenAI-compatible AI/RAG 配置。未配置 ChatApiUrl/ChatApiKey 时,AI 问答返回明确错误。 -Ai: - ChatApiUrl: - ChatApiKey: - ChatModelName: - EmbeddingApiUrl: - EmbeddingApiKey: - EmbeddingModelName: - QdrantUrl: - QdrantApiKey: - QdrantCollection: - VectorDimension: 0 - DailyQuota: 20 diff --git a/app/plant/api/internal/config/config.go b/app/plant/api/internal/config/config.go index b187fd7..d6838e0 100644 --- a/app/plant/api/internal/config/config.go +++ b/app/plant/api/internal/config/config.go @@ -22,17 +22,4 @@ type Config struct { ApiKey string SecretKey string } `json:",optional"` - Ai struct { - ChatApiUrl string `json:",optional"` - ChatApiKey string `json:",optional"` - ChatModelName string `json:",optional"` - EmbeddingApiUrl string `json:",optional"` - EmbeddingApiKey string `json:",optional"` - EmbeddingModelName string `json:",optional"` - QdrantUrl string `json:",optional"` - QdrantApiKey string `json:",optional"` - QdrantCollection string `json:",optional"` - VectorDimension int `json:",optional"` - DailyQuota int64 `json:",optional"` - } `json:",optional"` } diff --git a/app/plant/api/internal/handler/legacy/legacy.go b/app/plant/api/internal/handler/legacy/legacy.go index fd24f21..f20211b 100644 --- a/app/plant/api/internal/handler/legacy/legacy.go +++ b/app/plant/api/internal/handler/legacy/legacy.go @@ -2,14 +2,20 @@ package legacy import ( "context" + "encoding/base64" "encoding/json" "fmt" + "io" "net/http" + "net/url" + "strings" "time" filePb "sundynix-micro-go/app/file/rpc/file" + aiLogic "sundynix-micro-go/app/plant/api/internal/logic/ai" "sundynix-micro-go/app/plant/api/internal/logic/complete" plantLogic "sundynix-micro-go/app/plant/api/internal/logic/myPlant" + ocrLogic "sundynix-micro-go/app/plant/api/internal/logic/ocr" postLogic "sundynix-micro-go/app/plant/api/internal/logic/post" topicLogic "sundynix-micro-go/app/plant/api/internal/logic/topic" wikiLogic "sundynix-micro-go/app/plant/api/internal/logic/wiki" @@ -892,3 +898,227 @@ func MediaCheckCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { response.Ok(w) } } + +func AiChatStreamHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + if query == "" { + query = r.URL.Query().Get("question") + } + if query == "" { + response.Fail(w, "query 不能为空") + return + } + userId := fmt.Sprintf("%v", r.Context().Value("userId")) + + header := w.Header() + header.Set("Content-Type", "text/event-stream") + header.Set("Cache-Control", "no-cache") + header.Set("Connection", "keep-alive") + header.Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) + + err := aiLogic.StreamChat(r.Context(), svcCtx, userId, query, w) + if err != nil { + _, _ = fmt.Fprintf(w, "data: [ERROR] %v\n\n", err) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } else { + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + } +} + +func getBaiduAccessToken(apiKey, secretKey string) (string, error) { + tokenURL := fmt.Sprintf( + "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + apiKey, secretKey, + ) + resp, err := http.Post(tokenURL, "application/x-www-form-urlencoded", nil) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + var tokenObj struct { + AccessToken string `json:"access_token"` + } + if err = json.Unmarshal(body, &tokenObj); err != nil || tokenObj.AccessToken == "" { + return "", fmt.Errorf("解析百度 token 失败") + } + return tokenObj.AccessToken, nil +} + +func ClassifyPlantHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + contentType := r.Header.Get("Content-Type") + if !strings.Contains(contentType, "multipart/form-data") { + var req types.OcrReq + if err := httpx.Parse(r, &req); err != nil { + _ = json.NewDecoder(r.Body).Decode(&req) + } + if req.ImageUrl == "" { + response.Fail(w, "接收文件失败: imageUrl 不能为空并且必须使用 multipart/form-data 上传文件") + return + } + l := ocrLogic.NewOcrClassifyLogic(r.Context(), svcCtx) + resp, err := l.OcrClassify(&req) + if err != nil { + response.Fail(w, err.Error()) + } else { + response.OkWithData(w, resp) + } + return + } + + file, _, err := r.FormFile("file") + if err != nil { + response.Fail(w, "接收文件失败: "+err.Error()) + return + } + defer file.Close() + + fileBytes, err := io.ReadAll(file) + if err != nil { + response.Fail(w, "读取文件失败: "+err.Error()) + return + } + + apiKey := svcCtx.Config.BaiduOcr.ApiKey + secretKey := svcCtx.Config.BaiduOcr.SecretKey + if apiKey == "" || secretKey == "" { + response.Fail(w, "百度 OCR 未配置 ApiKey/SecretKey") + return + } + + accessToken, err := getBaiduAccessToken(apiKey, secretKey) + if err != nil { + response.Fail(w, err.Error()) + return + } + + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + escapedBase64 := url.QueryEscape(base64Str) + payload := strings.NewReader("image=" + escapedBase64 + "&baike_num=1") + + apiURL := "https://aip.baidubce.com/rest/2.0/image-classify/v1/plant?access_token=" + accessToken + classifyReq, err := http.NewRequest("POST", apiURL, payload) + if err != nil { + response.Fail(w, "创建请求失败: "+err.Error()) + return + } + classifyReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + classifyResp, err := client.Do(classifyReq) + if err != nil { + response.Fail(w, "调用百度植物识别接口失败: "+err.Error()) + return + } + defer classifyResp.Body.Close() + + body, err := io.ReadAll(classifyResp.Body) + if err != nil { + response.Fail(w, "读取识别结果失败: "+err.Error()) + return + } + + var baiduResp struct { + LogId uint64 `json:"log_id"` + Result []struct { + Score float64 `json:"score"` + Name string `json:"name"` + BaikeInfo *struct { + BaikeUrl string `json:"baike_url"` + ImageUrl string `json:"image_url"` + Description string `json:"description"` + } `json:"baike_info"` + } `json:"result"` + } + _ = json.Unmarshal(body, &baiduResp) + + if baiduResp.LogId > 0 { + var dbResults plantModel.ResultsArray = make(plantModel.ResultsArray, 0, len(baiduResp.Result)) + for _, item := range baiduResp.Result { + var baikeInfo *plantModel.BaikeInfo + if item.BaikeInfo != nil { + baikeInfo = &plantModel.BaikeInfo{ + BaikeUrl: item.BaikeInfo.BaikeUrl, + ImageUrl: item.BaikeInfo.ImageUrl, + Description: item.BaikeInfo.Description, + } + } + dbResults = append(dbResults, plantModel.ResultItem{ + Score: item.Score, + Name: item.Name, + BaikeInfo: baikeInfo, + }) + } + + userID := fmt.Sprintf("%v", r.Context().Value("userId")) + record := plantModel.ClassifyRecord{ + UserID: userID, + LogID: baiduResp.LogId, + AllResults: dbResults, + } + if errDb := svcCtx.DB.Create(&record).Error; errDb != nil { + fmt.Printf("植物识别记录写入数据库失败: %v\n", errDb) + } + } + + var result interface{} + if err = json.Unmarshal(body, &result); err != nil { + response.Fail(w, "解析识别结果失败: "+err.Error()) + return + } + + response.OkWithData(w, result) + } +} + +// AddCarePlanHandler 兼容旧小程序的批量添加养护计划 +func AddCarePlanHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req struct { + CarePlan []struct { + PlantID string `json:"plantId"` + Name string `json:"name"` + Period int32 `json:"period"` + Icon string `json:"icon"` + TargetAction string `json:"targetAction"` + } `json:"carePlan"` + } + + if err := httpx.Parse(r, &req); err != nil { + response.Fail(w, "解析请求参数失败: "+err.Error()) + return + } + + userId := fmt.Sprintf("%v", r.Context().Value("userId")) + + for _, item := range req.CarePlan { + if item.PlantID == "" { + response.Fail(w, "plantId 不能为空") + return + } + _, err := svcCtx.PlantRpc.AddCarePlan(r.Context(), &plantPb.AddCarePlanReq{ + UserId: userId, + PlantId: item.PlantID, + Name: item.Name, + Icon: item.Icon, + Period: item.Period, + TargetAction: item.TargetAction, + }) + if err != nil { + response.Fail(w, "添加养护计划失败: "+err.Error()) + return + } + } + + response.Ok(w) + } +} diff --git a/app/plant/api/internal/handler/routes.go b/app/plant/api/internal/handler/routes.go index ca54c1e..f9b7648 100644 --- a/app/plant/api/internal/handler/routes.go +++ b/app/plant/api/internal/handler/routes.go @@ -573,7 +573,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { {Method: http.MethodPost, Path: "/update", Handler: myPlant.UpdatePlantHandler(serverCtx)}, {Method: http.MethodPost, Path: "/deletePlant", Handler: myPlant.DeletePlantHandler(serverCtx)}, {Method: http.MethodPost, Path: "/deletePlan", Handler: myPlant.DeleteCarePlanHandler(serverCtx)}, - {Method: http.MethodPost, Path: "/plan/add", Handler: myPlant.AddCarePlanHandler(serverCtx)}, + {Method: http.MethodPost, Path: "/plan/add", Handler: legacy.AddCarePlanHandler(serverCtx)}, {Method: http.MethodGet, Path: "/plan/delete", Handler: legacy.DeletePlanHandler(serverCtx)}, {Method: http.MethodGet, Path: "/todayTask", Handler: myPlant.GetTodayTaskListHandler(serverCtx)}, {Method: http.MethodPost, Path: "/completeTask", Handler: complete.CompleteTaskHandler(serverCtx)}, @@ -604,7 +604,6 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { {Method: http.MethodPost, Path: "/topic/page", Handler: topic.GetTopicListHandler(serverCtx)}, {Method: http.MethodGet, Path: "/topic/detail", Handler: legacy.TopicDetailHandler(serverCtx)}, - {Method: http.MethodPost, Path: "/classify/plant", Handler: ocr.OcrClassifyHandler(serverCtx)}, {Method: http.MethodPost, Path: "/classify/myClassifyLog", Handler: ocr.GetMyClassifyLogHandler(serverCtx)}, {Method: http.MethodPost, Path: "/classify/deleteClassifyLog", Handler: ocr.DeleteClassifyLogHandler(serverCtx)}, @@ -626,4 +625,15 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { rest.WithJwt(serverCtx.Config.Auth.AccessSecret), rest.WithPrefix("/api/plant"), ) + + // SSE 流式问答 & 百度 OCR 识别:需要禁用 go-zero 默认超时(TimeoutHandler 会中断长连接) + server.AddRoutes( + []rest.Route{ + {Method: http.MethodGet, Path: "/chat/stream", Handler: legacy.AiChatStreamHandler(serverCtx)}, + {Method: http.MethodPost, Path: "/classify/plant", Handler: legacy.ClassifyPlantHandler(serverCtx)}, + }, + rest.WithJwt(serverCtx.Config.Auth.AccessSecret), + rest.WithPrefix("/api/plant"), + rest.WithTimeout(0), + ) } diff --git a/app/plant/api/internal/logic/ai/openai.go b/app/plant/api/internal/logic/ai/openai.go index dec8e66..7199f24 100644 --- a/app/plant/api/internal/logic/ai/openai.go +++ b/app/plant/api/internal/logic/ai/openai.go @@ -1,131 +1,101 @@ package ai import ( - "bufio" - "bytes" "context" - "encoding/json" "errors" "fmt" "io" "net/http" "strings" + qdrant "github.com/qdrant/go-client/qdrant" + openai "github.com/sashabaranov/go-openai" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "sundynix-micro-go/app/plant/api/internal/svc" + plantModel "sundynix-micro-go/app/plant/model" plantPb "sundynix-micro-go/app/plant/rpc/plant" ) -type chatMessage struct { - Role string `json:"role"` - Content string `json:"content"` +func getActiveAiConfig(svcCtx *svc.ServiceContext) (*plantModel.SysAiConfig, error) { + var cfg plantModel.SysAiConfig + err := svcCtx.DB.Where("is_active = 1").First(&cfg).Error + if err != nil { + return nil, errors.New("AI/RAG 问答服务暂未激活或数据库配置缺失") + } + return &cfg, nil } -type chatRequest struct { - Model string `json:"model,omitempty"` - Messages []chatMessage `json:"messages"` - Stream bool `json:"stream"` -} - -func chatModel(svcCtx *svc.ServiceContext) string { - if svcCtx.Config.Ai.ChatModelName != "" { - return svcCtx.Config.Ai.ChatModelName +func chatModel(dbCfg *plantModel.SysAiConfig) string { + if dbCfg.ChatModelName != "" { + return dbCfg.ChatModelName } return "gpt-4o-mini" } -func requestBody(svcCtx *svc.ServiceContext, question string, stream bool) ([]byte, error) { - systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" - if ctxText := retrieveRAGContext(context.Background(), svcCtx, question); ctxText != "" { - systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" +func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) { + addr := strings.TrimPrefix(cfg.QdrantUrl, "http://") + addr = strings.TrimPrefix(addr, "https://") + conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err) } - return json.Marshal(chatRequest{ - Model: chatModel(svcCtx), - Messages: []chatMessage{ - {Role: "system", Content: systemPrompt}, - {Role: "user", Content: question}, - }, - Stream: stream, - }) + ctx := context.Background() + if cfg.QdrantApiKey != "" { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey)) + } + return conn, ctx, nil } -func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, question string) string { - c := svcCtx.Config.Ai - if c.EmbeddingApiUrl == "" || c.EmbeddingApiKey == "" || c.QdrantUrl == "" || c.QdrantCollection == "" { - return "" - } - body, _ := json.Marshal(map[string]interface{}{ - "model": c.EmbeddingModelName, - "input": question, - }) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.EmbeddingApiUrl, bytes.NewReader(body)) - if err != nil { - return "" - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.EmbeddingApiKey) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - var emb struct { - Data []struct { - Embedding []float32 `json:"embedding"` - } `json:"data"` - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 || json.NewDecoder(resp.Body).Decode(&emb) != nil || len(emb.Data) == 0 { +func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, dbCfg *plantModel.SysAiConfig, question string) string { + if dbCfg.EmbeddingApiUrl == "" || dbCfg.EmbeddingApiKey == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" { return "" } - searchBody, _ := json.Marshal(map[string]interface{}{ - "vector": emb.Data[0].Embedding, - "limit": 3, - "with_payload": true, + config := openai.DefaultConfig(dbCfg.EmbeddingApiKey) + if dbCfg.EmbeddingApiUrl != "" { + config.BaseURL = dbCfg.EmbeddingApiUrl + } + client := openai.NewClientWithConfig(config) + embResp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: []string{question}, + Model: openai.EmbeddingModel(dbCfg.EmbeddingModelName), }) - searchReq, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.QdrantUrl, "/")+"/collections/"+c.QdrantCollection+"/points/search", bytes.NewReader(searchBody)) if err != nil { return "" } - searchReq.Header.Set("Content-Type", "application/json") - if c.QdrantApiKey != "" { - searchReq.Header.Set("api-key", c.QdrantApiKey) - } - searchResp, err := http.DefaultClient.Do(searchReq) - if err != nil { + + conn, qdCtx, connErr := newQdrantConn(dbCfg) + if connErr != nil { return "" } - defer searchResp.Body.Close() - var parsed struct { - Result []struct { - Payload map[string]interface{} `json:"payload"` - } `json:"result"` - } - if searchResp.StatusCode < 200 || searchResp.StatusCode >= 300 || json.NewDecoder(searchResp.Body).Decode(&parsed) != nil { + defer conn.Close() + + ptsClient := qdrant.NewPointsClient(conn) + searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{ + CollectionName: dbCfg.QdrantCollection, + Vector: embResp.Data[0].Embedding, + Limit: 3, + WithPayload: &qdrant.WithPayloadSelector{ + SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}, + }, + }) + if searchErr != nil { return "" } + var b strings.Builder - for _, item := range parsed.Result { - if text, ok := item.Payload["full_text"].(string); ok && text != "" { - b.WriteString(text) + for _, pt := range searchRes.GetResult() { + if txt, ok := pt.GetPayload()["full_text"]; ok { + b.WriteString(txt.GetStringValue()) b.WriteString("\n") } } return b.String() } -func newChatRequest(ctx context.Context, svcCtx *svc.ServiceContext, body []byte) (*http.Request, error) { - if svcCtx.Config.Ai.ChatApiUrl == "" || svcCtx.Config.Ai.ChatApiKey == "" { - return nil, errors.New("AI/RAG 未配置 ChatApiUrl 或 ChatApiKey") - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, svcCtx.Config.Ai.ChatApiUrl, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+svcCtx.Config.Ai.ChatApiKey) - return req, nil -} - func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, question, answer string) { if userID == "" || question == "" || answer == "" { return @@ -136,102 +106,100 @@ func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, questi } func ChatCompletion(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string) (string, error) { - if err := ensureQuota(ctx, svcCtx, userID); err != nil { - return "", err - } - body, err := requestBody(svcCtx, question, false) + dbCfg, err := getActiveAiConfig(svcCtx) if err != nil { return "", err } - req, err := newChatRequest(ctx, svcCtx, body) - if err != nil { + if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil { return "", err } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return "", fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) - } - var parsed struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` + systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" + if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" { + systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" } - if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + + config := openai.DefaultConfig(dbCfg.ChatApiKey) + if dbCfg.ChatApiUrl != "" { + config.BaseURL = dbCfg.ChatApiUrl + } + client := openai.NewClientWithConfig(config) + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: chatModel(dbCfg), + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleSystem, Content: systemPrompt}, + {Role: openai.ChatMessageRoleUser, Content: question}, + }, + }) + if err != nil { return "", err } - if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" { + if len(resp.Choices) == 0 || resp.Choices[0].Message.Content == "" { return "", errors.New("AI 响应为空") } - answer := parsed.Choices[0].Message.Content + answer := resp.Choices[0].Message.Content SaveHistory(ctx, svcCtx, userID, question, answer) return answer, nil } func StreamChat(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string, w io.Writer) error { - if err := ensureQuota(ctx, svcCtx, userID); err != nil { - return err - } - body, err := requestBody(svcCtx, question, true) + dbCfg, err := getActiveAiConfig(svcCtx) if err != nil { return err } - req, err := newChatRequest(ctx, svcCtx, body) + if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil { + return err + } + + systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" + if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" { + systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" + } + + config := openai.DefaultConfig(dbCfg.ChatApiKey) + if dbCfg.ChatApiUrl != "" { + config.BaseURL = dbCfg.ChatApiUrl + } + client := openai.NewClientWithConfig(config) + stream, err := client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ + Model: chatModel(dbCfg), + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleSystem, Content: systemPrompt}, + {Role: openai.ChatMessageRoleUser, Content: question}, + }, + Stream: true, + }) if err != nil { return err } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) - } + defer stream.Close() var answer strings.Builder - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - _, _ = fmt.Fprintln(w, line) - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() + for { + resp, recvErr := stream.Recv() + if errors.Is(recvErr, io.EOF) { + break } + if recvErr != nil { + return recvErr + } + if len(resp.Choices) > 0 { + content := resp.Choices[0].Delta.Content + if content != "" { + _, _ = fmt.Fprintf(w, "data: %s\n\n", content) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + answer.WriteString(content) + } + } + } - if !strings.HasPrefix(line, "data: ") { - continue - } - data := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) - if data == "[DONE]" { - continue - } - var chunk struct { - Choices []struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - } `json:"choices"` - } - if json.Unmarshal([]byte(data), &chunk) == nil && len(chunk.Choices) > 0 { - answer.WriteString(chunk.Choices[0].Delta.Content) - } - } - if err := scanner.Err(); err != nil { - return err - } SaveHistory(ctx, svcCtx, userID, question, answer.String()) return nil } -func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string) error { +func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string, dbCfg *plantModel.SysAiConfig) error { if userID == "" { return nil } @@ -239,8 +207,9 @@ func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string) if err != nil { return err } - if quota.Limit > 0 && quota.Remaining <= 0 { - return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", quota.Limit) + limit := int64(dbCfg.DailyQueryLimit) + if limit > 0 && quota.Remaining <= 0 { + return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", limit) } return nil } diff --git a/app/plant/api/internal/logic/myPlant/addGrowthRecordLogic.go b/app/plant/api/internal/logic/myPlant/addGrowthRecordLogic.go index 28b8c57..ae84652 100644 --- a/app/plant/api/internal/logic/myPlant/addGrowthRecordLogic.go +++ b/app/plant/api/internal/logic/myPlant/addGrowthRecordLogic.go @@ -3,10 +3,13 @@ package myPlant import ( "context" "fmt" - "github.com/zeromicro/go-zero/core/logx" + "sundynix-micro-go/app/plant/api/internal/svc" "sundynix-micro-go/app/plant/api/internal/types" - "sundynix-micro-go/app/plant/rpc/plant" + plantModel "sundynix-micro-go/app/plant/model" + + "github.com/zeromicro/go-zero/core/logx" + "gorm.io/gorm" ) type AddGrowthRecordLogic struct { @@ -21,8 +24,36 @@ func NewAddGrowthRecordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *A func (l *AddGrowthRecordLogic) AddGrowthRecord(req *types.GrowthRecordReq) error { userId := fmt.Sprintf("%v", l.ctx.Value("userId")) - _, err := l.svcCtx.PlantRpc.AddGrowthRecord(l.ctx, &plant.AddGrowthRecordReq{ - UserId: userId, PlantId: req.PlantId, Content: req.Content, ImgIds: req.ImgIds, + imgIds := req.ImgIds + if len(imgIds) == 0 && len(req.OssIds) > 0 { + imgIds = req.OssIds + } + + err := l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + record := plantModel.GrowthRecord{ + UserID: userId, + PlantID: req.PlantId, + Name: req.Name, + Tag: req.Tag, + Desc: req.Desc, + Content: req.Content, + } + if err := tx.Create(&record).Error; err != nil { + return err + } + // 保存图片关联 + if len(imgIds) > 0 { + relations := make([]plantModel.GrowthRecordOss, 0, len(imgIds)) + for _, ossId := range imgIds { + relations = append(relations, plantModel.GrowthRecordOss{ + GrowthRecordID: record.ID, OssID: ossId, + }) + } + if err := tx.Create(&relations).Error; err != nil { + return err + } + } + return nil }) return err } diff --git a/app/plant/api/internal/logic/ocr/getmyclassifyloglogic.go b/app/plant/api/internal/logic/ocr/getmyclassifyloglogic.go index 0f6bc5d..b20ffb5 100644 --- a/app/plant/api/internal/logic/ocr/getmyclassifyloglogic.go +++ b/app/plant/api/internal/logic/ocr/getmyclassifyloglogic.go @@ -3,9 +3,11 @@ package ocr import ( "context" "fmt" + "time" + "github.com/zeromicro/go-zero/core/logx" "sundynix-micro-go/app/plant/api/internal/svc" - plantPb "sundynix-micro-go/app/plant/rpc/plant" + plantModel "sundynix-micro-go/app/plant/model" ) type GetMyClassifyLogLogic struct { @@ -14,11 +16,53 @@ type GetMyClassifyLogLogic struct { svcCtx *svc.ServiceContext } +type ClassifyRecordResp struct { + List []ClassifyRecordInfo `json:"list"` + Total int64 `json:"total"` +} + +type ClassifyRecordInfo struct { + ID string `json:"id"` + UserID string `json:"userId"` + LogID uint64 `json:"logId"` + AllResults plantModel.ResultsArray `json:"allResults"` + CreatedAt string `json:"createdAt"` + CreatedAtStr string `json:"createdAtStr"` +} + func NewGetMyClassifyLogLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetMyClassifyLogLogic { return &GetMyClassifyLogLogic{Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } -func (l *GetMyClassifyLogLogic) GetMyClassifyLog() (*plantPb.ClassifyLogListResp, error) { +func (l *GetMyClassifyLogLogic) GetMyClassifyLog() (*ClassifyRecordResp, error) { userId := fmt.Sprintf("%v", l.ctx.Value("userId")) - return l.svcCtx.PlantRpc.GetMyClassifyLog(l.ctx, &plantPb.GetProfileReq{UserId: userId}) + + var records []plantModel.ClassifyRecord + var total int64 + + db := l.svcCtx.DB.Model(&plantModel.ClassifyRecord{}).Where("user_id = ?", userId) + if err := db.Count(&total).Error; err != nil { + return nil, err + } + + if err := db.Order("created_at desc").Limit(50).Find(&records).Error; err != nil { + return nil, err + } + + list := make([]ClassifyRecordInfo, 0, len(records)) + for _, og := range records { + list = append(list, ClassifyRecordInfo{ + ID: og.ID, + UserID: og.UserID, + LogID: og.LogID, + AllResults: og.AllResults, + CreatedAt: og.CreatedAt.Format(time.RFC3339), + CreatedAtStr: og.CreatedAt.Format("2006-01-02 15:04:05"), + }) + } + + return &ClassifyRecordResp{ + List: list, + Total: total, + }, nil } diff --git a/app/plant/api/internal/logic/ocr/ocrClassifyLogic.go b/app/plant/api/internal/logic/ocr/ocrClassifyLogic.go index 702a41c..d616aec 100644 --- a/app/plant/api/internal/logic/ocr/ocrClassifyLogic.go +++ b/app/plant/api/internal/logic/ocr/ocrClassifyLogic.go @@ -14,6 +14,7 @@ import ( "sundynix-micro-go/app/plant/api/internal/svc" "sundynix-micro-go/app/plant/api/internal/types" + plantModel "sundynix-micro-go/app/plant/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -83,7 +84,51 @@ func (l *OcrClassifyLogic) OcrClassify(req *types.OcrReq) (interface{}, error) { return nil, fmt.Errorf("读取识别结果失败: %w", err) } - // 3. 直接返回百度原始结果(前端自行解析 result 字段) + // 3. 解析为结构化识别结果并写入 ClassifyRecord 表 + var baiduResp struct { + LogId uint64 `json:"log_id"` + Result []struct { + Score float64 `json:"score"` + Name string `json:"name"` + BaikeInfo *struct { + BaikeUrl string `json:"baike_url"` + ImageUrl string `json:"image_url"` + Description string `json:"description"` + } `json:"baike_info"` + } `json:"result"` + } + _ = json.Unmarshal(body, &baiduResp) + + if baiduResp.LogId > 0 { + var dbResults plantModel.ResultsArray = make(plantModel.ResultsArray, 0, len(baiduResp.Result)) + for _, item := range baiduResp.Result { + var baikeInfo *plantModel.BaikeInfo + if item.BaikeInfo != nil { + baikeInfo = &plantModel.BaikeInfo{ + BaikeUrl: item.BaikeInfo.BaikeUrl, + ImageUrl: item.BaikeInfo.ImageUrl, + Description: item.BaikeInfo.Description, + } + } + dbResults = append(dbResults, plantModel.ResultItem{ + Score: item.Score, + Name: item.Name, + BaikeInfo: baikeInfo, + }) + } + + userID := fmt.Sprintf("%v", l.ctx.Value("userId")) + record := plantModel.ClassifyRecord{ + UserID: userID, + LogID: baiduResp.LogId, + AllResults: dbResults, + } + if errDb := l.svcCtx.DB.Create(&record).Error; errDb != nil { + l.Logger.Errorf("植物识别记录写入数据库失败: %v", errDb) + } + } + + // 4. 直接返回百度原始结果 var result interface{} if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("解析识别结果失败: %w", err) diff --git a/app/plant/api/internal/logic/post/createPostLogic.go b/app/plant/api/internal/logic/post/createPostLogic.go index c8b83db..c38530e 100644 --- a/app/plant/api/internal/logic/post/createPostLogic.go +++ b/app/plant/api/internal/logic/post/createPostLogic.go @@ -31,12 +31,16 @@ func NewCreatePostLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Create func (l *CreatePostLogic) CreatePost(req *types.CreatePostReq) error { userId := fmt.Sprintf("%v", l.ctx.Value("userId")) + imgIds := req.ImgIds + if len(imgIds) == 0 && len(req.OssIds) > 0 { + imgIds = req.OssIds + } _, err := l.svcCtx.PlantRpc.CreatePost(l.ctx, &plant.CreatePostReq{ UserId: userId, Title: req.Title, Content: req.Content, Location: req.Location, - ImgIds: req.ImgIds, + ImgIds: imgIds, TopicId: req.TopicId, }) return err diff --git a/app/plant/api/internal/logic/userProfile/getmystarslogic.go b/app/plant/api/internal/logic/userProfile/getmystarslogic.go index 98dec63..42816fb 100644 --- a/app/plant/api/internal/logic/userProfile/getmystarslogic.go +++ b/app/plant/api/internal/logic/userProfile/getmystarslogic.go @@ -3,10 +3,15 @@ package userProfile import ( "context" "fmt" - "github.com/zeromicro/go-zero/core/logx" + "time" + + filePb "sundynix-micro-go/app/file/rpc/file" "sundynix-micro-go/app/plant/api/internal/svc" "sundynix-micro-go/app/plant/api/internal/types" - plantPb "sundynix-micro-go/app/plant/rpc/plant" + plantModel "sundynix-micro-go/app/plant/model" + + "github.com/zeromicro/go-zero/core/logx" + "gorm.io/gorm" ) type GetMyStarsLogic struct { @@ -19,7 +24,348 @@ func NewGetMyStarsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetMyS return &GetMyStarsLogic{Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } -func (l *GetMyStarsLogic) GetMyStars(req *types.PageReq) (*plantPb.UserStarListResp, error) { +func (l *GetMyStarsLogic) GetMyStars(req *types.PageReq) (interface{}, error) { userId := fmt.Sprintf("%v", l.ctx.Value("userId")) - return l.svcCtx.PlantRpc.GetMyStars(l.ctx, &plantPb.GetProfileReq{UserId: userId}) + + db := l.svcCtx.DB.Model(&plantModel.UserStar{}).Where("user_id = ?", userId) + if req.Class == 1 { + db = db.Where("type = ?", "wiki") + } else if req.Class == 2 { + db = db.Where("type = ?", "post") + } + + var total int64 + if err := db.Count(&total).Error; err != nil { + return nil, err + } + + pageSize := req.PageSize + if pageSize <= 0 { + pageSize = 10 + } + page := req.Current + if page <= 0 { + page = 1 + } + offset := (page - 1) * pageSize + + var stars []*plantModel.UserStar + if err := db.Limit(pageSize).Offset(offset).Order("created_at desc").Find(&stars).Error; err != nil { + return nil, err + } + + // 提取 TargetID + var wikiIds []string + var postIds []string + for _, s := range stars { + if s.Type == "wiki" { + wikiIds = append(wikiIds, s.TargetID) + } else if s.Type == "post" { + postIds = append(postIds, s.TargetID) + } + } + + // 1. 查询 Wiki 详情 + wikiMap := make(map[string]map[string]interface{}) + if len(wikiIds) > 0 { + var wikis []*plantModel.Wiki + if err := l.svcCtx.DB.Where("id IN ?", wikiIds).Find(&wikis).Error; err == nil { + // 查本地 WikiOss + type rel struct { + WikiID string `gorm:"column:wiki_id"` + OssID string `gorm:"column:oss_id"` + } + var rels []rel + l.svcCtx.DB.Table("sundynix_plant_wiki_oss").Where("wiki_id IN ?", wikiIds).Find(&rels) + wikiOssMap := make(map[string][]string) + var allOssIds []string + for _, r := range rels { + wikiOssMap[r.WikiID] = append(wikiOssMap[r.WikiID], r.OssID) + allOssIds = append(allOssIds, r.OssID) + } + + // 通过 FileRpc 获取图片信息 + fileMap := l.fetchFileMap(allOssIds) + + for _, w := range wikis { + ossIds := wikiOssMap[w.ID] + imgList := l.imgListToList(fileMap, ossIds) + wikiMap[w.ID] = map[string]interface{}{ + "id": w.ID, "name": w.Name, "latinName": w.LatinName, + "aliases": w.Aliases, "genus": w.Genus, "difficulty": w.Difficulty, + "isHot": w.IsHot, "growthHabit": w.GrowthHabit, + "lightIntensity": w.LightIntensity, "classId": w.ClassID, + "createdAt": w.CreatedAt.Format("2006-01-02 15:04:05"), + "hasStar": 1, // 既然在这个列表中,说明一定是被收藏的 + "imgList": imgList, + } + } + } + } + + // 2. 查询 Post 详情 + postMap := make(map[string]map[string]interface{}) + if len(postIds) > 0 { + var posts []*plantModel.Post + if err := l.svcCtx.DB. + Preload("CommentList", func(db *gorm.DB) *gorm.DB { + return db.Order("created_at asc") + }). + Preload("LikeList"). + Where("id IN ?", postIds).Find(&posts).Error; err == nil { + + // 查帖子图片 + postImgMap := l.queryPostImages(postIds) + // 查用户信息 + allUserIds := l.collectPostUserIds(posts) + userMap := l.queryUserMap(allUserIds) + // 点赞收藏状态 + likedSet, starredSet := l.queryLikeStarStatus(userId, postIds) + + for _, p := range posts { + item := map[string]interface{}{ + "id": p.ID, "title": p.Title, "content": p.Content, + "userId": p.UserID, "location": p.Location, + "viewCount": p.ViewCount, "commentCount": p.CommentCount, + "likeCount": p.LikeCount, "starCount": p.StarCount, + "hasReviewed": p.HasReviewed, + "createdAt": p.CreatedAt.Format(time.RFC3339), + "updatedAt": p.UpdatedAt.Format(time.RFC3339), + "createdAtStr": p.CreatedAt.Format("2006-01-02 15:04:05"), + "hasLiked": 0, "hasStar": 0, + "imgList": postImgMap[p.ID], + "publisher": l.buildPublisherInfo(userMap, p.UserID), + "commentList": l.buildCommentList(userMap, p.CommentList), + "likeList": l.buildLikeList(userMap, p.LikeList), + "starList": []map[string]interface{}{}, + } + if likedSet[p.ID] { + item["hasLiked"] = 1 + } + if starredSet[p.ID] { + item["hasStar"] = 1 + } + postMap[p.ID] = item + } + } + } + + // 3. 按照 Stars 排序组装最终 List + var list []map[string]interface{} + for _, s := range stars { + item := map[string]interface{}{ + "id": s.ID, "userId": s.UserID, "targetId": s.TargetID, "type": s.Type, + "createdAt": s.CreatedAt.Format("2006-01-02 15:04:05"), + } + if s.Type == "wiki" { + if w, ok := wikiMap[s.TargetID]; ok { + item["wiki"] = w + list = append(list, item) + } + } else if s.Type == "post" { + if p, ok := postMap[s.TargetID]; ok { + item["post"] = p + list = append(list, item) + } + } + } + + if list == nil { + list = []map[string]interface{}{} + } + + return map[string]interface{}{ + "list": list, + "total": total, + "page": page, + "pageSize": pageSize, + }, nil +} + +func (l *GetMyStarsLogic) fetchFileMap(ids []string) map[string]map[string]interface{} { + result := make(map[string]map[string]interface{}) + if len(ids) == 0 { + return result + } + resp, err := l.svcCtx.FileRpc.GetFilesByIds(l.ctx, &filePb.GetFilesByIdsReq{Ids: ids}) + if err != nil || resp == nil { + return result + } + for _, f := range resp.Files { + result[f.Id] = map[string]interface{}{ + "id": f.Id, "name": f.Name, "url": f.Url, "tag": f.Tag, + "key": f.Key, "suffix": f.Suffix, "md5": f.Md5, + "createdAt": time.Unix(f.CreatedAt, 0).Format(time.RFC3339), + "updatedAt": time.Unix(f.CreatedAt, 0).Format(time.RFC3339), + "createdAtStr": time.Unix(f.CreatedAt, 0).Format("2006-01-02 15:04:05"), + } + } + return result +} + +func (l *GetMyStarsLogic) imgListToList(fileMap map[string]map[string]interface{}, ossIds []string) []map[string]interface{} { + var list []map[string]interface{} + for _, id := range ossIds { + if img, ok := fileMap[id]; ok { + list = append(list, img) + } + } + if list == nil { + list = []map[string]interface{}{} + } + return list +} + +func (l *GetMyStarsLogic) queryPostImages(postIds []string) map[string][]map[string]interface{} { + type rel struct { + PostID string `gorm:"column:post_id"` + OssID string `gorm:"column:oss_id"` + } + var rels []rel + l.svcCtx.DB.Table("sundynix_plant_post_oss").Where("post_id IN ?", postIds).Find(&rels) + + var allOssIds []string + pidMap := make(map[string][]string) + for _, r := range rels { + pidMap[r.PostID] = append(pidMap[r.PostID], r.OssID) + allOssIds = append(allOssIds, r.OssID) + } + + fileInfos := l.fetchFileMap(allOssIds) + + result := make(map[string][]map[string]interface{}) + for pid, ids := range pidMap { + var imgs []map[string]interface{} + for _, oid := range ids { + if info, ok := fileInfos[oid]; ok { + imgs = append(imgs, info) + } + } + if imgs == nil { + imgs = []map[string]interface{}{} + } + result[pid] = imgs + } + for _, pid := range postIds { + if _, ok := result[pid]; !ok { + result[pid] = []map[string]interface{}{} + } + } + return result +} + +func (l *GetMyStarsLogic) collectPostUserIds(posts []*plantModel.Post) []string { + set := make(map[string]bool) + for _, p := range posts { + set[p.UserID] = true + for _, c := range p.CommentList { + set[c.UserID] = true + } + for _, l := range p.LikeList { + set[l.UserID] = true + } + } + var ids []string + for id := range set { + ids = append(ids, id) + } + return ids +} + +func (l *GetMyStarsLogic) queryUserMap(ids []string) map[string]map[string]interface{} { + result := make(map[string]map[string]interface{}) + if len(ids) == 0 { + return result + } + type userRow struct { + ID string `gorm:"column:id"` + NickName string `gorm:"column:nick_name"` + Name string `gorm:"column:name"` + AvatarID string `gorm:"column:avatar_id"` + } + var rows []userRow + l.svcCtx.DB.Table("sundynix_user").Where("id IN ?", ids).Find(&rows) + + var avatarIds []string + for _, row := range rows { + if row.AvatarID != "" { + avatarIds = append(avatarIds, row.AvatarID) + } + } + avatarMap := l.fetchFileMap(avatarIds) + + for _, row := range rows { + avatarData := map[string]interface{}{} + if av, ok := avatarMap[row.AvatarID]; ok { + avatarData = av + } + result[row.ID] = map[string]interface{}{ + "id": row.ID, "nickName": row.NickName, "name": row.Name, + "avatarId": row.AvatarID, "avatar": avatarData, + } + } + return result +} + +func (l *GetMyStarsLogic) queryLikeStarStatus(userId string, postIds []string) (likedSet, starredSet map[string]bool) { + likedSet = make(map[string]bool) + starredSet = make(map[string]bool) + if len(postIds) == 0 { + return + } + type rel struct { + PostID string `gorm:"column:post_id"` + } + var likes []rel + l.svcCtx.DB.Table("sundynix_plant_post_like").Where("post_id IN ? AND user_id = ?", postIds, userId).Find(&likes) + for _, l := range likes { + likedSet[l.PostID] = true + } + var stars []rel + l.svcCtx.DB.Table("sundynix_plant_user_star").Where("target_id IN ? AND user_id = ? AND type = 'post'", postIds, userId).Find(&stars) + for _, s := range stars { + starredSet[s.PostID] = true + } + return +} + +func (l *GetMyStarsLogic) buildPublisherInfo(userMap map[string]map[string]interface{}, userId string) map[string]interface{} { + if u, ok := userMap[userId]; ok { + return u + } + return map[string]interface{}{ + "id": userId, "nickName": "", "name": "", "avatarId": "", "avatar": map[string]interface{}{}, + } +} + +func (l *GetMyStarsLogic) buildCommentList(userMap map[string]map[string]interface{}, comments []*plantModel.PostComment) []map[string]interface{} { + var list []map[string]interface{} + for _, c := range comments { + list = append(list, map[string]interface{}{ + "id": c.ID, "postId": c.PostID, "userId": c.UserID, + "content": c.Content, "parentId": c.ParentID, + "createdAt": c.CreatedAt.Format(time.RFC3339), + "updatedAt": c.UpdatedAt.Format(time.RFC3339), + "createdAtStr": c.CreatedAt.Format("2006-01-02 15:04:05"), + "commentator": l.buildPublisherInfo(userMap, c.UserID), + }) + } + if list == nil { + list = []map[string]interface{}{} + } + return list +} + +func (l *GetMyStarsLogic) buildLikeList(userMap map[string]map[string]interface{}, likes []*plantModel.PostLike) []map[string]interface{} { + var list []map[string]interface{} + for _, like := range likes { + list = append(list, map[string]interface{}{ + "id": like.ID, "postId": like.PostID, "userId": like.UserID, + "liker": l.buildPublisherInfo(userMap, like.UserID), + }) + } + if list == nil { + list = []map[string]interface{}{} + } + return list } diff --git a/app/plant/api/internal/logic/wiki/createWikiClassLogic.go b/app/plant/api/internal/logic/wiki/createWikiClassLogic.go index 1e4d683..b1b0491 100644 --- a/app/plant/api/internal/logic/wiki/createWikiClassLogic.go +++ b/app/plant/api/internal/logic/wiki/createWikiClassLogic.go @@ -29,9 +29,13 @@ func NewCreateWikiClassLogic(ctx context.Context, svcCtx *svc.ServiceContext) *C } func (l *CreateWikiClassLogic) CreateWikiClass(req *types.WikiClassReq) error { + icon := req.Icon + if icon == "" && req.OssId != "" { + icon = req.OssId + } _, err := l.svcCtx.PlantRpc.CreateWikiClass(l.ctx, &plant.CreateWikiClassReq{ Name: req.Name, - Icon: req.Icon, + Icon: icon, }) return err } diff --git a/app/plant/api/internal/logic/wiki/updatewikiclasslogic.go b/app/plant/api/internal/logic/wiki/updatewikiclasslogic.go index 7842a51..8b4dbc8 100644 --- a/app/plant/api/internal/logic/wiki/updatewikiclasslogic.go +++ b/app/plant/api/internal/logic/wiki/updatewikiclasslogic.go @@ -19,8 +19,12 @@ func NewUpdateWikiClassLogic(ctx context.Context, svcCtx *svc.ServiceContext) *U } func (l *UpdateWikiClassLogic) UpdateWikiClass(req *types.UpdateWikiClassReq) error { + icon := req.Icon + if icon == "" && req.OssId != "" { + icon = req.OssId + } _, err := l.svcCtx.PlantRpc.UpdateWikiClass(l.ctx, &plantPb.UpdateWikiClassReq{ - Id: req.Id, Name: req.Name, Icon: req.Icon, + Id: req.Id, Name: req.Name, Icon: icon, }) return err } diff --git a/app/plant/api/internal/types/types.go b/app/plant/api/internal/types/types.go index 0cc8c94..262e20f 100644 --- a/app/plant/api/internal/types/types.go +++ b/app/plant/api/internal/types/types.go @@ -88,6 +88,7 @@ type CreatePostReq struct { Content string `json:"content"` Location string `json:"location,optional"` ImgIds []string `json:"imgIds,optional"` + OssIds []string `json:"ossIds,optional"` // 向下兼容旧版 TopicId string `json:"topicId,optional"` } @@ -180,6 +181,7 @@ type PageReq struct { Current int `json:"current,optional"` PageSize int `json:"pageSize,optional"` Keyword string `json:"keyword,optional"` + Class int `json:"class,optional"` // 分类过滤: 0 全部, 1 百科, 2 动态 } type PlantListReq struct { @@ -289,9 +291,10 @@ type UpdateTopicReq struct { } type UpdateWikiClassReq struct { - Id string `json:"id"` - Name string `json:"name,optional"` - Icon string `json:"icon,optional"` + Id string `json:"id"` + Name string `json:"name,optional"` + Icon string `json:"icon,optional"` + OssId string `json:"ossId,optional"` // 向下兼容旧版 } type UpdateWikiReq struct { @@ -335,8 +338,9 @@ type UpdatePlanReq struct { } type WikiClassReq struct { - Name string `json:"name"` - Icon string `json:"icon,optional"` + Name string `json:"name"` + Icon string `json:"icon,optional"` + OssId string `json:"ossId,optional"` // 向下兼容旧版 } // ========== Banner ========== diff --git a/app/plant/model/plant_model.go b/app/plant/model/plant_model.go index 58828c0..9a64b06 100644 --- a/app/plant/model/plant_model.go +++ b/app/plant/model/plant_model.go @@ -1,6 +1,9 @@ package model import ( + "database/sql/driver" + "encoding/json" + "errors" "sundynix-micro-go/common/model" "time" ) @@ -247,19 +250,6 @@ type Topic struct { func (Topic) TableName() string { return "sundynix_plant_topic" } -// ========== OCR ========== - -// OcrLog OCR识别记录 -type OcrLog struct { - model.BaseModel - UserID string `gorm:"size:50;index;column:user_id" json:"userId"` - ImageUrl string `gorm:"size:500;column:image_url" json:"imageUrl"` - Result string `gorm:"type:text;column:result" json:"result"` - LogID uint64 `gorm:"column:log_id;index" json:"logId"` -} - -func (OcrLog) TableName() string { return "sundynix_plant_ocr_log" } - // ========== 积分商城 ========== // ExchangeItem status: 1=上架 2=下架 @@ -391,3 +381,75 @@ type MediaCheckResult struct { } func (MediaCheckResult) TableName() string { return "sundynix_plant_media_check_result" } + +// ========== AI RAG 配置 ========== + +type SysAiConfig struct { + model.BaseModel + IsActive int `gorm:"column:is_active;type:tinyint;default:0;comment:是否激活(1是0否)" json:"isActive"` + QdrantUrl string `gorm:"column:qdrant_url;type:varchar(255);comment:Qdrant接口地址" json:"qdrantUrl"` + QdrantApiKey string `gorm:"column:qdrant_api_key;type:varchar(255);comment:Qdrant密钥" json:"qdrantApiKey"` + QdrantCollection string `gorm:"column:qdrant_collection;type:varchar(100);comment:Qdrant集合名" json:"qdrantCollection"` + VectorDimension int `gorm:"column:vector_dimension;type:int;comment:向量维度" json:"vectorDimension"` + ChatProvider string `gorm:"column:chat_provider;type:varchar(50);comment:对话模型供应商" json:"chatProvider"` + ChatApiUrl string `gorm:"column:chat_api_url;type:varchar(255);comment:对话模型接口地址" json:"chatApiUrl"` + ChatApiKey string `gorm:"column:chat_api_key;type:varchar(255);comment:对话模型ApiKey" json:"chatApiKey"` + ChatModelName string `gorm:"column:chat_model_name;type:varchar(100);comment:对话模型名称" json:"chatModelName"` + EmbeddingProvider string `gorm:"column:embedding_provider;type:varchar(50);comment:Embedding模型供应商" json:"embeddingProvider"` + EmbeddingApiUrl string `gorm:"column:embedding_api_url;type:varchar(255);comment:Embedding模型接口地址" json:"embeddingApiUrl"` + EmbeddingApiKey string `gorm:"column:embedding_api_key;type:varchar(255);comment:Embedding模型ApiKey" json:"embeddingApiKey"` + EmbeddingModelName string `gorm:"column:embedding_model_name;type:varchar(100);comment:Embedding模型名称" json:"embeddingModelName"` + DailyQueryLimit int `gorm:"column:daily_query_limit;type:int;default:20;comment:每用户每日问答上限(0=不限)" json:"dailyQueryLimit"` +} + +func (SysAiConfig) TableName() string { return "sundynix_plant_ai_config" } + +// ========== 植物识别记录 ========== + +type BaikeInfo struct { + BaikeUrl string `json:"baike_url"` // 百度百科链接 + ImageUrl string `json:"image_url"` // 植物图片链接 + Description string `json:"description"` // 植物百科描述文本 +} + +type ResultItem struct { + Score float64 `json:"score"` // 匹配相似度得分(0-1) + Name string `json:"name"` // 植物名称 + BaikeInfo *BaikeInfo `json:"baike_info"` // 植物百科信息 +} + +type ResultsArray []ResultItem + +// Scan 实现 sql.Scanner 接口:JSON String -> Go Struct (读库) +func (r *ResultsArray) Scan(value interface{}) error { + if value == nil { + *r = make([]ResultItem, 0) + return nil + } + bytes, ok := value.([]byte) + if !ok { + if str, ok := value.(string); ok { + bytes = []byte(str) + } else { + return errors.New("type assertion to []byte/string failed") + } + } + return json.Unmarshal(bytes, r) +} + +// Value 实现 driver.Valuer 接口:Go Struct -> JSON String (存库) +func (r ResultsArray) Value() (driver.Value, error) { + if len(r) == 0 { + return "[]", nil + } + return json.Marshal(r) +} + +type ClassifyRecord struct { + model.BaseModel + UserID string `gorm:"size:50;index;column:user_id" json:"userId"` + LogID uint64 `gorm:"column:log_id;index" json:"logId"` + AllResults ResultsArray `gorm:"type:json;column:all_results" json:"allResults"` +} + +func (ClassifyRecord) TableName() string { return "sundynix_plant_classify_record" } diff --git a/app/plant/rpc/etc/plant.yaml b/app/plant/rpc/etc/plant.yaml index 94c85ba..853eea7 100644 --- a/app/plant/rpc/etc/plant.yaml +++ b/app/plant/rpc/etc/plant.yaml @@ -10,13 +10,3 @@ Etcd: DB: DataSource: root:root@tcp(192.168.100.127:3307)/sundynix_micro_go?charset=utf8mb4&parseTime=True&loc=Local - -Ai: - EmbeddingApiUrl: - EmbeddingApiKey: - EmbeddingModelName: - QdrantUrl: - QdrantApiKey: - QdrantCollection: - VectorDimension: 0 - DailyQuota: 20 diff --git a/app/plant/rpc/internal/config/config.go b/app/plant/rpc/internal/config/config.go index 3a0e1e3..897b703 100755 --- a/app/plant/rpc/internal/config/config.go +++ b/app/plant/rpc/internal/config/config.go @@ -7,14 +7,4 @@ type Config struct { DB struct { DataSource string } - Ai struct { - EmbeddingApiUrl string - EmbeddingApiKey string - EmbeddingModelName string - QdrantUrl string - QdrantApiKey string - QdrantCollection string - VectorDimension int - DailyQuota int64 - } } diff --git a/app/plant/rpc/internal/logic/addCareRecordLogic.go b/app/plant/rpc/internal/logic/addCareRecordLogic.go index 1cf5f9c..85dfeed 100644 --- a/app/plant/rpc/internal/logic/addCareRecordLogic.go +++ b/app/plant/rpc/internal/logic/addCareRecordLogic.go @@ -57,19 +57,29 @@ func (l *AddCareRecordLogic) AddCareRecord(in *plant.AddCareRecordReq) (*plant.C // 4. 生成下一期任务(以今天为基准,+period 天) nextDue := time.Now().Truncate(24*time.Hour).AddDate(0, 0, plan.Period) - nextTask := plantModel.CareTask{ - UserID: in.UserId, - PlantID: in.PlantId, - PlanID: in.PlanId, - Name: plan.Name, - Icon: plan.Icon, - TargetAction: plan.TargetAction, - DueDate: nextDue, - Status: 1, - } - if err := tx.Create(&nextTask).Error; err != nil { + + // 检查是否已经存在该计划的待办任务,避免重复生成相同事项的多个待办任务 + var activeCount int64 + if err := tx.Model(&plantModel.CareTask{}). + Where("plan_id = ? AND status = 1", plan.ID). + Count(&activeCount).Error; err != nil { return err } + if activeCount == 0 { + nextTask := plantModel.CareTask{ + UserID: in.UserId, + PlantID: in.PlantId, + PlanID: in.PlanId, + Name: plan.Name, + Icon: plan.Icon, + TargetAction: plan.TargetAction, + DueDate: nextDue, + Status: 1, + } + if err := tx.Create(&nextTask).Error; err != nil { + return err + } + } // 5. 更新用户 care_count 统计 actionMap := map[string]string{ diff --git a/app/plant/rpc/internal/logic/completeTaskLogic.go b/app/plant/rpc/internal/logic/completeTaskLogic.go index bd0d1db..a720d77 100644 --- a/app/plant/rpc/internal/logic/completeTaskLogic.go +++ b/app/plant/rpc/internal/logic/completeTaskLogic.go @@ -51,14 +51,24 @@ func (l *CompleteTaskLogic) CompleteTask(in *plant.CompleteTaskReq) (*plant.Task today := time.Now() todayZero := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, today.Location()) nextDue := todayZero.AddDate(0, 0, plan.Period) - newTask := plantModel.CareTask{ - UserID: plan.UserID, PlantID: plan.PlantID, PlanID: plan.ID, - Name: plan.Name, Icon: plan.Icon, TargetAction: plan.TargetAction, - DueDate: nextDue, Status: 1, - } - if err := tx.Create(&newTask).Error; err != nil { + + // 检查是否已经存在该计划的待办任务,避免重复生成相同事项的多个待办任务 + var activeCount int64 + if err := tx.Model(&plantModel.CareTask{}). + Where("plan_id = ? AND status = 1", plan.ID). + Count(&activeCount).Error; err != nil { return err } + if activeCount == 0 { + newTask := plantModel.CareTask{ + UserID: plan.UserID, PlantID: plan.PlantID, PlanID: plan.ID, + Name: plan.Name, Icon: plan.Icon, TargetAction: plan.TargetAction, + DueDate: nextDue, Status: 1, + } + if err := tx.Create(&newTask).Error; err != nil { + return err + } + } // 4. 保存养护记录 record := plantModel.CareRecord{ UserID: plan.UserID, PlantID: plan.PlantID, PlanID: plan.ID, diff --git a/app/plant/rpc/internal/logic/deleteClassifyLogLogic.go b/app/plant/rpc/internal/logic/deleteClassifyLogLogic.go index da4dba6..9426e3b 100644 --- a/app/plant/rpc/internal/logic/deleteClassifyLogLogic.go +++ b/app/plant/rpc/internal/logic/deleteClassifyLogLogic.go @@ -19,7 +19,7 @@ func NewDeleteClassifyLogLogic(ctx context.Context, svcCtx *svc.ServiceContext) } func (l *DeleteClassifyLogLogic) DeleteClassifyLog(in *plant.IdsReq) (*plant.CommonResp, error) { - if err := l.svcCtx.DB.Where("id IN ?", in.Ids).Delete(&plantModel.OcrLog{}).Error; err != nil { + if err := l.svcCtx.DB.Where("id IN ?", in.Ids).Delete(&plantModel.ClassifyRecord{}).Error; err != nil { return nil, err } return &plant.CommonResp{Code: 0, Msg: "ok"}, nil diff --git a/app/plant/rpc/internal/logic/deleteWikiVectorLogic.go b/app/plant/rpc/internal/logic/deleteWikiVectorLogic.go index bd1e9bc..2536f97 100644 --- a/app/plant/rpc/internal/logic/deleteWikiVectorLogic.go +++ b/app/plant/rpc/internal/logic/deleteWikiVectorLogic.go @@ -22,10 +22,14 @@ func NewDeleteWikiVectorLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } func (l *DeleteWikiVectorLogic) DeleteWikiVector(in *plant.SyncWikiVectorReq) (*plant.CommonResp, error) { - if l.svcCtx.Config.Ai.QdrantUrl == "" || l.svcCtx.Config.Ai.QdrantCollection == "" { + dbCfg, err := getActiveAiConfig(l.svcCtx.DB) + if err != nil { + return nil, err + } + if dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" { return nil, errors.New("AI/RAG 未配置 QdrantUrl 或 QdrantCollection") } - if err := deleteWikiVector(l.ctx, l.svcCtx.Config, in.WikiId); err != nil { + if err := deleteWikiVector(l.ctx, dbCfg, in.WikiId); err != nil { return nil, err } if err := l.svcCtx.DB.Model(&plantModel.Wiki{}).Where("id = ?", in.WikiId).Update("is_vector_synced", false).Error; err != nil { diff --git a/app/plant/rpc/internal/logic/getAiChatQuotaLogic.go b/app/plant/rpc/internal/logic/getAiChatQuotaLogic.go index ddc5729..4445a7d 100644 --- a/app/plant/rpc/internal/logic/getAiChatQuotaLogic.go +++ b/app/plant/rpc/internal/logic/getAiChatQuotaLogic.go @@ -26,10 +26,12 @@ func (l *GetAiChatQuotaLogic) GetAiChatQuota(in *plant.GetProfileReq) (*plant.Ai l.svcCtx.DB.Model(&plantModel.AiChatHistory{}). Where("user_id = ? AND created_at >= ?", in.UserId, todayStart). Count(&used) - limit := l.svcCtx.Config.Ai.DailyQuota - if limit <= 0 { - limit = 20 + + limit := int64(20) + if dbCfg, err := getActiveAiConfig(l.svcCtx.DB); err == nil && dbCfg.DailyQueryLimit > 0 { + limit = int64(dbCfg.DailyQueryLimit) } + remaining := limit - used if remaining < 0 { remaining = 0 diff --git a/app/plant/rpc/internal/logic/getMyClassifyLogLogic.go b/app/plant/rpc/internal/logic/getMyClassifyLogLogic.go index db5f92f..7cd1996 100644 --- a/app/plant/rpc/internal/logic/getMyClassifyLogLogic.go +++ b/app/plant/rpc/internal/logic/getMyClassifyLogLogic.go @@ -2,6 +2,7 @@ package logic import ( "context" + "encoding/json" "github.com/zeromicro/go-zero/core/logx" plantModel "sundynix-micro-go/app/plant/model" "sundynix-micro-go/app/plant/rpc/internal/svc" @@ -19,14 +20,19 @@ func NewGetMyClassifyLogLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } func (l *GetMyClassifyLogLogic) GetMyClassifyLog(in *plant.GetProfileReq) (*plant.ClassifyLogListResp, error) { - var logs []plantModel.OcrLog - if err := l.svcCtx.DB.Where("user_id = ?", in.UserId).Order("created_at desc").Limit(50).Find(&logs).Error; err != nil { + var records []plantModel.ClassifyRecord + if err := l.svcCtx.DB.Where("user_id = ?", in.UserId).Order("created_at desc").Limit(50).Find(&records).Error; err != nil { return nil, err } - list := make([]*plant.ClassifyLogInfo, 0, len(logs)) - for _, og := range logs { + list := make([]*plant.ClassifyLogInfo, 0, len(records)) + for _, og := range records { + var imgUrl string + if len(og.AllResults) > 0 && og.AllResults[0].BaikeInfo != nil { + imgUrl = og.AllResults[0].BaikeInfo.ImageUrl + } + resultBytes, _ := json.Marshal(og.AllResults) list = append(list, &plant.ClassifyLogInfo{ - Id: og.ID, UserId: og.UserID, ImageUrl: og.ImageUrl, Result: og.Result, + Id: og.ID, UserId: og.UserID, ImageUrl: imgUrl, Result: string(resultBytes), CreatedAt: og.CreatedAt.Format("2006-01-02 15:04:05"), }) } diff --git a/app/plant/rpc/internal/logic/qdrantVector.go b/app/plant/rpc/internal/logic/qdrantVector.go index 516a371..f47a216 100644 --- a/app/plant/rpc/internal/logic/qdrantVector.go +++ b/app/plant/rpc/internal/logic/qdrantVector.go @@ -1,24 +1,24 @@ package logic import ( - "bytes" "context" - "crypto/md5" - "encoding/hex" - "encoding/json" "errors" "fmt" - "io" - "net/http" "strings" + "github.com/google/uuid" + qdrant "github.com/qdrant/go-client/qdrant" + openai "github.com/sashabaranov/go-openai" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "gorm.io/gorm" + plantModel "sundynix-micro-go/app/plant/model" - "sundynix-micro-go/app/plant/rpc/internal/config" ) func wikiVectorID(wikiID string) string { - sum := md5.Sum([]byte("sundynix-plant-wiki:" + wikiID)) - return hex.EncodeToString(sum[:]) + return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiID)).String() } func buildWikiVectorText(w plantModel.Wiki) string { @@ -30,131 +30,129 @@ func buildWikiVectorText(w plantModel.Wiki) string { w.FloweringShape, w.FlowerDiameter, w.Fruit) } -func embeddingModel(c config.Config) string { - if c.Ai.EmbeddingModelName != "" { - return c.Ai.EmbeddingModelName +func getActiveAiConfig(db *gorm.DB) (*plantModel.SysAiConfig, error) { + var cfg plantModel.SysAiConfig + if err := db.Where("is_active = 1").First(&cfg).Error; err != nil { + return nil, errors.New("数据库未找到已激活的 AI 配置") + } + return &cfg, nil +} + +func embeddingModel(cfg *plantModel.SysAiConfig) string { + if cfg.EmbeddingModelName != "" { + return cfg.EmbeddingModelName } return "text-embedding-3-small" } -func createEmbedding(ctx context.Context, c config.Config, text string) ([]float32, error) { - body, _ := json.Marshal(map[string]interface{}{ - "model": embeddingModel(c), - "input": text, +func createEmbedding(ctx context.Context, cfg *plantModel.SysAiConfig, text string) ([]float32, error) { + config := openai.DefaultConfig(cfg.EmbeddingApiKey) + if cfg.EmbeddingApiUrl != "" { + config.BaseURL = cfg.EmbeddingApiUrl + } + client := openai.NewClientWithConfig(config) + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: []string{text}, + Model: openai.EmbeddingModel(embeddingModel(cfg)), }) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Ai.EmbeddingApiUrl, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.Ai.EmbeddingApiKey) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return nil, fmt.Errorf("embedding 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) - } - var parsed struct { - Data []struct { - Embedding []float32 `json:"embedding"` - } `json:"data"` - } - if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { - return nil, err - } - if len(parsed.Data) == 0 || len(parsed.Data[0].Embedding) == 0 { + if len(resp.Data) == 0 || len(resp.Data[0].Embedding) == 0 { return nil, errors.New("embedding 响应为空") } - return parsed.Data[0].Embedding, nil + return resp.Data[0].Embedding, nil } -func qdrantURL(c config.Config, path string) string { - return strings.TrimRight(c.Ai.QdrantUrl, "/") + path -} - -func doQdrant(ctx context.Context, c config.Config, method, path string, body interface{}) error { - var reader io.Reader - if body != nil { - raw, _ := json.Marshal(body) - reader = bytes.NewReader(raw) +func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) { + addr := strings.TrimPrefix(cfg.QdrantUrl, "http://") + addr = strings.TrimPrefix(addr, "https://") + conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err) } - req, err := http.NewRequestWithContext(ctx, method, qdrantURL(c, path), reader) + ctx := context.Background() + if cfg.QdrantApiKey != "" { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey)) + } + return conn, ctx, nil +} + +func ensureQdrantCollection(cfg *plantModel.SysAiConfig, dim int) error { + conn, ctx, err := newQdrantConn(cfg) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - if c.Ai.QdrantApiKey != "" { - req.Header.Set("api-key", c.Ai.QdrantApiKey) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return fmt.Errorf("qdrant 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) - } - return nil -} - -func ensureQdrantCollection(ctx context.Context, c config.Config, dim int) error { - getReq, err := http.NewRequestWithContext(ctx, http.MethodGet, qdrantURL(c, "/collections/"+c.Ai.QdrantCollection), nil) - if err != nil { - return err - } - if c.Ai.QdrantApiKey != "" { - getReq.Header.Set("api-key", c.Ai.QdrantApiKey) - } - if resp, err := http.DefaultClient.Do(getReq); err == nil { - _ = resp.Body.Close() - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil - } - } + defer conn.Close() if dim <= 0 { - dim = c.Ai.VectorDimension + dim = cfg.VectorDimension } if dim <= 0 { dim = 1536 } - return doQdrant(ctx, c, http.MethodPut, "/collections/"+c.Ai.QdrantCollection, map[string]interface{}{ - "vectors": map[string]interface{}{ - "size": dim, - "distance": "Cosine", + collClient := qdrant.NewCollectionsClient(conn) + if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil { + return nil + } + _, err = collClient.Create(ctx, &qdrant.CreateCollection{ + CollectionName: cfg.QdrantCollection, + VectorsConfig: &qdrant.VectorsConfig{ + Config: &qdrant.VectorsConfig_Params{ + Params: &qdrant.VectorParams{Size: uint64(dim), Distance: qdrant.Distance_Cosine}, + }, }, }) + return err } -func upsertWikiVector(ctx context.Context, c config.Config, w plantModel.Wiki) error { +func upsertWikiVector(ctx context.Context, cfg *plantModel.SysAiConfig, w plantModel.Wiki) error { text := buildWikiVectorText(w) - vector, err := createEmbedding(ctx, c, text) + vector, err := createEmbedding(ctx, cfg, text) if err != nil { return err } - if err := ensureQdrantCollection(ctx, c, len(vector)); err != nil { + if err := ensureQdrantCollection(cfg, len(vector)); err != nil { return err } - return doQdrant(ctx, c, http.MethodPut, "/collections/"+c.Ai.QdrantCollection+"/points?wait=true", map[string]interface{}{ - "points": []map[string]interface{}{ - { - "id": wikiVectorID(w.ID), - "vector": vector, - "payload": map[string]interface{}{ - "wiki_id": w.ID, - "name": w.Name, - "full_text": text, + conn, qdCtx, err := newQdrantConn(cfg) + if err != nil { + return err + } + defer conn.Close() + ptsClient := qdrant.NewPointsClient(conn) + + _, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{ + CollectionName: cfg.QdrantCollection, + Points: []*qdrant.PointStruct{{ + Id: qdrant.NewID(wikiVectorID(w.ID)), + Vectors: qdrant.NewVectors(vector...), + Payload: map[string]*qdrant.Value{ + "wiki_id": qdrant.NewValueString(w.ID), + "name": qdrant.NewValueString(w.Name), + "full_text": qdrant.NewValueString(text), + }, + }}, + }) + return err +} + +func deleteWikiVector(ctx context.Context, cfg *plantModel.SysAiConfig, wikiID string) error { + conn, qdCtx, err := newQdrantConn(cfg) + if err != nil { + return err + } + defer conn.Close() + ptsClient := qdrant.NewPointsClient(conn) + + _, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{ + CollectionName: cfg.QdrantCollection, + Points: &qdrant.PointsSelector{ + PointsSelectorOneOf: &qdrant.PointsSelector_Points{ + Points: &qdrant.PointsIdsList{ + Ids: []*qdrant.PointId{qdrant.NewID(wikiVectorID(wikiID))}, }, }, }, }) -} - -func deleteWikiVector(ctx context.Context, c config.Config, wikiID string) error { - return doQdrant(ctx, c, http.MethodPost, "/collections/"+c.Ai.QdrantCollection+"/points/delete?wait=true", map[string]interface{}{ - "points": []string{wikiVectorID(wikiID)}, - }) + return err } diff --git a/app/plant/rpc/internal/logic/syncAllWikiVectorLogic.go b/app/plant/rpc/internal/logic/syncAllWikiVectorLogic.go index 7254a58..c8bd204 100644 --- a/app/plant/rpc/internal/logic/syncAllWikiVectorLogic.go +++ b/app/plant/rpc/internal/logic/syncAllWikiVectorLogic.go @@ -23,7 +23,11 @@ func NewSyncAllWikiVectorLogic(ctx context.Context, svcCtx *svc.ServiceContext) } func (l *SyncAllWikiVectorLogic) SyncAllWikiVector(in *plant.PageReq) (*plant.CommonResp, error) { - if l.svcCtx.Config.Ai.EmbeddingApiUrl == "" || l.svcCtx.Config.Ai.QdrantUrl == "" || l.svcCtx.Config.Ai.QdrantCollection == "" { + dbCfg, err := getActiveAiConfig(l.svcCtx.DB) + if err != nil { + return nil, err + } + if dbCfg.EmbeddingApiUrl == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" { return nil, errors.New("AI/RAG 未配置 EmbeddingApiUrl、QdrantUrl 或 QdrantCollection") } var wikis []plantModel.Wiki @@ -32,7 +36,7 @@ func (l *SyncAllWikiVectorLogic) SyncAllWikiVector(in *plant.PageReq) (*plant.Co } success := 0 for _, wiki := range wikis { - if err := upsertWikiVector(l.ctx, l.svcCtx.Config, wiki); err != nil { + if err := upsertWikiVector(l.ctx, dbCfg, wiki); err != nil { l.Logger.Errorf("sync wiki vector failed, wiki_id=%s, err=%v", wiki.ID, err) continue } diff --git a/app/plant/rpc/internal/logic/syncWikiVectorLogic.go b/app/plant/rpc/internal/logic/syncWikiVectorLogic.go index c985635..898518c 100644 --- a/app/plant/rpc/internal/logic/syncWikiVectorLogic.go +++ b/app/plant/rpc/internal/logic/syncWikiVectorLogic.go @@ -25,14 +25,18 @@ func (l *SyncWikiVectorLogic) SyncWikiVector(in *plant.SyncWikiVectorReq) (*plan if in.WikiId == "" { return nil, errors.New("wikiId 不能为空") } - if l.svcCtx.Config.Ai.EmbeddingApiUrl == "" || l.svcCtx.Config.Ai.QdrantUrl == "" || l.svcCtx.Config.Ai.QdrantCollection == "" { + dbCfg, err := getActiveAiConfig(l.svcCtx.DB) + if err != nil { + return nil, err + } + if dbCfg.EmbeddingApiUrl == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" { return nil, errors.New("AI/RAG 未配置 EmbeddingApiUrl、QdrantUrl 或 QdrantCollection") } var wiki plantModel.Wiki if err := l.svcCtx.DB.Where("id = ?", in.WikiId).First(&wiki).Error; err != nil { return nil, err } - if err := upsertWikiVector(l.ctx, l.svcCtx.Config, wiki); err != nil { + if err := upsertWikiVector(l.ctx, dbCfg, wiki); err != nil { return nil, err } if err := l.svcCtx.DB.Model(&plantModel.Wiki{}).Where("id = ?", in.WikiId).Update("is_vector_synced", true).Error; err != nil { diff --git a/app/plant/rpc/internal/svc/serviceContext.go b/app/plant/rpc/internal/svc/serviceContext.go index fa98489..f72515f 100644 --- a/app/plant/rpc/internal/svc/serviceContext.go +++ b/app/plant/rpc/internal/svc/serviceContext.go @@ -40,7 +40,6 @@ func NewServiceContext(c config.Config) *ServiceContext { &plantModel.PostLike{}, &plantModel.PostOss{}, &plantModel.Topic{}, - &plantModel.OcrLog{}, &plantModel.MediaCheckResult{}, &plantModel.ExchangeItem{}, &plantModel.ExchangeOrder{}, @@ -50,6 +49,8 @@ func NewServiceContext(c config.Config) *ServiceContext { &plantModel.AiChatHistory{}, &plantModel.GrowthRecordOss{}, &plantModel.Banner{}, + &plantModel.SysAiConfig{}, + &plantModel.ClassifyRecord{}, ); err != nil { logx.Errorf("数据库迁移失败: %v", err) } diff --git a/common/model/base_model.go b/common/model/base_model.go index 6650120..919d6a4 100644 --- a/common/model/base_model.go +++ b/common/model/base_model.go @@ -18,7 +18,10 @@ type BaseModel struct { // BeforeCreate 创建前自动生成雪花ID func (m *BaseModel) BeforeCreate(db *gorm.DB) (err error) { - db.Statement.SetColumn("id", uniqueid.GenerateID()) + if m.ID == "" { + m.ID = uniqueid.GenerateID() + } + db.Statement.SetColumn("id", m.ID) return } diff --git a/zero-gateway b/zero-gateway deleted file mode 100755 index ebc74dd..0000000 Binary files a/zero-gateway and /dev/null differ