From 4e991f84d802c918839f2c058219307008aebdfc Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 3 Jun 2023 01:42:02 +0200 Subject: [PATCH 01/26] feat: add textproto fork --- textproto/header.go | 56 +++ textproto/header_test.go | 54 +++ textproto/pipeline.go | 118 ++++++ textproto/reader.go | 822 +++++++++++++++++++++++++++++++++++++++ textproto/reader_test.go | 525 +++++++++++++++++++++++++ textproto/textproto.go | 152 ++++++++ textproto/writer.go | 119 ++++++ textproto/writer_test.go | 61 +++ 8 files changed, 1907 insertions(+) create mode 100644 textproto/header.go create mode 100644 textproto/header_test.go create mode 100644 textproto/pipeline.go create mode 100644 textproto/reader.go create mode 100644 textproto/reader_test.go create mode 100644 textproto/textproto.go create mode 100644 textproto/writer.go create mode 100644 textproto/writer_test.go diff --git a/textproto/header.go b/textproto/header.go new file mode 100644 index 00000000..a58df7ae --- /dev/null +++ b/textproto/header.go @@ -0,0 +1,56 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +// A MIMEHeader represents a MIME-style header mapping +// keys to sets of values. +type MIMEHeader map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h MIMEHeader) Add(key, value string) { + key = CanonicalMIMEHeaderKey(key) + h[key] = append(h[key], value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h MIMEHeader) Set(key, value string) { + h[CanonicalMIMEHeaderKey(key)] = []string{value} +} + +// Get gets the first value associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. +// If there are no values associated with the key, Get returns "". +// To use non-canonical keys, access the map directly. +func (h MIMEHeader) Get(key string) string { + if h == nil { + return "" + } + v := h[CanonicalMIMEHeaderKey(key)] + if len(v) == 0 { + return "" + } + return v[0] +} + +// Values returns all values associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h MIMEHeader) Values(key string) []string { + if h == nil { + return nil + } + return h[CanonicalMIMEHeaderKey(key)] +} + +// Del deletes the values associated with key. +func (h MIMEHeader) Del(key string) { + delete(h, CanonicalMIMEHeaderKey(key)) +} diff --git a/textproto/header_test.go b/textproto/header_test.go new file mode 100644 index 00000000..de9405ca --- /dev/null +++ b/textproto/header_test.go @@ -0,0 +1,54 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import "testing" + +type canonicalHeaderKeyTest struct { + in, out string +} + +var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ + {"a-b-c", "A-B-C"}, + {"a-1-c", "A-1-C"}, + {"User-Agent", "User-Agent"}, + {"uSER-aGENT", "User-Agent"}, + {"user-agent", "User-Agent"}, + {"USER-AGENT", "User-Agent"}, + + // Other valid tchar bytes in tokens: + {"foo-bar_baz", "Foo-Bar_baz"}, + {"foo-bar$baz", "Foo-Bar$baz"}, + {"foo-bar~baz", "Foo-Bar~baz"}, + {"foo-bar*baz", "Foo-Bar*baz"}, + + // Non-ASCII or anything with spaces or non-token chars is unchanged: + {"üser-agenT", "üser-agenT"}, + {"a B", "a B"}, + + // This caused a panic due to mishandling of a space: + {"C Ontent-Transfer-Encoding", "C Ontent-Transfer-Encoding"}, + {"foo bar", "foo bar"}, +} + +func TestCanonicalMIMEHeaderKey(t *testing.T) { + for _, tt := range canonicalHeaderKeyTests { + if s := CanonicalMIMEHeaderKey(tt.in); s != tt.out { + t.Errorf("CanonicalMIMEHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) + } + } +} + +// Issue #34799 add a Header method to get multiple values []string, with canonicalized key +func TestMIMEHeaderMultipleValues(t *testing.T) { + testHeader := MIMEHeader{ + "Set-Cookie": {"cookie 1", "cookie 2"}, + } + values := testHeader.Values("set-cookie") + n := len(values) + if n != 2 { + t.Errorf("count: %d; want 2", n) + } +} diff --git a/textproto/pipeline.go b/textproto/pipeline.go new file mode 100644 index 00000000..1928a306 --- /dev/null +++ b/textproto/pipeline.go @@ -0,0 +1,118 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "sync" +) + +// A Pipeline manages a pipelined in-order request/response sequence. +// +// To use a Pipeline p to manage multiple clients on a connection, +// each client should run: +// +// id := p.Next() // take a number +// +// p.StartRequest(id) // wait for turn to send request +// «send request» +// p.EndRequest(id) // notify Pipeline that request is sent +// +// p.StartResponse(id) // wait for turn to read response +// «read response» +// p.EndResponse(id) // notify Pipeline that response is read +// +// A pipelined server can use the same calls to ensure that +// responses computed in parallel are written in the correct order. +type Pipeline struct { + mu sync.Mutex + id uint + request sequencer + response sequencer +} + +// Next returns the next id for a request/response pair. +func (p *Pipeline) Next() uint { + p.mu.Lock() + id := p.id + p.id++ + p.mu.Unlock() + return id +} + +// StartRequest blocks until it is time to send (or, if this is a server, receive) +// the request with the given id. +func (p *Pipeline) StartRequest(id uint) { + p.request.Start(id) +} + +// EndRequest notifies p that the request with the given id has been sent +// (or, if this is a server, received). +func (p *Pipeline) EndRequest(id uint) { + p.request.End(id) +} + +// StartResponse blocks until it is time to receive (or, if this is a server, send) +// the request with the given id. +func (p *Pipeline) StartResponse(id uint) { + p.response.Start(id) +} + +// EndResponse notifies p that the response with the given id has been received +// (or, if this is a server, sent). +func (p *Pipeline) EndResponse(id uint) { + p.response.End(id) +} + +// A sequencer schedules a sequence of numbered events that must +// happen in order, one after the other. The event numbering must start +// at 0 and increment without skipping. The event number wraps around +// safely as long as there are not 2^32 simultaneous events pending. +type sequencer struct { + mu sync.Mutex + id uint + wait map[uint]chan struct{} +} + +// Start waits until it is time for the event numbered id to begin. +// That is, except for the first event, it waits until End(id-1) has +// been called. +func (s *sequencer) Start(id uint) { + s.mu.Lock() + if s.id == id { + s.mu.Unlock() + return + } + c := make(chan struct{}) + if s.wait == nil { + s.wait = make(map[uint]chan struct{}) + } + s.wait[id] = c + s.mu.Unlock() + <-c +} + +// End notifies the sequencer that the event numbered id has completed, +// allowing it to schedule the event numbered id+1. It is a run-time error +// to call End with an id that is not the number of the active event. +func (s *sequencer) End(id uint) { + s.mu.Lock() + if s.id != id { + s.mu.Unlock() + panic("out of sync") + } + id++ + s.id = id + if s.wait == nil { + s.wait = make(map[uint]chan struct{}) + } + c, ok := s.wait[id] + if ok { + delete(s.wait, id) + } + s.mu.Unlock() + if ok { + close(c) + } +} diff --git a/textproto/reader.go b/textproto/reader.go new file mode 100644 index 00000000..fc2590b1 --- /dev/null +++ b/textproto/reader.go @@ -0,0 +1,822 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "math" + "strconv" + "strings" + "sync" +) + +// A Reader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type Reader struct { + R *bufio.Reader + dot *dotReader + buf []byte // a re-usable buffer for readContinuedLineSlice +} + +// NewReader returns a new Reader reading from r. +// +// To avoid denial of service attacks, the provided bufio.Reader +// should be reading from an io.LimitReader or similar Reader to bound +// the size of responses. +func NewReader(r *bufio.Reader) *Reader { + return &Reader{R: r} +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *Reader) ReadLine() (string, error) { + line, err := r.readLineSlice() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *Reader) ReadLineBytes() ([]byte, error) { + line, err := r.readLineSlice() + if line != nil { + line = bytes.Clone(line) + } + return line, err +} + +func (r *Reader) readLineSlice() ([]byte, error) { + r.closeDot() + var line []byte + for { + l, more, err := r.R.ReadLine() + if err != nil { + return nil, err + } + // Avoid the copy if the first call produced a full line. + if line == nil && !more { + return l, nil + } + line = append(line, l...) + if !more { + break + } + } + return line, nil +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// Empty lines are never continued. +func (r *Reader) ReadContinuedLine() (string, error) { + line, err := r.readContinuedLineSlice(noValidation) + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { + line, err := r.readContinuedLineSlice(noValidation) + if line != nil { + line = bytes.Clone(line) + } + return line, err +} + +// readContinuedLineSlice reads continued lines from the reader buffer, +// returning a byte slice with all lines. The validateFirstLine function +// is run on the first read line, and if it returns an error then this +// error is returned from readContinuedLineSlice. +func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { + if validateFirstLine == nil { + return nil, fmt.Errorf("missing validateFirstLine func") + } + + // Read the first line. + line, err := r.readLineSlice() + if err != nil { + return nil, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + + if err := validateFirstLine(line); err != nil { + return nil, err + } + + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), or a blank + // line, so we can avoid copying that buffered data around in memory + // and skipping over non-existent whitespace. + if r.R.Buffered() > 1 { + peek, _ := r.R.Peek(2) + if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') || + len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' { + return trim(line), nil + } + } + + // ReadByte or the next readLineSlice will flush the read buffer; + // copy the slice into buf. + r.buf = append(r.buf[:0], trim(line)...) + + // Read continuation lines. + for r.skipSpace() > 0 { + line, err := r.readLineSlice() + if err != nil { + break + } + r.buf = append(r.buf, ' ') + r.buf = append(r.buf, trim(line)...) + } + return r.buf, nil +} + +// skipSpace skips R over all spaces and returns the number of bytes skipped. +func (r *Reader) skipSpace() int { + n := 0 + for { + c, err := r.R.ReadByte() + if err != nil { + // Bufio will keep err until next read. + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + n++ + } + return n +} + +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { + line, err := r.ReadLine() + if err != nil { + return + } + return parseCodeLine(line, expectCode) +} + +func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = ProtocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = ProtocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &Error{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// +// code message +// +// where code is a three-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = ProtocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form: +// +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// +// where code is a three-digit status code. The first line starts with the +// code and a hyphen. The response is terminated by a line that starts +// with the same code followed by a space. Each line in message is +// separated by a newline (\n). +// +// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for +// details of another form of response accepted: +// +// code-message line 1 +// message line 2 +// ... +// code message line n +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + multi := continued + for continued { + line, err := r.ReadLine() + if err != nil { + return 0, "", err + } + + var code2 int + var moreMessage string + code2, continued, moreMessage, err = parseCodeLine(line, 0) + if err != nil || code2 != code { + message += "\n" + strings.TrimRight(line, "\r\n") + continued = true + continue + } + message += "\n" + moreMessage + } + if err != nil && multi && message != "" { + // replace one line error message with all lines (full message) + err = &Error{code, message} + } + return +} + +// DotReader returns a new Reader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned Reader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protocols such as SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the Reader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error io.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *Reader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *Reader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = io.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *Reader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotBytes() ([]byte, error) { + return io.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotLines() ([]string, error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v []string + var err error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v = append(v, line) + } + return v, err +} + +var colon = []byte(":") + +// ReadMIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, ReadMIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": {"Value 1", "Value 2"}, +// "Long-Key": {"Even Longer Value"}, +// } +func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { + return readMIMEHeader(r, math.MaxInt64, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) { + // Avoid lots of small slice allocations later by allocating one + // large one ahead of time which we'll cut up into smaller + // slices. If this isn't big enough later, we allocate small ones. + var strs []string + hint := r.upcomingHeaderKeys() + if hint > 0 { + if hint > 1000 { + hint = 1000 // set a cap to avoid overallocation + } + strs = make([]string, hint) + } + + m := make(MIMEHeader, hint) + + // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry. + // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large + // MIMEHeaders average about 200 bytes per entry. + maxMemory -= 400 + const mapEntryOverhead = 200 + + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + line, err := r.readLineSlice() + if err != nil { + return m, err + } + return m, ProtocolError("malformed MIME header initial line: " + string(line)) + } + + for { + kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon. + k, v, ok := bytes.Cut(kv, colon) + if !ok { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + key, ok := canonicalMIMEHeaderKey(k) + if !ok { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + for _, c := range v { + if !validHeaderValueByte(c) { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + } + + // As per RFC 7230 field-name is a token, tokens consist of one or more chars. + // We could return a ProtocolError here, but better to be liberal in what we + // accept, so if we get an empty key, skip it. + if key == "" { + continue + } + + maxHeaders-- + if maxHeaders < 0 { + return nil, errors.New("message too large") + } + + // Skip initial spaces in value. + value := string(bytes.TrimLeft(v, " \t")) + + vv := m[key] + if vv == nil { + maxMemory -= int64(len(key)) + maxMemory -= mapEntryOverhead + } + maxMemory -= int64(len(value)) + if maxMemory < 0 { + // TODO: This should be a distinguishable error (ErrMessageTooLarge) + // to allow mime/multipart to detect it. + return m, errors.New("message too large") + } + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = value + m[key] = vv + } else { + m[key] = append(vv, value) + } + + if err != nil { + return m, err + } + } +} + +// noValidation is a no-op validation func for readContinuedLineSlice +// that permits any lines. +func noValidation(_ []byte) error { return nil } + +// mustHaveFieldNameColon ensures that, per RFC 7230, the +// field-name is on a single line, so the first line must +// contain a colon. +func mustHaveFieldNameColon(line []byte) error { + if bytes.IndexByte(line, ':') < 0 { + return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line)) + } + return nil +} + +var nl = []byte("\n") + +// upcomingHeaderKeys returns an approximation of the number of keys +// that will be in this header. If it gets confused, it returns 0. +func (r *Reader) upcomingHeaderKeys() (n int) { + // Try to determine the 'hint' size. + r.R.Peek(1) // force a buffer load if empty + s := r.R.Buffered() + if s == 0 { + return + } + peek, _ := r.R.Peek(s) + for len(peek) > 0 && n < 1000 { + var line []byte + line, peek, _ = bytes.Cut(peek, nl) + if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { + // Blank line separating headers from the body. + break + } + if line[0] == ' ' || line[0] == '\t' { + // Folded continuation of the previous line. + continue + } + n++ + } + return n +} + +// CanonicalMIMEHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// MIME header keys are assumed to be ASCII only. +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalMIMEHeaderKey(s string) string { + // Quick check for canonical encoding. + upper := true + for i := 0; i < len(s); i++ { + c := s[i] + if !validHeaderFieldByte(c) { + return s + } + if upper && 'a' <= c && c <= 'z' { + s, _ = canonicalMIMEHeaderKey([]byte(s)) + return s + } + if !upper && 'A' <= c && c <= 'Z' { + s, _ = canonicalMIMEHeaderKey([]byte(s)) + return s + } + upper = c == '-' + } + return s +} + +const toLower = 'a' - 'A' + +// validHeaderFieldByte reports whether c is a valid byte in a header +// field name. RFC 7230 says: +// +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +// token = 1*tchar +func validHeaderFieldByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<>64)) != 0 +} + +// validHeaderValueByte reports whether c is a valid byte in a header +// field value. RFC 7230 says: +// +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// +// RFC 5234 says: +// +// HTAB = %x09 +// SP = %x20 +// VCHAR = %x21-7E +func validHeaderValueByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<>64)) == 0 +} + +// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is +// allowed to mutate the provided byte slice before returning the +// string. +// +// For invalid inputs (if a contains spaces or non-token bytes), a +// is unchanged and a string copy is returned. +// +// ok is true if the header key contains only valid characters and spaces. +// ReadMIMEHeader accepts header keys containing spaces, but does not +// canonicalize them. +func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) { + // See if a looks like a header key. If not, return it unchanged. + noCanon := false + for _, c := range a { + if validHeaderFieldByte(c) { + continue + } + // Don't canonicalize. + if c == ' ' { + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See https://go.dev/issue/34540. + noCanon = true + continue + } + return string(a), false + } + if noCanon { + return string(a), true + } + + upper := true + for i, c := range a { + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + if upper && 'a' <= c && c <= 'z' { + c -= toLower + } else if !upper && 'A' <= c && c <= 'Z' { + c += toLower + } + a[i] = c + upper = c == '-' // for next time + } + commonHeaderOnce.Do(initCommonHeader) + // The compiler recognizes m[string(byteSlice)] as a special + // case, so a copy of a's bytes into a new string does not + // happen in this map lookup: + if v := commonHeader[string(a)]; v != "" { + return v, true + } + return string(a), true +} + +// commonHeader interns common header strings. +var commonHeader map[string]string + +var commonHeaderOnce sync.Once + +func initCommonHeader() { + commonHeader = make(map[string]string) + for _, v := range []string{ + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Accept-Ranges", + "Cache-Control", + "Cc", + "Connection", + "Content-Id", + "Content-Language", + "Content-Length", + "Content-Transfer-Encoding", + "Content-Type", + "Cookie", + "Date", + "Dkim-Signature", + "Etag", + "Expires", + "From", + "Host", + "If-Modified-Since", + "If-None-Match", + "In-Reply-To", + "Last-Modified", + "Location", + "Message-Id", + "Mime-Version", + "Pragma", + "Received", + "Return-Path", + "Server", + "Set-Cookie", + "Subject", + "To", + "User-Agent", + "Via", + "X-Forwarded-For", + "X-Imforwards", + "X-Powered-By", + } { + commonHeader[v] = v + } +} diff --git a/textproto/reader_test.go b/textproto/reader_test.go new file mode 100644 index 00000000..696ae406 --- /dev/null +++ b/textproto/reader_test.go @@ -0,0 +1,525 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "io" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +func reader(s string) *Reader { + return NewReader(bufio.NewReader(strings.NewReader(s))) +} + +func TestReadLine(t *testing.T) { + r := reader("line1\nline2\n") + s, err := r.ReadLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "line2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "" || err != io.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadContinuedLine(t *testing.T) { + r := reader("line1\nline\n 2\nline3\n") + s, err := r.ReadContinuedLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line 2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line3" || err != nil { + t.Fatalf("Line 3: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "" || err != io.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadCodeLine(t *testing.T) { + r := reader("123 hi\n234 bye\n345 no way\n") + code, msg, err := r.ReadCodeLine(0) + if code != 123 || msg != "hi" || err != nil { + t.Fatalf("Line 1: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(23) + if code != 234 || msg != "bye" || err != nil { + t.Fatalf("Line 2: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(346) + if code != 345 || msg != "no way" || err == nil { + t.Fatalf("Line 3: %d, %s, %v", code, msg, err) + } + if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg { + t.Fatalf("Line 3: wrong error %v\n", err) + } + code, msg, err = r.ReadCodeLine(1) + if code != 0 || msg != "" || err != io.EOF { + t.Fatalf("EOF: %d, %s, %v", code, msg, err) + } +} + +func TestReadDotLines(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n") + s, err := r.ReadDotLines() + want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""} + if !reflect.DeepEqual(s, want) || err != nil { + t.Fatalf("ReadDotLines: %v, %v", s, err) + } + + s, err = r.ReadDotLines() + want = []string{"another"} + if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotLines2: %v, %v", s, err) + } +} + +func TestReadDotBytes(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n") + b, err := r.ReadDotBytes() + want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n") + if !reflect.DeepEqual(b, want) || err != nil { + t.Fatalf("ReadDotBytes: %q, %v", b, err) + } + + b, err = r.ReadDotBytes() + want = []byte("anot.her\n") + if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotBytes2: %q, %v", b, err) + } +} + +func TestReadMIMEHeader(t *testing.T) { + r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{ + "My-Key": {"Value 1", "Value 2"}, + "Long-Key": {"Even Longer Value"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +func TestReadMIMEHeaderSingle(t *testing.T) { + r := reader("Foo: bar\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{"Foo": {"bar"}} + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +// TestReaderUpcomingHeaderKeys is testing an internal function, but it's very +// difficult to test well via the external API. +func TestReaderUpcomingHeaderKeys(t *testing.T) { + for _, test := range []struct { + input string + want int + }{{ + input: "", + want: 0, + }, { + input: "A: v", + want: 1, + }, { + input: "A: v\r\nB: v\r\n", + want: 2, + }, { + input: "A: v\nB: v\n", + want: 2, + }, { + input: "A: v\r\n continued\r\n still continued\r\nB: v\r\n\r\n", + want: 2, + }, { + input: "A: v\r\n\r\nB: v\r\nC: v\r\n", + want: 1, + }, { + input: "A: v" + strings.Repeat("\n", 1000), + want: 1, + }} { + r := reader(test.input) + got := r.upcomingHeaderKeys() + if test.want != got { + t.Fatalf("upcomingHeaderKeys(%q): %v; want %v", test.input, got, test.want) + } + } +} + +func TestReadMIMEHeaderNoKey(t *testing.T) { + r := reader(": bar\ntest-1: 1\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{"Test-1": {"1"}} + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +func TestLargeReadMIMEHeader(t *testing.T) { + data := make([]byte, 16*1024) + for i := 0; i < len(data); i++ { + data[i] = 'x' + } + sdata := string(data) + r := reader("Cookie: " + sdata + "\r\n\n") + m, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + cookie := m.Get("Cookie") + if cookie != sdata { + t.Fatalf("ReadMIMEHeader: %v bytes, want %v bytes", len(cookie), len(sdata)) + } +} + +// TestReadMIMEHeaderNonCompliant checks that we don't normalize headers +// with spaces before colons, and accept spaces in keys. +func TestReadMIMEHeaderNonCompliant(t *testing.T) { + // These invalid headers will be rejected by net/http according to RFC 7230. + r := reader("Foo: bar\r\n" + + "Content-Language: en\r\n" + + "SID : 0\r\n" + + "Audio Mode : None\r\n" + + "Privilege : 127\r\n\r\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{ + "Foo": {"bar"}, + "Content-Language": {"en"}, + "SID ": {"0"}, + "Audio Mode ": {"None"}, + "Privilege ": {"127"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader =\n%v, %v; want:\n%v", m, err, want) + } +} + +func TestReadMIMEHeaderMalformed(t *testing.T) { + inputs := []string{ + "No colon first line\r\nFoo: foo\r\n\r\n", + " No colon first line with leading space\r\nFoo: foo\r\n\r\n", + "\tNo colon first line with leading tab\r\nFoo: foo\r\n\r\n", + " First: line with leading space\r\nFoo: foo\r\n\r\n", + "\tFirst: line with leading tab\r\nFoo: foo\r\n\r\n", + "Foo: foo\r\nNo colon second line\r\n\r\n", + "Foo-\n\tBar: foo\r\n\r\n", + "Foo-\r\n\tBar: foo\r\n\r\n", + "Foo\r\n\t: foo\r\n\r\n", + "Foo-\n\tBar", + "Foo \tBar: foo\r\n\r\n", + } + for _, input := range inputs { + r := reader(input) + if m, err := r.ReadMIMEHeader(); err == nil || err == io.EOF { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want nil, err", input, m, err) + } + } +} + +func TestReadMIMEHeaderBytes(t *testing.T) { + for i := 0; i <= 0xff; i++ { + s := "Foo" + string(rune(i)) + "Bar: foo\r\n\r\n" + r := reader(s) + wantErr := true + switch { + case i >= '0' && i <= '9': + wantErr = false + case i >= 'a' && i <= 'z': + wantErr = false + case i >= 'A' && i <= 'Z': + wantErr = false + case i == '!' || i == '#' || i == '$' || i == '%' || i == '&' || i == '\'' || i == '*' || i == '+' || i == '-' || i == '.' || i == '^' || i == '_' || i == '`' || i == '|' || i == '~': + wantErr = false + case i == ':': + // Special case: "Foo:Bar: foo" is the header "Foo". + wantErr = false + case i == ' ': + wantErr = false + } + m, err := r.ReadMIMEHeader() + if err != nil != wantErr { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr) + } + } + for i := 0; i <= 0xff; i++ { + s := "Foo: foo" + string(rune(i)) + "bar\r\n\r\n" + r := reader(s) + wantErr := true + switch { + case i >= 0x21 && i <= 0x7e: + wantErr = false + case i == ' ': + wantErr = false + case i == '\t': + wantErr = false + case i >= 0x80 && i <= 0xff: + wantErr = false + } + m, err := r.ReadMIMEHeader() + if (err != nil) != wantErr { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr) + } + } +} + +// Test that continued lines are properly trimmed. Issue 11204. +func TestReadMIMEHeaderTrimContinued(t *testing.T) { + // In this header, \n and \r\n terminated lines are mixed on purpose. + // We expect each line to be trimmed (prefix and suffix) before being concatenated. + // Keep the spaces as they are. + r := reader("" + // for code formatting purpose. + "a:\n" + + " 0 \r\n" + + "b:1 \t\r\n" + + "c: 2\r\n" + + " 3\t\n" + + " \t 4 \r\n\n") + m, err := r.ReadMIMEHeader() + if err != nil { + t.Fatal(err) + } + want := MIMEHeader{ + "A": {"0"}, + "B": {"1"}, + "C": {"2 3 4"}, + } + if !reflect.DeepEqual(m, want) { + t.Fatalf("ReadMIMEHeader mismatch.\n got: %q\nwant: %q", m, want) + } +} + +// Test that reading a header doesn't overallocate. Issue 58975. +func TestReadMIMEHeaderAllocations(t *testing.T) { + var totalAlloc uint64 + const count = 200 + for i := 0; i < count; i++ { + r := reader("A: b\r\n\r\n" + strings.Repeat("\n", 4096)) + var m1, m2 runtime.MemStats + runtime.ReadMemStats(&m1) + _, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + runtime.ReadMemStats(&m2) + totalAlloc += m2.TotalAlloc - m1.TotalAlloc + } + // 32k is large and we actually allocate substantially less, + // but prior to the fix for #58975 we allocated ~400k in this case. + if got, want := totalAlloc/count, uint64(32768); got > want { + t.Fatalf("ReadMIMEHeader allocated %v bytes, want < %v", got, want) + } +} + +type readResponseTest struct { + in string + inCode int + wantCode int + wantMsg string +} + +var readResponseTests = []readResponseTest{ + {"230-Anonymous access granted, restrictions apply\n" + + "Read the file README.txt,\n" + + "230 please", + 23, + 230, + "Anonymous access granted, restrictions apply\nRead the file README.txt,\n please", + }, + + {"230 Anonymous access granted, restrictions apply\n", + 23, + 230, + "Anonymous access granted, restrictions apply", + }, + + {"400-A\n400-B\n400 C", + 4, + 400, + "A\nB\nC", + }, + + {"400-A\r\n400-B\r\n400 C\r\n", + 4, + 400, + "A\nB\nC", + }, +} + +// See https://www.ietf.org/rfc/rfc959.txt page 36. +func TestRFC959Lines(t *testing.T) { + for i, tt := range readResponseTests { + r := reader(tt.in + "\nFOLLOWING DATA") + code, msg, err := r.ReadResponse(tt.inCode) + if err != nil { + t.Errorf("#%d: ReadResponse: %v", i, err) + continue + } + if code != tt.wantCode { + t.Errorf("#%d: code=%d, want %d", i, code, tt.wantCode) + } + if msg != tt.wantMsg { + t.Errorf("#%d: msg=%q, want %q", i, msg, tt.wantMsg) + } + } +} + +// Test that multi-line errors are appropriately and fully read. Issue 10230. +func TestReadMultiLineError(t *testing.T) { + r := reader("550-5.1.1 The email account that you tried to reach does not exist. Please try\n" + + "550-5.1.1 double-checking the recipient's email address for typos or\n" + + "550-5.1.1 unnecessary spaces. Learn more at\n" + + "Unexpected but legal text!\n" + + "550 5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp\n") + + wantMsg := "5.1.1 The email account that you tried to reach does not exist. Please try\n" + + "5.1.1 double-checking the recipient's email address for typos or\n" + + "5.1.1 unnecessary spaces. Learn more at\n" + + "Unexpected but legal text!\n" + + "5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp" + + code, msg, err := r.ReadResponse(250) + if err == nil { + t.Errorf("ReadResponse: no error, want error") + } + if code != 550 { + t.Errorf("ReadResponse: code=%d, want %d", code, 550) + } + if msg != wantMsg { + t.Errorf("ReadResponse: msg=%q, want %q", msg, wantMsg) + } + if err != nil && err.Error() != "550 "+wantMsg { + t.Errorf("ReadResponse: error=%q, want %q", err.Error(), "550 "+wantMsg) + } +} + +func TestCommonHeaders(t *testing.T) { + commonHeaderOnce.Do(initCommonHeader) + for h := range commonHeader { + if h != CanonicalMIMEHeaderKey(h) { + t.Errorf("Non-canonical header %q in commonHeader", h) + } + } + b := []byte("content-Length") + want := "Content-Length" + n := testing.AllocsPerRun(200, func() { + if x, _ := canonicalMIMEHeaderKey(b); x != want { + t.Fatalf("canonicalMIMEHeaderKey(%q) = %q; want %q", b, x, want) + } + }) + if n > 0 { + t.Errorf("canonicalMIMEHeaderKey allocs = %v; want 0", n) + } +} + +func TestIssue46363(t *testing.T) { + // Regression test for data race reported in issue 46363: + // ReadMIMEHeader reads commonHeader before commonHeader has been initialized. + // Run this test with the race detector enabled to catch the reported data race. + + // Reset commonHeaderOnce, so that commonHeader will have to be initialized + commonHeaderOnce = sync.Once{} + commonHeader = nil + + // Test for data race by calling ReadMIMEHeader and CanonicalMIMEHeaderKey concurrently + + // Send MIME header over net.Conn + r, w := net.Pipe() + go func() { + // ReadMIMEHeader calls canonicalMIMEHeaderKey, which reads from commonHeader + NewConn(r).ReadMIMEHeader() + }() + w.Write([]byte("A: 1\r\nB: 2\r\nC: 3\r\n\r\n")) + + // CanonicalMIMEHeaderKey calls commonHeaderOnce.Do(initCommonHeader) which initializes commonHeader + CanonicalMIMEHeaderKey("a") + + if commonHeader == nil { + t.Fatal("CanonicalMIMEHeaderKey should initialize commonHeader") + } +} + +var clientHeaders = strings.Replace(`Host: golang.org +Connection: keep-alive +Cache-Control: max-age=0 +Accept: application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5 +User-Agent: Mozilla/5.0 (X11; U; Linux x86_64; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.63 Safari/534.3 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8,fr-CH;q=0.6 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +COOKIE: __utma=000000000.0000000000.0000000000.0000000000.0000000000.00; __utmb=000000000.0.00.0000000000; __utmc=000000000; __utmz=000000000.0000000000.00.0.utmcsr=code.google.com|utmccn=(referral)|utmcmd=referral|utmcct=/p/go/issues/detail +Non-Interned: test + +`, "\n", "\r\n", -1) + +var serverHeaders = strings.Replace(`Content-Type: text/html; charset=utf-8 +Content-Encoding: gzip +Date: Thu, 27 Sep 2012 09:03:33 GMT +Server: Google Frontend +Cache-Control: private +Content-Length: 2298 +VIA: 1.1 proxy.example.com:80 (XXX/n.n.n-nnn) +Connection: Close +Non-Interned: test + +`, "\n", "\r\n", -1) + +func BenchmarkReadMIMEHeader(b *testing.B) { + b.ReportAllocs() + for _, set := range []struct { + name string + headers string + }{ + {"client_headers", clientHeaders}, + {"server_headers", serverHeaders}, + } { + b.Run(set.name, func(b *testing.B) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + + for i := 0; i < b.N; i++ { + buf.WriteString(set.headers) + if _, err := r.ReadMIMEHeader(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkUncommon(b *testing.B) { + b.ReportAllocs() + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + for i := 0; i < b.N; i++ { + buf.WriteString("uncommon-header-for-benchmark: foo\r\n\r\n") + h, err := r.ReadMIMEHeader() + if err != nil { + b.Fatal(err) + } + if _, ok := h["Uncommon-Header-For-Benchmark"]; !ok { + b.Fatal("Missing result header.") + } + } +} diff --git a/textproto/textproto.go b/textproto/textproto.go new file mode 100644 index 00000000..70038d58 --- /dev/null +++ b/textproto/textproto.go @@ -0,0 +1,152 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. +// +// The package provides: +// +// Error, which represents a numeric error response from +// a server. +// +// Pipeline, to manage pipelined requests and responses +// in a client. +// +// Reader, to read numeric response code lines, +// key: value headers, lines wrapped with leading spaces +// on continuation lines, and whole text blocks ending +// with a dot on a line by itself. +// +// Writer, to write dot-encoded text blocks. +// +// Conn, a convenient packaging of Reader, Writer, and Pipeline for use +// with a single network connection. +package textproto + +import ( + "bufio" + "fmt" + "io" + "net" +) + +// An Error represents a numeric error response from a server. +type Error struct { + Code int + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +// A ProtocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type ProtocolError string + +func (p ProtocolError) Error() string { + return string(p) +} + +// A Conn represents a textual network protocol connection. +// It consists of a Reader and Writer to manage I/O +// and a Pipeline to sequence concurrent requests on the connection. +// These embedded types carry methods with them; +// see the documentation of those types for details. +type Conn struct { + Reader + Writer + Pipeline + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using conn for I/O. +func NewConn(conn io.ReadWriteCloser) *Conn { + return &Conn{ + Reader: Reader{R: bufio.NewReader(conn)}, + Writer: Writer{W: bufio.NewWriter(conn)}, + conn: conn, + } +} + +// Close closes the connection. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, error) { + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// Cmd is a convenience method that sends a command after +// waiting its turn in the pipeline. The command text is the +// result of formatting format with args and appending \r\n. +// Cmd returns the id of the command, for use with StartResponse and EndResponse. +// +// For example, a client might run a HELP command that returns a dot-body +// by using: +// +// id, err := c.Cmd("HELP") +// if err != nil { +// return nil, err +// } +// +// c.StartResponse(id) +// defer c.EndResponse(id) +// +// if _, _, err = c.ReadCodeLine(110); err != nil { +// return nil, err +// } +// text, err := c.ReadDotBytes() +// if err != nil { +// return nil, err +// } +// return c.ReadCodeLine(250) +func (c *Conn) Cmd(format string, args ...any) (id uint, err error) { + id = c.Next() + c.StartRequest(id) + err = c.PrintfLine(format, args...) + c.EndRequest(id) + if err != nil { + return 0, err + } + return id, nil +} + +// TrimString returns s without leading and trailing ASCII space. +func TrimString(s string) string { + for len(s) > 0 && isASCIISpace(s[0]) { + s = s[1:] + } + for len(s) > 0 && isASCIISpace(s[len(s)-1]) { + s = s[:len(s)-1] + } + return s +} + +// TrimBytes returns b without leading and trailing ASCII space. +func TrimBytes(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[0]) { + b = b[1:] + } + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isASCIILetter(b byte) bool { + b |= 0x20 // make lower case + return 'a' <= b && b <= 'z' +} diff --git a/textproto/writer.go b/textproto/writer.go new file mode 100644 index 00000000..2ece3f51 --- /dev/null +++ b/textproto/writer.go @@ -0,0 +1,119 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "fmt" + "io" +) + +// A Writer implements convenience methods for writing +// requests or responses to a text protocol network connection. +type Writer struct { + W *bufio.Writer + dot *dotWriter +} + +// NewWriter returns a new Writer writing to w. +func NewWriter(w *bufio.Writer) *Writer { + return &Writer{W: w} +} + +var crnl = []byte{'\r', '\n'} +var dotcrnl = []byte{'.', '\r', '\n'} + +// PrintfLine writes the formatted output followed by \r\n. +func (w *Writer) PrintfLine(format string, args ...any) error { + w.closeDot() + fmt.Fprintf(w.W, format, args...) + w.W.Write(crnl) + return w.W.Flush() +} + +// DotWriter returns a writer that can be used to write a dot-encoding to w. +// It takes care of inserting leading dots when necessary, +// translating line-ending \n into \r\n, and adding the final .\r\n line +// when the DotWriter is closed. The caller should close the +// DotWriter before the next call to a method on w. +// +// See the documentation for Reader's DotReader method for details about dot-encoding. +func (w *Writer) DotWriter() io.WriteCloser { + w.closeDot() + w.dot = &dotWriter{w: w} + return w.dot +} + +func (w *Writer) closeDot() { + if w.dot != nil { + w.dot.Close() // sets w.dot = nil + } +} + +type dotWriter struct { + w *Writer + state int +} + +const ( + wstateBegin = iota // initial state; must be zero + wstateBeginLine // beginning of line + wstateCR // wrote \r (possibly at end of line) + wstateData // writing data in middle of line +) + +func (d *dotWriter) Write(b []byte) (n int, err error) { + bw := d.w.W + for n < len(b) { + c := b[n] + switch d.state { + case wstateBegin, wstateBeginLine: + d.state = wstateData + if c == '.' { + // escape leading dot + bw.WriteByte('.') + } + fallthrough + + case wstateData: + if c == '\r' { + d.state = wstateCR + } + if c == '\n' { + bw.WriteByte('\r') + d.state = wstateBeginLine + } + + case wstateCR: + d.state = wstateData + if c == '\n' { + d.state = wstateBeginLine + } + } + if err = bw.WriteByte(c); err != nil { + break + } + n++ + } + return +} + +func (d *dotWriter) Close() error { + if d.w.dot == d { + d.w.dot = nil + } + bw := d.w.W + switch d.state { + default: + bw.WriteByte('\r') + fallthrough + case wstateCR: + bw.WriteByte('\n') + fallthrough + case wstateBeginLine: + bw.Write(dotcrnl) + } + return bw.Flush() +} diff --git a/textproto/writer_test.go b/textproto/writer_test.go new file mode 100644 index 00000000..8f11b107 --- /dev/null +++ b/textproto/writer_test.go @@ -0,0 +1,61 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "strings" + "testing" +) + +func TestPrintfLine(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + err := w.PrintfLine("foo %d", 123) + if s := buf.String(); s != "foo 123\r\n" || err != nil { + t.Fatalf("s=%q; err=%s", s, err) + } +} + +func TestDotWriter(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n.")) + if n != 21 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q", s) + } +} + +func TestDotWriterCloseEmptyWrite(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte{}) + if n != 0 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q; want %q", s, want) + } +} + +func TestDotWriterCloseNoWrite(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + d.Close() + want := "\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q; want %q", s, want) + } +} From 9d9ff5dddc8475f703e486e15deda2c3e9465bf9 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 3 Jun 2023 10:50:48 +0200 Subject: [PATCH 02/26] feat: add header order --- h2_bundle.go | 114 +++++++++++++++++++++++++++++++++----------- header.go | 87 +++++++++++++++++++++++++++++++-- request.go | 33 ++++--------- response.go | 2 +- stdlibwrapper.go | 10 ++-- textproto/reader.go | 2 + transfer.go | 49 +++++++++++++++++++ transport.go | 4 ++ 8 files changed, 237 insertions(+), 64 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 529c43cd..c6de0adb 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -7755,6 +7755,8 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client initialSettings := []http2Setting{ {ID: http2SettingEnablePush, Val: 0}, {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + {ID: http2SettingHeaderTableSize, Val: http2initialHeaderTableSize}, + {ID: http2SettingMaxConcurrentStreams, Val: cc.maxConcurrentStreams}, } if max := t.maxFrameReadSize(); max != 0 { initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max}) @@ -8828,6 +8830,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // continue to reuse the hpack encoder for future requests) for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { + // If the header is magic key, the headers would have been ordered + // by this step. It is ok to delete and not raise an error + if k == HeaderOrderKey || k == PHeaderOrderKey { + continue + } return nil, fmt.Errorf("invalid HTTP header name %q", k) } for _, v := range vv { @@ -8844,54 +8851,79 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // target URI (the path-absolute production and optionally a '?' character // followed by the query production (see Sections 3.3 and 3.4 of // [RFC3986]). - f(":authority", host) + pHeaderOrder, ok := req.Header[PHeaderOrderKey] m := req.Method if m == "" { m = MethodGet } - f(":method", m) - if req.Method != "CONNECT" { - f(":path", path) - f(":scheme", req.URL.Scheme) + if ok { + // follow based on pseudo header order + for _, p := range pHeaderOrder { + switch p { + case ":authority": + f(":authority", host) + case ":method": + f(":method", req.Method) + case ":path": + if req.Method != "CONNECT" { + f(":path", path) + } + case ":scheme": + if req.Method != "CONNECT" { + f(":scheme", req.URL.Scheme) + } + + // (zMrKrabz): Currently skips over unrecognized pheader fields, + // should throw error or something but works for now. + default: + continue + } + } + } else { + f(":authority", host) + f(":method", m) + if req.Method != "CONNECT" { + f(":path", path) + f(":scheme", req.URL.Scheme) + } } if trailers != "" { f("trailer", trailers) } var didUA bool - for k, vv := range req.Header { - if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") { + + var kvs []keyValues + if headerOrder, ok := req.Header[HeaderOrderKey]; ok { + order := make(map[string]int) + for i, v := range headerOrder { + order[v] = i + } + + kvs, _ = req.Header.sortedKeyValuesBy(order, make(map[string]bool)) + } else { + kvs, _ = req.Header.sortedKeyValues(make(map[string]bool)) + } + + for _, kv := range kvs { + + if strings.EqualFold(kv.key, "host") || strings.EqualFold(kv.key, "content-length") { // Host is :authority, already sent. // Content-Length is automatic, set below. continue - } else if http2asciiEqualFold(k, "connection") || - http2asciiEqualFold(k, "proxy-connection") || - http2asciiEqualFold(k, "transfer-encoding") || - http2asciiEqualFold(k, "upgrade") || - http2asciiEqualFold(k, "keep-alive") { + } else if strings.EqualFold(kv.key, "connection") || strings.EqualFold(kv.key, "proxy-connection") || + strings.EqualFold(kv.key, "transfer-encoding") || strings.EqualFold(kv.key, "upgrade") || + strings.EqualFold(kv.key, "keep-alive") { // Per 8.1.2.2 Connection-Specific Header // Fields, don't send connection-specific // fields. We have already checked if any // are error-worthy so just ignore the rest. continue - } else if http2asciiEqualFold(k, "user-agent") { - // Match Go's http1 behavior: at most one - // User-Agent. If set to nil or empty string, - // then omit it. Otherwise if not mentioned, - // include the default (below). - didUA = true - if len(vv) < 1 { - continue - } - vv = vv[:1] - if vv[0] == "" { - continue - } - } else if http2asciiEqualFold(k, "cookie") { + } else if strings.EqualFold(kv.key, "cookie") { // Per 8.1.2.5 To allow for better compression efficiency, the // Cookie header field MAY be split into separate header fields, // each with one or more cookie-pairs. - for _, v := range vv { + for _, v := range kv.values { for { p := strings.IndexByte(v, ';') if p < 0 { @@ -8910,15 +8942,34 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail } } continue + } else if strings.EqualFold(kv.key, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(kv.values) > 1 { + kv.values = kv.values[:1] + } + + if kv.values[0] == "" { + continue + } + + } else if strings.EqualFold(kv.key, "accept-encoding") { + addGzipHeader = false } - for _, v := range vv { - f(k, v) + for _, v := range kv.values { + f(kv.key, v) } } + if http2shouldSendReqContentLength(req.Method, contentLength) { f("content-length", strconv.FormatInt(contentLength, 10)) } + + // Does not include accept-encoding header if its defined in req.Header if addGzipHeader { f("accept-encoding", "gzip") } @@ -8946,6 +8997,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { + // skips over writing magic key headers + if name == PHeaderOrderKey || name == HeaderOrderKey { + return + } + name, ascii := http2lowerHeader(name) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header diff --git a/header.go b/header.go index 5f453803..fec9ac3a 100644 --- a/header.go +++ b/header.go @@ -23,6 +23,25 @@ import ( // CanonicalHeaderKey. type Header map[string][]string +// HeaderOrderKey is a magic Key for ResponseWriter.Header map keys +// that, if present, defines a header order that will be used to +// write the headers onto wire. The order of the slice defined how the headers +// will be sorted. A defined Key goes before an undefined Key. +// +// This is the only way to specify some order, because maps don't +// have a stable iteration order. If no order is given, headers will +// be sorted lexicographically. +// +// According to RFC2616 it is good practice sending general-header fields +// first, followed by request-header or response-header fields and ending +// with entity-header fields. +const HeaderOrderKey = "Header-Order:" + +// PHeaderOrderKey is a magic Key for setting http2 pseudo header order. +// If the header is nil it will use regular GoLang header order. +// Valid fields are :authority, :method, :path, :scheme +const PHeaderOrderKey = "PHeader-Order:" + // Add adds the key, value pair to the header. // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by @@ -156,17 +175,34 @@ type keyValues struct { // by key. It's used as a pointer, so it can fit in a sort.Interface // interface value without allocation. type headerSorter struct { - kvs []keyValues + kvs []keyValues + order map[string]int } -func (s *headerSorter) Len() int { return len(s.kvs) } -func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { + if s.order == nil { + return s.kvs[i].key < s.kvs[j].key + } + idxi, iok := s.order[s.kvs[i].key] + idxj, jok := s.order[s.kvs[j].key] + if !iok && !jok { + return s.kvs[i].key < s.kvs[j].key + } else if !iok && jok { + return false + } else if iok && !jok { + return true + } + return idxi < idxj +} var headerSorterPool = sync.Pool{ New: func() any { return new(headerSorter) }, } +var mutex = &sync.RWMutex{} + // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. @@ -177,11 +213,32 @@ func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *h } kvs = hs.kvs[:0] for k, vv := range h { + mutex.RLock() + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + mutex.RUnlock() + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +func (h Header) sortedKeyValuesBy(order map[string]int, exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + mutex.RLock() if !exclude[k] { kvs = append(kvs, keyValues{k, vv}) } + mutex.RUnlock() } hs.kvs = kvs + hs.order = order sort.Sort(hs) return kvs, hs } @@ -198,7 +255,27 @@ func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptra if !ok { ws = stringWriter{w} } - kvs, sorter := h.sortedKeyValues(exclude) + + var kvs []keyValues + var sorter *headerSorter + // Check if the HeaderOrder is defined. + if headerOrder, ok := h[HeaderOrderKey]; ok { + order := make(map[string]int) + for i, v := range headerOrder { + order[v] = i + } + if exclude == nil { + exclude = make(map[string]bool) + } + mutex.Lock() + exclude[HeaderOrderKey] = true + exclude[PHeaderOrderKey] = true + mutex.Unlock() + kvs, sorter = h.sortedKeyValuesBy(order, exclude) + } else { + kvs, sorter = h.sortedKeyValues(exclude) + } + var formattedVals []string for _, kv := range kvs { if !httpguts.ValidHeaderFieldName(kv.key) { diff --git a/request.go b/request.go index 2bcc77f4..534d10b4 100644 --- a/request.go +++ b/request.go @@ -14,11 +14,11 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/ooni/oohttp/textproto" "io" "mime" "mime/multipart" "net" - "net/textproto" "net/url" urlpkg "net/url" "strconv" @@ -620,28 +620,15 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF return err } - // Header lines - _, err = fmt.Fprintf(w, "Host: %s\r\n", host) - if err != nil { - return err - } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Host", []string{host}) + if _, ok := r.Header["Host"]; !ok { + if _, ok := r.Header["host"]; !ok { + r.Header.Set("Host", host) + } } - // Use the defaultUserAgent unless the Header contains one, which - // may be blank to not send the header. - userAgent := defaultUserAgent - if r.Header.has("User-Agent") { - userAgent = r.Header.Get("User-Agent") - } - if userAgent != "" { - _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) - if err != nil { - return err - } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("User-Agent", []string{userAgent}) + if _, ok := r.Header["User-Agent"]; !ok { + if _, ok := r.Header["user-agent"]; !ok { + r.Header.Set("User-Agent", defaultUserAgent) } } @@ -650,12 +637,12 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF if err != nil { return err } - err = tw.writeHeader(w, trace) + err = tw.addHeaders(&r.Header, trace) if err != nil { return err } - err = r.Header.writeSubset(w, reqWriteExcludeHeader, trace) + err = r.Header.write(w, trace) if err != nil { return err } diff --git a/response.go b/response.go index 755c6965..b46eff2a 100644 --- a/response.go +++ b/response.go @@ -12,8 +12,8 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/ooni/oohttp/textproto" "io" - "net/textproto" "net/url" "strconv" "strings" diff --git a/stdlibwrapper.go b/stdlibwrapper.go index f0187744..28490d94 100644 --- a/stdlibwrapper.go +++ b/stdlibwrapper.go @@ -1,7 +1,5 @@ package http -import "net/http" - // StdlibTransport is an adapter for integrating net/http dependend code. // It looks like an http.RoundTripper but uses this fork internally. type StdlibTransport struct { @@ -9,7 +7,7 @@ type StdlibTransport struct { } // RoundTrip implements the http.RoundTripper interface. -func (txp *StdlibTransport) RoundTrip(stdReq *http.Request) (*http.Response, error) { +func (txp *StdlibTransport) RoundTrip(stdReq *Request) (*Response, error) { req := &Request{ Method: stdReq.Method, URL: stdReq.URL, @@ -38,19 +36,19 @@ func (txp *StdlibTransport) RoundTrip(stdReq *http.Request) (*http.Response, err if err != nil { return nil, err } - stdResp := &http.Response{ + stdResp := &Response{ Status: resp.Status, StatusCode: resp.StatusCode, Proto: resp.Proto, ProtoMinor: resp.ProtoMinor, ProtoMajor: resp.ProtoMajor, - Header: http.Header(resp.Header), + Header: resp.Header, Body: resp.Body, ContentLength: resp.ContentLength, TransferEncoding: resp.TransferEncoding, Close: resp.Close, Uncompressed: resp.Uncompressed, - Trailer: http.Header(resp.Trailer), + Trailer: resp.Trailer, Request: stdReq, TLS: resp.TLS, } diff --git a/textproto/reader.go b/textproto/reader.go index fc2590b1..b4483721 100644 --- a/textproto/reader.go +++ b/textproto/reader.go @@ -573,6 +573,8 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) m[key] = append(vv, value) } + m["Header-Order:"] = append(m["Header-Order:"], key) + if err != nil { return m, err } diff --git a/transfer.go b/transfer.go index cc7d98d0..f116d65a 100644 --- a/transfer.go +++ b/transfer.go @@ -273,6 +273,55 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } +// addHeaders adds transfer headers to an existing header object +func (t *transferWriter) addHeaders(hdrs *Header, trace *httptrace.ClientTrace) error { + if t.Close && !hasToken(t.Header.get("Connection"), "close") { + hdrs.Add("Connection", "close") + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Connection", []string{"close"}) + } + } + + // Write Content-Length and/or Transfer-Encoding whose Values are a + // function of the sanitized field triple (Body, ContentLength, + // TransferEncoding) + if t.shouldSendContentLength() { + hdrs.Add("Content-Length", strconv.FormatInt(t.ContentLength, 10)) + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)}) + } + } else if chunked(t.TransferEncoding) { + hdrs.Add("Transfer-Encoding", "chunked") + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"}) + } + } + + // Write Trailer header + if t.Trailer != nil { + keys := make([]string, 0, len(t.Trailer)) + for k := range t.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return badStringError("invalid Trailer Key", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + hdrs.Add("Trailer", strings.Join(keys, ",")) + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Trailer", keys) + } + } + } + + return nil +} + func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error { if t.Close && !hasToken(t.Header.get("Connection"), "close") { if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { diff --git a/transport.go b/transport.go index b422a919..fd3c695e 100644 --- a/transport.go +++ b/transport.go @@ -529,6 +529,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if isHTTP { for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { + // Allow the HeaderOrderKey and PHeaderOrderKey magic string, this will be handled further. + if k == HeaderOrderKey || k == PHeaderOrderKey { + continue + } req.closeBody() return nil, fmt.Errorf("net/http: invalid header field name %q", k) } From 5878d39f799b79a746524608685e1afda405a5c2 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 3 Jun 2023 11:28:24 +0200 Subject: [PATCH 03/26] feat: add server header order --- server.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/server.go b/server.go index a0d6b57e..d6785011 100644 --- a/server.go +++ b/server.go @@ -486,6 +486,8 @@ type response struct { // non-nil. Make this lazily-created again as it used to be? closeNotifyCh chan bool didCloseNotify atomic.Bool // atomic (only false->true winner should send) + // Option for opt-in sorting headers by defined order in a special header. + enableOrderHeaders bool } func (c *response) SetReadDeadline(deadline time.Time) error { @@ -1011,6 +1013,9 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { + if k == HeaderOrderKey || k == PHeaderOrderKey { + continue + } return nil, badRequestError("invalid header name") } for _, v := range vv { @@ -1132,6 +1137,9 @@ func relevantCaller() runtime.Frame { } func (w *response) WriteHeader(code int) { + if _, ok := w.handlerHeader[HeaderOrderKey]; ok && !w.enableOrderHeaders { + delete(w.handlerHeader, HeaderOrderKey) + } if w.conn.hijacked() { caller := relevantCaller() w.conn.server.logf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) @@ -3621,6 +3629,16 @@ func strSliceContains(ss []string, s string) bool { return false } +// EnableHeaderOrder set the option to enable the ResponseWriter to use the +// HeaderOrderKey in its headers, for sorting them. +func EnableHeaderOrder(writer ResponseWriter) ResponseWriter { + if res, ok := writer.(*response); ok { + res.enableOrderHeaders = true + return res + } + return writer +} + // tlsRecordHeaderLooksLikeHTTP reports whether a TLS record header // looks like it might've been a misdirected plaintext HTTP request. func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { From 6ac0447b1a8ec184729ef0b03ee9e02c850e6149 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 3 Jun 2023 12:58:12 +0200 Subject: [PATCH 04/26] feat: enable custom initial HTTP2 SETTINGS frame --- h2_bundle.go | 38 ++++++++++++++++++++++++-------------- transport.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index c6de0adb..e3524c0d 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -7752,20 +7752,30 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client cc.tlsState = &state } - initialSettings := []http2Setting{ - {ID: http2SettingEnablePush, Val: 0}, - {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, - {ID: http2SettingHeaderTableSize, Val: http2initialHeaderTableSize}, - {ID: http2SettingMaxConcurrentStreams, Val: cc.maxConcurrentStreams}, - } - if max := t.maxFrameReadSize(); max != 0 { - initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max}) - } - if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) - } - if maxHeaderTableSize != http2initialHeaderTableSize { - initialSettings = append(initialSettings, http2Setting{ID: http2SettingHeaderTableSize, Val: maxHeaderTableSize}) + var initialSettings []http2Setting + if t.t1.hasCustomInitialSettings { + initialSettings = []http2Setting{ + {ID: http2SettingHeaderTableSize, Val: t.t1.HeaderTableSize}, + {ID: http2SettingEnablePush, Val: t.t1.EnablePush}, + {ID: http2SettingMaxConcurrentStreams, Val: t.t1.MaxConcurrentStreams}, + {ID: http2SettingInitialWindowSize, Val: t.t1.InitialWindowSize}, + {ID: http2SettingMaxFrameSize, Val: t.t1.MaxFrameSize}, + {ID: http2SettingMaxHeaderListSize, Val: t.t1.MaxHeaderListSize}, + } + } else { + initialSettings = []http2Setting{ + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + } + if max := t.maxFrameReadSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max}) + } + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) + } + if maxHeaderTableSize != http2initialHeaderTableSize { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingHeaderTableSize, Val: maxHeaderTableSize}) + } } cc.bw.Write(http2clientPreface) diff --git a/transport.go b/transport.go index fd3c695e..fcc8680c 100644 --- a/transport.go +++ b/transport.go @@ -292,6 +292,45 @@ type Transport struct { // DialTLSContext function, you'll completely bypass this // per-Transport-or-global TLSClientFactory mechanism.) TLSClientFactory func(conn net.Conn, config *tls.Config) TLSConn + + hasCustomInitialSettings bool + + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to + // send in the initial settings frame. It is how many bytes + // of response headers are allowed. Unlike the http2 spec, zero here + // means to use a default limit (currently 10MB). If you actually + // want to advertise an unlimited value to the peer, Transport + // interprets the highest possible value here (0xffffffff or 1<<32-1) + // to mean no limit. + MaxHeaderListSize uint32 + + // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the + // initial settings frame. It is the size in bytes of the largest frame + // payload that the sender is willing to receive. If 0, no setting is + // sent, and the value is provided by the peer, which should be 16384 + // according to the spec: + // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2. + // Values are bounded in the range 16k to 16M. + MaxFrameSize uint32 + + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It + // informs the remote endpoint of the maximum size of the header compression + // table used to decode header blocks, in octets. If zero, the default value + // of 4096 is used. + HeaderTableSize uint32 + + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_ENABLE_PUSH to send in the initial settings frame. + EnablePush uint32 + + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_MAX_CONCURRENT_STREAMS to send in the initial settings frame. + MaxConcurrentStreams uint32 + + // MaxDecoderHeaderTableSize optionally specifies the http2 + // SETTINGS_INITIAL_WINDOW_SIZE to send in the initial settings frame. + InitialWindowSize uint32 } // A cancelKey is the key of the reqCanceler map. @@ -301,6 +340,10 @@ type cancelKey struct { req *Request } +func (t *Transport) EnableCustomInitialSettings() { + t.hasCustomInitialSettings = true +} + func (t *Transport) writeBufferSize() int { if t.WriteBufferSize > 0 { return t.WriteBufferSize From 04b2891c342cba5c598018fd5233be8f86925857 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Mon, 18 Sep 2023 18:53:53 -0500 Subject: [PATCH 05/26] fix: incompatibilities --- request.go | 1 - stdlibwrapper.go | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/request.go b/request.go index 61ae82e4..00cc74b6 100644 --- a/request.go +++ b/request.go @@ -18,7 +18,6 @@ import ( "io" "mime" "mime/multipart" - "net" "net/url" urlpkg "net/url" "strconv" diff --git a/stdlibwrapper.go b/stdlibwrapper.go index 28490d94..f0187744 100644 --- a/stdlibwrapper.go +++ b/stdlibwrapper.go @@ -1,5 +1,7 @@ package http +import "net/http" + // StdlibTransport is an adapter for integrating net/http dependend code. // It looks like an http.RoundTripper but uses this fork internally. type StdlibTransport struct { @@ -7,7 +9,7 @@ type StdlibTransport struct { } // RoundTrip implements the http.RoundTripper interface. -func (txp *StdlibTransport) RoundTrip(stdReq *Request) (*Response, error) { +func (txp *StdlibTransport) RoundTrip(stdReq *http.Request) (*http.Response, error) { req := &Request{ Method: stdReq.Method, URL: stdReq.URL, @@ -36,19 +38,19 @@ func (txp *StdlibTransport) RoundTrip(stdReq *Request) (*Response, error) { if err != nil { return nil, err } - stdResp := &Response{ + stdResp := &http.Response{ Status: resp.Status, StatusCode: resp.StatusCode, Proto: resp.Proto, ProtoMinor: resp.ProtoMinor, ProtoMajor: resp.ProtoMajor, - Header: resp.Header, + Header: http.Header(resp.Header), Body: resp.Body, ContentLength: resp.ContentLength, TransferEncoding: resp.TransferEncoding, Close: resp.Close, Uncompressed: resp.Uncompressed, - Trailer: resp.Trailer, + Trailer: http.Header(resp.Trailer), Request: stdReq, TLS: resp.TLS, } From 411d5894992ff634e23fa853dd9ad1ab5d658490 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Fri, 22 Sep 2023 12:30:50 -0500 Subject: [PATCH 06/26] fix: http 1.1 and capitalization inconsistency --- header.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/header.go b/header.go index fec9ac3a..06836427 100644 --- a/header.go +++ b/header.go @@ -185,8 +185,10 @@ func (s *headerSorter) Less(i, j int) bool { if s.order == nil { return s.kvs[i].key < s.kvs[j].key } - idxi, iok := s.order[s.kvs[i].key] - idxj, jok := s.order[s.kvs[j].key] + // idxi, iok := s.order[s.kvs[i].key] + // idxj, jok := s.order[s.kvs[j].key] + idxi, iok := s.order[strings.ToLower(s.kvs[i].key)] + idxj, jok := s.order[strings.ToLower(s.kvs[j].key)] if !iok && !jok { return s.kvs[i].key < s.kvs[j].key } else if !iok && jok { From b754cb3caa8fd58cd00e589847c79cc5545ba20b Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Fri, 22 Sep 2023 17:31:00 -0500 Subject: [PATCH 07/26] feat: all headers should be order --- h2_bundle.go | 48 +++++++++++++++++++++++++++++++----------------- request.go | 6 +++--- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index e3524c0d..fb36cbca 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -8901,7 +8901,36 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail f("trailer", trailers) } - var didUA bool + if http2shouldSendReqContentLength(req.Method, contentLength) { + req.Header.Set("content-length", strconv.FormatInt(contentLength, 10)) + } + + // Does not include accept-encoding header if its defined in req.Header + _, addGzipHeader = req.Header["accept-encoding"] + if !addGzipHeader { // presence check + req.Header.Set("accept-encoding", "gzip") + // we just aded it, set to true + addGzipHeader = true + } else { + // we didnt add it + addGzipHeader = false + } + + UA, didUA := req.Header["user-agent"] + if didUA { + switch len(UA) { + case 0: + // Default to to default UA if none provided + req.Header.Set("user-agent", defaultUserAgent) + case 1: + // Don't do anything for UA provided as expected + break + default: + // Unexpected UA provided, only take first element + req.Header.Set("user-agent", UA[0]) + + } + } var kvs []keyValues if headerOrder, ok := req.Header[HeaderOrderKey]; ok { @@ -8917,9 +8946,8 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail for _, kv := range kvs { - if strings.EqualFold(kv.key, "host") || strings.EqualFold(kv.key, "content-length") { + if strings.EqualFold(kv.key, "host") { // Host is :authority, already sent. - // Content-Length is automatic, set below. continue } else if strings.EqualFold(kv.key, "connection") || strings.EqualFold(kv.key, "proxy-connection") || strings.EqualFold(kv.key, "transfer-encoding") || strings.EqualFold(kv.key, "upgrade") || @@ -8966,26 +8994,12 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail continue } - } else if strings.EqualFold(kv.key, "accept-encoding") { - addGzipHeader = false } for _, v := range kv.values { f(kv.key, v) } } - - if http2shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) - } - - // Does not include accept-encoding header if its defined in req.Header - if addGzipHeader { - f("accept-encoding", "gzip") - } - if !didUA { - f("user-agent", http2defaultUserAgent) - } } // Do a first pass over the headers counting bytes to ensure diff --git a/request.go b/request.go index 00cc74b6..d9c81bab 100644 --- a/request.go +++ b/request.go @@ -654,10 +654,10 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF return err } + // Make sure can be ordered too Accept-Encoding, Connection if extraHeaders != nil { - err = extraHeaders.write(w, trace) - if err != nil { - return err + for key, values := range extraHeaders { + r.Header[key] = values } } From 1d02990de13cbbca22fc70e1c7b4131ecebf88b9 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Wed, 10 Jan 2024 20:51:35 -0600 Subject: [PATCH 08/26] feat: add http2 framing settings, priority header params and priority frames support --- h2_bundle.go | 54 ++++++++++++++++++++----- transport.go | 111 +++++++++++++++++++++++---------------------------- 2 files changed, 94 insertions(+), 71 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index fb36cbca..8c4cd937 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1659,7 +1659,8 @@ type http2Framer struct { debugReadLoggerf func(string, ...interface{}) debugWriteLoggerf func(string, ...interface{}) - frameCache *http2frameCache // nil if frames aren't reused (default) + frameCache *http2frameCache // nil if frames aren't reused (default) + HeaderPriority *http2PriorityParam } func (fr *http2Framer) maxHeaderListSize() uint32 { @@ -2419,6 +2420,9 @@ type http2HeadersFrameParam struct { // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if f.HeaderPriority != nil { + p.Priority = *f.HeaderPriority + } if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { return http2errStreamID } @@ -7753,14 +7757,13 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client } var initialSettings []http2Setting - if t.t1.hasCustomInitialSettings { - initialSettings = []http2Setting{ - {ID: http2SettingHeaderTableSize, Val: t.t1.HeaderTableSize}, - {ID: http2SettingEnablePush, Val: t.t1.EnablePush}, - {ID: http2SettingMaxConcurrentStreams, Val: t.t1.MaxConcurrentStreams}, - {ID: http2SettingInitialWindowSize, Val: t.t1.InitialWindowSize}, - {ID: http2SettingMaxFrameSize, Val: t.t1.MaxFrameSize}, - {ID: http2SettingMaxHeaderListSize, Val: t.t1.MaxHeaderListSize}, + if t.t1.HasCustomInitialSettings { + for id, value := range t.t1.HTTP2SettingsFrameParameters { + if value < 0 || value > 4294967295 { + // Skip because value is invalid + continue + } + initialSettings = append(initialSettings, http2Setting{ID: http2SettingID(id + 1), Val: uint32(value)}) } } else { initialSettings = []http2Setting{ @@ -7778,9 +7781,40 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client } } + windowUpdateIncrement := uint32(http2transportDefaultConnFlow) + if t.t1.HasCustomWindowUpdate { + windowUpdateIncrement = t.t1.WindowUpdateIncrement + } + cc.bw.Write(http2clientPreface) cc.fr.WriteSettings(initialSettings...) - cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.fr.WriteWindowUpdate(0, windowUpdateIncrement) + // cc.addStreamLocked() + + if t.t1.HTTP2PriorityFrameSettings != nil { + if t.t1.HTTP2PriorityFrameSettings.HeaderFrame != nil { + cc.fr.HeaderPriority = &http2PriorityParam{ + StreamDep: t.t1.HTTP2PriorityFrameSettings.HeaderFrame.StreamDep, + Exclusive: t.t1.HTTP2PriorityFrameSettings.HeaderFrame.Exclusive, + Weight: t.t1.HTTP2PriorityFrameSettings.HeaderFrame.Weight, + } + } + + for streamId, priority := range t.t1.HTTP2PriorityFrameSettings.PriorityFrames { + cc.mu.Lock() + cc.nextStreamID++ + cc.mu.Unlock() + if priority == nil { + continue + } + cc.fr.WritePriority(uint32((streamId)), http2PriorityParam{ + StreamDep: priority.StreamDep, + Exclusive: priority.Exclusive, + Weight: priority.Weight, + }) + } + } + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) cc.bw.Flush() if cc.werr != nil { diff --git a/transport.go b/transport.go index fcc8680c..93145de8 100644 --- a/transport.go +++ b/transport.go @@ -293,44 +293,29 @@ type Transport struct { // per-Transport-or-global TLSClientFactory mechanism.) TLSClientFactory func(conn net.Conn, config *tls.Config) TLSConn - hasCustomInitialSettings bool - - // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to - // send in the initial settings frame. It is how many bytes - // of response headers are allowed. Unlike the http2 spec, zero here - // means to use a default limit (currently 10MB). If you actually - // want to advertise an unlimited value to the peer, Transport - // interprets the highest possible value here (0xffffffff or 1<<32-1) - // to mean no limit. - MaxHeaderListSize uint32 - - // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the - // initial settings frame. It is the size in bytes of the largest frame - // payload that the sender is willing to receive. If 0, no setting is - // sent, and the value is provided by the peer, which should be 16384 - // according to the spec: - // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2. - // Values are bounded in the range 16k to 16M. - MaxFrameSize uint32 - - // MaxDecoderHeaderTableSize optionally specifies the http2 - // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It - // informs the remote endpoint of the maximum size of the header compression - // table used to decode header blocks, in octets. If zero, the default value - // of 4096 is used. - HeaderTableSize uint32 - - // MaxDecoderHeaderTableSize optionally specifies the http2 - // SETTINGS_ENABLE_PUSH to send in the initial settings frame. - EnablePush uint32 - - // MaxDecoderHeaderTableSize optionally specifies the http2 - // SETTINGS_MAX_CONCURRENT_STREAMS to send in the initial settings frame. - MaxConcurrentStreams uint32 - - // MaxDecoderHeaderTableSize optionally specifies the http2 - // SETTINGS_INITIAL_WINDOW_SIZE to send in the initial settings frame. - InitialWindowSize uint32 + HasCustomInitialSettings bool + HasCustomWindowUpdate bool + + HTTP2PriorityFrameSettings *HTTP2PriorityFrameSettings + + // HTTP2SettingsFrameParameters contains all the parameters you can send in the SETTINGS frame. + // The index + 1 is equal to the parameter ID so index 0 would control HEADER_TABLE_SIZE etc + // If the value is -1 or larger than the max size of uint32, it will NOT be sent. Not all browsers send all frames. + HTTP2SettingsFrameParameters []int64 + + // increment to send in the WINDOW_UPDATE frame. + WindowUpdateIncrement uint32 +} + +type HTTP2PriorityFrameSettings struct { + PriorityFrames []*HTTP2Priority + HeaderFrame *HTTP2Priority +} + +type HTTP2Priority struct { + StreamDep uint32 + Exclusive bool + Weight uint8 } // A cancelKey is the key of the reqCanceler map. @@ -341,7 +326,7 @@ type cancelKey struct { } func (t *Transport) EnableCustomInitialSettings() { - t.hasCustomInitialSettings = true + t.HasCustomInitialSettings = true } func (t *Transport) writeBufferSize() int { @@ -362,28 +347,32 @@ func (t *Transport) readBufferSize() int { func (t *Transport) Clone() *Transport { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t2 := &Transport{ - Proxy: t.Proxy, - OnProxyConnectResponse: t.OnProxyConnectResponse, - DialContext: t.DialContext, - Dial: t.Dial, - DialTLS: t.DialTLS, - DialTLSContext: t.DialTLSContext, - TLSHandshakeTimeout: t.TLSHandshakeTimeout, - DisableKeepAlives: t.DisableKeepAlives, - DisableCompression: t.DisableCompression, - MaxIdleConns: t.MaxIdleConns, - MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, - MaxConnsPerHost: t.MaxConnsPerHost, - IdleConnTimeout: t.IdleConnTimeout, - ResponseHeaderTimeout: t.ResponseHeaderTimeout, - ExpectContinueTimeout: t.ExpectContinueTimeout, - ProxyConnectHeader: t.ProxyConnectHeader.Clone(), - GetProxyConnectHeader: t.GetProxyConnectHeader, - MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, - ForceAttemptHTTP2: t.ForceAttemptHTTP2, - WriteBufferSize: t.WriteBufferSize, - ReadBufferSize: t.ReadBufferSize, - TLSClientFactory: t.TLSClientFactory, + Proxy: t.Proxy, + OnProxyConnectResponse: t.OnProxyConnectResponse, + DialContext: t.DialContext, + Dial: t.Dial, + DialTLS: t.DialTLS, + DialTLSContext: t.DialTLSContext, + TLSHandshakeTimeout: t.TLSHandshakeTimeout, + DisableKeepAlives: t.DisableKeepAlives, + DisableCompression: t.DisableCompression, + MaxIdleConns: t.MaxIdleConns, + MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, + MaxConnsPerHost: t.MaxConnsPerHost, + IdleConnTimeout: t.IdleConnTimeout, + ResponseHeaderTimeout: t.ResponseHeaderTimeout, + ExpectContinueTimeout: t.ExpectContinueTimeout, + ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + GetProxyConnectHeader: t.GetProxyConnectHeader, + MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, + ForceAttemptHTTP2: t.ForceAttemptHTTP2, + WriteBufferSize: t.WriteBufferSize, + ReadBufferSize: t.ReadBufferSize, + TLSClientFactory: t.TLSClientFactory, + HasCustomInitialSettings: t.HasCustomInitialSettings, + HasCustomWindowUpdate: t.HasCustomWindowUpdate, + HasHeaderPriority: t.HasHeaderPriority, + WindowUpdateIncrement: t.WindowUpdateIncrement, } if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() From 11c521938971e351de7c0216ac7cf2eec4173673 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Wed, 10 Jan 2024 20:58:54 -0600 Subject: [PATCH 09/26] feat: add http2 framing settings, priority header params and priority frames support --- transport.go | 1 - 1 file changed, 1 deletion(-) diff --git a/transport.go b/transport.go index 93145de8..8fc54fb9 100644 --- a/transport.go +++ b/transport.go @@ -371,7 +371,6 @@ func (t *Transport) Clone() *Transport { TLSClientFactory: t.TLSClientFactory, HasCustomInitialSettings: t.HasCustomInitialSettings, HasCustomWindowUpdate: t.HasCustomWindowUpdate, - HasHeaderPriority: t.HasHeaderPriority, WindowUpdateIncrement: t.WindowUpdateIncrement, } if t.TLSClientConfig != nil { From dce64bbf8c66a0d889a4d7d923b3bbea29a831cd Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Wed, 10 Jan 2024 22:46:17 -0600 Subject: [PATCH 10/26] fix: compression error caused by header table size --- h2_bundle.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/h2_bundle.go b/h2_bundle.go index 8c4cd937..2f088755 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1831,6 +1831,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) if err != nil { + fmt.Println("http2readFrameHeader") return nil, err } if fh.Length > fr.maxReadSize { @@ -1843,6 +1844,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) if err != nil { if ce, ok := err.(http2connError); ok { + fmt.Println("http2typeFrameParser") return nil, fr.connError(ce.Code, ce.Reason) } return nil, err @@ -2906,6 +2908,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr for { frag := hc.HeaderBlockFragment() if _, err := hdec.Write(frag); err != nil { + fr.debugReadLoggerf("http2: hdec.Write error: (%T) %v", err, err) return nil, http2ConnectionError(http2ErrCodeCompression) } @@ -2923,6 +2926,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr mh.http2HeadersFrame.invalidate() if err := hdec.Close(); err != nil { + fmt.Println("hdec.Close() ", err) return nil, http2ConnectionError(http2ErrCodeCompression) } if invalid != nil { @@ -3299,8 +3303,9 @@ const ( // HTTP/2's TLS setup. http2NextProtoTLS = "h2" + // TODO: This has consequences, should // https://httpwg.org/specs/rfc7540.html#SettingValues - http2initialHeaderTableSize = 4096 + http2initialHeaderTableSize = 4096 // 65536 http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size @@ -7683,6 +7688,12 @@ func (t *http2Transport) maxDecoderHeaderTableSize() uint32 { if v := t.MaxDecoderHeaderTableSize; v > 0 { return v } + // Needed else you may see connection error: COMPRESSION_ERROR upon hdec.Write(frag) + if t.t1.HTTP2SettingsFrameParameters != nil && len(t.t1.HTTP2SettingsFrameParameters) > 0 { + if t.t1.HTTP2SettingsFrameParameters[0] > 1 && t.t1.HTTP2SettingsFrameParameters[0] < 4294967295 { + return uint32(t.t1.HTTP2SettingsFrameParameters[0]) + } + } return http2initialHeaderTableSize } From 0d5c9f655c88deef316ecb0f3f2f972165075e1f Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Wed, 10 Jan 2024 22:46:41 -0600 Subject: [PATCH 11/26] cleanup --- h2_bundle.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 2f088755..834a4dce 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1831,7 +1831,6 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) if err != nil { - fmt.Println("http2readFrameHeader") return nil, err } if fh.Length > fr.maxReadSize { @@ -1844,7 +1843,6 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) if err != nil { if ce, ok := err.(http2connError); ok { - fmt.Println("http2typeFrameParser") return nil, fr.connError(ce.Code, ce.Reason) } return nil, err @@ -2926,7 +2924,6 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr mh.http2HeadersFrame.invalidate() if err := hdec.Close(); err != nil { - fmt.Println("hdec.Close() ", err) return nil, http2ConnectionError(http2ErrCodeCompression) } if invalid != nil { From 9852578fd7c201baa3cd593568f154f920f6e1e7 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Sun, 3 Nov 2024 15:53:12 -0600 Subject: [PATCH 12/26] feat: add ability to use browser encodings --- compression.go | 194 ++++++++++++++++++++++++++++++++++++++++++++ compression_test.go | 34 ++++++++ go.mod | 6 +- go.sum | 6 ++ h2_bundle.go | 10 ++- transport.go | 24 ++++-- 6 files changed, 262 insertions(+), 12 deletions(-) create mode 100644 compression.go create mode 100644 compression_test.go diff --git a/compression.go b/compression.go new file mode 100644 index 00000000..b1db6b23 --- /dev/null +++ b/compression.go @@ -0,0 +1,194 @@ +package http + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/lzw" + "compress/zlib" + "errors" + "fmt" + "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" + "io" +) + +type CompressionFactory func(writer io.Writer) (io.Writer, error) +type DecompressionFactory func(reader io.Reader) (io.Reader, error) + +var ( + defaultCompressionFactories = map[string]CompressionFactory{ + "": func(writer io.Writer) (io.Writer, error) { return writer, nil }, + "identity": func(writer io.Writer) (io.Writer, error) { return writer, nil }, + "gzip": func(writer io.Writer) (io.Writer, error) { return gzip.NewWriter(writer), nil }, + "zlib": func(writer io.Writer) (io.Writer, error) { return zlib.NewWriter(writer), nil }, + "br": func(writer io.Writer) (io.Writer, error) { return brotli.NewWriter(writer), nil }, + "deflate": func(writer io.Writer) (io.Writer, error) { return flate.NewWriter(writer, -1) }, + // TODO: Confirm compress + "compress": func(writer io.Writer) (io.Writer, error) { return lzw.NewWriter(writer, lzw.LSB, 8), nil }, + "zstd": func(writer io.Writer) (io.Writer, error) { return zstd.NewWriter(writer) }, + } + + defaultDecompressionFactories = map[string]DecompressionFactory{ + "": func(reader io.Reader) (io.Reader, error) { return reader, nil }, + "identity": func(reader io.Reader) (io.Reader, error) { return reader, nil }, + "gzip": func(reader io.Reader) (io.Reader, error) { return gzip.NewReader(reader) }, + "zlib": func(reader io.Reader) (io.Reader, error) { return zlib.NewReader(reader) }, + "br": func(reader io.Reader) (io.Reader, error) { return brotli.NewReader(reader), nil }, + "deflate": func(reader io.Reader) (io.Reader, error) { return flate.NewReader(reader), nil }, + "compress": func(reader io.Reader) (io.Reader, error) { return lzw.NewReader(reader, lzw.LSB, 8), nil }, + "zstd": func(reader io.Reader) (io.Reader, error) { return zstd.NewReader(reader) }, + } +) + +func compress(data []byte, compressions map[string]CompressionFactory, compressionOrder ...string) ([]byte, error) { + var ( + err error + writers []io.Writer + writer io.Writer + ) + + if compressions == nil { + compressions = defaultCompressionFactories + } + + dst := bytes.NewBuffer(nil) + writer = dst + + for idx, compression := range compressionOrder { + // fmt.Println(fmt.Sprintf("Compression added: %s", compression)) + mapping, ok := compressions[compression] + if !ok { + return nil, errors.New(compression + " is not supported") + } + writer, err = mapping(writer) + if err != nil { + return nil, fmt.Errorf("mapping[%d:%s]: %w", idx, compression, err) + } + + writers = append(writers, writer) + } + + _, err = writers[len(writers)-1].Write(data) + if err != nil { + return nil, fmt.Errorf("writer.Write: %w", err) + } + + // Close all writers in reverse order to ensure all data is flushed + for i := len(writers) - 1; i >= 0; i-- { + err = writers[i].(io.Closer).Close() + if err != nil { + return nil, fmt.Errorf("writers[%d].(io.Closer).Close: %w", i, err) + } + } + + // fmt.Printf("lenIn: %d lenOut: %d\n", len(data), dst.Len()) + return dst.Bytes(), nil +} + +func decompress(data []byte, compressions map[string]DecompressionFactory, compressionOrder ...string) ([]byte, error) { + var ( + err error + reader io.Reader + readers []io.Reader + ) + + if compressions == nil { + compressions = defaultDecompressionFactories + } + + src := bytes.NewBuffer(data) + reader = src + + readers = append(readers, src) + + // Reverse the order of compressions for decompression + for idx := 0; idx < len(compressionOrder); idx++ { + compression := compressionOrder[idx] + // fmt.Println(fmt.Sprintf("Decompression added: %s", compression)) + mapping, ok := compressions[compression] + if !ok { + return nil, errors.New(compression + " is not supported") + } + reader, err = mapping(reader) + if err != nil { + return nil, fmt.Errorf("mapping[%d:%s]: %w", idx, compression, err) + } + + readers = append(readers, reader) + } + + dataOut, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("io.ReadAll: %w", err) + } + + for _, readerObj := range readers { + typedReader, ok := readerObj.(io.Closer) + if ok { + defer typedReader.Close() + } + } + + // fmt.Printf("lenIn: %d lenOut: %d\n", len(data), len(dataOut)) + return dataOut, nil +} + +func decompressReader(src io.Reader, compressions map[string]DecompressionFactory, compressionOrder []string) (io.ReadCloser, error) { + var ( + err error + ) + + if compressions == nil { + compressions = defaultDecompressionFactories + } + + result := &bodyDecompressorReader{ + reader: src, + CompressionOrder: compressionOrder, + } + + result.readers = append(result.readers, result.reader) + + // Reverse the order of compressions for decompression + for idx := 0; idx < len(compressionOrder); idx++ { + compression := compressionOrder[idx] + // fmt.Println(fmt.Sprintf("Decompression added: %s", compression)) + mapping, ok := compressions[compression] + if !ok { + return nil, errors.New(compression + " is not supported") + } + result.reader, err = mapping(result.reader) + if err != nil { + return nil, fmt.Errorf("mapping[%d:%s]: %w", idx, compression, err) + } + + result.readers = append(result.readers, result.reader) + } + + return result, nil +} + +type bodyDecompressorReader struct { + reader io.Reader + readers []io.Reader + Factory map[string]DecompressionFactory + CompressionOrder []string +} + +func (body *bodyDecompressorReader) Read(p []byte) (n int, err error) { + return body.reader.Read(p) +} + +func (body *bodyDecompressorReader) Close() error { + for _, readerObj := range body.readers { + typedReader, ok := readerObj.(io.Closer) + if ok { + err := typedReader.Close() + if err != nil { + return err + } + } + } + return nil +} diff --git a/compression_test.go b/compression_test.go new file mode 100644 index 00000000..66ccad4e --- /dev/null +++ b/compression_test.go @@ -0,0 +1,34 @@ +package http + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "testing" +) + +func Test_Sanity(t *testing.T) { + data := []byte("test data 123456789012345678901234567890") + + fmt.Println("Original Data:", string(data)) + fmt.Println("Original Data (Hex):", hex.EncodeToString(data)) + + // Compress using multiple algorithms + compressions := []string{"gzip", "deflate", "br", "zstd"} + dataCompressed, err := compress(data, defaultCompressionFactories, compressions...) + if err != nil { + t.Fatalf("compress: %v", err) + } + + fmt.Println("Compressed Data (Hex):", hex.EncodeToString(dataCompressed)) + fmt.Println("Compressed Data (Base64):", base64.StdEncoding.EncodeToString(dataCompressed)) + + // Decompress using the same algorithms + dataUncompressed, err := decompress(dataCompressed, defaultDecompressionFactories, compressions...) + if err != nil { + t.Fatalf("decompress: %v", err) + } + + fmt.Println("Decompressed Data:", string(dataUncompressed)) + fmt.Println("Decompressed Data (Hex):", hex.EncodeToString(dataUncompressed)) +} diff --git a/go.mod b/go.mod index 13987b60..1046fd08 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,10 @@ module github.com/ooni/oohttp go 1.21 -require golang.org/x/net v0.22.0 +require ( + github.com/andybalholm/brotli v1.1.1 + github.com/klauspost/compress v1.17.11 + golang.org/x/net v0.22.0 +) require golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index 6ce074b7..3bd65ead 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= diff --git a/h2_bundle.go b/h2_bundle.go index 9de9740e..39b83dca 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -8551,8 +8551,8 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) { cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + // Amendment: We are supporting encoding based on the header present in the request now, not just if the transport decided. if !cc.t.disableCompression() && - req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && !cs.isHead { // Request gzip only, not deflate. Deflate is ambiguous and @@ -9732,11 +9732,13 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http cs.bytesRemain = res.ContentLength res.Body = http2transportResponseBody{cs} - if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { - res.Header.Del("Content-Encoding") + // if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { + if cs.requestedGzip { + // res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 - res.Body = &http2gzipReader{body: res.Body} + // res.Body = &http2gzipReader{body: res.Body} + res.Body, err = decompressReader(res.Body, rl.cc.t.t1.DecompressionFactories, strings.Split(res.Header.Get("Content-Encoding"), ",")) res.Uncompressed = true } return res, nil diff --git a/transport.go b/transport.go index b06a2bb2..97c86037 100644 --- a/transport.go +++ b/transport.go @@ -29,7 +29,6 @@ import ( "time" httptrace "github.com/ooni/oohttp/httptrace" - ascii "github.com/ooni/oohttp/internal/ascii" "golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpproxy" ) @@ -191,7 +190,9 @@ type Transport struct { // decoded in the Response.Body. However, if the user // explicitly requested gzip it is not automatically // uncompressed. - DisableCompression bool + DisableCompression bool + CompressionFactories map[string]CompressionFactory + DecompressionFactories map[string]DecompressionFactory // MaxIdleConns controls the maximum number of idle (keep-alive) // connections across all hosts. Zero means no limit. @@ -2254,9 +2255,16 @@ func (pc *persistConn) readLoop() { } resp.Body = body - if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - resp.Body = &gzipReader{body: body} - resp.Header.Del("Content-Encoding") + + //if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + // fmt.Println("Check decompression", rc.addedGzip) + if rc.addedGzip { + resp.Body, err = decompressReader(body, pc.t.DecompressionFactories, strings.Split(resp.Header.Get("Content-Encoding"), ",")) + if err != nil { + panic(err) + } + // resp.Body = &gzipReader{body: body} + // resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 resp.Uncompressed = true @@ -2618,9 +2626,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // own value for Accept-Encoding. We only attempt to // uncompress the gzip stream if we were the layer that // requested it. + // Amendment: We are supporting encoding based on the header present in the request now, not just if the transport decided. requestedGzip := false if !pc.t.DisableCompression && - req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { // Request gzip only, not deflate. Deflate is ambiguous and @@ -2635,8 +2643,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // We don't request gzip if the request is for a range, since // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 + requestedGzip = true - req.extraHeaders().Set("Accept-Encoding", "gzip") + req.extraHeaders().Set("Accept-Encoding", req.Header.Get("Accept-Encoding")) + req.Header.Del("Accept-Encoding") } var continueCh chan struct{} From 4b34daf53b0ac39284bd687ece1769a45295a01c Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Sun, 3 Nov 2024 15:58:25 -0600 Subject: [PATCH 13/26] feat: add ability to use browser encodings --- transport.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transport.go b/transport.go index 97c86037..7b90cf09 100644 --- a/transport.go +++ b/transport.go @@ -17,6 +17,7 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/ooni/oohttp/internal/ascii" "io" "log" "net" @@ -2644,6 +2645,11 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 + // Default std lib behavior is to default to gzip + if ascii.EqualFold(resp.Header.Get("Content-Encoding"), "") { + resp.Header.Set("Content-Encoding", "gzip") + } + requestedGzip = true req.extraHeaders().Set("Accept-Encoding", req.Header.Get("Accept-Encoding")) req.Header.Del("Accept-Encoding") From a3cdb82d07d37cf6acd56830836176c150d8b8ea Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Sun, 3 Nov 2024 22:56:49 -0600 Subject: [PATCH 14/26] fix: default encoding to gzip --- transport.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transport.go b/transport.go index 7b90cf09..2f4755fe 100644 --- a/transport.go +++ b/transport.go @@ -17,7 +17,6 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/ooni/oohttp/internal/ascii" "io" "log" "net" @@ -2646,8 +2645,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // anyway. See https://golang.org/issue/8923 // Default std lib behavior is to default to gzip - if ascii.EqualFold(resp.Header.Get("Content-Encoding"), "") { - resp.Header.Set("Content-Encoding", "gzip") + if req.Header.Get("Accept-Encoding") == "" { + req.Header.Set("Accept-Encoding", "gzip") } requestedGzip = true From f2e7ae5134c6df93587cfc29b57f44cd899ff926 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Mon, 4 Nov 2024 02:34:56 -0500 Subject: [PATCH 15/26] feat: add post handshake callback for certificate pinning --- transport.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/transport.go b/transport.go index 2f4755fe..167779fa 100644 --- a/transport.go +++ b/transport.go @@ -306,6 +306,8 @@ type Transport struct { // increment to send in the WINDOW_UPDATE frame. WindowUpdateIncrement uint32 + + PostHandshakeCallback func(string, *tls.ConnectionState) error } type HTTP2PriorityFrameSettings struct { @@ -1635,11 +1637,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.scheme() == "https" && t.hasCustomTLSDialer() { var err error + fmt.Println("custom dialer") pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) if err != nil { return nil, wrapErr(err) } + fmt.Println(reflect.TypeOf(pconn.conn)) if tc, ok := pconn.conn.(TLSConn); ok { + fmt.Println("Shaking hands") // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. if trace != nil && trace.TLSHandshakeStart != nil { @@ -1653,6 +1658,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } cs := tc.ConnectionState() + fmt.Println("Shook hands") if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(cs, nil) } @@ -1672,6 +1678,13 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { return nil, wrapErr(err) } + + // Callback to add certificate Pinning feature + if t.PostHandshakeCallback != nil { + if err = t.PostHandshakeCallback(firstTLSHost, pconn.tlsState); err != nil { + return nil, fmt.Errorf("oohttp: t.PostHandshakeCallback: %w", err) + } + } } } From acfb16b772c4e6205b14d16223eacb0c7718238e Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Thu, 7 Nov 2024 02:50:51 -0500 Subject: [PATCH 16/26] feat: add post handshake callback for certificate pinning --- transport.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transport.go b/transport.go index 167779fa..f3dde1b0 100644 --- a/transport.go +++ b/transport.go @@ -1680,6 +1680,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } // Callback to add certificate Pinning feature + fmt.Println("POST HANDSHAKE") if t.PostHandshakeCallback != nil { if err = t.PostHandshakeCallback(firstTLSHost, pconn.tlsState); err != nil { return nil, fmt.Errorf("oohttp: t.PostHandshakeCallback: %w", err) @@ -1807,6 +1808,13 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { return nil, err } + + // Callback to add certificate Pinning feature + if t.PostHandshakeCallback != nil { + if err = t.PostHandshakeCallback(cm.tlsHost(), pconn.tlsState); err != nil { + return nil, fmt.Errorf("oohttp: t.PostHandshakeCallback: %w", err) + } + } } if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { From a9672b1dbb4b146ae2e5bebe6ef63edc3555ac84 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Thu, 7 Nov 2024 02:51:17 -0500 Subject: [PATCH 17/26] feat: add post handshake callback for certificate pinning --- transport.go | 1 - 1 file changed, 1 deletion(-) diff --git a/transport.go b/transport.go index f3dde1b0..7a3d5c44 100644 --- a/transport.go +++ b/transport.go @@ -1680,7 +1680,6 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } // Callback to add certificate Pinning feature - fmt.Println("POST HANDSHAKE") if t.PostHandshakeCallback != nil { if err = t.PostHandshakeCallback(firstTLSHost, pconn.tlsState); err != nil { return nil, fmt.Errorf("oohttp: t.PostHandshakeCallback: %w", err) From 469fd7db1c3fa5aeb0d5be829c2a373e66e3f3c9 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 21:12:32 +0700 Subject: [PATCH 18/26] feat: compressor streamable and lazy-init, hot path --- compression.go | 196 ++++++++++++++++++++++++++++++++------------ compression_test.go | 127 +++++++++++++++++++++++++++- h2_bundle.go | 8 +- transport.go | 15 ++-- 4 files changed, 284 insertions(+), 62 deletions(-) diff --git a/compression.go b/compression.go index b1db6b23..27fbabea 100644 --- a/compression.go +++ b/compression.go @@ -8,16 +8,24 @@ import ( "compress/zlib" "errors" "fmt" + "io" + "strings" + "sync" + "github.com/andybalholm/brotli" "github.com/klauspost/compress/zstd" - "io" ) type CompressionFactory func(writer io.Writer) (io.Writer, error) type DecompressionFactory func(reader io.Reader) (io.Reader, error) +type EncodingName = string + +type CompressionRegistry map[EncodingName]CompressionFactory +type DecompressionRegistry map[EncodingName]DecompressionFactory + var ( - defaultCompressionFactories = map[string]CompressionFactory{ + DefaultCompressionFactories = CompressionRegistry{ "": func(writer io.Writer) (io.Writer, error) { return writer, nil }, "identity": func(writer io.Writer) (io.Writer, error) { return writer, nil }, "gzip": func(writer io.Writer) (io.Writer, error) { return gzip.NewWriter(writer), nil }, @@ -29,7 +37,7 @@ var ( "zstd": func(writer io.Writer) (io.Writer, error) { return zstd.NewWriter(writer) }, } - defaultDecompressionFactories = map[string]DecompressionFactory{ + DefaultDecompressionFactories = DecompressionRegistry{ "": func(reader io.Reader) (io.Reader, error) { return reader, nil }, "identity": func(reader io.Reader) (io.Reader, error) { return reader, nil }, "gzip": func(reader io.Reader) (io.Reader, error) { return gzip.NewReader(reader) }, @@ -41,23 +49,23 @@ var ( } ) -func compress(data []byte, compressions map[string]CompressionFactory, compressionOrder ...string) ([]byte, error) { +func compress(data []byte, registry CompressionRegistry, order ...string) ([]byte, error) { var ( err error writers []io.Writer writer io.Writer ) - if compressions == nil { - compressions = defaultCompressionFactories + if registry == nil { + registry = DefaultCompressionFactories } dst := bytes.NewBuffer(nil) writer = dst - for idx, compression := range compressionOrder { + for idx, compression := range order { // fmt.Println(fmt.Sprintf("Compression added: %s", compression)) - mapping, ok := compressions[compression] + mapping, ok := registry[compression] if !ok { return nil, errors.New(compression + " is not supported") } @@ -86,15 +94,87 @@ func compress(data []byte, compressions map[string]CompressionFactory, compressi return dst.Bytes(), nil } -func decompress(data []byte, compressions map[string]DecompressionFactory, compressionOrder ...string) ([]byte, error) { +// CompressorWriter compressor writer +type CompressorWriter struct { + io.Writer + + Registry CompressionRegistry + Order []string + + wrs []io.Writer + + once sync.Once +} + +var _ io.WriteCloser = (*CompressorWriter)(nil) + +func (cw *CompressorWriter) init() error { + if cw.Registry == nil { + cw.Registry = DefaultCompressionFactories + } + cw.wrs = nil + for i := 0; i < len(cw.Order); i++ { + directive := cw.Order[i] + directive = strings.Trim(directive, " ") + if directive == "" { + continue + } + compressorWrapper, exist := cw.Registry[directive] + if !exist { + return fmt.Errorf("%s is not supported", directive) + } + writer, err := compressorWrapper(cw.Writer) + if err != nil { + return fmt.Errorf("compressor wrapper init: %s: %w", directive, err) + } + cw.wrs = append(cw.wrs, writer) + cw.Writer = writer + } + return nil +} + +// Init initialize decompressor early instead of lazy-initialize on first read op +func (cw *CompressorWriter) Init() (err error) { + cw.once.Do(func() { + err = cw.init() + }) + return +} + +// Write write buffer to compressor +func (cw *CompressorWriter) Write(b []byte) (nb int, err error) { + cw.once.Do(func() { + err = cw.init() + }) + if err != nil { + return + } + nb, err = cw.Writer.Write(b) + return +} + +// Close close compressor +func (cw *CompressorWriter) Close() error { + for i := len(cw.wrs) - 1; i >= 0; i-- { + if closer, ok := cw.wrs[i].(io.Closer); ok { + err := closer.Close() + if err != nil { + return err + } + } + } + return nil +} + +func decompress(data []byte, registry DecompressionRegistry, order ...string) ([]byte, error) { var ( err error reader io.Reader readers []io.Reader ) - if compressions == nil { - compressions = defaultDecompressionFactories + if registry == nil { + registry = DefaultDecompressionFactories } src := bytes.NewBuffer(data) @@ -103,10 +183,10 @@ func decompress(data []byte, compressions map[string]DecompressionFactory, compr readers = append(readers, src) // Reverse the order of compressions for decompression - for idx := 0; idx < len(compressionOrder); idx++ { - compression := compressionOrder[idx] + for idx := 0; idx < len(order); idx++ { + compression := order[idx] // fmt.Println(fmt.Sprintf("Decompression added: %s", compression)) - mapping, ok := compressions[compression] + mapping, ok := registry[compression] if !ok { return nil, errors.New(compression + " is not supported") } @@ -134,57 +214,71 @@ func decompress(data []byte, compressions map[string]DecompressionFactory, compr return dataOut, nil } -func decompressReader(src io.Reader, compressions map[string]DecompressionFactory, compressionOrder []string) (io.ReadCloser, error) { - var ( - err error - ) +// DecompressorReader decompressor reader +type DecompressorReader struct { + io.Reader - if compressions == nil { - compressions = defaultDecompressionFactories - } + Registry DecompressionRegistry + Order []string - result := &bodyDecompressorReader{ - reader: src, - CompressionOrder: compressionOrder, - } + rds []io.Reader - result.readers = append(result.readers, result.reader) + once sync.Once +} - // Reverse the order of compressions for decompression - for idx := 0; idx < len(compressionOrder); idx++ { - compression := compressionOrder[idx] - // fmt.Println(fmt.Sprintf("Decompression added: %s", compression)) - mapping, ok := compressions[compression] - if !ok { - return nil, errors.New(compression + " is not supported") +var _ io.ReadCloser = (*DecompressorReader)(nil) + +func (dr *DecompressorReader) init() error { + if dr.Registry == nil { + dr.Registry = DefaultDecompressionFactories + } + dr.rds = nil + for i := 0; i < len(dr.Order); i++ { + directive := dr.Order[i] + directive = strings.Trim(directive, " ") + if directive == "" { + continue } - result.reader, err = mapping(result.reader) + // fmt.Println(directive) + decompressorWrapper, exist := dr.Registry[directive] + if !exist { + return fmt.Errorf("%s is not supported", directive) + } + reader, err := decompressorWrapper(dr.Reader) if err != nil { - return nil, fmt.Errorf("mapping[%d:%s]: %w", idx, compression, err) + return fmt.Errorf("decompressor wrapper init: %s: %w", directive, err) } - - result.readers = append(result.readers, result.reader) + dr.rds = append(dr.rds, reader) + dr.Reader = reader } - - return result, nil + return nil } -type bodyDecompressorReader struct { - reader io.Reader - readers []io.Reader - Factory map[string]DecompressionFactory - CompressionOrder []string +// Init initialize decompressor early instead of lazy-initialize on first read op +func (dr *DecompressorReader) Init() (err error) { + dr.once.Do(func() { + err = dr.init() + }) + return } -func (body *bodyDecompressorReader) Read(p []byte) (n int, err error) { - return body.reader.Read(p) +// Read read buffer from decompressor +func (dr *DecompressorReader) Read(b []byte) (nb int, err error) { + dr.once.Do(func() { + err = dr.init() + }) + if err != nil { + return + } + nb, err = dr.Reader.Read(b) + return } -func (body *bodyDecompressorReader) Close() error { - for _, readerObj := range body.readers { - typedReader, ok := readerObj.(io.Closer) - if ok { - err := typedReader.Close() +// Close close decompressor +func (dr *DecompressorReader) Close() error { + for i := len(dr.rds) - 1; i >= 0; i-- { + if closer, ok := dr.rds[i].(io.Closer); ok { + err := closer.Close() if err != nil { return err } diff --git a/compression_test.go b/compression_test.go index 66ccad4e..8d246694 100644 --- a/compression_test.go +++ b/compression_test.go @@ -1,12 +1,135 @@ package http import ( + "bytes" "encoding/base64" "encoding/hex" "fmt" + "io" "testing" ) +func TestCompressionDecompressionRoundTrip(t *testing.T) { + type tc struct { + input []byte + compressions []string + expectedCompressed []byte + expected []byte + } + + for i, tc := range []tc{ + { + input: []byte("test data 123456789012345678901234567890"), + compressions: []string{"gzip", "deflate", "br", "zstd"}, + expectedCompressed: (func() []byte { + dec, err := hex.DecodeString("1f8b08000000000000ff7a249679e1c6aaffd35b1f6a323773277b6615763a7925792e6a5a387163e7ca993b5b97322677086c5e30e5f86173f3998c8c0c0cffff03020000ffff66bd043d34000000") + if err != nil { + panic(err) + } + return dec + })(), + expected: []byte("test data 123456789012345678901234567890"), + }, + { + input: []byte("foo bar baz\n\n"), + compressions: []string{}, + expectedCompressed: []byte("foo bar baz\n\n"), + expected: []byte("foo bar baz\n\n"), + }, + { + input: []byte("hello"), + compressions: []string{""}, + expectedCompressed: []byte("hello"), + expected: []byte("hello"), + }, + { + input: []byte("hello"), + compressions: []string{"identity"}, + expectedCompressed: []byte("hello"), + expected: []byte("hello"), + }, + { + input: []byte("hello"), + compressions: []string{"gzip"}, + expectedCompressed: []byte{ + 0x1f, 0x8b, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, + 0xca, 0x48, 0xcd, 0xc9, 0xc9, 0x7, 0x4, 0x0, 0x0, 0xff, 0xff, + 0x86, 0xa6, 0x10, 0x36, 0x5, 0x0, 0x0, 0x0, + }, + expected: []byte("hello"), + }, + { + input: []byte("foo bar baz\n\n"), + compressions: []string{"gzip", "br"}, + expectedCompressed: []byte{ + 0x1f, 0x8b, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, + 0xe2, 0x66, 0x6b, 0x48, 0xcb, 0xcf, 0x57, 0x48, 0x4a, + 0x2c, 0x52, 0x48, 0x4a, 0xac, 0xe2, 0xe2, 0x62, 0x6, 0x4, + 0x0, 0x0, 0xff, 0xff, 0xcb, 0xa9, 0xea, 0xd4, 0x11, 0x0, 0x0, 0x0, + }, + expected: []byte("foo bar baz\n\n"), + }, + { + input: []byte("hello"), + compressions: []string{"gzip", "deflate", "br", "zstd"}, + expectedCompressed: []byte{ + 0x1f, 0x8b, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0x7a, 0xf5, + 0x2c, 0xe3, 0xc2, 0x8d, 0x55, 0xff, 0xa7, 0xb7, 0x3a, 0x4e, 0x6e, + 0x54, 0x54, 0x36, 0x55, 0x57, 0xaf, 0x2f, 0x7c, 0xf7, 0x87, 0x2f, 0xcd, 0x81, + 0x81, 0x81, 0xe1, 0xff, 0x7f, 0x40, 0x0, 0x0, 0x0, 0xff, 0xff, 0x51, 0x13, 0xb7, + 0x9e, 0x1d, 0x0, 0x0, 0x0, + }, + expected: []byte("hello"), + }, + } { + buf := bytes.NewBuffer([]byte{}) + + writer := &CompressorWriter{ + Writer: buf, + Order: tc.compressions, + } + + (func() { + var writer io.Writer = writer + if closer, ok := writer.(io.Closer); ok { + defer closer.Close() + } + + nb, err := writer.Write(tc.input) + if err != nil { + t.Errorf("compressor write: %s", err) + } + fmt.Println("compressor write nb", nb) + })() + + // peek buffer + if !bytes.Equal(tc.expectedCompressed, buf.Bytes()) { + t.Errorf("unexpected compression result: %d: %+#v", i, tc.compressions) + } + fmt.Printf("raw buf %+#v\n", buf.Bytes()) + + reader := &DecompressorReader{ + Reader: bytes.NewReader(buf.Bytes()), + Order: tc.compressions, + } + + (func() { + var reader io.Reader = reader + if closer, ok := reader.(io.Closer); ok { + defer closer.Close() + } + + actual, err := io.ReadAll(reader) + if err != nil { + t.Errorf("decompressor read: %s", err) + } + if !bytes.Equal(actual, tc.expected) { + t.Errorf("result mismatch: %d: %+#v", i, tc.compressions) + } + })() + } +} + func Test_Sanity(t *testing.T) { data := []byte("test data 123456789012345678901234567890") @@ -15,7 +138,7 @@ func Test_Sanity(t *testing.T) { // Compress using multiple algorithms compressions := []string{"gzip", "deflate", "br", "zstd"} - dataCompressed, err := compress(data, defaultCompressionFactories, compressions...) + dataCompressed, err := compress(data, DefaultCompressionFactories, compressions...) if err != nil { t.Fatalf("compress: %v", err) } @@ -24,7 +147,7 @@ func Test_Sanity(t *testing.T) { fmt.Println("Compressed Data (Base64):", base64.StdEncoding.EncodeToString(dataCompressed)) // Decompress using the same algorithms - dataUncompressed, err := decompress(dataCompressed, defaultDecompressionFactories, compressions...) + dataUncompressed, err := decompress(dataCompressed, DefaultDecompressionFactories, compressions...) if err != nil { t.Fatalf("decompress: %v", err) } diff --git a/h2_bundle.go b/h2_bundle.go index 39b83dca..736b5b01 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -9734,11 +9734,15 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http // if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { if cs.requestedGzip { - // res.Header.Del("Content-Encoding") + res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 // res.Body = &http2gzipReader{body: res.Body} - res.Body, err = decompressReader(res.Body, rl.cc.t.t1.DecompressionFactories, strings.Split(res.Header.Get("Content-Encoding"), ",")) + res.Body= &DecompressorReader{ + Reader: res.Body, + Registry: rl.cc.t.t1.DecompressionRegistry, + Order: strings.Split(res.Header.Get("Content-Encoding"), ","), + } res.Uncompressed = true } return res, nil diff --git a/transport.go b/transport.go index 7a3d5c44..1cb89703 100644 --- a/transport.go +++ b/transport.go @@ -190,9 +190,9 @@ type Transport struct { // decoded in the Response.Body. However, if the user // explicitly requested gzip it is not automatically // uncompressed. - DisableCompression bool - CompressionFactories map[string]CompressionFactory - DecompressionFactories map[string]DecompressionFactory + DisableCompression bool + CompressionRegistry CompressionRegistry + DecompressionRegistry DecompressionRegistry // MaxIdleConns controls the maximum number of idle (keep-alive) // connections across all hosts. Zero means no limit. @@ -2279,12 +2279,13 @@ func (pc *persistConn) readLoop() { //if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { // fmt.Println("Check decompression", rc.addedGzip) if rc.addedGzip { - resp.Body, err = decompressReader(body, pc.t.DecompressionFactories, strings.Split(resp.Header.Get("Content-Encoding"), ",")) - if err != nil { - panic(err) + resp.Body = &DecompressorReader{ + Reader: resp.Body, + Registry: pc.t.DecompressionRegistry, + Order: strings.Split(resp.Header.Get("Content-Encoding"), ","), } // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 resp.Uncompressed = true From 9fa006de9c1eed2a1036dc5a1c831eb986930153 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 21:14:52 +0700 Subject: [PATCH 19/26] chore: use encoding name type --- compression.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compression.go b/compression.go index 27fbabea..1e5338d4 100644 --- a/compression.go +++ b/compression.go @@ -49,7 +49,7 @@ var ( } ) -func compress(data []byte, registry CompressionRegistry, order ...string) ([]byte, error) { +func compress(data []byte, registry CompressionRegistry, order ...EncodingName) ([]byte, error) { var ( err error writers []io.Writer @@ -99,7 +99,7 @@ type CompressorWriter struct { io.Writer Registry CompressionRegistry - Order []string + Order []EncodingName wrs []io.Writer @@ -166,7 +166,7 @@ func (cw *CompressorWriter) Close() error { return nil } -func decompress(data []byte, registry DecompressionRegistry, order ...string) ([]byte, error) { +func decompress(data []byte, registry DecompressionRegistry, order ...EncodingName) ([]byte, error) { var ( err error reader io.Reader @@ -219,7 +219,7 @@ type DecompressorReader struct { io.Reader Registry DecompressionRegistry - Order []string + Order []EncodingName rds []io.Reader From bb817384bda6736c7d0fdd51c8df66ecd6c422b7 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 21:17:36 +0700 Subject: [PATCH 20/26] chore: h2 bundle decompressor reader space --- h2_bundle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/h2_bundle.go b/h2_bundle.go index 736b5b01..d52bd8c3 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -9738,7 +9738,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http res.Header.Del("Content-Length") res.ContentLength = -1 // res.Body = &http2gzipReader{body: res.Body} - res.Body= &DecompressorReader{ + res.Body = &DecompressorReader{ Reader: res.Body, Registry: rl.cc.t.t1.DecompressionRegistry, Order: strings.Split(res.Header.Get("Content-Encoding"), ","), From cb35a8cbe56f22ab27d8ed80ce8735a208f6d413 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 21:23:19 +0700 Subject: [PATCH 21/26] chore: rename factories to registry --- compression.go | 12 ++++++------ compression_test.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/compression.go b/compression.go index 1e5338d4..2aff921e 100644 --- a/compression.go +++ b/compression.go @@ -25,7 +25,7 @@ type CompressionRegistry map[EncodingName]CompressionFactory type DecompressionRegistry map[EncodingName]DecompressionFactory var ( - DefaultCompressionFactories = CompressionRegistry{ + DefaultCompressionRegistry = CompressionRegistry{ "": func(writer io.Writer) (io.Writer, error) { return writer, nil }, "identity": func(writer io.Writer) (io.Writer, error) { return writer, nil }, "gzip": func(writer io.Writer) (io.Writer, error) { return gzip.NewWriter(writer), nil }, @@ -37,7 +37,7 @@ var ( "zstd": func(writer io.Writer) (io.Writer, error) { return zstd.NewWriter(writer) }, } - DefaultDecompressionFactories = DecompressionRegistry{ + DefaultDecompressionRegistry = DecompressionRegistry{ "": func(reader io.Reader) (io.Reader, error) { return reader, nil }, "identity": func(reader io.Reader) (io.Reader, error) { return reader, nil }, "gzip": func(reader io.Reader) (io.Reader, error) { return gzip.NewReader(reader) }, @@ -57,7 +57,7 @@ func compress(data []byte, registry CompressionRegistry, order ...EncodingName) ) if registry == nil { - registry = DefaultCompressionFactories + registry = DefaultCompressionRegistry } dst := bytes.NewBuffer(nil) @@ -110,7 +110,7 @@ var _ io.WriteCloser = (*CompressorWriter)(nil) func (cw *CompressorWriter) init() error { if cw.Registry == nil { - cw.Registry = DefaultCompressionFactories + cw.Registry = DefaultCompressionRegistry } cw.wrs = nil for i := 0; i < len(cw.Order); i++ { @@ -174,7 +174,7 @@ func decompress(data []byte, registry DecompressionRegistry, order ...EncodingNa ) if registry == nil { - registry = DefaultDecompressionFactories + registry = DefaultDecompressionRegistry } src := bytes.NewBuffer(data) @@ -230,7 +230,7 @@ var _ io.ReadCloser = (*DecompressorReader)(nil) func (dr *DecompressorReader) init() error { if dr.Registry == nil { - dr.Registry = DefaultDecompressionFactories + dr.Registry = DefaultDecompressionRegistry } dr.rds = nil for i := 0; i < len(dr.Order); i++ { diff --git a/compression_test.go b/compression_test.go index 8d246694..5b83c1f9 100644 --- a/compression_test.go +++ b/compression_test.go @@ -138,7 +138,7 @@ func Test_Sanity(t *testing.T) { // Compress using multiple algorithms compressions := []string{"gzip", "deflate", "br", "zstd"} - dataCompressed, err := compress(data, DefaultCompressionFactories, compressions...) + dataCompressed, err := compress(data, DefaultCompressionRegistry, compressions...) if err != nil { t.Fatalf("compress: %v", err) } @@ -147,7 +147,7 @@ func Test_Sanity(t *testing.T) { fmt.Println("Compressed Data (Base64):", base64.StdEncoding.EncodeToString(dataCompressed)) // Decompress using the same algorithms - dataUncompressed, err := decompress(dataCompressed, DefaultDecompressionFactories, compressions...) + dataUncompressed, err := decompress(dataCompressed, DefaultDecompressionRegistry, compressions...) if err != nil { t.Fatalf("decompress: %v", err) } From e3308c167e8f01ec87b69b311a2cdd8738cccb56 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 23:02:13 +0700 Subject: [PATCH 22/26] fix: write request extra header to wire --- request.go | 8 +++++--- transport.go | 11 +++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/request.go b/request.go index 7a218d35..b81e2aef 100644 --- a/request.go +++ b/request.go @@ -14,7 +14,6 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/ooni/oohttp/textproto" "io" "mime" "mime/multipart" @@ -24,6 +23,8 @@ import ( "strings" "sync" + "github.com/ooni/oohttp/textproto" + httptrace "github.com/ooni/oohttp/httptrace" ascii "github.com/ooni/oohttp/internal/ascii" "golang.org/x/net/http/httpguts" @@ -682,8 +683,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF // Make sure can be ordered too Accept-Encoding, Connection if extraHeaders != nil { - for key, values := range extraHeaders { - r.Header[key] = values + err = extraHeaders.write(w, trace) + if err != nil { + return err } } diff --git a/transport.go b/transport.go index 1cb89703..42cf2e5e 100644 --- a/transport.go +++ b/transport.go @@ -2665,14 +2665,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - // Default std lib behavior is to default to gzip - if req.Header.Get("Accept-Encoding") == "" { - req.Header.Set("Accept-Encoding", "gzip") + // Leave control to the user + if req.Header.has("Accept-Encoding") { + // Default std lib behavior is to default to gzip + if req.Header.Get("Accept-Encoding") == "" { + req.Header.Set("Accept-Encoding", "gzip") + } } requestedGzip = true req.extraHeaders().Set("Accept-Encoding", req.Header.Get("Accept-Encoding")) - req.Header.Del("Accept-Encoding") + req.Header.Del("Accept-Encoding") // dedup } var continueCh chan struct{} From 726f96d8d46c6cfafb7614a43b3a0d8fe422294a Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 23:38:23 +0700 Subject: [PATCH 23/26] fix: accept-encoding fallback gzip --- transport.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transport.go b/transport.go index 42cf2e5e..5bbc8012 100644 --- a/transport.go +++ b/transport.go @@ -2665,12 +2665,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - // Leave control to the user - if req.Header.has("Accept-Encoding") { - // Default std lib behavior is to default to gzip - if req.Header.Get("Accept-Encoding") == "" { - req.Header.Set("Accept-Encoding", "gzip") - } + // Default std lib behavior is to default to gzip + if req.Header.Get("Accept-Encoding") == "" { + req.Header.Set("Accept-Encoding", "gzip") } requestedGzip = true From 3d4b7cb5a90731f6e8030724bae47c83f524dc95 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Fri, 6 Dec 2024 23:46:44 +0700 Subject: [PATCH 24/26] fix: extra header orderable --- request.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/request.go b/request.go index b81e2aef..eafafbcb 100644 --- a/request.go +++ b/request.go @@ -676,19 +676,18 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF return err } - err = r.Header.write(w, trace) - if err != nil { - return err - } - // Make sure can be ordered too Accept-Encoding, Connection if extraHeaders != nil { - err = extraHeaders.write(w, trace) - if err != nil { - return err + for key, values := range extraHeaders { + r.Header[key] = values } } + err = r.Header.write(w, trace) + if err != nil { + return err + } + _, err = io.WriteString(w, "\r\n") if err != nil { return err From 48b2637db41050e1ca97c18741654c9a128d378b Mon Sep 17 00:00:00 2001 From: Nugraha Date: Mon, 9 Dec 2024 14:55:01 +0700 Subject: [PATCH 25/26] fix: h2 bundle decompress encoding order --- h2_bundle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/h2_bundle.go b/h2_bundle.go index d52bd8c3..41d3111f 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -9734,7 +9734,6 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http // if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { if cs.requestedGzip { - res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 // res.Body = &http2gzipReader{body: res.Body} @@ -9743,6 +9742,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http Registry: rl.cc.t.t1.DecompressionRegistry, Order: strings.Split(res.Header.Get("Content-Encoding"), ","), } + res.Header.Del("Content-Encoding") res.Uncompressed = true } return res, nil From 1cccadf550c41c31ffae90db798e010fe133e3d8 Mon Sep 17 00:00:00 2001 From: BRUHItsABunny <53124399+BRUHItsABunny@users.noreply.github.com> Date: Fri, 3 Jan 2025 12:51:36 -0600 Subject: [PATCH 26/26] fix: revert to the old way of checking for user agent presence --- go.sum | 2 ++ request.go | 27 ++++++++++++--------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/go.sum b/go.sum index 3bd65ead..e3854fdf 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,7 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= diff --git a/request.go b/request.go index 216e1c57..e3fb2e6a 100644 --- a/request.go +++ b/request.go @@ -670,29 +670,26 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF } // Header lines - if !r.Header.has("Host") { - r.Header.Set("Host", host) - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Host", []string{host}) + if _, ok := r.Header["Host"]; !ok { + if _, ok := r.Header["host"]; !ok { + r.Header.Set("Host", host) + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Host", []string{host}) + } } } // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. - userAgent := defaultUserAgent - if r.Header.has("User-Agent") { - userAgent = r.Header.Get("User-Agent") - } - if userAgent != "" { - userAgent = headerNewlineToSpace.Replace(userAgent) - userAgent = textproto.TrimString(userAgent) - r.Header.Set("User-Agent", userAgent) - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("User-Agent", []string{userAgent}) + if _, ok := r.Header["User-Agent"]; !ok { + if _, ok := r.Header["user-agent"]; !ok { + r.Header.Set("User-Agent", defaultUserAgent) + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("User-Agent", []string{defaultUserAgent}) + } } } - // Process Body,ContentLength,Close,Trailer tw, err := newTransferWriter(r) if err != nil {