package middleware import ( "bytes" "context" "io" "net/http" "strings" "time" "sundynix-micro-go/app/system/rpc/system" "sundynix-micro-go/app/system/rpc/systemservice" jwtUtil "sundynix-micro-go/common/utils/jwt" "github.com/zeromicro/go-zero/core/logx" ) // OperationLogMiddleware 操作日志中间件 type OperationLogMiddleware struct { systemRpc systemservice.SystemService jwtSecret string logChan chan *system.CreateOperationRecordReq } // NewOperationLogMiddleware 创建操作日志中间件 func NewOperationLogMiddleware(systemRpc systemservice.SystemService, jwtSecret string) *OperationLogMiddleware { m := &OperationLogMiddleware{ systemRpc: systemRpc, jwtSecret: jwtSecret, logChan: make(chan *system.CreateOperationRecordReq, 500), } // 启动异步消费者,避免阻塞请求 go m.consumer() return m } // Handle 中间件处理函数 func (m *OperationLogMiddleware) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 跳过健康检查、OPTIONS 预检 if r.URL.Path == "/health" || r.Method == http.MethodOptions { next.ServeHTTP(w, r) return } startTime := time.Now() // 读取请求体(限制大小,跳过文件上传) var bodyStr string if r.Body != nil && !strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { bodyBytes, _ := io.ReadAll(io.LimitReader(r.Body, 2048)) bodyStr = string(bodyBytes) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } clientId := r.Header.Get("X-Client-Id") userId := m.extractUserId(r) clientIP := getClientIP(r) // 包装 ResponseWriter 捕获响应 rw := &responseCapture{ResponseWriter: w, statusCode: http.StatusOK} // 执行后续 handler next.ServeHTTP(rw, r) latency := time.Since(startTime) // 构建日志请求 record := &system.CreateOperationRecordReq{ ClientId: clientId, Ip: clientIP, Method: r.Method, Path: r.URL.Path, Status: int32(rw.statusCode), Latency: latency.Nanoseconds(), Agent: truncate(r.UserAgent(), 500), ErrorMessage: rw.errorMsg(), Body: truncate(bodyStr, 2000), Resp: truncate(rw.body.String(), 2000), UserId: userId, } // 异步发送,不阻塞响应 select { case m.logChan <- record: default: logx.Error("操作日志缓冲区满,丢弃日志") } }) } // consumer 异步消费日志并通过 system-rpc 写入 func (m *OperationLogMiddleware) consumer() { for record := range m.logChan { _, err := m.systemRpc.CreateOperationRecord(context.Background(), record) if err != nil { logx.Errorf("写入操作日志失败: %v", err) } } } // extractUserId 从请求中获取 userId // 优先从鉴权中间件注入的 X-User-Id 头获取(避免重复解析 JWT) func (m *OperationLogMiddleware) extractUserId(r *http.Request) string { if uid := r.Header.Get("X-User-Id"); uid != "" { return uid } // fallback: 自己解析 JWT authHeader := r.Header.Get("Authorization") if authHeader == "" { return "" } tokenStr := jwtUtil.GetTokenFromHeader(authHeader) if tokenStr == "" { return "" } j := jwtUtil.NewJWT(m.jwtSecret) claims, err := j.ParseToken(tokenStr) if err != nil { return "" } return claims.BaseClaims.ID } // responseCapture 捕获响应状态码和响应体 type responseCapture struct { http.ResponseWriter statusCode int body bytes.Buffer } func (rc *responseCapture) WriteHeader(code int) { rc.statusCode = code rc.ResponseWriter.WriteHeader(code) } func (rc *responseCapture) Write(b []byte) (int, error) { if rc.body.Len() < 2048 { rc.body.Write(b) } return rc.ResponseWriter.Write(b) } func (rc *responseCapture) errorMsg() string { if rc.statusCode >= 400 { return rc.body.String() } return "" } // getClientIP 获取真实客户端 IP func getClientIP(r *http.Request) string { if ip := r.Header.Get("X-Forwarded-For"); ip != "" { return strings.Split(ip, ",")[0] } if ip := r.Header.Get("X-Real-Ip"); ip != "" { return ip } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr } // truncate 截断字符串 func truncate(s string, maxLen int) string { if len(s) > maxLen { return s[:maxLen] + "..." } return s }