package middleware import ( "bytes" "context" "io" "net/http" "strings" "time" "sundynix-micro-go/app/system/rpc/system" "sundynix-micro-go/app/system/rpc/systemservice" jwtUtil "sundynix-micro-go/common/utils/jwt" "github.com/zeromicro/go-zero/core/logx" ) // OperationLogMiddleware 操作日志中间件(异步写入 system-rpc) type OperationLogMiddleware struct { systemRpc systemservice.SystemService jwtSecret string logChan chan *system.CreateOperationRecordReq } func NewOperationLogMiddleware(systemRpc systemservice.SystemService, jwtSecret string) *OperationLogMiddleware { m := &OperationLogMiddleware{ systemRpc: systemRpc, jwtSecret: jwtSecret, logChan: make(chan *system.CreateOperationRecordReq, 500), } // 启动异步消费者,避免阻塞请求 go m.consumer() return m } func (m *OperationLogMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 跳过健康检查和 OPTIONS 预检 if r.URL.Path == "/health" || r.Method == http.MethodOptions { next(w, r) return } startTime := time.Now() // 读取并缓存请求体(限制大小,跳过文件上传) var bodyStr string if r.Body != nil && !strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { bodyBytes, _ := io.ReadAll(io.LimitReader(r.Body, 2048)) bodyStr = string(bodyBytes) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 恢复以供上游读取 } clientId := r.Header.Get("X-Client-Id") userId := m.extractUserId(r) clientIP := getClientIP(r) // 包装 ResponseWriter 捕获响应状态码和响应体 rw := &responseCapture{ResponseWriter: w, statusCode: http.StatusOK} next(rw, r) latency := time.Since(startTime) record := &system.CreateOperationRecordReq{ ClientId: clientId, Ip: clientIP, Method: r.Method, Path: r.URL.Path, Status: int32(rw.statusCode), Latency: latency.Nanoseconds(), Agent: truncate(r.UserAgent(), 500), ErrorMessage: rw.errorMsg(), Body: truncate(bodyStr, 2000), Resp: truncate(rw.body.String(), 2000), UserId: userId, } // 异步写入,不阻塞响应 select { case m.logChan <- record: default: logx.Error("[zero-gateway] 操作日志缓冲区满,丢弃日志") } } } // consumer 异步消费操作日志并通过 system-rpc 写入数据库 func (m *OperationLogMiddleware) consumer() { for record := range m.logChan { _, err := m.systemRpc.CreateOperationRecord(context.Background(), record) if err != nil { logx.Errorf("[zero-gateway] 写入操作日志失败: %v", err) } } } // extractUserId 从 Authorization 头解析 JWT 获取 userId func (m *OperationLogMiddleware) extractUserId(r *http.Request) string { // 优先从鉴权中间件注入的请求头获取(避免重复解析 JWT) if uid := r.Header.Get("X-User-Id"); uid != "" { return uid } // fallback: 自己解析 authHeader := r.Header.Get("Authorization") if authHeader == "" { return "" } tokenStr := jwtUtil.GetTokenFromHeader(authHeader) if tokenStr == "" { return "" } j := jwtUtil.NewJWT(m.jwtSecret) claims, err := j.ParseToken(tokenStr) if err != nil { return "" } return claims.BaseClaims.ID } // responseCapture 捕获响应状态码和响应体 type responseCapture struct { http.ResponseWriter statusCode int body bytes.Buffer } func (rc *responseCapture) WriteHeader(code int) { rc.statusCode = code rc.ResponseWriter.WriteHeader(code) } func (rc *responseCapture) Write(b []byte) (int, error) { if rc.body.Len() < 2048 { rc.body.Write(b) } return rc.ResponseWriter.Write(b) } func (rc *responseCapture) errorMsg() string { if rc.statusCode >= 400 { return rc.body.String() } return "" } // getClientIP 获取真实客户端 IP func getClientIP(r *http.Request) string { if ip := r.Header.Get("X-Forwarded-For"); ip != "" { return strings.Split(ip, ",")[0] } if ip := r.Header.Get("X-Real-Ip"); ip != "" { return ip } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr } // truncate 截断字符串,防止写入过长内容 func truncate(s string, maxLen int) string { if len(s) > maxLen { return s[:maxLen] + "..." } return s }