|
3 | 3 | package websocket |
4 | 4 |
|
5 | 5 | import ( |
| 6 | + "bufio" |
6 | 7 | "bytes" |
7 | 8 | "compress/flate" |
| 9 | + "context" |
8 | 10 | "io" |
| 11 | + "net" |
9 | 12 | "strings" |
10 | 13 | "testing" |
| 14 | + "time" |
11 | 15 |
|
12 | 16 | "github.com/coder/websocket/internal/test/assert" |
13 | 17 | "github.com/coder/websocket/internal/test/xrand" |
@@ -59,3 +63,249 @@ func BenchmarkFlateReader(b *testing.B) { |
59 | 63 | io.ReadAll(r) |
60 | 64 | } |
61 | 65 | } |
| 66 | + |
| 67 | +// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed |
| 68 | +// messages in a single frame instead of multiple frames, and that messages |
| 69 | +// below the flateThreshold are sent uncompressed. |
| 70 | +// This is a regression test for https://github.com/coder/websocket/issues/435 |
| 71 | +func TestWriteSingleFrameCompressed(t *testing.T) { |
| 72 | + t.Parallel() |
| 73 | + |
| 74 | + var ( |
| 75 | + flateThreshold = 64 |
| 76 | + |
| 77 | + largeMsg = []byte(strings.Repeat("hello world ", 100)) |
| 78 | + smallMsg = []byte("small message") |
| 79 | + ) |
| 80 | + |
| 81 | + testCases := []struct { |
| 82 | + name string |
| 83 | + mode CompressionMode |
| 84 | + msg []byte |
| 85 | + wantRsv1 bool // true = compressed, false = uncompressed |
| 86 | + }{ |
| 87 | + {"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true}, |
| 88 | + {"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true}, |
| 89 | + {"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false}, |
| 90 | + {"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false}, |
| 91 | + } |
| 92 | + |
| 93 | + for _, tc := range testCases { |
| 94 | + t.Run(tc.name, func(t *testing.T) { |
| 95 | + t.Parallel() |
| 96 | + |
| 97 | + clientConn, serverConn := net.Pipe() |
| 98 | + defer clientConn.Close() |
| 99 | + defer serverConn.Close() |
| 100 | + |
| 101 | + c := newConn(connConfig{ |
| 102 | + rwc: clientConn, |
| 103 | + client: true, |
| 104 | + copts: tc.mode.opts(), |
| 105 | + flateThreshold: flateThreshold, |
| 106 | + br: bufio.NewReader(clientConn), |
| 107 | + bw: bufio.NewWriterSize(clientConn, 4096), |
| 108 | + }) |
| 109 | + |
| 110 | + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) |
| 111 | + defer cancel() |
| 112 | + |
| 113 | + writeDone := make(chan error, 1) |
| 114 | + go func() { |
| 115 | + writeDone <- c.Write(ctx, MessageText, tc.msg) |
| 116 | + }() |
| 117 | + |
| 118 | + reader := bufio.NewReader(serverConn) |
| 119 | + readBuf := make([]byte, 8) |
| 120 | + |
| 121 | + h, err := readFrameHeader(reader, readBuf) |
| 122 | + assert.Success(t, err) |
| 123 | + |
| 124 | + _, err = io.CopyN(io.Discard, reader, h.payloadLength) |
| 125 | + assert.Success(t, err) |
| 126 | + |
| 127 | + assert.Equal(t, "opcode", opText, h.opcode) |
| 128 | + assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1) |
| 129 | + assert.Equal(t, "fin", true, h.fin) |
| 130 | + |
| 131 | + err = <-writeDone |
| 132 | + assert.Success(t, err) |
| 133 | + }) |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by |
| 138 | +// Conn.Writer works correctly with context takeover enabled. This tests that |
| 139 | +// the flateWriter destination is properly restored after Conn.Write redirects |
| 140 | +// it to a temporary buffer. |
| 141 | +func TestWriteThenWriterContextTakeover(t *testing.T) { |
| 142 | + t.Parallel() |
| 143 | + |
| 144 | + clientConn, serverConn := net.Pipe() |
| 145 | + defer clientConn.Close() |
| 146 | + defer serverConn.Close() |
| 147 | + |
| 148 | + client := newConn(connConfig{ |
| 149 | + rwc: clientConn, |
| 150 | + client: true, |
| 151 | + copts: CompressionContextTakeover.opts(), |
| 152 | + flateThreshold: 64, |
| 153 | + br: bufio.NewReader(clientConn), |
| 154 | + bw: bufio.NewWriterSize(clientConn, 4096), |
| 155 | + }) |
| 156 | + |
| 157 | + server := newConn(connConfig{ |
| 158 | + rwc: serverConn, |
| 159 | + client: false, |
| 160 | + copts: CompressionContextTakeover.opts(), |
| 161 | + flateThreshold: 64, |
| 162 | + br: bufio.NewReader(serverConn), |
| 163 | + bw: bufio.NewWriterSize(serverConn, 4096), |
| 164 | + }) |
| 165 | + |
| 166 | + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) |
| 167 | + defer cancel() |
| 168 | + |
| 169 | + msg1 := []byte(strings.Repeat("first message ", 100)) |
| 170 | + msg2 := []byte(strings.Repeat("second message ", 100)) |
| 171 | + |
| 172 | + type readResult struct { |
| 173 | + typ MessageType |
| 174 | + p []byte |
| 175 | + err error |
| 176 | + } |
| 177 | + readCh := make(chan readResult, 3) |
| 178 | + go func() { |
| 179 | + for range 3 { |
| 180 | + typ, p, err := server.Read(ctx) |
| 181 | + readCh <- readResult{typ, p, err} |
| 182 | + } |
| 183 | + }() |
| 184 | + |
| 185 | + // We want to verify mixing `Write` and `Writer` usages still work. |
| 186 | + // |
| 187 | + // To this end, we call them in this order: |
| 188 | + // - `Write` |
| 189 | + // - `Writer` |
| 190 | + // - `Write` |
| 191 | + // |
| 192 | + // This verifies that it works for a `Write` followed by a `Writer` |
| 193 | + // as well as a `Writer` followed by a `Write`. |
| 194 | + |
| 195 | + // 1. `Write` API |
| 196 | + err := client.Write(ctx, MessageText, msg1) |
| 197 | + assert.Success(t, err) |
| 198 | + |
| 199 | + r := <-readCh |
| 200 | + assert.Success(t, r.err) |
| 201 | + assert.Equal(t, "Write type", MessageText, r.typ) |
| 202 | + assert.Equal(t, "Write content", string(msg1), string(r.p)) |
| 203 | + |
| 204 | + // 2. `Writer` API |
| 205 | + w, err := client.Writer(ctx, MessageBinary) |
| 206 | + assert.Success(t, err) |
| 207 | + _, err = w.Write(msg2) |
| 208 | + assert.Success(t, err) |
| 209 | + assert.Success(t, w.Close()) |
| 210 | + |
| 211 | + r = <-readCh |
| 212 | + assert.Success(t, r.err) |
| 213 | + assert.Equal(t, "Writer type", MessageBinary, r.typ) |
| 214 | + assert.Equal(t, "Writer content", string(msg2), string(r.p)) |
| 215 | + |
| 216 | + // 3. `Write` API again |
| 217 | + err = client.Write(ctx, MessageText, msg1) |
| 218 | + assert.Success(t, err) |
| 219 | + |
| 220 | + r = <-readCh |
| 221 | + assert.Success(t, r.err) |
| 222 | + assert.Equal(t, "Write type", MessageText, r.typ) |
| 223 | + assert.Equal(t, "Write content", string(msg1), string(r.p)) |
| 224 | +} |
| 225 | + |
| 226 | +// TestCompressionDictionaryPreserved verifies that context takeover mode |
| 227 | +// preserves the compression dictionary across Conn.Write calls, resulting |
| 228 | +// in better compression for consecutive similar messages. |
| 229 | +func TestCompressionDictionaryPreserved(t *testing.T) { |
| 230 | + t.Parallel() |
| 231 | + |
| 232 | + msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50)) |
| 233 | + |
| 234 | + takeoverClient, takeoverServer := net.Pipe() |
| 235 | + defer takeoverClient.Close() |
| 236 | + defer takeoverServer.Close() |
| 237 | + |
| 238 | + withTakeover := newConn(connConfig{ |
| 239 | + rwc: takeoverClient, |
| 240 | + client: true, |
| 241 | + copts: CompressionContextTakeover.opts(), |
| 242 | + flateThreshold: 64, |
| 243 | + br: bufio.NewReader(takeoverClient), |
| 244 | + bw: bufio.NewWriterSize(takeoverClient, 4096), |
| 245 | + }) |
| 246 | + |
| 247 | + noTakeoverClient, noTakeoverServer := net.Pipe() |
| 248 | + defer noTakeoverClient.Close() |
| 249 | + defer noTakeoverServer.Close() |
| 250 | + |
| 251 | + withoutTakeover := newConn(connConfig{ |
| 252 | + rwc: noTakeoverClient, |
| 253 | + client: true, |
| 254 | + copts: CompressionNoContextTakeover.opts(), |
| 255 | + flateThreshold: 64, |
| 256 | + br: bufio.NewReader(noTakeoverClient), |
| 257 | + bw: bufio.NewWriterSize(noTakeoverClient, 4096), |
| 258 | + }) |
| 259 | + |
| 260 | + ctx, cancel := context.WithTimeout(context.Background(), time.Second) |
| 261 | + defer cancel() |
| 262 | + |
| 263 | + // Capture compressed sizes for both modes |
| 264 | + var withTakeoverSizes, withoutTakeoverSizes []int64 |
| 265 | + |
| 266 | + reader1 := bufio.NewReader(takeoverServer) |
| 267 | + reader2 := bufio.NewReader(noTakeoverServer) |
| 268 | + readBuf := make([]byte, 8) |
| 269 | + |
| 270 | + // Send 3 identical messages each |
| 271 | + for range 3 { |
| 272 | + // With context takeover |
| 273 | + writeDone1 := make(chan error, 1) |
| 274 | + go func() { |
| 275 | + writeDone1 <- withTakeover.Write(ctx, MessageText, msg) |
| 276 | + }() |
| 277 | + |
| 278 | + h1, err := readFrameHeader(reader1, readBuf) |
| 279 | + assert.Success(t, err) |
| 280 | + |
| 281 | + _, err = io.CopyN(io.Discard, reader1, h1.payloadLength) |
| 282 | + assert.Success(t, err) |
| 283 | + |
| 284 | + withTakeoverSizes = append(withTakeoverSizes, h1.payloadLength) |
| 285 | + assert.Success(t, <-writeDone1) |
| 286 | + |
| 287 | + // Without context takeover |
| 288 | + writeDone2 := make(chan error, 1) |
| 289 | + go func() { |
| 290 | + writeDone2 <- withoutTakeover.Write(ctx, MessageText, msg) |
| 291 | + }() |
| 292 | + |
| 293 | + h2, err := readFrameHeader(reader2, readBuf) |
| 294 | + assert.Success(t, err) |
| 295 | + |
| 296 | + _, err = io.CopyN(io.Discard, reader2, h2.payloadLength) |
| 297 | + assert.Success(t, err) |
| 298 | + |
| 299 | + withoutTakeoverSizes = append(withoutTakeoverSizes, h2.payloadLength) |
| 300 | + assert.Success(t, <-writeDone2) |
| 301 | + } |
| 302 | + |
| 303 | + // With context takeover, the 2nd and 3rd messages should be smaller than |
| 304 | + // without context takeover (dictionary helps compress repeated patterns). |
| 305 | + // The first message will be similar size for both modes since there's no |
| 306 | + // prior dictionary. But subsequent messages benefit from context takeover. |
| 307 | + if withTakeoverSizes[2] >= withoutTakeoverSizes[2] { |
| 308 | + t.Errorf("context takeover should compress better: with=%d, without=%d", |
| 309 | + withTakeoverSizes[2], withoutTakeoverSizes[2]) |
| 310 | + } |
| 311 | +} |
0 commit comments