now with websockets
This commit is contained in:
@@ -15,6 +15,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Simple structure to extract just the method from JSON-RPC requests
|
// Simple structure to extract just the method from JSON-RPC requests
|
||||||
@@ -36,6 +38,16 @@ type ResponseStats struct {
|
|||||||
Method string // Added method field
|
Method string // Added method field
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebSocketStats tracks information about websocket connections
|
||||||
|
type WebSocketStats struct {
|
||||||
|
Backend string
|
||||||
|
Error error
|
||||||
|
ConnectTime time.Duration
|
||||||
|
IsActive bool
|
||||||
|
MessagesSent int
|
||||||
|
MessagesReceived int
|
||||||
|
}
|
||||||
|
|
||||||
// StatsCollector maintains statistics for periodic summaries
|
// StatsCollector maintains statistics for periodic summaries
|
||||||
type StatsCollector struct {
|
type StatsCollector struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -43,6 +55,8 @@ type StatsCollector struct {
|
|||||||
methodStats map[string][]time.Duration // Track durations by method
|
methodStats map[string][]time.Duration // Track durations by method
|
||||||
totalRequests int
|
totalRequests int
|
||||||
errorCount int
|
errorCount int
|
||||||
|
wsConnections []WebSocketStats // Track websocket connections
|
||||||
|
totalWsConnections int
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
summaryInterval time.Duration
|
summaryInterval time.Duration
|
||||||
}
|
}
|
||||||
@@ -84,6 +98,18 @@ func (sc *StatsCollector) AddStats(stats []ResponseStats, totalDuration time.Dur
|
|||||||
sc.totalRequests++
|
sc.totalRequests++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *StatsCollector) AddWebSocketStats(stats WebSocketStats) {
|
||||||
|
sc.mu.Lock()
|
||||||
|
defer sc.mu.Unlock()
|
||||||
|
|
||||||
|
sc.wsConnections = append(sc.wsConnections, stats)
|
||||||
|
sc.totalWsConnections++
|
||||||
|
|
||||||
|
if stats.Error != nil {
|
||||||
|
sc.errorCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (sc *StatsCollector) periodicSummary() {
|
func (sc *StatsCollector) periodicSummary() {
|
||||||
ticker := time.NewTicker(sc.summaryInterval)
|
ticker := time.NewTicker(sc.summaryInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -100,8 +126,9 @@ func (sc *StatsCollector) printSummary() {
|
|||||||
uptime := time.Since(sc.startTime)
|
uptime := time.Since(sc.startTime)
|
||||||
fmt.Printf("\n=== BENCHMARK PROXY SUMMARY ===\n")
|
fmt.Printf("\n=== BENCHMARK PROXY SUMMARY ===\n")
|
||||||
fmt.Printf("Uptime: %s\n", uptime.Round(time.Second))
|
fmt.Printf("Uptime: %s\n", uptime.Round(time.Second))
|
||||||
fmt.Printf("Total Requests: %d\n", sc.totalRequests)
|
fmt.Printf("Total HTTP Requests: %d\n", sc.totalRequests)
|
||||||
fmt.Printf("Error Rate: %.2f%%\n", float64(sc.errorCount)/float64(sc.totalRequests)*100)
|
fmt.Printf("Total WebSocket Connections: %d\n", sc.totalWsConnections)
|
||||||
|
fmt.Printf("Error Rate: %.2f%%\n", float64(sc.errorCount)/float64(sc.totalRequests+sc.totalWsConnections)*100)
|
||||||
|
|
||||||
// Calculate response time statistics for primary backend
|
// Calculate response time statistics for primary backend
|
||||||
var primaryDurations []time.Duration
|
var primaryDurations []time.Duration
|
||||||
@@ -165,18 +192,18 @@ func (sc *StatsCollector) printSummary() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
avg := sum / time.Duration(len(durations))
|
avg := sum / time.Duration(len(durations))
|
||||||
min := durations[0]
|
minDuration := durations[0]
|
||||||
max := durations[len(durations)-1]
|
max := durations[len(durations)-1]
|
||||||
|
|
||||||
// Only calculate percentiles if we have enough samples
|
// Only calculate percentiles if we have enough samples
|
||||||
p50 := min
|
p50 := minDuration
|
||||||
p90 := min
|
p90 := minDuration
|
||||||
p99 := min
|
p99 := minDuration
|
||||||
|
|
||||||
if len(durations) >= 2 {
|
if len(durations) >= 2 {
|
||||||
p50idx := len(durations) * 50 / 100
|
p50idx := len(durations) * 50 / 100
|
||||||
p90idx := len(durations) * 90 / 100
|
p90idx := len(durations) * 90 / 100
|
||||||
p99idx := min(len(durations)-1, len(durations)*99/100)
|
p99idx := minInt(len(durations)-1, len(durations)*99/100)
|
||||||
|
|
||||||
p50 = durations[p50idx]
|
p50 = durations[p50idx]
|
||||||
p90 = durations[p90idx]
|
p90 = durations[p90idx]
|
||||||
@@ -184,7 +211,7 @@ func (sc *StatsCollector) printSummary() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf(" %-20s Count: %-5d Avg: %-10s Min: %-10s Max: %-10s p50: %-10s p90: %-10s p99: %-10s\n",
|
fmt.Printf(" %-20s Count: %-5d Avg: %-10s Min: %-10s Max: %-10s p50: %-10s p90: %-10s p99: %-10s\n",
|
||||||
method, len(durations), avg, min, max, p50, p90, p99)
|
method, len(durations), avg, minDuration, max, p50, p90, p99)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,10 +228,15 @@ func (sc *StatsCollector) printSummary() {
|
|||||||
sc.methodStats[method] = durations[len(durations)-1000:]
|
sc.methodStats[method] = durations[len(durations)-1000:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keep only the last 1000 websocket connections to prevent unlimited memory growth
|
||||||
|
if len(sc.wsConnections) > 1000 {
|
||||||
|
sc.wsConnections = sc.wsConnections[len(sc.wsConnections)-1000:]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to avoid potential index out of bounds
|
// Helper function to avoid potential index out of bounds
|
||||||
func min(a, b int) int {
|
func minInt(a, b int) int {
|
||||||
if a < b {
|
if a < b {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
@@ -261,9 +293,23 @@ func main() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure websocket upgrader
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
// Allow all origins
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if this is a WebSocket upgrade request
|
||||||
|
if websocket.IsWebSocketUpgrade(r) {
|
||||||
|
handleWebSocketRequest(w, r, backends, client, &upgrader, statsCollector)
|
||||||
|
} else {
|
||||||
|
// Handle regular HTTP request
|
||||||
stats := handleRequest(w, r, backends, client)
|
stats := handleRequest(w, r, backends, client)
|
||||||
statsCollector.AddStats(stats, 0) // The 0 is a placeholder, we're not using totalDuration in the collector
|
statsCollector.AddStats(stats, 0) // The 0 is a placeholder, we're not using totalDuration in the collector
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Fatal(http.ListenAndServe(listenAddr, nil))
|
log.Fatal(http.ListenAndServe(listenAddr, nil))
|
||||||
@@ -423,6 +469,114 @@ func logResponseStats(totalDuration time.Duration, stats []ResponseStats) {
|
|||||||
fmt.Println(strings.Join(parts, " | "))
|
fmt.Println(strings.Join(parts, " | "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleWebSocketRequest manages WebSocket proxying
|
||||||
|
func handleWebSocketRequest(w http.ResponseWriter, r *http.Request, backends []Backend,
|
||||||
|
httpClient *http.Client, upgrader *websocket.Upgrader,
|
||||||
|
statsCollector *StatsCollector) {
|
||||||
|
// Upgrade the client connection
|
||||||
|
clientConn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to upgrade client connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// Connect to all backends
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, backend := range backends {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(b Backend) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Create backend URL with ws/wss instead of http/https
|
||||||
|
backendURL := strings.Replace(b.URL, "http://", "ws://", 1)
|
||||||
|
backendURL = strings.Replace(backendURL, "https://", "wss://", 1)
|
||||||
|
|
||||||
|
// Copy headers for the dialer
|
||||||
|
header := http.Header{}
|
||||||
|
for name, values := range r.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
header.Add(name, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
// Connect to backend WebSocket
|
||||||
|
dialer := websocket.DefaultDialer
|
||||||
|
backendConn, resp, err := dialer.Dial(backendURL, header)
|
||||||
|
connectDuration := time.Since(startTime)
|
||||||
|
|
||||||
|
stats := WebSocketStats{
|
||||||
|
Backend: b.Name,
|
||||||
|
ConnectTime: connectDuration,
|
||||||
|
IsActive: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
status := 0
|
||||||
|
if resp != nil {
|
||||||
|
status = resp.StatusCode
|
||||||
|
}
|
||||||
|
log.Printf("Failed to connect to backend %s: %v (status: %d)", b.Name, err, status)
|
||||||
|
stats.Error = err
|
||||||
|
statsCollector.AddWebSocketStats(stats)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer backendConn.Close()
|
||||||
|
|
||||||
|
stats.IsActive = true
|
||||||
|
statsCollector.AddWebSocketStats(stats)
|
||||||
|
|
||||||
|
// If this is the primary backend, set up bidirectional proxying
|
||||||
|
if b.Role == "primary" {
|
||||||
|
// Forward messages from client to primary backend
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
messageType, message, err := clientConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error reading from client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = backendConn.WriteMessage(messageType, message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error writing to primary backend: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Forward messages from primary backend to client
|
||||||
|
for {
|
||||||
|
messageType, message, err := backendConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error reading from primary backend: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = clientConn.WriteMessage(messageType, message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error writing to client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For secondary backends, just read and discard messages
|
||||||
|
for {
|
||||||
|
_, _, err := backendConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Secondary backend %s connection closed: %v", b.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all connections to terminate
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if value, exists := os.LookupEnv(key); exists {
|
if value, exists := os.LookupEnv(key); exists {
|
||||||
return value
|
return value
|
||||||
|
|||||||
Reference in New Issue
Block a user