package middleware import ( "net/http" "strings" ) // CorsMiddleware 跨域中间件 type CorsMiddleware struct { allowOrigins []string allowMethods []string allowHeaders []string } func NewCorsMiddleware(origins, methods, headers []string) *CorsMiddleware { return &CorsMiddleware{ allowOrigins: origins, allowMethods: methods, allowHeaders: headers, } } func (m *CorsMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") allowed := false for _, o := range m.allowOrigins { if o == "*" || o == origin { allowed = true break } } if allowed { allowOrigin := origin if len(m.allowOrigins) == 1 && m.allowOrigins[0] == "*" { allowOrigin = "*" } w.Header().Set("Access-Control-Allow-Origin", allowOrigin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowHeaders, ", ")) w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "3600") // 允许前端 JS 读取自定义响应头(默认跨域只能读 6 个安全头) w.Header().Set("Access-Control-Expose-Headers", "X-Refresh-Token") } // 预检请求直接返回 204 if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next(w, r) } }