diff --git a/benchmark-proxy/main.go b/benchmark-proxy/main.go index f40cb0f1..a377583c 100644 --- a/benchmark-proxy/main.go +++ b/benchmark-proxy/main.go @@ -15,6 +15,8 @@ import ( "strings" "sync" "time" + + "github.com/gorilla/websocket" ) // Simple structure to extract just the method from JSON-RPC requests @@ -36,15 +38,27 @@ type ResponseStats struct { Method string // Added method field } +// WebSocketStats tracks information about websocket connections +type WebSocketStats struct { + Backend string + Error error + ConnectTime time.Duration + IsActive bool + MessagesSent int + MessagesReceived int +} + // StatsCollector maintains statistics for periodic summaries type StatsCollector struct { - mu sync.Mutex - requestStats []ResponseStats - methodStats map[string][]time.Duration // Track durations by method - totalRequests int - errorCount int - startTime time.Time - summaryInterval time.Duration + mu sync.Mutex + requestStats []ResponseStats + methodStats map[string][]time.Duration // Track durations by method + totalRequests int + errorCount int + wsConnections []WebSocketStats // Track websocket connections + totalWsConnections int + startTime time.Time + summaryInterval time.Duration } func NewStatsCollector(summaryInterval time.Duration) *StatsCollector { @@ -84,6 +98,18 @@ func (sc *StatsCollector) AddStats(stats []ResponseStats, totalDuration time.Dur sc.totalRequests++ } +func (sc *StatsCollector) AddWebSocketStats(stats WebSocketStats) { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.wsConnections = append(sc.wsConnections, stats) + sc.totalWsConnections++ + + if stats.Error != nil { + sc.errorCount++ + } +} + func (sc *StatsCollector) periodicSummary() { ticker := time.NewTicker(sc.summaryInterval) defer ticker.Stop() @@ -100,8 +126,9 @@ func (sc *StatsCollector) printSummary() { uptime := time.Since(sc.startTime) fmt.Printf("\n=== BENCHMARK PROXY SUMMARY ===\n") fmt.Printf("Uptime: %s\n", uptime.Round(time.Second)) - fmt.Printf("Total Requests: %d\n", sc.totalRequests) - fmt.Printf("Error Rate: %.2f%%\n", float64(sc.errorCount)/float64(sc.totalRequests)*100) + fmt.Printf("Total HTTP Requests: %d\n", sc.totalRequests) + fmt.Printf("Total WebSocket Connections: %d\n", sc.totalWsConnections) + fmt.Printf("Error Rate: %.2f%%\n", float64(sc.errorCount)/float64(sc.totalRequests+sc.totalWsConnections)*100) // Calculate response time statistics for primary backend var primaryDurations []time.Duration @@ -165,18 +192,18 @@ func (sc *StatsCollector) printSummary() { } avg := sum / time.Duration(len(durations)) - min := durations[0] + minDuration := durations[0] max := durations[len(durations)-1] // Only calculate percentiles if we have enough samples - p50 := min - p90 := min - p99 := min + p50 := minDuration + p90 := minDuration + p99 := minDuration if len(durations) >= 2 { p50idx := len(durations) * 50 / 100 p90idx := len(durations) * 90 / 100 - p99idx := min(len(durations)-1, len(durations)*99/100) + p99idx := minInt(len(durations)-1, len(durations)*99/100) p50 = durations[p50idx] p90 = durations[p90idx] @@ -184,7 +211,7 @@ func (sc *StatsCollector) printSummary() { } fmt.Printf(" %-20s Count: %-5d Avg: %-10s Min: %-10s Max: %-10s p50: %-10s p90: %-10s p99: %-10s\n", - method, len(durations), avg, min, max, p50, p90, p99) + method, len(durations), avg, minDuration, max, p50, p90, p99) } } @@ -201,10 +228,15 @@ func (sc *StatsCollector) printSummary() { sc.methodStats[method] = durations[len(durations)-1000:] } } + + // Keep only the last 1000 websocket connections to prevent unlimited memory growth + if len(sc.wsConnections) > 1000 { + sc.wsConnections = sc.wsConnections[len(sc.wsConnections)-1000:] + } } // Helper function to avoid potential index out of bounds -func min(a, b int) int { +func minInt(a, b int) int { if a < b { return a } @@ -261,9 +293,23 @@ func main() { }, } + // Configure websocket upgrader + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // Allow all origins + CheckOrigin: func(r *http.Request) bool { return true }, + } + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - stats := handleRequest(w, r, backends, client) - statsCollector.AddStats(stats, 0) // The 0 is a placeholder, we're not using totalDuration in the collector + // Check if this is a WebSocket upgrade request + if websocket.IsWebSocketUpgrade(r) { + handleWebSocketRequest(w, r, backends, client, &upgrader, statsCollector) + } else { + // Handle regular HTTP request + stats := handleRequest(w, r, backends, client) + statsCollector.AddStats(stats, 0) // The 0 is a placeholder, we're not using totalDuration in the collector + } }) log.Fatal(http.ListenAndServe(listenAddr, nil)) @@ -423,6 +469,114 @@ func logResponseStats(totalDuration time.Duration, stats []ResponseStats) { fmt.Println(strings.Join(parts, " | ")) } +// handleWebSocketRequest manages WebSocket proxying +func handleWebSocketRequest(w http.ResponseWriter, r *http.Request, backends []Backend, + httpClient *http.Client, upgrader *websocket.Upgrader, + statsCollector *StatsCollector) { + // Upgrade the client connection + clientConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to upgrade client connection: %v", err) + return + } + defer clientConn.Close() + + // Connect to all backends + var wg sync.WaitGroup + for _, backend := range backends { + wg.Add(1) + go func(b Backend) { + defer wg.Done() + + // Create backend URL with ws/wss instead of http/https + backendURL := strings.Replace(b.URL, "http://", "ws://", 1) + backendURL = strings.Replace(backendURL, "https://", "wss://", 1) + + // Copy headers for the dialer + header := http.Header{} + for name, values := range r.Header { + for _, value := range values { + header.Add(name, value) + } + } + + startTime := time.Now() + // Connect to backend WebSocket + dialer := websocket.DefaultDialer + backendConn, resp, err := dialer.Dial(backendURL, header) + connectDuration := time.Since(startTime) + + stats := WebSocketStats{ + Backend: b.Name, + ConnectTime: connectDuration, + IsActive: false, + } + + if err != nil { + status := 0 + if resp != nil { + status = resp.StatusCode + } + log.Printf("Failed to connect to backend %s: %v (status: %d)", b.Name, err, status) + stats.Error = err + statsCollector.AddWebSocketStats(stats) + return + } + defer backendConn.Close() + + stats.IsActive = true + statsCollector.AddWebSocketStats(stats) + + // If this is the primary backend, set up bidirectional proxying + if b.Role == "primary" { + // Forward messages from client to primary backend + go func() { + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + log.Printf("Error reading from client: %v", err) + return + } + + err = backendConn.WriteMessage(messageType, message) + if err != nil { + log.Printf("Error writing to primary backend: %v", err) + return + } + } + }() + + // Forward messages from primary backend to client + for { + messageType, message, err := backendConn.ReadMessage() + if err != nil { + log.Printf("Error reading from primary backend: %v", err) + return + } + + err = clientConn.WriteMessage(messageType, message) + if err != nil { + log.Printf("Error writing to client: %v", err) + return + } + } + } else { + // For secondary backends, just read and discard messages + for { + _, _, err := backendConn.ReadMessage() + if err != nil { + log.Printf("Secondary backend %s connection closed: %v", b.Name, err) + return + } + } + } + }(backend) + } + + // Wait for all connections to terminate + wg.Wait() +} + func getEnv(key, fallback string) string { if value, exists := os.LookupEnv(key); exists { return value