Files

58 lines
1.4 KiB
Go

package middleware
import (
"net/http"
"strings"
)
// CorsMiddleware 跨域中间件
type CorsMiddleware struct {
allowOrigins []string
allowMethods []string
allowHeaders []string
}
func NewCorsMiddleware(origins, methods, headers []string) *CorsMiddleware {
return &CorsMiddleware{
allowOrigins: origins,
allowMethods: methods,
allowHeaders: headers,
}
}
func (m *CorsMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
allowed := false
for _, o := range m.allowOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
if allowed {
allowOrigin := origin
if len(m.allowOrigins) == 1 && m.allowOrigins[0] == "*" {
allowOrigin = "*"
}
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowHeaders, ", "))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "3600")
// 允许前端 JS 读取自定义响应头(默认跨域只能读 6 个安全头)
w.Header().Set("Access-Control-Expose-Headers", "X-Refresh-Token")
}
// 预检请求直接返回 204
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next(w, r)
}
}