58 lines
1.4 KiB
Go
58 lines
1.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// CorsMiddleware 跨域中间件
|
|
type CorsMiddleware struct {
|
|
allowOrigins []string
|
|
allowMethods []string
|
|
allowHeaders []string
|
|
}
|
|
|
|
func NewCorsMiddleware(origins, methods, headers []string) *CorsMiddleware {
|
|
return &CorsMiddleware{
|
|
allowOrigins: origins,
|
|
allowMethods: methods,
|
|
allowHeaders: headers,
|
|
}
|
|
}
|
|
|
|
func (m *CorsMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
allowed := false
|
|
for _, o := range m.allowOrigins {
|
|
if o == "*" || o == origin {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if allowed {
|
|
allowOrigin := origin
|
|
if len(m.allowOrigins) == 1 && m.allowOrigins[0] == "*" {
|
|
allowOrigin = "*"
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowMethods, ", "))
|
|
w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowHeaders, ", "))
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Max-Age", "3600")
|
|
// 允许前端 JS 读取自定义响应头(默认跨域只能读 6 个安全头)
|
|
w.Header().Set("Access-Control-Expose-Headers", "X-Refresh-Token")
|
|
}
|
|
|
|
// 预检请求直接返回 204
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
}
|