package mcp import ( "context" "fmt" "io" "net" "net/http" "net/url" "os" "strings" "time" "github.com/sundynix/sundynix-shared/contract" ) const ( extTimeout = 10 * time.Second extMaxBytes = 256 * 1024 // 响应体读取上限 ) // externalAPI 是通用出站 HTTP 工具(GET/POST)。带 SSRF 防护:拒环回/内网/链路本地/ // 云元数据地址;可选 EXTERNAL_API_ALLOWLIST 收窄到白名单主机。限超时 + 限响应体大小 + 限重定向。 func (g *Gateway) externalAPI(ctx context.Context, call *contract.ToolCall) *contract.ToolResult { raw, _ := call.Args["url"].(string) raw = strings.TrimSpace(raw) if raw == "" { return &contract.ToolResult{OK: false, Error: "external_api: url 必填"} } method := strings.ToUpper(strings.TrimSpace(fmt.Sprint(call.Args["method"]))) if method == "" || method == "" { method = "GET" } if method != "GET" && method != "POST" { return &contract.ToolResult{OK: false, Error: "external_api: 仅支持 GET/POST"} } allow := extAllowlist() if reason, ok := validateExternalURL(raw, allow); !ok { return &contract.ToolResult{OK: false, Error: "external_api: URL 被拦截 —— " + reason} } var body io.Reader if b, _ := call.Args["body"].(string); b != "" { body = strings.NewReader(b) } req, err := http.NewRequestWithContext(ctx, method, raw, body) if err != nil { return &contract.ToolResult{OK: false, Error: "external_api: " + err.Error()} } if hm, ok := call.Args["headers"].(map[string]any); ok { for k, v := range hm { req.Header.Set(k, fmt.Sprint(v)) } } client := &http.Client{ Timeout: extTimeout, CheckRedirect: func(r *http.Request, via []*http.Request) error { if len(via) >= 3 { return fmt.Errorf("重定向过多") } if reason, ok := validateExternalURL(r.URL.String(), allow); !ok { return fmt.Errorf("重定向被拦截:%s", reason) } return nil }, } resp, err := client.Do(req) if err != nil { return &contract.ToolResult{OK: false, Error: "external_api: " + err.Error()} } defer resp.Body.Close() data, _ := io.ReadAll(io.LimitReader(resp.Body, extMaxBytes)) return &contract.ToolResult{OK: true, Content: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, string(data))} } // extAllowlist 读取 EXTERNAL_API_ALLOWLIST(逗号分隔主机);空则不限主机(仍有 SSRF 防护)。 func extAllowlist() []string { v := strings.TrimSpace(os.Getenv("EXTERNAL_API_ALLOWLIST")) if v == "" { return nil } var out []string for _, p := range strings.Split(v, ",") { if p = strings.TrimSpace(p); p != "" { out = append(out, p) } } return out } // validateExternalURL 校验出站 URL:scheme 限 http/https;可选白名单;解析出的 IP 不得为 // 环回/内网/链路本地/未指定(防 SSRF 打内部服务与 169.254.169.254 云元数据)。 func validateExternalURL(raw string, allow []string) (reason string, ok bool) { u, err := url.Parse(raw) if err != nil || (u.Scheme != "http" && u.Scheme != "https") { return "仅支持 http/https", false } host := u.Hostname() if host == "" { return "缺少主机", false } if len(allow) > 0 && !hostAllowed(host, allow) { return "主机不在允许清单", false } ips, err := net.LookupIP(host) if err != nil || len(ips) == 0 { return "域名解析失败", false } for _, ip := range ips { if isBlockedIP(ip) { return "禁止访问内网/环回/元数据地址", false } } return "", true } func hostAllowed(host string, allow []string) bool { host = strings.ToLower(host) for _, a := range allow { a = strings.ToLower(a) if host == a || strings.HasSuffix(host, "."+a) { return true } } return false } // isBlockedIP 判断 IP 是否属于禁止出站的范围(SSRF 防护)。 func isBlockedIP(ip net.IP) bool { return ip.IsLoopback() || // 127.0.0.0/8, ::1 ip.IsPrivate() || // 10/8, 172.16/12, 192.168/16, fc00::/7 ip.IsLinkLocalUnicast() || // 169.254/16(含云元数据), fe80::/10 ip.IsLinkLocalMulticast() || ip.IsUnspecified() // 0.0.0.0, :: }