diff --git a/benchmark-proxy/main.go b/benchmark-proxy/main.go index a377583c..30f2bbbe 100644 --- a/benchmark-proxy/main.go +++ b/benchmark-proxy/main.go @@ -492,12 +492,33 @@ func handleWebSocketRequest(w http.ResponseWriter, r *http.Request, backends []B backendURL := strings.Replace(b.URL, "http://", "ws://", 1) backendURL = strings.Replace(backendURL, "https://", "wss://", 1) - // Copy headers for the dialer + // Create a clean header map for the dialer + // Instead of copying all headers which can cause duplicates header := http.Header{} - for name, values := range r.Header { - for _, value := range values { - header.Add(name, value) - } + + // Only copy specific headers needed for the WebSocket connection + // and avoid the problematic ones like "Connection" and "Upgrade" + if host := r.Header.Get("Host"); host != "" { + header.Set("Host", host) + } + if origin := r.Header.Get("Origin"); origin != "" { + header.Set("Origin", origin) + } + if secWebSocketKey := r.Header.Get("Sec-WebSocket-Key"); secWebSocketKey != "" { + header.Set("Sec-WebSocket-Key", secWebSocketKey) + } + if secWebSocketVersion := r.Header.Get("Sec-WebSocket-Version"); secWebSocketVersion != "" { + header.Set("Sec-WebSocket-Version", secWebSocketVersion) + } + if secWebSocketProtocol := r.Header.Get("Sec-WebSocket-Protocol"); secWebSocketProtocol != "" { + header.Set("Sec-WebSocket-Protocol", secWebSocketProtocol) + } + if secWebSocketExtensions := r.Header.Get("Sec-WebSocket-Extensions"); secWebSocketExtensions != "" { + header.Set("Sec-WebSocket-Extensions", secWebSocketExtensions) + } + // Add user-agent if present + if userAgent := r.Header.Get("User-Agent"); userAgent != "" { + header.Set("User-Agent", userAgent) } startTime := time.Now()