package middleware import ( "net/http" "strings" ) // CorsMiddleware 跨域中间件 type CorsMiddleware struct { allowOrigins []string allowMethods []string allowHeaders []string } func NewCorsMiddleware(origins, methods, headers []string) *CorsMiddleware { return &CorsMiddleware{ allowOrigins: origins, allowMethods: methods, allowHeaders: headers, } } func (m *CorsMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin == "" { next(w, r) return } allowed := false for _, o := range m.allowOrigins { if o == "*" || o == origin { allowed = true break } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowHeaders, ", ")) w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "3600") // 允许前端 JS 读取自定义响应头(用于 Token 自动续期) w.Header().Set("Access-Control-Expose-Headers", "X-Refresh-Token") } // 预检请求直接返回 if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next(w, r) } }