feat: rbac迁移完成,并已部署至dev服务器
This commit is contained in:
@@ -0,0 +1,142 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jwtUtil "sundynix-micro-go/common/utils/jwt"
|
||||
|
||||
jwtv5 "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// RefreshTokenHeader 续期后新 Token 放在此响应头里,前端读取后静默替换
|
||||
const RefreshTokenHeader = "X-Refresh-Token"
|
||||
|
||||
// AuthMiddleware 网关鉴权 + 自动续期中间件
|
||||
type AuthMiddleware struct {
|
||||
jwtSecret string
|
||||
whitelist map[string]bool
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(jwtSecret string, whitelist []string) *AuthMiddleware {
|
||||
wl := make(map[string]bool, len(whitelist))
|
||||
for _, p := range whitelist {
|
||||
wl[p] = true
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
jwtSecret: jwtSecret,
|
||||
whitelist: wl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// OPTIONS 预检直接放行
|
||||
if r.Method == http.MethodOptions {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 白名单路径放行(支持精确匹配和 /* 前缀通配)
|
||||
if m.isWhitelisted(r.URL.Path) {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 Authorization 头
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeUnauthorized(w, "缺少 Authorization 请求头")
|
||||
return
|
||||
}
|
||||
tokenStr := jwtUtil.GetTokenFromHeader(authHeader)
|
||||
if tokenStr == "" {
|
||||
writeUnauthorized(w, "Token 格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
j := jwtUtil.NewJWT(m.jwtSecret)
|
||||
claims, err := j.ParseToken(tokenStr)
|
||||
if err != nil {
|
||||
logx.Infof("[zero-gateway] JWT 解析失败: %v, path: %s", err, r.URL.Path)
|
||||
writeUnauthorized(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息透传到上游,避免上游重复解析 JWT
|
||||
r.Header.Set("X-User-Id", claims.BaseClaims.ID)
|
||||
r.Header.Set("X-User-Account", claims.BaseClaims.Account)
|
||||
|
||||
// ---- 滑动窗口续期 ----
|
||||
// 剩余有效时间 < BufferTime(存储在 token claims 里),说明进入缓冲窗口
|
||||
if newToken, ok := m.tryRefresh(j, claims); ok {
|
||||
// 在响应头写入新 Token,前端收到后静默替换本地存储的 Token
|
||||
w.Header().Set(RefreshTokenHeader, newToken)
|
||||
logx.Infof("[zero-gateway] Token 已续期, userId: %s", claims.BaseClaims.ID)
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRefresh 判断是否需要续期,需要则签发新 Token 并返回
|
||||
// 续期规则:剩余有效时间 < BufferTime → 以原始有效时长(ExpiresAt - NotBefore)重新签发
|
||||
func (m *AuthMiddleware) tryRefresh(j *jwtUtil.JWT, claims *jwtUtil.CustomClaims) (string, bool) {
|
||||
bufferTime := time.Duration(claims.BufferTime) * time.Second
|
||||
expiresAt := claims.RegisteredClaims.ExpiresAt.Time
|
||||
remaining := time.Until(expiresAt)
|
||||
|
||||
// 未进入缓冲窗口,无需续期
|
||||
if remaining >= bufferTime {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 计算原始有效时长:ExpiresAt - NotBefore ≈ 当初登录时配置的 activeTimeout
|
||||
notBefore := claims.RegisteredClaims.NotBefore.Time
|
||||
originalDuration := expiresAt.Sub(notBefore)
|
||||
|
||||
// 构建新 Claims,保持 BaseClaims 和 BufferTime 不变,重新计算有效期
|
||||
newClaims := jwtUtil.CustomClaims{
|
||||
BaseClaims: claims.BaseClaims,
|
||||
BufferTime: claims.BufferTime,
|
||||
RegisteredClaims: jwtv5.RegisteredClaims{
|
||||
Audience: claims.RegisteredClaims.Audience,
|
||||
Issuer: claims.RegisteredClaims.Issuer,
|
||||
NotBefore: jwtv5.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwtv5.NewNumericDate(time.Now().Add(originalDuration)),
|
||||
},
|
||||
}
|
||||
|
||||
newToken, err := j.CreateToken(newClaims)
|
||||
if err != nil {
|
||||
logx.Errorf("[zero-gateway] Token 续期失败: %v", err)
|
||||
return "", false
|
||||
}
|
||||
return newToken, true
|
||||
}
|
||||
|
||||
// isWhitelisted 支持精确匹配和 /* 前缀通配
|
||||
func (m *AuthMiddleware) isWhitelisted(path string) bool {
|
||||
if m.whitelist[path] {
|
||||
return true
|
||||
}
|
||||
for p := range m.whitelist {
|
||||
if strings.HasSuffix(p, "/*") && strings.HasPrefix(path, strings.TrimSuffix(p, "*")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeUnauthorized 返回统一的 401 响应
|
||||
func writeUnauthorized(w http.ResponseWriter, msg string) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"code": 401,
|
||||
"msg": msg,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CorsMiddleware 跨域中间件
|
||||
type CorsMiddleware struct {
|
||||
allowOrigins []string
|
||||
allowMethods []string
|
||||
allowHeaders []string
|
||||
}
|
||||
|
||||
func NewCorsMiddleware(origins, methods, headers []string) *CorsMiddleware {
|
||||
return &CorsMiddleware{
|
||||
allowOrigins: origins,
|
||||
allowMethods: methods,
|
||||
allowHeaders: headers,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CorsMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
allowed := false
|
||||
for _, o := range m.allowOrigins {
|
||||
if o == "*" || o == origin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
allowOrigin := origin
|
||||
if len(m.allowOrigins) == 1 && m.allowOrigins[0] == "*" {
|
||||
allowOrigin = "*"
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowMethods, ", "))
|
||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowHeaders, ", "))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "3600")
|
||||
// 允许前端 JS 读取自定义响应头(默认跨域只能读 6 个安全头)
|
||||
w.Header().Set("Access-Control-Expose-Headers", "X-Refresh-Token")
|
||||
}
|
||||
|
||||
// 预检请求直接返回 204
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sundynix-micro-go/app/system/rpc/system"
|
||||
"sundynix-micro-go/app/system/rpc/systemservice"
|
||||
jwtUtil "sundynix-micro-go/common/utils/jwt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// OperationLogMiddleware 操作日志中间件(异步写入 system-rpc)
|
||||
type OperationLogMiddleware struct {
|
||||
systemRpc systemservice.SystemService
|
||||
jwtSecret string
|
||||
logChan chan *system.CreateOperationRecordReq
|
||||
}
|
||||
|
||||
func NewOperationLogMiddleware(systemRpc systemservice.SystemService, jwtSecret string) *OperationLogMiddleware {
|
||||
m := &OperationLogMiddleware{
|
||||
systemRpc: systemRpc,
|
||||
jwtSecret: jwtSecret,
|
||||
logChan: make(chan *system.CreateOperationRecordReq, 500),
|
||||
}
|
||||
// 启动异步消费者,避免阻塞请求
|
||||
go m.consumer()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *OperationLogMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 跳过健康检查和 OPTIONS 预检
|
||||
if r.URL.Path == "/health" || r.Method == http.MethodOptions {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 读取并缓存请求体(限制大小,跳过文件上传)
|
||||
var bodyStr string
|
||||
if r.Body != nil && !strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(r.Body, 2048))
|
||||
bodyStr = string(bodyBytes)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 恢复以供上游读取
|
||||
}
|
||||
|
||||
clientId := r.Header.Get("X-Client-Id")
|
||||
userId := m.extractUserId(r)
|
||||
clientIP := getClientIP(r)
|
||||
|
||||
// 包装 ResponseWriter 捕获响应状态码和响应体
|
||||
rw := &responseCapture{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next(rw, r)
|
||||
|
||||
latency := time.Since(startTime)
|
||||
|
||||
record := &system.CreateOperationRecordReq{
|
||||
ClientId: clientId,
|
||||
Ip: clientIP,
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Status: int32(rw.statusCode),
|
||||
Latency: latency.Nanoseconds(),
|
||||
Agent: truncate(r.UserAgent(), 500),
|
||||
ErrorMessage: rw.errorMsg(),
|
||||
Body: truncate(bodyStr, 2000),
|
||||
Resp: truncate(rw.body.String(), 2000),
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
// 异步写入,不阻塞响应
|
||||
select {
|
||||
case m.logChan <- record:
|
||||
default:
|
||||
logx.Error("[zero-gateway] 操作日志缓冲区满,丢弃日志")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// consumer 异步消费操作日志并通过 system-rpc 写入数据库
|
||||
func (m *OperationLogMiddleware) consumer() {
|
||||
for record := range m.logChan {
|
||||
_, err := m.systemRpc.CreateOperationRecord(context.Background(), record)
|
||||
if err != nil {
|
||||
logx.Errorf("[zero-gateway] 写入操作日志失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractUserId 从 Authorization 头解析 JWT 获取 userId
|
||||
func (m *OperationLogMiddleware) extractUserId(r *http.Request) string {
|
||||
// 优先从鉴权中间件注入的请求头获取(避免重复解析 JWT)
|
||||
if uid := r.Header.Get("X-User-Id"); uid != "" {
|
||||
return uid
|
||||
}
|
||||
// fallback: 自己解析
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
tokenStr := jwtUtil.GetTokenFromHeader(authHeader)
|
||||
if tokenStr == "" {
|
||||
return ""
|
||||
}
|
||||
j := jwtUtil.NewJWT(m.jwtSecret)
|
||||
claims, err := j.ParseToken(tokenStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return claims.BaseClaims.ID
|
||||
}
|
||||
|
||||
// responseCapture 捕获响应状态码和响应体
|
||||
type responseCapture struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body bytes.Buffer
|
||||
}
|
||||
|
||||
func (rc *responseCapture) WriteHeader(code int) {
|
||||
rc.statusCode = code
|
||||
rc.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rc *responseCapture) Write(b []byte) (int, error) {
|
||||
if rc.body.Len() < 2048 {
|
||||
rc.body.Write(b)
|
||||
}
|
||||
return rc.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (rc *responseCapture) errorMsg() string {
|
||||
if rc.statusCode >= 400 {
|
||||
return rc.body.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getClientIP 获取真实客户端 IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
return strings.Split(ip, ",")[0]
|
||||
}
|
||||
if ip := r.Header.Get("X-Real-Ip"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
addr := r.RemoteAddr
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// truncate 截断字符串,防止写入过长内容
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) > maxLen {
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user