diff --git a/.github/workflows/build-refactor.yml b/.github/workflows/build-refactor.yml deleted file mode 100644 index 53b284e6..00000000 --- a/.github/workflows/build-refactor.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: build-refactor -# this action is covering internal/ tree with go1.21 - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - short-tests: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: setup go - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Run short tests - run: go test --short -cover ./internal/... - - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Lint with revive action, from pre-built image - uses: docker://morphy/revive-action:v2 - with: - path: "internal/..." - - gosec: - runs-on: ubuntu-latest - env: - GO111MODULE: on - steps: - - name: Checkout Source - uses: actions/checkout@v4 - - name: Run Gosec security scanner - uses: securego/gosec@master - with: - args: '-no-fail ./...' - - coverage-threshold: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: setup go - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Ensure coverage threshold - run: make test-coverage-threshold-refactor - - integration: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: setup go - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: run integration tests - run: go run ./tests/integration - diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4995f6a8..140d09ed 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,24 +1,34 @@ name: build +# this action is covering internal/ tree with go1.20 on: push: branches: - - main + - 'main' pull_request: branches: - - main + - 'main' jobs: short-tests: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: '1.20' - name: Run short tests - run: go test --short -cover ./vpn + run: go test --short -cover ./internal/... + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Lint with revive action, from pre-built image + uses: docker://morphy/revive-action:v2 + with: + path: "internal/..." gosec: runs-on: ubuntu-latest @@ -26,7 +36,7 @@ jobs: GO111MODULE: on steps: - name: Checkout Source - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run Gosec security scanner uses: securego/gosec@master with: @@ -35,10 +45,22 @@ jobs: coverage-threshold: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: '1.20' - name: Ensure coverage threshold run: make test-coverage-threshold + + integration: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: setup go + uses: actions/setup-go@v5 + with: + go-version: '1.20' + - name: run integration tests + run: go run ./tests/integration + diff --git a/Makefile b/Makefile index df14cc49..8db170cc 100644 --- a/Makefile +++ b/Makefile @@ -16,13 +16,7 @@ build-rel: @GOOS=windows go build ${FLAGS} -o minivpn.exe build-race: - @go build -race - -build-ping: - @go build -v ./cmd/vpnping - -build-ndt7: - @go build -o ndt7 ./cmd/ndt7 + @go build -race ./cmd/minivpn bootstrap: @./scripts/bootstrap-provider ${PROVIDER} @@ -45,10 +39,6 @@ test-combined-coverage: scripts/go-coverage-check.sh ./coverage/profile.out ${COVERAGE_THRESHOLD} test-coverage-threshold: - go test --short -coverprofile=cov-threshold.out ./vpn - ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} - -test-coverage-threshold-refactor: go test --short -coverprofile=cov-threshold-refactor.out ./internal/... ./scripts/go-coverage-check.sh cov-threshold-refactor.out ${COVERAGE_THRESHOLD} @@ -56,7 +46,7 @@ test-short: go test -race -short -v ./... test-ping: - ./minivpn -c data/${PROVIDER}/config -t ${TARGET} -n ${COUNT} ping + ./minivpn -c data/${PROVIDER}/config -ping integration-server: # this needs the container from https://github.com/ainghazal/docker-openvpn @@ -66,12 +56,6 @@ test-fetch-config: rm -rf data/tests mkdir -p data/tests && curl 172.17.0.2:8080/ > data/tests/config -test-ping-local: - # run the integration-server first - ./minivpn -c data/tests/config -t 172.17.0.1 -n ${COUNT} ping - -test-local: test-fetch-config test-ping-local - qa: @# all the steps at once cd tests/integration && ./run-server.sh & @@ -79,7 +63,7 @@ qa: @rm -rf data/tests @mkdir -p data/tests && curl 172.17.0.2:8080/ > data/tests/config @sleep 1 - ./minivpn -c data/tests/config -t 172.17.0.1 -n ${COUNT} ping + ./minivpn -c data/tests/config -ping @docker stop ovpn1 integration: @@ -88,17 +72,6 @@ integration: filternet-qa: cd tests/qa && ./run-filternet.sh remote-block-all -coverage: - go test -coverprofile=coverage.out ./vpn - go tool cover -html=coverage.out - -coverage-ping: - go test -coverprofile=coverage-ping.out ./extras/ping - go tool cover -html=coverage-ping.out - -proxy: - ./minivpn -c data/${PROVIDER}/config proxy - backup-data: @tar cvzf ../data-vpn-`date +'%F'`.tar.gz @@ -122,5 +95,9 @@ go-sec: go-revive: revive internal/... +install-linters: + go install github.com/mgechev/revive@latest + go install github.com/securego/gosec/v2/cmd/gosec@latest + clean: @rm -f coverage.out diff --git a/extras/example_pinger_test.go b/extras/example_pinger_test.go deleted file mode 100644 index b197510c..00000000 --- a/extras/example_pinger_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package extras - -import ( - "context" - "os" - "time" - - "github.com/ooni/minivpn/extras/ping" - "github.com/ooni/minivpn/vpn" -) - -var ( - cfg = "data/riseup/config" - target = "8.8.8.8" - count = 3 -) - -func ExamplePinger() { - opts, err := vpn.NewOptionsFromFilePath(cfg) - if err != nil { - os.Exit(1) - } - tunnel := vpn.NewClientFromOptions(opts) - tunnel.Start(context.Background()) - pinger := ping.New(target, tunnel) - pinger.Count = 3 - pinger.Timeout = 5 * time.Second - pinger.Run(context.Background()) -} diff --git a/extras/ndt7.go b/extras/ndt7.go deleted file mode 100644 index 80d2e5b5..00000000 --- a/extras/ndt7.go +++ /dev/null @@ -1,134 +0,0 @@ -package extras - -/* - Vendoring of m-lab's ndt7-client reference code, to be able to experiment - and manipulate parts that are exposed as internal in the library. - Upstream: https://github.com/m-lab/ndt7-client-go/ - - SPDX-License-Identifier: Apache-2.0 - - (c) Stephen Soltesz - (c) Peter Boothe - (c) Simone Basso - (c) Ain Ghazal -*/ - -import ( - "context" - "crypto/tls" - "log" - "os" - "time" - - "github.com/gorilla/websocket" - "github.com/m-lab/ndt7-client-go" - "github.com/m-lab/ndt7-client-go/spec" - "github.com/ooni/minivpn/extras/ndt7/emitter" - "github.com/ooni/minivpn/vpn" -) - -const ( - clientName = "minivpn-ndt7-client" - clientVersion = "0.6.1" - defaultTimeout = 10 * time.Second -) - -type runner struct { - client *ndt7.Client - emitter emitter.Emitter -} - -func (r runner) runDownload(ctx context.Context) int { - return r.runTest(ctx, spec.TestDownload, r.client.StartDownload, - r.emitter.OnDownloadEvent) -} - -func (r runner) runUpload(ctx context.Context) int { - return r.runTest(ctx, spec.TestUpload, r.client.StartUpload, - r.emitter.OnUploadEvent) -} - -func (r runner) runTest( - ctx context.Context, test spec.TestKind, - start func(context.Context) (<-chan spec.Measurement, error), - emitEvent func(m *spec.Measurement) error, -) int { - // Implementation note: we want to always emit the initial and the - // final events regardless of how the actual test goes. What's more, - // we want the exit code to be nonzero in case of any error. - err := r.emitter.OnStarting(test) - if err != nil { - return 1 - } - code := r.doRunTest(ctx, test, start, emitEvent) - err = r.emitter.OnComplete(test) - if err != nil { - return 1 - } - return code -} - -func (r runner) doRunTest( - ctx context.Context, test spec.TestKind, - start func(context.Context) (<-chan spec.Measurement, error), - emitEvent func(m *spec.Measurement) error, -) int { - ch, err := start(ctx) - if err != nil { - _ = r.emitter.OnError(test, err) - return 1 - } - err = r.emitter.OnConnected(test, r.client.FQDN) - if err != nil { - return 1 - } - for ev := range ch { - err = emitEvent(&ev) - if err != nil { - return 1 - } - } - return 0 -} - -// TODO use memoryless to repeat measurements, gather the json outputs and -// return a measurement batch. - -// RunMeasurement performs a download & upload measurement against a given ndt7 server. -// It expects a vpn Dialer and a server string (ip:port). -// If the direct parameter is set to true, the vpn Dialer will not be used and -// a direct connection will be used instead. -func RunMeasurement(d vpn.TunDialer, ndt7Server string, mode string, direct bool) { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) - defer cancel() - var r runner - - insecureTLS := false - if os.Getenv("TLS_NOVERIFY") == "1" { - insecureTLS = true - } - - vpnDialer := websocket.Dialer{ - // TODO(ainghazal): pass a config flag to force the InsecureSkipVerify config, - // this should not be used in production. - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureTLS, - }, - } //#nosec G402 - if direct == false { - vpnDialer.NetDialContext = d.DialContext - } else { - log.Println("using a direct connection to ndt7 server") - } - - r.client = ndt7.NewClient(clientName, clientVersion) - r.client.Server = ndt7Server - r.client.Dialer = vpnDialer - r.emitter = emitter.NewJSON(os.Stdout) - switch mode { - case "download": - r.runDownload(ctx) - case "upload": - r.runUpload(ctx) - } -} diff --git a/extras/ndt7/emitter/emitter.go b/extras/ndt7/emitter/emitter.go deleted file mode 100644 index a96da615..00000000 --- a/extras/ndt7/emitter/emitter.go +++ /dev/null @@ -1,50 +0,0 @@ -// Package emitter contains the ndt7-client emitter. -package emitter - -/* - Vendoring of m-lab's ndt7-client reference code, to be able to experiment - and manipulate parts that are exposed as internal in the library. - Upstream: https://github.com/m-lab/ndt7-client-go/ - - SPDX-License-Identifier: Apache-2.0 - - (c) Stephen Soltesz - (c) Peter Boothe - (c) Simone Basso - (c) Ain Ghazal -*/ - -import ( - "github.com/m-lab/ndt7-client-go/spec" -) - -// Emitter is a generic emitter. When an event occurs, the -// corresponding method will be called. An error will generally -// mean that it's not possible to write the output. A common -// case where this happen is where the output is redirected to -// a file on a full hard disk. -// -// See the documentation of the main package for more details -// on the sequence in which events may occur. -type Emitter interface { - // OnStarting is emitted before attempting to start a test. - OnStarting(test spec.TestKind) error - - // OnError is emitted if a test cannot start. - OnError(test spec.TestKind, err error) error - - // OnConnected is emitted when we connected to the ndt7 server. - OnConnected(test spec.TestKind, fqdn string) error - - // OnDownloadEvent is emitted during the download. - OnDownloadEvent(m *spec.Measurement) error - - // OnUploadEvent is emitted during the upload. - OnUploadEvent(m *spec.Measurement) error - - // OnComplete is always emitted when the test is over. - OnComplete(test spec.TestKind) error - - // OnSummary is emitted after the test is over. - OnSummary(s *Summary) error -} diff --git a/extras/ndt7/emitter/json.go b/extras/ndt7/emitter/json.go deleted file mode 100644 index f33c40f4..00000000 --- a/extras/ndt7/emitter/json.go +++ /dev/null @@ -1,129 +0,0 @@ -package emitter - -/* - Vendoring of m-lab's ndt7-client reference code, to be able to experiment - and manipulate parts that are exposed as internal in the library. - Upstream: https://github.com/m-lab/ndt7-client-go/ - - SPDX-License-Identifier: Apache-2.0 - - (c) Stephen Soltesz - (c) Peter Boothe - (c) Simone Basso - (c) Ain Ghazal -*/ - -import ( - "encoding/json" - "io" - - "github.com/m-lab/ndt7-client-go/spec" -) - -// jsonEmitter is a jsonEmitter emitter. It emits messages consistent with -// the cmd/ndt7-client/main.go documentation for `-format=json`. -type jsonEmitter struct { - io.Writer -} - -// NewJSON creates a new JSON emitter -func NewJSON(w io.Writer) Emitter { - return jsonEmitter{ - Writer: w, - } -} - -func (j jsonEmitter) emitData(data []byte) error { - _, err := j.Write(append(data, byte('\n'))) - return err -} - -func (j jsonEmitter) emitInterface(any interface{}) error { - data, err := json.Marshal(any) - if err != nil { - return err - } - return j.emitData(data) -} - -type batchEvent struct { - Key string - Value interface{} -} - -type batchValue struct { - spec.Measurement - Failure string `json:",omitempty"` - Server string `json:",omitempty"` -} - -// OnStarting emits the starting event -func (j jsonEmitter) OnStarting(test spec.TestKind) error { - return j.emitInterface(batchEvent{ - Key: "starting", - Value: batchValue{ - Measurement: spec.Measurement{ - Test: test, - }, - }, - }) -} - -// OnError emits the error event -func (j jsonEmitter) OnError(test spec.TestKind, err error) error { - return j.emitInterface(batchEvent{ - Key: "error", - Value: batchValue{ - Measurement: spec.Measurement{ - Test: test, - }, - Failure: err.Error(), - }, - }) -} - -// OnConnected emits the connected event -func (j jsonEmitter) OnConnected(test spec.TestKind, fqdn string) error { - return j.emitInterface(batchEvent{ - Key: "connected", - Value: batchValue{ - Measurement: spec.Measurement{ - Test: test, - }, - Server: fqdn, - }, - }) -} - -// OnDownloadEvent handles an event emitted during the download -func (j jsonEmitter) OnDownloadEvent(m *spec.Measurement) error { - return j.emitInterface(batchEvent{ - Key: "measurement", - Value: m, - }) -} - -// OnUploadEvent handles an event emitted during the upload -func (j jsonEmitter) OnUploadEvent(m *spec.Measurement) error { - return j.emitInterface(batchEvent{ - Key: "measurement", - Value: m, - }) -} - -// OnComplete is the event signalling the end of the test -func (j jsonEmitter) OnComplete(test spec.TestKind) error { - return j.emitInterface(batchEvent{ - Key: "complete", - Value: batchValue{ - Measurement: spec.Measurement{ - Test: test, - }, - }, - }) -} - -// OnSummary handles the summary event, emitted after the test is over. -func (j jsonEmitter) OnSummary(s *Summary) error { - return j.emitInterface(s) -} diff --git a/extras/ndt7/emitter/summary.go b/extras/ndt7/emitter/summary.go deleted file mode 100644 index 7282d1a1..00000000 --- a/extras/ndt7/emitter/summary.go +++ /dev/null @@ -1,58 +0,0 @@ -package emitter - -/* - Vendoring of m-lab's ndt7-client reference code, to be able to experiment - and manipulate parts that are exposed as internal in the library. - Upstream: https://github.com/m-lab/ndt7-client-go/ - - SPDX-License-Identifier: Apache-2.0 - - (c) Stephen Soltesz - (c) Peter Boothe - (c) Simone Basso -*/ - -// ValueUnitPair represents a {"Value": ..., "Unit": ...} pair. -type ValueUnitPair struct { - Value float64 - Unit string -} - -// Summary is a struct containing the values displayed to the user at -// the end of an ndt7 test. -type Summary struct { - // ServerFQDN is the FQDN of the server used for this test. - ServerFQDN string - - // ServerIP is the (v4 or v6) IP address of the server. - ServerIP string - - // ClientIP is the (v4 or v6) IP address of the client. - ClientIP string - - // DownloadUUID is the UUID of the download test. - // TODO: add UploadUUID after we start processing counterflow messages. - DownloadUUID string - - // Download is the download speed, in Mbit/s. This is measured at the - // receiver. - Download ValueUnitPair - - // Upload is the upload speed, in Mbit/s. This is measured at the sender. - Upload ValueUnitPair - - // DownloadRetrans is the retransmission rate. This is based on the TCPInfo - // values provided by the server during a download test. - DownloadRetrans ValueUnitPair - - // RTT is the round-trip time of the latest measurement, in milliseconds. - // This is provided by the server during a download test. - MinRTT ValueUnitPair -} - -// NewSummary returns a new Summary struct for a given FQDN. -func NewSummary(FQDN string) *Summary { - return &Summary{ - ServerFQDN: FQDN, - } -} diff --git a/extras/ping/ping_test.go b/extras/ping/ping_test.go index e2d6c3b8..02d1ecbb 100644 --- a/extras/ping/ping_test.go +++ b/extras/ping/ping_test.go @@ -10,8 +10,9 @@ import ( "testing" "time" + "github.com/ooni/minivpn/internal/mocks" + "github.com/google/uuid" - "github.com/ooni/minivpn/vpn/mocks" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" ) diff --git a/go.mod b/go.mod index 3b12fd46..bf5a7ed7 100644 --- a/go.mod +++ b/go.mod @@ -7,32 +7,34 @@ go 1.20 require ( git.torproject.org/pluggable-transports/goptlib.git v1.3.0 + github.com/Doridian/water v1.6.1 github.com/apex/log v1.9.0 - github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 github.com/google/martian v2.1.0+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 + github.com/jackpal/gateway v1.0.11 // pinned to a previous version until we can use go1.21 github.com/m-lab/ndt7-client-go v0.7.0 github.com/ory/dockertest/v3 v3.9.1 - github.com/pborman/getopt/v2 v2.1.0 github.com/refraction-networking/utls v1.3.1 gitlab.com/yawning/obfs4.git v0.0.0-20220904064028-336a71d6e4cf - golang.org/x/net v0.17.0 - golang.org/x/sync v0.4.0 + golang.org/x/net v0.22.0 + golang.org/x/sync v0.6.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 ) require ( filippo.io/edwards25519 v1.0.0-rc.1.0.20210721174708-390f27c3be20 // indirect github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect + github.com/Doridian/gopacket v1.2.1 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/andybalholm/brotli v1.0.4 // indirect github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/containerd/continuity v0.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dchest/siphash v1.2.1 // indirect github.com/docker/cli v20.10.14+incompatible // indirect github.com/docker/docker v20.10.7+incompatible // indirect @@ -53,18 +55,22 @@ require ( github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/runc v1.1.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/stretchr/testify v1.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/mod v0.13.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index 34049e97..ed35b385 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,10 @@ github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7O github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Doridian/gopacket v1.2.1 h1:z0Iu5zplIq01nGNwKoreAhc/RMIUqu6vZLxLsHjpO48= +github.com/Doridian/gopacket v1.2.1/go.mod h1:16EwY3JsEHp3TFeSRcmSC9yOdG8GkFAWImZaL13kOGc= +github.com/Doridian/water v1.6.1 h1:cszUUfRlk9duYwv1bE5mFluhaHJK0TAIdPLJk2DV0cI= +github.com/Doridian/water v1.6.1/go.mod h1:bjQW67+p0YKqdrjoDWpJ+bcRvrrHphxwZOC6OTfBuf8= github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= @@ -52,8 +56,6 @@ github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 h1:TEBmxO80TM04L8IuMWk77SGL1HomBmKTdzdJLLWznxI= github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1/go.mod h1:SLqhdZcd+dF3TEVL2RMoob5bBP5R1P1qkox+HtCBgGI= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -111,6 +113,7 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-test/deep v1.0.6 h1:UHSEyLZUwX9Qoi99vVwvewiMC8mM2bf7XEM2nqvzEn8= github.com/go-test/deep v1.0.6/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= @@ -146,6 +149,7 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= @@ -191,6 +195,10 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/jackpal/gateway v1.0.11 h1:XqCVFIyo2LtQYXjz9nis1WMTvAadJiFP/Zc04xmdEYE= +github.com/jackpal/gateway v1.0.11/go.mod h1:NqRwEsSP/DD8d4YXIsHEMNUSYetesFXjmL6QZFrul+M= +github.com/jackpal/gateway v1.0.13 h1:fJccMvawxx0k7S1q7Fy/SXFE0R3hMXkMuw8y9SofWAk= +github.com/jackpal/gateway v1.0.13/go.mod h1:6c8LjW+FVESFmwxaXySkt7fU98Yv806ADS3OY6Cvh2U= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -217,6 +225,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394= +github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/m-lab/access v0.0.3 h1:ixaAZNvN/ggOj1/FOZQhw4r31QA/WlApn0HH71LSHgQ= github.com/m-lab/access v0.0.3/go.mod h1:gZ7YN3SeMTZYeRv5EFaLdG+XVI/F/X4njM1G1BfwuE4= github.com/m-lab/go v0.1.43 h1:AKmrhhi5a5rUL9nNMo0YNLSJLNFLfV5PmdAXqRCBGE8= @@ -262,8 +271,6 @@ github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.m github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI= github.com/ory/dockertest/v3 v3.9.1 h1:v4dkG+dlu76goxMiTT2j8zV7s4oPPEppKT8K8p2f1kY= github.com/ory/dockertest/v3 v3.9.1/go.mod h1:42Ir9hmvaAPm0Mgibk6mBPi7SFvTXxEcnztDYOJ//uM= -github.com/pborman/getopt/v2 v2.1.0 h1:eNfR+r+dWLdWmV8g5OlpyrTYHkhVNxHBdN2cCrJmOEA= -github.com/pborman/getopt/v2 v2.1.0/go.mod h1:4NtW75ny4eBw9fO1bhtNdYTlZKYX5/tBLtsOpwKIKd0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -307,11 +314,16 @@ github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4S github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= @@ -325,6 +337,7 @@ github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKw github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/vishvananda/netns v0.0.1/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= @@ -334,6 +347,7 @@ github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb h1:qRSZHsODmAP5qDvb3YsO7Qnf3TRiVbGxNG/WYnlM4/o= gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo= gitlab.com/yawning/obfs4.git v0.0.0-20220904064028-336a71d6e4cf h1:k9czJST0Jvc6fnz4Jp1sxRmA4dSuiWFq+DVpxLZP5yM= @@ -350,8 +364,11 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -362,6 +379,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -382,8 +401,11 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -409,8 +431,12 @@ golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -424,8 +450,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -451,6 +480,7 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -466,15 +496,24 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -518,8 +557,11 @@ golang.org/x/tools v0.0.0-20200409170454-77362c5149f0/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200422205258-72e4a01eba43/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -594,6 +636,7 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.2-0.20230118093459-a9481185b34d h1:qp0AnQCvRCMlu9jBjtdbTaaEmThIgZOrbVyDEOcmKhQ= +google.golang.org/protobuf v1.28.2-0.20230118093459-a9481185b34d/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -615,8 +658,10 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= +gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/vpn/mocks/README.md b/internal/mocks/README.md similarity index 100% rename from vpn/mocks/README.md rename to internal/mocks/README.md diff --git a/vpn/mocks/addr.go b/internal/mocks/addr.go similarity index 100% rename from vpn/mocks/addr.go rename to internal/mocks/addr.go diff --git a/vpn/mocks/dialer.go b/internal/mocks/dialer.go similarity index 100% rename from vpn/mocks/dialer.go rename to internal/mocks/dialer.go diff --git a/internal/model/packet.go b/internal/model/packet.go index 2ed5f674..1ead02a2 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -339,6 +339,7 @@ func (p *Packet) IsData() bool { var pingPayload = []byte{0x2A, 0x18, 0x7B, 0xF3, 0x64, 0x1E, 0xB4, 0xCB, 0x07, 0xED, 0x2D, 0x0A, 0x98, 0x1F, 0xC7, 0x48} +// IsPing returns true if this packet matches a openvpn ping packet. func (p *Packet) IsPing() bool { return bytes.Equal(pingPayload, p.Payload) } diff --git a/internal/model/session.go b/internal/model/session.go index 5e181737..c2f10823 100644 --- a/internal/model/session.go +++ b/internal/model/session.go @@ -7,7 +7,7 @@ const ( // S_ERROR means there was some form of protocol error. S_ERROR = NegotiationState(iota) - 1 - // S_UNDER is the undefined state. + // S_UNDEF is the undefined state. S_UNDEF // S_INITIAL means we're ready to begin the three-way handshake. diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index 56cf3fe9..d3ab453b 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -1,7 +1,6 @@ package reliabletransport import ( - "slices" "testing" "time" @@ -9,6 +8,9 @@ import ( "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/vpntest" "github.com/ooni/minivpn/pkg/config" + + // TODO: replace with stlib slices after 1.21 + "golang.org/x/exp/slices" ) // test that everything that is received from below is eventually ACKed to the sender. diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index eb25afb3..024b4c46 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -2,13 +2,15 @@ package reliabletransport import ( "reflect" - "slices" "testing" "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/optional" + + // TODO: replace with stdlib slices after 1.21 + "golang.org/x/exp/slices" ) func idSequence(s inflightSequence) []model.PacketID { diff --git a/internal/session/datachannelkey.go b/internal/session/datachannelkey.go index 3c903f0b..de8885ac 100644 --- a/internal/session/datachannelkey.go +++ b/internal/session/datachannelkey.go @@ -6,6 +6,11 @@ import ( "sync" ) +var ( + // ErrDataChannelKey is a [DataChannelKey] error. + ErrDataChannelKey = errors.New("bad data-channel key") +) + // DataChannelKey represents a pair of key sources that have been negotiated // over the control channel, and from which we will derive local and remote // keys for encryption and decrption over the data channel. The index refers to @@ -23,9 +28,6 @@ type DataChannelKey struct { mu sync.Mutex } -// errDayaChannelKey is a [DataChannelKey] error. -var errDataChannelKey = errors.New("bad data-channel key") - // Local returns the local [KeySource] func (dck *DataChannelKey) Local() *KeySource { return dck.local @@ -42,7 +44,7 @@ func (dck *DataChannelKey) AddRemoteKey(k *KeySource) error { dck.mu.Lock() defer dck.mu.Unlock() if dck.ready { - return fmt.Errorf("%w: %s", errDataChannelKey, "cannot overwrite remote key slot") + return fmt.Errorf("%w: %s", ErrDataChannelKey, "cannot overwrite remote key slot") } dck.remote = k dck.ready = true diff --git a/internal/session/manager.go b/internal/session/manager.go index c24637b1..113e9ee4 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -14,6 +14,14 @@ import ( "github.com/ooni/minivpn/pkg/config" ) +var ( + // ErrExpiredKey is the error we raise when we have an expired key. + ErrExpiredKey = errors.New("expired key") + + // ErrNoRemoteSessionID indicates we are missing the remote session ID. + ErrNoRemoteSessionID = errors.New("missing remote session ID") +) + // Manager manages the session. The zero value is invalid. Please, construct // using [NewManager]. This struct is concurrency safe. type Manager struct { @@ -103,9 +111,6 @@ func (m *Manager) IsRemoteSessionIDSet() bool { return !m.remoteSessionID.IsNone() } -// ErrNoRemoteSessionID indicates we are missing the remote session ID. -var ErrNoRemoteSessionID = errors.New("missing remote session ID") - // NewACKForPacket creates a new ACK for the given packet IDs. func (m *Manager) NewACKForPacketIDs(ids []model.PacketID) (*model.Packet, error) { defer m.mu.Unlock() @@ -170,8 +175,6 @@ func (m *Manager) NewHardResetPacket() *model.Packet { return packet } -var ErrExpiredKey = errors.New("expired key") - // LocalDataPacketID returns an unique Packet ID for the Data Channel. It // increments the counter for the local data packet ID. func (m *Manager) LocalDataPacketID() (model.PacketID, error) { @@ -228,7 +231,7 @@ func (m *Manager) ActiveKey() (*DataChannelKey, error) { defer m.mu.Unlock() m.mu.Lock() if len(m.keys) > math.MaxUint8 || m.keyID >= uint8(len(m.keys)) { - return nil, fmt.Errorf("%w: %s", errDataChannelKey, "no such key id") + return nil, fmt.Errorf("%w: %s", ErrDataChannelKey, "no such key id") } dck := m.keys[m.keyID] return dck, nil diff --git a/internal/tlssession/tlshandshake_test.go b/internal/tlssession/tlshandshake_test.go index 46b2a4d4..95bdb6df 100644 --- a/internal/tlssession/tlshandshake_test.go +++ b/internal/tlssession/tlshandshake_test.go @@ -14,8 +14,8 @@ import ( "time" "github.com/google/martian/mitm" + "github.com/ooni/minivpn/internal/mocks" "github.com/ooni/minivpn/pkg/config" - "github.com/ooni/minivpn/vpn/mocks" tls "github.com/refraction-networking/utls" ) diff --git a/internal/vpntest/certs.go b/internal/vpntest/certs.go index d31fb6a9..3b0b4f8f 100644 --- a/internal/vpntest/certs.go +++ b/internal/vpntest/certs.go @@ -78,12 +78,14 @@ S5nL4GaRzx84PB1HWONlh0Wp7KBk2j6Lp0acoJwI2mHJcJoOPpaYiWWYNNTjMv2/ XXNUizTI136liavLslSMoYkjYAun+5HOux/keA1L+lm2XeG06Ew1qS4= -----END CERTIFICATE-----`) +// TestingCert holds key, cert and ca to pass to tests needing to mock certificates. type TestingCert struct { Cert string Key string CA string } +// WriteTEestingCerts will write valid certificates in the passed dir, and return a [TestingCert] and any error. func WriteTestingCerts(dir string) (TestingCert, error) { certFile, err := os.CreateTemp(dir, "tmpfile-") if err != nil { diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index f2e74ccb..48bc9d06 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -3,7 +3,6 @@ package vpntest import ( "fmt" "regexp" - "slices" "strconv" "sync" "time" @@ -11,6 +10,9 @@ import ( "github.com/apex/log" "github.com/ooni/minivpn/internal/bytesx" "github.com/ooni/minivpn/internal/model" + + // TODO: replace with stdlib slices after 1.21 + "golang.org/x/exp/slices" ) // PacketWriter writes packets into a channel. diff --git a/internal/vpntest/packetio_test.go b/internal/vpntest/packetio_test.go index f39f9903..0744f8d5 100644 --- a/internal/vpntest/packetio_test.go +++ b/internal/vpntest/packetio_test.go @@ -3,12 +3,14 @@ package vpntest import ( "bytes" "reflect" - "slices" "testing" "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" + + // TODO: replace with stdlib slices after 1.21 + "golang.org/x/exp/slices" ) func TestPacketLog_ACKs(t *testing.T) { diff --git a/pkg/config/vpnoptions.go b/pkg/config/vpnoptions.go index 2b3e320f..06d8ef36 100644 --- a/pkg/config/vpnoptions.go +++ b/pkg/config/vpnoptions.go @@ -77,7 +77,7 @@ const ProtoUDP = Proto("udp") // ErrBadConfig is the generic error returned for invalid config files var ErrBadConfig = errors.New("openvpn: bad config") -// SupportCiphers defines the supported ciphers. +// SupportedCiphers defines the supported ciphers. var SupportedCiphers = []string{ "AES-128-CBC", "AES-192-CBC", diff --git a/pkg/config/vpnoptions_test.go b/pkg/config/vpnoptions_test.go index b9cda01a..1000ac7a 100644 --- a/pkg/config/vpnoptions_test.go +++ b/pkg/config/vpnoptions_test.go @@ -365,7 +365,7 @@ func Test_parseProxyOBFS4(t *testing.T) { opt := &OpenVPNOptions{} obfs4Uri := "obfs4://foobar" o, err := parseProxyOBFS4([]string{obfs4Uri}, opt) - var wantErr error = nil + var wantErr error if !errors.Is(err, wantErr) { t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err) } @@ -386,7 +386,7 @@ func Test_parseCA(t *testing.T) { t.Run("empty part should fail", func(t *testing.T) { _, err := parseCA([]string{}, &OpenVPNOptions{}, "") - var wantErr error = ErrBadConfig + wantErr := ErrBadConfig if !errors.Is(err, wantErr) { t.Errorf("parseCA(): want %v, got %v", wantErr, err) } @@ -404,7 +404,7 @@ func Test_parseCert(t *testing.T) { t.Run("empty parts should fail", func(t *testing.T) { _, err := parseCert([]string{}, &OpenVPNOptions{}, "") - var wantErr error = ErrBadConfig + wantErr := ErrBadConfig if !errors.Is(err, wantErr) { t.Errorf("parseCert(): want %v, got %v", wantErr, err) } @@ -412,7 +412,7 @@ func Test_parseCert(t *testing.T) { t.Run("non-existent cert should fail", func(t *testing.T) { _, err := parseCert([]string{"/tmp/nonexistent"}, &OpenVPNOptions{}, "") - var wantErr error = ErrBadConfig + var wantErr = ErrBadConfig if !errors.Is(err, wantErr) { t.Errorf("parseCert(): want %v, got %v", wantErr, err) } diff --git a/vpn/bytes.go b/vpn/bytes.go deleted file mode 100644 index c018f364..00000000 --- a/vpn/bytes.go +++ /dev/null @@ -1,157 +0,0 @@ -package vpn - -// -// Functions operating on bytes: -// -// 1. generating random bytes; -// -// 2. OpenVPN options encoding and decoding; -// -// 3. PKCS#7 padding and unpadding. -// - -import ( - "bytes" - "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "io" - "math" -) - -var ( - // errEncodeOption indicates an option encoding error occurred. - errEncodeOption = errors.New("can't encode option") - - // errDecodeOption indicates an option decoding error occurred. - errDecodeOption = errors.New("can't decode option") - - // errPaddingPKCS7 indicates that a PKCS#7 padding error has occurred. - errPaddingPKCS7 = errors.New("PKCS#7 padding error") - - // errUnpaddingPKCS7 indicates that a PKCS#7 unpadding error has occurred. - errUnpaddingPKCS7 = errors.New("PKCS#7 unpadding error") -) - -// genRandomBytes returns an array of bytes with the given size using -// a CSRNG, on success, or an error, in case of failure. -func genRandomBytes(size int) ([]byte, error) { - b := make([]byte, size) - _, err := rand.Read(b) - return b, err -} - -// encodeOptionStringToBytes is used to encode the options string, username and password. -// -// According to the OpenVPN protocol, options are represented as a two-byte word, -// plus the byte representation of the string, null-terminated. -// -// See https://openvpn.net/community-resources/openvpn-protocol/. -// -// This function returns errEncodeOption in case of failure. -func encodeOptionStringToBytes(s string) ([]byte, error) { - if len(s) >= math.MaxUint16 { // Using >= b/c we need to account for the final \0 - return nil, fmt.Errorf("%w:%s", errEncodeOption, "string too large") - } - data := make([]byte, 2) - binary.BigEndian.PutUint16(data, uint16(len(s))+1) - data = append(data, []byte(s)...) - data = append(data, 0x00) - return data, nil -} - -// decodeOptionStringFromBytes returns the string-value for the null-terminated string -// returned by the server when sending remote options to us. -// -// This function returns errDecodeOption on failure. -func decodeOptionStringFromBytes(b []byte) (string, error) { - if len(b) < 2 { - return "", fmt.Errorf("%w: expected at least two bytes", errDecodeOption) - } - length := int(binary.BigEndian.Uint16(b[:2])) - b = b[2:] // skip over the length - // the server sends padding, so we cannot do a strict check - if len(b) < length { - return "", fmt.Errorf("%w: got %d, expected %d", errDecodeOption, len(b), length) - } - if len(b) <= 0 || length == 0 { - return "", fmt.Errorf("%w: zero length encoded option is not possible: %s", errDecodeOption, - "we need at least one byte for the trailing \\0") - } - if b[length-1] != 0x00 { - return "", fmt.Errorf("%w: missing trailing \\0", errDecodeOption) - } - return string(b[:len(b)-1]), nil -} - -// bytesUnpadPKCS7 performs the PKCS#7 unpadding of a byte array. -func bytesUnpadPKCS7(b []byte, blockSize int) ([]byte, error) { - // 1. check whether we can unpad at all - if blockSize > math.MaxUint8 { - return nil, fmt.Errorf("%w: blockSize too large", errUnpaddingPKCS7) - } - // 2. trivial case - if len(b) <= 0 { - return nil, fmt.Errorf("%w: passed empty buffer", errUnpaddingPKCS7) - } - // 4. read the padding size - psiz := int(b[len(b)-1]) - // 5. enforce padding size constraints - if psiz <= 0x00 { - return nil, fmt.Errorf("%w: padding size cannot be zero", errUnpaddingPKCS7) - } - if psiz > blockSize { - return nil, fmt.Errorf("%w: padding size cannot be larger than blockSize", errUnpaddingPKCS7) - } - // 6. compute the padding offset - off := len(b) - psiz - // 7. return unpadded bytes - panicIfFalse(off >= 0 && off <= len(b), "off is out of bounds") - return b[:off], nil -} - -// bytesPadPKCS7 returns the PKCS#7 padding of a byte array. -func bytesPadPKCS7(b []byte, blockSize int) ([]byte, error) { - if blockSize <= 0 { - return nil, fmt.Errorf("%w: %s", errBadInput, "blocksize cannot be negative or zero") - } - // If lth mod blockSize == 0, then the input gets appended a whole block size - // See https://datatracker.ietf.org/doc/html/rfc5652#section-6.3 - if blockSize > math.MaxUint8 { - // This padding method is well defined iff blockSize is less than 256. - return nil, errPaddingPKCS7 - } - psiz := blockSize - len(b)%blockSize - padding := bytes.Repeat([]byte{byte(psiz)}, psiz) - return append(b, padding...), nil -} - -// bufReadUint32 is a convenience function that reads a uint32 from a 4-byte -// buffer, returning an error if the operation failed. -func bufReadUint32(buf *bytes.Buffer) (uint32, error) { - var numBuf [4]byte - _, err := io.ReadFull(buf, numBuf[:]) - if err != nil { - return 0, err - } - return binary.BigEndian.Uint32(numBuf[:]), nil -} - -// bufWriteUint32 is a convenience function that appends to the given buffer -// 4 bytes containing the big-endian representation of the given uint32 value. -func bufWriteUint32(buf *bytes.Buffer, val uint32) { - var numBuf [4]byte - binary.BigEndian.PutUint32(numBuf[:], val) - buf.Write(numBuf[:]) -} - -// bufWriteUint24 is a convenience function that appends to the given buffer -// 3 bytes containing the big-endian representation of the given uint32 value. -// Caller is responsible to ensure the passed value does not overflow the -// maximal capacity of 3 bytes. -func bufWriteUint24(buf *bytes.Buffer, val uint32) { - b := &bytes.Buffer{} - bufWriteUint32(b, val) - buf.Write(b.Bytes()[1:]) -} diff --git a/vpn/bytes_test.go b/vpn/bytes_test.go deleted file mode 100644 index 632fe47a..00000000 --- a/vpn/bytes_test.go +++ /dev/null @@ -1,399 +0,0 @@ -package vpn - -import ( - "errors" - "math" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func Test_genRandomBytes(t *testing.T) { - const smallBuffer = 128 - data, err := genRandomBytes(smallBuffer) - if err != nil { - t.Fatal("unexpected error", err) - } - if len(data) != smallBuffer { - t.Fatal("unexpected returned buffer length") - } -} - -func Test_encodeOptionStringToBytes(t *testing.T) { - type args struct { - s string - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{{ - name: "common case", - args: args{ - s: "test", - }, - want: []byte{0, 5, 116, 101, 115, 116, 0}, - wantErr: nil, - }, { - name: "encoding empty string", - args: args{ - s: "", - }, - want: []byte{0, 1, 0}, - wantErr: nil, - }, { - name: "encoding a very large string", - args: args{ - s: string(make([]byte, 1<<16)), - }, - want: nil, - wantErr: errEncodeOption, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := encodeOptionStringToBytes(tt.args.s) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("encodeOptionStringToBytes() error = %v, wantErr %v", err, tt.wantErr) - } - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatal(diff) - } - }) - } -} - -func Test_decodeOptionStringFromBytes(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want string - wantErr error - }{{ - name: "with zero-length input", - args: args{ - b: nil, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with input length equal to one", - args: args{ - b: []byte{0x00}, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with input length equal to two", - args: args{ - b: []byte{0x00, 0x00}, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with length mismatch and length < actual length", - args: args{ - b: []byte{ - 0x00, 0x03, // length = 3 - 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa - 0x00, // trailing zero - }, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with length mismatch and length > actual length", - args: args{ - b: []byte{ - 0x00, 0x44, // length = 68 - 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa - 0x00, // trailing zero - }, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with missing trailing \\0", - args: args{ - b: []byte{ - 0x00, 0x05, // length = 5 - 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa - }, - }, - want: "", - wantErr: errDecodeOption, - }, { - name: "with valid input", - args: args{ - b: []byte{ - 0x00, 0x06, // length = 6 - 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa - 0x00, // trailing zero - }, - }, - want: "aaaaa", - wantErr: nil, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeOptionStringFromBytes(tt.args.b) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("decodeOptionStringFromBytes() error = %v, wantErr %v", err, tt.wantErr) - } - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatal(diff) - } - }) - } -} - -func Test_bytesUnpadPKCS7(t *testing.T) { - type args struct { - b []byte - blockSize int - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{{ - name: "with too-large blockSize", - args: args{ - b: []byte{0x00, 0x00, 0x00}, - blockSize: math.MaxUint8 + 1, // too large - }, - want: nil, - wantErr: errUnpaddingPKCS7, - }, { - name: "with zero-length array", - args: args{ - b: nil, - blockSize: 2, - }, - want: nil, - wantErr: errUnpaddingPKCS7, - }, { - name: "with 0x00 used as padding", - args: args{ - b: []byte{ - 0x61, 0x61, // block ("aa") - 0x00, 0x00, // padding - }, - blockSize: 2, - }, - want: nil, - wantErr: errUnpaddingPKCS7, - }, { - name: "with padding larger than block size", - args: args{ - b: []byte{ - 0x61, 0x61, // block ("aa") - 0x03, 0x03, // padding - }, - blockSize: 2, - }, - want: nil, - wantErr: errUnpaddingPKCS7, - }, { - name: "with blocksize == 4 and len(data) == 0", - args: args{ - b: []byte{ - 0x04, 0x04, 0x04, 0x04, // padding - }, - blockSize: 4, - }, - want: []byte{}, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 1", - args: args{ - b: []byte{ - 0xde, // data - 0x03, 0x03, 0x03, // padding - }, - blockSize: 4, - }, - want: []byte{0xde}, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 2", - args: args{ - b: []byte{ - 0xde, 0xad, // data - 0x02, 0x02, // padding - }, - blockSize: 4, - }, - want: []byte{0xde, 0xad}, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 3", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, // data - 0x01, // padding - }, - blockSize: 4, - }, - want: []byte{0xde, 0xad, 0xbe}, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 4", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, 0xff, // data - 0x04, 0x04, 0x04, 0x04, // padding - }, - blockSize: 4, - }, - want: []byte{0xde, 0xad, 0xbe, 0xff}, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 5", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, 0xff, 0xab, // data - 0x03, 0x03, 0x03, // padding - }, - blockSize: 4, - }, - want: []byte{0xde, 0xad, 0xbe, 0xff, 0xab}, - wantErr: nil, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := bytesUnpadPKCS7(tt.args.b, tt.args.blockSize) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("bytesUnpadPKCS7() error = %v, wantErr %v", err, tt.wantErr) - } - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatal(diff) - } - }) - } -} - -func Test_bytesPadPKCS7(t *testing.T) { - type args struct { - b []byte - blockSize int - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{{ - name: "with too-large block size", - args: args{ - b: []byte{0x00, 0x00, 0x00}, - blockSize: math.MaxUint8 + 1, - }, - want: nil, - wantErr: errPaddingPKCS7, - }, { - name: "with negative block size", - args: args{ - b: []byte{0x00, 0x00, 0x00}, - blockSize: -1, - }, - want: nil, - wantErr: errBadInput, - }, { - name: "with blockSize == 4 and len(data) == 0", - args: args{ - b: nil, - blockSize: 4, - }, - want: []byte{ - 0x04, 0x04, 0x04, 0x04, // only padding - }, - wantErr: nil, - }, { - name: "with blockSize == 4 and len(data) == 1", - args: args{ - b: []byte{ - 0xde, // len(data) == 1 - }, - blockSize: 4, - }, - want: []byte{ - 0xde, // data - 0x03, 0x03, 0x03, // padding - }, - wantErr: nil, - }, { - name: "with blockSize == 4 and len(data) == 2", - args: args{ - b: []byte{ - 0xde, 0xad, // len(data) == 2 - }, - blockSize: 4, - }, - want: []byte{ - 0xde, 0xad, // data - 0x02, 0x02, // padding - }, - wantErr: nil, - }, { - name: "with blockSize == 4 and len(data) == 3", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, // len(data) == 3 - }, - blockSize: 4, - }, - want: []byte{ - 0xde, 0xad, 0xbe, //data - 0x01, // padding - }, - wantErr: nil, - }, { - name: "with blockSize == 4 and len(data) == 4", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, 0xef, // len(data) == 4 - }, - blockSize: 4, - }, - want: []byte{ - 0xde, 0xad, 0xbe, 0xef, // data - 0x04, 0x04, 0x04, 0x04, // padding - }, - wantErr: nil, - }, { - name: "with blocksize == 4 and len(data) == 5", - args: args{ - b: []byte{ - 0xde, 0xad, 0xbe, 0xef, 0xab, // len(data) == 5 - }, - blockSize: 4, - }, - want: []byte{ - 0xde, 0xad, 0xbe, 0xef, 0xab, // data - 0x03, 0x03, 0x03, // padding - }, - wantErr: nil, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := bytesPadPKCS7(tt.args.b, tt.args.blockSize) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("bytesPadPKCS7() error = %v, wantErr %v", err, tt.wantErr) - } - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatal(diff) - } - }) - } -} - -// Regression test for MIV-01-002 -func Test_Crash_bytesPadPCKS7(t *testing.T) { - bytesPadPKCS7(nil, 0) - bytesPadPKCS7([]byte{0xaa, 0xab}, -1) -} diff --git a/vpn/client.go b/vpn/client.go deleted file mode 100644 index 94dc2ebb..00000000 --- a/vpn/client.go +++ /dev/null @@ -1,257 +0,0 @@ -package vpn - -// -// Client initialization and public methods -// - -import ( - "context" - "errors" - "fmt" - "net" - "strings" - "sync" - "time" -) - -var ( - // ErrDialError is a generic error while dialing - ErrDialError = errors.New("dial error") - - // ErrAlreadyStarted is returned when trying to start the tunnel more than once - ErrAlreadyStarted = errors.New("tunnel already started") - - // ErrNotReady is returned when a Read/Write attempt is made before the tunnel is ready. - ErrNotReady = errors.New("tunnel not ready") -) - -// tunnelInfo holds state about the VPN tunnelInfo that has longer duration than a -// given session. This information is gathered at different stages: -// - during the handshake (mtu). -// - after server pushes config options(ip, gw). -type tunnelInfo struct { - mtu int - ip string - gw string - peerID int -} - -// vpnClient is a net.Conn that uses the VPN tunnel. It is a net.Conn with an -// additional `Start()` method. -type vpnClient interface { - net.Conn - Start(ctx context.Context) error -} - -type dialContextFn func(context.Context, string, string) (net.Conn, error) - -// DialerContext is anything that features a net.Dialer-like DialContext method. -type DialerContext interface { - DialContext(context.Context, string, string) (net.Conn, error) -} - -// Client implements the OpenVPN protocol. A Client object satisfies the -// net.Conn interface. plus Start(). -// The Read and Write operations send and receive bytes to and from the tunnel -// - they are writing to and reading from the OpenVPN Data channel, with the -// control channel being handled in the background. -// To Dial sockets through the Tunnel, you should use the NewTunDialer constructor, -// that accepts a Client object. -// Client is only intended to be directly instantiated by users that need a -// finer control of the protocol steps, or for the case in which you need the -// equivalent of raw sockets. -type Client struct { - Opts *Options - Dialer DialerContext - - // If this channel is not nil, a series of Event* will be - // sent to the channel. The user of the Client can set a - // channel externally to subscribe to discrete transitions. A sufficiently - // buffered-channel should be provided to avoid losing events (~10 - // events should do it). - EventListener chan uint8 - - Log Logger - - conn net.Conn - mux vpnMuxer - tunInfo *tunnelInfo - - // muxerFactoryFn allows to inject a different factory - // for testing. - muxerFactoryFn muxFactory - - startOnce sync.Once - startErr error -} - -var _ net.Conn = &Client{} // Ensure that we implement net.Conn -var _ vpnClient = &Client{} // Ensure that we implement vpnClient - -// NewClientFromOptions returns a Client configured with the given Options. -func NewClientFromOptions(opt *Options) *Client { - if opt == nil { - return &Client{} - } - return &Client{ - Opts: opt, - tunInfo: &tunnelInfo{}, - Dialer: &net.Dialer{}, - } -} - -// -// observability -// - -// emit sends the passed stage into any configured EventListener. -func (c *Client) emit(stage uint8) { - select { - case c.EventListener <- stage: - default: - // don't deliver - } -} - -// Start starts the OpenVPN tunnel. -func (c *Client) Start(ctx context.Context) error { - c.startOnce.Do(func() { - c.startErr = c.start(ctx) - }) - return c.startErr -} - -func (c *Client) start(ctx context.Context) error { - c.emit(EventReady) - - conn, err := c.dial(ctx) - if err != nil { - return err - } - - c.emit(EventDialDone) - - muxFactory := c.muxerFactory() - mux, err := muxFactory(conn, c.Opts, c.tunInfo) - if err != nil { - conn.Close() - return err - } - - mux.SetEventListener(c.EventListener) - - err = mux.Handshake(ctx) - if err != nil { - conn.Close() - return err - } - - c.emit(EventHandshakeDone) - - c.conn = conn - c.mux = mux - return nil -} - -// muxerFactory returns the default muxer Factory, or any other one that has -// been injected into the `muxerFactoryFn` private field in Client for testing. -func (c *Client) muxerFactory() muxFactory { - muxFactory := newMuxerFromOptions - if c.muxerFactoryFn == nil { - return muxFactory - } - return c.muxerFactoryFn -} - -// dial opens a TCP/UDP socket against the remote, and creates an internal -// data channel. It is the second step in an OpenVPN connection (out of five). -// (In UDP mode no network connection is done at this step). -func (c *Client) dial(ctx context.Context) (net.Conn, error) { - if c.Opts == nil { - return nil, fmt.Errorf("%w:%s", errBadInput, "nil options") - - } - var proto string - switch c.Opts.Proto { - case UDPMode: - proto = protoUDP.String() - case TCPMode: - proto = protoTCP.String() - default: - return nil, fmt.Errorf("%w: unknown proto %d", errBadInput, c.Opts.Proto) - - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - msg := fmt.Sprintf("Connecting to %s:%s with proto %s", - c.Opts.Remote, c.Opts.Port, strings.ToUpper(proto)) - logger.Info(msg) - - conn, err := c.Dialer.DialContext(ctx, proto, net.JoinHostPort(c.Opts.Remote, c.Opts.Port)) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrDialError, err) - } - return conn, nil - } -} - -// Write sends bytes into the tunnel. -func (c *Client) Write(b []byte) (int, error) { - return c.mux.Write(b) -} - -// Read reads bytes from the tunnel. -func (c *Client) Read(b []byte) (int, error) { - if c.mux == nil { - return 0, ErrNotReady - - } - return c.mux.Read(b) -} - -// Close closes the tunnel connection. -func (c *Client) Close() error { - if c.conn != nil { - return c.conn.Close() - } - return nil -} - -// LocalAddr returns the local address on the tunnel virtual device, if known. -// In case the Addr is not known, a zero-value net.Addr will be returned. -func (c *Client) LocalAddr() net.Addr { - addr := &net.IPAddr{} - if c.tunInfo != nil { - if ip := net.ParseIP(c.tunInfo.ip); ip != nil { - addr.IP = ip - } - } - return addr -} - -// RemoteAddr returns the address of the tun interface of the tunnel gateway, -// if known. In case the Addr is not known, a zero-value net.Addr will be returned. -func (c *Client) RemoteAddr() net.Addr { - addr := &net.IPAddr{} - if c.tunInfo != nil { - if ip := net.ParseIP(c.tunInfo.gw); ip != nil { - addr.IP = ip - } - } - return addr -} - -func (c *Client) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *Client) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *Client) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/vpn/client_test.go b/vpn/client_test.go deleted file mode 100644 index 1e6975e3..00000000 --- a/vpn/client_test.go +++ /dev/null @@ -1,286 +0,0 @@ -package vpn - -import ( - "context" - "errors" - "net" - "reflect" - "testing" - "time" -) - -// the name is confusing, but we're just getting a generic mocked conn -// that serves as witness of calls -// TODO can copy the mockTLSConn here to avoid confusion with names and -// decouple these tests from those. -func makeTestingClientConn() (*Client, *MockTLSConn) { - c := makeConnForTransportTest() - cl := &Client{} - cl.conn = c - return cl, c -} - -func TestNewClientFromOptions(t *testing.T) { - t.Run("proper options does not fail getting client", func(t *testing.T) { - randomFn = func(int) ([]byte, error) { - return []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil - } - opts := makeTestingOptions(t, "AES-128-GCM", "sha512") - _ = NewClientFromOptions(opts) - }) - - t.Run("nil options return empty client", func(t *testing.T) { - c := NewClientFromOptions(nil) - if !reflect.DeepEqual(c, &Client{}) { - t.Error("Client.NewClientFromOptions(): expected empty client with nil options") - } - }) -} - -type mockMuxerForClient struct { - muxer - writeCalled bool - readCalled bool -} - -func (mm *mockMuxerForClient) Read([]byte) (int, error) { - mm.readCalled = true - return 42, nil -} - -func (mm *mockMuxerForClient) Write(b []byte) (int, error) { - mm.writeCalled = true - return len(b), nil -} - -func mockMuxerFactory() muxFactory { - fn := func(net.Conn, *Options, *tunnelInfo) (vpnMuxer, error) { - m := &mockMuxerWithDummyHandshake{} - return m, nil - } - return fn -} - -func TestClient_Write(t *testing.T) { - // test that call to write calls the muxer method - cl, _ := makeTestingClientConn() - mux := &mockMuxerForClient{} - cl.mux = mux - _, err := cl.Write([]byte("alles ist green")) - if err != nil { - t.Errorf("Client.Write(): expected err = nil, got %v", err) - } - if !mux.writeCalled { - t.Errorf("Client.Write(): client.mux.Write() not called") - } -} - -func TestClient_Read(t *testing.T) { - cl, _ := makeTestingClientConn() - cl.mux = nil - b := make([]byte, 255) - _, err := cl.Read(b) - if !errors.Is(err, ErrNotReady) { - t.Errorf("Client.Read(): nil mux, expected error %v, got %v ", errBadInput, err) - } - - // test that call to read calls the muxer method - cl, _ = makeTestingClientConn() - mux := &mockMuxerForClient{} - cl.mux = mux - b = make([]byte, 255) - _, err = cl.Read(b) - if err != nil { - t.Errorf("Client.Read(): expected err = nil, got %v", err) - } - if !mux.readCalled { - t.Errorf("Client.Read(): client.mux.Read() not called") - } -} - -func TestClient_LocalAddr(t *testing.T) { - cl, _ := makeTestingClientConn() - cl.tunInfo = nil - a := cl.LocalAddr() - if a.String() != "" { - t.Errorf("Client.LocalAddr(): expected empty string, got %v", a.String()) - } -} - -func TestClient_RemoteAddr(t *testing.T) { - cl, _ := makeTestingClientConn() - a := cl.RemoteAddr() - if a.String() != "" { - t.Errorf("Client.RemoteAddr(): expected empty string, got %v", a.String()) - } -} - -// for the tests that test the delegation of methods to the underlying conn we -// can reuse the mock used in transport_test - -func TestClient_SetDeadline(t *testing.T) { - cl, conn := makeTestingClientConn() - err := cl.SetDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("Client.SetDeadline() error = %v, want = nil", err) - } - if !conn.setDeadlineCalled { - t.Error("Client.SetDeadline(): conn.SetDeadline() not called") - } - -} - -func TestClient_SetReadDeadline(t *testing.T) { - cl, conn := makeTestingClientConn() - err := cl.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("Client.SetDeadline() error = %v, want = nil", err) - } - if !conn.setReadDeadlineCalled { - t.Error("Client.SetReadDeadline(): conn.SetReadDeadline() not called") - } -} - -func TestClient_SetWriteDeadline(t *testing.T) { - cl, conn := makeTestingClientConn() - err := cl.SetWriteDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("Client.SetWriteDeadline() error = %v, want = nil", err) - } - if !conn.setWriteDeadlineCalled { - t.Error("Client.SetWriteDeadline(): conn.SetWriteReadDeadline() not called") - } -} - -func TestClient_Close(t *testing.T) { - cl, conn := makeTestingClientConn() - err := cl.Close() - if err != nil { - t.Errorf("Client.Close() error = %v, want = nil", err) - } - if !conn.closedCalled { - t.Error("Client.Close(): conn.Close() not called") - } -} - -type badDialer struct{} - -func (bd *badDialer) DialContext(context.Context, string, string) (net.Conn, error) { - return nil, errors.New("cannot dial") -} - -func TestClient_dialFailsWithBadOptions(t *testing.T) { - c := &Client{} - _, err := c.dial(context.Background()) - wantErr := errBadInput - if !errors.Is(err, wantErr) { - t.Error("Client.Dial(): should fail with nil options") - } - - c = &Client{ - Opts: &Options{ - Proto: 3, - }, - } - _, err = c.dial(context.Background()) - wantErr = errBadInput - if !errors.Is(err, wantErr) { - t.Error("Client.Dial(): should fail with bad proto") - } - - c = &Client{ - Opts: &Options{ - Proto: TCPMode, - }, - Dialer: &badDialer{}, - } - _, err = c.dial(context.Background()) - wantErr = ErrDialError - if !errors.Is(err, wantErr) { - t.Errorf("Client.Dial(): should fail with ErrDialError, err = %v", err) - } -} - -func TestCient_DialRaisesError(t *testing.T) { - c := &Client{ - Opts: &Options{ - Proto: TCPMode, - }, - } - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - _, err := c.dial(ctx) - if err != context.Canceled { - t.Errorf("Client.Dial(): expected context.Canceled, err = %v", err) - } -} - -func TestClient_StartRaisesDialError(t *testing.T) { - c := &Client{ - Opts: &Options{ - Proto: TCPMode, - }, - Dialer: &badDialer{}, - } - err := c.Start(context.Background()) - if !errors.Is(err, ErrDialError) { - t.Errorf("Client.Start(): expected = %v, got = %v", ErrDialError, err) - } -} - -func TestClientStartWithMockedMuxerFactory(t *testing.T) { - c := &Client{ - Opts: &Options{ - Proto: TCPMode, - }, - Dialer: &mockedDialerContext{}, - } - c.muxerFactoryFn = mockMuxerFactory() - err := c.Start(context.Background()) - if err != nil { - t.Errorf("expected no error, got %v", err) - } -} - -func TestClient_emitSendsToListener(t *testing.T) { - t.Run("emit writes event if listener not null", func(t *testing.T) { - l := make(chan uint8, 2) - c := &Client{} - c.EventListener = l - sent := uint8(2) - c.emit(sent) - got := <-l - if got != sent { - t.Errorf("expected %v, got %v", sent, got) - } - }) - t.Run("emit is a noop if evenlistener not set", func(t *testing.T) { - c := &Client{} - sent := uint8(2) - c.emit(sent) - if c.EventListener != nil { - t.Errorf("expected EventListener to be nil") - } - }) - t.Run("listener receives several events", func(t *testing.T) { - l := make(chan uint8, 5) - c := &Client{} - c.EventListener = l - received := []uint8{} - sent := []uint8{1, 2, 3, 4, 5} - for _, i := range sent { - c.emit(i) - } - for _ = range sent { - got := <-l - received = append(received, got) - } - for i := range sent { - if sent[i] != received[i] { - t.Errorf("at [%d]: expected %v, got %v", i, sent, received) - return - } - } - }) -} diff --git a/vpn/control.go b/vpn/control.go deleted file mode 100644 index 2f02f8d5..00000000 --- a/vpn/control.go +++ /dev/null @@ -1,278 +0,0 @@ -package vpn - -// -// OpenVPN control channel -// - -import ( - "bytes" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "math" - "net" - "sync" -) - -var ( - errBadReset = errors.New("bad reset packet") - errExpiredKey = errors.New("max packet id reached") -) - -var ( - serverPushReply = []byte("PUSH_REPLY") - serverBadAuth = []byte("AUTH_FAILED") -) - -// session keeps mutable state related to an OpenVPN session. -type session struct { - RemoteSessionID sessionID - LocalSessionID sessionID - keys []*dataChannelKey - keyID int - localPacketID packetID - lastACK packetID - ackQueue chan *packet - mu sync.Mutex - Log Logger -} - -// newSession returns a session ready to be used. -func newSession() (*session, error) { - key0 := &dataChannelKey{} - ackQueue := make(chan *packet, 100) - session := &session{ - keys: []*dataChannelKey{key0}, - ackQueue: ackQueue, - } - - randomBytes, err := randomFn(8) - if err != nil { - return session, err - } - - // in go 1.17, one could do: - // localSession := (*sessionID)(lsid) - var localSession sessionID - copy(localSession[:], randomBytes[:8]) - session.LocalSessionID = localSession - - localKey, err := newKeySource() - if err != nil { - return session, err - } - - k, err := session.ActiveKey() - if err != nil { - return session, err - } - k.local = localKey - return session, nil -} - -// ActiveKey returns the dataChannelKey that is actively being used. -func (s *session) ActiveKey() (*dataChannelKey, error) { - if len(s.keys) < s.keyID { - return nil, fmt.Errorf("%w: %s", errDataChannelKey, "no such key id") - } - dck := s.keys[s.keyID] - return dck, nil -} - -// localPacketID returns an unique Packet ID. It increments the counter. -// In the future, this call could detect (or warn us) when we're approaching -// the key end of life. -func (s *session) LocalPacketID() (packetID, error) { - s.mu.Lock() - defer s.mu.Unlock() - pid := s.localPacketID - if pid == math.MaxUint32 { - // we reached the max packetID, increment will overflow - return 0, errExpiredKey - } - s.localPacketID++ - return pid, nil -} - -// UpdateLastACK will update the internal variable for the last acknowledged -// packet to the passed packetID, only if packetID is greater than the lastACK. -func (s *session) UpdateLastACK(newPacketID packetID) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.lastACK == math.MaxUint32 { - return errExpiredKey - } - if s.lastACK != 0 && newPacketID <= s.lastACK { - logger.Warnf("tried to write ack %d; last was %d", newPacketID, s.lastACK) - } - s.lastACK = newPacketID - return nil -} - -// isNextPacket returns true if the packetID is the next integer -// from the last acknowledged packet. -func (s *session) isNextPacket(p *packet) bool { - s.mu.Lock() - defer s.mu.Unlock() - if p == nil { - return false - } - return p.id-s.lastACK == 1 -} - -// control implements the controlHandler interface. -// Like for true pirates, there is no state in control. -type control struct{} - -// SendHardReset sends a control packet with the HardResetClientv2 header, -// over the passed net.Conn. -func (c *control) SendHardReset(conn net.Conn, s *session) error { - _, err := sendControlPacket(conn, s, pControlHardResetClientV2, 0, []byte("")) - return err -} - -// ParseHardReset extracts the sessionID from a hard-reset server response, and -// an error if the operation was not successful. -func (c *control) ParseHardReset(b []byte) (sessionID, error) { - p, err := newServerHardReset(b) - if err != nil { - return sessionID{}, err - } - return parseServerHardResetPacket(p) -} - -// PushRequest returns a byte array with the PUSH_REQUEST command. -func (c *control) PushRequest() []byte { - var out bytes.Buffer - out.Write([]byte("PUSH_REQUEST")) - out.WriteByte(0x00) - return out.Bytes() -} - -// ReadReadPushResponse reads a byte array returned from the server, -// as the response to a Push Request, and returns a string containing the -// tunnel IP. -// For now, this is a single string containing _only_ the tunnel ip, -// but we might want to pass a pointer to the tunnel struct in the -// future. -func (*control) ReadPushResponse(b []byte) map[string][]string { - return pushedOptionsAsMap(b) -} - -// ControlMessage returns a byte array containing a message over the control -// channel. -// This is not a P_CONTROL, but a message over the TLS encrypted channel. -func (c *control) ControlMessage(s *session, opt *Options) ([]byte, error) { - key, err := s.ActiveKey() - if err != nil { - return []byte{}, err - } - return encodeClientControlMessageAsBytes(key.local, opt) -} - -// ReadControlMessage reads a control message with authentication result data. -// it returns the remote key, remote options and an error if we cannot parse -// the data. -func (c *control) ReadControlMessage(b []byte) (*keySource, string, error) { - cm := newServerControlMessageFromBytes(b) - return parseServerControlMessage(cm) -} - -// SendACK builds an ACK control packet for the given packetID, and writes it -// over the passed connection. It returns an error if the operation cannot be -// completed successfully. -func (c *control) SendACK(conn net.Conn, s *session, pid packetID) error { - return sendACKFn(conn, s, pid) -} - -// sendACK is used by controlHandler.SendACK() and by TLSConn.Read() -func sendACK(conn net.Conn, s *session, pid packetID) error { - panicIfFalse(len(s.RemoteSessionID) != 0, "tried to ack with null remote") - - p := newACKPacket(pid, s) - payload := p.Bytes() - payload = maybeAddSizeFrame(conn, payload) - - _, err := conn.Write(payload) - if err != nil { - return err - } - - logger.Debug(fmt.Sprintln("write ack:", pid)) - logger.Debug(fmt.Sprintln(hex.Dump(payload))) - - return s.UpdateLastACK(pid) -} - -var sendACKFn = sendACK - -var _ controlHandler = &control{} // Ensure that we implement controlHandler - -// sendControlPacket crafts a control packet with the given opcode and payload, -// and writes it to the passed net.Conn. -func sendControlPacket(conn net.Conn, s *session, opcode int, ack int, payload []byte) (n int, err error) { - if s == nil { - return 0, fmt.Errorf("%w:%s", errBadInput, "nil session") - } - p := newPacketFromPayload(uint8(opcode), 0, payload) - p.localSessionID = s.LocalSessionID - - p.id, err = s.LocalPacketID() - if err != nil { - return 0, err - } - out := p.Bytes() - - out = maybeAddSizeFrame(conn, out) - - logger.Debug(fmt.Sprintf("control write: (%d bytes)\n", len(out))) - logger.Debug(fmt.Sprintln(hex.Dump(out))) - return conn.Write(out) -} - -// isControlMessage returns a boolean indicating whether the header of a -// payload indicates a control message. -func isControlMessage(b []byte) bool { - if len(b) < 4 { - return false - } - return bytes.Equal(b[:4], controlMessageHeader) -} - -// maybeAddSizeFrame prepends a two-byte header containing the size of the -// payload if the network type for the passed net.Conn is not UDP (assumed to -// be TCP). -func maybeAddSizeFrame(conn net.Conn, payload []byte) []byte { - switch conn.LocalAddr().Network() { - case "udp", "udp4", "udp6": - // nothing to do for UDP - return payload - case "tcp", "tcp4", "tcp6": - length := make([]byte, 2) - binary.BigEndian.PutUint16(length, uint16(len(payload))) - return append(length, payload...) - default: - return []byte{} - } -} - -// isBadAuthReply returns true if the passed payload is a "bad auth" server -// response; false otherwise. -func isBadAuthReply(b []byte) bool { - l := len(serverBadAuth) - if len(b) < l { - return false - } - return bytes.Equal(b[:l], serverBadAuth) -} - -// isPushReply returns true if the passed payload is a "push reply" server -// response; false otherwise. -func isPushReply(b []byte) bool { - l := len(serverPushReply) - if len(b) < l { - return false - } - return bytes.Equal(b[:l], serverPushReply) -} diff --git a/vpn/control_test.go b/vpn/control_test.go deleted file mode 100644 index 56653bb4..00000000 --- a/vpn/control_test.go +++ /dev/null @@ -1,286 +0,0 @@ -package vpn - -import ( - "bytes" - "errors" - "math" - "net" - "reflect" - "testing" - - "github.com/ooni/minivpn/vpn/mocks" -) - -func Test_newSession(t *testing.T) { - tests := []struct { - name string - want *session - wantErr bool - }{ - {"get session", &session{}, false}, - } - // TODO(ainghazal): get smarter and use test values (turn sesion into an interface). - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := newSession() - if (err != nil) != tt.wantErr { - t.Errorf("newSession() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func Test_maybeAddSizeFrame(t *testing.T) { - - type args struct { - conn net.Conn - payload []byte - } - tests := []struct { - name string - args args - want []byte - }{ - - // FIXME ---- fix these tests --- - /* - { - name: "udp", - args: args{ - makeTestinConnFromNetwork("udp"), - []byte{0xff, 0xfe, 0xfd}, - }, - want: []byte{0xff, 0xfe, 0xfd}, - }, - { - name: "tcp", - args: args{ - makeTestinConnFromNetwork("udp"), - []byte{0xff, 0xfe, 0xfd}, - }, - want: []byte{0x00, 0x03, 0xff, 0xfe, 0xfd}, - }, - */ - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := maybeAddSizeFrame(tt.args.conn, tt.args.payload); !reflect.DeepEqual(got, tt.want) { - t.Errorf("maybeAddSizeFrame() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_session_ActiveKey(t *testing.T) { - s := &session{ - keys: make([]*dataChannelKey, 2), - keyID: 10, - } - _, err := s.ActiveKey() - wantErr := errDataChannelKey - if !errors.Is(err, wantErr) { - t.Errorf("session.ActiveKey() = got err %v, want %v", err, wantErr) - } -} - -func Test_session_LocalPacketID(t *testing.T) { - type fields struct { - RemoteSessionID sessionID - LocalSessionID sessionID - keys []*dataChannelKey - keyID int - localPacketID packetID - lastACK packetID - ackQueue chan *packet - } - - tests := []struct { - name string - fields fields - want packetID - wantErr error - }{ - { - "return arbitrary value", - fields{localPacketID: packetID(42)}, - packetID(42), - nil, - }, - { - "return zero", - fields{localPacketID: packetID(0)}, - packetID(0), - nil, - }, - { - "overflow", - fields{localPacketID: math.MaxUint32}, - packetID(0), - errExpiredKey, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &session{ - localPacketID: tt.fields.localPacketID, - } - if got, err := s.LocalPacketID(); got != tt.want || err != tt.wantErr { - t.Errorf("session.LocalPacketID() = %v, want %v", got, tt.want) - } - }) - } - - // increments - val := packetID(1000) - s := &session{localPacketID: packetID(1000)} - - if got, _ := s.LocalPacketID(); got != val { - t.Errorf("session.LocalPacketID() = %v, want %v", got, val) - } - val++ - if got, _ := s.LocalPacketID(); got != val { - t.Errorf("session.LocalPacketID() = %v, want %v", got, val) - } - val++ - if got, _ := s.LocalPacketID(); got != val { - t.Errorf("session.LocalPacketID() = %v, want %v", got, val) - } -} - -func Test_session_isNextPacket(t *testing.T) { - type fields struct { - lastACK packetID - } - type args struct { - p *packet - } - tests := []struct { - name string - fields fields - args args - want bool - }{ - { - "is next", - fields{lastACK: packetID(0)}, - args{&packet{id: packetID(1)}}, - true, - }, - { - "is two more", - fields{lastACK: packetID(0)}, - args{&packet{id: packetID(2)}}, - false, - }, - { - "is lesser", - fields{lastACK: packetID(100)}, - args{&packet{id: packetID(99)}}, - false, - }, - { - "is nil", - fields{lastACK: packetID(100)}, - args{nil}, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &session{ - lastACK: tt.fields.lastACK, - } - if got := s.isNextPacket(tt.args.p); got != tt.want { - t.Errorf("session.isNextPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_isBadAuthReply(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want bool - }{ - {"bad_auth", args{[]byte("AUTH_FAILED")}, true}, - {"too_large", args{[]byte("AUTH_FAILEDAAAAAA")}, true}, - {"too_short", args{[]byte("AAA")}, false}, - {"empty", args{[]byte("")}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isBadAuthReply(tt.args.b); got != tt.want { - t.Errorf("isBadAuthReply() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_isPushReply(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want bool - }{ - {"push_reply", args{serverPushReply}, true}, - {"too_large", args{[]byte("PUSH_REPLYAAA")}, true}, - {"too_short", args{[]byte("AAA")}, false}, - {"empty", args{[]byte("")}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isPushReply(tt.args.b); got != tt.want { - t.Errorf("isPushReply() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_sendControlPacket(t *testing.T) { - _, err := sendControlPacket(&mocks.Conn{}, nil, 1, 1, []byte("")) - wantErr := errBadInput - if !errors.Is(err, wantErr) { - t.Errorf("sendControlPacket(): empty session should fail with err=%v, got=%v", wantErr, err) - } -} - -func Test_isControlMessage(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want bool - }{ - {"good_control", args{controlMessageHeader}, true}, - {"bad_control", args{[]byte{0x00, 0x00, 0x00, 0x01}}, false}, - {"too_short", args{[]byte{0x00}}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isControlMessage(tt.args.b); got != tt.want { - t.Errorf("isControlMessage() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_control_PushRequest(t *testing.T) { - c := &control{} - got := c.PushRequest() - if !bytes.Equal(got[:len(got)-1], []byte("PUSH_REQUEST")) { - t.Errorf("control_PushRequest() = %v", got) - } - if got[len(got)-1] != 0x00 { - t.Errorf("control_PushRequest(): expected trailing null byte") - } -} diff --git a/vpn/crypto.go b/vpn/crypto.go deleted file mode 100644 index 93b64076..00000000 --- a/vpn/crypto.go +++ /dev/null @@ -1,373 +0,0 @@ -package vpn - -// -// Code to perform encryption, decryption and key derivation. -// - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "errors" - "fmt" - "hash" - "log" -) //#nosec G501,G505 - -// TODO(ainghazal,bassosimone): see if it's feasible to use stdlib -// functionality rather than using the code below. - -type ( - // cipherMode describes a cipher mode (e.g., GCM). - cipherMode string - - // cipherName is a cipher name (e.g., AES). - cipherName string -) - -const ( - // cipherModeCBC is the CBC cipher mode. - cipherModeCBC = cipherMode("cbc") - - // cipherModeGCM is the GCM cipher mode. - cipherModeGCM = cipherMode("gcm") - - // cipherNameAES is an AES-based cipher. - cipherNameAES = cipherName("aes") -) - -var ( - // errInvalidKeySize means that the key size is invalid. - errInvalidKeySize = errors.New("invalid key size") - - // errUnsupportedCipher indicates we don't support the desired cipher. - errUnsupportedCipher = errors.New("unsupported cipher") - - // errUnsupportedMode indicates that the mode is not uspported. - errUnsupportedMode = errors.New("unsupported mode") - - // errBadInput indicates invalid inputs to encrypt/decrypt functions. - errBadInput = errors.New("bad input") -) - -// encrypteData holds the different parts needed to decrypt an encrypted data -// packet. -type encryptedData struct { - iv []byte - ciphertext []byte - aead []byte -} - -// plaintextData holds the different parts needed to encrypt a plaintext -// payload (after padding). -type plaintextData struct { - iv []byte - plaintext []byte - aead []byte -} - -// dataCipher encrypts and decrypts OpenVPN data. -type dataCipher interface { - // keySizeBytes returns the key size (in bytes). - keySizeBytes() int - - // isAEAD returns whether this cipher has AEAD properties. - isAEAD() bool - - // blockSize returns the expected block size. - blockSize() uint8 - - // encrypt encripts a plaintext. - // - // Arguments: - // - // - key is the key, whose size must be consistent with the cipher; - // - // - plaintextData is the data to be encrypted; - // - // Returns the ciphertext on success and an error on failure. - encrypt([]byte, *plaintextData) ([]byte, error) - - // decrypt is the opposite operation of encrypt. It takes in input the - // ciphertext and returns the plaintext of an error. - decrypt([]byte, *encryptedData) ([]byte, error) - - // mode returns the cipherMode - cipherMode() cipherMode -} - -// dataCipherAES implements dataCipher for AES. -type dataCipherAES struct { - // ksb is the key size in bytes - ksb int - - // mode is the cipher mode - mode cipherMode -} - -var _ dataCipher = &dataCipherAES{} // Ensure we implement dataCipher - -// keySizeBytes implements dataCipher.keySizeBytes -func (a *dataCipherAES) keySizeBytes() int { - return a.ksb -} - -// isAEAD implements dataCipher.isAEAD -func (a *dataCipherAES) isAEAD() bool { - return a.mode != cipherModeCBC -} - -// blockSize implements dataCipher.BlockSize -func (a *dataCipherAES) blockSize() uint8 { - switch a.mode { - case cipherModeCBC, cipherModeGCM: - return 16 - default: - return 0 - } -} - -// decrypt implements dataCipher.decrypt. -// Since key comes from a prf derivation, we only take as many bytes as we need to match -// our key size. -func (a *dataCipherAES) decrypt(key []byte, data *encryptedData) ([]byte, error) { - if len(key) < a.keySizeBytes() { - return nil, errInvalidKeySize - } - - // they key material might be longer - k := key[:a.keySizeBytes()] - block, err := aes.NewCipher(k) - if err != nil { - return nil, err - } - switch a.mode { - case cipherModeCBC: - if len(data.iv) != block.BlockSize() { - return nil, fmt.Errorf("%w: wrong size for iv: %v", errCannotDecrypt, len(data.iv)) - } - mode := cipher.NewCBCDecrypter(block, data.iv) - plaintext := make([]byte, len(data.ciphertext)) - mode.CryptBlocks(plaintext, data.ciphertext) - plaintext, err := bytesUnpadPKCS7(plaintext, block.BlockSize()) - if err != nil { - return nil, err - } - padLen := len(data.ciphertext) - len(plaintext) - if padLen > block.BlockSize() || padLen > len(plaintext) { - // TODO(bassosimone, ainghazal): discuss the cases in which - // this set of conditions actually occurs. - // TODO(ainghazal): this assertion might actually be moved into a - // boundary assertion in the unpad fun. - return nil, errors.New("unpadding error") - } - return plaintext, nil - - case cipherModeGCM: - // standard nonce size is 12. more is surely ok, but let's stick to it. - // https://github.com/golang/go/blob/master/src/crypto/aes/aes_gcm.go#L37 - if len(data.iv) != 12 { - return nil, fmt.Errorf("%w: wrong size for iv: %v", errCannotDecrypt, len(data.iv)) - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - plaintext, err := aesGCM.Open(nil, data.iv, data.ciphertext, data.aead) - if err != nil { - log.Println("gdm decryption failed:", err.Error()) - /* - log.Println("dump begins----") - log.Println("len:", len(data.ciphertext)) - log.Println("iv:", data.iv) - log.Printf("%v\n", data.ciphertext) - log.Printf("%x\n", data.ciphertext) - log.Printf("aead: %x\n", data.aead) - log.Println("dump ends------") - */ - return nil, err - } - return plaintext, nil - - default: - return nil, errUnsupportedMode - } -} - -func (a *dataCipherAES) cipherMode() cipherMode { - return a.mode -} - -// encrypt implements dataCipher.encrypt -// Since key comes from a prf derivation, we only take as many bytes as we need to match -// our key size. -func (a *dataCipherAES) encrypt(key []byte, data *plaintextData) ([]byte, error) { - if len(key) < a.keySizeBytes() { - return nil, errInvalidKeySize - } - k := key[:a.keySizeBytes()] - block, err := aes.NewCipher(k) - if err != nil { - return nil, err - } - blockSize := block.BlockSize() - switch a.mode { - case cipherModeCBC: - if len(data.iv) != blockSize { - return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", errCannotEncrypt, len(data.iv)) - } - if len(data.plaintext)%blockSize != 0 { - return []byte{}, fmt.Errorf("%w: wrong padding", errCannotEncrypt) - } - mode := cipher.NewCBCEncrypter(block, data.iv) - - ciphertext := make([]byte, len(data.plaintext)) - mode.CryptBlocks(ciphertext, data.plaintext) - return ciphertext, nil - - case cipherModeGCM: - if len(data.iv) != 12 { - return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", errCannotEncrypt, len(data.iv)) - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - // In GCM mode, the IV consists of the 32-bit packet counter - // followed by data from the HMAC key. The HMAC key can be used - // as IV, since in GCM mode the HMAC key is not used for the - // HMAC. The packet counter may not roll over within a single - // TLS session. This results in a unique IV for each packet, as - // required by GCM. - ciphertext := aesGCM.Seal(nil, data.iv, data.plaintext, data.aead) - return ciphertext, nil - - default: - return nil, errUnsupportedMode - } -} - -// newDataCipherFromCipherSuite constructs a new dataCipher from the cipher suite string. -func newDataCipherFromCipherSuite(c string) (dataCipher, error) { - switch c { - case "AES-128-CBC": - return newDataCipher(cipherNameAES, 128, cipherModeCBC) - case "AES-192-CBC": - return newDataCipher(cipherNameAES, 192, cipherModeCBC) - case "AES-256-CBC": - return newDataCipher(cipherNameAES, 256, cipherModeCBC) - case "AES-128-GCM": - return newDataCipher(cipherNameAES, 128, cipherModeGCM) - case "AES-256-GCM": - return newDataCipher(cipherNameAES, 256, cipherModeGCM) - default: - return nil, errUnsupportedCipher - } -} - -// newDataCipher constructs a new dataCipher from the given name, bits, and mode. -func newDataCipher(name cipherName, bits int, mode cipherMode) (dataCipher, error) { - if bits%8 != 0 || bits > 512 || bits < 64 { - return nil, fmt.Errorf("%w: %d", errInvalidKeySize, bits) - } - switch name { - case cipherNameAES: - default: - return nil, fmt.Errorf("%w: %s", errUnsupportedCipher, name) - } - switch mode { - case cipherModeCBC, cipherModeGCM: - default: - return nil, fmt.Errorf("%w: %s", errUnsupportedMode, mode) - } - dc := &dataCipherAES{ - ksb: bits / 8, - mode: mode, - } - return dc, nil -} - -// newHMACFactory accepts a label coming from an OpenVPN auth label, and returns two -// values: a function that will return a Hash implementation, and a boolean -// indicating if the operation was successful. -func newHMACFactory(name string) (func() hash.Hash, bool) { - switch name { - case "sha1": - return sha1.New, true - case "sha256": - return sha256.New, true - case "sha512": - return sha512.New, true - default: - return nil, false - } -} - -// prf function is used to derive master and client keys -func prf(secret, label, clientSeed, serverSeed, clientSid, serverSid []byte, olen int) []byte { - seed := append(clientSeed, serverSeed...) - if len(clientSid) != 0 { - seed = append(seed, clientSid...) - } - if len(serverSid) != 0 { - seed = append(seed, serverSid...) - } - result := make([]byte, olen) - return prf10(result, secret, label, seed) -} - -// Code below is taken from crypto/tls/prf.go -// Copyright 2009 The Go Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. -func prf10(result, secret, label, seed []byte) []byte { - hashSHA1 := sha1.New - hashMD5 := md5.New - - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - s1, s2 := splitPreMasterSecret(secret) - pHash(result, s1, labelAndSeed, hashMD5) - result2 := make([]byte, len(result)) - pHash(result2, s2, labelAndSeed, hashSHA1) - for i, b := range result2 { - result[i] ^= b - } - return result -} - -// SPDX-License-Identifier: BSD-3-Clause -// Split a premaster secret in two as specified in RFC 4346, Section 5. -func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2] - s2 = secret[len(secret)/2:] - return - -} - -// SPDX-License-Identifier: BSD-3-Clause -// pHash implements the P_hash function, as defined in RFC 4346, Section 5. -func pHash(result, secret, seed []byte, hash func() hash.Hash) { - h := hmac.New(hash, secret) - h.Write(seed) - a := h.Sum(nil) - j := 0 - for j < len(result) { - h.Reset() - h.Write(a) - h.Write(seed) - b := h.Sum(nil) - copy(result[j:], b) - j += len(b) - h.Reset() - h.Write(a) - a = h.Sum(nil) - } -} diff --git a/vpn/crypto_test.go b/vpn/crypto_test.go deleted file mode 100644 index 5c22d447..00000000 --- a/vpn/crypto_test.go +++ /dev/null @@ -1,405 +0,0 @@ -package vpn - -import ( - "bytes" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/hex" - "errors" - "hash" - "log" - "reflect" - "testing" -) - -func TestDataCipherAES(t *testing.T) { - _, err := newDataCipher("aes", 128, "cbc") - if err != nil { - t.Errorf("Cannot instantiate aes-128-cbc") - } -} - -func TestBadCipher(t *testing.T) { - _, err := newDataCipher("bad", 128, "cbc") - if err == nil { - t.Errorf("Should fail with bad cipher") - } -} - -func TestBadMode(t *testing.T) { - _, err := newDataCipher("aes", 128, "bad") - if err == nil { - t.Errorf("Should fail with bad mode") - } -} - -func TestBadKeySize(t *testing.T) { - _, err := newDataCipher("aes", 1024, "cbc") - if err == nil { - t.Errorf("Should fail with bad key size") - } - _, err = newDataCipher("aes", 8, "cbc") - if err == nil { - t.Errorf("Should fail with bad key size") - } -} - -func Test_newDataCipherFromCipherSuite(t *testing.T) { - type args struct { - c string - } - tests := []struct { - name string - args args - want dataCipher - wantErr bool - }{ - {"aes-128-cbc", args{"AES-128-CBC"}, &dataCipherAES{16, "cbc"}, false}, - {"aes-192-cbc", args{"AES-192-CBC"}, &dataCipherAES{24, "cbc"}, false}, - {"aes-256-cbc", args{"AES-256-CBC"}, &dataCipherAES{32, "cbc"}, false}, - {"aes-128-gcm", args{"AES-128-GCM"}, &dataCipherAES{16, "gcm"}, false}, - {"aes-256-gcm", args{"AES-256-GCM"}, &dataCipherAES{32, "gcm"}, false}, - {"bad-256-gcm", args{"AES-512-GCM"}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newDataCipherFromCipherSuite(tt.args.c) - if (err != nil) != tt.wantErr { - t.Errorf("newCipherFromCipherSuite() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newCipherFromCipherSuite() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_newCipher(t *testing.T) { - type args struct { - name cipherName - bits int - mode cipherMode - } - tests := []struct { - name string - args args - want dataCipher - wantErr bool - }{ - {"aesOK", args{"aes", 256, "cbc"}, &dataCipherAES{32, "cbc"}, false}, - {"badCipher", args{"blowfish", 256, "cbc"}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newDataCipher(tt.args.name, tt.args.bits, tt.args.mode) - if (err != nil) != tt.wantErr { - t.Errorf("newCipher() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newCipher() = %v, want %v", got, tt.want) - } - }) - } -} - -// this particular test is basically equivalent to reimplementing the factory, but okay, -// it's somehow useful to catch allowed values. -func Test_newHMACFactory(t *testing.T) { - type args struct { - name string - } - tests := []struct { - name string - args args - want func() hash.Hash - want1 bool - }{ - {"sha1", args{"sha1"}, sha1.New, true}, - {"sha256", args{"sha256"}, sha256.New, true}, - {"sha512", args{"sha512"}, sha512.New, true}, - {"shabad", args{"sha192"}, nil, false}, - } - - //str := "hello" - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1 := newHMACFactory(tt.args.name) - if got == nil { - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newHMACFactory() got = %v, want %v", &got, &tt.want) - } - if got1 != tt.want1 { - t.Errorf("newHMACFactory() got1 = %v, want %v", got1, tt.want1) - } - } else { - // it is a function factory, so let's get the function to compare - if !reflect.DeepEqual(got(), tt.want()) { - t.Errorf("newHMACFactory() got = %v, want %v", &got, &tt.want) - } - if got1 != tt.want1 { - t.Errorf("newHMACFactory() got1 = %v, want %v", got1, tt.want1) - } - } - }) - } -} - -func TestPrf(t *testing.T) { - expected := []byte{ - 0x67, 0x18, 0x7c, 0x52, 0xac, 0xd2, 0x4d, 0x95, - 0x9a, 0x55, 0xd3, 0x1c, 0xdb, 0x97, 0x80, 0x11} - secret := []byte("secret") - label := []byte("master key") - cseed := []byte("aaa") - sseed := []byte("bbb") - out := prf(secret, label, cseed, sseed, []byte{}, []byte{}, 16) - if !bytes.Equal(out, expected) { - t.Errorf("Bad output in prf call: %v", out) - } -} - -func Test_dataCipherAES_decrypt(t *testing.T) { - key := bytes.Repeat([]byte("A"), 64) - iv12, _ := hex.DecodeString("000000006868686868686868") - iv16, _ := hex.DecodeString("00000000686868686868686865656565") - ciphertextGCM, _ := hex.DecodeString("a949df311c57ec762428a7ba98d1d0d8213134925bf1cd2cb4ab4ea9066c569b0579") - ciphertextCBC, _ := hex.DecodeString("f908ff8dedbe4e2097c992c67e603d25606c76a460cd785503cf0a2a9e6ec961") - - type fields struct { - ksb int - mode cipherMode - } - type args struct { - key []byte - data *encryptedData - } - tests := []struct { - name string - fields fields - args args - want []byte - wantErr error - }{ - { - name: "good decrypt gcm", - fields: fields{ - ksb: 16, - mode: cipherModeGCM, - }, - args: args{ - key: key, - data: &encryptedData{ - iv: iv12, - ciphertext: ciphertextGCM, - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: []byte("this test is green"), - wantErr: nil, - }, - { - name: "iv too short gcm", - fields: fields{ - ksb: 16, - mode: cipherModeGCM, - }, - args: args{ - key: key, - data: &encryptedData{ - iv: []byte{0x00}, - ciphertext: ciphertextGCM, - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: nil, - wantErr: errCannotDecrypt, - }, - { - name: "good decrypt cbc", - fields: fields{ - ksb: 16, - mode: cipherModeCBC, - }, - args: args{ - key: key, - data: &encryptedData{ - iv: iv16, - ciphertext: ciphertextCBC, - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: []byte("this test is green"), - wantErr: nil, - }, - { - name: "iv too short cbc", - fields: fields{ - ksb: 16, - mode: cipherModeGCM, - }, - args: args{ - key: key, - data: &encryptedData{ - iv: []byte{0x00}, - ciphertext: ciphertextGCM, - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: []byte{}, - wantErr: errCannotDecrypt, - }, - // TODO: Add moar test cases, with failing. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &dataCipherAES{ - ksb: tt.fields.ksb, - mode: tt.fields.mode, - } - got, err := a.decrypt(tt.args.key, tt.args.data) - if !errors.Is(err, tt.wantErr) { - t.Errorf("dataCipherAES.decrypt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !bytes.Equal(got, tt.want) { - t.Errorf("dataCipherAES.decrypt() = %v, want %v", got, tt.want) - } - }) - } -} - -func doPaddingForTest(payload []byte, blockSize int) []byte { - padded, _ := bytesPadPKCS7(payload, blockSize) - return padded -} - -func Test_dataCipherAES_encrypt(t *testing.T) { - key := bytes.Repeat([]byte("A"), 64) - iv12, _ := hex.DecodeString("000000006868686868686868") - iv16, _ := hex.DecodeString("00000000686868686868686865656565") - - ciphertextGCM, _ := hex.DecodeString("a949df311c57ec762428a7ba98d1d0d8213134925bf1cd2cb4ab4ea9066c569b0579") - ciphertextCBC, _ := hex.DecodeString("f908ff8dedbe4e2097c992c67e603d25606c76a460cd785503cf0a2a9e6ec961") - - type fields struct { - ksb int - mode cipherMode - } - type args struct { - key []byte - data *plaintextData - } - tests := []struct { - name string - fields fields - args args - want []byte - wantErr error - }{ - { - name: "good encrypt aes-128-gcm", - fields: fields{ - ksb: 16, - mode: cipherModeGCM, - }, - args: args{ - key: key, - data: &plaintextData{ - iv: iv12, - plaintext: []byte("this test is green"), - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: ciphertextGCM, - wantErr: nil, - }, - { - name: "iv too short aes-128-gcm", - fields: fields{ - ksb: 16, - mode: cipherModeGCM, - }, - args: args{ - key: key, - data: &plaintextData{ - iv: []byte{0x00}, - plaintext: []byte("should fail"), - aead: []byte{0x00, 0x01, 0x02, 0x03}, - }, - }, - want: []byte(""), - wantErr: errCannotEncrypt, - }, - { - name: "iv too short aes-128-cbc", - fields: fields{ - ksb: 16, - mode: cipherModeCBC, - }, - args: args{ - key: key, - data: &plaintextData{ - iv: iv12, - plaintext: []byte("should fail"), - }, - }, - want: []byte(""), - wantErr: errCannotEncrypt, - }, - { - name: "bad padding aes-128-cbc", - fields: fields{ - ksb: 16, - mode: cipherModeCBC, - }, - args: args{ - key: key, - data: &plaintextData{ - iv: iv16, - plaintext: []byte("should fail"), - }, - }, - want: []byte(""), - wantErr: errCannotEncrypt, - }, - { - name: "good encrypt aes-128-cbc", - fields: fields{ - ksb: 16, - mode: cipherModeCBC, - }, - args: args{ - key: key, - data: &plaintextData{ - iv: iv16, - plaintext: doPaddingForTest([]byte("this test is green"), 16), - }, - }, - want: ciphertextCBC, - wantErr: nil, - }, - // TODO: Add moar test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &dataCipherAES{ - ksb: tt.fields.ksb, - mode: tt.fields.mode, - } - got, err := a.encrypt(tt.args.key, tt.args.data) - if !errors.Is(err, tt.wantErr) { - t.Errorf("dataCipherAES.encrypt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - log.Println(hex.EncodeToString(got)) - - t.Errorf("dataCipherAES.encrypt() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/vpn/data.go b/vpn/data.go deleted file mode 100644 index 15546ca9..00000000 --- a/vpn/data.go +++ /dev/null @@ -1,673 +0,0 @@ -package vpn - -// -// OpenVPN data channel -// - -import ( - "bytes" - "crypto/hmac" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "hash" - "math" - "net" - "strings" - "sync" -) - -var ( - errDataChannelKey = errors.New("bad key") - errBadCompression = errors.New("bad compression") - errReplayAttack = errors.New("replay attack") - errCannotEncrypt = errors.New("cannot encrypt") - errCannotDecrypt = errors.New("cannot decrypt") - errBadHMAC = errors.New("bad hmac") - errInitError = errors.New("improperly initialized") -) - -// keySlot holds the different local and remote keys. -type keySlot [64]byte - -// dataChannelState is the state of the data channel. -type dataChannelState struct { - dataCipher dataCipher - hash func() hash.Hash - // outgoing and incoming nomenclature is probably more adequate here. - hmacLocal hash.Hash - hmacRemote hash.Hash - remotePacketID packetID - cipherKeyLocal keySlot - cipherKeyRemote keySlot - hmacKeyLocal keySlot - hmacKeyRemote keySlot - keyID int // not used at the moment, paving the way for key rotation. - peerID int - - mu sync.Mutex -} - -// SetSetRemotePacketID stores the passed packetID internally. -func (dcs *dataChannelState) SetRemotePacketID(id packetID) { - dcs.mu.Lock() - defer dcs.mu.Unlock() - dcs.remotePacketID = packetID(id) -} - -// RemotePacketID returns the last known remote packetID. It returns an error -// if the stored packet id has reached the maximum capacity of the packetID -// type. -func (dcs *dataChannelState) RemotePacketID() (packetID, error) { - dcs.mu.Lock() - defer dcs.mu.Unlock() - pid := dcs.remotePacketID - if pid == math.MaxUint32 { - // we reached the max packetID, increment will overflow - return 0, errExpiredKey - } - return pid, nil -} - -// dataChannelKey represents a pair of key sources that have been negotiated -// over the control channel, and from which we will derive local and remote -// keys for encryption and decrption over the data channel. The index refers to -// the short key_id that is passed in the lower 3 bits if a packet header. -// The setup of the keys for a given data channel (that is, for every key_id) -// is made by expanding the keysources using the prf function. -// Do note that we are not yet implementing key renegotiation - but the index -// is provided for convenience when/if we support that in the future. -type dataChannelKey struct { - index uint32 - ready bool - local *keySource - remote *keySource - mu sync.Mutex -} - -// addRemoteKey adds the server keySource to our dataChannelKey. This makes the -// dataChannelKey ready to be used. -func (dck *dataChannelKey) addRemoteKey(k *keySource) error { - dck.mu.Lock() - defer dck.mu.Unlock() - if dck.ready { - return fmt.Errorf("%w:%s", errDataChannelKey, "cannot overwrite remote key slot") - } - dck.remote = k - dck.ready = true - return nil -} - -var ( - randomFn = genRandomBytes - errRandomBytes = errors.New("Error generating random bytes") -) - -// keySource contains random data to generate keys. -type keySource struct { - r1 [32]byte - r2 [32]byte - preMaster [48]byte -} - -// Bytes returns the byte representation of a keySource. -func (k *keySource) Bytes() []byte { - buf := &bytes.Buffer{} - buf.Write(k.preMaster[:]) - buf.Write(k.r1[:]) - buf.Write(k.r2[:]) - return buf.Bytes() -} - -// newKeySource returns a keySource and an error. -func newKeySource() (*keySource, error) { - random1, err := randomFn(32) - if err != nil { - return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error()) - } - - var r1, r2 [32]byte - var preMaster [48]byte - copy(r1[:], random1) - - random2, err := randomFn(32) - if err != nil { - return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error()) - } - copy(r2[:], random2) - - random3, err := randomFn(48) - if err != nil { - return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error()) - } - copy(preMaster[:], random3) - return &keySource{ - r1: r1, - r2: r2, - preMaster: preMaster, - }, nil -} - -// data represents the data "channel", that will encrypt and decrypt the tunnel payloads. -// data implements the dataHandler interface. -type data struct { - options *Options - session *session - state *dataChannelState - decodeFn func([]byte, *dataChannelState) (*encryptedData, error) - encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error) - decryptFn func([]byte, *encryptedData) ([]byte, error) -} - -var _ dataHandler = &data{} // Ensure that we implement dataHandler - -// newDataFromOptions returns a new data object, initialized with the -// options given. it also returns any error raised. -func newDataFromOptions(opt *Options, s *session) (*data, error) { - if opt == nil || s == nil { - return nil, fmt.Errorf("%w: %s", errBadInput, "found nil on init") - } - if len(opt.Cipher) == 0 || len(opt.Auth) == 0 { - return nil, fmt.Errorf("%w: %s", errBadInput, "empty options") - } - state := &dataChannelState{} - data := &data{options: opt, session: s, state: state} - - logger.Info(fmt.Sprintf("Cipher: %s", opt.Cipher)) - - dataCipher, err := newDataCipherFromCipherSuite(opt.Cipher) - if err != nil { - return data, err - } - data.state.dataCipher = dataCipher - switch dataCipher.isAEAD() { - case true: - data.decodeFn = decodeEncryptedPayloadAEAD - data.encryptEncodeFn = encryptAndEncodePayloadAEAD - case false: - data.decodeFn = decodeEncryptedPayloadNonAEAD - data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD - } - - logger.Info(fmt.Sprintf("Auth: %s", opt.Auth)) - - hmacHash, ok := newHMACFactory(strings.ToLower(opt.Auth)) - if !ok { - return data, fmt.Errorf("%w:%s", errBadInput, "no such mac") - } - data.state.hash = hmacHash - data.decryptFn = state.dataCipher.decrypt - - return data, nil -} - -// DecodeEncryptedPayload calls the corresponding function for AEAD or Non-AEAD decryption. -func (d *data) DecodeEncryptedPayload(b []byte, dcs *dataChannelState) (*encryptedData, error) { - return d.decodeFn(b, dcs) -} - -// SetSetupKeys performs the key expansion from the local and remote -// keySources, initializing the data channel state. -func (d *data) SetupKeys(dck *dataChannelKey) error { - if dck == nil { - return fmt.Errorf("%w: %s", errBadInput, "nil args") - } - if !dck.ready { - return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready") - } - master := prf( - dck.local.preMaster[:], - []byte("OpenVPN master secret"), - dck.local.r1[:], - dck.remote.r1[:], - []byte{}, []byte{}, - 48) - - keys := prf( - master, - []byte("OpenVPN key expansion"), - dck.local.r2[:], - dck.remote.r2[:], - d.session.LocalSessionID[:], d.session.RemoteSessionID[:], - 256) - - var keyLocal, hmacLocal, keyRemote, hmacRemote keySlot - copy(keyLocal[:], keys[0:64]) - copy(hmacLocal[:], keys[64:128]) - copy(keyRemote[:], keys[128:192]) - copy(hmacRemote[:], keys[192:256]) - - d.state.cipherKeyLocal = keyLocal - d.state.hmacKeyLocal = hmacLocal - d.state.cipherKeyRemote = keyRemote - d.state.hmacKeyRemote = hmacRemote - - logger.Debugf("Cipher key local: %x", keyLocal) - logger.Debugf("Cipher key remote: %x", keyRemote) - logger.Debugf("Hmac key local: %x", hmacLocal) - logger.Debugf("Hmac key remote: %x", hmacRemote) - - hashSize := d.state.hash().Size() - d.state.hmacLocal = hmac.New(d.state.hash, hmacLocal[:hashSize]) - d.state.hmacRemote = hmac.New(d.state.hash, hmacRemote[:hashSize]) - - logger.Info("Key derivation OK") - return nil -} - -// SetPeerID updates the data state field with the info sent by the server. -func (d *data) SetPeerID(i int) error { - d.state.peerID = i - return nil -} - -// -// write + encrypt -// - -// encrypt calls the corresponding function for AEAD or Non-AEAD decryption. -// Due to the particularities of the iv generation on each of the modes, encryption and encoding are -// done together in the same function. -// TODO accept state for symmetry -func (d *data) EncryptAndEncodePayload(plaintext []byte, dcs *dataChannelState) ([]byte, error) { - if len(plaintext) == 0 { - return []byte{}, fmt.Errorf("%w: nothing to encrypt", errCannotEncrypt) - } - if dcs == nil || dcs.dataCipher == nil { - return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, fmt.Errorf("data chan not initialized")) - } - - padded, err := doPadding(plaintext, d.options.Compress, dcs.dataCipher.blockSize()) - if err != nil { - return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, err) - } - - encrypted, err := d.encryptEncodeFn(padded, d.session, d.state) - if err != nil { - return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, err) - } - return encrypted, nil - -} - -// encryptAndEncodePayloadAEAD peforms encryption and encoding of the payload in AEAD modes (i.e., AES-GCM). -// TODO(ainghazal): for testing we can pass both the state object and the encryptFn -func encryptAndEncodePayloadAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) { - nextPacketID, err := session.LocalPacketID() - if err != nil { - return []byte{}, fmt.Errorf("bad packet id") - } - - // in AEAD mode, we authenticate: - // - 1 byte: opcode/key - // - 3 bytes: peer-id (we're using P_DATA_V2) - // - 4 bytes: packet-id - aead := &bytes.Buffer{} - aead.WriteByte(opcodeAndKeyHeader(state)) - bufWriteUint24(aead, uint32(state.peerID)) - bufWriteUint32(aead, uint32(nextPacketID)) - - // the iv is the packetID (again) concatenated with the 8 bytes of the - // key derived for local hmac (which we do not use for anything else in AEAD mode). - iv := &bytes.Buffer{} - bufWriteUint32(iv, uint32(nextPacketID)) - iv.Write(state.hmacKeyLocal[:8]) - - data := &plaintextData{ - iv: iv.Bytes(), - plaintext: padded, - aead: aead.Bytes(), - } - - encryptFn := state.dataCipher.encrypt - encrypted, err := encryptFn(state.cipherKeyLocal[:], data) - if err != nil { - return []byte{}, err - } - - // some reordering, because openvpn uses tag | payload - boundary := len(encrypted) - 16 - tag := encrypted[boundary:] - ciphertext := encrypted[:boundary] - - // we now write to the output buffer - out := bytes.Buffer{} - out.Write(data.aead) // opcode|peer-id|packet_id - out.Write(tag) - out.Write(ciphertext) - return out.Bytes(), nil - -} - -// encryptAndEncodePayloadNonAEAD peforms encryption and encoding of the payload in Non-AEAD modes (i.e., AES-CBC). -func encryptAndEncodePayloadNonAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) { - // For iv generation, OpenVPN uses a nonce-based PRNG that is initially seeded with - // OpenSSL RAND_bytes function. I am assuming this is good enough for our current purposes. - blockSize := state.dataCipher.blockSize() - - iv, err := randomFn(int(blockSize)) - if err != nil { - return []byte{}, err - } - data := &plaintextData{ - iv: iv, - plaintext: padded, - aead: nil, - } - - encryptFn := state.dataCipher.encrypt - ciphertext, err := encryptFn(state.cipherKeyLocal[:], data) - if err != nil { - return []byte{}, err - } - - state.hmacLocal.Reset() - state.hmacLocal.Write(iv) - state.hmacLocal.Write(ciphertext) - computedMAC := state.hmacLocal.Sum(nil) - - out := &bytes.Buffer{} - out.WriteByte(opcodeAndKeyHeader(state)) - bufWriteUint24(out, uint32(state.peerID)) - - out.Write(computedMAC) - out.Write(iv) - out.Write(ciphertext) - return out.Bytes(), nil -} - -// doCompress adds compression bytes if needed by the passed compression options. -// if the compression stub is on, it sends the first byte to the last position, -// and it adds the compression preamble, according to the spec. compression -// lzo-no also adds a preamble. It returns a byte array and an error if the -// operation could not be completed. -func doCompress(b []byte, c compression) ([]byte, error) { - switch c { - case "stub": - // compression stub: send first byte to last - // and add 0xfb marker on the first byte. - b = append(b, b[0]) - b[0] = 0xfb - case "lzo-no": - // old "comp-lzo no" option - b = append([]byte{0xfa}, b...) - } - return b, nil -} - -// doPadding does pkcs7 padding of the encryption payloads as -// needed. if we're using the compression stub the padding is applied without taking the -// trailing bit into account. it returns the resulting byte array, and an error -// if the operatio could not be completed. -func doPadding(b []byte, compress compression, blockSize uint8) ([]byte, error) { - if len(b) == 0 { - return nil, fmt.Errorf("%w: nothing to pad", errBadInput) - } - if compress == "stub" { - // if we're using the compression stub - // we need to account for a trailing byte - // that we have appended in the doCompress stage. - endByte := b[len(b)-1] - padded, err := bytesPadPKCS7(b[:len(b)-1], int(blockSize)) - if err != nil { - return nil, err - } - padded[len(padded)-1] = endByte - return padded, nil - } - padded, err := bytesPadPKCS7(b, int(blockSize)) - if err != nil { - return nil, err - } - return padded, nil -} - -// prependPacketID returns the original buffer with the passed packetID -// concatenated at the beginning. -func prependPacketID(p packetID, buf []byte) []byte { - newbuf := &bytes.Buffer{} - packetID := make([]byte, 4) - binary.BigEndian.PutUint32(packetID, uint32(p)) - newbuf.Write(packetID[:]) - newbuf.Write(buf) - return newbuf.Bytes() -} - -func (d *data) WritePacket(conn net.Conn, payload []byte) (int, error) { - if d.state == nil || d.state.dataCipher == nil { - return 0, fmt.Errorf("%w: %s", errBadInput, "bad state") - } - - var plain []byte - var err error - - // TODO(ainghazal): separate into two different implementations - // and get rid of multiple switch. - switch d.state.dataCipher.isAEAD() { - case true: - plain, err = doCompress(payload, d.options.Compress) - if err != nil { - return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err) - } - case false: // non-aead - localPacketID, _ := d.session.LocalPacketID() - plain = prependPacketID(localPacketID, payload) - - plain, err = doCompress(plain, d.options.Compress) - if err != nil { - return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err) - } - } - - // encrypted adds padding, if needed, and it also includes the - // opcode/keyid and peer-id headers and, if used, any authenticated - // parts in the packet. - encrypted, err := d.EncryptAndEncodePayload(plain, d.state) - if err != nil { - return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err) - } - - // TODO(ainghazal): increment counter for used bytes, and - // trigger renegotiation if we're near the end of the key useful lifetime. - - out := maybeAddSizeFrame(conn, encrypted) - - logger.Debug("data: write packet") - logger.Debugf("\n" + hex.Dump(out)) - - return conn.Write(out) -} - -// -// read + decrypt -// - -func (d *data) decrypt(encrypted []byte) ([]byte, error) { - if d.decryptFn == nil { - return []byte{}, errInitError - } - if len(d.state.hmacKeyRemote) == 0 { - logger.Error("decrypt: not ready yet") - return []byte{}, errCannotDecrypt - } - encryptedData, err := d.DecodeEncryptedPayload(encrypted, d.state) - - if err != nil { - return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, err) - } - plainText, err := d.decryptFn(d.state.cipherKeyRemote[:], encryptedData) - if err != nil { - return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, err) - } - return plainText, nil -} - -func decodeEncryptedPayloadAEAD(buf []byte, state *dataChannelState) (*encryptedData, error) { - // P_DATA_V2 GCM data channel crypto format - // 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a... - // [ OP32 ] [seq # ] [ auth tag ] [ payload ... ] - // - means authenticated - * means encrypted * - // [ - opcode/peer-id - ] [ - packet ID - ] [ TAG ] [ * packet payload * ] - - // preconditions - - if len(buf) == 0 || len(buf) < 20 { - return &encryptedData{}, fmt.Errorf("too short: %d bytes", len(buf)) - } - if len(state.hmacKeyRemote) < 8 { - return &encryptedData{}, fmt.Errorf("bad remote hmac") - } - remoteHMAC := state.hmacKeyRemote[:8] - packet_id := buf[:4] - - headers := &bytes.Buffer{} - headers.WriteByte(opcodeAndKeyHeader(state)) - bufWriteUint24(headers, uint32(state.peerID)) - headers.Write(packet_id) - - // we need to swap because decryption expects payload|tag - // but we've got tag | payload instead - payload := &bytes.Buffer{} - payload.Write(buf[20:]) // ciphertext - payload.Write(buf[4:20]) // tag - - // iv := packetID | remoteHMAC - iv := &bytes.Buffer{} - iv.Write(packet_id) - iv.Write(remoteHMAC) - - encrypted := &encryptedData{ - iv: iv.Bytes(), - ciphertext: payload.Bytes(), - aead: headers.Bytes(), - } - return encrypted, nil -} - -func decodeEncryptedPayloadNonAEAD(buf []byte, state *dataChannelState) (*encryptedData, error) { - if state == nil || state.dataCipher == nil { - return &encryptedData{}, fmt.Errorf("%w: bad state", errBadInput) - } - hashSize := uint8(state.hmacRemote.Size()) - blockSize := state.dataCipher.blockSize() - - minLen := hashSize + blockSize - - if len(buf) < int(minLen) { - return &encryptedData{}, fmt.Errorf("%w: too short (%d bytes)", errBadInput, len(buf)) - } - - receivedHMAC := buf[:hashSize] - iv := buf[hashSize : hashSize+blockSize] - cipherText := buf[hashSize+blockSize:] - - state.hmacRemote.Reset() - state.hmacRemote.Write(iv) - state.hmacRemote.Write(cipherText) - computedHMAC := state.hmacRemote.Sum(nil) - - if !hmac.Equal(computedHMAC, receivedHMAC) { - logger.Errorf("expected: %x, got: %x", computedHMAC, receivedHMAC) - return &encryptedData{}, fmt.Errorf("%w: %s", errCannotDecrypt, errBadHMAC) - } - - encrypted := &encryptedData{ - iv: iv, - ciphertext: cipherText, - aead: []byte{}, // no AEAD data in this mode, leaving it empty to satisfy common interface - } - return encrypted, nil -} - -func (d *data) ReadPacket(p *packet) ([]byte, error) { - if len(p.payload) == 0 { - return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, "empty payload") - } - panicIfFalse(p.isData(), "ReadPacket expects data packet") - - plaintext, err := d.decrypt(p.payload) - if err != nil { - return []byte{}, err - } - - // get plaintext payload from the decrypted plaintext - return maybeDecompress(plaintext, d.state, d.options) -} - -// maybeDecompress de-serializes the data from the payload according to the framing -// given by different compression methods. only the different no-compression -// modes are supported at the moment, so no real decompression is done. It -// returns a byte array, and an error if the operation could not be completed -// successfully. -func maybeDecompress(b []byte, st *dataChannelState, opt *Options) ([]byte, error) { - if st == nil || st.dataCipher == nil { - return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad state") - } - if opt == nil { - return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad options") - } - - var compr byte // compression type - var payload []byte - - // TODO(ainghazal): have two different decompress implementations - // instead of this switch - switch st.dataCipher.isAEAD() { - case true: - switch opt.Compress { - case compressionStub, compressionLZONo: - // these are deprecated in openvpn 2.5.x - compr = b[0] - payload = b[1:] - default: - compr = 0x00 - payload = b[:] - } - default: // non-aead - remotePacketID := packetID(binary.BigEndian.Uint32(b[:4])) - lastKnownRemote, err := st.RemotePacketID() - if err != nil { - return payload, err - } - if remotePacketID <= lastKnownRemote { - return []byte{}, errReplayAttack - } - st.SetRemotePacketID(remotePacketID) - - switch opt.Compress { - case compressionStub, compressionLZONo: - compr = b[4] - payload = b[5:] - default: - compr = 0x00 - payload = b[4:] - } - } - - switch compr { - case 0xfb: - // compression stub swap: - // we get the last byte and replace the compression byte - // these are deprecated in openvpn 2.5.x - end := payload[len(payload)-1] - b := payload[:len(payload)-1] - payload = append([]byte{end}, b...) - case 0x00, 0xfa: - // do nothing - // 0x00 is compress-no, - // 0xfa is the old no compression or comp-lzo no case. - // http://build.openvpn.net/doxygen/comp_8h_source.html - // see: https://community.openvpn.net/openvpn/ticket/952#comment:5 - default: - errMsg := fmt.Sprintf("cannot handle compression:%x", compr) - return []byte{}, fmt.Errorf("%w:%s", errBadCompression, errMsg) - } - return payload, nil -} - -// opcodeAndKeyHeader returns the header byte encoding the opcode and keyID (3 upper -// and 5 lower bits, respectively) -func opcodeAndKeyHeader(st *dataChannelState) byte { - return byte((pDataV2 << 3) | (st.keyID & 0x07)) -} diff --git a/vpn/data_test.go b/vpn/data_test.go deleted file mode 100644 index b311275e..00000000 --- a/vpn/data_test.go +++ /dev/null @@ -1,1425 +0,0 @@ -package vpn - -import ( - "bytes" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "encoding/hex" - "errors" - "fmt" - "math" - "net" - "reflect" - "sync" - "testing" - - "github.com/ooni/minivpn/vpn/mocks" -) - -const ( - rnd16 = "0123456789012345" - rnd32 = "01234567890123456789012345678901" - rnd48 = "012345678901234567890123456789012345678901234567" -) - -func makeTestKeys() ([32]byte, [32]byte, [48]byte) { - r1 := *(*[32]byte)([]byte(rnd32)) - r2 := *(*[32]byte)([]byte(rnd32)) - r3 := *(*[48]byte)([]byte(rnd48)) - return r1, r2, r3 -} - -// getDeterministicRandomKeySize returns a sequence of integers -// using the map in the closure. we use this to construct a deterministic -// random function to replace the random function used in the real client. -func getDeterministicRandomKeySizeFn() func() int { - var rndSeq = map[int]int{ - 1: 32, - 2: 32, - 3: 48, - } - i := 1 - f := func() int { - v := rndSeq[i] - i += 1 - return v - } - return f -} - -func Test_newKeySource(t *testing.T) { - - genKeySizeFn := getDeterministicRandomKeySizeFn() - - // we replace the global random function used in the constructor - randomFn = func(int) ([]byte, error) { - switch genKeySizeFn() { - case 48: - return []byte(rnd48), nil - default: - return []byte(rnd32), nil - } - } - - r1, r2, premaster := makeTestKeys() - ks := &keySource{r1, r2, premaster} - - tests := []struct { - name string - want *keySource - }{ - { - name: "test generation of a new key with mocked random data", - want: ks, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got, _ := newKeySource(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("newKeySource() = %v, want %v", got, tt.want) - } - }) - } -} - -func makeTestingSession() *session { - s := &session{ - RemoteSessionID: sessionID{0x01}, - LocalSessionID: sessionID{0x02}, - mu: sync.Mutex{}, - } - return s -} - -func makeTestingOptions(t *testing.T, cipher, auth string) *Options { - crt, _ := writeTestingCerts(t.TempDir()) - opt := &Options{ - Cipher: cipher, - Auth: auth, - CertPath: crt.cert, - KeyPath: crt.key, - CaPath: crt.ca, - } - return opt -} - -func Test_newDataFromOptions(t *testing.T) { - type args struct { - opt *Options - s *session - } - tests := []struct { - name string - args args - want *data - wantWhatever bool - wantErr error - }{ - { - name: "nil args should fail", - args: args{}, - want: nil, - wantErr: errBadInput, - }, - { - name: "empty Options should fail", - args: args{ - opt: &Options{}, - s: makeTestingSession(), - }, - want: nil, - wantErr: errBadInput, - }, - { - name: "bad auth in Options should fail", - args: args{ - opt: makeTestingOptions(t, "AES-128-GCM", "shabad"), - s: makeTestingSession(), - }, - wantWhatever: true, - wantErr: errBadInput, - }, - { - name: "empty session should not fail", - args: args{ - opt: makeTestingOptions(t, "AES-128-GCM", "sha512"), - s: &session{}, - }, - wantWhatever: true, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newDataFromOptions(tt.args.opt, tt.args.s) - if !errors.Is(err, tt.wantErr) { - t.Errorf("newDataFromOptions() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.wantWhatever && !reflect.DeepEqual(got, tt.want) { - t.Errorf("newDataFromOptions() = %v, want %v", got, tt.want) - } - }) - } -} - -func makeTestingDataChannelKey() *dataChannelKey { - rl1, rl2, preml := makeTestKeys() - rr1, rr2, premr := makeTestKeys() - - ksLocal := &keySource{rl1, rl2, preml} - ksRemote := &keySource{rr1, rr2, premr} - - dck := &dataChannelKey{ - ready: true, - local: ksLocal, - remote: ksRemote, - } - return dck -} - -func Test_data_SetupKeys(t *testing.T) { - type fields struct { - session *session - state *dataChannelState - } - type args struct { - dck *dataChannelKey - } - tests := []struct { - name string - fields fields - args args - wantErr error - }{ - { - name: "nil in arguments should fail", - fields: fields{ - session: makeTestingSession(), - state: makeTestingState(), - }, - args: args{}, - wantErr: errBadInput, - }, - { - name: "dataChannelKey not ready", - fields: fields{ - session: makeTestingSession(), - state: makeTestingState(), - }, - args: args{ - dck: &dataChannelKey{}, - }, - wantErr: errDataChannelKey, - }, - { - name: "good setup", - fields: fields{ - session: makeTestingSession(), - state: makeTestingState(), - }, - args: args{ - dck: makeTestingDataChannelKey(), - }, - wantErr: nil, - // TODO(ainghazal): should write another test to verify the key derivation? - // but what that would be testing, if not the implementation? - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &data{ - session: tt.fields.session, - state: tt.fields.state, - } - if err := d.SetupKeys(tt.args.dck); !errors.Is(err, tt.wantErr) { - t.Errorf("data.SetupKeys() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_data_EncryptAndEncodePayload(t *testing.T) { - // TODO(ainghazal): this is exercising only one encryption method - - opt := &Options{} - - type fields struct { - options *Options - session *session - state *dataChannelState - decodeFn func([]byte, *dataChannelState) (*encryptedData, error) - encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error) - } - type args struct { - plaintext []byte - dcs *dataChannelState - } - tests := []struct { - name string - fields fields - args args - want []byte - wantErr error - }{ - { - name: "dummy encryptEncodeFn does not fail", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - return []byte{}, nil - }, - }, - args: args{ - plaintext: []byte("hello"), - dcs: makeTestingState(), - }, - want: []byte{}, - wantErr: nil, - }, - { - name: "empty plaintext fails", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - return []byte{}, nil - }, - }, - args: args{ - plaintext: []byte{}, - dcs: makeTestingState(), - }, - want: []byte{}, - wantErr: errCannotEncrypt, - }, - { - name: "error on encryptEncodeFn gets propagated", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - return []byte{}, errors.New("dummyTestError") - }, - }, - args: args{ - plaintext: []byte{}, - dcs: makeTestingState(), - }, - want: []byte{}, - wantErr: errCannotEncrypt, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &data{ - options: tt.fields.options, - session: tt.fields.session, - state: tt.fields.state, - decodeFn: tt.fields.decodeFn, - encryptEncodeFn: tt.fields.encryptEncodeFn, - } - got, err := d.EncryptAndEncodePayload(tt.args.plaintext, tt.args.dcs) - if !errors.Is(err, tt.wantErr) { - t.Errorf("data.EncryptAndEncodePayload() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("data.EncryptAndEncodePayload() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_dataChannelState_RemotePacketID(t *testing.T) { - type fields struct { - remotePacketID packetID - } - tests := []struct { - name string - fields fields - want packetID - wantErr error - }{ - { - "zero", - fields{0}, - packetID(0), - nil, - }, - { - "one", - fields{1}, - packetID(1), - nil, - }, - { - "overflow", - fields{math.MaxUint32}, - packetID(0), - errExpiredKey, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dcs := &dataChannelState{ - remotePacketID: tt.fields.remotePacketID, - } - if got, err := dcs.RemotePacketID(); got != tt.want || err != tt.wantErr { - t.Errorf("dataChannelState.RemotePacketID() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_keySource_Bytes(t *testing.T) { - r1, r2, premaster := makeTestKeys() - goodSerialized := append(premaster[:], r1[:]...) - goodSerialized = append(goodSerialized, r2[:]...) - - type fields struct { - r1 [32]byte - r2 [32]byte - preMaster [48]byte - } - tests := []struct { - name string - fields fields - want []byte - }{ - { - name: "good keysource", - fields: fields{ - r1: r1, - r2: r2, - preMaster: premaster, - }, - want: goodSerialized, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - k := &keySource{ - r1: tt.fields.r1, - r2: tt.fields.r2, - preMaster: tt.fields.preMaster, - } - if got := k.Bytes(); !bytes.Equal(got, tt.want) { - t.Errorf("keySource.Bytes() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_dataChannelKey_addRemoteKey(t *testing.T) { - type fields struct { - ready bool - remote *keySource - } - type args struct { - k *keySource - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - "passing a keysource should make it ready", - fields{false, &keySource{}}, - args{&keySource{}}, - false, - }, - { - "fail if ready", - fields{true, &keySource{}}, - args{&keySource{}}, - true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dck := &dataChannelKey{ - ready: tt.fields.ready, - remote: tt.fields.remote, - } - if err := dck.addRemoteKey(tt.args.k); (err != nil) != tt.wantErr { - t.Errorf("dataChannelKey.addRemoteKey() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func makeTestingState() *dataChannelState { - dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeGCM) - st := &dataChannelState{ - hash: sha1.New, - // my linter doesn't like it, but this is the proper way of casting to keySlot - cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), - cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), - hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), - hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), - } - st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20]) - st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20]) - st.dataCipher = dataCipher - return st -} - -func makeTestingStateNonAEAD() *dataChannelState { - dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeCBC) - st := &dataChannelState{ - hash: sha1.New, - // my linter doesn't like it, but this is the proper way of casting to keySlot - cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), - cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), - hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), - hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), - } - st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20]) - st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20]) - st.dataCipher = dataCipher - return st -} - -func makeTestingStateNonAEADReversed() *dataChannelState { - dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeCBC) - st := &dataChannelState{ - hash: sha1.New, - // my linter doesn't like it, but this is the proper way of casting to keySlot - cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), - cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), - hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), - hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), - } - st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20]) - st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20]) - st.dataCipher = dataCipher - return st -} - -func Test_data_decrypt(t *testing.T) { - - goodMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) { - return []byte("alles ist gut"), nil - } - - failingMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) { - return []byte{}, errCannotDecrypt - } - - opt := &Options{} - - type fields struct { - options *Options - session *session - state *dataChannelState - decodeFn func([]byte, *dataChannelState) (*encryptedData, error) - encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error) - decryptFn func([]byte, *encryptedData) ([]byte, error) - } - type args struct { - encrypted []byte - } - tests := []struct { - name string - fields fields - args args - want []byte - wantErr error - }{ - { - name: "empty output in decodeFn does fail", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) { - return &encryptedData{}, nil - }, - encryptEncodeFn: nil, - decryptFn: makeTestingState().dataCipher.decrypt, - }, - args: args{ - encrypted: bytes.Repeat([]byte{0x0a}, 20), - }, - want: []byte{}, - wantErr: errCannotDecrypt, - }, - { - name: "empty encrypted input does fail", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) { - return &encryptedData{}, nil - }, - encryptEncodeFn: nil, - decryptFn: makeTestingState().dataCipher.decrypt, - }, - args: args{ - encrypted: []byte{}, - }, - want: []byte{}, - wantErr: errCannotDecrypt, - }, - { - name: "error in decrypt propagates", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) { - return &encryptedData{}, nil - }, - encryptEncodeFn: nil, - decryptFn: failingMockDecryptFn, - }, - args: args{ - encrypted: []byte{}, - }, - want: []byte{}, - wantErr: errCannotDecrypt, - }, - { - name: "good decrypt returns expected output", - fields: fields{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) { - return &encryptedData{}, nil - }, - encryptEncodeFn: nil, - decryptFn: goodMockDecryptFn, - }, - args: args{ - encrypted: []byte{}, - }, - want: []byte("alles ist gut"), - wantErr: nil, - }, - // TODO we already are testing decrypt + encrypt in the crypto module - // so we can mock the decrypt here in the state. - // TODO empty ciphertext raises error - // TODO: Add moar test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &data{ - options: tt.fields.options, - session: tt.fields.session, - state: tt.fields.state, - decodeFn: tt.fields.decodeFn, - encryptEncodeFn: tt.fields.encryptEncodeFn, - decryptFn: tt.fields.decryptFn, - } - got, err := d.decrypt(tt.args.encrypted) - if !errors.Is(err, tt.wantErr) { - t.Errorf("data.decrypt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("data.decrypt() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeEncryptedPayloadAEAD(t *testing.T) { - - state := makeTestingState() - - goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766") - goodDecodeIV, _ := hex.DecodeString("000000006868686868686868") - goodDecodeCipherText, _ := hex.DecodeString("31278ff328ff2fc65c4dbf9eb8e67766b3653a842f2b8a148de26375218fb01d") - goodDecodeAEAD, _ := hex.DecodeString("4800000000000000") - - type args struct { - buf []byte - state *dataChannelState - } - tests := []struct { - name string - args args - want *encryptedData - wantErr bool - }{ - { - "empty", - args{[]byte{}, &dataChannelState{}}, - &encryptedData{}, - true, - }, - { - "too short", - args{bytes.Repeat([]byte{0xff}, 19), &dataChannelState{}}, - &encryptedData{}, - true, - }, - { - "good decode", - args{goodEncryptedPayload, state}, - &encryptedData{ - iv: goodDecodeIV, - ciphertext: goodDecodeCipherText, - aead: goodDecodeAEAD, - }, - false, - }, - // TODO: Add moar test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeEncryptedPayloadAEAD(tt.args.buf, tt.args.state) - if (err != nil) != tt.wantErr { - t.Errorf("decodeEncryptedPayloadAEAD() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeEncryptedPayloadAEAD() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeEncryptedPayloadNonAEAD(t *testing.T) { - - goodInput, _ := hex.DecodeString("fdf9b069b2e5a637fa7b5c9231166ea96307e4123031323334353637383930313233343581e4878c5eec602c2d2f5a95139c84af") - iv, _ := hex.DecodeString("30313233343536373839303132333435") - ciphertext, _ := hex.DecodeString("81e4878c5eec602c2d2f5a95139c84af") - - type args struct { - buf []byte - state *dataChannelState - } - tests := []struct { - name string - args args - want *encryptedData - wantErr error - }{ - { - name: "empty", - args: args{[]byte{}, &dataChannelState{}}, - want: &encryptedData{}, - wantErr: errBadInput, - }, - { - name: "too short", - args: args{bytes.Repeat([]byte{0xff}, 27), &dataChannelState{}}, - want: &encryptedData{}, - wantErr: errBadInput, - }, - { - name: "nil state should fail", - args: args{goodInput, nil}, - want: &encryptedData{}, - wantErr: errBadInput, - }, - { - name: "empty state.dataCipher should fail", - args: args{goodInput, &dataChannelState{}}, - want: &encryptedData{}, - wantErr: errBadInput, - }, - { - name: "good decode", - args: args{goodInput, makeTestingStateNonAEADReversed()}, - want: &encryptedData{ - iv: iv, - ciphertext: ciphertext, - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeEncryptedPayloadNonAEAD(tt.args.buf, tt.args.state) - if !errors.Is(err, tt.wantErr) { - t.Errorf("decodeEncryptedPayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !bytes.Equal(got.iv, tt.want.iv) { - t.Errorf("decodeEncryptedPayloadNonAEAD().iv = %v, want %v", got.iv, tt.want.iv) - } - if !bytes.Equal(got.ciphertext, tt.want.ciphertext) { - t.Errorf("decodeEncryptedPayloadNonAEAD().iv = %v, want %v", got.iv, tt.want.iv) - } - }) - } -} - -func Test_encryptAndEncodePayloadAEAD(t *testing.T) { - - state := makeTestingState() - padded, _ := doPadding([]byte("hello go tests"), "", state.dataCipher.blockSize()) - - goodEncryptedPayload, _ := hex.DecodeString("48000000000000006ba730fd633b1d5f11478f6f601cb84231278ff328ff2fc65c4dbf9eb8e67766") - - type args struct { - padded []byte - session *session - state *dataChannelState - } - tests := []struct { - name string - args args - want []byte - wantErr bool - }{ - { - "good encrypt", - args{padded, &session{}, state}, - goodEncryptedPayload, - false, - }, - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := encryptAndEncodePayloadAEAD(tt.args.padded, tt.args.session, tt.args.state) - if (err != nil) != tt.wantErr { - t.Errorf("encryptAndEncodePayloadAEAD() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("encryptAndEncodePayloadAEAD() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_encryptAndEncodePayloadNonAEAD(t *testing.T) { - - padded16 := bytes.Repeat([]byte{0xff}, 16) - padded15 := bytes.Repeat([]byte{0xff}, 15) - - // including OP32 header + peerid (v2) - goodEncrypted, _ := hex.DecodeString("48000000fdf9b069b2e5a637fa7b5c9231166ea96307e4123031323334353637383930313233343581e4878c5eec602c2d2f5a95139c84af") - - // we replace the global random function that is used for the iv in, e.g., CBC mode. - randomFn = func(i int) ([]byte, error) { - switch i { - case 16: - return []byte(rnd16), nil - default: - return []byte(rnd32), nil - } - } - - type args struct { - padded []byte - session *session - state *dataChannelState - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{ - { - name: "good encrypt", - args: args{ - padded: padded16, - session: &session{}, - state: makeTestingStateNonAEAD()}, - want: goodEncrypted, - wantErr: nil, - }, - { - name: "badly padded input should fail", - args: args{ - padded: padded15, - session: &session{}, - state: makeTestingStateNonAEAD()}, - want: nil, - wantErr: errCannotEncrypt, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := encryptAndEncodePayloadNonAEAD(tt.args.padded, tt.args.session, tt.args.state) - if !errors.Is(err, tt.wantErr) { - t.Errorf("encryptAndEncodePayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !bytes.Equal(got, tt.want) { - fmt.Println(hex.EncodeToString(got)) - t.Errorf("encryptAndEncodePayloadNonAEAD() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_doCompress(t *testing.T) { - type args struct { - b []byte - opt compression - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{ - { - name: "null compression should not fail", - args: args{}, - want: []byte{}, - wantErr: nil, - }, - { - name: "do nothing by default", - args: args{ - b: []byte{0xde, 0xad, 0xbe, 0xef}, - opt: "", - }, - want: []byte{0xde, 0xad, 0xbe, 0xef}, - wantErr: nil, - }, - { - name: "stub appends the first byte at the end", - args: args{ - b: []byte{0xde, 0xad, 0xbe, 0xef}, - opt: "stub", - }, - want: []byte{0xfb, 0xad, 0xbe, 0xef, 0xde}, - wantErr: nil, - }, - { - name: "lzo-no adds 0xfa preamble", - args: args{ - b: []byte{0xde, 0xad, 0xbe, 0xef}, - opt: "lzo-no", - }, - want: []byte{0xfa, 0xde, 0xad, 0xbe, 0xef}, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := doCompress(tt.args.b, tt.args.opt) - if !errors.Is(err, tt.wantErr) { - t.Errorf("maybeAddCompressStub() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !bytes.Equal(got, tt.want) { - t.Errorf("maybeAddCompressStub() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_doPadding(t *testing.T) { - type args struct { - b []byte - compress compression - blockSize uint8 - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{ - { - name: "add a whole padding block if len equal to block size, no padding stub", - args: args{ - b: []byte{0x00, 0x01, 0x02, 0x03}, - compress: compression(""), - blockSize: 4, - }, - want: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x04, 0x04, 0x04}, - wantErr: nil, - }, - { - name: "compression stub with len == blocksize", - args: args{ - b: []byte{0x00, 0x01, 0x02, 0x03}, - compress: compressionStub, - blockSize: 4, - }, - want: []byte{0x00, 0x01, 0x02, 0x03}, - wantErr: nil, - }, - { - name: "compression stub with len < blocksize", - args: args{ - b: []byte{0x00, 0x01, 0xff}, - compress: compressionStub, - blockSize: 4, - }, - want: []byte{0x00, 0x01, 0x02, 0xff}, - wantErr: nil, - }, - { - name: "compression stub with len = blocksize + 1", - args: args{ - b: []byte{0x00, 0x01, 0x02, 0x03, 0xff}, - compress: compressionStub, - blockSize: 4, - }, - want: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x04, 0x04, 0xff}, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := doPadding(tt.args.b, tt.args.compress, tt.args.blockSize) - if !errors.Is(err, tt.wantErr) { - t.Errorf("doPadding() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("doPadding() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_prependPacketID(t *testing.T) { - type args struct { - p packetID - buf []byte - } - tests := []struct { - name string - args args - want []byte - }{ - { - name: "good append", - args: args{ - packetID(0x01), - []byte{0x07, 0x08}, - }, - want: []byte{0x00, 0x00, 0x00, 0x01, 0x07, 0x08}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := prependPacketID(tt.args.p, tt.args.buf); !reflect.DeepEqual(got, tt.want) { - t.Errorf("prependPacketID() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_maybeDecompress(t *testing.T) { - - getStateForDecompressTestNonAEAD := func() *dataChannelState { - st := makeTestingStateNonAEAD() - st.remotePacketID = packetID(0x42) - return st - } - - type args struct { - b []byte - st *dataChannelState - opt *Options - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{ - { - name: "nil state should fail", - args: args{ - b: []byte{}, - st: nil, - opt: &Options{}, - }, - want: []byte{}, - wantErr: errBadInput, - }, - { - name: "nil options should fail", - args: args{ - b: []byte{}, - st: makeTestingState(), - opt: nil, - }, - want: []byte{}, - wantErr: errBadInput, - }, - { - name: "aead cipher, no compression", - args: args{ - b: []byte{0xaa, 0xbb, 0xcc}, - st: makeTestingState(), - opt: &Options{}, - }, - want: []byte{0xaa, 0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "aead cipher, no compr", - args: args{ - b: []byte{0xfa, 0xbb, 0xcc}, - st: makeTestingState(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "aead cipher, stub on options and stub on header", - args: args{ - b: []byte{0xfb, 0xbb, 0xcc, 0xdd}, - st: makeTestingState(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{0xdd, 0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "aead cipher, stub, unsupported compression", - args: args{ - b: []byte{0xff, 0xbb, 0xcc}, - st: makeTestingState(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{}, - wantErr: errBadCompression, - }, - { - name: "aead cipher, lzo-no", - args: args{ - b: []byte{0xfa, 0xbb, 0xcc}, - st: makeTestingState(), - opt: &Options{Compress: "lzo-no"}, - }, - want: []byte{0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "aead cipher, compress-no", - args: args{ - b: []byte{0x00, 0xbb, 0xcc}, - st: makeTestingState(), - opt: &Options{Compress: "no"}, - }, - want: []byte{0x00, 0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "non-aead cipher, stub", - args: args{ - b: []byte{0x00, 0x00, 0x00, 0x43, 0x00, 0xbb, 0xcc}, - st: getStateForDecompressTestNonAEAD(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "non-aead cipher, stub, unsupported compression", - args: args{ - b: []byte{0x00, 0x00, 0x00, 0x43, 0x0ff, 0xbb, 0xcc}, - st: getStateForDecompressTestNonAEAD(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{}, - wantErr: errBadCompression, - }, - { - name: "non-aead cipher, compress-no", - args: args{ - b: []byte{0x00, 0x00, 0x00, 0x43, 0x00, 0xbb, 0xcc}, - st: getStateForDecompressTestNonAEAD(), - opt: &Options{Compress: "no"}, - }, - want: []byte{0x00, 0xbb, 0xcc}, - wantErr: nil, - }, - { - name: "non-aead cipher, replay detected (equal remote packetID)", - args: args{ - b: []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0xbb, 0xcc}, - st: getStateForDecompressTestNonAEAD(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{}, - wantErr: errReplayAttack, - }, - { - name: "non-aead cipher, replay detected (lesser remote packetID)", - args: args{ - b: []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0xbb, 0xcc}, - st: getStateForDecompressTestNonAEAD(), - opt: &Options{Compress: "stub"}, - }, - want: []byte{}, - wantErr: errReplayAttack, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := maybeDecompress(tt.args.b, tt.args.st, tt.args.opt) - if !errors.Is(err, tt.wantErr) { - t.Errorf("maybeDecompress() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("maybeDecompress() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_data_ReadPacket(t *testing.T) { - - goodMockDecodeFn := func([]byte, *dataChannelState) (*encryptedData, error) { - d := &encryptedData{ - iv: []byte{0xee}, - ciphertext: []byte("garbledpayload"), - aead: []byte{0xff}, - } - return d, nil - } - - goodMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) { - return []byte("alles ist gut"), nil - } - - type fields struct { - options *Options - state *dataChannelState - decryptFn func([]byte, *encryptedData) ([]byte, error) - decodeFn func([]byte, *dataChannelState) (*encryptedData, error) - } - type args struct { - p *packet - } - tests := []struct { - name string - fields fields - args args - want []byte - wantErr error - }{ - { - name: "good decrypt using mocked decrypt fn and decode fn", - fields: fields{ - options: makeTestingOptions(t, "AES-128-GCM", "sha1"), - state: makeTestingState(), - decryptFn: goodMockDecryptFn, - decodeFn: goodMockDecodeFn, - }, - args: args{&packet{ - opcode: pDataV1, - payload: []byte("garbled")}, - }, - want: []byte("alles ist gut"), - wantErr: nil, - }, - // TODO panic when call to DecodeEncryptedPayload - // TODO error if empty payload - // TODO make sure decompress fn is called? - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &data{ - options: tt.fields.options, - state: tt.fields.state, - decryptFn: tt.fields.decryptFn, - decodeFn: tt.fields.decodeFn, - } - got, err := d.ReadPacket(tt.args.p) - if !errors.Is(err, tt.wantErr) { - t.Errorf("data.ReadPacket() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("data.ReadPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -// we'll use a mocked net.Conn for WritePacket - -func makeTestingConnForWrite(network, addr string, n int) net.Conn { - mockAddr := &mocks.Addr{} - mockAddr.MockString = func() string { - return addr - } - mockAddr.MockNetwork = func() string { - return network - } - - mockConn := &mocks.Conn{} - mockConn.MockLocalAddr = func() net.Addr { - return mockAddr - } - mockConn.MockWrite = func([]byte) (int, error) { - return n, nil - } - mockConn.MockRead = func([]byte) (int, error) { - return n, nil - } - return mockConn -} - -func Test_data_WritePacket(t *testing.T) { - opt := &Options{} - - goodMockEncodedEncryptFn := func([]byte, *session, *dataChannelState) ([]byte, error) { - return []byte("alles ist garbled gut"), nil - } - - type fields struct { - options *Options - // session is only used for NonAEAD encryption - session *session - state *dataChannelState - encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error) - } - type args struct { - conn net.Conn - payload []byte - } - tests := []struct { - name string - fields fields - args args - want int - wantErr error - }{ - { - name: "good write, aead encryption", - fields: fields{ - options: opt, - session: nil, - state: makeTestingState(), - encryptEncodeFn: goodMockEncodedEncryptFn, - }, - args: args{ - conn: makeTestingConnForWrite("udp", "10.0.42.1", 42), - payload: []byte("hello test"), - }, - want: 42, - wantErr: nil, - }, - - // TODO: Add moar test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &data{ - options: tt.fields.options, - session: tt.fields.session, - state: tt.fields.state, - encryptEncodeFn: tt.fields.encryptEncodeFn, - } - got, err := d.WritePacket(tt.args.conn, tt.args.payload) - if !errors.Is(err, tt.wantErr) { - t.Errorf("data.WritePacket() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("data.WritePacket() = %v, want %v", got, tt.want) - } - }) - } -} - -// Regression test for MIV-01-003 -func Test_Crash_EncryptAndEncodePayload(t *testing.T) { - opt := &Options{} - st := &dataChannelState{ - hash: sha1.New, - cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), - cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), - hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), - hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), - } - a := &data{ - options: opt, - session: makeTestingSession(), - state: st, - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - return []byte{}, nil - }, - } - a.EncryptAndEncodePayload(nil, a.state) -} - -// Regression test for MIV-01-004 -func Test_Crash_doPadding(t *testing.T) { - arr := []byte{} - doPadding(arr, "stub", 16) -} - -// Regression test for MIV-01-004 -func Test_Crash_EncryptAndEncodePayload_Zero_Len_Array(t *testing.T) { - opt := &Options{} - st := &dataChannelState{ - hash: sha1.New, - cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), - cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), - hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), - hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), - } - a := &data{ - options: opt, - session: makeTestingSession(), - state: st, - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - if len(b) == 0 { - return nil, fmt.Errorf("should not receive zero len") - } - return []byte{}, nil - }, - } - _, err := a.EncryptAndEncodePayload([]byte{}, a.state) - if err == nil || !errors.Is(err, errCannotEncrypt) { - t.Error("should not fail with zero len") - } -} - -func base64Decode(str string) (string, bool) { - data, err := base64.StdEncoding.DecodeString(str) - if err != nil { - return "", true - } - return string(data), false -} - -// Regression test for MIV-01-004 -func Test_Crash_DecodeEncryptedPayload_Too_Short(t *testing.T) { - input, _ := base64Decode("////////mv//////////////////////////xxk=") - opt := &Options{} - type args struct { - encrypted []byte - dcs *dataChannelState - } - a := &data{ - options: opt, - session: makeTestingSession(), - state: makeTestingState(), - decodeFn: nil, - encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) { - return []byte{}, nil - }, - } - a.decodeFn = decodeEncryptedPayloadNonAEAD - b := &args{[]byte(input), makeTestingState()} - a.DecodeEncryptedPayload(b.encrypted, b.dcs) -} diff --git a/vpn/dialer.go b/vpn/dialer.go deleted file mode 100644 index ffb43b8d..00000000 --- a/vpn/dialer.go +++ /dev/null @@ -1,217 +0,0 @@ -package vpn - -// -// This file contains dialer types and functions that allow transparent use of -// an OpenVPN connection. -// - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - "time" - - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/netstack" -) - -var ( - openDNSPrimary = "208.67.222.222" - openDNSSecondary = "208.67.220.220" -) - -// A TunDialer contains options for obtaining a network connection tunneled -// through an OpenVPN endpoint. It uses a userspace gVisor virtual device over -// the raw OpenVPN tunnel. -// -// You need to be careful and create only one instance of TunDialer for each -// Client, since the underlying virtual device will connect both ends of the -// tunnel. -type TunDialer struct { - // Dialer will be passed to the underlying Client constructor. - Dialer DialerContext - client *Client - ns1 string - ns2 string - skipDeviceSetup bool - device *device - tun *netstack.Net - mu sync.Mutex - - // dependency injection to test client start - clientStartFn func(context.Context) error -} - -// NewTunDialer creates a new Dialer with the default nameservers (OpenDNS). -func NewTunDialer(client *Client) *TunDialer { - td := &TunDialer{ - client: client, - ns1: openDNSPrimary, - ns2: openDNSSecondary, - } - return td -} - -// NewTunDialerWithNameservers creates a new TunDialer with the passed nameservers. -// You probably want to pass the nameservers for your own VPN service here. -func NewTunDialerWithNameservers(client *Client, ns1, ns2 string) *TunDialer { - td := &TunDialer{ - client: client, - ns1: ns1, - ns2: ns2, - } - return td -} - -// StartNewTunDialerFromOptions creates a new Dialer directly from an Options -// object. It also starts the underlying client. -func StartNewTunDialerFromOptions(opt *Options, dialer DialerContext) (*TunDialer, error) { - if dialer == nil { - return nil, fmt.Errorf("%w: nil dialer", errBadInput) - } - client := NewClientFromOptions(opt) - client.Dialer = dialer - err := client.Start(context.Background()) - if err != nil { - defer client.Close() - return nil, err - } - td := &TunDialer{ - client: client, - ns1: openDNSPrimary, - ns2: openDNSSecondary, - } - return td, nil -} - -// Dial connects to the address on the named network, via the OpenVPN endpoint -// in the Client that this TunDialer is initialized with. -// -// The return value implements the net.Conn interface, but it is a socket created -// on a virtual device, using gVisor userspace network stack. This means that the -// kernel only sees UDP packets with an encrypted payload. -// -// The addresses are resolved via the OpenVPN tunnel too, and against the nameservers -// configured in the dialer. This feature uses wireguard's little custom DNS client -// implementation. -// -// Dial calls DialContext with the background context. See documentation of -// DialContext for more details. -// -// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), -// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ping4", "ping6". -func (td *TunDialer) Dial(network, address string) (net.Conn, error) { - ctx := context.Background() - tnet, err := td.createNetTUN(ctx) - if err != nil { - return nil, err - } - return tnet.Dial(network, address) -} - -// DialContext connects to the address on the named network using -// the provided context. -// -// The underlying tun is created just once upon successive invocations of -// DialContext. -func (td *TunDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - td.mu.Lock() - defer td.mu.Unlock() - if td.tun == nil { - tnet, err := td.createNetTUN(ctx) - if err != nil { - return nil, err - } - td.tun = tnet - } - return td.tun.DialContext(ctx, network, address) -} - -// DialTimeout acts like Dial but takes a timeout. -func (td *TunDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { - conn, err := td.Dial(network, address) - if err != nil { - return nil, err - } - err = conn.SetReadDeadline(time.Now().Add(timeout)) - return conn, err -} - -func (td *TunDialer) createNetTUN(ctx context.Context) (*netstack.Net, error) { - localIP := td.client.LocalAddr().String() - - // create a virtual device in userspace, courtesy of wireguard-go - tun, tnet, err := netstack.CreateNetTUN( - []netip.Addr{netip.Addr(netip.MustParseAddr(localIP))}, - []netip.Addr{ - netip.MustParseAddr(td.ns1), - netip.MustParseAddr(td.ns2)}, - td.client.tunInfo.mtu-100, - ) - // TODO(https://github.com/ooni/minivpn/issues/26): - // we cannot use the tun-mtu that the remote advertises, so we subtract - // a "safety" margin for the time being. - - if err != nil { - return nil, err - } - - // connect the virtual device to our openvpn tunnel - if !td.skipDeviceSetup { - dev := &device{tun, td.client} - dev.Up() - td.device = dev - } - return tnet, nil -} - -// device contains the two halves of the tunnel that we are connecting in our -// toy implementation: the virtual tun device that is handled by netstack, and -// the vpn.Client (that satisfies a net.Conn) that writes and reads to sockets -// provided by the kernel. -type device struct { - tun tun.Device - vpn net.Conn -} - -// Up spawns two goroutines that communicate the two halves of a device. -// TODO(https://github.com/ooni/minivpn/issues/27): we probably want a way of -// shutting them down too. -func (d *device) Up() { - go func() { - b := make([]byte, 4096) - bufs := [][]byte{b} - sizes := []int{4096} - for { - n, err := d.tun.Read(bufs, sizes, 0) // zero offset - if err != nil { - logger.Errorf("tun read error: %v", err) - break - } - _, err = d.vpn.Write(b[0:n]) - if err != nil { - logger.Errorf("vpn write error: %v", err) - break - } - - } - }() - go func() { - b := make([]byte, 4096) - for { - n, err := d.vpn.Read(b) - if err != nil { - logger.Errorf("vpn read error: %v", err) - break - } - - _, err = d.tun.Write([][]byte{b[0:n]}, 0) // zero offset - if err != nil { - logger.Errorf("tun write error: %v", err) - break - } - } - }() -} diff --git a/vpn/dialer_test.go b/vpn/dialer_test.go deleted file mode 100644 index ed5d361c..00000000 --- a/vpn/dialer_test.go +++ /dev/null @@ -1,449 +0,0 @@ -package vpn - -import ( - "bytes" - "context" - "errors" - "net" - "net/netip" - "reflect" - "testing" - "time" - - "github.com/ooni/minivpn/vpn/mocks" - tls "github.com/refraction-networking/utls" - "golang.zx2c4.com/wireguard/tun/netstack" -) - -func makeTestingClient(opt *Options) *Client { - client := &Client{Opts: opt} - client.conn = makeTestingConnForHandshake("udp", "10.0.0.1", 42) - client.tunInfo = &tunnelInfo{ip: "10.0.0.1", mtu: 1500} - client.mux = &mockMuxerForClient{} - return client -} - -func TestNewTunDialer(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - mockClient := makeTestingClient(opt) - type args struct { - client *Client - } - tests := []struct { - name string - args args - want *TunDialer - }{ - { - name: "get dialer ok", - args: args{ - client: mockClient, - }, - want: &TunDialer{ - client: mockClient, - ns1: openDNSPrimary, - ns2: openDNSSecondary, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewTunDialer(tt.args.client) - if tt.want != nil && got == nil { - t.Errorf("expected non-nil result") - return - } - if got.client == nil { - t.Errorf("client should not be nil") - return - } - if !reflect.DeepEqual(got.client.Opts, tt.want.client.Opts) { - t.Errorf("NewTunDialerFromOptions() = %v, want %v", got, tt.want) - return - } - }) - } -} - -func TestNewTunDialerWithNameservers(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - mockClient := makeTestingClient(opt) - type args struct { - client *Client - ns1 string - ns2 string - } - tests := []struct { - name string - args args - want *TunDialer - }{ - { - name: "get tundialer with passed nameservers", - args: args{ - client: mockClient, - ns1: "8.8.8.8", - ns2: "8.8.4.4", - }, - want: &TunDialer{ - client: mockClient, - ns1: "8.8.8.8", - ns2: "8.8.4.4", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewTunDialerWithNameservers(tt.args.client, tt.args.ns1, tt.args.ns2) - if tt.want != nil && got == nil { - t.Errorf("expected non-nil result") - return - } - if got.client == nil { - t.Errorf("client should not be nil") - return - } - if got.ns1 != tt.want.ns1 { - t.Errorf("NewTunDialerWithNameservers() ns1 = %v, want %v", got.ns1, tt.want.ns1) - } - if got.ns2 != tt.want.ns2 { - t.Errorf("NewTunDialerWithNameservers() ns2 = %v, want %v", got.ns2, tt.want.ns2) - } - }) - } -} - -type mockDialer struct { - called bool -} - -func (d *mockDialer) DialContext(ctx context.Context, a, b string) (net.Conn, error) { - d.called = true - conn := makeTestingConnForHandshake("udp", "10.0.0.0", 42) - return conn, nil -} - -func TestStartNewTunDialerFromOptions(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - - type args struct { - opt *Options - dialer *mockDialer - } - tests := []struct { - name string - args args - want *TunDialer - wantErr error - }{ - { - name: "get tundialer from options calls start and fails on tls handshake", - args: args{ - opt: opt, - dialer: &mockDialer{}, - }, - want: nil, - // TODO(ainghazal): I'd like to return nil here, but that would force - // me to leak even more internals from the client - // initialization. maybe it's not a good idea to have a - // convenience function that returns an started client after all? - wantErr: ErrBadTLSHandshake, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := StartNewTunDialerFromOptions(tt.args.opt, tt.args.dialer) - if !errors.Is(err, tt.wantErr) { - t.Errorf("expected error %v, got %v", tt.wantErr, err) - return - } - if tt.want == nil && got == nil { - return - } - if tt.want != nil && got != nil { - t.Errorf("expected non-nil result") - return - } - if tt.want != nil || got.client == nil { - t.Errorf("client should not be nil") - return - } - if !tt.args.dialer.called { - t.Errorf("the mock Dialer has not been called") - return - } - if !reflect.DeepEqual(got.client.Opts, tt.want.client.Opts) { - t.Errorf("NewTunDialerFromOptions() = %v, want %v", got, tt.want) - return - } - }) - } -} - -type mockedDialerContext struct{} - -func (md *mockedDialerContext) DialContext(context.Context, string, string) (net.Conn, error) { - conn := makeTestingConnForHandshake("udp", "10.0.0.0", 42) - return conn, nil -} - -func makeTestingConnForReadWrite(network, addr string, n int) net.Conn { - mockAddr := &mocks.Addr{} - mockAddr.MockString = func() string { - return addr - } - mockAddr.MockNetwork = func() string { - return network - } - - mockConn := &mocks.Conn{} - mockConn.MockLocalAddr = func() net.Addr { - return mockAddr - } - mockConn.MockWrite = func([]byte) (int, error) { - return n, nil - } - mockConn.MockRead = func(b []byte) (int, error) { - switch mockConn.Count { - case 0: - // control message data (to load remote key) - p := []byte{0x00, 0x00, 0x00, 0x00, 0x02} - p = append(p, bytes.Repeat([]byte{0x01}, 70)...) - copy(b[:], p) - mockConn.Count += 1 - return len(p), nil - case 1: - // control message data (pushed options) - p := []byte("PUSH_REPLY,ifconfig 2.2.2.2") - copy(b[:], p) - mockConn.Count += 1 - return len(p), nil - } - - return 0, nil - } - return mockConn -} - -// TODO(https://github.com/ooni/minivpn/issues/28): -// refactor test to use custom dialers -func TestTunDialer_Dial(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - mockClient := makeTestingClient(opt) - - orig := initTLSFn - defer func() { - initTLSFn = orig - }() - - initTLSFn = func(*session, *certConfig) (*tls.Config, error) { - return &tls.Config{InsecureSkipVerify: true}, nil - } - tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) { - conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42) - return conn, nil - } - - type fields struct { - Dialer DialerContext - client *Client - ns1 string - ns2 string - skipDeviceSetup bool - } - type args struct { - network string - address string - } - tests := []struct { - name string - fields fields - args args - want net.Conn - wantErr error - }{ - { - name: "dial ok with mocked dialFn", - fields: fields{ - client: mockClient, - ns1: "8.8.8.8", - ns2: "8.8.4.4", - skipDeviceSetup: true, - }, - args: args{ - network: "udp", - address: "10.0.88.88:443", - }, - wantErr: nil, - }, - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - td := TunDialer{ - client: tt.fields.client, - ns1: tt.fields.ns1, - ns2: tt.fields.ns2, - skipDeviceSetup: tt.fields.skipDeviceSetup, - } - conn, err := td.Dial(tt.args.network, tt.args.address) - if !errors.Is(err, tt.wantErr) { - t.Errorf("TunDialer.Dial() error = %v, wantErr %v", err, tt.wantErr) - return - } - conn.Close() - }) - } -} - -// TODO(https://github.com/ooni/minivpn/issues/28): -// refactor test to use custom dialers -func TestTunDialer_DialTimeout(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - mockClient := makeTestingClient(opt) - - orig := initTLSFn - defer func() { - initTLSFn = orig - }() - initTLSFn = func(*session, *certConfig) (*tls.Config, error) { - return &tls.Config{InsecureSkipVerify: true}, nil - } - tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) { - conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42) - return conn, nil - } - type fields struct { - client *Client - ns1 string - ns2 string - skipDeviceSetup bool - } - type args struct { - network string - address string - timeout time.Duration - } - tests := []struct { - name string - fields fields - args args - want net.Conn - wantErr error - }{ - { - name: "dial ok with mocked dialFn", - fields: fields{ - client: mockClient, - ns1: "8.8.8.8", - ns2: "8.8.4.4", - skipDeviceSetup: true, - }, - args: args{ - network: "udp", - address: "10.0.88.88:443", - timeout: time.Second, - }, - wantErr: nil, - }, - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - td := TunDialer{ - client: tt.fields.client, - ns1: tt.fields.ns1, - ns2: tt.fields.ns2, - skipDeviceSetup: true, - } - conn, err := td.DialTimeout(tt.args.network, tt.args.address, tt.args.timeout) - if !errors.Is(err, tt.wantErr) { - t.Errorf("TunDialer.DialTimeout() error = %v, wantErr %v", err, tt.wantErr) - return - } - conn.Close() - }) - } -} - -// TODO(https://github.com/ooni/minivpn/issues/28): -// refactor test to use custom dialers -func TestTunDialer_DialContext(t *testing.T) { - opt := makeTestingOptions(t, "AES-128-GCM", "sha512") - mockClient := makeTestingClient(opt) - - orig := initTLSFn - defer func() { - initTLSFn = orig - }() - initTLSFn = func(*session, *certConfig) (*tls.Config, error) { - return &tls.Config{InsecureSkipVerify: true}, nil - } - tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) { - conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42) - return conn, nil - } - - type fields struct { - client *Client - ns1 string - ns2 string - skipDeviceSetup bool - } - type args struct { - ctx context.Context - network string - address string - } - tests := []struct { - name string - fields fields - args args - want net.Conn - wantErr error - }{ - { - name: "dial ok with mocked dialer", - fields: fields{ - client: mockClient, - ns1: "8.8.8.8", - ns2: "8.8.4.4", - skipDeviceSetup: true, - }, - args: args{ - ctx: context.Background(), - network: "udp", - address: "10.0.88.88:443", - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - td := TunDialer{ - client: tt.fields.client, - ns1: tt.fields.ns1, - ns2: tt.fields.ns2, - skipDeviceSetup: true, - } - conn, err := td.DialContext(tt.args.ctx, tt.args.network, tt.args.address) - if !errors.Is(err, tt.wantErr) { - t.Errorf("TunDialer.DialContext() error = %v, wantErr %v", err, tt.wantErr) - return - } - conn.Close() - }) - } -} - -func Test_device_Up(t *testing.T) { - tun, _, _ := netstack.CreateNetTUN( - []netip.Addr{netip.MustParseAddr("10.0.0.1")}, - []netip.Addr{ - netip.MustParseAddr("8.8.8.8"), - netip.MustParseAddr("4.4.4.4")}, - 1500) - vpn := makeTestinConnFromNetwork("udp") - d := device{tun: tun, vpn: vpn} - d.Up() -} diff --git a/vpn/doc.go b/vpn/doc.go deleted file mode 100644 index 1818a6ad..00000000 --- a/vpn/doc.go +++ /dev/null @@ -1,21 +0,0 @@ -// -// Package vpn contains the API to create an OpenVPN client that can connect to -// a remote OpenVPN endpoint and provide you with a tunnel where to send packets. -// -// The recommended way to use this package is to use the TunDialer -// constructors, that gives you a way to transparently Dial() and get TCP or -// UDP sockets over a virtual gVisor interface, that uses the VPN tunnel as an -// underlying transport. For examples, see the `proxy' implementation in the -// `extras` package. -// -// If you need to write raw packets to the tunnel instead, you can construct -// and use a `Client` object directly. `Client` is an implementer of the -// `net.Conn` interface. You need to `Start()` the Client before you can Read -// or Write packets to it. An example of this can be found in the `extras/ping` -// package. -// -// Reads and Writes to the Client tunnel object are -// actually reading and writing to the initialized Data channel of the Client. -// Any incoming packet while reading that is not an OpenVPN data packet will be -// dispatched accordingly. -package vpn diff --git a/vpn/events.go b/vpn/events.go deleted file mode 100644 index 1b6e51ec..00000000 --- a/vpn/events.go +++ /dev/null @@ -1,20 +0,0 @@ -package vpn - -// -// Catalog of the events that can be emitted, so that users of the library can -// observe progress of the client's bootstrap. -// - -// The events to be emitted. This is treated as an uint8, so if we ever go past -// 255 events we need to accomodate the data type. -const ( - EventReady = iota - EventDialDone - EventHandshake - EventReset - EventTLSConn - EventTLSHandshake - EventTLSHandshakeDone - EventDataInitDone - EventHandshakeDone -) diff --git a/vpn/logger.go b/vpn/logger.go deleted file mode 100644 index 90066229..00000000 --- a/vpn/logger.go +++ /dev/null @@ -1,82 +0,0 @@ -package vpn - -// -// Logging capabilities. -// - -import ( - "log" - "os" -) - -// logger uses an implementation from the standard library in case the -// binary does not set its own. -var logger Logger = &defaultLogger{} - -// Logger is compatible with github.com/apex/log -type Logger interface { - // Debug emits a debug message. - Debug(msg string) - - // Debugf formats and emits a debug message. - Debugf(format string, v ...interface{}) - - // Info emits an informational message. - Info(msg string) - - // Infof formats and emits an informational message. - Infof(format string, v ...interface{}) - - // Warn emits a warning message. - Warn(msg string) - - // Warnf formats and emits a warning message. - Warnf(format string, v ...interface{}) - - // Error emits an error message - Error(msg string) - - // Errorf formats and emits an error message. - Errorf(format string, v ...interface{}) -} - -// defaultLogger uses the standard log package for logs in case -// the user does not provide a custom Log implementation. - -type defaultLogger struct{} - -func (dl *defaultLogger) Debug(msg string) { - if os.Getenv("EXTRA_DEBUG") == "1" { - log.Println(msg) - } -} - -func (dl *defaultLogger) Debugf(format string, v ...interface{}) { - if os.Getenv("EXTRA_DEBUG") == "1" { - log.Printf(format, v...) - } -} - -func (dl *defaultLogger) Info(msg string) { - log.Printf("info : %s\n", msg) -} - -func (dl *defaultLogger) Infof(format string, v ...interface{}) { - log.Printf("info : "+format, v...) -} - -func (dl *defaultLogger) Warn(msg string) { - log.Printf("warn: %s\n", msg) -} - -func (dl *defaultLogger) Warnf(format string, v ...interface{}) { - log.Printf("warn: "+format, v...) -} - -func (dl *defaultLogger) Error(msg string) { - log.Printf("error: %s\n", msg) -} - -func (dl *defaultLogger) Errorf(format string, v ...interface{}) { - log.Printf("error: "+format, v...) -} diff --git a/vpn/logger_test.go b/vpn/logger_test.go deleted file mode 100644 index b80a8f2d..00000000 --- a/vpn/logger_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package vpn - -import ( - "os" - "testing" -) - -func TestDefaultLoggerDoesNotFail(t *testing.T) { - os.Setenv("EXTRA_DEBUG", "1") - logger := defaultLogger{} - logger.Debug("foo") - logger.Debugf("%s", "foo") - logger.Info("foo") - logger.Infof("%s", "foo") - logger.Warn("foo") - logger.Warnf("%s", "foo") - logger.Error("foo") - logger.Errorf("%s", "foo") -} diff --git a/vpn/muxer.go b/vpn/muxer.go deleted file mode 100644 index b480f63e..00000000 --- a/vpn/muxer.go +++ /dev/null @@ -1,553 +0,0 @@ -package vpn - -import ( - "bytes" - "context" - "encoding/hex" - "errors" - "fmt" - "log" - "net" -) - -// -// OpenVPN Multiplexer -// - -var ( - ErrBadHandshake = errors.New("bad vpn handshake") - ErrBadDataHandshake = errors.New("bad data handshake") -) - -/* - The vpnMuxer interface represents the VPN transport multiplexer. - - One important limitation of the current implementation at this moment is that - the processing of incoming packets needs to be driven by reads from the user of - the library. This means that if you don't do reads during some time, any packets - on the control channel that the server sends us (e.g., openvpn-pings) will not - be processed (and so, not acknowledged) until triggered by a muxer.Read(). - - From the original documentation: - https://community.openvpn.net/openvpn/wiki/SecurityOverview - - "OpenVPN multiplexes the SSL/TLS session used for authentication and key - exchange with the actual encrypted tunnel data stream. OpenVPN provides the - SSL/TLS connection with a reliable transport layer (as it is designed to - operate over). The actual IP packets, after being encrypted and signed with an - HMAC, are tunnelled over UDP without any reliability layer. So if --proto udp - is used, no IP packets are tunneled over a reliable transport, eliminating the - problem of reliability-layer collisions -- Of course, if you are tunneling a - TCP session over OpenVPN running in UDP mode, the TCP protocol itself will - provide the reliability layer." - - SSL/TLS -> Reliability Layer -> \ - --tls-auth HMAC \ - \ - > Multiplexer ----> UDP/TCP - / Transport - IP Encrypt and HMAC / - Tunnel -> using OpenSSL EVP --> / - Packets interface. - -"This model has the benefit that SSL/TLS sees a reliable transport layer while -the IP packet forwarder sees an unreliable transport layer -- exactly what both -components want to see. The reliability and authentication layers are -completely independent of one another, i.e. the sequence number is embedded -inside the HMAC-signed envelope and is not used for authentication purposes." - -*/ - -// muxer implements vpnMuxer -type muxer struct { - - // A net.Conn that has access to the "wire" transport. this can - // represent an UDP/TCP socket, or a net.Conn coming from a Pluggable - // Transport etc. - conn net.Conn - - // After completing the TLS handshake, we get a tls transport that implements - // net.Conn. All the control packets from that moment on are read from - // and written to the tls Conn. - tls net.Conn - - // control and data are the handlers for the control and data channels. - // they implement the methods needed for the handshake and handling of - // packets. - control controlHandler - data dataHandler - - // bufReader is used to buffer data channel reads. We only write to - // this buffer when we have correctly decrypted an incoming - bufReader *bytes.Buffer - - // Mutable state tied to a concrete session. - session *session - - // Mutable state tied to a particular vpn run. - tunnel *tunnelInfo - - // Options are OpenVPN options that come from parsing a subset of the OpenVPN - // configuration directives, plus some non-standard config directives. - options *Options - - // eventListener is a channel to which Event_*- will be sent if - // the channel is not nil. - eventListener chan uint8 - - failed bool -} - -var _ vpnMuxer = &muxer{} // Ensure that we implement the vpnMuxer interface. - -// -// Interfaces -// - -// vpnMuxer contains all the behavior expected by the muxer. -type vpnMuxer interface { - Handshake(ctx context.Context) error - Reset(net.Conn, *session) error - InitDataWithRemoteKey() error - SetEventListener(chan uint8) - Write([]byte) (int, error) - Read([]byte) (int, error) -} - -// controlHandler manages the control "channel". -type controlHandler interface { - SendHardReset(net.Conn, *session) error - ParseHardReset([]byte) (sessionID, error) - SendACK(net.Conn, *session, packetID) error - PushRequest() []byte - ReadPushResponse([]byte) map[string][]string - ControlMessage(*session, *Options) ([]byte, error) - ReadControlMessage([]byte) (*keySource, string, error) -} - -// dataHandler manages the data "channel". -type dataHandler interface { - SetupKeys(*dataChannelKey) error - SetPeerID(int) error - WritePacket(net.Conn, []byte) (int, error) - ReadPacket(*packet) ([]byte, error) - DecodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error) - EncryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error) -} - -// -// muxer initialization -// - -// muxFactory acepts a net.Conn, a pointer to an Options object, and another -// pointer to a tunnelInfo object, and returns a vpnMuxer and an error if it -// could not be initialized. This type is used to be able to mock a muxer while -// testing the Client. -type muxFactory func(conn net.Conn, options *Options, tunnel *tunnelInfo) (vpnMuxer, error) - -// newMuxerFromOptions returns a configured muxer, and any error if the -// operation could not be completed. -func newMuxerFromOptions(conn net.Conn, options *Options, tunnel *tunnelInfo) (vpnMuxer, error) { - control := &control{} - session, err := newSession() - if err != nil { - return &muxer{}, err - } - data, err := newDataFromOptions(options, session) - if err != nil { - return &muxer{}, err - } - br := bytes.NewBuffer(nil) - - m := &muxer{ - conn: conn, - session: session, - options: options, - control: control, - data: data, - tunnel: tunnel, - bufReader: br, - } - return m, nil -} - -// -// observability -// - -// SetEvenSetEventListener assigns the passed channel as the event listener for -// this muxer. -func (m *muxer) SetEventListener(el chan uint8) { - m.eventListener = el -} - -// emit sends the passed stage into any configured EventListener -func (m *muxer) emit(stage uint8) { - select { - case m.eventListener <- stage: - default: - // do not deliver - } -} - -// -// muxer handshake -// - -// Handshake performs the OpenVPN "handshake" operations serially. Accepts a -// Context, and itt returns any error that is raised at any of the underlying -// steps. -func (m *muxer) Handshake(ctx context.Context) (err error) { - errch := make(chan error, 1) - go func() { - errch <- m.handshake() - }() - select { - case err = <-errch: - case <-ctx.Done(): - err = ctx.Err() - } - return -} - -func (m *muxer) handshake() error { - - // 1. control channel sends reset, parse response. - - m.emit(EventReset) - - if err := m.Reset(m.conn, m.session); err != nil { - return fmt.Errorf("%w: %s", ErrBadHandshake, err) - - } - - // 2. TLS handshake. - - // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer - if !m.options.hasAuthInfo() { - return fmt.Errorf("%w: %s", errBadInput, "expected certificate or username/password") - } - certCfg, err := newCertConfigFromOptions(m.options) - if err != nil { - return err - } - - tlsConf, err := initTLSFn(m.session, certCfg) - if err != nil { - return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err) - - } - tlsConn, err := newControlChannelTLSConn(m.conn, m.session) - m.emit(EventTLSConn) - - if err != nil { - return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err) - } - - m.emit(EventTLSHandshake) - - tls, err := tlsHandshakeFn(tlsConn, tlsConf) - if err != nil { - return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err) - - } - - m.emit(EventTLSHandshakeDone) - - m.tls = tls - logger.Info("TLS handshake done") - - // 3. data channel init (auth, push, data initialization). - - if err := m.InitDataWithRemoteKey(); err != nil { - return fmt.Errorf("%w: %s", ErrBadDataHandshake, err) - - } - - m.emit(EventDataInitDone) - - logger.Info("VPN handshake done") - return nil -} - -// Reset sends a hard-reset packet to the server, and awaits the server -// confirmation. -func (m *muxer) Reset(conn net.Conn, s *session) error { - if m.control == nil { - return fmt.Errorf("%w:%s", errBadInput, "bad control") - } - if err := m.control.SendHardReset(conn, s); err != nil { - return err - } - - resp, err := readPacket(m.conn) - if err != nil { - return err - } - - remoteSessionID, err := m.control.ParseHardReset(resp) - - // here we could check if we have received a remote session id but - // our session.remoteSessionID is != from all zeros - if err != nil { - return err - } - m.session.RemoteSessionID = remoteSessionID - - logger.Infof("Remote session ID: %x", m.session.RemoteSessionID) - logger.Infof("Local session ID: %x", m.session.LocalSessionID) - - // we assume id is 0, this is the first packet we ack. - // XXX I could parse the real packet id from server instead. this - // _might_ be important when re-keying? - return m.control.SendACK(m.conn, m.session, packetID(0)) -} - -// -// muxer: read and handle packets -// - -// handleIncoming packet reads the next packet available in the underlying -// socket. It returns true if the packet was a data packet; otherwise it will -// process it but return false. -func (m *muxer) handleIncomingPacket(data []byte) (bool, error) { - if m.data == nil { - logger.Errorf("uninitialized muxer") - return false, errBadInput - } - var input []byte - if data == nil { - parsed, err := readPacket(m.conn) - if err != nil { - return false, err - } - input = parsed - } else { - input = data - } - - if isPing(input) { - err := handleDataPing(m.conn, m.data) - if err != nil { - logger.Errorf("cannot handle ping: %s", err.Error()) - } - return false, nil - } - - p, err := parsePacketFromBytes(input) - if err != nil { - logger.Error(err.Error()) - return false, err - } - if p.isACK() { - logger.Warn("muxer: got ACK (ignored)") - return false, err - } - if p.isControl() { - logger.Infof("Got control packet: %d", len(data)) - // Here the server might be requesting us to reset, or to - // re-key (but I keep ignoring that case for now). - // we're doing nothing for now. - fmt.Println(hex.Dump(p.payload)) - return false, err - } - if !p.isData() { - logger.Warnf("unhandled data. (op: %d)", p.opcode) - fmt.Println(hex.Dump(data)) - return false, err - } - - // at this point, the incoming packet should be - // a data packet that needs to be processed - // (decompress+decrypt) - - plaintext, err := m.data.ReadPacket(p) - if err != nil { - logger.Errorf("bad decryption: %s", err.Error()) - // XXX I'm not sure returning false is the right thing to do here. - return false, err - } - - // all good! we write the plaintext into the read buffer. - // the caller is responsible for reading from there. - m.bufReader.Write(plaintext) - return true, nil -} - -// handleDataPing replies to an openvpn-ping with a canned response. -func handleDataPing(conn net.Conn, data dataHandler) error { - log.Println("openvpn-ping, sending reply") - _, err := data.WritePacket(conn, pingPayload) - return err -} - -// readTLSPacket reads a packet over the TLS connection. -func (m *muxer) readTLSPacket() ([]byte, error) { - data := make([]byte, 4096) - _, err := m.tls.Read(data) - return data, err -} - -// readAndLoadRemoteKey reads one incoming TLS packet, and tries to parse the -// response contained in it. If the server response is the right kind of -// packet, it will store the remote key and the parts of the remote options -// that will be of use later. -func (m *muxer) readAndLoadRemoteKey() error { - data, err := m.readTLSPacket() - if err != nil { - return err - } - if !isControlMessage(data) { - return fmt.Errorf("%w: %s", errBadControlMessage, "expected null header") - } - - // Parse the received data: we expect remote key and remote options. - remoteKey, remoteOptStr, err := m.control.ReadControlMessage(data) - if err != nil { - logger.Errorf("cannot parse control message") - return fmt.Errorf("%w: %s", ErrBadHandshake, err) - } - - // Store the remote key. - key, err := m.session.ActiveKey() - if err != nil { - logger.Errorf("cannot get active key") - return fmt.Errorf("%w: %s", ErrBadHandshake, err) - } - err = key.addRemoteKey(remoteKey) - if err != nil { - logger.Errorf("cannot add remote key") - return fmt.Errorf("%w: %s", ErrBadHandshake, err) - } - - // Parse and update the useful fields from the remote options (mtu). - ti := newTunnelInfoFromRemoteOptionsString(remoteOptStr) - m.tunnel.mtu = ti.mtu - return nil -} - -// sendPushRequest sends a push request over the TLS channel. -func (m *muxer) sendPushRequest() (int, error) { - return m.tls.Write(m.control.PushRequest()) -} - -// readPushReply reads one incoming TLS packet, where we expect to find the -// response to our push request. If the server response is the right kind of -// packet, it will store the parts of the pushed options that will be of use -// later. -func (m *muxer) readPushReply() error { - if m.control == nil || m.tunnel == nil { - return fmt.Errorf("%w:%s", errBadInput, "muxer badly initialized") - - } - resp, err := m.readTLSPacket() - if err != nil { - return err - } - - logger.Info("Server pushed options") - - if isBadAuthReply(resp) { - return errBadAuth - } - - if !isPushReply(resp) { - return fmt.Errorf("%w:%s", errBadServerReply, "expected push reply") - } - - optsMap := m.control.ReadPushResponse(resp) - ti := newTunnelInfoFromPushedOptions(optsMap) - - m.tunnel.ip = ti.ip - m.tunnel.gw = ti.gw - m.tunnel.peerID = ti.peerID - - logger.Infof("Tunnel IP: %s", m.tunnel.ip) - logger.Infof("Gateway IP: %s", m.tunnel.gw) - logger.Infof("Peer ID: %d", m.tunnel.peerID) - - return nil -} - -// sendControl message sends a control message over the TLS channel. -func (m *muxer) sendControlMessage() error { - cm, err := m.control.ControlMessage(m.session, m.options) - if err != nil { - return err - } - if _, err := m.tls.Write(cm); err != nil { - return err - } - return nil -} - -// InitDataWithRemoteKey initializes the internal data channel. To do that, it sends a -// control packet, parses the response, and derives the cryptographic material -// that will be used to encrypt and decrypt data through the tunnel. At the end -// of this exchange, the data channel is ready to be used. -func (m *muxer) InitDataWithRemoteKey() error { - - // 1. first we send a control message. - - if err := m.sendControlMessage(); err != nil { - return err - } - - // 2. then we read the server response and load the remote key. - - if err := m.readAndLoadRemoteKey(); err != nil { - return err - } - - // 3. now we can initialize the data channel. - - key0, err := m.session.ActiveKey() - if err != nil { - return err - } - - err = m.data.SetupKeys(key0) - if err != nil { - return err - } - - // 4. finally, we ask the server to push remote options to us. we parse - // them and keep some useful info. - - if _, err := m.sendPushRequest(); err != nil { - return err - } - if err := m.readPushReply(); err != nil { - return err - } - - m.data.SetPeerID(m.tunnel.peerID) - - return nil -} - -// Write sends user bytes as encrypted packets in the data channel. It returns -// the number of written bytes, and an error if the operation could not succeed. -func (m *muxer) Write(b []byte) (int, error) { - if m.data == nil { - return 0, fmt.Errorf("%w:%s", errBadInput, "data not initialized") - - } - return m.data.WritePacket(m.conn, b) -} - -// Read reads bytes after decrypting packets from the data channel. This is the -// user-view of the VPN connection reads. It returns the number of bytes read, -// and an error if the operation could not succeed. -func (m *muxer) Read(b []byte) (int, error) { - for { - ok, err := m.handleIncomingPacket(nil) - if err != nil { - return 0, err - } - if ok { - break - } - } - return m.bufReader.Read(b) -} diff --git a/vpn/muxer_test.go b/vpn/muxer_test.go deleted file mode 100644 index c03ca7b8..00000000 --- a/vpn/muxer_test.go +++ /dev/null @@ -1,606 +0,0 @@ -package vpn - -import ( - "bytes" - "context" - "errors" - "net" - "reflect" - "testing" - - "github.com/ooni/minivpn/vpn/mocks" - tls "github.com/refraction-networking/utls" -) - -func Test_newMuxerFromOptions(t *testing.T) { - randomFn = func(int) ([]byte, error) { - return []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil - } - testSession, _ := newSession() - - type args struct { - conn net.Conn - options *Options - tunnel *tunnelInfo - } - tests := []struct { - name string - args args - want *muxer - wantErr error - }{ - { - name: "get muxer ok", - args: args{ - conn: makeTestingConnForWrite("udp", "10.0.42.2", 42), - options: makeTestingOptions(t, "AES-128-GCM", "sha1"), - tunnel: &tunnelInfo{}, - }, - want: &muxer{ - conn: makeTestingConnForWrite("udp", "10.0.42.2", 42), - control: &control{}, - session: testSession, - options: makeTestingOptions(t, "AES-128-GCM", "sha1"), - }, - wantErr: nil, - }, - // TODO: Add more test cases: - // failure on newSession() - // failure in newData() - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := newMuxerFromOptions(tt.args.conn, tt.args.options, tt.args.tunnel) - if !errors.Is(err, tt.wantErr) { - t.Errorf("newMuxerFromOptions() error = %v, wantErr %v", err, tt.wantErr) - return - } - // TODO(ainghazal): we cannot compare the options because the paths for the certs are going to be different - // I think this calls from separating the initial options from a more structured config - // with the parsed, loaded certs instead. - }) - } -} - -func makeTestingConnForHandshake(network, addr string, n int) net.Conn { - ma := &mocks.Addr{} - ma.MockString = func() string { - return addr - } - ma.MockNetwork = func() string { - return network - } - - c := &mocks.Conn{} - c.MockLocalAddr = func() net.Addr { - return ma - } - c.MockWrite = func([]byte) (int, error) { - return n, nil - } - c.MockRead = func(b []byte) (int, error) { - switch c.Count { - case 0: - // this is the expected reset response from server - rp := []byte{ - 0x40, - 0x00, 0x01, 0x02, 0x03, 0x04, - 0x05, 0x06, 0x07, 0x08, - } - copy(b[:], rp) - c.Count += 1 - return len(rp), nil - case 1: - // control message data (to load remote key) - p := []byte{0x00, 0x00, 0x00, 0x00, 0x02} - p = append(p, bytes.Repeat([]byte{0x01}, 70)...) - copy(b[:], p) - c.Count += 1 - return len(p), nil - case 2: - // control message data (to load remote key) - p := []byte("PUSH_REPLY") - copy(b[:], p) - c.Count += 1 - return len(p), nil - } - - return 0, nil - } - c.MockClose = func() error { - return nil - } - return c -} - -type mockMuxerForHandshake struct { - muxer -} - -func (md *mockMuxerForHandshake) sendControlMessage() error { - return nil -} - -func (md *mockMuxerForHandshake) readAndLoadRemoteKey() error { - return nil -} - -type mockMuxerWithDummyHandshake struct { - mockMuxerForHandshake -} - -func (md *mockMuxerWithDummyHandshake) Handshake(context.Context) error { - return nil -} - -func Test_muxer_Handshake(t *testing.T) { - makeData := func() *data { - options := makeTestingOptions(t, "AES-128-GCM", "sha1") - data, _ := newDataFromOptions(options, makeTestingSession()) - return data - } - - m := &mockMuxerForHandshake{} - m.control = &control{} - m.data = makeData() - m.tunnel = &tunnelInfo{} - s, err := newSession() - if err != nil { - t.Error("session failed, cannot run handshake test") - } - m.session = s - m.options = makeTestingOptions(t, "AES-128-GCM", "sha512") - m.tls = makeTestingConnForWrite("udp", "0.0.0.0", 42) - m.conn = makeTestingConnForHandshake("udp", "10.0.0.0", 42) - - origInit := initTLSFn - origHandshake := tlsHandshakeFn - - defer func() { - initTLSFn = origInit - tlsHandshakeFn = origHandshake - }() - - // monkey patch the global functions - - initTLSFn = func(*session, *certConfig) (*tls.Config, error) { - return &tls.Config{InsecureSkipVerify: true}, nil - } - tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) { - return m.conn, nil - } - - // and now for the test itself... - - err = m.Handshake(context.Background()) - if err != nil { - t.Errorf("muxer.Handshake() error = %v, wantErr nil", err) - return - } -} - -func makePacketForHandleIncomingTest(opcode byte, s *session) *packet { - p := &packet{ - id: packetID(1), // always a good packet for a clean session - opcode: opcode, - keyID: 0x00, - payload: []byte("aaa"), - localSessionID: s.LocalSessionID, - remoteSessionID: s.RemoteSessionID, - acks: []packetID{}, - } - return p -} - -// I have modified muxer.handleIncomingPacket() so that it optionally receives a []byte -// in order to make it easier to test payloads. here we go: -type mockDataHandler struct{} - -func (m *mockDataHandler) SetupKeys(*dataChannelKey) error { - return nil -} - -func (m *mockDataHandler) WritePacket(net.Conn, []byte) (int, error) { - return 42, nil -} - -func (m *mockDataHandler) ReadPacket(*packet) ([]byte, error) { - return []byte("alles ist gut"), nil -} - -func (m *mockDataHandler) DecodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error) { - return &encryptedData{}, nil -} - -func (m *mockDataHandler) EncryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error) { - return []byte("this is not a payload"), nil -} - -func (m *mockDataHandler) SetPeerID(int) error { - return nil -} - -type mockDataHandlerBadReadPacket struct { - mockDataHandler -} - -func (m *mockDataHandlerBadReadPacket) ReadPacket(*packet) ([]byte, error) { - dummy := errors.New("dummy error") - return []byte{}, dummy -} - -var _ dataHandler = &mockData{} - -func Test_muxer_handleIncomingPacket(t *testing.T) { - m := muxer{ - data: &mockData{}, - bufReader: &bytes.Buffer{}, - } - - // ping data - if ok, _ := m.handleIncomingPacket(pingPayload); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with ping payload") - return - } - // packets with different opcodes - if ok, _ := m.handleIncomingPacket([]byte{}); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with empty bytes") - return - } - p := &packet{opcode: pACKV1} - if ok, _ := m.handleIncomingPacket(p.Bytes()); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with ack packet") - return - } - p = &packet{opcode: pControlV1} - if ok, _ := m.handleIncomingPacket(p.Bytes()); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with control packet") - return - } - p = &packet{opcode: pControlV1} - if ok, _ := m.handleIncomingPacket(p.Bytes()); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with control packet") - return - } - p = &packet{opcode: byte(0xff)} - if ok, _ := m.handleIncomingPacket(p.Bytes()); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with unknown opcode") - return - } - p = &packet{opcode: pDataV1} - if ok, _ := m.handleIncomingPacket(p.Bytes()); !ok { - t.Errorf("muxer.handleIncomingPacket(): expected ok with data opcode") - return - } - - // replace dataHandler in muxer with a method that raises error on ReadPacket() - t.Run("error in ReadPacket() should propagate", func(t *testing.T) { - m = muxer{ - data: &mockDataHandlerBadReadPacket{}, - bufReader: &bytes.Buffer{}, - } - p = &packet{opcode: pDataV1} - if ok, _ := m.handleIncomingPacket(p.Bytes()); ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with error in ReadPacket()") - } - }) - - t.Run("null data raises error", func(t *testing.T) { - m = muxer{ - data: nil, - bufReader: &bytes.Buffer{}, - } - p = &packet{opcode: pDataV1} - ok, err := m.handleIncomingPacket(p.Bytes()) - if ok { - t.Errorf("muxer.handleIncomingPacket(): expected !ok with null data") - } - if err != errBadInput { - t.Errorf("muxer.handleIncomingPacket(): expected errBadInput") - } - }) -} - -func Test_muxer_Write(t *testing.T) { - - makeData := func() *data { - options := makeTestingOptions(t, "AES-128-GCM", "sha1") - data, _ := newDataFromOptions(options, makeTestingSession()) - return data - } - - type fields struct { - conn net.Conn - data dataHandler - } - type args struct { - b []byte - } - tests := []struct { - name string - fields fields - args args - want int - wantErr error - }{ - { - name: "write fails if data state is not initialized", - fields: fields{ - conn: makeTestinConnFromNetwork("udp"), - data: &data{}, - }, - args: args{[]byte("alles ist bad")}, - want: 0, - wantErr: errBadInput, - }, - { - name: "write fails if data state is nil", - fields: fields{ - conn: makeTestinConnFromNetwork("udp"), - data: nil, - }, - args: args{[]byte("alles ist bad")}, - want: 0, - wantErr: errBadInput, - }, - { - name: "write calls data.WritePacket", - fields: fields{ - conn: makeTestingConnForWrite("udp", "10.0.1.1", 42), - data: makeData(), - }, - args: args{[]byte("alles ist gut")}, - want: 42, - wantErr: nil, - }, - - // TODO can add more tests: - // [ ] check that the error raised by the underlying data read is the error we - // expect to be returned. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &muxer{ - conn: tt.fields.conn, - data: tt.fields.data, - } - got, err := m.Write(tt.args.b) - if !errors.Is(err, tt.wantErr) { - t.Errorf("muxer.Write() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("muxer.Write() = %v, want %v", got, tt.want) - } - }) - } -} - -func makeTestingConnForRead(retInt int, retErr error, payload []byte) net.Conn { - ma := &mocks.Addr{} - ma.MockString = func() string { - return "10.0.42.2" - } - ma.MockNetwork = func() string { - return "udp" - } - - mc := &mocks.Conn{} - mc.MockLocalAddr = func() net.Addr { - return ma - } - mc.MockRead = func(b []byte) (int, error) { - copy(b[:], payload) - return retInt, retErr - } - return mc -} - -type mockData struct { - data -} - -func (md *mockData) ReadPacket(*packet) ([]byte, error) { - return []byte("alles ist gut"), nil -} - -func Test_muxer_Read(t *testing.T) { - // XXX(ainghazal): I'm not sure this is a very good test. - // what I want to test: - // - that I call readPacket(mockConn) - I'm assuming we get a good data packet - // - that I call data.ReadPacket(p) - // - that we get the right return from muxer.Read() - // - that the expected buffer is written into the buffer that we pass to Read() - - testDataPacket := &packet{opcode: pDataV1, payload: []byte("discarded")} - bufData := "alles ist gut" - - b := make([]byte, 4096) - want := len(bufData) - m := &muxer{ - conn: makeTestingConnForRead(want, nil, testDataPacket.Bytes()), - data: &mockData{}, - bufReader: bytes.NewBuffer(nil), - } - got, err := m.Read(b) - if err != nil { - t.Errorf("muxer.Read() error = %v, wantErr nil", err) - return - } - if got != want { - t.Errorf("muxer.Read() = %v, want %v", got, want) - } - if !bytes.Equal(b[:len(bufData)], []byte(bufData)) { - t.Errorf("muxer.Read() = %v, want %v", string(b[:len(bufData)]), string(bufData)) - } -} - -func Test_muxer_readTLSPacket(t *testing.T) { - type fields struct { - conn net.Conn - tls net.Conn - control controlHandler - data dataHandler - bufReader *bytes.Buffer - session *session - tunnel *tunnelInfo - options *Options - } - tests := []struct { - name string - fields fields - want []byte - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &muxer{ - conn: tt.fields.conn, - tls: tt.fields.tls, - control: tt.fields.control, - data: tt.fields.data, - bufReader: tt.fields.bufReader, - session: tt.fields.session, - tunnel: tt.fields.tunnel, - options: tt.fields.options, - } - got, err := m.readTLSPacket() - if (err != nil) != tt.wantErr { - t.Errorf("muxer.readTLSPacket() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("muxer.readTLSPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_muxer_readAndLoadRemoteKey(t *testing.T) { - type fields struct { - conn net.Conn - tls net.Conn - control controlHandler - data dataHandler - bufReader *bytes.Buffer - session *session - tunnel *tunnelInfo - options *Options - } - tests := []struct { - name string - fields fields - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &muxer{ - conn: tt.fields.conn, - tls: tt.fields.tls, - control: tt.fields.control, - data: tt.fields.data, - bufReader: tt.fields.bufReader, - session: tt.fields.session, - tunnel: tt.fields.tunnel, - options: tt.fields.options, - } - if err := m.readAndLoadRemoteKey(); (err != nil) != tt.wantErr { - t.Errorf("muxer.readAndLoadRemoteKey() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_muxer_readPushReply(t *testing.T) { - type fields struct { - conn net.Conn - tls net.Conn - control controlHandler - data dataHandler - bufReader *bytes.Buffer - session *session - tunnel *tunnelInfo - options *Options - } - tests := []struct { - name string - fields fields - wantErr error - }{ - { - name: "control == nil should return error", - fields: fields{ - control: nil, - }, - wantErr: errBadInput, - }, - { - name: "tunnel == nil should return error", - fields: fields{ - tunnel: nil, - }, - wantErr: errBadInput, - }, - // TODO: Add moar test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &muxer{ - conn: tt.fields.conn, - tls: tt.fields.tls, - control: tt.fields.control, - data: tt.fields.data, - bufReader: tt.fields.bufReader, - session: tt.fields.session, - tunnel: tt.fields.tunnel, - options: tt.fields.options, - } - if err := m.readPushReply(); !errors.Is(err, tt.wantErr) { - t.Errorf("muxer.readPushReply() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_muxer_emitSendsToListener(t *testing.T) { - t.Run("emit writes event if listener not null", func(t *testing.T) { - l := make(chan uint8, 2) - m := &muxer{} - m.SetEventListener(l) - sent := uint8(2) - m.emit(sent) - got := <-l - if got != sent { - t.Errorf("expected %v, got %v", sent, got) - } - }) - t.Run("emit is a noop if evenlistener not set", func(t *testing.T) { - m := &muxer{} - sent := uint8(2) - m.emit(sent) - }) - t.Run("listener receives several events", func(t *testing.T) { - l := make(chan uint8, 5) - m := &muxer{} - m.SetEventListener(l) - received := []uint8{} - sent := []uint8{1, 2, 3, 4, 5} - for _, i := range sent { - m.emit(i) - } - for _ = range sent { - got := <-l - received = append(received, got) - } - for i := range sent { - if sent[i] != received[i] { - t.Errorf("at [%d]: expected %v, got %v", i, sent, received) - return - } - } - }) -} diff --git a/vpn/options.go b/vpn/options.go deleted file mode 100644 index 8f57fcbc..00000000 --- a/vpn/options.go +++ /dev/null @@ -1,673 +0,0 @@ -package vpn - -// -// Parse VPN options. -// -// Mostly, this file conforms to the format in the reference implementation. -// However, there are some additions that are specific. To avoid feature creep -// and fat dependencies, the main `vpn` module only supports mainline -// capabilities. It is still useful to carry all options in a single type, -// so it's up to the user of this library to do something useful with -// such options. The `extra` package provides some of these extra features, like -// obfuscation support. -// -// Following the configuration format in the reference implementation, `minivpn` -// allows including files in the main configuration file, but only for the `ca`, -// `cert` and `key` options. -// -// Each inline file is started by the line . -// -// Here is an example of an inline file usage: -// -// ``` -// -// -----BEGIN CERTIFICATE----- -// [...] -// -----END CERTIFICATE----- -// -// ``` - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "log" - "os" - "path/filepath" - "strconv" - "strings" -) - -type ( - // compression describes a compression type (e.g., stub). - compression string -) - -const ( - // compressionStub adds the (empty) compression stub to the packets. - compressionStub = compression("stub") - - // compressionEmpty is the empty compression. - compressionEmpty = compression("empty") - - // compressionLZONo is lzo-no (another type of no-compression, older). - compressionLZONo = compression("lzo-no") -) - -type ( - // proto is the main vpn mode (e.g., TCP or UDP). - proto string -) - -func (p proto) String() string { - return string(p) -} - -const ( - // protoTCP is used for vpn in TCP mode. - protoTCP = proto("tcp") - - // protoUDP is used for vpn in UDP mode. - protoUDP = proto("udp") -) - -var ( - // errBadCfg is the generic error returned for invalid config files - errBadCfg = errors.New("bad config") -) - -var supportedCiphers = []string{ - "AES-128-CBC", - "AES-192-CBC", - "AES-256-CBC", - "AES-128-GCM", - "AES-192-GCM", - "AES-256-GCM", -} - -var supportedAuth = []string{ - "SHA1", - "SHA256", - "SHA512", -} - -// Options make all the relevant configuration options accessible to the -// different modules that need it. -type Options struct { - Remote string - Port string - //TODO(https://github.com/ooni/minivpn/issues/25): Proto should be changed to a string and checked against known types. - Proto int - Username string - Password string - CaPath string - CertPath string - KeyPath string - Ca []byte - Cert []byte - Key []byte - Compress compression - Cipher string - Auth string - TLSMaxVer string - // below are options that do not conform to the OpenVPN configuration format. - ProxyOBFS4 string - Log Logger -} - -// NewOptionsFromFilePath expects a string with a path to a valid config file, -// and returns a pointer to a Options struct after parsing the file, and an -// error if the operation could not be completed. -func NewOptionsFromFilePath(filePath string) (*Options, error) { - lines, err := getLinesFromFile(filePath) - dir, _ := filepath.Split(filePath) - if err != nil { - return nil, err - } - return getOptionsFromLines(lines, dir) -} - -// certsFromPath returns true when the options object is configured to load -// certificates from paths; false when we have inline certificates. -func (o *Options) certsFromPath() bool { - return o.CertPath != "" && o.KeyPath != "" && o.CaPath != "" -} - -// hasAuthInfo returns true if: -// - we have paths for cert, key and ca; or -// - we have inline byte arrays for cert, key and ca; or -// - we have username + password info. -func (o *Options) hasAuthInfo() bool { - if o.CertPath != "" && o.KeyPath != "" && o.CaPath != "" { - return true - } - if len(o.Cert) != 0 && len(o.Key) != 0 && len(o.Ca) != 0 { - return true - } - if o.Username != "" && o.Password != "" { - return true - } - return false -} - -const clientOptions = "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto %sv4,cipher %s,auth %s,keysize %s,key-method 2,tls-client" - -// String produces a comma-separated representation of the options, in the same -// order and format that the openvpn server expects from us. -func (o *Options) String() string { - if o.Cipher == "" { - return "" - } - keysize := strings.Split(o.Cipher, "-")[1] - proto := strings.ToUpper(protoUDP.String()) - if o.Proto == TCPMode { - proto = strings.ToUpper(protoTCP.String()) - } - s := fmt.Sprintf( - clientOptions, - proto, o.Cipher, o.Auth, keysize) - if o.Compress == compressionStub { - s = s + ",compress stub" - } else if o.Compress == "lzo-no" { - s = s + ",lzo-comp no" - } else if o.Compress == compressionEmpty { - s = s + ",compress" - } - logger.Debugf("Local opts: %s", s) - return s -} - -// newTunnelInfoFromRemoteOptionsString parses the options string returned by -// server. it returns a new tunnel object where the needed fields have been -// updated. At the moment, we only parse the tun-mtu parameter. -func newTunnelInfoFromRemoteOptionsString(remoteOpts string) *tunnelInfo { - t := &tunnelInfo{} - opts := strings.Split(remoteOpts, ",") - for _, opt := range opts { - vals := strings.Split(opt, " ") - if len(vals) < 2 { - continue - } - k, v := vals[0], vals[1:] - if k == "tun-mtu" { - mtu, err := strconv.Atoi(v[0]) - if err != nil { - log.Println("bad mtu:", err) - continue - } - t.mtu = mtu - } - } - return t -} - -// newTunnelInfoFromPushedOptions takes a map of string to array of strings, and returns -// a new tunnel struct with the relevant info. -func newTunnelInfoFromPushedOptions(opts map[string][]string) *tunnelInfo { - t := &tunnelInfo{} - if r := opts["route"]; len(r) >= 1 { - t.gw = r[0] - } else if r := opts["route-gateway"]; len(r) >= 1 { - t.gw = r[0] - } - ip := opts["ifconfig"] - if len(ip) >= 1 { - t.ip = ip[0] - } - peerID := opts["peer-id"] - if len(peerID) == 1 { - i, err := parseIntFromOption(peerID[0]) - if err == nil { - t.peerID = i - } else { - log.Println("Cannot parse peer-id:", err.Error()) - } - } - return t -} - -// parseIntFromOption parses an int from a null-terminated string -func parseIntFromOption(s string) (int, error) { - str := "" - for i := 0; i < len(s); i++ { - if byte(s[i]) == 0x00 { - return strconv.Atoi(str) - } - str = str + string(s[i]) - } - return 0, nil -} - -// pushedOptionsAsMap returns a map for the server-pushed options, -// where the options are the keys and each space-separated value is the value. -// This function always returns an initialized map, even if empty. -func pushedOptionsAsMap(pushedOptions []byte) map[string][]string { - optMap := make(map[string][]string) - if pushedOptions == nil || len(pushedOptions) == 0 { - return optMap - } - - optStr := string(pushedOptions[:len(pushedOptions)-1]) - - opts := strings.Split(optStr, ",") - for _, opt := range opts { - vals := strings.Split(opt, " ") - k, v := vals[0], vals[1:] - optMap[k] = v - } - return optMap -} - -func parseProto(p []string, o *Options) error { - if len(p) != 1 { - return fmt.Errorf("%w: %s", errBadCfg, "proto needs one arg") - } - m := p[0] - switch m { - case protoUDP.String(): - o.Proto = UDPMode - case protoTCP.String(): - o.Proto = TCPMode - default: - return fmt.Errorf("%w: bad proto: %s", errBadCfg, m) - - } - return nil -} - -// TODO(ainghazal): all these little functions can be better tested if we return the options object too - -func parseRemote(p []string, o *Options) error { - if len(p) != 2 { - return fmt.Errorf("%w: %s", errBadCfg, "remote needs two args") - } - o.Remote, o.Port = p[0], p[1] - return nil -} - -func parseCipher(p []string, o *Options) error { - if len(p) != 1 { - return fmt.Errorf("%w: %s", errBadCfg, "cipher expects one arg") - } - cipher := p[0] - if !hasElement(cipher, supportedCiphers) { - return fmt.Errorf("%w: unsupported cipher: %s", errBadCfg, cipher) - } - o.Cipher = cipher - return nil -} - -func parseAuth(p []string, o *Options) error { - if len(p) != 1 { - return fmt.Errorf("%w: %s", errBadCfg, "invalid auth entry") - } - auth := p[0] - if !hasElement(auth, supportedAuth) { - return fmt.Errorf("%w: unsupported auth: %s", errBadCfg, auth) - } - o.Auth = auth - return nil -} - -func parseCA(p []string, o *Options, basedir string) error { - e := fmt.Errorf("%w: %s", errBadCfg, "ca expects a valid file") - if len(p) != 1 { - return e - } - ca := toAbs(p[0], basedir) - if sub, _ := isSubdir(basedir, ca); !sub { - return fmt.Errorf("%w: %s", errBadCfg, "ca must be below config path") - } - if !existsFile(ca) { - return e - } - o.CaPath = ca - return nil -} - -func parseCert(p []string, o *Options, basedir string) error { - e := fmt.Errorf("%w: %s", errBadCfg, "cert expects a valid file") - if len(p) != 1 { - return e - } - cert := toAbs(p[0], basedir) - if sub, _ := isSubdir(basedir, cert); !sub { - return fmt.Errorf("%w: %s", errBadCfg, "cert must be below config path") - } - if !existsFile(cert) { - return e - } - o.CertPath = cert - return nil -} - -func parseKey(p []string, o *Options, basedir string) error { - e := fmt.Errorf("%w: %s", errBadCfg, "key expects a valid file") - if len(p) != 1 { - return e - } - key := toAbs(p[0], basedir) - if sub, _ := isSubdir(basedir, key); !sub { - return fmt.Errorf("%w: %s", errBadCfg, "key must be below config path") - } - if !existsFile(key) { - return e - } - o.KeyPath = key - return nil -} - -// parseAuthUser reads credentials from a given file, according to the openvpn -// format (user and pass on a line each). To avoid path traversal / LFI, the -// credentials file is expected to be in a subdirectory of the base dir. -func parseAuthUser(p []string, o *Options, basedir string) error { - e := fmt.Errorf("%w: %s", errBadCfg, "auth-user-pass expects a valid file") - if len(p) != 1 { - return e - } - auth := toAbs(p[0], basedir) - if sub, _ := isSubdir(basedir, auth); !sub { - return fmt.Errorf("%w: %s", errBadCfg, "auth must be below config path") - } - if !existsFile(auth) { - return e - } - creds, err := getCredentialsFromFile(auth) - if err != nil { - return err - } - o.Username, o.Password = creds[0], creds[1] - return nil -} - -func parseCompress(p []string, o *Options) error { - if len(p) > 1 { - return fmt.Errorf("%w: %s", errBadCfg, "compress: only empty/stub options supported") - } - if len(p) == 0 { - o.Compress = compressionEmpty - return nil - } - if p[0] == "stub" { - o.Compress = compressionStub - return nil - } - return fmt.Errorf("%w: %s", errBadCfg, "compress: only empty/stub options supported") -} - -func parseCompLZO(p []string, o *Options) error { - if p[0] != "no" { - return fmt.Errorf("%w: %s", errBadCfg, "comp-lzo: compression not supported") - } - o.Compress = "lzo-no" - return nil -} - -// parseTLSVerMax sets the maximum TLS version. This is currently ignored -// because we're using uTLS to parrot the Client Hello. -func parseTLSVerMax(p []string, o *Options) error { - if o == nil { - return errBadInput - } - if len(p) == 0 { - o.TLSMaxVer = "1.3" - return nil - } - if p[0] == "1.2" { - o.TLSMaxVer = "1.2" - } - return nil -} - -func parseProxyOBFS4(p []string, o *Options) error { - if len(p) != 1 { - return fmt.Errorf("%w: %s", errBadCfg, "proto-obfs4: need a properly configured proxy") - } - // TODO(ainghazal): can validate the obfs4://... scheme here - o.ProxyOBFS4 = p[0] - return nil -} - -var pMap = map[string]interface{}{ - "proto": parseProto, - "remote": parseRemote, - "cipher": parseCipher, - "auth": parseAuth, - "compress": parseCompress, - "comp-lzo": parseCompLZO, - "proxy-obfs4": parseProxyOBFS4, - "tls-version-max": parseTLSVerMax, // this is currently ignored because of uTLS -} - -var pMapDir = map[string]interface{}{ - "ca": parseCA, - "cert": parseCert, - "key": parseKey, - "auth-user-pass": parseAuthUser, -} - -func parseOption(o *Options, dir, key string, p []string, lineno int) error { - switch key { - case "proto", "remote", "cipher", "auth", "compress", "comp-lzo", "tls-version-max", "proxy-obfs4": - fn := pMap[key].(func([]string, *Options) error) - if e := fn(p, o); e != nil { - return e - } - case "ca", "cert", "key", "auth-user-pass": - fn := pMapDir[key].(func([]string, *Options, string) error) - if e := fn(p, o, dir); e != nil { - return e - } - default: - log.Printf("warn: unsupported key in line %d\n", lineno) - } - return nil -} - -// getOptionsFromLines tries to parse all the lines coming from a config file -// and raises validation errors if the values do not conform to the expected -// format. -// the config file supports inline file inclusion for , and . -func getOptionsFromLines(lines []string, dir string) (*Options, error) { - opt := &Options{} - - // tag and inlineBuf are used to parse inline files. - // these follow the format used by the reference openvpn implementation. - // each block (any of ca, key, cert) is marked by a line; lines in between are expected to contain - // the crypto block. - tag := "" - inlineBuf := new(bytes.Buffer) - - for lineno, l := range lines { - if strings.HasPrefix(l, "#") { - continue - } - l = strings.TrimSpace(l) - - // inline certs - if isClosingTag(l) { - // we expect an already existing inlineBuf - e := parseInlineTag(opt, tag, inlineBuf) - if e != nil { - return nil, e - } - tag = "" - inlineBuf = new(bytes.Buffer) - continue - } - if tag != "" { - inlineBuf.Write([]byte(l)) - inlineBuf.Write([]byte("\n")) - continue - } - if isOpeningTag(l) { - if len(inlineBuf.Bytes()) != 0 { - // something wrong: an opening tag should not be found - // when we still have bytes in the inline buffer. - return opt, fmt.Errorf("%w: %s", errBadInput, "tag not closed") - } - tag = parseTag(l) - continue - } - - // parse parts in the same line - p := strings.Split(l, " ") - if len(p) == 0 { - continue - } - var ( - key string - parts []string - ) - if len(p) == 1 { - key = p[0] - } else { - key, parts = p[0], p[1:] - } - e := parseOption(opt, dir, key, parts, lineno) - if e != nil { - return nil, e - } - } - return opt, nil -} - -func isOpeningTag(key string) bool { - switch key { - case "", "", "": - return true - default: - return false - } -} - -func isClosingTag(key string) bool { - switch key { - case "", "", "": - return true - default: - return false - } -} - -func parseTag(tag string) string { - switch tag { - case "", "": - return "ca" - case "", "": - return "cert" - case "", "": - return "key" - default: - return "" - } -} - -// parseInlineTag -func parseInlineTag(o *Options, tag string, buf *bytes.Buffer) error { - b := buf.Bytes() - if len(b) == 0 { - return fmt.Errorf("%w: empty inline tag: %d", errBadInput, len(b)) - } - switch tag { - case "ca": - o.Ca = b - case "cert": - o.Cert = b - case "key": - o.Key = b - default: - return fmt.Errorf("%w: unknown tag: %s", errBadInput, tag) - - } - return nil -} - -// hasElement checks if a given string is present in a string array. returns -// true if that is the case, false otherwise. -func hasElement(el string, arr []string) bool { - for _, v := range arr { - if v == el { - return true - } - } - return false -} - -// existsFile returns true if the file to which the path refers to exists and -// is a regular file. -func existsFile(path string) bool { - statbuf, err := os.Stat(path) - return !errors.Is(err, os.ErrNotExist) && statbuf.Mode().IsRegular() -} - -// getLinesFromFile accepts a path parameter, and return a string array with -// its content and an error if the operation cannot be completed. -func getLinesFromFile(path string) ([]string, error) { - f, err := os.Open(path) //#nosec G304 - defer func() { - if err := f.Close(); err != nil { - logger.Errorf("Error closing file: %s\n", err) - } - }() - if err != nil { - return nil, err - } - - lines := make([]string, 0) - scanner := bufio.NewScanner(f) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - err = scanner.Err() - if err != nil { - return nil, err - } - return lines, nil -} - -// getCredentialsFromFile accepts a path string parameter, and return a string -// array containing the credentials in that file, and an error if the operation -// could not be completed. -func getCredentialsFromFile(path string) ([]string, error) { - lines, err := getLinesFromFile(path) - if err != nil { - return nil, fmt.Errorf("%w: %s", errBadCfg, err) - } - if len(lines) != 2 { - return nil, fmt.Errorf("%w: %s", errBadCfg, "malformed credentials file") - } - if len(lines[0]) == 0 { - return nil, fmt.Errorf("%w: %s", errBadCfg, "empty username in creds file") - } - if len(lines[1]) == 0 { - return nil, fmt.Errorf("%w: %s", errBadCfg, "empty password in creds file") - } - return lines, nil -} - -// toAbs return an absolute path if the given path is not already absolute; to -// do so, it will append the path to the given basedir. -func toAbs(path, basedir string) string { - if filepath.IsAbs(path) { - return path - } - return filepath.Join(basedir, path) -} - -// isSubdir checks if a given path is a subdirectory of another. It returns -// true if that's the case, and any error raise during the check. -func isSubdir(parent, sub string) (bool, error) { - p, err := filepath.Abs(parent) - if err != nil { - return false, err - } - s, err := filepath.Abs(sub) - if err != nil { - return false, err - } - return strings.HasPrefix(s, p), nil -} diff --git a/vpn/options_test.go b/vpn/options_test.go deleted file mode 100644 index 9c7d072f..00000000 --- a/vpn/options_test.go +++ /dev/null @@ -1,971 +0,0 @@ -package vpn - -import ( - "errors" - "os" - fp "path/filepath" - "reflect" - "testing" -) - -func writeDummyCertFiles(d string) { - os.WriteFile(fp.Join(d, "ca.crt"), []byte("dummy"), 0600) - os.WriteFile(fp.Join(d, "cert.pem"), []byte("dummy"), 0600) - os.WriteFile(fp.Join(d, "key.pem"), []byte("dummy"), 0600) -} - -func TestOptions_String(t *testing.T) { - type fields struct { - Remote string - Port string - Proto int - Username string - Password string - Ca string - Cert string - Key string - Compress compression - Cipher string - Auth string - TLSMaxVer string - ProxyOBFS4 string - Log Logger - } - tests := []struct { - name string - fields fields - want string - }{ - { - name: "empty cipher", - fields: fields{}, - want: "", - }, - { - name: "proto tcp", - fields: fields{ - Cipher: "AES-128-GCM", - Auth: "sha512", - Proto: 1, - }, - want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto TCPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client", - }, - { - name: "compress stub", - fields: fields{ - Cipher: "AES-128-GCM", - Auth: "sha512", - Proto: 2, - Compress: compressionStub, - }, - want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client,compress stub", - }, - { - name: "compress lzo-no", - fields: fields{ - Cipher: "AES-128-GCM", - Auth: "sha512", - Proto: 2, - Compress: compressionLZONo, - }, - want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client,lzo-comp no", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := &Options{ - Remote: tt.fields.Remote, - Port: tt.fields.Port, - Proto: tt.fields.Proto, - Username: tt.fields.Username, - Password: tt.fields.Password, - CaPath: tt.fields.Ca, - CertPath: tt.fields.Cert, - KeyPath: tt.fields.Key, - Compress: tt.fields.Compress, - Cipher: tt.fields.Cipher, - Auth: tt.fields.Auth, - TLSMaxVer: tt.fields.TLSMaxVer, - ProxyOBFS4: tt.fields.ProxyOBFS4, - Log: tt.fields.Log, - } - if got := o.String(); got != tt.want { - t.Errorf("Options.string() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestGetOptionsFromLines(t *testing.T) { - d := t.TempDir() - l := []string{ - "remote 0.0.0.0 1194", - "cipher AES-256-GCM", - "auth SHA512", - "ca ca.crt", - "cert cert.pem", - "key cert.pem", - } - writeDummyCertFiles(d) - o, err := getOptionsFromLines(l, d) - if err != nil { - t.Errorf("Good options should not fail: %s", err) - } - if o.Cipher != "AES-256-GCM" { - t.Errorf("Cipher not what expected") - } - if o.Auth != "SHA512" { - t.Errorf("Auth not what expected") - } -} - -func TestGetOptionsFromLinesInlineCerts(t *testing.T) { - l := []string{ - "", - "ca_string", - "", - "", - "cert_string", - "", - "", - "key_string", - "", - } - o, err := getOptionsFromLines(l, "") - if err != nil { - t.Errorf("Good options should not fail: %s", err) - } - if string(o.Ca) != "ca_string\n" { - t.Errorf("Expected ca_string, got: %s.", string(o.Ca)) - } - if string(o.Cert) != "cert_string\n" { - t.Errorf("Expected cert_string, got: %s.", string(o.Cert)) - } - if string(o.Key) != "key_string\n" { - t.Errorf("Expected key_string, got: %s.", string(o.Key)) - } -} - -func TestGetOptionsFromLinesNoFiles(t *testing.T) { - d := t.TempDir() - l := []string{ - "ca ca.crt", - } - _, err := getOptionsFromLines(l, d) - if err == nil { - t.Errorf("Should fail if no files provided") - } -} - -func TestGetOptionsNoCompression(t *testing.T) { - d := t.TempDir() - l := []string{ - "compress", - } - // should fail if no certs - // writeDummyCertFiles(d) - o, err := getOptionsFromLines(l, d) - if err != nil { - t.Errorf("Should not fail: compress") - } - if o.Compress != "empty" { - t.Errorf("Expected compress==empty") - } -} - -func TestGetOptionsCompressionStub(t *testing.T) { - d := t.TempDir() - l := []string{ - "compress stub", - } - // should fail if no certs - // writeDummyCertFiles(d) - o, err := getOptionsFromLines(l, d) - if err != nil { - t.Errorf("Should not fail: compress stub") - } - if o.Compress != "stub" { - t.Errorf("expected compress==stub") - } -} - -func TestGetOptionsCompressionBad(t *testing.T) { - d := t.TempDir() - l := []string{ - "compress foo", - } - // should fail if no certs - // writeDummyCertFiles(d) - _, err := getOptionsFromLines(l, d) - if err == nil { - t.Errorf("Unknown compress: should fail") - } -} - -func TestGetOptionsCompressLZO(t *testing.T) { - d := t.TempDir() - l := []string{ - "comp-lzo no", - } - // should fail if no certs - // writeDummyCertFiles(d) - o, err := getOptionsFromLines(l, d) - if err != nil { - t.Errorf("Should not fail: lzo-comp no") - } - if o.Compress != "lzo-no" { - t.Errorf("expected compress=lzo-no") - } -} - -func TestGetOptionsBadRemote(t *testing.T) { - d := t.TempDir() - l := []string{ - "remote", - } - // should fail if no certs - // writeDummyCertFiles(d) - _, err := getOptionsFromLines(l, d) - if err == nil { - t.Errorf("Should fail: malformed remote") - } -} - -func TestGetOptionsBadCipher(t *testing.T) { - d := t.TempDir() - l := []string{ - "cipher", - } - // should fail if no certs - // writeDummyCertFiles(d) - _, err := getOptionsFromLines(l, d) - if err == nil { - t.Errorf("Should fail: malformed cipher") - } - l = []string{ - "cipher AES-111-CBC", - } - _, err = getOptionsFromLines(l, d) - if err == nil { - t.Errorf("Should fail: bad cipher") - } -} - -func TestGetOptionsComment(t *testing.T) { - d := t.TempDir() - l := []string{ - "cipher AES-256-GCM", - "#cipher AES-128-GCM", - } - // should fail if no certs - // writeDummyCertFiles(d) - o, err := getOptionsFromLines(l, d) - if err != nil { - t.Errorf("Should not fail: commented line") - } - if o.Cipher != "AES-256-GCM" { - t.Errorf("Expected cipher: AES-256-GCM") - } -} - -var dummyConfigFile = []byte(`proto udp -cipher AES-128-GCM -auth SHA1`) - -func writeDummyConfigFile(dir string) (string, error) { - f, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return "", err - } - f.Write(dummyConfigFile) - return f.Name(), nil -} - -func Test_ParseConfigFile(t *testing.T) { - // parse good file - f, err := writeDummyConfigFile(t.TempDir()) - if err != nil { - t.Fatal("ParseConfigFile(): cannot write cert needed for the test") - } - o, err := NewOptionsFromFilePath(f) - if err != nil { - t.Errorf("ParseConfigFile(): expected err=%v, got=%v", nil, err) - } - wantProto := UDPMode - if o.Proto != wantProto { - t.Errorf("ParseConfigFile(): expected Proto=%v, got=%v", wantProto, o.Proto) - } - wantCipher := "AES-128-GCM" - if o.Cipher != wantCipher { - t.Errorf("ParseConfigFile(): expected=%v, got=%v", wantCipher, o.Cipher) - } - - // expect error when parsing a bad filepath - _, err = NewOptionsFromFilePath("") - if err == nil { - t.Errorf("expected error with empty file") - } - - _, err = NewOptionsFromFilePath("http://example.com") - if err == nil { - t.Errorf("expected error with http uri") - } -} - -func Test_parseProto(t *testing.T) { - // empty parts - err := parseProto([]string{}, &Options{}) - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseProto(): wantErr: %v, got %v", wantErr, err) - } - - // two parts - err = parseProto([]string{"foo", "bar"}, &Options{}) - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseProto(): wantErr %v, got %v", wantErr, err) - } - - // udp - opt := &Options{} - err = parseProto([]string{"udp"}, opt) - if !errors.Is(err, nil) { - t.Errorf("parseProto(): wantErr: %v, got %v", nil, err) - } - if opt.Proto != UDPMode { - t.Errorf("parseProto(): wantErr %v, got %v", nil, err) - } - - // tcp - opt = &Options{} - err = parseProto([]string{"tcp"}, opt) - if !errors.Is(err, nil) { - t.Errorf("parseProto(): wantErr: %v, got %v", nil, err) - } - if opt.Proto != TCPMode { - t.Errorf("parseProto(): wantErr %v, got %v", nil, err) - } - - // bad - opt = &Options{} - err = parseProto([]string{"kcp"}, opt) - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseProto(): wantErr: %v, got %v", errBadCfg, err) - } - -} - -func Test_parseProxyOBFS4(t *testing.T) { - // empty parts - err := parseProxyOBFS4([]string{}, &Options{}) - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err) - } - - // obfs4 string - opt := &Options{} - obfs4Uri := "obfs4://foobar" - err = parseProxyOBFS4([]string{obfs4Uri}, opt) - wantErr = nil - if !errors.Is(err, wantErr) { - t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err) - } - if opt.ProxyOBFS4 != obfs4Uri { - t.Errorf("parseProxyOBFS4(): want %v, got %v", obfs4Uri, opt.ProxyOBFS4) - } - -} - -func Test_parseCA(t *testing.T) { - // more than one part should fail - err := parseCA([]string{"one", "two"}, &Options{}, "") - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCA(): want %v, got %v", wantErr, err) - } - - // empty part should fail - err = parseCA([]string{}, &Options{}, "") - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCA(): want %v, got %v", wantErr, err) - } -} - -func Test_parseCert(t *testing.T) { - // more than one part should fail - err := parseCert([]string{"one", "two"}, &Options{}, "") - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCert(): want %v, got %v", wantErr, err) - } - - // empty part should fail - err = parseCert([]string{}, &Options{}, "") - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCert(): want %v, got %v", wantErr, err) - } - - // non-existent cert should fail - err = parseCert([]string{"/tmp/nonexistent"}, &Options{}, "") - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCert(): want %v, got %v", wantErr, err) - } -} - -func Test_parseKey(t *testing.T) { - // more than one part should fail - err := parseKey([]string{"one", "two"}, &Options{}, "") - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseKey(): want %v, got %v", wantErr, err) - } - - // empty part should fail - err = parseKey([]string{}, &Options{}, "") - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseKey(): want %v, got %v", wantErr, err) - } - - // non-existent key should fail - err = parseKey([]string{"/tmp/nonexistent"}, &Options{}, "") - wantErr = errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseKey(): want %v, got %v", wantErr, err) - } -} - -func Test_parseCompress(t *testing.T) { - // more than one part should fail - err := parseCompress([]string{"one", "two"}, &Options{}) - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCompress(): want %v, got %v", wantErr, err) - } -} - -func Test_parseCompLZO(t *testing.T) { - // only "no" is supported - err := parseCompLZO([]string{"yes"}, &Options{}) - wantErr := errBadCfg - if !errors.Is(err, wantErr) { - t.Errorf("parseCompLZO(): want %v, got %v", wantErr, err) - } -} - -func Test_parseOption(t *testing.T) { - // unknown key should not fail - err := parseOption(&Options{}, t.TempDir(), "unknownKey", []string{"a", "b"}, 0) - if err != nil { - t.Errorf("parseOption(): want %v, got %v", nil, err) - } -} - -func Test_newTunnelInfoFromRemoteOptionsString(t *testing.T) { - type args struct { - remoteOpts string - } - tests := []struct { - name string - args args - want *tunnelInfo - }{ - { - name: "parse good tun-mtu", - args: args{ - remoteOpts: "foo bar,tun-mtu 1500", - }, - want: &tunnelInfo{ - mtu: 1500, - }, - }, - { - name: "empty string", - args: args{ - remoteOpts: "", - }, - want: &tunnelInfo{}, - }, - { - name: "empty field", - args: args{ - remoteOpts: "tun-mtu 1200,,", - }, - want: &tunnelInfo{ - mtu: 1200, - }, - }, - { - name: "extra space", - args: args{ - remoteOpts: "tun-mtu 1200", - }, - want: &tunnelInfo{}, - }, - { - name: "mtu not an int", - args: args{ - remoteOpts: "tun-mtu aaa", - }, - want: &tunnelInfo{}, - }, - { - name: "entry with single field", - args: args{ - remoteOpts: "bad", - }, - want: &tunnelInfo{}, - }, - { - name: "entries with single field", - args: args{ - remoteOpts: "bad,worse", - }, - want: &tunnelInfo{}, - }, - { - name: "tun-mtu no value", - args: args{ - remoteOpts: "tun-mtu", - }, - want: &tunnelInfo{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := newTunnelInfoFromRemoteOptionsString(tt.args.remoteOpts); !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseRemoteOptions() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_pushedOptionsAsMap(t *testing.T) { - type args struct { - pushedOptions []byte - } - tests := []struct { - name string - args args - want map[string][]string - }{ - { - name: "do parse tunnel ip", - args: args{[]byte("foo bar,ifconfig 10.0.0.3,")}, - want: map[string][]string{ - "foo": []string{"bar"}, - "ifconfig": []string{"10.0.0.3"}, - }, - }, - { - name: "empty string", - args: args{[]byte{}}, - want: map[string][]string{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := pushedOptionsAsMap(tt.args.pushedOptions); !reflect.DeepEqual(got, tt.want) { - t.Errorf("pushedOptionsAsMap() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_parseAuth(t *testing.T) { - type args struct { - p []string - o *Options - } - tests := []struct { - name string - args args - wantErr error - }{ - { - name: "should fail with empty array", - args: args{[]string{}, &Options{}}, - wantErr: errBadCfg, - }, - { - name: "should fail with 2-element array", - args: args{[]string{"foo", "bar"}, &Options{}}, - wantErr: errBadCfg, - }, - { - name: "should fail with lowercase option", - args: args{[]string{"sha1"}, &Options{}}, - wantErr: errBadCfg, - }, - { - name: "should fail with unknown option", - args: args{[]string{"SHA666"}, &Options{}}, - wantErr: errBadCfg, - }, - { - name: "should not fail with good option", - args: args{[]string{"SHA512"}, &Options{}}, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := parseAuth(tt.args.p, tt.args.o); !errors.Is(err, tt.wantErr) { - t.Errorf("parseAuth() error = %v, wantErr %v", err, tt.wantErr) - } - - }) - } -} - -func Test_parseAuthUser(t *testing.T) { - makeCreds := func(credStr string) string { - f, err := os.CreateTemp(t.TempDir(), "tmpfile-") - if err != nil { - t.Fatal(err) - } - if _, err := f.Write([]byte(credStr)); err != nil { - t.Fatal(err) - } - return f.Name() - } - - baseDir := func() string { - return os.TempDir() - } - - type args struct { - p []string - o *Options - d string - } - tests := []struct { - name string - args args - wantErr error - }{ - { - name: "parse good auth", - args: args{ - p: []string{makeCreds("foo\nbar\n")}, - o: &Options{}, - d: baseDir(), - }, - wantErr: nil, - }, - { - name: "path traversal should fail", - args: args{ - p: []string{"/tmp/../etc/passwd"}, - o: &Options{}, - d: baseDir(), - }, - wantErr: errBadCfg, - }, - { - name: "parse empty file should fail", - args: args{ - p: []string{""}, - o: &Options{}, - d: baseDir(), - }, - wantErr: errBadCfg, - }, - { - name: "parse empty parts should fail", - args: args{ - p: []string{}, - o: &Options{}, - d: baseDir(), - }, - wantErr: errBadCfg, - }, - { - name: "parse less than two lines should fail", - args: args{ - p: []string{makeCreds("foo\n")}, - o: &Options{}, - d: baseDir(), - }, - wantErr: errBadCfg, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := parseAuthUser(tt.args.p, tt.args.o, tt.args.d); !errors.Is(err, tt.wantErr) { - t.Errorf("parseAuthUser() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -// TODO(ainghazal): return options object so that it's testable too -func Test_parseTLSVerMax(t *testing.T) { - type args struct { - p []string - o *Options - } - tests := []struct { - name string - args args - wantErr error - }{ - { - name: "nil options should fail", - args: args{}, - wantErr: errBadInput, - }, - { - name: "default", - args: args{o: &Options{}}, - wantErr: nil, - }, - { - name: "default with good tls opt", - args: args{p: []string{"1.2"}, o: &Options{}}, - wantErr: nil, - }, - { - // FIXME this case should probably fail - name: "default with too many parts", - args: args{p: []string{"1.2", "1.3"}, o: &Options{}}, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := parseTLSVerMax(tt.args.p, tt.args.o); !errors.Is(err, tt.wantErr) { - t.Errorf("parseTLSVerMax() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_proto_String(t *testing.T) { - tests := []struct { - name string - p proto - want string - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.p.String(); got != tt.want { - t.Errorf("proto.String() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_getCredentialsFromFile(t *testing.T) { - makeCreds := func(credStr string) string { - f, err := os.CreateTemp(t.TempDir(), "tmpfile-") - if err != nil { - t.Fatal(err) - } - if _, err := f.Write([]byte(credStr)); err != nil { - t.Fatal(err) - } - return f.Name() - } - - type args struct { - path string - } - tests := []struct { - name string - args args - want []string - wantErr error - }{ - { - name: "should fail with non-existing file", - args: args{"/tmp/nonexistent"}, - want: nil, - wantErr: errBadCfg, - }, - { - name: "should fail with empty file", - args: args{makeCreds("")}, - want: nil, - wantErr: errBadCfg, - }, - { - name: "should fail with empty user", - args: args{makeCreds("\n\n")}, - want: nil, - wantErr: errBadCfg, - }, - { - name: "should fail with empty pass", - args: args{makeCreds("user\n\n")}, - want: nil, - wantErr: errBadCfg, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getCredentialsFromFile(tt.args.path) - if !errors.Is(err, tt.wantErr) { - t.Errorf("getCredentialsFromFile() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getCredentialsFromFile() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_isSubdir(t *testing.T) { - type args struct { - parent string - sub string - } - tests := []struct { - name string - args args - want bool - wantErr bool - }{ - { - name: "sunny path", - args: args{ - parent: "/foo/bar", - sub: "/foo/bar/baz", - }, - want: true, - wantErr: false, - }, - { - name: "same dir", - args: args{ - parent: "/foo/bar", - sub: "/foo/bar", - }, - want: true, - wantErr: false, - }, - { - name: "same dir w/ slash", - args: args{ - parent: "/foo/bar", - sub: "/foo/bar/", - }, - want: true, - wantErr: false, - }, - { - name: "not subdir", - args: args{ - parent: "/foo/bar", - sub: "/foo", - }, - want: false, - wantErr: false, - }, - { - name: "path traversal", - args: args{ - parent: "/foo/bar", - sub: "/foo/bar/./../", - }, - want: false, - wantErr: false, - }, - { - name: "path traversal with .", - args: args{ - parent: ".", - sub: "/etc/", - }, - want: false, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := isSubdir(tt.args.parent, tt.args.sub) - if (err != nil) != tt.wantErr { - t.Errorf("isSubdir() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("isSubdir() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_newTunnelInfoFromPushedOptions(t *testing.T) { - type args struct { - opts map[string][]string - } - tests := []struct { - name string - args args - want *tunnelInfo - }{ - { - name: "get route", - args: args{ - map[string][]string{ - "route": []string{"1.1.1.1"}, - }, - }, - want: &tunnelInfo{ - gw: "1.1.1.1", - }, - }, - { - name: "get route from gw", - args: args{ - map[string][]string{ - "route-gateway": []string{"1.1.2.2"}, - }, - }, - want: &tunnelInfo{ - gw: "1.1.2.2", - }, - }, - { - name: "get ip", - args: args{ - map[string][]string{ - "ifconfig": []string{"1.1.3.3", "foo", "bar"}, - }, - }, - want: &tunnelInfo{ - ip: "1.1.3.3", - }, - }, - { - name: "get ip and route", - args: args{ - map[string][]string{ - "ifconfig": []string{"1.1.3.3", "foo", "bar"}, - "route": []string{"1.1.1.1"}, - "route-gateway": []string{"1.1.2.2"}, - }, - }, - want: &tunnelInfo{ - ip: "1.1.3.3", - gw: "1.1.1.1", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := newTunnelInfoFromPushedOptions(tt.args.opts); !reflect.DeepEqual(got, tt.want) { - t.Errorf("newTunnelInfoFromPushedOptions() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/vpn/packet.go b/vpn/packet.go deleted file mode 100644 index e5f012ba..00000000 --- a/vpn/packet.go +++ /dev/null @@ -1,382 +0,0 @@ -package vpn - -// -// Encode and decode packets according to the OpenVPN protocol. -// - -import ( - "bytes" - "errors" - "fmt" - "io" -) - -const ( - stNothing = iota - stControlChannelOpen - stControlMessageSent - stKeyExchanged - stPullRequestSent - stOptionsPushed - stInitialized - stDataReady -) - -const ( - pControlHardResetClientV1 = iota + 1 - pControlHardResetServerV1 // 2 - pControlSoftResetV1 // 3 - pControlV1 // 4 - pACKV1 // 5 - pDataV1 // 6 - pControlHardResetClientV2 // 7 - pControlHardResetServerV2 // 8 - pDataV2 // 9 -) - -const ( - UDPMode = iota - TCPMode -) - -var ( - errEmptyPayload = errors.New("empty payload") - errBadKeyMethod = errors.New("unsupported key method") - errBadControlMessage = errors.New("bad message") - errBadServerReply = errors.New("bad server reply") - errBadAuth = errors.New("server says: bad auth") - - controlMessageHeader = []byte{0x00, 0x00, 0x00, 0x00} - pingPayload = []byte{0x2A, 0x18, 0x7B, 0xF3, 0x64, 0x1E, 0xB4, 0xCB, 0x07, 0xED, 0x2D, 0x0A, 0x98, 0x1F, 0xC7, 0x48} - - IV_Ver = "2.5.5" // OpenVPN version compat that we declare to the server - IV_Proto = "2" // IV_PROTO declared to the server. We need to be sure to enable the peer-id bit to use P_DATA_V2. -) - -// sessionID is the session identifier. -type sessionID [8]byte - -// packetID is a packet identifier. -type packetID uint32 - -// ackArray holds the identifiers of packets to ack. -type ackArray []packetID - -// packet represents a packet according to the OpenVPN protocol. -type packet struct { - - // id is the packet-id for replay protection. - // According to the spec: "4 or 8 bytes, includes sequence number and optional time_t timestamp". - // We do not use the timestamp. - id packetID - - // opcode is the packet message type (a P_* constant; high 5-bits of - // the first packet byte). - opcode byte - - // The key_id refers to an already negotiated TLS session. - // This is the shortened version of the key-id (low 3-bits of the first - // packet byte). - keyID byte - - // The 64 bit form (of the key) is referred to as a session_id. - localSessionID sessionID - remoteSessionID sessionID - payload []byte - acks ackArray -} - -// parsePacketFromBytes produces a packet after parsing the common header. -// In TCP mode, it is assumed that the packet length (part of the header) has -// already been stripped out. -func parsePacketFromBytes(buf []byte) (*packet, error) { - if len(buf) < 2 { - return &packet{}, errBadInput - } - opcode := buf[0] >> 3 - keyID := buf[0] & 0x07 - - var payload = []byte{} - - switch opcode { - case pDataV2: - payload = buf[4:] - default: - payload = buf[1:] - } - - // TODO missing peerID - p := &packet{ - opcode: opcode, - keyID: keyID, - payload: payload, - } - return parsePacket(p) -} - -// newPacketFromPayload returns a packet from the passed arguments: opcode, -// keyID and a raw byte array payload. -func newPacketFromPayload(opcode uint8, keyID uint8, payload []byte) *packet { - p := &packet{ - opcode: opcode, - keyID: keyID, - payload: payload, - } - return p -} - -// Bytes returns a byte array that is ready to be sent on the wire. -func (packet *packet) Bytes() []byte { - buf := &bytes.Buffer{} - buf.WriteByte((packet.opcode << 3) | (packet.keyID & 0x07)) - buf.Write(packet.localSessionID[:]) - // we write a byte with the number of acks, and then - // serialize each ack. - nAcks := len(packet.acks) - if nAcks > 255 { - logger.Warnf("packet %d had too many acks (%d)", packet.id, nAcks) - nAcks = 255 - } - buf.WriteByte(byte(nAcks)) - for i := 0; i < nAcks; i++ { - bufWriteUint32(buf, uint32(packet.acks[i])) - } - // remote session id - if len(packet.acks) > 0 { - buf.Write(packet.remoteSessionID[:]) - } - if packet.opcode != pACKV1 { - bufWriteUint32(buf, uint32(packet.id)) - } - // payload - buf.Write(packet.payload) - return buf.Bytes() -} - -// isACK returns true if the packet is an ACK packet. -func (p *packet) isACK() bool { - return p.opcode == byte(pACKV1) -} - -// isControl returns true if the packet is any of the control types. -func (p *packet) isControl() bool { - switch p.opcode { - case byte(pControlHardResetServerV2), byte(pControlV1): - return true - default: - return false - } -} - -// isControlV1 returns true if the packet is of the control v1 type. -func (p *packet) isControlV1() bool { - return p.opcode == byte(pControlV1) -} - -// isData returns true if the packet is of data type. -func (p *packet) isData() bool { - switch p.opcode { - case byte(pDataV1), byte(pDataV2): - return true - default: - return false - } -} - -// parse tries to parse the payload of the packet, and returns a packet and an -// error. it does only parse control packets (for now - parsing of data packets -// is done on the data handler methods). -func parsePacket(p *packet) (*packet, error) { - if p.isControl() { - return parseControlPacket(p) - } - return p, nil -} - -// parseControlPacket parses the contents of a control packet, and returns a -// packet and an error. -func parseControlPacket(p *packet) (*packet, error) { - if len(p.payload) == 0 { - return p, errEmptyPayload - } - if !p.isControl() { - return p, fmt.Errorf("%w: %s", errBadInput, "expected control packet") - } - - buf := bytes.NewBuffer(p.payload) - - // local session id - _, err := io.ReadFull(buf, p.localSessionID[:]) - if err != nil { - return p, fmt.Errorf("%w: bad sessionID: %s", errBadInput, err) - } - - // ack array - ackBuf, err := buf.ReadByte() - if err != nil { - return p, fmt.Errorf("%w: bad ack: %s", errBadInput, err) - } - nAcks := int(ackBuf) - p.acks = make([]packetID, nAcks) - for i := 0; i < nAcks; i++ { - val, err := bufReadUint32(buf) - if err != nil { - return p, fmt.Errorf("%w: cannot parse ack id: %s", errBadInput, err) - } - p.acks[i] = packetID(val) - } - - // remote session id - if nAcks > 0 { - _, err = io.ReadFull(buf, p.remoteSessionID[:]) - if err != nil { - return p, fmt.Errorf("%w: bad remote sessionID: %s", errBadInput, err) - } - } - - // packet id - if p.opcode != pACKV1 { - val, err := bufReadUint32(buf) - if err != nil { - return p, fmt.Errorf("%w: bad packetID: %s", errBadInput, err) - } - p.id = packetID(val) - } - - // payload - p.payload = buf.Bytes() - return p, nil -} - -// isPingPacket returns true if the packet payload matches a hard-coded ping -// payload. -func isPing(b []byte) bool { - return bytes.Equal(b, pingPayload) -} - -// serverControlMessage is sent by the server. it contains reply to the auth -// and push requests. we initialize client's internal state after parsing the -// fields contained in here. -type serverControlMessage struct { - payload []byte -} - -// valid returns true if the packet has a control-message header. -func (sc *serverControlMessage) valid() bool { - if len(sc.payload) < 4 { - return false - } - return bytes.Equal(sc.payload[:4], controlMessageHeader) -} - -// newServerControlMessageFromBytes returns a server control message from the -// passed byte array. -func newServerControlMessageFromBytes(buf []byte) *serverControlMessage { - return &serverControlMessage{buf} -} - -// parseControlMessage gets a server control message and returns the value for -// the remote key, the server remote options, and an error indicating if the -// operation could not be completed. -func parseServerControlMessage(sc *serverControlMessage) (*keySource, string, error) { - if !sc.valid() { - return nil, "", fmt.Errorf("%w: %s", errBadControlMessage, "bad header") - } - if len(sc.payload) < 71 { - return nil, "", fmt.Errorf("%w: bad len from server:%d", errBadControlMessage, len(sc.payload)) - } - keyMethod := sc.payload[4] - if keyMethod != 2 { - return nil, "", fmt.Errorf("%w: %d", errBadKeyMethod, keyMethod) - - } - var random1, random2 [32]byte - // first chunk of random bytes - copy(random1[:], sc.payload[5:37]) - // second chunk of random bytes - copy(random2[:], sc.payload[37:69]) - - options, err := decodeOptionStringFromBytes(sc.payload[69:]) - if err != nil { - return nil, "", fmt.Errorf("%w:%s", errBadControlMessage, "bad options string") - } - - logger.Debugf("Remote opts: %s", options) - remoteKey := &keySource{r1: random1, r2: random2} - return remoteKey, options, nil -} - -// encodeClientControlMessage returns a byte array with the payload for a control channel packet. -// This is the packet that the client sends to the server with the key -// material, local options and credentials (if username+password authentication is used). -func encodeClientControlMessageAsBytes(k *keySource, o *Options) ([]byte, error) { - opt, err := encodeOptionStringToBytes(o.String()) - if err != nil { - return nil, err - } - user, err := encodeOptionStringToBytes(string(o.Username)) - if err != nil { - return nil, err - } - pass, err := encodeOptionStringToBytes(string(o.Password)) - if err != nil { - return nil, err - } - - var out bytes.Buffer - out.Write(controlMessageHeader) - out.WriteByte(0x02) // key method (2) - out.Write(k.Bytes()) - out.Write(opt) - out.Write(user) - out.Write(pass) - - // we could send IV_PLAT too, but afaik declaring the platform does not - // make any difference for our purposes. - rawInfo := fmt.Sprintf("IV_VER=%s\nIV_PROTO=%s\n", IV_Ver, IV_Proto) - peerInfo, _ := encodeOptionStringToBytes(rawInfo) - out.Write(peerInfo) - return out.Bytes(), nil -} - -// serverHard reset contains the payload for a serverHardReset message type. -type serverHardReset struct { - payload []byte -} - -// newServerHardReset returns a serverHardReset message type, and an error if -// the passed payload is empty. -func newServerHardReset(b []byte) (*serverHardReset, error) { - if len(b) == 0 { - return nil, fmt.Errorf("%w: %s", errBadReset, "zero len") - } - p := &serverHardReset{b} - return p, nil -} - -// parseServerHardResetPacket returns the sessionID received from the server, or an -// error if we could not parse the message. -func parseServerHardResetPacket(p *serverHardReset) (sessionID, error) { - if len(p.payload) < 10 { - return sessionID{}, fmt.Errorf("%w: %s", errBadReset, "not enough bytes") - } - // BUG: this function assumes keyID == 0 - if p.payload[0] != 0x40 { - return sessionID{}, fmt.Errorf("%w: %s", errBadReset, "bad header") - } - var rs sessionID - copy(rs[:], p.payload[1:9]) - return rs, nil -} - -// newACKPacket returns a packet with the P_ACK_V1 opcode. -func newACKPacket(ackID packetID, s *session) *packet { - acks := []packetID{ackID} - p := &packet{ - opcode: pACKV1, - localSessionID: s.LocalSessionID, - remoteSessionID: s.RemoteSessionID, - acks: acks, - } - return p -} diff --git a/vpn/packet_test.go b/vpn/packet_test.go deleted file mode 100644 index 28ac9c99..00000000 --- a/vpn/packet_test.go +++ /dev/null @@ -1,286 +0,0 @@ -package vpn - -import ( - "bytes" - "errors" - "reflect" - "testing" -) - -func Test_newACKPacket(t *testing.T) { - type args struct { - ackID packetID - s *session - } - tests := []struct { - name string - args args - want *packet - }{ - {"good_ack", - args{42, &session{}}, - &packet{opcode: pACKV1, acks: []packetID{42}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := newACKPacket(tt.args.ackID, tt.args.s); !reflect.DeepEqual(got, tt.want) { - t.Errorf("newACKPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_isPing(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want bool - }{ - {"good ping", args{pingPayload}, true}, - {"bad ping", args{append(pingPayload, 0x00)}, false}, - {"empty", args{[]byte{}}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isPing(tt.args.b); got != tt.want { - t.Errorf("isPing() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_newServerControlMessageFromBytes(t *testing.T) { - payload := []byte{0xff, 0xfe, 0xfd} - m := newServerControlMessageFromBytes(payload) - if !bytes.Equal(m.payload, payload) { - t.Errorf("newServerControlMessageFromBytes() = got %v, want %v", m.payload, payload) - } -} - -func Test_serverControlMessage_valid(t *testing.T) { - type fields struct { - payload []byte - } - tests := []struct { - name string - fields fields - want bool - }{ - { - "good control message", - fields{controlMessageHeader}, - true, - }, - { - "bad control message", - fields{[]byte{0x00, 0x00, 0x00, 0x01}}, - false, - }, - { - "empty control message", - fields{[]byte{}}, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sc := &serverControlMessage{ - payload: tt.fields.payload, - } - if got := sc.valid(); got != tt.want { - t.Errorf("serverControlMessage.valid() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_encodeClientControlMessageAsBytes(t *testing.T) { - - var manyA, manyB [32]byte - var manyC [48]byte - - copy(manyA[:], bytes.Repeat([]byte{0x65}, 32)) - copy(manyB[:], bytes.Repeat([]byte{0x66}, 32)) - copy(manyC[:], bytes.Repeat([]byte{0x67}, 48)) - - type args struct { - k *keySource - o *Options - } - tests := []struct { - name string - args args - want []byte - wantErr bool - }{ - { - "empty options", - args{ - &keySource{manyA, manyB, manyC}, - &Options{}, - }, - func() []byte { - buf := []byte{0x00, 0x00, 0x00, 0x00, 0x02} - buf = append(buf, manyC[:]...) - buf = append(buf, manyA[:]...) - buf = append(buf, manyB[:]...) - buf = append(buf, []byte{ - // options, null-terminated - 0x00, 0x01, 0x00, - // auth strings - 0x00, 0x01, 0x00, - 0x00, 0x01, 0x00}...) - buf = append(buf, []byte{0x00, 0x19}...) - buf = append(buf, []byte("IV_VER=2.5.5\nIV_PROTO=2\n")...) - buf = append(buf, 0x00) - return buf - }(), - false, - }, - { - "good options", - args{ - &keySource{manyA, manyB, manyC}, - &Options{Cipher: "AES-128-CBC"}, - }, - func() []byte { - buf := []byte{0x00, 0x00, 0x00, 0x00, 0x02} - buf = append(buf, manyC[:]...) - buf = append(buf, manyA[:]...) - buf = append(buf, manyB[:]...) - buf = append(buf, []byte{0x00, 0x74}...) - buf = append(buf, []byte("V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-CBC,auth ,keysize 128,key-method 2,tls-client")...) - // null-terminate + auth - buf = append(buf, []byte{ - 0x00, - 0x00, 0x01, 0x00, - 0x00, 0x01, 0x00}...) - buf = append(buf, []byte{0x00, 0x19}...) - buf = append(buf, []byte("IV_VER=2.5.5\nIV_PROTO=2\n")...) - buf = append(buf, 0x00) - return buf - }(), - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := encodeClientControlMessageAsBytes(tt.args.k, tt.args.o) - if (err != nil) != tt.wantErr { - t.Errorf("encodeClientControlMessageAsBytes() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("encodeClientControlMessageAsBytes() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_newServerHardReset(t *testing.T) { - type args struct { - b []byte - } - tests := []struct { - name string - args args - want *serverHardReset - wantErr error - }{ - { - name: "good payload", - args: args{[]byte("not a payload")}, - want: &serverHardReset{[]byte("not a payload")}, - wantErr: nil, - }, - { - name: "empty", - args: args{[]byte{}}, - want: nil, - wantErr: errBadReset, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newServerHardReset(tt.args.b) - if !errors.Is(err, tt.wantErr) { - t.Errorf("newServerHardReset() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newServerHardReset() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_parseServerHardResetPacket(t *testing.T) { - - var goodSessionID sessionID - goodPayload := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} - shortPayload := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} - copy(goodSessionID[:], goodPayload) - - type args struct { - p *serverHardReset - } - tests := []struct { - name string - args args - want sessionID - wantErr error - }{ - { - name: "good server hard reset", - args: args{ - &serverHardReset{ - payload: append([]byte{0x40}, goodPayload...), - }, - }, - want: goodSessionID, - wantErr: nil, - }, - { - name: "payload too short should fail", - args: args{ - &serverHardReset{ - payload: append([]byte{0x40}, shortPayload...), - }, - }, - want: sessionID{}, - wantErr: errBadReset, - }, - { - name: "bad header should fail", - args: args{ - &serverHardReset{ - payload: append([]byte{0x41}, goodPayload...), - }, - }, - want: sessionID{}, - wantErr: errBadReset, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseServerHardResetPacket(tt.args.p) - if !errors.Is(err, tt.wantErr) { - t.Errorf("parseServerHardResetPacket() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseServerHardResetPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -// Regression test for MIV-01-001 -func Test_Crash_parseServerHardResetPacket(t *testing.T) { - p := &serverHardReset{} - parseServerHardResetPacket(p) -} diff --git a/vpn/tls.go b/vpn/tls.go deleted file mode 100644 index 0651847d..00000000 --- a/vpn/tls.go +++ /dev/null @@ -1,267 +0,0 @@ -package vpn - -// -// TLS initialization and read/write wrappers. -// -// TODO(ainghazal): for the time being, we're using uTLS to parrot a ClientHello that can reasonably blend -// with a recent openvpn+openssl client (2.5.x). We might want to revisit this -// in the near future and perhaps expose other TLS Factories. -// - -import ( - "crypto/x509" - "encoding/hex" - "errors" - "fmt" - "io/ioutil" - "net" - - tls "github.com/refraction-networking/utls" -) - -var ( - // ErrBadTLSHandshake is returned when the OpenVPN handshake failed. - ErrBadTLSHandshake = errors.New("handshake failure") - // ErrBadCA is returned when the CA file cannot be found or is not valid. - ErrBadCA = errors.New("bad ca conf") - // ErrBadKeypair is returned when the key or cert file cannot be found or is not valid. - ErrBadKeypair = errors.New("bad keypair conf") - // ErrBadParrot is returned for errors during TLS parroting - ErrBadParrot = errors.New("cannot parrot") - // ErrCannotVerifyCertChain is returned for certificate chain validation errors. - ErrCannotVerifyCertChain = errors.New("cannot verify chain") -) - -// certVerifyOptionsNoCommonNameCheck returns a x509.VerifyOptions initialized with -// an empty string for the DNSName. This allows to skip CN verification. -func certVerifyOptionsNoCommonNameCheck() x509.VerifyOptions { - return x509.VerifyOptions{DNSName: ""} -} - -// certVerifyOptions is the options factory that the customVerify function will -// use; by default it configures VerifyOptions to skip the DNSName check. -var certVerifyOptions = certVerifyOptionsNoCommonNameCheck - -// certPaths holds the paths for the cert, key, and ca used for OpenVPN -// certificate authentication. -type certPaths struct { - certPath string - keyPath string - caPath string -} - -// loadCertAndCAFromPath parses the PEM certificates contained in the paths pointed by -// the passed certPaths and return a certConfig with the client and CA certificates. -func loadCertAndCAFromPath(pth certPaths) (*certConfig, error) { - ca := x509.NewCertPool() - caData, err := ioutil.ReadFile(pth.caPath) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrBadCA, err) - } - ok := ca.AppendCertsFromPEM(caData) - if !ok { - return nil, fmt.Errorf("%w: %s", ErrBadCA, "cannot parse ca cert") - } - - cfg := &certConfig{ca: ca} - if pth.certPath != "" && pth.keyPath != "" { - cert, err := tls.LoadX509KeyPair(pth.certPath, pth.keyPath) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrBadKeypair, err) - } - cfg.cert = cert - } - return cfg, nil -} - -// certBytes holds the byte arrays for the cert, key, and ca used for OpenVPN -// certificate authentication. -type certBytes struct { - cert []byte - key []byte - ca []byte -} - -// loadCertAndCAFromBytes parses the PEM certificates from the byte arrays in the -// the passed certBytes, and return a certConfig with the client and CA certificates. -func loadCertAndCAFromBytes(crt certBytes) (*certConfig, error) { - ca := x509.NewCertPool() - ok := ca.AppendCertsFromPEM(crt.ca) - if !ok { - return nil, fmt.Errorf("%w: %s", ErrBadCA, "cannot parse ca cert") - } - cfg := &certConfig{ca: ca} - if crt.cert != nil && crt.key != nil { - cert, err := tls.X509KeyPair(crt.cert, crt.key) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrBadKeypair, err) - } - cfg.cert = cert - } - return cfg, nil -} - -// authorityPinner is any object from which we can obtain a certpool containing -// a pinned Certificate Authority for verification. -type authorityPinner interface { - authority() *x509.CertPool -} - -// certConfig holds the parsed certificate and CA used for OpenVPN mutual -// certificate authentication. -type certConfig struct { - cert tls.Certificate - ca *x509.CertPool -} - -// newCertConfigFromOptions is a constructor that returns a certConfig object initialized -// from the paths specified in the passed Options object, and an error if it -// could not be properly built. -func newCertConfigFromOptions(o *Options) (*certConfig, error) { - var cfg *certConfig - var err error - if o.certsFromPath() { - cfg, err = loadCertAndCAFromPath(certPaths{ - certPath: o.CertPath, - keyPath: o.KeyPath, - caPath: o.CaPath, - }) - } else { - cfg, err = loadCertAndCAFromBytes(certBytes{ - cert: o.Cert, - key: o.Key, - ca: o.Ca, - }) - } - return cfg, err -} - -// authority implements authorityPinner interface. -func (c *certConfig) authority() *x509.CertPool { - return c.ca -} - -// ensure certConfig implements authorityPinner. -var _ authorityPinner = &certConfig{} - -// verifyFun is the type expected by the VerifyPeerCertificate callback in tls.Config. -type verifyFun func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - -// customVerifyFactory returns a verifyFun callback that will verify any received certificates -// against the ca provided by the pased implementation of authorityPinner -func customVerifyFactory(pinner authorityPinner) verifyFun { - // customVerify is a version of the verification routines that does not try to verify - // the Common Name, since we don't know it a priori for a VPN gateway. Returns - // an error if the verification fails. - // From tls/common documentation: If normal verification is disabled by - // setting InsecureSkipVerify, [...] then this callback will be considered but - // the verifiedChains argument will always be nil. - customVerify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - // we assume (from docs) that we're always given the - // leaf certificate as the first cert in the array. - leaf, _ := x509.ParseCertificate(rawCerts[0]) - if leaf == nil { - return fmt.Errorf("%w: %s", ErrCannotVerifyCertChain, "nothing to verify") - } - // By default has DNSName verification disabled. - opts := certVerifyOptions() - // Set the configured CA(s) as the certificate pool to verify against. - opts.Roots = pinner.authority() - - if _, err := leaf.Verify(opts); err != nil { - return fmt.Errorf("%w: %s", ErrCannotVerifyCertChain, err) - } - return nil - } - return customVerify -} - -// initTLS returns a tls.Config matching the VPN options. Internally, it uses -// the verify function returned by the global customVerifyFactory, -// verification function since verifying the ServerName does not make sense in -// the context of establishing a VPN session: we perform mutual TLS -// Authentication with the custom CA. -func initTLS(session *session, cfg *certConfig) (*tls.Config, error) { - if session == nil || cfg == nil { - return nil, fmt.Errorf("%w: %s", errBadInput, "nil args") - } - - customVerify := customVerifyFactory(cfg) - - tlsConf := &tls.Config{ - // the certificate we've loaded from the config file - Certificates: []tls.Certificate{cfg.cert}, - // crypto/tls wants either ServerName or InsecureSkipVerify set ... - InsecureSkipVerify: true, - // ...but we pass our own verification function that verifies against the CA and ignores the ServerName - VerifyPeerCertificate: customVerify, - // disable DynamicRecordSizing to lower distinguishability. - DynamicRecordSizingDisabled: true, - // uTLS does not pick min/max version from the passed spec - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS13, - } //#nosec G402 - - return tlsConf, nil -} - -// tlsHandshake performs the TLS handshake over the control channel, and return -// the TLS Client as a net.Conn; returns also any error during the handshake. -func tlsHandshake(tlsConn *controlChannelTLSConn, tlsConf *tls.Config) (net.Conn, error) { - tlsClient, err := tlsFactoryFn(tlsConn, tlsConf) - if err != nil { - return nil, err - } - if err := tlsClient.Handshake(); err != nil { - return nil, fmt.Errorf("%w: %s", ErrBadTLSHandshake, err) - } - return tlsClient, nil -} - -// handshaker is a custom interface that we define here to be able to mock -// the tls.Conn implementation. -type handshaker interface { - net.Conn - Handshake() error -} - -// defaultTLSFactory returns an implementer of the handshaker interface; that -// is, the default tls.Client factory; and an error. -// we're not using the default factory right now, but it comes handy to be able -// to compare the fingerprints with a golang TLS handshake. -func defaultTLSFactory(conn net.Conn, config *tls.Config) (handshaker, error) { - c := tls.Client(conn, config) - return c, nil -} - -// vpnClientHelloHex is the hexadecimal representation of a capture from the reference openvpn implementation. -// openvpn=2.5.5,openssl=3.0.2 -// You can use https://github.com/ainghazal/sniff/tree/main/clienthello to -// analyze a ClientHello from the wire or pcap. -var vpnClientHelloHex = `1603010114010001100303534e0a0f2687b240f7c7dfbb51c4aac33639f28173aa5d7bcebb159695ab0855208b835bf240a83df66885d6747b5bbf1b631e8c34ae469c629d7eb76e247128eb0032130213031301c02cc030009fcca9cca8ccaac02bc02f009ec024c028006bc023c0270067c00ac0140039c009c013003300ff01000095000b000403000102000a00160014001d0017001e00190018010001010102010301040016000000170000000d002a0028040305030603080708080809080a080b080408050806040105010601030303010302040205020602002b0009080304030303020301002d00020101003300260024001d0020a10bc24becb583293c317220e6725205d3a177a4a974090f6ffcf13a43da7035` - -// parrotTLSFactory returns an implementer of the handshaker interface; in this -// case, a parroting implementation; and an error. -func parrotTLSFactory(conn net.Conn, config *tls.Config) (handshaker, error) { - fingerprinter := &tls.Fingerprinter{AllowBluntMimicry: true} - rawOpenVPNClientHelloBytes, err := hex.DecodeString(vpnClientHelloHex) - if err != nil { - return nil, fmt.Errorf("%w: cannot decode raw fingerprint: %s", ErrBadParrot, err) - } - generatedSpec, err := fingerprinter.FingerprintClientHello(rawOpenVPNClientHelloBytes) - if err != nil { - return nil, fmt.Errorf("%w: fingerprinting failed: %s", ErrBadParrot, err) - } - client := tls.UClient(conn, config, tls.HelloCustom) - if err := client.ApplyPreset(generatedSpec); err != nil { - return nil, fmt.Errorf("%w: cannot apply spec: %s", ErrBadParrot, err) - } - return client, nil -} - -// global variables to allow monkeypatching in tests. -var ( - initTLSFn = initTLS - tlsFactoryFn = parrotTLSFactory - tlsHandshakeFn = tlsHandshake -) diff --git a/vpn/tls_test.go b/vpn/tls_test.go deleted file mode 100644 index eae2c80c..00000000 --- a/vpn/tls_test.go +++ /dev/null @@ -1,897 +0,0 @@ -package vpn - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "math/big" - "net" - "os" - "reflect" - "testing" - "time" - - "github.com/google/martian/mitm" - "github.com/ooni/minivpn/vpn/mocks" - tls "github.com/refraction-networking/utls" -) - -func makeDummyOptionsForCertPaths() *Options { - return &Options{ - CertPath: "aa", - KeyPath: "aa", - CaPath: "aa", - } -} - -// TODO(ainghazal): many of these tests right now create certs, which is costly -// bacause creating a CA each time is expensive. -// I should mock any certs for the tests that do not need to actually verify -// the cert chain. - -func Test_initTLS(t *testing.T) { - type args struct { - session *session - cfg *certConfig - } - tests := []struct { - name string - args args - want *tls.Config - wantErr error - }{ - { - name: "empty opts should fail", - args: args{ - session: makeTestingSession(), - }, - want: nil, - wantErr: errBadInput, - }, - { - name: "empty session should fail", - args: args{ - cfg: func() *certConfig { - crt, err := writeTestingCerts(t.TempDir()) - if err != nil { - t.Errorf("initTLS() cannot create certs for test %v", err.Error()) - } - cfg, err := newCertConfigFromOptions(&Options{ - CertPath: crt.cert, - KeyPath: crt.key, - CaPath: crt.ca, - }) - if err != nil { - t.Errorf("initTLS() cannot load config from opts %v", err.Error()) - } - return cfg - }(), - }, - want: nil, - wantErr: errBadInput, - }, - { - name: "default tls config should not fail with proper cert paths", - args: args{ - session: makeTestingSession(), - cfg: func() *certConfig { - c, _ := writeTestingCerts("") - cfg, err := newCertConfigFromOptions( - &Options{ - CertPath: c.cert, - KeyPath: c.key, - CaPath: c.ca, - }) - if err != nil { - t.Errorf("error while testing: %v", err) - } - return cfg - }(), - }, - want: nil, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := initTLS(tt.args.session, tt.args.cfg) - if !errors.Is(err, tt.wantErr) { - t.Errorf("initTLS() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.want == nil { - return - } - if !reflect.DeepEqual(got.InsecureSkipVerify, tt.want.InsecureSkipVerify) { - t.Errorf("initTLS() InsecureSkipVerify = %v, want %v", got.InsecureSkipVerify, tt.want.InsecureSkipVerify) - return - } - }) - } -} - -var pemTestingKey = []byte(`-----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/vw0YScdbP2wg -3M+N6BlsCQePUVFlyLh3faPtfqKTeWfyMYhGMeUE4fMcO1H0l7b/+zfwfA85AhlT -dU152AXvizBidnaQXwVxsxzLPiPxn3qH5KxD72vkMHMyUrRh/tdJzIj1bqlCiLcw -SK5EDPMwuUSAIk7evRzLUdGu1JkUxi7xox03R5rvC8ZohAPSRxFAg6rajkk7HlUi -BepNz5PRlPGJ0Kfn0oa/BF+5F3Y4WU+75r9tK+H691eRL65exTGrYIOZE9Rd6i8C -S3WoFNmlO6tv0HMAh/GYR6/mrekOkSZdjNIbDfcNiFsvNtMIO9jztd7g/3BcQg/3 -eFydHplrAgMBAAECggEAM8lBnCGw+e/zIB0C4WyiEQ+PPyHTPg4r4/nG4EmnVvUf -IcZG685l8B+mLSXISKsA/bm3rfeTlO4AMQ4pUpMJZ1zMQIuGEg/XxJF/YVTzGDre -OP2FmQN8vDBprFmx5hWRx5i6FK9Cf3m1IBFBH5fvxmUDHygk7PteX3tFilZY0ccM -TpK8nOOpbbK/8S8dC6ePXYgjamLotAnKdgKnpmxQjiprsRAWiOr7DFdjMLCUyZkC -NYwRszVNX84wLOFNzFdU653gFKNcJ/8NI2MBQ5EaBMWOcxNgdfBtCXE9GwQVNzp2 -tjTt2QYbTdaw6LAMKgrWgaZBp0VSK4WTlYLifwrSQQKBgQD4Ah39r/l+QyTLwr6d -AkMp/rgpOYzvaRzuUcZnObvi8yfFlJJ6EM4zfNICXNexdqeL+WTaSV1yuc4/rsRx -nAgXklgz2UpATccLJ7JrCDsWgZm71tfUWQM5IbMgkyVixwGYiTsW+kMxFD0n2sNK -sPkEgr2IiSEDfjzTf0LPr7sLyQKBgQDF7NCTTEp92FSz5OcKNSI7iH+lsVgV+U88 -Widc/thn/vRnyRqpvyjUvl9D9jMTz2/9DiV06lCYfN8KpknCb3jCWY5cjmOSZQTs -oHQQX145Exe8cj2z+66QK6CsE1tlUC99Y684hn+eDlLMIQGMtRz8aSYb8oZo68sM -hcTaP8CtkwKBgQDK0RhrrWyQWCKQS9uMFRyODFPYysq5wzE4qEFji3BeodFFoEHF -d1bZ/lrUOc7evxU3wCU86kB0oQTNSYQ3EI4BkNl21V0Gh1Seh8E+DIYd2rC5T3JD -ouOi5i9SFWO+itaAQsHDAbjPOyjkHeAVhfKvQKf1L4eDDsp5f5pItAJ4GQKBgDvF -EwuYW1p7jMCynG7Bsu/Ffb68unwQSLRSCVcVAqcNICODYJDoUF1GjCBK5gvSdeA2 -eGtBI0uZUgW2R8n2vcH7J3md6kXYSc9neQVEt4CG2oEnAqkqlQGmmyO7yLrkpyK3 -ir+IJlvFuY05Xm1ueC1lV4PTDnH62tuSPesmm3oPAoGBANsj/l6xgcMZK6VKZHGV -gG59FoMudCvMP1pITJh+TQPIJbD4TgYnDUG7z14zrYhxChWHYysVrIT35Iuu7k6S -JlkPybAiLmv2nulx9fRkTzcGgvPtG3iHS/WQLvr9umWrfmQYMMW1Udr0IdflS1Sk -fIeuXWkQrCE24uKSInkRupLO ------END PRIVATE KEY-----`) - -var pemTestingCertificate = []byte(`-----BEGIN CERTIFICATE----- -MIIDjTCCAnUCFGb3X7au5DHHCSd8n6e5vG1/HGtyMA0GCSqGSIb3DQEBCwUAMIGB -MQswCQYDVQQGEwJOWjELMAkGA1UECAwCTk8xEjAQBgNVBAcMCUludGVybmV0ejEN -MAsGA1UECgwEQW5vbjENMAsGA1UECwwEcm9vdDESMBAGA1UEAwwJbG9jYWxob3N0 -MR8wHQYJKoZIhvcNAQkBFhB1c2VyQGV4YW1wbGUuY29tMB4XDTIyMDUyMDE4Mzk0 -N1oXDTIyMDYxOTE4Mzk0N1owgYMxCzAJBgNVBAYTAk5aMQswCQYDVQQIDAJOTzES -MBAGA1UEBwwJSW50ZXJuZXR6MQ0wCwYDVQQKDARBbm9uMQ8wDQYDVQQLDAZzZXJ2 -ZXIxEjAQBgNVBAMMCWxvY2FsaG9zdDEfMB0GCSqGSIb3DQEJARYQdXNlckBleGFt -cGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL+/DRhJx1s/ -bCDcz43oGWwJB49RUWXIuHd9o+1+opN5Z/IxiEYx5QTh8xw7UfSXtv/7N/B8DzkC -GVN1TXnYBe+LMGJ2dpBfBXGzHMs+I/GfeofkrEPva+QwczJStGH+10nMiPVuqUKI -tzBIrkQM8zC5RIAiTt69HMtR0a7UmRTGLvGjHTdHmu8LxmiEA9JHEUCDqtqOSTse -VSIF6k3Pk9GU8YnQp+fShr8EX7kXdjhZT7vmv20r4fr3V5Evrl7FMatgg5kT1F3q -LwJLdagU2aU7q2/QcwCH8ZhHr+at6Q6RJl2M0hsN9w2IWy820wg72PO13uD/cFxC -D/d4XJ0emWsCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAGt+m0kwuULOVEr7QvbOI -6pxEd9AysxWxGzGBM6G9jrhlgch10wWuhDZq0LqahlWQ8DK9Kjg+pHEYYN8B1m0L -2lloFpXb+AXJR9RKsBr4iU2HdJkPIAwYlDhPUTeskfWP61JGGQC6oem3UXCbLldE -VxcY3vSifP9/pIyjHVULa83FQwwsseavav3NvBgYIyglz+BLl6azMdFLXyzGzEUv -iiN6MdNrJ34iDKHCYSlNvJktJY91eTsQ1GLYD6O9C5KrCJRp0ibQ1keSE7vdhnTY -doKeoNOwq224DcktFdFAYnOM/q3dKxz3m8TsM5OLel4kebqDovPt0hJl2Wwwx43k -0A== ------END CERTIFICATE-----`) - -var pemTestingCa = []byte(`-----BEGIN CERTIFICATE----- -MIID5TCCAs2gAwIBAgIUecMREJYMxFeQEWNBRSCM1x/pAEIwDQYJKoZIhvcNAQEL -BQAwgYExCzAJBgNVBAYTAk5aMQswCQYDVQQIDAJOTzESMBAGA1UEBwwJSW50ZXJu -ZXR6MQ0wCwYDVQQKDARBbm9uMQ0wCwYDVQQLDARyb290MRIwEAYDVQQDDAlsb2Nh -bGhvc3QxHzAdBgkqhkiG9w0BCQEWEHVzZXJAZXhhbXBsZS5jb20wHhcNMjIwNTIw -MTgzOTQ3WhcNMjIwNjE5MTgzOTQ3WjCBgTELMAkGA1UEBhMCTloxCzAJBgNVBAgM -Ak5PMRIwEAYDVQQHDAlJbnRlcm5ldHoxDTALBgNVBAoMBEFub24xDTALBgNVBAsM -BHJvb3QxEjAQBgNVBAMMCWxvY2FsaG9zdDEfMB0GCSqGSIb3DQEJARYQdXNlckBl -eGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMxO6abV -xOy/2VuekAAvJnM2bFIpqSoWK1uMDHJc7NRWVPy2UFaDvCL2g+CSqEyqMN0NI0El -J2cIAgUYOa0+wHJWQhAL60veR6ew9JfIDk3S7YNeKzUGgrRzKvTLdms5mL8fZpT+ -GFwHprx58EZwg2TDQ6bGdThsSYNbx72PRngIOl5k6NWdIgd0wiAAYIpNQQUc8rDC -IG4VvoitbpzYcAFCxCVGivodLP02pk2hokbidnLyTj5wIVTccA3u9FeEq2+IIAfr -OW+3LjCpH9SC+3qPjA0UHv2bCLMVzIp86lUsbx6Qcoy0RPh5qC28cLk19wQj5+pw -XtOeL90d2Hokf40CAwEAAaNTMFEwHQYDVR0OBBYEFNuQwyljbQs208ZCI5NFuzvo -1ez8MB8GA1UdIwQYMBaAFNuQwyljbQs208ZCI5NFuzvo1ez8MA8GA1UdEwEB/wQF -MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAHPkGlDDq79rdxFfbt0dMKm1dWZtPlZl -iIY9Pcet/hgf69OKXwb4h3E0IjFW7JHwo4Bfr4mqrTQLTC1qCRNEMC9XUyc4neQy -3r2LRk+D7XAN1zwL6QPw550ukbLk4R4I1xQr+9Sap9h0QUaJj5tts6XSzhZ1AylJ -HgmkOnPOpcIWm+yUMEDESGnhE8hfXR1nhb5lLrg2HIqp9qRRH1w/wc7jG3bYV3jg -S5nL4GaRzx84PB1HWONlh0Wp7KBk2j6Lp0acoJwI2mHJcJoOPpaYiWWYNNTjMv2/ -XXNUizTI136liavLslSMoYkjYAun+5HOux/keA1L+lm2XeG06Ew1qS4= ------END CERTIFICATE-----`) - -type testingCert struct { - cert string - key string - ca string -} - -func writeTestingCerts(dir string) (testingCert, error) { - certFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - certFile.Write(pemTestingCertificate) - - keyFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - keyFile.Write(pemTestingKey) - - caFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - caFile.Write(pemTestingCa) - - testingCert := testingCert{ - cert: certFile.Name(), - key: keyFile.Name(), - ca: caFile.Name(), - } - return testingCert, nil -} - -func writeTestingCertsBadCAFile(dir string) (testingCert, error) { - certFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - certFile.Write(pemTestingCertificate) - - keyFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - keyFile.Write(pemTestingKey) - - caFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - caFile.Write(pemTestingCa[:len(pemTestingCa)-10]) - - testingCert := testingCert{ - cert: certFile.Name(), - key: keyFile.Name(), - ca: caFile.Name() + "-non-existent", - } - return testingCert, nil -} - -func writeTestingCertsBadCA(dir string) (testingCert, error) { - certFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - certFile.Write(pemTestingCertificate) - - keyFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - keyFile.Write(pemTestingKey) - - caFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - caFile.Write(pemTestingCa[:len(pemTestingCa)-10]) - - testingCert := testingCert{ - cert: certFile.Name(), - key: keyFile.Name(), - ca: caFile.Name(), - } - return testingCert, nil -} - -func writeTestingCertsBadKey(dir string) (testingCert, error) { - certFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - certFile.Write(pemTestingCertificate) - - keyFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - keyFile.Write(pemTestingKey[:len(pemTestingKey)-10]) - - caFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - caFile.Write(pemTestingCa) - - testingCert := testingCert{ - cert: certFile.Name(), - key: keyFile.Name(), - ca: caFile.Name(), - } - return testingCert, nil -} - -func writeTestingCertsBadCert(dir string) (testingCert, error) { - certFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - certFile.Write(pemTestingCertificate[:len(pemTestingCertificate)-10]) - - keyFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - keyFile.Write(pemTestingKey[:len(pemTestingKey)-10]) - - caFile, err := os.CreateTemp(dir, "tmpfile-") - if err != nil { - return testingCert{}, err - } - caFile.Write(pemTestingCa) - - testingCert := testingCert{ - cert: certFile.Name(), - key: keyFile.Name(), - ca: caFile.Name(), - } - return testingCert, nil -} - -func Test_loadCertAndCAFromPath(t *testing.T) { - type args struct { - pth certPaths - } - tests := []struct { - name string - args args - want *certConfig - wantErr error - }{ - { - name: "bad ca (non existent file) should fail", - args: func() args { - crt, err := writeTestingCertsBadCAFile(t.TempDir()) - if err != nil { - t.Errorf("error while testing: %v", err) - } - return args{pth: certPaths{crt.cert, crt.key, crt.ca}} - - }(), - want: nil, - wantErr: ErrBadCA, - }, - { - name: "bad ca (malformed) should fail", - args: func() args { - crt, err := writeTestingCertsBadCA(t.TempDir()) - if err != nil { - t.Errorf("error while testing: %v", err) - } - return args{pth: certPaths{crt.cert, crt.key, crt.ca}} - - }(), - want: nil, - wantErr: ErrBadCA, - }, - { - name: "bad key", - args: func() args { - crt, err := writeTestingCertsBadKey(t.TempDir()) - if err != nil { - t.Errorf("error while testing: %v", err) - } - return args{pth: certPaths{crt.cert, crt.key, crt.ca}} - - }(), - want: nil, - wantErr: ErrBadKeypair, - }, - { - name: "bad cert", - args: func() args { - crt, err := writeTestingCertsBadCert(t.TempDir()) - if err != nil { - t.Errorf("error while testing: %v", err) - } - return args{pth: certPaths{crt.cert, crt.key, crt.ca}} - - }(), - want: nil, - wantErr: ErrBadKeypair, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := loadCertAndCAFromPath(tt.args.pth) - if !errors.Is(err, tt.wantErr) { - t.Errorf("loadCertAndCA() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("loadCertAndCA() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_loadCertAndCAFromBytes(t *testing.T) { - type args struct { - crt certBytes - } - tests := []struct { - name string - args args - want *certConfig - wantErr error - }{ - { - name: "bad ca should fail", - args: args{crt: certBytes{ - ca: pemTestingCa[:len(pemTestingCa)-10], - cert: pemTestingCertificate, - key: pemTestingKey}}, - want: nil, - wantErr: ErrBadCA, - }, - { - name: "bad cert should fail", - args: args{crt: certBytes{ - ca: pemTestingCa, - cert: pemTestingCertificate[:len(pemTestingCertificate)-10], - key: pemTestingKey}}, - want: nil, - wantErr: ErrBadKeypair, - }, - { - name: "bad key should fail", - args: args{crt: certBytes{ - ca: pemTestingCa, - cert: pemTestingCertificate, - key: pemTestingKey[10:]}}, - want: nil, - wantErr: ErrBadKeypair, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := loadCertAndCAFromBytes(tt.args.crt) - if !errors.Is(err, tt.wantErr) { - t.Errorf("loadCertAndCAFromBytes() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("loadCertAndCAFromBytes() = %v, want %v", got, tt.want) - } - }) - } - t.Run("sunny path should not fail", func(t *testing.T) { - crt := certBytes{ - ca: pemTestingCa, - cert: pemTestingCertificate, - key: pemTestingKey, - } - _, err := loadCertAndCAFromBytes(crt) - if err != nil { - t.Errorf("loadCertAndCAFromBytes() err = %v, want %v", err, nil) - } - }) -} - -func Test_initTLSLoadTestCertificates(t *testing.T) { - - t.Run("default options should not fail", func(t *testing.T) { - session := makeTestingSession() - crt, err := writeTestingCerts(t.TempDir()) - if err != nil { - t.Errorf("error while testing: %v", err) - } - cfg, err := newCertConfigFromOptions(&Options{ - CertPath: crt.cert, - KeyPath: crt.key, - CaPath: crt.ca, - }) - if err != nil { - t.Errorf("error while testing: %v", err) - } - - _, err = initTLS(session, cfg) - if err != nil { - t.Errorf("initTLS() error = %v, want: nil", err) - } - }) - - t.Run("default options from bytes should not fail", func(t *testing.T) { - session := makeTestingSession() - cfg, err := newCertConfigFromOptions(&Options{ - Cert: pemTestingCertificate, - Key: pemTestingKey, - Ca: pemTestingCa, - }) - if err != nil { - t.Errorf("error while testing: %v", err) - } - _, err = initTLS(session, cfg) - if err != nil { - t.Errorf("initTLS() error = %v, want: nil", err) - } - }) -} - -// mock good handshake - -type dummyTLSConn struct { - tls.Conn -} - -var _ handshaker = &dummyTLSConn{} // Ensure that we implement handshaker - -func (d *dummyTLSConn) Handshake() error { - return nil -} - -func dummyTLSFactory(net.Conn, *tls.Config) (handshaker, error) { - return &dummyTLSConn{tls.Conn{}}, nil -} - -// mock bad handshake - -type dummyTLSConnBadHandshake struct { - tls.Conn -} - -var _ handshaker = &dummyTLSConnBadHandshake{} // Ensure that we implement handshaker - -func (d *dummyTLSConnBadHandshake) Handshake() error { - return errors.New("dummy error") -} - -func dummyTLSFactoryBadHandshake(net.Conn, *tls.Config) (handshaker, error) { - return &dummyTLSConnBadHandshake{tls.Conn{}}, nil -} - -var tlsFactoryError = errors.New("tlsFactory error") - -func errorRaisingTLSFactory(net.Conn, *tls.Config) (handshaker, error) { - return nil, tlsFactoryError -} - -func Test_tlsHandshake(t *testing.T) { - - makeConnAndConf := func() (*controlChannelTLSConn, *tls.Config) { - conn := &mocks.Conn{} - s := makeTestingSession() - tc, _ := newControlChannelTLSConn(conn, s) - - conf := &tls.Config{ - InsecureSkipVerify: true, - } - return tc, conf - } - - t.Run("mocked good handshake should not fail", func(t *testing.T) { - origTLS := tlsFactoryFn - tlsFactoryFn = dummyTLSFactory - defer func() { - tlsFactoryFn = origTLS - }() - - conn, conf := makeConnAndConf() - - _, err := tlsHandshake(conn, conf) - if err != nil { - t.Errorf("tlsHandshake() error = %v, wantErr %v", err, nil) - return - } - }) - - t.Run("mocked bad handshake should fail", func(t *testing.T) { - origTLS := tlsFactoryFn - tlsFactoryFn = dummyTLSFactoryBadHandshake - defer func() { - tlsFactoryFn = origTLS - }() - - conn, conf := makeConnAndConf() - - wantErr := ErrBadTLSHandshake - _, err := tlsHandshake(conn, conf) - if !errors.Is(err, wantErr) { - t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr) - return - } - }) - - t.Run("any error from the factory should be bubbled up", func(t *testing.T) { - origTLS := tlsFactoryFn - tlsFactoryFn = errorRaisingTLSFactory - defer func() { - tlsFactoryFn = origTLS - }() - wantErr := tlsFactoryError - - conn, conf := makeConnAndConf() - - _, err := tlsHandshake(conn, conf) - if !errors.Is(err, wantErr) { - t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr) - return - } - }) -} - -func Test_defaultTLSFactory(t *testing.T) { - conn := &mocks.Conn{} - conf := &tls.Config{} - defaultTLSFactory(conn, conf) -} - -func Test_parrotTLSFactory(t *testing.T) { - conn := &mocks.Conn{} - conf := &tls.Config{InsecureSkipVerify: true} - - t.Run("parrotTLS factory does not return any error by default", func(t *testing.T) { - _, err := parrotTLSFactory(conn, conf) - if err != nil { - t.Errorf("parrotTLSFactory() error = %v, wantErr %v", err, nil) - return - } - }) - - t.Run("an hex clienthello that cannot be decoded to raw bytes should raise ErrBadParrot", func(t *testing.T) { - defer func(original string) { - vpnClientHelloHex = original - }(vpnClientHelloHex) - vpnClientHelloHex = `aaa` - - _, err := parrotTLSFactory(conn, conf) - wantErr := ErrBadParrot - if !errors.Is(err, wantErr) { - t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr) - return - } - }) - - t.Run("an hex representation that is not a valid clienthello should raise ErrBadParrot", func(t *testing.T) { - defer func(original string) { - vpnClientHelloHex = original - }(vpnClientHelloHex) - vpnClientHelloHex = `deadbeef` - - _, err := parrotTLSFactory(conn, conf) - wantErr := ErrBadParrot - if !errors.Is(err, wantErr) { - t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr) - return - } - }) - - // TODO(ainghazal): there's an extra error case that I'm not pretty sure how to reach - // (error on client.ApplyPreset) -} - -func Test_customVerify(t *testing.T) { - - t.Run("happy path: a correct certChain should validate if we pin with the good ca", func(t *testing.T) { - rawCerts, ca, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - err = customVerify(rawCerts, nil) - if err != nil { - t.Errorf("customVerify() error = %v, wantErr %v", err, nil) - } - }) - - t.Run("a certChain should not validate if we do not pin the proper ca", func(t *testing.T) { - rawCerts, _, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - _, badCa, _, _, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - auth, err := makeCertAndCAFromMemory(badCa, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - wantErr := ErrCannotVerifyCertChain - err = customVerify(rawCerts, nil) - - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, nil) - } - }) - - t.Run("a certChain should not validate if we pass an empty ca", func(t *testing.T) { - rawCerts, _, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - emptyCa := &x509.Certificate{} - - auth, err := makeCertAndCAFromMemory(emptyCa, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - wantErr := ErrCannotVerifyCertChain - err = customVerify(rawCerts, nil) - - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, nil) - } - }) - - t.Run("a correct certChain fails if DNSName is set in VerifyOptions", func(t *testing.T) { - // this test is really only testing the behavior of golang x509 validation - // in the stdlib, but it gives me more faith in the correctness - // of the custom verify function - rawCerts, ca, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - defer func(orig func() x509.VerifyOptions) { - certVerifyOptions = orig - }(certVerifyOptions) - - // the test cert has random.gateway set as the DNSName, so we're just verifying - // that the verification actually fails with options different from the default that we're - // setting in the certVerifyOptions global. - certVerifyOptions = func() x509.VerifyOptions { - return x509.VerifyOptions{DNSName: "other.gateway"} - } - - wantErr := ErrCannotVerifyCertChain - - auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - err = customVerify(rawCerts, nil) - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, nil) - } - }) - - t.Run("empty certchain raises error", func(t *testing.T) { - emptyCerts := [][]byte{[]byte{}, []byte{}} - wantErr := ErrCannotVerifyCertChain - - _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - err = customVerify(emptyCerts, nil) - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr) - } - }) - - t.Run("garbage certchain raises error", func(t *testing.T) { - garbageCerts := [][]byte{[]byte{0xde, 0xad}, []byte{0xbe, 0xef}} - wantErr := ErrCannotVerifyCertChain - - _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - err = customVerify(garbageCerts, nil) - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr) - } - }) - - t.Run("attempting to verify one cert with a different ca raises error", func(t *testing.T) { - certChainOne, _, _, _, _ := makeRawCertsForTesting() - certChainTwo, _, _, _, _ := makeRawCertsForTesting() - badChain := [][]byte{certChainOne[0], certChainTwo[1]} - wantErr := ErrCannotVerifyCertChain - - _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting() - if err != nil { - t.Errorf("error getting raw certs") - return - } - - auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey) - customVerify := customVerifyFactory(auth) - - err = customVerify(badChain, nil) - if !errors.Is(err, wantErr) { - t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr) - } - }) -} - -// makeRawCertsForTesting creates a CA, and returns: -// * an array of byte arrays containing a cert signed with that CA and the CA itself (to be used to test the verify routine). -// * the ca used to sign the certs -// * a cert that simulates a vpn certificate signed by the ca (rsa) -// * the private key for the vpn certificate -// * an error if it could not build any of the certs correctly. -func makeRawCertsForTesting() ([][]byte, *x509.Certificate, []byte, []byte, error) { - // set up a CA certificate. this sets up a 2048 cert for the ca, if we ever - // want to shave milliseconds we can roll a ca with a smaller key size. - ca, caPrivKey, err := mitm.NewAuthority("ca", "oonitarians united", 1*time.Hour) - if err != nil { - return nil, nil, nil, nil, err - } - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) - if err != nil { - return nil, nil, nil, nil, err - } - - // set up a leaf certificate - this would be the gateway cert - cert := &x509.Certificate{ - SerialNumber: big.NewInt(1984), - Subject: pkix.Name{ - Organization: []string{"Oonitarians united"}, - StreetAddress: []string{"On a pinneaple at the bottom of the sea"}, - CommonName: "random.gateway", - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - DNSNames: []string{"random.gateway", "randomgw"}, - } - - // tiny cert size to make tests go brrr - certPrivKey, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - return nil, nil, nil, nil, err - } - - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) - if err != nil { - return nil, nil, nil, nil, err - } - - // set up a vpn certificate - this would be the client cert - vpnCert := &x509.Certificate{ - SerialNumber: big.NewInt(1984), - Subject: pkix.Name{ - Organization: []string{"Oonitarians united"}, - StreetAddress: []string{"On a pinneaple at the bottom of the sea"}, - CommonName: "client cert", - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - } - - // tiny cert size to make tests go brrr - vpnCertPrivKey, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - return nil, nil, nil, nil, err - } - - vpnCertBytes, err := x509.CreateCertificate(rand.Reader, vpnCert, ca, &vpnCertPrivKey.PublicKey, caPrivKey) - if err != nil { - return nil, nil, nil, nil, err - } - - vpnKeyBytes := x509.MarshalPKCS1PrivateKey(vpnCertPrivKey) - - result := [][]byte{certBytes, caBytes} - return result, ca, vpnCertBytes, vpnKeyBytes, nil -} - -func makeCertAndCAFromMemory(caCert *x509.Certificate, vpnCert []byte, vpnKey []byte) (*certConfig, error) { - ca := x509.NewCertPool() - ca.AddCert(caCert) - cert, _ := tls.X509KeyPair(vpnCert, vpnKey) - auth := &certConfig{ - ca: ca, - cert: cert, - } - return auth, nil -} diff --git a/vpn/transport.go b/vpn/transport.go deleted file mode 100644 index baa3cdc4..00000000 --- a/vpn/transport.go +++ /dev/null @@ -1,310 +0,0 @@ -package vpn - -// -// Transports for OpenVPN over TCP and over UDP. -// This file includes: -// 1. Methods for reading packets from the wire -// 2. A TLS transport that reads and writes TLS records as part of control packets. -// - -import ( - "bytes" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "net" - "time" -) - -var ( - // ErrBadConnNetwork indicates that the conn's network is neither TCP nor UDP. - ErrBadConnNetwork = errors.New("bad conn.Network value") - - // ErrPacketTooShort indicates that a packet is too short. - ErrPacketTooShort = errors.New("packet too short") -) - -// direct reads on the underlying conn - -func readPacket(conn net.Conn) ([]byte, error) { - switch network := conn.LocalAddr().Network(); network { - case "tcp", "tcp4", "tcp6": - return readPacketFromTCP(conn) - case "udp", "udp4", "upd6": - // for UDP we don't need to parse size frames - return readPacketFromUDP(conn) - default: - return nil, fmt.Errorf("%w: %s", ErrBadConnNetwork, network) - } -} - -func readPacketFromUDP(conn net.Conn) ([]byte, error) { - const enough = 1 << 17 - buf := make([]byte, enough) - - count, err := conn.Read(buf) - if err != nil { - return nil, err - } - buf = buf[:count] - return buf, nil -} - -func readPacketFromTCP(conn net.Conn) ([]byte, error) { - lenbuf := make([]byte, 2) - if _, err := io.ReadFull(conn, lenbuf); err != nil { - return nil, err - } - length := binary.BigEndian.Uint16(lenbuf) - buf := make([]byte, length) - if _, err := io.ReadFull(conn, buf); err != nil { - return nil, err - } - return buf, nil -} - -// tlsModeTransporter is a transport for OpenVPN in TLS mode. -// -// See https://openvpn.net/community-resources/openvpn-protocol/ for documentation -// on the protocol used by OpenVPN on the wire. -type tlsModeTransporter interface { - // ReadPacket reads an OpenVPN packet from the wire. - ReadPacket() (p *packet, err error) - - // WritePacket writes an OpenVPN packet to the wire. - WritePacket(opcodeKeyID uint8, data []byte) error - - // SetDeadline sets the underlying conn's deadline. - SetDeadline(deadline time.Time) error - - // SetReadDeadline sets the underlying conn's read deadline. - SetReadDeadline(deadline time.Time) error - - // SetWriteDeadline sets the underlying conn's write deadline. - SetWriteDeadline(deadline time.Time) error - - // Close closes the underlying conn. - Close() error - - // LocalAddr returns the underlying conn's local addr. - LocalAddr() net.Addr - - // RemoteAddr returns the underlying conn's remote addr. - RemoteAddr() net.Addr -} - -// newTLSModeTransport creates a new TLSModeTransporter using the given net.Conn. -func newTLSModeTransport(conn net.Conn, s *session) (tlsModeTransporter, error) { - return &tlsTransport{Conn: conn, session: s}, nil -} - -// tlsTransport implements TLSModeTransporter. -type tlsTransport struct { - net.Conn - session *session -} - -// ReadPacket returns a packet reading from the underlying conn, and an error -// if the read did not succeed. -func (t *tlsTransport) ReadPacket() (*packet, error) { - buf, err := readPacket(t.Conn) - if err != nil { - return nil, err - } - - p, err := parsePacketFromBytes(buf) - if err != nil { - return &packet{}, err - } - if p.isACK() { - logger.Warn("tls: got ACK (ignored)") - return &packet{}, nil - } - return p, nil -} - -// WritePacket writes a packet to the underlying conn. It expect the opcode of the packet and a byte array containing the serialized data. It returns an error if the write did not succeed. -func (t *tlsTransport) WritePacket(opcodeKeyID uint8, data []byte) error { - if t.session == nil { - return fmt.Errorf("%w:%s", errBadInput, "tlsTransport badly initialized") - - } - p := newPacketFromPayload(opcodeKeyID, 0, data) - id, err := t.session.LocalPacketID() - if err != nil { - return err - } - p.id = id - p.localSessionID = t.session.LocalSessionID - payload := p.Bytes() - - out := maybeAddSizeFrame(t.Conn, payload) - - logger.Debug(fmt.Sprintln("tls write:", len(out))) - logger.Debug(fmt.Sprintln(hex.Dump(out))) - - _, err = t.Conn.Write(out) - return err -} - -var _ tlsModeTransporter = &tlsTransport{} // Ensure that we implement TLSModelTransporter - -// controlChannelTLSConn implements net.Conn, and is passed to the tls.Client to perform a -// TLS Handshake over OpenVPN control packets. -type controlChannelTLSConn struct { - conn net.Conn - session *session - transport tlsModeTransporter - // we need to buffer reads because the tls records request less than - // the payload we receive. - bufReader *bytes.Buffer - - doReadFromConnFn func(*controlChannelTLSConn, []byte) (bool, int, error) - doReadFromQueueFn func(*controlChannelTLSConn, []byte) (bool, int, error) -} - -// newControlChannelTLSConn returns a controlChannelTLSConn. It requires the on-the-wire -// net.Conn that will be used underneath, and a configured session. It returns -// also an error if the operation cannot be completed. -func newControlChannelTLSConn(conn net.Conn, s *session) (*controlChannelTLSConn, error) { - transport, err := newTLSModeTransport(conn, s) - if err != nil { - return &controlChannelTLSConn{}, err - } - buf := bytes.NewBuffer(nil) - tlsConn := &controlChannelTLSConn{ - conn: conn, - session: s, - transport: transport, - bufReader: buf, - } - tlsConn.doReadFromConnFn = doReadFromConn - tlsConn.doReadFromQueueFn = doReadFromQueue - return tlsConn, err -} - -// Read over the control channel. This method implements the reliability layer: -// it retries reads until the _next_ packet is received (according to the -// packetID). Returns also an error if the operation cannot be completed. -func (c *controlChannelTLSConn) Read(b []byte) (int, error) { - if c.session == nil || c.session.ackQueue == nil { - return 0, fmt.Errorf("%w: %s", errBadInput, "bad session in TLSConn.Read()") - } - for { - switch len(c.session.ackQueue) { - case 0: - ok, n, err := c.doReadFromConnFn(c, b) - if ok { - return n, err - } - default: - ok, n, err := c.doReadFromQueueFn(c, b) - if ok { - return n, err - } - } - } -} - -func doReadFromConn(c *controlChannelTLSConn, b []byte) (bool, int, error) { - p, err := c.doRead() - - if err != nil { - return true, 0, err - } - switch c.canRead(p) { - case true: - if err := sendACKFn(c.conn, c.session, p.id); err != nil { - return true, 0, err - } - n, err := writeAndReadFromBufferFn(c.bufReader, b, p.payload) - return true, n, err - case false: - if p != nil { - c.session.ackQueue <- p - } - } - - return false, 0, nil -} - -func doReadFromQueue(c *controlChannelTLSConn, b []byte) (bool, int, error) { - for p := range c.session.ackQueue { - if c.canRead(p) { - if err := sendACKFn(c.conn, c.session, p.id); err != nil { - return true, 0, err - } - n, err := writeAndReadFromBufferFn(c.bufReader, b, p.payload) - return true, n, err - } else { - c.session.ackQueue <- p - return doReadFromConn(c, b) - } - } - return false, 0, nil -} - -// doRead() calls ReadPacket() in the underlying transport implementation. It -// returns a packet and an error. -func (c *controlChannelTLSConn) doRead() (*packet, error) { - if c.transport == nil { - return nil, fmt.Errorf("%w:%s", errBadInput, "tlsConn is missing transport") - - } - return c.transport.ReadPacket() -} - -// canRead returns true if the packet is not nil and its packetID is the next -// integer in the expected sequence; returns false otherwise. -func (c *controlChannelTLSConn) canRead(p *packet) bool { - return p != nil && c.session.isNextPacket(p) -} - -// writeAndReadPayloadFromBuffer writes a given payload to a buffered reader, and returns -// a read from that same buffered reader into the passed byte array. it returns both an integer -// denoting the amount of bytes read, and any error during the operation. -func writeAndReadFromBuffer(bb *bytes.Buffer, b []byte, payload []byte) (int, error) { - bb.Write(payload) - return bb.Read(b) -} - -var writeAndReadFromBufferFn = writeAndReadFromBuffer - -// Write writes the given data to the tls connection. -func (c *controlChannelTLSConn) Write(b []byte) (int, error) { - err := c.transport.WritePacket(uint8(pControlV1), b) - if err != nil { - logger.Errorf("tls write: %s", err.Error()) - return 0, err - } - return len(b), err -} - -// Close closes the tls connection. -func (c *controlChannelTLSConn) Close() error { - return c.conn.Close() -} - -func (c *controlChannelTLSConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *controlChannelTLSConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *controlChannelTLSConn) SetDeadline(tt time.Time) error { - return c.conn.SetDeadline(tt) -} - -func (c *controlChannelTLSConn) SetReadDeadline(tt time.Time) error { - return c.conn.SetReadDeadline(tt) -} - -func (c *controlChannelTLSConn) SetWriteDeadline(tt time.Time) error { - return c.conn.SetWriteDeadline(tt) -} - -var _ net.Conn = &controlChannelTLSConn{} // Ensure that we implement net.Conn diff --git a/vpn/transport_test.go b/vpn/transport_test.go deleted file mode 100644 index afd7ab53..00000000 --- a/vpn/transport_test.go +++ /dev/null @@ -1,647 +0,0 @@ -package vpn - -import ( - "bytes" - "errors" - "net" - "reflect" - "testing" - "time" - - "github.com/ooni/minivpn/vpn/mocks" -) - -func Test_readPacketFromUDP(t *testing.T) { - conn := makeTestinConnFromNetwork("udp") - got, err := readPacketFromUDP(conn) - want := []byte("alles ist gut") - if err != nil { - t.Errorf("readPacketFromUDP() error = %v, want %v", err, nil) - } - if !bytes.Equal(got, want) { - t.Errorf("readPacketFromTCP() got = %s, want %s", got, want) - } -} - -func Test_readPacketFromTCP(t *testing.T) { - conn := makeTestinConnFromNetwork("tcp") - got, err := readPacketFromTCP(conn) - want := []byte("alles ist gut") - if err != nil { - t.Errorf("readPacketFromTCP() error = %v, want %v", err, nil) - } - if !bytes.Equal(got, want) { - t.Errorf("readPacketFromTCP() got = %s, want %s", got, want) - } -} - -func Test_readPacket_BadNetwork(t *testing.T) { - conn := makeTestinConnFromNetwork("unix") - _, err := readPacket(conn) - wantErr := ErrBadConnNetwork - if !errors.Is(err, wantErr) { - t.Errorf("readPacket() got = %v, want %v", err, wantErr) - } -} - -type MockTLSTransportConn struct { - *mocks.Conn - written []byte -} - -func makeTestingTLSTransportWithPacket(packetPayload *packet) (*tlsTransport, *MockTLSTransportConn) { - s := makeTestingSession() - a := &mocks.Addr{} - a.MockNetwork = func() string { return "udp" } - c := &MockTLSTransportConn{Conn: &mocks.Conn{}} - c.MockLocalAddr = func() net.Addr { return a } - c.MockRead = func(b []byte) (int, error) { - out := packetPayload.Bytes() - copy(b, out) - return len(out), nil - } - c.MockWrite = func(b []byte) (int, error) { - c.written = b - return 0, nil - } - return &tlsTransport{Conn: c, session: s}, c -} - -func makeTestingTLSTransportWithDefaultPacketPayload() (*tlsTransport, *MockTLSTransportConn) { - readPayload := &packet{opcode: pDataV1, payload: []byte("this is not a payload")} - return makeTestingTLSTransportWithPacket(readPayload) -} - -func Test_tlsTransport_ReadPacket(t *testing.T) { - fakePayload := append( - // fake tag - bytes.Repeat([]byte{0x00}, 13), - []byte("this is not a payload")...) - want := &packet{opcode: pDataV1, payload: fakePayload} - - tt, _ := makeTestingTLSTransportWithDefaultPacketPayload() - got, err := tt.ReadPacket() - - if err != nil { - t.Errorf("ReadPacket() error = %v, wantErr %v", err, nil) - } - if !bytes.Equal(got.payload, want.payload) { - t.Errorf("ReadPacket() got = %v, want = %v", got.payload, want.payload) - } -} - -func Test_tlsTransport_ReadPacket_ACK(t *testing.T) { - ackPacket := &packet{opcode: pACKV1} - tt, _ := makeTestingTLSTransportWithPacket(ackPacket) - got, err := tt.ReadPacket() - if err != nil { - t.Errorf("ReadPacket() error = %v, wantErr %v", err, nil) - } - if !bytes.Equal(got.payload, ackPacket.payload) { - t.Errorf("ReadPacket() got = %v, want = %v", got.payload, ackPacket.payload) - } - -} - -func Test_tlsTransport_WritePacket(t *testing.T) { - payload := []byte("this is not a payload") - fakePacket := append([]byte{0x30, 0x02}, bytes.Repeat([]byte{0x00}, 12)...) - fakePacket = append(fakePacket, payload...) - - tt, conn := makeTestingTLSTransportWithDefaultPacketPayload() - err := tt.WritePacket(pDataV1, payload) - if err != nil { - t.Errorf("ReadPacket() error = %v, want = %v", err, nil) - } - if !bytes.Equal(conn.written, fakePacket) { - t.Errorf("ReadPacket(): got = %v, want = %v", conn.written, fakePacket) - } -} - -func makeTestinConnFromNetwork(network string) net.Conn { - mockAddr := &mocks.Addr{} - mockAddr.MockNetwork = func() string { - return network - } - c := &mocks.Conn{} - c.MockLocalAddr = func() net.Addr { - return mockAddr - } - switch network { - case "udp": - c.MockRead = func(b []byte) (int, error) { - out := []byte("alles ist gut") - copy(b, out) - return len(out), nil - } - case "tcp": - c.MockRead = func(b []byte) (int, error) { - var out []byte - switch c.Count { - case 0: - out = []byte{0x00, 0x0d} - copy(b, out) - c.Count += 1 - case 1: - out = []byte("alles ist gut") - copy(b, out) - } - return len(out), nil - } - default: - c.MockRead = func([]byte) (int, error) { - return 0, nil - } - } - return c -} - -func Test_readPacket(t *testing.T) { - - type args struct { - conn net.Conn - } - tests := []struct { - name string - args args - want []byte - wantErr error - }{ - { - name: "test read from udp conn is ok", - args: args{ - conn: makeTestinConnFromNetwork("udp"), - }, - want: []byte("alles ist gut"), - wantErr: nil, - }, - { - name: "test read from tcp conn is ok", - args: args{ - conn: makeTestinConnFromNetwork("tcp"), - }, - want: []byte("alles ist gut"), - wantErr: nil, - }, - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := readPacket(tt.args.conn) - if !errors.Is(err, tt.wantErr) { - t.Errorf("readPacket() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readPacket() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_NewTLSConn(t *testing.T) { - conn := makeTestinConnFromNetwork("udp") - s := makeTestingSession() - _, err := newControlChannelTLSConn(conn, s) - if err != nil { - t.Errorf("NewTLSConn() error = %v, want = nil", err) - } -} - -type MockTLSConn struct { - mocks.Conn - closedCalled bool - localAddrCalled bool - remoteAddrCalled bool - setDeadlineCalled bool - setReadDeadlineCalled bool - setWriteDeadlineCalled bool -} - -func makeConnForTransportTest() *MockTLSConn { - localAddr := &mocks.Addr{} - localAddr.MockString = func() string { return "1.1.1.1" } - localAddr.MockNetwork = func() string { return "udp" } - - remoteAddr := &mocks.Addr{} - remoteAddr.MockString = func() string { return "2.2.2.2" } - remoteAddr.MockNetwork = func() string { return "udp" } - - c := &MockTLSConn{} - c.MockClose = func() error { - c.closedCalled = true - return nil - } - c.MockLocalAddr = func() net.Addr { - c.localAddrCalled = true - return localAddr - } - c.MockRemoteAddr = func() net.Addr { - c.remoteAddrCalled = true - return remoteAddr - } - c.MockSetDeadline = func(time.Time) error { - c.setDeadlineCalled = true - return nil - } - c.MockSetReadDeadline = func(time.Time) error { - c.setReadDeadlineCalled = true - return nil - } - c.MockSetWriteDeadline = func(time.Time) error { - c.setWriteDeadlineCalled = true - return nil - } - return c -} - -func makeTestingTLSConn() (*controlChannelTLSConn, *MockTLSConn) { - c := makeConnForTransportTest() - t := &controlChannelTLSConn{} - t.conn = c - return t, c -} - -func TestTLSConn_Read_Fails_With_Bad_Data(t *testing.T) { - tc, _ := makeTestingTLSConn() - b := make([]byte, 16) - _, err := tc.Read(b) - wantErr := errBadInput - if !errors.Is(err, wantErr) { - t.Errorf("TLSConn.Read(): empty session; gotErr = %v, wantErr = %v ", err, wantErr) - } - -} - -func TestTLSConn_Read(t *testing.T) { - // call witnesses - readFromConnCalled := false - readFromQueueCalled := false - - payload := []byte("alles ist gut") - - // setup the fields we need - tc, _ := makeTestingTLSConn() - tc.session = makeTestingSession() - ackQueue := make(chan *packet, 16) - tc.session.ackQueue = ackQueue - - // mock read functions - tc.doReadFromConnFn = func(tcn *controlChannelTLSConn, b []byte) (bool, int, error) { - readFromConnCalled = true - copy(b[:], payload) - return true, len(payload), nil - } - tc.doReadFromQueueFn = func(tcn *controlChannelTLSConn, b []byte) (bool, int, error) { - readFromQueueCalled = true - copy(b[:], payload) - return true, len(payload), nil - } - - // first we read from conn - - b := make([]byte, 255) - n, err := tc.Read(b) - - if err != nil { - t.Errorf("TLSConn.Read(): expected no error, got %v", err) - } - if n != len(payload) { - t.Errorf("TLSConn.Read(): readFromConn returned wrong len %v", n) - } - if !readFromConnCalled { - t.Errorf("TLSConn.Read(): readFromConn not called") - } - if readFromQueueCalled { - t.Errorf("TLSConn.Read(): readFromQueue should have not been called") - } - - // now we read from queue. reset the witnesses: - - readFromConnCalled = false - readFromQueueCalled = false - - // inject one packet in the queue - p := &packet{opcode: pDataV1, payload: []byte("alles ist gut")} - tc.session.ackQueue <- p - - b = make([]byte, 255) - - // and do another call to Read() - n, err = tc.Read(b) - if err != nil { - t.Errorf("TLSConn.Read(): expected no error, got %v", err) - } - if !readFromQueueCalled { - t.Errorf("TLSConn.Read(): readFromQueue not called") - } - if readFromConnCalled { - t.Errorf("TLSConn.Read(): readFromConn should not have been called") - } -} - -func makeTestingTLSTransportFromPayload(payload []byte) (*tlsTransport, *MockTLSTransportConn) { - s := makeTestingSession() - a := &mocks.Addr{} - a.MockNetwork = func() string { return "udp" } - c := &MockTLSTransportConn{Conn: &mocks.Conn{}} - c.MockLocalAddr = func() net.Addr { return a } - c.MockRead = func(b []byte) (int, error) { - out := payload - copy(b, out) - return len(out), nil - } - c.MockWrite = func(b []byte) (int, error) { - c.written = b - return 0, nil - } - return &tlsTransport{Conn: c, session: s}, c -} - -func makePacketForTLSConnTest(id int, s *session) *packet { - p := &packet{ - id: packetID(id), - opcode: pControlV1, - keyID: 0x00, - payload: []byte("aaa"), - localSessionID: s.LocalSessionID, - remoteSessionID: s.RemoteSessionID, - acks: []packetID{}, - } - return p -} - -func makeTestingTLSConnForReadTest(payload []byte) *controlChannelTLSConn { - tc, _ := makeTestingTLSConn() - tt, _ := makeTestingTLSTransportFromPayload(payload) - tc.transport = tt - tc.session = makeTestingSession() - ackQueue := make(chan *packet, 16) - tc.session.ackQueue = ackQueue - return tc -} - -func Test_doReadFromConn(t *testing.T) { - s := makeTestingSession() - p := makePacketForTLSConnTest(1, s) // next packet - payload := p.Bytes() - - tc := makeTestingTLSConnForReadTest(payload) - sendACKFn = func(net.Conn, *session, packetID) error { - return nil - } - writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) { - return 42, nil - } - b := make([]byte, 255) - ok, n, err := doReadFromConn(tc, b) - if err != nil { - t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", nil, err) - return - } - if !ok { - t.Errorf("doReadFromBuffer(): expected ok=true, got ok=%v", ok) - return - } - if n != 42 { - t.Errorf("doReadFromBuffer(): expected %v, got %v", 42, n) - } - if len(tc.session.ackQueue) != 0 { - t.Errorf("doReadFromBuffer(): ackQueue should be 0") - } -} - -func Test_doReadFromConn_Out_Of_Order_Packet(t *testing.T) { - s := makeTestingSession() - p := makePacketForTLSConnTest(2, s) // not next packet - payload := p.Bytes() - - tc := makeTestingTLSConnForReadTest(payload) - - sendACKFn = func(net.Conn, *session, packetID) error { - return nil - } - writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) { - return 42, nil - } - b := make([]byte, 255) - ok, n, err := doReadFromConn(tc, b) - if err != nil { - t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", nil, err) - return - } - if ok { - t.Errorf("doReadFromBuffer(): expected ok=false, got ok=%v", ok) - return - } - if n != 0 { - t.Errorf("doReadFromBuffer(): expected %v, got %v", 0, n) - } - if len(tc.session.ackQueue) != 1 { - t.Errorf("doReadFromBuffer(): ackQueue should be 1") - } -} - -func Test_doReadFromConn_Bubble_Up_Errors(t *testing.T) { - s := makeTestingSession() - p := makePacketForTLSConnTest(1, s) // next packet - payload := p.Bytes() - - tc := makeTestingTLSConnForReadTest(payload) - - makeUpError := errors.New("silly error") - - sendACKFn = func(net.Conn, *session, packetID) error { - return makeUpError - } - writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) { - return 42, nil - } - b := make([]byte, 255) - _, _, err := doReadFromConn(tc, b) - if !errors.Is(err, makeUpError) { - t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", makeUpError, err) - return - } -} - -func Test_doReadFromQueue(t *testing.T) { - s := makeTestingSession() - p := makePacketForTLSConnTest(2, s) // not next packet - tc := makeTestingTLSConnForReadTest(p.Bytes()) // dont care, not going to use it - tc.session.ackQueue <- p - - // mock ack and writes - sendACKFn = func(net.Conn, *session, packetID) error { - return nil - } - writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) { - return 42, nil - } - b := make([]byte, 255) - _, _, err := doReadFromQueue(tc, b) - if err != nil { - t.Errorf("doReadFromQueue(): wanted error=%v, got=%v", nil, err) - } - -} - -func TestTLSConn_doRead(t *testing.T) { - tt, _ := makeTestingTLSTransportWithDefaultPacketPayload() - tc := &controlChannelTLSConn{transport: tt} - _, err := tc.doRead() - if err != nil { - t.Errorf("TLSConn.doRead(): expected nil error") - return - } - - tc = &controlChannelTLSConn{} - _, err = tc.doRead() - if !errors.Is(err, errBadInput) { - t.Errorf("TLSConn.doRead(): should fail with nil transport. got: %v, wanted: %v", err, errBadInput) - return - } - -} - -func TestTLSConn_canRead(t *testing.T) { - tc := &controlChannelTLSConn{ - session: makeTestingSession(), - } - canRead := tc.canRead(nil) - if canRead { - t.Errorf("TLSConn.canRead() should return false with nil packet") - } - - pNext := &packet{id: 1} - canRead = tc.canRead(pNext) - if !canRead { - t.Errorf("TLSConn.canRead() should be able to read pID = 1") - } - - pEq := &packet{id: 0} - canRead = tc.canRead(pEq) - if canRead { - t.Errorf("TLSConn.canRead() should not able to read pID = 0") - } - - tc.session.localPacketID = packetID(42) - pMore := &packet{id: 44} - canRead = tc.canRead(pMore) - if canRead { - t.Errorf("TLSConn.canRead() should not able to read pID = 44") - } - - pLess := &packet{id: 41} - canRead = tc.canRead(pLess) - if canRead { - t.Errorf("TLSConn.canRead() should not able to read pID = 41") - } -} - -func Test_writeAndReadFromBuffer(t *testing.T) { - bb := &bytes.Buffer{} - b := make([]byte, 255) - payload := []byte("this test is green") - n, err := writeAndReadFromBuffer(bb, b, payload) - if err != nil { - t.Error("writeAndReadFromBuffer(): expected no error") - } - if n != len(payload) { - t.Errorf("writeAndReadFromBuffer(): got len = %v, wanted = %v", n, len(payload)) - } -} - -func TestTLSConn_Close(t *testing.T) { - tc, conn := makeTestingTLSConn() - err := tc.Close() - if err != nil { - t.Errorf("TLSConn.Close() error = %v, want = nil", err) - } - if !conn.closedCalled { - t.Error("TLSConn.Close(): conn.Close() not called") - } -} - -func TestTLSConn_LocalAddr(t *testing.T) { - tc, conn := makeTestingTLSConn() - want := "1.1.1.1" - if addr := tc.LocalAddr(); addr.String() != want { - t.Errorf("TLSConn.LocalAddr() got = %s, want = %s", addr, want) - } - if !conn.localAddrCalled { - t.Error("TLSConn.LocalAddr(): conn.LocalAddr() not called") - } -} - -func TestTLSConn_RemoteAddr(t *testing.T) { - tc, conn := makeTestingTLSConn() - want := "2.2.2.2" - if addr := tc.RemoteAddr(); addr.String() != want { - t.Errorf("TLSConn.RemoteAddr() got = %s, want = %s", addr, want) - } - if !conn.remoteAddrCalled { - t.Error("TLSConn.RemoteAddr(): conn.RemoteAddr() not called") - } -} - -func TestTLSConn_SetDeadline(t *testing.T) { - tc, conn := makeTestingTLSConn() - err := tc.SetDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("TLSConn.SetDeadline() error = %v, want = nil", err) - } - if !conn.setDeadlineCalled { - t.Error("TLSConn.SetDeadline(): conn.SetDeadline() not called") - } -} - -func TestTLSConn_SetReadDeadline(t *testing.T) { - tc, conn := makeTestingTLSConn() - err := tc.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("TLSConn.SetReadDeadline() error = %v, want = nil", err) - } - if !conn.setReadDeadlineCalled { - t.Error("TLSConn.SetReadDeadline(): conn.SetReadDeadline() not called") - } -} - -func TestTLSConn_SetWriteDeadline(t *testing.T) { - tc, conn := makeTestingTLSConn() - err := tc.SetWriteDeadline(time.Now().Add(time.Second)) - if err != nil { - t.Errorf("TLSConn.SetWriteDeadline() error = %v, want = nil", err) - } - if !conn.setWriteDeadlineCalled { - t.Error("TLSConn.SetWriteDeadline(): conn.SetWriteDeadline() not called") - } -} - -func TestTLSConn_Write(t *testing.T) { - a := &mocks.Addr{} - a.MockNetwork = func() string { return "udp" } - conn := &mocks.Conn{} - conn.MockLocalAddr = func() net.Addr { return a } - c := &MockTLSTransportConn{Conn: conn} - c.MockWrite = func(b []byte) (int, error) { - c.written = b - return len(b), nil - } - s := makeTestingSession() - tlsTr := &tlsTransport{Conn: c, session: s} - tc := &controlChannelTLSConn{transport: tlsTr, session: s} - - payload := []byte("this is fine") - want := append( - []byte{0x20, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - payload...) - _, err := tc.Write(payload) - if err != nil { - t.Errorf("TLSConn.Write(): expected err = nil, got = %v", err) - } - if !bytes.Equal(c.written, want) { - t.Errorf("TLSConn.Write(): written = %v, want = %v", c.written, want) - } -} diff --git a/vpn/utils.go b/vpn/utils.go deleted file mode 100644 index fd234a6c..00000000 --- a/vpn/utils.go +++ /dev/null @@ -1,12 +0,0 @@ -package vpn - -// -// Utility functions -// - -// panicIfFalse calls panic with the given message if the given statement is false. -func panicIfFalse(stmt bool, message interface{}) { - if !stmt { - panic(message) - } -} diff --git a/vpn/utils_test.go b/vpn/utils_test.go deleted file mode 100644 index 51f4eec1..00000000 --- a/vpn/utils_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package vpn - -import "testing" - -func Test_panicIfFalse(t *testing.T) { - t.Run("panics when false", func(t *testing.T) { - var happened bool - func() { - defer func() { - happened = recover() != nil - }() - panicIfFalse(false, "should happen") - }() - if !happened { - t.Fatal("did not panic") - } - }) - - t.Run("does nothing when true", func(t *testing.T) { - panicIfFalse(true, "should not happen") - }) -}