diff --git a/.github/workflows/build-refactor.yml b/.github/workflows/build-refactor.yml
deleted file mode 100644
index 53b284e6..00000000
--- a/.github/workflows/build-refactor.yml
+++ /dev/null
@@ -1,66 +0,0 @@
-name: build-refactor
-# this action is covering internal/ tree with go1.21
-
-on:
- push:
- branches:
- - main
- pull_request:
- branches:
- - main
-
-jobs:
- short-tests:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - name: setup go
- uses: actions/setup-go@v5
- with:
- go-version: '1.21'
- - name: Run short tests
- run: go test --short -cover ./internal/...
-
- lint:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - name: Lint with revive action, from pre-built image
- uses: docker://morphy/revive-action:v2
- with:
- path: "internal/..."
-
- gosec:
- runs-on: ubuntu-latest
- env:
- GO111MODULE: on
- steps:
- - name: Checkout Source
- uses: actions/checkout@v4
- - name: Run Gosec security scanner
- uses: securego/gosec@master
- with:
- args: '-no-fail ./...'
-
- coverage-threshold:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - name: setup go
- uses: actions/setup-go@v5
- with:
- go-version: '1.21'
- - name: Ensure coverage threshold
- run: make test-coverage-threshold-refactor
-
- integration:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - name: setup go
- uses: actions/setup-go@v5
- with:
- go-version: '1.21'
- - name: run integration tests
- run: go run ./tests/integration
-
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 4995f6a8..140d09ed 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,24 +1,34 @@
name: build
+# this action is covering internal/ tree with go1.20
on:
push:
branches:
- - main
+ - 'main'
pull_request:
branches:
- - main
+ - 'main'
jobs:
short-tests:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- name: setup go
- uses: actions/setup-go@v2
+ uses: actions/setup-go@v5
with:
go-version: '1.20'
- name: Run short tests
- run: go test --short -cover ./vpn
+ run: go test --short -cover ./internal/...
+
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Lint with revive action, from pre-built image
+ uses: docker://morphy/revive-action:v2
+ with:
+ path: "internal/..."
gosec:
runs-on: ubuntu-latest
@@ -26,7 +36,7 @@ jobs:
GO111MODULE: on
steps:
- name: Checkout Source
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Run Gosec security scanner
uses: securego/gosec@master
with:
@@ -35,10 +45,22 @@ jobs:
coverage-threshold:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- name: setup go
- uses: actions/setup-go@v2
+ uses: actions/setup-go@v5
with:
go-version: '1.20'
- name: Ensure coverage threshold
run: make test-coverage-threshold
+
+ integration:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: setup go
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.20'
+ - name: run integration tests
+ run: go run ./tests/integration
+
diff --git a/Makefile b/Makefile
index df14cc49..8db170cc 100644
--- a/Makefile
+++ b/Makefile
@@ -16,13 +16,7 @@ build-rel:
@GOOS=windows go build ${FLAGS} -o minivpn.exe
build-race:
- @go build -race
-
-build-ping:
- @go build -v ./cmd/vpnping
-
-build-ndt7:
- @go build -o ndt7 ./cmd/ndt7
+ @go build -race ./cmd/minivpn
bootstrap:
@./scripts/bootstrap-provider ${PROVIDER}
@@ -45,10 +39,6 @@ test-combined-coverage:
scripts/go-coverage-check.sh ./coverage/profile.out ${COVERAGE_THRESHOLD}
test-coverage-threshold:
- go test --short -coverprofile=cov-threshold.out ./vpn
- ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD}
-
-test-coverage-threshold-refactor:
go test --short -coverprofile=cov-threshold-refactor.out ./internal/...
./scripts/go-coverage-check.sh cov-threshold-refactor.out ${COVERAGE_THRESHOLD}
@@ -56,7 +46,7 @@ test-short:
go test -race -short -v ./...
test-ping:
- ./minivpn -c data/${PROVIDER}/config -t ${TARGET} -n ${COUNT} ping
+ ./minivpn -c data/${PROVIDER}/config -ping
integration-server:
# this needs the container from https://github.com/ainghazal/docker-openvpn
@@ -66,12 +56,6 @@ test-fetch-config:
rm -rf data/tests
mkdir -p data/tests && curl 172.17.0.2:8080/ > data/tests/config
-test-ping-local:
- # run the integration-server first
- ./minivpn -c data/tests/config -t 172.17.0.1 -n ${COUNT} ping
-
-test-local: test-fetch-config test-ping-local
-
qa:
@# all the steps at once
cd tests/integration && ./run-server.sh &
@@ -79,7 +63,7 @@ qa:
@rm -rf data/tests
@mkdir -p data/tests && curl 172.17.0.2:8080/ > data/tests/config
@sleep 1
- ./minivpn -c data/tests/config -t 172.17.0.1 -n ${COUNT} ping
+ ./minivpn -c data/tests/config -ping
@docker stop ovpn1
integration:
@@ -88,17 +72,6 @@ integration:
filternet-qa:
cd tests/qa && ./run-filternet.sh remote-block-all
-coverage:
- go test -coverprofile=coverage.out ./vpn
- go tool cover -html=coverage.out
-
-coverage-ping:
- go test -coverprofile=coverage-ping.out ./extras/ping
- go tool cover -html=coverage-ping.out
-
-proxy:
- ./minivpn -c data/${PROVIDER}/config proxy
-
backup-data:
@tar cvzf ../data-vpn-`date +'%F'`.tar.gz
@@ -122,5 +95,9 @@ go-sec:
go-revive:
revive internal/...
+install-linters:
+ go install github.com/mgechev/revive@latest
+ go install github.com/securego/gosec/v2/cmd/gosec@latest
+
clean:
@rm -f coverage.out
diff --git a/extras/example_pinger_test.go b/extras/example_pinger_test.go
deleted file mode 100644
index b197510c..00000000
--- a/extras/example_pinger_test.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package extras
-
-import (
- "context"
- "os"
- "time"
-
- "github.com/ooni/minivpn/extras/ping"
- "github.com/ooni/minivpn/vpn"
-)
-
-var (
- cfg = "data/riseup/config"
- target = "8.8.8.8"
- count = 3
-)
-
-func ExamplePinger() {
- opts, err := vpn.NewOptionsFromFilePath(cfg)
- if err != nil {
- os.Exit(1)
- }
- tunnel := vpn.NewClientFromOptions(opts)
- tunnel.Start(context.Background())
- pinger := ping.New(target, tunnel)
- pinger.Count = 3
- pinger.Timeout = 5 * time.Second
- pinger.Run(context.Background())
-}
diff --git a/extras/ndt7.go b/extras/ndt7.go
deleted file mode 100644
index 80d2e5b5..00000000
--- a/extras/ndt7.go
+++ /dev/null
@@ -1,134 +0,0 @@
-package extras
-
-/*
- Vendoring of m-lab's ndt7-client reference code, to be able to experiment
- and manipulate parts that are exposed as internal in the library.
- Upstream: https://github.com/m-lab/ndt7-client-go/
-
- SPDX-License-Identifier: Apache-2.0
-
- (c) Stephen Soltesz
- (c) Peter Boothe
- (c) Simone Basso
- (c) Ain Ghazal
-*/
-
-import (
- "context"
- "crypto/tls"
- "log"
- "os"
- "time"
-
- "github.com/gorilla/websocket"
- "github.com/m-lab/ndt7-client-go"
- "github.com/m-lab/ndt7-client-go/spec"
- "github.com/ooni/minivpn/extras/ndt7/emitter"
- "github.com/ooni/minivpn/vpn"
-)
-
-const (
- clientName = "minivpn-ndt7-client"
- clientVersion = "0.6.1"
- defaultTimeout = 10 * time.Second
-)
-
-type runner struct {
- client *ndt7.Client
- emitter emitter.Emitter
-}
-
-func (r runner) runDownload(ctx context.Context) int {
- return r.runTest(ctx, spec.TestDownload, r.client.StartDownload,
- r.emitter.OnDownloadEvent)
-}
-
-func (r runner) runUpload(ctx context.Context) int {
- return r.runTest(ctx, spec.TestUpload, r.client.StartUpload,
- r.emitter.OnUploadEvent)
-}
-
-func (r runner) runTest(
- ctx context.Context, test spec.TestKind,
- start func(context.Context) (<-chan spec.Measurement, error),
- emitEvent func(m *spec.Measurement) error,
-) int {
- // Implementation note: we want to always emit the initial and the
- // final events regardless of how the actual test goes. What's more,
- // we want the exit code to be nonzero in case of any error.
- err := r.emitter.OnStarting(test)
- if err != nil {
- return 1
- }
- code := r.doRunTest(ctx, test, start, emitEvent)
- err = r.emitter.OnComplete(test)
- if err != nil {
- return 1
- }
- return code
-}
-
-func (r runner) doRunTest(
- ctx context.Context, test spec.TestKind,
- start func(context.Context) (<-chan spec.Measurement, error),
- emitEvent func(m *spec.Measurement) error,
-) int {
- ch, err := start(ctx)
- if err != nil {
- _ = r.emitter.OnError(test, err)
- return 1
- }
- err = r.emitter.OnConnected(test, r.client.FQDN)
- if err != nil {
- return 1
- }
- for ev := range ch {
- err = emitEvent(&ev)
- if err != nil {
- return 1
- }
- }
- return 0
-}
-
-// TODO use memoryless to repeat measurements, gather the json outputs and
-// return a measurement batch.
-
-// RunMeasurement performs a download & upload measurement against a given ndt7 server.
-// It expects a vpn Dialer and a server string (ip:port).
-// If the direct parameter is set to true, the vpn Dialer will not be used and
-// a direct connection will be used instead.
-func RunMeasurement(d vpn.TunDialer, ndt7Server string, mode string, direct bool) {
- ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
- defer cancel()
- var r runner
-
- insecureTLS := false
- if os.Getenv("TLS_NOVERIFY") == "1" {
- insecureTLS = true
- }
-
- vpnDialer := websocket.Dialer{
- // TODO(ainghazal): pass a config flag to force the InsecureSkipVerify config,
- // this should not be used in production.
- TLSClientConfig: &tls.Config{
- InsecureSkipVerify: insecureTLS,
- },
- } //#nosec G402
- if direct == false {
- vpnDialer.NetDialContext = d.DialContext
- } else {
- log.Println("using a direct connection to ndt7 server")
- }
-
- r.client = ndt7.NewClient(clientName, clientVersion)
- r.client.Server = ndt7Server
- r.client.Dialer = vpnDialer
- r.emitter = emitter.NewJSON(os.Stdout)
- switch mode {
- case "download":
- r.runDownload(ctx)
- case "upload":
- r.runUpload(ctx)
- }
-}
diff --git a/extras/ndt7/emitter/emitter.go b/extras/ndt7/emitter/emitter.go
deleted file mode 100644
index a96da615..00000000
--- a/extras/ndt7/emitter/emitter.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Package emitter contains the ndt7-client emitter.
-package emitter
-
-/*
- Vendoring of m-lab's ndt7-client reference code, to be able to experiment
- and manipulate parts that are exposed as internal in the library.
- Upstream: https://github.com/m-lab/ndt7-client-go/
-
- SPDX-License-Identifier: Apache-2.0
-
- (c) Stephen Soltesz
- (c) Peter Boothe
- (c) Simone Basso
- (c) Ain Ghazal
-*/
-
-import (
- "github.com/m-lab/ndt7-client-go/spec"
-)
-
-// Emitter is a generic emitter. When an event occurs, the
-// corresponding method will be called. An error will generally
-// mean that it's not possible to write the output. A common
-// case where this happen is where the output is redirected to
-// a file on a full hard disk.
-//
-// See the documentation of the main package for more details
-// on the sequence in which events may occur.
-type Emitter interface {
- // OnStarting is emitted before attempting to start a test.
- OnStarting(test spec.TestKind) error
-
- // OnError is emitted if a test cannot start.
- OnError(test spec.TestKind, err error) error
-
- // OnConnected is emitted when we connected to the ndt7 server.
- OnConnected(test spec.TestKind, fqdn string) error
-
- // OnDownloadEvent is emitted during the download.
- OnDownloadEvent(m *spec.Measurement) error
-
- // OnUploadEvent is emitted during the upload.
- OnUploadEvent(m *spec.Measurement) error
-
- // OnComplete is always emitted when the test is over.
- OnComplete(test spec.TestKind) error
-
- // OnSummary is emitted after the test is over.
- OnSummary(s *Summary) error
-}
diff --git a/extras/ndt7/emitter/json.go b/extras/ndt7/emitter/json.go
deleted file mode 100644
index f33c40f4..00000000
--- a/extras/ndt7/emitter/json.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package emitter
-
-/*
- Vendoring of m-lab's ndt7-client reference code, to be able to experiment
- and manipulate parts that are exposed as internal in the library.
- Upstream: https://github.com/m-lab/ndt7-client-go/
-
- SPDX-License-Identifier: Apache-2.0
-
- (c) Stephen Soltesz
- (c) Peter Boothe
- (c) Simone Basso
- (c) Ain Ghazal
-*/
-
-import (
- "encoding/json"
- "io"
-
- "github.com/m-lab/ndt7-client-go/spec"
-)
-
-// jsonEmitter is a jsonEmitter emitter. It emits messages consistent with
-// the cmd/ndt7-client/main.go documentation for `-format=json`.
-type jsonEmitter struct {
- io.Writer
-}
-
-// NewJSON creates a new JSON emitter
-func NewJSON(w io.Writer) Emitter {
- return jsonEmitter{
- Writer: w,
- }
-}
-
-func (j jsonEmitter) emitData(data []byte) error {
- _, err := j.Write(append(data, byte('\n')))
- return err
-}
-
-func (j jsonEmitter) emitInterface(any interface{}) error {
- data, err := json.Marshal(any)
- if err != nil {
- return err
- }
- return j.emitData(data)
-}
-
-type batchEvent struct {
- Key string
- Value interface{}
-}
-
-type batchValue struct {
- spec.Measurement
- Failure string `json:",omitempty"`
- Server string `json:",omitempty"`
-}
-
-// OnStarting emits the starting event
-func (j jsonEmitter) OnStarting(test spec.TestKind) error {
- return j.emitInterface(batchEvent{
- Key: "starting",
- Value: batchValue{
- Measurement: spec.Measurement{
- Test: test,
- },
- },
- })
-}
-
-// OnError emits the error event
-func (j jsonEmitter) OnError(test spec.TestKind, err error) error {
- return j.emitInterface(batchEvent{
- Key: "error",
- Value: batchValue{
- Measurement: spec.Measurement{
- Test: test,
- },
- Failure: err.Error(),
- },
- })
-}
-
-// OnConnected emits the connected event
-func (j jsonEmitter) OnConnected(test spec.TestKind, fqdn string) error {
- return j.emitInterface(batchEvent{
- Key: "connected",
- Value: batchValue{
- Measurement: spec.Measurement{
- Test: test,
- },
- Server: fqdn,
- },
- })
-}
-
-// OnDownloadEvent handles an event emitted during the download
-func (j jsonEmitter) OnDownloadEvent(m *spec.Measurement) error {
- return j.emitInterface(batchEvent{
- Key: "measurement",
- Value: m,
- })
-}
-
-// OnUploadEvent handles an event emitted during the upload
-func (j jsonEmitter) OnUploadEvent(m *spec.Measurement) error {
- return j.emitInterface(batchEvent{
- Key: "measurement",
- Value: m,
- })
-}
-
-// OnComplete is the event signalling the end of the test
-func (j jsonEmitter) OnComplete(test spec.TestKind) error {
- return j.emitInterface(batchEvent{
- Key: "complete",
- Value: batchValue{
- Measurement: spec.Measurement{
- Test: test,
- },
- },
- })
-}
-
-// OnSummary handles the summary event, emitted after the test is over.
-func (j jsonEmitter) OnSummary(s *Summary) error {
- return j.emitInterface(s)
-}
diff --git a/extras/ndt7/emitter/summary.go b/extras/ndt7/emitter/summary.go
deleted file mode 100644
index 7282d1a1..00000000
--- a/extras/ndt7/emitter/summary.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package emitter
-
-/*
- Vendoring of m-lab's ndt7-client reference code, to be able to experiment
- and manipulate parts that are exposed as internal in the library.
- Upstream: https://github.com/m-lab/ndt7-client-go/
-
- SPDX-License-Identifier: Apache-2.0
-
- (c) Stephen Soltesz
- (c) Peter Boothe
- (c) Simone Basso
-*/
-
-// ValueUnitPair represents a {"Value": ..., "Unit": ...} pair.
-type ValueUnitPair struct {
- Value float64
- Unit string
-}
-
-// Summary is a struct containing the values displayed to the user at
-// the end of an ndt7 test.
-type Summary struct {
- // ServerFQDN is the FQDN of the server used for this test.
- ServerFQDN string
-
- // ServerIP is the (v4 or v6) IP address of the server.
- ServerIP string
-
- // ClientIP is the (v4 or v6) IP address of the client.
- ClientIP string
-
- // DownloadUUID is the UUID of the download test.
- // TODO: add UploadUUID after we start processing counterflow messages.
- DownloadUUID string
-
- // Download is the download speed, in Mbit/s. This is measured at the
- // receiver.
- Download ValueUnitPair
-
- // Upload is the upload speed, in Mbit/s. This is measured at the sender.
- Upload ValueUnitPair
-
- // DownloadRetrans is the retransmission rate. This is based on the TCPInfo
- // values provided by the server during a download test.
- DownloadRetrans ValueUnitPair
-
- // RTT is the round-trip time of the latest measurement, in milliseconds.
- // This is provided by the server during a download test.
- MinRTT ValueUnitPair
-}
-
-// NewSummary returns a new Summary struct for a given FQDN.
-func NewSummary(FQDN string) *Summary {
- return &Summary{
- ServerFQDN: FQDN,
- }
-}
diff --git a/extras/ping/ping_test.go b/extras/ping/ping_test.go
index e2d6c3b8..02d1ecbb 100644
--- a/extras/ping/ping_test.go
+++ b/extras/ping/ping_test.go
@@ -10,8 +10,9 @@ import (
"testing"
"time"
+ "github.com/ooni/minivpn/internal/mocks"
+
"github.com/google/uuid"
- "github.com/ooni/minivpn/vpn/mocks"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
diff --git a/go.mod b/go.mod
index 3b12fd46..bf5a7ed7 100644
--- a/go.mod
+++ b/go.mod
@@ -7,32 +7,34 @@ go 1.20
require (
git.torproject.org/pluggable-transports/goptlib.git v1.3.0
+ github.com/Doridian/water v1.6.1
github.com/apex/log v1.9.0
- github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/google/go-cmp v0.5.9
github.com/google/gopacket v1.1.19
github.com/google/martian v2.1.0+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
+ github.com/jackpal/gateway v1.0.11 // pinned to a previous version until we can use go1.21
github.com/m-lab/ndt7-client-go v0.7.0
github.com/ory/dockertest/v3 v3.9.1
- github.com/pborman/getopt/v2 v2.1.0
github.com/refraction-networking/utls v1.3.1
gitlab.com/yawning/obfs4.git v0.0.0-20220904064028-336a71d6e4cf
- golang.org/x/net v0.17.0
- golang.org/x/sync v0.4.0
+ golang.org/x/net v0.22.0
+ golang.org/x/sync v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
)
require (
filippo.io/edwards25519 v1.0.0-rc.1.0.20210721174708-390f27c3be20 // indirect
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect
+ github.com/Doridian/gopacket v1.2.1 // indirect
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/containerd/continuity v0.3.0 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dchest/siphash v1.2.1 // indirect
github.com/docker/cli v20.10.14+incompatible // indirect
github.com/docker/docker v20.10.7+incompatible // indirect
@@ -53,18 +55,22 @@ require (
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/opencontainers/runc v1.1.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
+ github.com/stretchr/objx v0.5.0 // indirect
github.com/stretchr/testify v1.8.4 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb // indirect
- golang.org/x/crypto v0.14.0 // indirect
- golang.org/x/mod v0.13.0 // indirect
- golang.org/x/sys v0.13.0 // indirect
+ golang.org/x/crypto v0.21.0 // indirect
+ golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
+ golang.org/x/mod v0.16.0 // indirect
+ golang.org/x/sys v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
- golang.org/x/tools v0.14.0 // indirect
+ golang.org/x/tools v0.19.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
)
diff --git a/go.sum b/go.sum
index 34049e97..ed35b385 100644
--- a/go.sum
+++ b/go.sum
@@ -35,6 +35,10 @@ github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7O
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/Doridian/gopacket v1.2.1 h1:z0Iu5zplIq01nGNwKoreAhc/RMIUqu6vZLxLsHjpO48=
+github.com/Doridian/gopacket v1.2.1/go.mod h1:16EwY3JsEHp3TFeSRcmSC9yOdG8GkFAWImZaL13kOGc=
+github.com/Doridian/water v1.6.1 h1:cszUUfRlk9duYwv1bE5mFluhaHJK0TAIdPLJk2DV0cI=
+github.com/Doridian/water v1.6.1/go.mod h1:bjQW67+p0YKqdrjoDWpJ+bcRvrrHphxwZOC6OTfBuf8=
github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg=
github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE=
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw=
@@ -52,8 +56,6 @@ github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy
github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys=
github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 h1:TEBmxO80TM04L8IuMWk77SGL1HomBmKTdzdJLLWznxI=
github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1/go.mod h1:SLqhdZcd+dF3TEVL2RMoob5bBP5R1P1qkox+HtCBgGI=
-github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
-github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
@@ -111,6 +113,7 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
+github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-test/deep v1.0.6 h1:UHSEyLZUwX9Qoi99vVwvewiMC8mM2bf7XEM2nqvzEn8=
github.com/go-test/deep v1.0.6/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8=
@@ -146,6 +149,7 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
+github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
@@ -191,6 +195,10 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
+github.com/jackpal/gateway v1.0.11 h1:XqCVFIyo2LtQYXjz9nis1WMTvAadJiFP/Zc04xmdEYE=
+github.com/jackpal/gateway v1.0.11/go.mod h1:NqRwEsSP/DD8d4YXIsHEMNUSYetesFXjmL6QZFrul+M=
+github.com/jackpal/gateway v1.0.13 h1:fJccMvawxx0k7S1q7Fy/SXFE0R3hMXkMuw8y9SofWAk=
+github.com/jackpal/gateway v1.0.13/go.mod h1:6c8LjW+FVESFmwxaXySkt7fU98Yv806ADS3OY6Cvh2U=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@@ -217,6 +225,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394=
+github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/m-lab/access v0.0.3 h1:ixaAZNvN/ggOj1/FOZQhw4r31QA/WlApn0HH71LSHgQ=
github.com/m-lab/access v0.0.3/go.mod h1:gZ7YN3SeMTZYeRv5EFaLdG+XVI/F/X4njM1G1BfwuE4=
github.com/m-lab/go v0.1.43 h1:AKmrhhi5a5rUL9nNMo0YNLSJLNFLfV5PmdAXqRCBGE8=
@@ -262,8 +271,6 @@ github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.m
github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/ory/dockertest/v3 v3.9.1 h1:v4dkG+dlu76goxMiTT2j8zV7s4oPPEppKT8K8p2f1kY=
github.com/ory/dockertest/v3 v3.9.1/go.mod h1:42Ir9hmvaAPm0Mgibk6mBPi7SFvTXxEcnztDYOJ//uM=
-github.com/pborman/getopt/v2 v2.1.0 h1:eNfR+r+dWLdWmV8g5OlpyrTYHkhVNxHBdN2cCrJmOEA=
-github.com/pborman/getopt/v2 v2.1.0/go.mod h1:4NtW75ny4eBw9fO1bhtNdYTlZKYX5/tBLtsOpwKIKd0=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@@ -307,11 +314,16 @@ github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4S
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
@@ -325,6 +337,7 @@ github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKw
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
+github.com/vishvananda/netns v0.0.1/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
@@ -334,6 +347,7 @@ github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQ
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb h1:qRSZHsODmAP5qDvb3YsO7Qnf3TRiVbGxNG/WYnlM4/o=
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo=
gitlab.com/yawning/obfs4.git v0.0.0-20220904064028-336a71d6e4cf h1:k9czJST0Jvc6fnz4Jp1sxRmA4dSuiWFq+DVpxLZP5yM=
@@ -350,8 +364,11 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
+golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
+golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -362,6 +379,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
+golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -382,8 +401,11 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
+golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -409,8 +431,12 @@ golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
+golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
+golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -424,8 +450,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
+golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
+golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -451,6 +480,7 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -466,15 +496,24 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
+golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -518,8 +557,11 @@ golang.org/x/tools v0.0.0-20200409170454-77362c5149f0/go.mod h1:EkVYQZoAsY45+roY
golang.org/x/tools v0.0.0-20200422205258-72e4a01eba43/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
+golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
+golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -594,6 +636,7 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.2-0.20230118093459-a9481185b34d h1:qp0AnQCvRCMlu9jBjtdbTaaEmThIgZOrbVyDEOcmKhQ=
+google.golang.org/protobuf v1.28.2-0.20230118093459-a9481185b34d/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -615,8 +658,10 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk=
gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o=
+gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
diff --git a/vpn/mocks/README.md b/internal/mocks/README.md
similarity index 100%
rename from vpn/mocks/README.md
rename to internal/mocks/README.md
diff --git a/vpn/mocks/addr.go b/internal/mocks/addr.go
similarity index 100%
rename from vpn/mocks/addr.go
rename to internal/mocks/addr.go
diff --git a/vpn/mocks/dialer.go b/internal/mocks/dialer.go
similarity index 100%
rename from vpn/mocks/dialer.go
rename to internal/mocks/dialer.go
diff --git a/internal/model/packet.go b/internal/model/packet.go
index 2ed5f674..1ead02a2 100644
--- a/internal/model/packet.go
+++ b/internal/model/packet.go
@@ -339,6 +339,7 @@ func (p *Packet) IsData() bool {
var pingPayload = []byte{0x2A, 0x18, 0x7B, 0xF3, 0x64, 0x1E, 0xB4, 0xCB, 0x07, 0xED, 0x2D, 0x0A, 0x98, 0x1F, 0xC7, 0x48}
+// IsPing returns true if this packet matches a openvpn ping packet.
func (p *Packet) IsPing() bool {
return bytes.Equal(pingPayload, p.Payload)
}
diff --git a/internal/model/session.go b/internal/model/session.go
index 5e181737..c2f10823 100644
--- a/internal/model/session.go
+++ b/internal/model/session.go
@@ -7,7 +7,7 @@ const (
// S_ERROR means there was some form of protocol error.
S_ERROR = NegotiationState(iota) - 1
- // S_UNDER is the undefined state.
+ // S_UNDEF is the undefined state.
S_UNDEF
// S_INITIAL means we're ready to begin the three-way handshake.
diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go
index 56cf3fe9..d3ab453b 100644
--- a/internal/reliabletransport/reliable_ack_test.go
+++ b/internal/reliabletransport/reliable_ack_test.go
@@ -1,7 +1,6 @@
package reliabletransport
import (
- "slices"
"testing"
"time"
@@ -9,6 +8,9 @@ import (
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/vpntest"
"github.com/ooni/minivpn/pkg/config"
+
+ // TODO: replace with stlib slices after 1.21
+ "golang.org/x/exp/slices"
)
// test that everything that is received from below is eventually ACKed to the sender.
diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go
index eb25afb3..024b4c46 100644
--- a/internal/reliabletransport/sender_test.go
+++ b/internal/reliabletransport/sender_test.go
@@ -2,13 +2,15 @@ package reliabletransport
import (
"reflect"
- "slices"
"testing"
"time"
"github.com/apex/log"
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/optional"
+
+ // TODO: replace with stdlib slices after 1.21
+ "golang.org/x/exp/slices"
)
func idSequence(s inflightSequence) []model.PacketID {
diff --git a/internal/session/datachannelkey.go b/internal/session/datachannelkey.go
index 3c903f0b..de8885ac 100644
--- a/internal/session/datachannelkey.go
+++ b/internal/session/datachannelkey.go
@@ -6,6 +6,11 @@ import (
"sync"
)
+var (
+ // ErrDataChannelKey is a [DataChannelKey] error.
+ ErrDataChannelKey = errors.New("bad data-channel key")
+)
+
// DataChannelKey represents a pair of key sources that have been negotiated
// over the control channel, and from which we will derive local and remote
// keys for encryption and decrption over the data channel. The index refers to
@@ -23,9 +28,6 @@ type DataChannelKey struct {
mu sync.Mutex
}
-// errDayaChannelKey is a [DataChannelKey] error.
-var errDataChannelKey = errors.New("bad data-channel key")
-
// Local returns the local [KeySource]
func (dck *DataChannelKey) Local() *KeySource {
return dck.local
@@ -42,7 +44,7 @@ func (dck *DataChannelKey) AddRemoteKey(k *KeySource) error {
dck.mu.Lock()
defer dck.mu.Unlock()
if dck.ready {
- return fmt.Errorf("%w: %s", errDataChannelKey, "cannot overwrite remote key slot")
+ return fmt.Errorf("%w: %s", ErrDataChannelKey, "cannot overwrite remote key slot")
}
dck.remote = k
dck.ready = true
diff --git a/internal/session/manager.go b/internal/session/manager.go
index c24637b1..113e9ee4 100644
--- a/internal/session/manager.go
+++ b/internal/session/manager.go
@@ -14,6 +14,14 @@ import (
"github.com/ooni/minivpn/pkg/config"
)
+var (
+ // ErrExpiredKey is the error we raise when we have an expired key.
+ ErrExpiredKey = errors.New("expired key")
+
+ // ErrNoRemoteSessionID indicates we are missing the remote session ID.
+ ErrNoRemoteSessionID = errors.New("missing remote session ID")
+)
+
// Manager manages the session. The zero value is invalid. Please, construct
// using [NewManager]. This struct is concurrency safe.
type Manager struct {
@@ -103,9 +111,6 @@ func (m *Manager) IsRemoteSessionIDSet() bool {
return !m.remoteSessionID.IsNone()
}
-// ErrNoRemoteSessionID indicates we are missing the remote session ID.
-var ErrNoRemoteSessionID = errors.New("missing remote session ID")
-
// NewACKForPacket creates a new ACK for the given packet IDs.
func (m *Manager) NewACKForPacketIDs(ids []model.PacketID) (*model.Packet, error) {
defer m.mu.Unlock()
@@ -170,8 +175,6 @@ func (m *Manager) NewHardResetPacket() *model.Packet {
return packet
}
-var ErrExpiredKey = errors.New("expired key")
-
// LocalDataPacketID returns an unique Packet ID for the Data Channel. It
// increments the counter for the local data packet ID.
func (m *Manager) LocalDataPacketID() (model.PacketID, error) {
@@ -228,7 +231,7 @@ func (m *Manager) ActiveKey() (*DataChannelKey, error) {
defer m.mu.Unlock()
m.mu.Lock()
if len(m.keys) > math.MaxUint8 || m.keyID >= uint8(len(m.keys)) {
- return nil, fmt.Errorf("%w: %s", errDataChannelKey, "no such key id")
+ return nil, fmt.Errorf("%w: %s", ErrDataChannelKey, "no such key id")
}
dck := m.keys[m.keyID]
return dck, nil
diff --git a/internal/tlssession/tlshandshake_test.go b/internal/tlssession/tlshandshake_test.go
index 46b2a4d4..95bdb6df 100644
--- a/internal/tlssession/tlshandshake_test.go
+++ b/internal/tlssession/tlshandshake_test.go
@@ -14,8 +14,8 @@ import (
"time"
"github.com/google/martian/mitm"
+ "github.com/ooni/minivpn/internal/mocks"
"github.com/ooni/minivpn/pkg/config"
- "github.com/ooni/minivpn/vpn/mocks"
tls "github.com/refraction-networking/utls"
)
diff --git a/internal/vpntest/certs.go b/internal/vpntest/certs.go
index d31fb6a9..3b0b4f8f 100644
--- a/internal/vpntest/certs.go
+++ b/internal/vpntest/certs.go
@@ -78,12 +78,14 @@ S5nL4GaRzx84PB1HWONlh0Wp7KBk2j6Lp0acoJwI2mHJcJoOPpaYiWWYNNTjMv2/
XXNUizTI136liavLslSMoYkjYAun+5HOux/keA1L+lm2XeG06Ew1qS4=
-----END CERTIFICATE-----`)
+// TestingCert holds key, cert and ca to pass to tests needing to mock certificates.
type TestingCert struct {
Cert string
Key string
CA string
}
+// WriteTEestingCerts will write valid certificates in the passed dir, and return a [TestingCert] and any error.
func WriteTestingCerts(dir string) (TestingCert, error) {
certFile, err := os.CreateTemp(dir, "tmpfile-")
if err != nil {
diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go
index f2e74ccb..48bc9d06 100644
--- a/internal/vpntest/packetio.go
+++ b/internal/vpntest/packetio.go
@@ -3,7 +3,6 @@ package vpntest
import (
"fmt"
"regexp"
- "slices"
"strconv"
"sync"
"time"
@@ -11,6 +10,9 @@ import (
"github.com/apex/log"
"github.com/ooni/minivpn/internal/bytesx"
"github.com/ooni/minivpn/internal/model"
+
+ // TODO: replace with stdlib slices after 1.21
+ "golang.org/x/exp/slices"
)
// PacketWriter writes packets into a channel.
diff --git a/internal/vpntest/packetio_test.go b/internal/vpntest/packetio_test.go
index f39f9903..0744f8d5 100644
--- a/internal/vpntest/packetio_test.go
+++ b/internal/vpntest/packetio_test.go
@@ -3,12 +3,14 @@ package vpntest
import (
"bytes"
"reflect"
- "slices"
"testing"
"time"
"github.com/apex/log"
"github.com/ooni/minivpn/internal/model"
+
+ // TODO: replace with stdlib slices after 1.21
+ "golang.org/x/exp/slices"
)
func TestPacketLog_ACKs(t *testing.T) {
diff --git a/pkg/config/vpnoptions.go b/pkg/config/vpnoptions.go
index 2b3e320f..06d8ef36 100644
--- a/pkg/config/vpnoptions.go
+++ b/pkg/config/vpnoptions.go
@@ -77,7 +77,7 @@ const ProtoUDP = Proto("udp")
// ErrBadConfig is the generic error returned for invalid config files
var ErrBadConfig = errors.New("openvpn: bad config")
-// SupportCiphers defines the supported ciphers.
+// SupportedCiphers defines the supported ciphers.
var SupportedCiphers = []string{
"AES-128-CBC",
"AES-192-CBC",
diff --git a/pkg/config/vpnoptions_test.go b/pkg/config/vpnoptions_test.go
index b9cda01a..1000ac7a 100644
--- a/pkg/config/vpnoptions_test.go
+++ b/pkg/config/vpnoptions_test.go
@@ -365,7 +365,7 @@ func Test_parseProxyOBFS4(t *testing.T) {
opt := &OpenVPNOptions{}
obfs4Uri := "obfs4://foobar"
o, err := parseProxyOBFS4([]string{obfs4Uri}, opt)
- var wantErr error = nil
+ var wantErr error
if !errors.Is(err, wantErr) {
t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err)
}
@@ -386,7 +386,7 @@ func Test_parseCA(t *testing.T) {
t.Run("empty part should fail", func(t *testing.T) {
_, err := parseCA([]string{}, &OpenVPNOptions{}, "")
- var wantErr error = ErrBadConfig
+ wantErr := ErrBadConfig
if !errors.Is(err, wantErr) {
t.Errorf("parseCA(): want %v, got %v", wantErr, err)
}
@@ -404,7 +404,7 @@ func Test_parseCert(t *testing.T) {
t.Run("empty parts should fail", func(t *testing.T) {
_, err := parseCert([]string{}, &OpenVPNOptions{}, "")
- var wantErr error = ErrBadConfig
+ wantErr := ErrBadConfig
if !errors.Is(err, wantErr) {
t.Errorf("parseCert(): want %v, got %v", wantErr, err)
}
@@ -412,7 +412,7 @@ func Test_parseCert(t *testing.T) {
t.Run("non-existent cert should fail", func(t *testing.T) {
_, err := parseCert([]string{"/tmp/nonexistent"}, &OpenVPNOptions{}, "")
- var wantErr error = ErrBadConfig
+ var wantErr = ErrBadConfig
if !errors.Is(err, wantErr) {
t.Errorf("parseCert(): want %v, got %v", wantErr, err)
}
diff --git a/vpn/bytes.go b/vpn/bytes.go
deleted file mode 100644
index c018f364..00000000
--- a/vpn/bytes.go
+++ /dev/null
@@ -1,157 +0,0 @@
-package vpn
-
-//
-// Functions operating on bytes:
-//
-// 1. generating random bytes;
-//
-// 2. OpenVPN options encoding and decoding;
-//
-// 3. PKCS#7 padding and unpadding.
-//
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "math"
-)
-
-var (
- // errEncodeOption indicates an option encoding error occurred.
- errEncodeOption = errors.New("can't encode option")
-
- // errDecodeOption indicates an option decoding error occurred.
- errDecodeOption = errors.New("can't decode option")
-
- // errPaddingPKCS7 indicates that a PKCS#7 padding error has occurred.
- errPaddingPKCS7 = errors.New("PKCS#7 padding error")
-
- // errUnpaddingPKCS7 indicates that a PKCS#7 unpadding error has occurred.
- errUnpaddingPKCS7 = errors.New("PKCS#7 unpadding error")
-)
-
-// genRandomBytes returns an array of bytes with the given size using
-// a CSRNG, on success, or an error, in case of failure.
-func genRandomBytes(size int) ([]byte, error) {
- b := make([]byte, size)
- _, err := rand.Read(b)
- return b, err
-}
-
-// encodeOptionStringToBytes is used to encode the options string, username and password.
-//
-// According to the OpenVPN protocol, options are represented as a two-byte word,
-// plus the byte representation of the string, null-terminated.
-//
-// See https://openvpn.net/community-resources/openvpn-protocol/.
-//
-// This function returns errEncodeOption in case of failure.
-func encodeOptionStringToBytes(s string) ([]byte, error) {
- if len(s) >= math.MaxUint16 { // Using >= b/c we need to account for the final \0
- return nil, fmt.Errorf("%w:%s", errEncodeOption, "string too large")
- }
- data := make([]byte, 2)
- binary.BigEndian.PutUint16(data, uint16(len(s))+1)
- data = append(data, []byte(s)...)
- data = append(data, 0x00)
- return data, nil
-}
-
-// decodeOptionStringFromBytes returns the string-value for the null-terminated string
-// returned by the server when sending remote options to us.
-//
-// This function returns errDecodeOption on failure.
-func decodeOptionStringFromBytes(b []byte) (string, error) {
- if len(b) < 2 {
- return "", fmt.Errorf("%w: expected at least two bytes", errDecodeOption)
- }
- length := int(binary.BigEndian.Uint16(b[:2]))
- b = b[2:] // skip over the length
- // the server sends padding, so we cannot do a strict check
- if len(b) < length {
- return "", fmt.Errorf("%w: got %d, expected %d", errDecodeOption, len(b), length)
- }
- if len(b) <= 0 || length == 0 {
- return "", fmt.Errorf("%w: zero length encoded option is not possible: %s", errDecodeOption,
- "we need at least one byte for the trailing \\0")
- }
- if b[length-1] != 0x00 {
- return "", fmt.Errorf("%w: missing trailing \\0", errDecodeOption)
- }
- return string(b[:len(b)-1]), nil
-}
-
-// bytesUnpadPKCS7 performs the PKCS#7 unpadding of a byte array.
-func bytesUnpadPKCS7(b []byte, blockSize int) ([]byte, error) {
- // 1. check whether we can unpad at all
- if blockSize > math.MaxUint8 {
- return nil, fmt.Errorf("%w: blockSize too large", errUnpaddingPKCS7)
- }
- // 2. trivial case
- if len(b) <= 0 {
- return nil, fmt.Errorf("%w: passed empty buffer", errUnpaddingPKCS7)
- }
- // 4. read the padding size
- psiz := int(b[len(b)-1])
- // 5. enforce padding size constraints
- if psiz <= 0x00 {
- return nil, fmt.Errorf("%w: padding size cannot be zero", errUnpaddingPKCS7)
- }
- if psiz > blockSize {
- return nil, fmt.Errorf("%w: padding size cannot be larger than blockSize", errUnpaddingPKCS7)
- }
- // 6. compute the padding offset
- off := len(b) - psiz
- // 7. return unpadded bytes
- panicIfFalse(off >= 0 && off <= len(b), "off is out of bounds")
- return b[:off], nil
-}
-
-// bytesPadPKCS7 returns the PKCS#7 padding of a byte array.
-func bytesPadPKCS7(b []byte, blockSize int) ([]byte, error) {
- if blockSize <= 0 {
- return nil, fmt.Errorf("%w: %s", errBadInput, "blocksize cannot be negative or zero")
- }
- // If lth mod blockSize == 0, then the input gets appended a whole block size
- // See https://datatracker.ietf.org/doc/html/rfc5652#section-6.3
- if blockSize > math.MaxUint8 {
- // This padding method is well defined iff blockSize is less than 256.
- return nil, errPaddingPKCS7
- }
- psiz := blockSize - len(b)%blockSize
- padding := bytes.Repeat([]byte{byte(psiz)}, psiz)
- return append(b, padding...), nil
-}
-
-// bufReadUint32 is a convenience function that reads a uint32 from a 4-byte
-// buffer, returning an error if the operation failed.
-func bufReadUint32(buf *bytes.Buffer) (uint32, error) {
- var numBuf [4]byte
- _, err := io.ReadFull(buf, numBuf[:])
- if err != nil {
- return 0, err
- }
- return binary.BigEndian.Uint32(numBuf[:]), nil
-}
-
-// bufWriteUint32 is a convenience function that appends to the given buffer
-// 4 bytes containing the big-endian representation of the given uint32 value.
-func bufWriteUint32(buf *bytes.Buffer, val uint32) {
- var numBuf [4]byte
- binary.BigEndian.PutUint32(numBuf[:], val)
- buf.Write(numBuf[:])
-}
-
-// bufWriteUint24 is a convenience function that appends to the given buffer
-// 3 bytes containing the big-endian representation of the given uint32 value.
-// Caller is responsible to ensure the passed value does not overflow the
-// maximal capacity of 3 bytes.
-func bufWriteUint24(buf *bytes.Buffer, val uint32) {
- b := &bytes.Buffer{}
- bufWriteUint32(b, val)
- buf.Write(b.Bytes()[1:])
-}
diff --git a/vpn/bytes_test.go b/vpn/bytes_test.go
deleted file mode 100644
index 632fe47a..00000000
--- a/vpn/bytes_test.go
+++ /dev/null
@@ -1,399 +0,0 @@
-package vpn
-
-import (
- "errors"
- "math"
- "testing"
-
- "github.com/google/go-cmp/cmp"
-)
-
-func Test_genRandomBytes(t *testing.T) {
- const smallBuffer = 128
- data, err := genRandomBytes(smallBuffer)
- if err != nil {
- t.Fatal("unexpected error", err)
- }
- if len(data) != smallBuffer {
- t.Fatal("unexpected returned buffer length")
- }
-}
-
-func Test_encodeOptionStringToBytes(t *testing.T) {
- type args struct {
- s string
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{{
- name: "common case",
- args: args{
- s: "test",
- },
- want: []byte{0, 5, 116, 101, 115, 116, 0},
- wantErr: nil,
- }, {
- name: "encoding empty string",
- args: args{
- s: "",
- },
- want: []byte{0, 1, 0},
- wantErr: nil,
- }, {
- name: "encoding a very large string",
- args: args{
- s: string(make([]byte, 1<<16)),
- },
- want: nil,
- wantErr: errEncodeOption,
- }}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := encodeOptionStringToBytes(tt.args.s)
- if !errors.Is(err, tt.wantErr) {
- t.Fatalf("encodeOptionStringToBytes() error = %v, wantErr %v", err, tt.wantErr)
- }
- if diff := cmp.Diff(tt.want, got); diff != "" {
- t.Fatal(diff)
- }
- })
- }
-}
-
-func Test_decodeOptionStringFromBytes(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want string
- wantErr error
- }{{
- name: "with zero-length input",
- args: args{
- b: nil,
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with input length equal to one",
- args: args{
- b: []byte{0x00},
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with input length equal to two",
- args: args{
- b: []byte{0x00, 0x00},
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with length mismatch and length < actual length",
- args: args{
- b: []byte{
- 0x00, 0x03, // length = 3
- 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa
- 0x00, // trailing zero
- },
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with length mismatch and length > actual length",
- args: args{
- b: []byte{
- 0x00, 0x44, // length = 68
- 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa
- 0x00, // trailing zero
- },
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with missing trailing \\0",
- args: args{
- b: []byte{
- 0x00, 0x05, // length = 5
- 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa
- },
- },
- want: "",
- wantErr: errDecodeOption,
- }, {
- name: "with valid input",
- args: args{
- b: []byte{
- 0x00, 0x06, // length = 6
- 0x61, 0x61, 0x61, 0x61, 0x61, // aaaaa
- 0x00, // trailing zero
- },
- },
- want: "aaaaa",
- wantErr: nil,
- }}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := decodeOptionStringFromBytes(tt.args.b)
- if !errors.Is(err, tt.wantErr) {
- t.Fatalf("decodeOptionStringFromBytes() error = %v, wantErr %v", err, tt.wantErr)
- }
- if diff := cmp.Diff(tt.want, got); diff != "" {
- t.Fatal(diff)
- }
- })
- }
-}
-
-func Test_bytesUnpadPKCS7(t *testing.T) {
- type args struct {
- b []byte
- blockSize int
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{{
- name: "with too-large blockSize",
- args: args{
- b: []byte{0x00, 0x00, 0x00},
- blockSize: math.MaxUint8 + 1, // too large
- },
- want: nil,
- wantErr: errUnpaddingPKCS7,
- }, {
- name: "with zero-length array",
- args: args{
- b: nil,
- blockSize: 2,
- },
- want: nil,
- wantErr: errUnpaddingPKCS7,
- }, {
- name: "with 0x00 used as padding",
- args: args{
- b: []byte{
- 0x61, 0x61, // block ("aa")
- 0x00, 0x00, // padding
- },
- blockSize: 2,
- },
- want: nil,
- wantErr: errUnpaddingPKCS7,
- }, {
- name: "with padding larger than block size",
- args: args{
- b: []byte{
- 0x61, 0x61, // block ("aa")
- 0x03, 0x03, // padding
- },
- blockSize: 2,
- },
- want: nil,
- wantErr: errUnpaddingPKCS7,
- }, {
- name: "with blocksize == 4 and len(data) == 0",
- args: args{
- b: []byte{
- 0x04, 0x04, 0x04, 0x04, // padding
- },
- blockSize: 4,
- },
- want: []byte{},
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 1",
- args: args{
- b: []byte{
- 0xde, // data
- 0x03, 0x03, 0x03, // padding
- },
- blockSize: 4,
- },
- want: []byte{0xde},
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 2",
- args: args{
- b: []byte{
- 0xde, 0xad, // data
- 0x02, 0x02, // padding
- },
- blockSize: 4,
- },
- want: []byte{0xde, 0xad},
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 3",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, // data
- 0x01, // padding
- },
- blockSize: 4,
- },
- want: []byte{0xde, 0xad, 0xbe},
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 4",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, 0xff, // data
- 0x04, 0x04, 0x04, 0x04, // padding
- },
- blockSize: 4,
- },
- want: []byte{0xde, 0xad, 0xbe, 0xff},
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 5",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, 0xff, 0xab, // data
- 0x03, 0x03, 0x03, // padding
- },
- blockSize: 4,
- },
- want: []byte{0xde, 0xad, 0xbe, 0xff, 0xab},
- wantErr: nil,
- }}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := bytesUnpadPKCS7(tt.args.b, tt.args.blockSize)
- if !errors.Is(err, tt.wantErr) {
- t.Fatalf("bytesUnpadPKCS7() error = %v, wantErr %v", err, tt.wantErr)
- }
- if diff := cmp.Diff(tt.want, got); diff != "" {
- t.Fatal(diff)
- }
- })
- }
-}
-
-func Test_bytesPadPKCS7(t *testing.T) {
- type args struct {
- b []byte
- blockSize int
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{{
- name: "with too-large block size",
- args: args{
- b: []byte{0x00, 0x00, 0x00},
- blockSize: math.MaxUint8 + 1,
- },
- want: nil,
- wantErr: errPaddingPKCS7,
- }, {
- name: "with negative block size",
- args: args{
- b: []byte{0x00, 0x00, 0x00},
- blockSize: -1,
- },
- want: nil,
- wantErr: errBadInput,
- }, {
- name: "with blockSize == 4 and len(data) == 0",
- args: args{
- b: nil,
- blockSize: 4,
- },
- want: []byte{
- 0x04, 0x04, 0x04, 0x04, // only padding
- },
- wantErr: nil,
- }, {
- name: "with blockSize == 4 and len(data) == 1",
- args: args{
- b: []byte{
- 0xde, // len(data) == 1
- },
- blockSize: 4,
- },
- want: []byte{
- 0xde, // data
- 0x03, 0x03, 0x03, // padding
- },
- wantErr: nil,
- }, {
- name: "with blockSize == 4 and len(data) == 2",
- args: args{
- b: []byte{
- 0xde, 0xad, // len(data) == 2
- },
- blockSize: 4,
- },
- want: []byte{
- 0xde, 0xad, // data
- 0x02, 0x02, // padding
- },
- wantErr: nil,
- }, {
- name: "with blockSize == 4 and len(data) == 3",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, // len(data) == 3
- },
- blockSize: 4,
- },
- want: []byte{
- 0xde, 0xad, 0xbe, //data
- 0x01, // padding
- },
- wantErr: nil,
- }, {
- name: "with blockSize == 4 and len(data) == 4",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, 0xef, // len(data) == 4
- },
- blockSize: 4,
- },
- want: []byte{
- 0xde, 0xad, 0xbe, 0xef, // data
- 0x04, 0x04, 0x04, 0x04, // padding
- },
- wantErr: nil,
- }, {
- name: "with blocksize == 4 and len(data) == 5",
- args: args{
- b: []byte{
- 0xde, 0xad, 0xbe, 0xef, 0xab, // len(data) == 5
- },
- blockSize: 4,
- },
- want: []byte{
- 0xde, 0xad, 0xbe, 0xef, 0xab, // data
- 0x03, 0x03, 0x03, // padding
- },
- wantErr: nil,
- }}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := bytesPadPKCS7(tt.args.b, tt.args.blockSize)
- if !errors.Is(err, tt.wantErr) {
- t.Fatalf("bytesPadPKCS7() error = %v, wantErr %v", err, tt.wantErr)
- }
- if diff := cmp.Diff(tt.want, got); diff != "" {
- t.Fatal(diff)
- }
- })
- }
-}
-
-// Regression test for MIV-01-002
-func Test_Crash_bytesPadPCKS7(t *testing.T) {
- bytesPadPKCS7(nil, 0)
- bytesPadPKCS7([]byte{0xaa, 0xab}, -1)
-}
diff --git a/vpn/client.go b/vpn/client.go
deleted file mode 100644
index 94dc2ebb..00000000
--- a/vpn/client.go
+++ /dev/null
@@ -1,257 +0,0 @@
-package vpn
-
-//
-// Client initialization and public methods
-//
-
-import (
- "context"
- "errors"
- "fmt"
- "net"
- "strings"
- "sync"
- "time"
-)
-
-var (
- // ErrDialError is a generic error while dialing
- ErrDialError = errors.New("dial error")
-
- // ErrAlreadyStarted is returned when trying to start the tunnel more than once
- ErrAlreadyStarted = errors.New("tunnel already started")
-
- // ErrNotReady is returned when a Read/Write attempt is made before the tunnel is ready.
- ErrNotReady = errors.New("tunnel not ready")
-)
-
-// tunnelInfo holds state about the VPN tunnelInfo that has longer duration than a
-// given session. This information is gathered at different stages:
-// - during the handshake (mtu).
-// - after server pushes config options(ip, gw).
-type tunnelInfo struct {
- mtu int
- ip string
- gw string
- peerID int
-}
-
-// vpnClient is a net.Conn that uses the VPN tunnel. It is a net.Conn with an
-// additional `Start()` method.
-type vpnClient interface {
- net.Conn
- Start(ctx context.Context) error
-}
-
-type dialContextFn func(context.Context, string, string) (net.Conn, error)
-
-// DialerContext is anything that features a net.Dialer-like DialContext method.
-type DialerContext interface {
- DialContext(context.Context, string, string) (net.Conn, error)
-}
-
-// Client implements the OpenVPN protocol. A Client object satisfies the
-// net.Conn interface. plus Start().
-// The Read and Write operations send and receive bytes to and from the tunnel
-// - they are writing to and reading from the OpenVPN Data channel, with the
-// control channel being handled in the background.
-// To Dial sockets through the Tunnel, you should use the NewTunDialer constructor,
-// that accepts a Client object.
-// Client is only intended to be directly instantiated by users that need a
-// finer control of the protocol steps, or for the case in which you need the
-// equivalent of raw sockets.
-type Client struct {
- Opts *Options
- Dialer DialerContext
-
- // If this channel is not nil, a series of Event* will be
- // sent to the channel. The user of the Client can set a
- // channel externally to subscribe to discrete transitions. A sufficiently
- // buffered-channel should be provided to avoid losing events (~10
- // events should do it).
- EventListener chan uint8
-
- Log Logger
-
- conn net.Conn
- mux vpnMuxer
- tunInfo *tunnelInfo
-
- // muxerFactoryFn allows to inject a different factory
- // for testing.
- muxerFactoryFn muxFactory
-
- startOnce sync.Once
- startErr error
-}
-
-var _ net.Conn = &Client{} // Ensure that we implement net.Conn
-var _ vpnClient = &Client{} // Ensure that we implement vpnClient
-
-// NewClientFromOptions returns a Client configured with the given Options.
-func NewClientFromOptions(opt *Options) *Client {
- if opt == nil {
- return &Client{}
- }
- return &Client{
- Opts: opt,
- tunInfo: &tunnelInfo{},
- Dialer: &net.Dialer{},
- }
-}
-
-//
-// observability
-//
-
-// emit sends the passed stage into any configured EventListener.
-func (c *Client) emit(stage uint8) {
- select {
- case c.EventListener <- stage:
- default:
- // don't deliver
- }
-}
-
-// Start starts the OpenVPN tunnel.
-func (c *Client) Start(ctx context.Context) error {
- c.startOnce.Do(func() {
- c.startErr = c.start(ctx)
- })
- return c.startErr
-}
-
-func (c *Client) start(ctx context.Context) error {
- c.emit(EventReady)
-
- conn, err := c.dial(ctx)
- if err != nil {
- return err
- }
-
- c.emit(EventDialDone)
-
- muxFactory := c.muxerFactory()
- mux, err := muxFactory(conn, c.Opts, c.tunInfo)
- if err != nil {
- conn.Close()
- return err
- }
-
- mux.SetEventListener(c.EventListener)
-
- err = mux.Handshake(ctx)
- if err != nil {
- conn.Close()
- return err
- }
-
- c.emit(EventHandshakeDone)
-
- c.conn = conn
- c.mux = mux
- return nil
-}
-
-// muxerFactory returns the default muxer Factory, or any other one that has
-// been injected into the `muxerFactoryFn` private field in Client for testing.
-func (c *Client) muxerFactory() muxFactory {
- muxFactory := newMuxerFromOptions
- if c.muxerFactoryFn == nil {
- return muxFactory
- }
- return c.muxerFactoryFn
-}
-
-// dial opens a TCP/UDP socket against the remote, and creates an internal
-// data channel. It is the second step in an OpenVPN connection (out of five).
-// (In UDP mode no network connection is done at this step).
-func (c *Client) dial(ctx context.Context) (net.Conn, error) {
- if c.Opts == nil {
- return nil, fmt.Errorf("%w:%s", errBadInput, "nil options")
-
- }
- var proto string
- switch c.Opts.Proto {
- case UDPMode:
- proto = protoUDP.String()
- case TCPMode:
- proto = protoTCP.String()
- default:
- return nil, fmt.Errorf("%w: unknown proto %d", errBadInput, c.Opts.Proto)
-
- }
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- msg := fmt.Sprintf("Connecting to %s:%s with proto %s",
- c.Opts.Remote, c.Opts.Port, strings.ToUpper(proto))
- logger.Info(msg)
-
- conn, err := c.Dialer.DialContext(ctx, proto, net.JoinHostPort(c.Opts.Remote, c.Opts.Port))
- if err != nil {
- return nil, fmt.Errorf("%w: %s", ErrDialError, err)
- }
- return conn, nil
- }
-}
-
-// Write sends bytes into the tunnel.
-func (c *Client) Write(b []byte) (int, error) {
- return c.mux.Write(b)
-}
-
-// Read reads bytes from the tunnel.
-func (c *Client) Read(b []byte) (int, error) {
- if c.mux == nil {
- return 0, ErrNotReady
-
- }
- return c.mux.Read(b)
-}
-
-// Close closes the tunnel connection.
-func (c *Client) Close() error {
- if c.conn != nil {
- return c.conn.Close()
- }
- return nil
-}
-
-// LocalAddr returns the local address on the tunnel virtual device, if known.
-// In case the Addr is not known, a zero-value net.Addr will be returned.
-func (c *Client) LocalAddr() net.Addr {
- addr := &net.IPAddr{}
- if c.tunInfo != nil {
- if ip := net.ParseIP(c.tunInfo.ip); ip != nil {
- addr.IP = ip
- }
- }
- return addr
-}
-
-// RemoteAddr returns the address of the tun interface of the tunnel gateway,
-// if known. In case the Addr is not known, a zero-value net.Addr will be returned.
-func (c *Client) RemoteAddr() net.Addr {
- addr := &net.IPAddr{}
- if c.tunInfo != nil {
- if ip := net.ParseIP(c.tunInfo.gw); ip != nil {
- addr.IP = ip
- }
- }
- return addr
-}
-
-func (c *Client) SetDeadline(t time.Time) error {
- return c.conn.SetDeadline(t)
-}
-
-func (c *Client) SetReadDeadline(t time.Time) error {
- return c.conn.SetReadDeadline(t)
-}
-
-func (c *Client) SetWriteDeadline(t time.Time) error {
- return c.conn.SetWriteDeadline(t)
-}
diff --git a/vpn/client_test.go b/vpn/client_test.go
deleted file mode 100644
index 1e6975e3..00000000
--- a/vpn/client_test.go
+++ /dev/null
@@ -1,286 +0,0 @@
-package vpn
-
-import (
- "context"
- "errors"
- "net"
- "reflect"
- "testing"
- "time"
-)
-
-// the name is confusing, but we're just getting a generic mocked conn
-// that serves as witness of calls
-// TODO can copy the mockTLSConn here to avoid confusion with names and
-// decouple these tests from those.
-func makeTestingClientConn() (*Client, *MockTLSConn) {
- c := makeConnForTransportTest()
- cl := &Client{}
- cl.conn = c
- return cl, c
-}
-
-func TestNewClientFromOptions(t *testing.T) {
- t.Run("proper options does not fail getting client", func(t *testing.T) {
- randomFn = func(int) ([]byte, error) {
- return []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil
- }
- opts := makeTestingOptions(t, "AES-128-GCM", "sha512")
- _ = NewClientFromOptions(opts)
- })
-
- t.Run("nil options return empty client", func(t *testing.T) {
- c := NewClientFromOptions(nil)
- if !reflect.DeepEqual(c, &Client{}) {
- t.Error("Client.NewClientFromOptions(): expected empty client with nil options")
- }
- })
-}
-
-type mockMuxerForClient struct {
- muxer
- writeCalled bool
- readCalled bool
-}
-
-func (mm *mockMuxerForClient) Read([]byte) (int, error) {
- mm.readCalled = true
- return 42, nil
-}
-
-func (mm *mockMuxerForClient) Write(b []byte) (int, error) {
- mm.writeCalled = true
- return len(b), nil
-}
-
-func mockMuxerFactory() muxFactory {
- fn := func(net.Conn, *Options, *tunnelInfo) (vpnMuxer, error) {
- m := &mockMuxerWithDummyHandshake{}
- return m, nil
- }
- return fn
-}
-
-func TestClient_Write(t *testing.T) {
- // test that call to write calls the muxer method
- cl, _ := makeTestingClientConn()
- mux := &mockMuxerForClient{}
- cl.mux = mux
- _, err := cl.Write([]byte("alles ist green"))
- if err != nil {
- t.Errorf("Client.Write(): expected err = nil, got %v", err)
- }
- if !mux.writeCalled {
- t.Errorf("Client.Write(): client.mux.Write() not called")
- }
-}
-
-func TestClient_Read(t *testing.T) {
- cl, _ := makeTestingClientConn()
- cl.mux = nil
- b := make([]byte, 255)
- _, err := cl.Read(b)
- if !errors.Is(err, ErrNotReady) {
- t.Errorf("Client.Read(): nil mux, expected error %v, got %v ", errBadInput, err)
- }
-
- // test that call to read calls the muxer method
- cl, _ = makeTestingClientConn()
- mux := &mockMuxerForClient{}
- cl.mux = mux
- b = make([]byte, 255)
- _, err = cl.Read(b)
- if err != nil {
- t.Errorf("Client.Read(): expected err = nil, got %v", err)
- }
- if !mux.readCalled {
- t.Errorf("Client.Read(): client.mux.Read() not called")
- }
-}
-
-func TestClient_LocalAddr(t *testing.T) {
- cl, _ := makeTestingClientConn()
- cl.tunInfo = nil
- a := cl.LocalAddr()
- if a.String() != "" {
- t.Errorf("Client.LocalAddr(): expected empty string, got %v", a.String())
- }
-}
-
-func TestClient_RemoteAddr(t *testing.T) {
- cl, _ := makeTestingClientConn()
- a := cl.RemoteAddr()
- if a.String() != "" {
- t.Errorf("Client.RemoteAddr(): expected empty string, got %v", a.String())
- }
-}
-
-// for the tests that test the delegation of methods to the underlying conn we
-// can reuse the mock used in transport_test
-
-func TestClient_SetDeadline(t *testing.T) {
- cl, conn := makeTestingClientConn()
- err := cl.SetDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("Client.SetDeadline() error = %v, want = nil", err)
- }
- if !conn.setDeadlineCalled {
- t.Error("Client.SetDeadline(): conn.SetDeadline() not called")
- }
-
-}
-
-func TestClient_SetReadDeadline(t *testing.T) {
- cl, conn := makeTestingClientConn()
- err := cl.SetReadDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("Client.SetDeadline() error = %v, want = nil", err)
- }
- if !conn.setReadDeadlineCalled {
- t.Error("Client.SetReadDeadline(): conn.SetReadDeadline() not called")
- }
-}
-
-func TestClient_SetWriteDeadline(t *testing.T) {
- cl, conn := makeTestingClientConn()
- err := cl.SetWriteDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("Client.SetWriteDeadline() error = %v, want = nil", err)
- }
- if !conn.setWriteDeadlineCalled {
- t.Error("Client.SetWriteDeadline(): conn.SetWriteReadDeadline() not called")
- }
-}
-
-func TestClient_Close(t *testing.T) {
- cl, conn := makeTestingClientConn()
- err := cl.Close()
- if err != nil {
- t.Errorf("Client.Close() error = %v, want = nil", err)
- }
- if !conn.closedCalled {
- t.Error("Client.Close(): conn.Close() not called")
- }
-}
-
-type badDialer struct{}
-
-func (bd *badDialer) DialContext(context.Context, string, string) (net.Conn, error) {
- return nil, errors.New("cannot dial")
-}
-
-func TestClient_dialFailsWithBadOptions(t *testing.T) {
- c := &Client{}
- _, err := c.dial(context.Background())
- wantErr := errBadInput
- if !errors.Is(err, wantErr) {
- t.Error("Client.Dial(): should fail with nil options")
- }
-
- c = &Client{
- Opts: &Options{
- Proto: 3,
- },
- }
- _, err = c.dial(context.Background())
- wantErr = errBadInput
- if !errors.Is(err, wantErr) {
- t.Error("Client.Dial(): should fail with bad proto")
- }
-
- c = &Client{
- Opts: &Options{
- Proto: TCPMode,
- },
- Dialer: &badDialer{},
- }
- _, err = c.dial(context.Background())
- wantErr = ErrDialError
- if !errors.Is(err, wantErr) {
- t.Errorf("Client.Dial(): should fail with ErrDialError, err = %v", err)
- }
-}
-
-func TestCient_DialRaisesError(t *testing.T) {
- c := &Client{
- Opts: &Options{
- Proto: TCPMode,
- },
- }
- ctx := context.Background()
- ctx, cancel := context.WithCancel(ctx)
- cancel()
- _, err := c.dial(ctx)
- if err != context.Canceled {
- t.Errorf("Client.Dial(): expected context.Canceled, err = %v", err)
- }
-}
-
-func TestClient_StartRaisesDialError(t *testing.T) {
- c := &Client{
- Opts: &Options{
- Proto: TCPMode,
- },
- Dialer: &badDialer{},
- }
- err := c.Start(context.Background())
- if !errors.Is(err, ErrDialError) {
- t.Errorf("Client.Start(): expected = %v, got = %v", ErrDialError, err)
- }
-}
-
-func TestClientStartWithMockedMuxerFactory(t *testing.T) {
- c := &Client{
- Opts: &Options{
- Proto: TCPMode,
- },
- Dialer: &mockedDialerContext{},
- }
- c.muxerFactoryFn = mockMuxerFactory()
- err := c.Start(context.Background())
- if err != nil {
- t.Errorf("expected no error, got %v", err)
- }
-}
-
-func TestClient_emitSendsToListener(t *testing.T) {
- t.Run("emit writes event if listener not null", func(t *testing.T) {
- l := make(chan uint8, 2)
- c := &Client{}
- c.EventListener = l
- sent := uint8(2)
- c.emit(sent)
- got := <-l
- if got != sent {
- t.Errorf("expected %v, got %v", sent, got)
- }
- })
- t.Run("emit is a noop if evenlistener not set", func(t *testing.T) {
- c := &Client{}
- sent := uint8(2)
- c.emit(sent)
- if c.EventListener != nil {
- t.Errorf("expected EventListener to be nil")
- }
- })
- t.Run("listener receives several events", func(t *testing.T) {
- l := make(chan uint8, 5)
- c := &Client{}
- c.EventListener = l
- received := []uint8{}
- sent := []uint8{1, 2, 3, 4, 5}
- for _, i := range sent {
- c.emit(i)
- }
- for _ = range sent {
- got := <-l
- received = append(received, got)
- }
- for i := range sent {
- if sent[i] != received[i] {
- t.Errorf("at [%d]: expected %v, got %v", i, sent, received)
- return
- }
- }
- })
-}
diff --git a/vpn/control.go b/vpn/control.go
deleted file mode 100644
index 2f02f8d5..00000000
--- a/vpn/control.go
+++ /dev/null
@@ -1,278 +0,0 @@
-package vpn
-
-//
-// OpenVPN control channel
-//
-
-import (
- "bytes"
- "encoding/binary"
- "encoding/hex"
- "errors"
- "fmt"
- "math"
- "net"
- "sync"
-)
-
-var (
- errBadReset = errors.New("bad reset packet")
- errExpiredKey = errors.New("max packet id reached")
-)
-
-var (
- serverPushReply = []byte("PUSH_REPLY")
- serverBadAuth = []byte("AUTH_FAILED")
-)
-
-// session keeps mutable state related to an OpenVPN session.
-type session struct {
- RemoteSessionID sessionID
- LocalSessionID sessionID
- keys []*dataChannelKey
- keyID int
- localPacketID packetID
- lastACK packetID
- ackQueue chan *packet
- mu sync.Mutex
- Log Logger
-}
-
-// newSession returns a session ready to be used.
-func newSession() (*session, error) {
- key0 := &dataChannelKey{}
- ackQueue := make(chan *packet, 100)
- session := &session{
- keys: []*dataChannelKey{key0},
- ackQueue: ackQueue,
- }
-
- randomBytes, err := randomFn(8)
- if err != nil {
- return session, err
- }
-
- // in go 1.17, one could do:
- // localSession := (*sessionID)(lsid)
- var localSession sessionID
- copy(localSession[:], randomBytes[:8])
- session.LocalSessionID = localSession
-
- localKey, err := newKeySource()
- if err != nil {
- return session, err
- }
-
- k, err := session.ActiveKey()
- if err != nil {
- return session, err
- }
- k.local = localKey
- return session, nil
-}
-
-// ActiveKey returns the dataChannelKey that is actively being used.
-func (s *session) ActiveKey() (*dataChannelKey, error) {
- if len(s.keys) < s.keyID {
- return nil, fmt.Errorf("%w: %s", errDataChannelKey, "no such key id")
- }
- dck := s.keys[s.keyID]
- return dck, nil
-}
-
-// localPacketID returns an unique Packet ID. It increments the counter.
-// In the future, this call could detect (or warn us) when we're approaching
-// the key end of life.
-func (s *session) LocalPacketID() (packetID, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- pid := s.localPacketID
- if pid == math.MaxUint32 {
- // we reached the max packetID, increment will overflow
- return 0, errExpiredKey
- }
- s.localPacketID++
- return pid, nil
-}
-
-// UpdateLastACK will update the internal variable for the last acknowledged
-// packet to the passed packetID, only if packetID is greater than the lastACK.
-func (s *session) UpdateLastACK(newPacketID packetID) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- if s.lastACK == math.MaxUint32 {
- return errExpiredKey
- }
- if s.lastACK != 0 && newPacketID <= s.lastACK {
- logger.Warnf("tried to write ack %d; last was %d", newPacketID, s.lastACK)
- }
- s.lastACK = newPacketID
- return nil
-}
-
-// isNextPacket returns true if the packetID is the next integer
-// from the last acknowledged packet.
-func (s *session) isNextPacket(p *packet) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- if p == nil {
- return false
- }
- return p.id-s.lastACK == 1
-}
-
-// control implements the controlHandler interface.
-// Like for true pirates, there is no state in control.
-type control struct{}
-
-// SendHardReset sends a control packet with the HardResetClientv2 header,
-// over the passed net.Conn.
-func (c *control) SendHardReset(conn net.Conn, s *session) error {
- _, err := sendControlPacket(conn, s, pControlHardResetClientV2, 0, []byte(""))
- return err
-}
-
-// ParseHardReset extracts the sessionID from a hard-reset server response, and
-// an error if the operation was not successful.
-func (c *control) ParseHardReset(b []byte) (sessionID, error) {
- p, err := newServerHardReset(b)
- if err != nil {
- return sessionID{}, err
- }
- return parseServerHardResetPacket(p)
-}
-
-// PushRequest returns a byte array with the PUSH_REQUEST command.
-func (c *control) PushRequest() []byte {
- var out bytes.Buffer
- out.Write([]byte("PUSH_REQUEST"))
- out.WriteByte(0x00)
- return out.Bytes()
-}
-
-// ReadReadPushResponse reads a byte array returned from the server,
-// as the response to a Push Request, and returns a string containing the
-// tunnel IP.
-// For now, this is a single string containing _only_ the tunnel ip,
-// but we might want to pass a pointer to the tunnel struct in the
-// future.
-func (*control) ReadPushResponse(b []byte) map[string][]string {
- return pushedOptionsAsMap(b)
-}
-
-// ControlMessage returns a byte array containing a message over the control
-// channel.
-// This is not a P_CONTROL, but a message over the TLS encrypted channel.
-func (c *control) ControlMessage(s *session, opt *Options) ([]byte, error) {
- key, err := s.ActiveKey()
- if err != nil {
- return []byte{}, err
- }
- return encodeClientControlMessageAsBytes(key.local, opt)
-}
-
-// ReadControlMessage reads a control message with authentication result data.
-// it returns the remote key, remote options and an error if we cannot parse
-// the data.
-func (c *control) ReadControlMessage(b []byte) (*keySource, string, error) {
- cm := newServerControlMessageFromBytes(b)
- return parseServerControlMessage(cm)
-}
-
-// SendACK builds an ACK control packet for the given packetID, and writes it
-// over the passed connection. It returns an error if the operation cannot be
-// completed successfully.
-func (c *control) SendACK(conn net.Conn, s *session, pid packetID) error {
- return sendACKFn(conn, s, pid)
-}
-
-// sendACK is used by controlHandler.SendACK() and by TLSConn.Read()
-func sendACK(conn net.Conn, s *session, pid packetID) error {
- panicIfFalse(len(s.RemoteSessionID) != 0, "tried to ack with null remote")
-
- p := newACKPacket(pid, s)
- payload := p.Bytes()
- payload = maybeAddSizeFrame(conn, payload)
-
- _, err := conn.Write(payload)
- if err != nil {
- return err
- }
-
- logger.Debug(fmt.Sprintln("write ack:", pid))
- logger.Debug(fmt.Sprintln(hex.Dump(payload)))
-
- return s.UpdateLastACK(pid)
-}
-
-var sendACKFn = sendACK
-
-var _ controlHandler = &control{} // Ensure that we implement controlHandler
-
-// sendControlPacket crafts a control packet with the given opcode and payload,
-// and writes it to the passed net.Conn.
-func sendControlPacket(conn net.Conn, s *session, opcode int, ack int, payload []byte) (n int, err error) {
- if s == nil {
- return 0, fmt.Errorf("%w:%s", errBadInput, "nil session")
- }
- p := newPacketFromPayload(uint8(opcode), 0, payload)
- p.localSessionID = s.LocalSessionID
-
- p.id, err = s.LocalPacketID()
- if err != nil {
- return 0, err
- }
- out := p.Bytes()
-
- out = maybeAddSizeFrame(conn, out)
-
- logger.Debug(fmt.Sprintf("control write: (%d bytes)\n", len(out)))
- logger.Debug(fmt.Sprintln(hex.Dump(out)))
- return conn.Write(out)
-}
-
-// isControlMessage returns a boolean indicating whether the header of a
-// payload indicates a control message.
-func isControlMessage(b []byte) bool {
- if len(b) < 4 {
- return false
- }
- return bytes.Equal(b[:4], controlMessageHeader)
-}
-
-// maybeAddSizeFrame prepends a two-byte header containing the size of the
-// payload if the network type for the passed net.Conn is not UDP (assumed to
-// be TCP).
-func maybeAddSizeFrame(conn net.Conn, payload []byte) []byte {
- switch conn.LocalAddr().Network() {
- case "udp", "udp4", "udp6":
- // nothing to do for UDP
- return payload
- case "tcp", "tcp4", "tcp6":
- length := make([]byte, 2)
- binary.BigEndian.PutUint16(length, uint16(len(payload)))
- return append(length, payload...)
- default:
- return []byte{}
- }
-}
-
-// isBadAuthReply returns true if the passed payload is a "bad auth" server
-// response; false otherwise.
-func isBadAuthReply(b []byte) bool {
- l := len(serverBadAuth)
- if len(b) < l {
- return false
- }
- return bytes.Equal(b[:l], serverBadAuth)
-}
-
-// isPushReply returns true if the passed payload is a "push reply" server
-// response; false otherwise.
-func isPushReply(b []byte) bool {
- l := len(serverPushReply)
- if len(b) < l {
- return false
- }
- return bytes.Equal(b[:l], serverPushReply)
-}
diff --git a/vpn/control_test.go b/vpn/control_test.go
deleted file mode 100644
index 56653bb4..00000000
--- a/vpn/control_test.go
+++ /dev/null
@@ -1,286 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "errors"
- "math"
- "net"
- "reflect"
- "testing"
-
- "github.com/ooni/minivpn/vpn/mocks"
-)
-
-func Test_newSession(t *testing.T) {
- tests := []struct {
- name string
- want *session
- wantErr bool
- }{
- {"get session", &session{}, false},
- }
- // TODO(ainghazal): get smarter and use test values (turn sesion into an interface).
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := newSession()
- if (err != nil) != tt.wantErr {
- t.Errorf("newSession() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- })
- }
-}
-
-func Test_maybeAddSizeFrame(t *testing.T) {
-
- type args struct {
- conn net.Conn
- payload []byte
- }
- tests := []struct {
- name string
- args args
- want []byte
- }{
-
- // FIXME ---- fix these tests ---
- /*
- {
- name: "udp",
- args: args{
- makeTestinConnFromNetwork("udp"),
- []byte{0xff, 0xfe, 0xfd},
- },
- want: []byte{0xff, 0xfe, 0xfd},
- },
- {
- name: "tcp",
- args: args{
- makeTestinConnFromNetwork("udp"),
- []byte{0xff, 0xfe, 0xfd},
- },
- want: []byte{0x00, 0x03, 0xff, 0xfe, 0xfd},
- },
- */
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := maybeAddSizeFrame(tt.args.conn, tt.args.payload); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("maybeAddSizeFrame() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_session_ActiveKey(t *testing.T) {
- s := &session{
- keys: make([]*dataChannelKey, 2),
- keyID: 10,
- }
- _, err := s.ActiveKey()
- wantErr := errDataChannelKey
- if !errors.Is(err, wantErr) {
- t.Errorf("session.ActiveKey() = got err %v, want %v", err, wantErr)
- }
-}
-
-func Test_session_LocalPacketID(t *testing.T) {
- type fields struct {
- RemoteSessionID sessionID
- LocalSessionID sessionID
- keys []*dataChannelKey
- keyID int
- localPacketID packetID
- lastACK packetID
- ackQueue chan *packet
- }
-
- tests := []struct {
- name string
- fields fields
- want packetID
- wantErr error
- }{
- {
- "return arbitrary value",
- fields{localPacketID: packetID(42)},
- packetID(42),
- nil,
- },
- {
- "return zero",
- fields{localPacketID: packetID(0)},
- packetID(0),
- nil,
- },
- {
- "overflow",
- fields{localPacketID: math.MaxUint32},
- packetID(0),
- errExpiredKey,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- s := &session{
- localPacketID: tt.fields.localPacketID,
- }
- if got, err := s.LocalPacketID(); got != tt.want || err != tt.wantErr {
- t.Errorf("session.LocalPacketID() = %v, want %v", got, tt.want)
- }
- })
- }
-
- // increments
- val := packetID(1000)
- s := &session{localPacketID: packetID(1000)}
-
- if got, _ := s.LocalPacketID(); got != val {
- t.Errorf("session.LocalPacketID() = %v, want %v", got, val)
- }
- val++
- if got, _ := s.LocalPacketID(); got != val {
- t.Errorf("session.LocalPacketID() = %v, want %v", got, val)
- }
- val++
- if got, _ := s.LocalPacketID(); got != val {
- t.Errorf("session.LocalPacketID() = %v, want %v", got, val)
- }
-}
-
-func Test_session_isNextPacket(t *testing.T) {
- type fields struct {
- lastACK packetID
- }
- type args struct {
- p *packet
- }
- tests := []struct {
- name string
- fields fields
- args args
- want bool
- }{
- {
- "is next",
- fields{lastACK: packetID(0)},
- args{&packet{id: packetID(1)}},
- true,
- },
- {
- "is two more",
- fields{lastACK: packetID(0)},
- args{&packet{id: packetID(2)}},
- false,
- },
- {
- "is lesser",
- fields{lastACK: packetID(100)},
- args{&packet{id: packetID(99)}},
- false,
- },
- {
- "is nil",
- fields{lastACK: packetID(100)},
- args{nil},
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- s := &session{
- lastACK: tt.fields.lastACK,
- }
- if got := s.isNextPacket(tt.args.p); got != tt.want {
- t.Errorf("session.isNextPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_isBadAuthReply(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want bool
- }{
- {"bad_auth", args{[]byte("AUTH_FAILED")}, true},
- {"too_large", args{[]byte("AUTH_FAILEDAAAAAA")}, true},
- {"too_short", args{[]byte("AAA")}, false},
- {"empty", args{[]byte("")}, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isBadAuthReply(tt.args.b); got != tt.want {
- t.Errorf("isBadAuthReply() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_isPushReply(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want bool
- }{
- {"push_reply", args{serverPushReply}, true},
- {"too_large", args{[]byte("PUSH_REPLYAAA")}, true},
- {"too_short", args{[]byte("AAA")}, false},
- {"empty", args{[]byte("")}, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isPushReply(tt.args.b); got != tt.want {
- t.Errorf("isPushReply() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_sendControlPacket(t *testing.T) {
- _, err := sendControlPacket(&mocks.Conn{}, nil, 1, 1, []byte(""))
- wantErr := errBadInput
- if !errors.Is(err, wantErr) {
- t.Errorf("sendControlPacket(): empty session should fail with err=%v, got=%v", wantErr, err)
- }
-}
-
-func Test_isControlMessage(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want bool
- }{
- {"good_control", args{controlMessageHeader}, true},
- {"bad_control", args{[]byte{0x00, 0x00, 0x00, 0x01}}, false},
- {"too_short", args{[]byte{0x00}}, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isControlMessage(tt.args.b); got != tt.want {
- t.Errorf("isControlMessage() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_control_PushRequest(t *testing.T) {
- c := &control{}
- got := c.PushRequest()
- if !bytes.Equal(got[:len(got)-1], []byte("PUSH_REQUEST")) {
- t.Errorf("control_PushRequest() = %v", got)
- }
- if got[len(got)-1] != 0x00 {
- t.Errorf("control_PushRequest(): expected trailing null byte")
- }
-}
diff --git a/vpn/crypto.go b/vpn/crypto.go
deleted file mode 100644
index 93b64076..00000000
--- a/vpn/crypto.go
+++ /dev/null
@@ -1,373 +0,0 @@
-package vpn
-
-//
-// Code to perform encryption, decryption and key derivation.
-//
-
-import (
- "crypto/aes"
- "crypto/cipher"
- "crypto/hmac"
- "crypto/md5"
- "crypto/sha1"
- "crypto/sha256"
- "crypto/sha512"
- "errors"
- "fmt"
- "hash"
- "log"
-) //#nosec G501,G505
-
-// TODO(ainghazal,bassosimone): see if it's feasible to use stdlib
-// functionality rather than using the code below.
-
-type (
- // cipherMode describes a cipher mode (e.g., GCM).
- cipherMode string
-
- // cipherName is a cipher name (e.g., AES).
- cipherName string
-)
-
-const (
- // cipherModeCBC is the CBC cipher mode.
- cipherModeCBC = cipherMode("cbc")
-
- // cipherModeGCM is the GCM cipher mode.
- cipherModeGCM = cipherMode("gcm")
-
- // cipherNameAES is an AES-based cipher.
- cipherNameAES = cipherName("aes")
-)
-
-var (
- // errInvalidKeySize means that the key size is invalid.
- errInvalidKeySize = errors.New("invalid key size")
-
- // errUnsupportedCipher indicates we don't support the desired cipher.
- errUnsupportedCipher = errors.New("unsupported cipher")
-
- // errUnsupportedMode indicates that the mode is not uspported.
- errUnsupportedMode = errors.New("unsupported mode")
-
- // errBadInput indicates invalid inputs to encrypt/decrypt functions.
- errBadInput = errors.New("bad input")
-)
-
-// encrypteData holds the different parts needed to decrypt an encrypted data
-// packet.
-type encryptedData struct {
- iv []byte
- ciphertext []byte
- aead []byte
-}
-
-// plaintextData holds the different parts needed to encrypt a plaintext
-// payload (after padding).
-type plaintextData struct {
- iv []byte
- plaintext []byte
- aead []byte
-}
-
-// dataCipher encrypts and decrypts OpenVPN data.
-type dataCipher interface {
- // keySizeBytes returns the key size (in bytes).
- keySizeBytes() int
-
- // isAEAD returns whether this cipher has AEAD properties.
- isAEAD() bool
-
- // blockSize returns the expected block size.
- blockSize() uint8
-
- // encrypt encripts a plaintext.
- //
- // Arguments:
- //
- // - key is the key, whose size must be consistent with the cipher;
- //
- // - plaintextData is the data to be encrypted;
- //
- // Returns the ciphertext on success and an error on failure.
- encrypt([]byte, *plaintextData) ([]byte, error)
-
- // decrypt is the opposite operation of encrypt. It takes in input the
- // ciphertext and returns the plaintext of an error.
- decrypt([]byte, *encryptedData) ([]byte, error)
-
- // mode returns the cipherMode
- cipherMode() cipherMode
-}
-
-// dataCipherAES implements dataCipher for AES.
-type dataCipherAES struct {
- // ksb is the key size in bytes
- ksb int
-
- // mode is the cipher mode
- mode cipherMode
-}
-
-var _ dataCipher = &dataCipherAES{} // Ensure we implement dataCipher
-
-// keySizeBytes implements dataCipher.keySizeBytes
-func (a *dataCipherAES) keySizeBytes() int {
- return a.ksb
-}
-
-// isAEAD implements dataCipher.isAEAD
-func (a *dataCipherAES) isAEAD() bool {
- return a.mode != cipherModeCBC
-}
-
-// blockSize implements dataCipher.BlockSize
-func (a *dataCipherAES) blockSize() uint8 {
- switch a.mode {
- case cipherModeCBC, cipherModeGCM:
- return 16
- default:
- return 0
- }
-}
-
-// decrypt implements dataCipher.decrypt.
-// Since key comes from a prf derivation, we only take as many bytes as we need to match
-// our key size.
-func (a *dataCipherAES) decrypt(key []byte, data *encryptedData) ([]byte, error) {
- if len(key) < a.keySizeBytes() {
- return nil, errInvalidKeySize
- }
-
- // they key material might be longer
- k := key[:a.keySizeBytes()]
- block, err := aes.NewCipher(k)
- if err != nil {
- return nil, err
- }
- switch a.mode {
- case cipherModeCBC:
- if len(data.iv) != block.BlockSize() {
- return nil, fmt.Errorf("%w: wrong size for iv: %v", errCannotDecrypt, len(data.iv))
- }
- mode := cipher.NewCBCDecrypter(block, data.iv)
- plaintext := make([]byte, len(data.ciphertext))
- mode.CryptBlocks(plaintext, data.ciphertext)
- plaintext, err := bytesUnpadPKCS7(plaintext, block.BlockSize())
- if err != nil {
- return nil, err
- }
- padLen := len(data.ciphertext) - len(plaintext)
- if padLen > block.BlockSize() || padLen > len(plaintext) {
- // TODO(bassosimone, ainghazal): discuss the cases in which
- // this set of conditions actually occurs.
- // TODO(ainghazal): this assertion might actually be moved into a
- // boundary assertion in the unpad fun.
- return nil, errors.New("unpadding error")
- }
- return plaintext, nil
-
- case cipherModeGCM:
- // standard nonce size is 12. more is surely ok, but let's stick to it.
- // https://github.com/golang/go/blob/master/src/crypto/aes/aes_gcm.go#L37
- if len(data.iv) != 12 {
- return nil, fmt.Errorf("%w: wrong size for iv: %v", errCannotDecrypt, len(data.iv))
- }
- aesGCM, err := cipher.NewGCM(block)
- if err != nil {
- return nil, err
- }
-
- plaintext, err := aesGCM.Open(nil, data.iv, data.ciphertext, data.aead)
- if err != nil {
- log.Println("gdm decryption failed:", err.Error())
- /*
- log.Println("dump begins----")
- log.Println("len:", len(data.ciphertext))
- log.Println("iv:", data.iv)
- log.Printf("%v\n", data.ciphertext)
- log.Printf("%x\n", data.ciphertext)
- log.Printf("aead: %x\n", data.aead)
- log.Println("dump ends------")
- */
- return nil, err
- }
- return plaintext, nil
-
- default:
- return nil, errUnsupportedMode
- }
-}
-
-func (a *dataCipherAES) cipherMode() cipherMode {
- return a.mode
-}
-
-// encrypt implements dataCipher.encrypt
-// Since key comes from a prf derivation, we only take as many bytes as we need to match
-// our key size.
-func (a *dataCipherAES) encrypt(key []byte, data *plaintextData) ([]byte, error) {
- if len(key) < a.keySizeBytes() {
- return nil, errInvalidKeySize
- }
- k := key[:a.keySizeBytes()]
- block, err := aes.NewCipher(k)
- if err != nil {
- return nil, err
- }
- blockSize := block.BlockSize()
- switch a.mode {
- case cipherModeCBC:
- if len(data.iv) != blockSize {
- return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", errCannotEncrypt, len(data.iv))
- }
- if len(data.plaintext)%blockSize != 0 {
- return []byte{}, fmt.Errorf("%w: wrong padding", errCannotEncrypt)
- }
- mode := cipher.NewCBCEncrypter(block, data.iv)
-
- ciphertext := make([]byte, len(data.plaintext))
- mode.CryptBlocks(ciphertext, data.plaintext)
- return ciphertext, nil
-
- case cipherModeGCM:
- if len(data.iv) != 12 {
- return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", errCannotEncrypt, len(data.iv))
- }
- aesGCM, err := cipher.NewGCM(block)
- if err != nil {
- return nil, err
- }
- // In GCM mode, the IV consists of the 32-bit packet counter
- // followed by data from the HMAC key. The HMAC key can be used
- // as IV, since in GCM mode the HMAC key is not used for the
- // HMAC. The packet counter may not roll over within a single
- // TLS session. This results in a unique IV for each packet, as
- // required by GCM.
- ciphertext := aesGCM.Seal(nil, data.iv, data.plaintext, data.aead)
- return ciphertext, nil
-
- default:
- return nil, errUnsupportedMode
- }
-}
-
-// newDataCipherFromCipherSuite constructs a new dataCipher from the cipher suite string.
-func newDataCipherFromCipherSuite(c string) (dataCipher, error) {
- switch c {
- case "AES-128-CBC":
- return newDataCipher(cipherNameAES, 128, cipherModeCBC)
- case "AES-192-CBC":
- return newDataCipher(cipherNameAES, 192, cipherModeCBC)
- case "AES-256-CBC":
- return newDataCipher(cipherNameAES, 256, cipherModeCBC)
- case "AES-128-GCM":
- return newDataCipher(cipherNameAES, 128, cipherModeGCM)
- case "AES-256-GCM":
- return newDataCipher(cipherNameAES, 256, cipherModeGCM)
- default:
- return nil, errUnsupportedCipher
- }
-}
-
-// newDataCipher constructs a new dataCipher from the given name, bits, and mode.
-func newDataCipher(name cipherName, bits int, mode cipherMode) (dataCipher, error) {
- if bits%8 != 0 || bits > 512 || bits < 64 {
- return nil, fmt.Errorf("%w: %d", errInvalidKeySize, bits)
- }
- switch name {
- case cipherNameAES:
- default:
- return nil, fmt.Errorf("%w: %s", errUnsupportedCipher, name)
- }
- switch mode {
- case cipherModeCBC, cipherModeGCM:
- default:
- return nil, fmt.Errorf("%w: %s", errUnsupportedMode, mode)
- }
- dc := &dataCipherAES{
- ksb: bits / 8,
- mode: mode,
- }
- return dc, nil
-}
-
-// newHMACFactory accepts a label coming from an OpenVPN auth label, and returns two
-// values: a function that will return a Hash implementation, and a boolean
-// indicating if the operation was successful.
-func newHMACFactory(name string) (func() hash.Hash, bool) {
- switch name {
- case "sha1":
- return sha1.New, true
- case "sha256":
- return sha256.New, true
- case "sha512":
- return sha512.New, true
- default:
- return nil, false
- }
-}
-
-// prf function is used to derive master and client keys
-func prf(secret, label, clientSeed, serverSeed, clientSid, serverSid []byte, olen int) []byte {
- seed := append(clientSeed, serverSeed...)
- if len(clientSid) != 0 {
- seed = append(seed, clientSid...)
- }
- if len(serverSid) != 0 {
- seed = append(seed, serverSid...)
- }
- result := make([]byte, olen)
- return prf10(result, secret, label, seed)
-}
-
-// Code below is taken from crypto/tls/prf.go
-// Copyright 2009 The Go Authors. All rights reserved.
-// SPDX-License-Identifier: BSD-3-Clause
-// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
-func prf10(result, secret, label, seed []byte) []byte {
- hashSHA1 := sha1.New
- hashMD5 := md5.New
-
- labelAndSeed := make([]byte, len(label)+len(seed))
- copy(labelAndSeed, label)
- copy(labelAndSeed[len(label):], seed)
-
- s1, s2 := splitPreMasterSecret(secret)
- pHash(result, s1, labelAndSeed, hashMD5)
- result2 := make([]byte, len(result))
- pHash(result2, s2, labelAndSeed, hashSHA1)
- for i, b := range result2 {
- result[i] ^= b
- }
- return result
-}
-
-// SPDX-License-Identifier: BSD-3-Clause
-// Split a premaster secret in two as specified in RFC 4346, Section 5.
-func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
- s1 = secret[0 : (len(secret)+1)/2]
- s2 = secret[len(secret)/2:]
- return
-
-}
-
-// SPDX-License-Identifier: BSD-3-Clause
-// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
-func pHash(result, secret, seed []byte, hash func() hash.Hash) {
- h := hmac.New(hash, secret)
- h.Write(seed)
- a := h.Sum(nil)
- j := 0
- for j < len(result) {
- h.Reset()
- h.Write(a)
- h.Write(seed)
- b := h.Sum(nil)
- copy(result[j:], b)
- j += len(b)
- h.Reset()
- h.Write(a)
- a = h.Sum(nil)
- }
-}
diff --git a/vpn/crypto_test.go b/vpn/crypto_test.go
deleted file mode 100644
index 5c22d447..00000000
--- a/vpn/crypto_test.go
+++ /dev/null
@@ -1,405 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "crypto/sha1"
- "crypto/sha256"
- "crypto/sha512"
- "encoding/hex"
- "errors"
- "hash"
- "log"
- "reflect"
- "testing"
-)
-
-func TestDataCipherAES(t *testing.T) {
- _, err := newDataCipher("aes", 128, "cbc")
- if err != nil {
- t.Errorf("Cannot instantiate aes-128-cbc")
- }
-}
-
-func TestBadCipher(t *testing.T) {
- _, err := newDataCipher("bad", 128, "cbc")
- if err == nil {
- t.Errorf("Should fail with bad cipher")
- }
-}
-
-func TestBadMode(t *testing.T) {
- _, err := newDataCipher("aes", 128, "bad")
- if err == nil {
- t.Errorf("Should fail with bad mode")
- }
-}
-
-func TestBadKeySize(t *testing.T) {
- _, err := newDataCipher("aes", 1024, "cbc")
- if err == nil {
- t.Errorf("Should fail with bad key size")
- }
- _, err = newDataCipher("aes", 8, "cbc")
- if err == nil {
- t.Errorf("Should fail with bad key size")
- }
-}
-
-func Test_newDataCipherFromCipherSuite(t *testing.T) {
- type args struct {
- c string
- }
- tests := []struct {
- name string
- args args
- want dataCipher
- wantErr bool
- }{
- {"aes-128-cbc", args{"AES-128-CBC"}, &dataCipherAES{16, "cbc"}, false},
- {"aes-192-cbc", args{"AES-192-CBC"}, &dataCipherAES{24, "cbc"}, false},
- {"aes-256-cbc", args{"AES-256-CBC"}, &dataCipherAES{32, "cbc"}, false},
- {"aes-128-gcm", args{"AES-128-GCM"}, &dataCipherAES{16, "gcm"}, false},
- {"aes-256-gcm", args{"AES-256-GCM"}, &dataCipherAES{32, "gcm"}, false},
- {"bad-256-gcm", args{"AES-512-GCM"}, nil, true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := newDataCipherFromCipherSuite(tt.args.c)
- if (err != nil) != tt.wantErr {
- t.Errorf("newCipherFromCipherSuite() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newCipherFromCipherSuite() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_newCipher(t *testing.T) {
- type args struct {
- name cipherName
- bits int
- mode cipherMode
- }
- tests := []struct {
- name string
- args args
- want dataCipher
- wantErr bool
- }{
- {"aesOK", args{"aes", 256, "cbc"}, &dataCipherAES{32, "cbc"}, false},
- {"badCipher", args{"blowfish", 256, "cbc"}, nil, true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := newDataCipher(tt.args.name, tt.args.bits, tt.args.mode)
- if (err != nil) != tt.wantErr {
- t.Errorf("newCipher() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newCipher() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-// this particular test is basically equivalent to reimplementing the factory, but okay,
-// it's somehow useful to catch allowed values.
-func Test_newHMACFactory(t *testing.T) {
- type args struct {
- name string
- }
- tests := []struct {
- name string
- args args
- want func() hash.Hash
- want1 bool
- }{
- {"sha1", args{"sha1"}, sha1.New, true},
- {"sha256", args{"sha256"}, sha256.New, true},
- {"sha512", args{"sha512"}, sha512.New, true},
- {"shabad", args{"sha192"}, nil, false},
- }
-
- //str := "hello"
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, got1 := newHMACFactory(tt.args.name)
- if got == nil {
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newHMACFactory() got = %v, want %v", &got, &tt.want)
- }
- if got1 != tt.want1 {
- t.Errorf("newHMACFactory() got1 = %v, want %v", got1, tt.want1)
- }
- } else {
- // it is a function factory, so let's get the function to compare
- if !reflect.DeepEqual(got(), tt.want()) {
- t.Errorf("newHMACFactory() got = %v, want %v", &got, &tt.want)
- }
- if got1 != tt.want1 {
- t.Errorf("newHMACFactory() got1 = %v, want %v", got1, tt.want1)
- }
- }
- })
- }
-}
-
-func TestPrf(t *testing.T) {
- expected := []byte{
- 0x67, 0x18, 0x7c, 0x52, 0xac, 0xd2, 0x4d, 0x95,
- 0x9a, 0x55, 0xd3, 0x1c, 0xdb, 0x97, 0x80, 0x11}
- secret := []byte("secret")
- label := []byte("master key")
- cseed := []byte("aaa")
- sseed := []byte("bbb")
- out := prf(secret, label, cseed, sseed, []byte{}, []byte{}, 16)
- if !bytes.Equal(out, expected) {
- t.Errorf("Bad output in prf call: %v", out)
- }
-}
-
-func Test_dataCipherAES_decrypt(t *testing.T) {
- key := bytes.Repeat([]byte("A"), 64)
- iv12, _ := hex.DecodeString("000000006868686868686868")
- iv16, _ := hex.DecodeString("00000000686868686868686865656565")
- ciphertextGCM, _ := hex.DecodeString("a949df311c57ec762428a7ba98d1d0d8213134925bf1cd2cb4ab4ea9066c569b0579")
- ciphertextCBC, _ := hex.DecodeString("f908ff8dedbe4e2097c992c67e603d25606c76a460cd785503cf0a2a9e6ec961")
-
- type fields struct {
- ksb int
- mode cipherMode
- }
- type args struct {
- key []byte
- data *encryptedData
- }
- tests := []struct {
- name string
- fields fields
- args args
- want []byte
- wantErr error
- }{
- {
- name: "good decrypt gcm",
- fields: fields{
- ksb: 16,
- mode: cipherModeGCM,
- },
- args: args{
- key: key,
- data: &encryptedData{
- iv: iv12,
- ciphertext: ciphertextGCM,
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: []byte("this test is green"),
- wantErr: nil,
- },
- {
- name: "iv too short gcm",
- fields: fields{
- ksb: 16,
- mode: cipherModeGCM,
- },
- args: args{
- key: key,
- data: &encryptedData{
- iv: []byte{0x00},
- ciphertext: ciphertextGCM,
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: nil,
- wantErr: errCannotDecrypt,
- },
- {
- name: "good decrypt cbc",
- fields: fields{
- ksb: 16,
- mode: cipherModeCBC,
- },
- args: args{
- key: key,
- data: &encryptedData{
- iv: iv16,
- ciphertext: ciphertextCBC,
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: []byte("this test is green"),
- wantErr: nil,
- },
- {
- name: "iv too short cbc",
- fields: fields{
- ksb: 16,
- mode: cipherModeGCM,
- },
- args: args{
- key: key,
- data: &encryptedData{
- iv: []byte{0x00},
- ciphertext: ciphertextGCM,
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: []byte{},
- wantErr: errCannotDecrypt,
- },
- // TODO: Add moar test cases, with failing.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- a := &dataCipherAES{
- ksb: tt.fields.ksb,
- mode: tt.fields.mode,
- }
- got, err := a.decrypt(tt.args.key, tt.args.data)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("dataCipherAES.decrypt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !bytes.Equal(got, tt.want) {
- t.Errorf("dataCipherAES.decrypt() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func doPaddingForTest(payload []byte, blockSize int) []byte {
- padded, _ := bytesPadPKCS7(payload, blockSize)
- return padded
-}
-
-func Test_dataCipherAES_encrypt(t *testing.T) {
- key := bytes.Repeat([]byte("A"), 64)
- iv12, _ := hex.DecodeString("000000006868686868686868")
- iv16, _ := hex.DecodeString("00000000686868686868686865656565")
-
- ciphertextGCM, _ := hex.DecodeString("a949df311c57ec762428a7ba98d1d0d8213134925bf1cd2cb4ab4ea9066c569b0579")
- ciphertextCBC, _ := hex.DecodeString("f908ff8dedbe4e2097c992c67e603d25606c76a460cd785503cf0a2a9e6ec961")
-
- type fields struct {
- ksb int
- mode cipherMode
- }
- type args struct {
- key []byte
- data *plaintextData
- }
- tests := []struct {
- name string
- fields fields
- args args
- want []byte
- wantErr error
- }{
- {
- name: "good encrypt aes-128-gcm",
- fields: fields{
- ksb: 16,
- mode: cipherModeGCM,
- },
- args: args{
- key: key,
- data: &plaintextData{
- iv: iv12,
- plaintext: []byte("this test is green"),
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: ciphertextGCM,
- wantErr: nil,
- },
- {
- name: "iv too short aes-128-gcm",
- fields: fields{
- ksb: 16,
- mode: cipherModeGCM,
- },
- args: args{
- key: key,
- data: &plaintextData{
- iv: []byte{0x00},
- plaintext: []byte("should fail"),
- aead: []byte{0x00, 0x01, 0x02, 0x03},
- },
- },
- want: []byte(""),
- wantErr: errCannotEncrypt,
- },
- {
- name: "iv too short aes-128-cbc",
- fields: fields{
- ksb: 16,
- mode: cipherModeCBC,
- },
- args: args{
- key: key,
- data: &plaintextData{
- iv: iv12,
- plaintext: []byte("should fail"),
- },
- },
- want: []byte(""),
- wantErr: errCannotEncrypt,
- },
- {
- name: "bad padding aes-128-cbc",
- fields: fields{
- ksb: 16,
- mode: cipherModeCBC,
- },
- args: args{
- key: key,
- data: &plaintextData{
- iv: iv16,
- plaintext: []byte("should fail"),
- },
- },
- want: []byte(""),
- wantErr: errCannotEncrypt,
- },
- {
- name: "good encrypt aes-128-cbc",
- fields: fields{
- ksb: 16,
- mode: cipherModeCBC,
- },
- args: args{
- key: key,
- data: &plaintextData{
- iv: iv16,
- plaintext: doPaddingForTest([]byte("this test is green"), 16),
- },
- },
- want: ciphertextCBC,
- wantErr: nil,
- },
- // TODO: Add moar test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- a := &dataCipherAES{
- ksb: tt.fields.ksb,
- mode: tt.fields.mode,
- }
- got, err := a.encrypt(tt.args.key, tt.args.data)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("dataCipherAES.encrypt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- log.Println(hex.EncodeToString(got))
-
- t.Errorf("dataCipherAES.encrypt() = %v, want %v", got, tt.want)
- }
- })
- }
-}
diff --git a/vpn/data.go b/vpn/data.go
deleted file mode 100644
index 15546ca9..00000000
--- a/vpn/data.go
+++ /dev/null
@@ -1,673 +0,0 @@
-package vpn
-
-//
-// OpenVPN data channel
-//
-
-import (
- "bytes"
- "crypto/hmac"
- "encoding/binary"
- "encoding/hex"
- "errors"
- "fmt"
- "hash"
- "math"
- "net"
- "strings"
- "sync"
-)
-
-var (
- errDataChannelKey = errors.New("bad key")
- errBadCompression = errors.New("bad compression")
- errReplayAttack = errors.New("replay attack")
- errCannotEncrypt = errors.New("cannot encrypt")
- errCannotDecrypt = errors.New("cannot decrypt")
- errBadHMAC = errors.New("bad hmac")
- errInitError = errors.New("improperly initialized")
-)
-
-// keySlot holds the different local and remote keys.
-type keySlot [64]byte
-
-// dataChannelState is the state of the data channel.
-type dataChannelState struct {
- dataCipher dataCipher
- hash func() hash.Hash
- // outgoing and incoming nomenclature is probably more adequate here.
- hmacLocal hash.Hash
- hmacRemote hash.Hash
- remotePacketID packetID
- cipherKeyLocal keySlot
- cipherKeyRemote keySlot
- hmacKeyLocal keySlot
- hmacKeyRemote keySlot
- keyID int // not used at the moment, paving the way for key rotation.
- peerID int
-
- mu sync.Mutex
-}
-
-// SetSetRemotePacketID stores the passed packetID internally.
-func (dcs *dataChannelState) SetRemotePacketID(id packetID) {
- dcs.mu.Lock()
- defer dcs.mu.Unlock()
- dcs.remotePacketID = packetID(id)
-}
-
-// RemotePacketID returns the last known remote packetID. It returns an error
-// if the stored packet id has reached the maximum capacity of the packetID
-// type.
-func (dcs *dataChannelState) RemotePacketID() (packetID, error) {
- dcs.mu.Lock()
- defer dcs.mu.Unlock()
- pid := dcs.remotePacketID
- if pid == math.MaxUint32 {
- // we reached the max packetID, increment will overflow
- return 0, errExpiredKey
- }
- return pid, nil
-}
-
-// dataChannelKey represents a pair of key sources that have been negotiated
-// over the control channel, and from which we will derive local and remote
-// keys for encryption and decrption over the data channel. The index refers to
-// the short key_id that is passed in the lower 3 bits if a packet header.
-// The setup of the keys for a given data channel (that is, for every key_id)
-// is made by expanding the keysources using the prf function.
-// Do note that we are not yet implementing key renegotiation - but the index
-// is provided for convenience when/if we support that in the future.
-type dataChannelKey struct {
- index uint32
- ready bool
- local *keySource
- remote *keySource
- mu sync.Mutex
-}
-
-// addRemoteKey adds the server keySource to our dataChannelKey. This makes the
-// dataChannelKey ready to be used.
-func (dck *dataChannelKey) addRemoteKey(k *keySource) error {
- dck.mu.Lock()
- defer dck.mu.Unlock()
- if dck.ready {
- return fmt.Errorf("%w:%s", errDataChannelKey, "cannot overwrite remote key slot")
- }
- dck.remote = k
- dck.ready = true
- return nil
-}
-
-var (
- randomFn = genRandomBytes
- errRandomBytes = errors.New("Error generating random bytes")
-)
-
-// keySource contains random data to generate keys.
-type keySource struct {
- r1 [32]byte
- r2 [32]byte
- preMaster [48]byte
-}
-
-// Bytes returns the byte representation of a keySource.
-func (k *keySource) Bytes() []byte {
- buf := &bytes.Buffer{}
- buf.Write(k.preMaster[:])
- buf.Write(k.r1[:])
- buf.Write(k.r2[:])
- return buf.Bytes()
-}
-
-// newKeySource returns a keySource and an error.
-func newKeySource() (*keySource, error) {
- random1, err := randomFn(32)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error())
- }
-
- var r1, r2 [32]byte
- var preMaster [48]byte
- copy(r1[:], random1)
-
- random2, err := randomFn(32)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error())
- }
- copy(r2[:], random2)
-
- random3, err := randomFn(48)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", errRandomBytes, err.Error())
- }
- copy(preMaster[:], random3)
- return &keySource{
- r1: r1,
- r2: r2,
- preMaster: preMaster,
- }, nil
-}
-
-// data represents the data "channel", that will encrypt and decrypt the tunnel payloads.
-// data implements the dataHandler interface.
-type data struct {
- options *Options
- session *session
- state *dataChannelState
- decodeFn func([]byte, *dataChannelState) (*encryptedData, error)
- encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error)
- decryptFn func([]byte, *encryptedData) ([]byte, error)
-}
-
-var _ dataHandler = &data{} // Ensure that we implement dataHandler
-
-// newDataFromOptions returns a new data object, initialized with the
-// options given. it also returns any error raised.
-func newDataFromOptions(opt *Options, s *session) (*data, error) {
- if opt == nil || s == nil {
- return nil, fmt.Errorf("%w: %s", errBadInput, "found nil on init")
- }
- if len(opt.Cipher) == 0 || len(opt.Auth) == 0 {
- return nil, fmt.Errorf("%w: %s", errBadInput, "empty options")
- }
- state := &dataChannelState{}
- data := &data{options: opt, session: s, state: state}
-
- logger.Info(fmt.Sprintf("Cipher: %s", opt.Cipher))
-
- dataCipher, err := newDataCipherFromCipherSuite(opt.Cipher)
- if err != nil {
- return data, err
- }
- data.state.dataCipher = dataCipher
- switch dataCipher.isAEAD() {
- case true:
- data.decodeFn = decodeEncryptedPayloadAEAD
- data.encryptEncodeFn = encryptAndEncodePayloadAEAD
- case false:
- data.decodeFn = decodeEncryptedPayloadNonAEAD
- data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD
- }
-
- logger.Info(fmt.Sprintf("Auth: %s", opt.Auth))
-
- hmacHash, ok := newHMACFactory(strings.ToLower(opt.Auth))
- if !ok {
- return data, fmt.Errorf("%w:%s", errBadInput, "no such mac")
- }
- data.state.hash = hmacHash
- data.decryptFn = state.dataCipher.decrypt
-
- return data, nil
-}
-
-// DecodeEncryptedPayload calls the corresponding function for AEAD or Non-AEAD decryption.
-func (d *data) DecodeEncryptedPayload(b []byte, dcs *dataChannelState) (*encryptedData, error) {
- return d.decodeFn(b, dcs)
-}
-
-// SetSetupKeys performs the key expansion from the local and remote
-// keySources, initializing the data channel state.
-func (d *data) SetupKeys(dck *dataChannelKey) error {
- if dck == nil {
- return fmt.Errorf("%w: %s", errBadInput, "nil args")
- }
- if !dck.ready {
- return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready")
- }
- master := prf(
- dck.local.preMaster[:],
- []byte("OpenVPN master secret"),
- dck.local.r1[:],
- dck.remote.r1[:],
- []byte{}, []byte{},
- 48)
-
- keys := prf(
- master,
- []byte("OpenVPN key expansion"),
- dck.local.r2[:],
- dck.remote.r2[:],
- d.session.LocalSessionID[:], d.session.RemoteSessionID[:],
- 256)
-
- var keyLocal, hmacLocal, keyRemote, hmacRemote keySlot
- copy(keyLocal[:], keys[0:64])
- copy(hmacLocal[:], keys[64:128])
- copy(keyRemote[:], keys[128:192])
- copy(hmacRemote[:], keys[192:256])
-
- d.state.cipherKeyLocal = keyLocal
- d.state.hmacKeyLocal = hmacLocal
- d.state.cipherKeyRemote = keyRemote
- d.state.hmacKeyRemote = hmacRemote
-
- logger.Debugf("Cipher key local: %x", keyLocal)
- logger.Debugf("Cipher key remote: %x", keyRemote)
- logger.Debugf("Hmac key local: %x", hmacLocal)
- logger.Debugf("Hmac key remote: %x", hmacRemote)
-
- hashSize := d.state.hash().Size()
- d.state.hmacLocal = hmac.New(d.state.hash, hmacLocal[:hashSize])
- d.state.hmacRemote = hmac.New(d.state.hash, hmacRemote[:hashSize])
-
- logger.Info("Key derivation OK")
- return nil
-}
-
-// SetPeerID updates the data state field with the info sent by the server.
-func (d *data) SetPeerID(i int) error {
- d.state.peerID = i
- return nil
-}
-
-//
-// write + encrypt
-//
-
-// encrypt calls the corresponding function for AEAD or Non-AEAD decryption.
-// Due to the particularities of the iv generation on each of the modes, encryption and encoding are
-// done together in the same function.
-// TODO accept state for symmetry
-func (d *data) EncryptAndEncodePayload(plaintext []byte, dcs *dataChannelState) ([]byte, error) {
- if len(plaintext) == 0 {
- return []byte{}, fmt.Errorf("%w: nothing to encrypt", errCannotEncrypt)
- }
- if dcs == nil || dcs.dataCipher == nil {
- return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, fmt.Errorf("data chan not initialized"))
- }
-
- padded, err := doPadding(plaintext, d.options.Compress, dcs.dataCipher.blockSize())
- if err != nil {
- return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, err)
- }
-
- encrypted, err := d.encryptEncodeFn(padded, d.session, d.state)
- if err != nil {
- return []byte{}, fmt.Errorf("%w: %s", errCannotEncrypt, err)
- }
- return encrypted, nil
-
-}
-
-// encryptAndEncodePayloadAEAD peforms encryption and encoding of the payload in AEAD modes (i.e., AES-GCM).
-// TODO(ainghazal): for testing we can pass both the state object and the encryptFn
-func encryptAndEncodePayloadAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) {
- nextPacketID, err := session.LocalPacketID()
- if err != nil {
- return []byte{}, fmt.Errorf("bad packet id")
- }
-
- // in AEAD mode, we authenticate:
- // - 1 byte: opcode/key
- // - 3 bytes: peer-id (we're using P_DATA_V2)
- // - 4 bytes: packet-id
- aead := &bytes.Buffer{}
- aead.WriteByte(opcodeAndKeyHeader(state))
- bufWriteUint24(aead, uint32(state.peerID))
- bufWriteUint32(aead, uint32(nextPacketID))
-
- // the iv is the packetID (again) concatenated with the 8 bytes of the
- // key derived for local hmac (which we do not use for anything else in AEAD mode).
- iv := &bytes.Buffer{}
- bufWriteUint32(iv, uint32(nextPacketID))
- iv.Write(state.hmacKeyLocal[:8])
-
- data := &plaintextData{
- iv: iv.Bytes(),
- plaintext: padded,
- aead: aead.Bytes(),
- }
-
- encryptFn := state.dataCipher.encrypt
- encrypted, err := encryptFn(state.cipherKeyLocal[:], data)
- if err != nil {
- return []byte{}, err
- }
-
- // some reordering, because openvpn uses tag | payload
- boundary := len(encrypted) - 16
- tag := encrypted[boundary:]
- ciphertext := encrypted[:boundary]
-
- // we now write to the output buffer
- out := bytes.Buffer{}
- out.Write(data.aead) // opcode|peer-id|packet_id
- out.Write(tag)
- out.Write(ciphertext)
- return out.Bytes(), nil
-
-}
-
-// encryptAndEncodePayloadNonAEAD peforms encryption and encoding of the payload in Non-AEAD modes (i.e., AES-CBC).
-func encryptAndEncodePayloadNonAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) {
- // For iv generation, OpenVPN uses a nonce-based PRNG that is initially seeded with
- // OpenSSL RAND_bytes function. I am assuming this is good enough for our current purposes.
- blockSize := state.dataCipher.blockSize()
-
- iv, err := randomFn(int(blockSize))
- if err != nil {
- return []byte{}, err
- }
- data := &plaintextData{
- iv: iv,
- plaintext: padded,
- aead: nil,
- }
-
- encryptFn := state.dataCipher.encrypt
- ciphertext, err := encryptFn(state.cipherKeyLocal[:], data)
- if err != nil {
- return []byte{}, err
- }
-
- state.hmacLocal.Reset()
- state.hmacLocal.Write(iv)
- state.hmacLocal.Write(ciphertext)
- computedMAC := state.hmacLocal.Sum(nil)
-
- out := &bytes.Buffer{}
- out.WriteByte(opcodeAndKeyHeader(state))
- bufWriteUint24(out, uint32(state.peerID))
-
- out.Write(computedMAC)
- out.Write(iv)
- out.Write(ciphertext)
- return out.Bytes(), nil
-}
-
-// doCompress adds compression bytes if needed by the passed compression options.
-// if the compression stub is on, it sends the first byte to the last position,
-// and it adds the compression preamble, according to the spec. compression
-// lzo-no also adds a preamble. It returns a byte array and an error if the
-// operation could not be completed.
-func doCompress(b []byte, c compression) ([]byte, error) {
- switch c {
- case "stub":
- // compression stub: send first byte to last
- // and add 0xfb marker on the first byte.
- b = append(b, b[0])
- b[0] = 0xfb
- case "lzo-no":
- // old "comp-lzo no" option
- b = append([]byte{0xfa}, b...)
- }
- return b, nil
-}
-
-// doPadding does pkcs7 padding of the encryption payloads as
-// needed. if we're using the compression stub the padding is applied without taking the
-// trailing bit into account. it returns the resulting byte array, and an error
-// if the operatio could not be completed.
-func doPadding(b []byte, compress compression, blockSize uint8) ([]byte, error) {
- if len(b) == 0 {
- return nil, fmt.Errorf("%w: nothing to pad", errBadInput)
- }
- if compress == "stub" {
- // if we're using the compression stub
- // we need to account for a trailing byte
- // that we have appended in the doCompress stage.
- endByte := b[len(b)-1]
- padded, err := bytesPadPKCS7(b[:len(b)-1], int(blockSize))
- if err != nil {
- return nil, err
- }
- padded[len(padded)-1] = endByte
- return padded, nil
- }
- padded, err := bytesPadPKCS7(b, int(blockSize))
- if err != nil {
- return nil, err
- }
- return padded, nil
-}
-
-// prependPacketID returns the original buffer with the passed packetID
-// concatenated at the beginning.
-func prependPacketID(p packetID, buf []byte) []byte {
- newbuf := &bytes.Buffer{}
- packetID := make([]byte, 4)
- binary.BigEndian.PutUint32(packetID, uint32(p))
- newbuf.Write(packetID[:])
- newbuf.Write(buf)
- return newbuf.Bytes()
-}
-
-func (d *data) WritePacket(conn net.Conn, payload []byte) (int, error) {
- if d.state == nil || d.state.dataCipher == nil {
- return 0, fmt.Errorf("%w: %s", errBadInput, "bad state")
- }
-
- var plain []byte
- var err error
-
- // TODO(ainghazal): separate into two different implementations
- // and get rid of multiple switch.
- switch d.state.dataCipher.isAEAD() {
- case true:
- plain, err = doCompress(payload, d.options.Compress)
- if err != nil {
- return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err)
- }
- case false: // non-aead
- localPacketID, _ := d.session.LocalPacketID()
- plain = prependPacketID(localPacketID, payload)
-
- plain, err = doCompress(plain, d.options.Compress)
- if err != nil {
- return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err)
- }
- }
-
- // encrypted adds padding, if needed, and it also includes the
- // opcode/keyid and peer-id headers and, if used, any authenticated
- // parts in the packet.
- encrypted, err := d.EncryptAndEncodePayload(plain, d.state)
- if err != nil {
- return 0, fmt.Errorf("%w: %s", errCannotEncrypt, err)
- }
-
- // TODO(ainghazal): increment counter for used bytes, and
- // trigger renegotiation if we're near the end of the key useful lifetime.
-
- out := maybeAddSizeFrame(conn, encrypted)
-
- logger.Debug("data: write packet")
- logger.Debugf("\n" + hex.Dump(out))
-
- return conn.Write(out)
-}
-
-//
-// read + decrypt
-//
-
-func (d *data) decrypt(encrypted []byte) ([]byte, error) {
- if d.decryptFn == nil {
- return []byte{}, errInitError
- }
- if len(d.state.hmacKeyRemote) == 0 {
- logger.Error("decrypt: not ready yet")
- return []byte{}, errCannotDecrypt
- }
- encryptedData, err := d.DecodeEncryptedPayload(encrypted, d.state)
-
- if err != nil {
- return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, err)
- }
- plainText, err := d.decryptFn(d.state.cipherKeyRemote[:], encryptedData)
- if err != nil {
- return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, err)
- }
- return plainText, nil
-}
-
-func decodeEncryptedPayloadAEAD(buf []byte, state *dataChannelState) (*encryptedData, error) {
- // P_DATA_V2 GCM data channel crypto format
- // 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a...
- // [ OP32 ] [seq # ] [ auth tag ] [ payload ... ]
- // - means authenticated - * means encrypted *
- // [ - opcode/peer-id - ] [ - packet ID - ] [ TAG ] [ * packet payload * ]
-
- // preconditions
-
- if len(buf) == 0 || len(buf) < 20 {
- return &encryptedData{}, fmt.Errorf("too short: %d bytes", len(buf))
- }
- if len(state.hmacKeyRemote) < 8 {
- return &encryptedData{}, fmt.Errorf("bad remote hmac")
- }
- remoteHMAC := state.hmacKeyRemote[:8]
- packet_id := buf[:4]
-
- headers := &bytes.Buffer{}
- headers.WriteByte(opcodeAndKeyHeader(state))
- bufWriteUint24(headers, uint32(state.peerID))
- headers.Write(packet_id)
-
- // we need to swap because decryption expects payload|tag
- // but we've got tag | payload instead
- payload := &bytes.Buffer{}
- payload.Write(buf[20:]) // ciphertext
- payload.Write(buf[4:20]) // tag
-
- // iv := packetID | remoteHMAC
- iv := &bytes.Buffer{}
- iv.Write(packet_id)
- iv.Write(remoteHMAC)
-
- encrypted := &encryptedData{
- iv: iv.Bytes(),
- ciphertext: payload.Bytes(),
- aead: headers.Bytes(),
- }
- return encrypted, nil
-}
-
-func decodeEncryptedPayloadNonAEAD(buf []byte, state *dataChannelState) (*encryptedData, error) {
- if state == nil || state.dataCipher == nil {
- return &encryptedData{}, fmt.Errorf("%w: bad state", errBadInput)
- }
- hashSize := uint8(state.hmacRemote.Size())
- blockSize := state.dataCipher.blockSize()
-
- minLen := hashSize + blockSize
-
- if len(buf) < int(minLen) {
- return &encryptedData{}, fmt.Errorf("%w: too short (%d bytes)", errBadInput, len(buf))
- }
-
- receivedHMAC := buf[:hashSize]
- iv := buf[hashSize : hashSize+blockSize]
- cipherText := buf[hashSize+blockSize:]
-
- state.hmacRemote.Reset()
- state.hmacRemote.Write(iv)
- state.hmacRemote.Write(cipherText)
- computedHMAC := state.hmacRemote.Sum(nil)
-
- if !hmac.Equal(computedHMAC, receivedHMAC) {
- logger.Errorf("expected: %x, got: %x", computedHMAC, receivedHMAC)
- return &encryptedData{}, fmt.Errorf("%w: %s", errCannotDecrypt, errBadHMAC)
- }
-
- encrypted := &encryptedData{
- iv: iv,
- ciphertext: cipherText,
- aead: []byte{}, // no AEAD data in this mode, leaving it empty to satisfy common interface
- }
- return encrypted, nil
-}
-
-func (d *data) ReadPacket(p *packet) ([]byte, error) {
- if len(p.payload) == 0 {
- return []byte{}, fmt.Errorf("%w: %s", errCannotDecrypt, "empty payload")
- }
- panicIfFalse(p.isData(), "ReadPacket expects data packet")
-
- plaintext, err := d.decrypt(p.payload)
- if err != nil {
- return []byte{}, err
- }
-
- // get plaintext payload from the decrypted plaintext
- return maybeDecompress(plaintext, d.state, d.options)
-}
-
-// maybeDecompress de-serializes the data from the payload according to the framing
-// given by different compression methods. only the different no-compression
-// modes are supported at the moment, so no real decompression is done. It
-// returns a byte array, and an error if the operation could not be completed
-// successfully.
-func maybeDecompress(b []byte, st *dataChannelState, opt *Options) ([]byte, error) {
- if st == nil || st.dataCipher == nil {
- return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad state")
- }
- if opt == nil {
- return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad options")
- }
-
- var compr byte // compression type
- var payload []byte
-
- // TODO(ainghazal): have two different decompress implementations
- // instead of this switch
- switch st.dataCipher.isAEAD() {
- case true:
- switch opt.Compress {
- case compressionStub, compressionLZONo:
- // these are deprecated in openvpn 2.5.x
- compr = b[0]
- payload = b[1:]
- default:
- compr = 0x00
- payload = b[:]
- }
- default: // non-aead
- remotePacketID := packetID(binary.BigEndian.Uint32(b[:4]))
- lastKnownRemote, err := st.RemotePacketID()
- if err != nil {
- return payload, err
- }
- if remotePacketID <= lastKnownRemote {
- return []byte{}, errReplayAttack
- }
- st.SetRemotePacketID(remotePacketID)
-
- switch opt.Compress {
- case compressionStub, compressionLZONo:
- compr = b[4]
- payload = b[5:]
- default:
- compr = 0x00
- payload = b[4:]
- }
- }
-
- switch compr {
- case 0xfb:
- // compression stub swap:
- // we get the last byte and replace the compression byte
- // these are deprecated in openvpn 2.5.x
- end := payload[len(payload)-1]
- b := payload[:len(payload)-1]
- payload = append([]byte{end}, b...)
- case 0x00, 0xfa:
- // do nothing
- // 0x00 is compress-no,
- // 0xfa is the old no compression or comp-lzo no case.
- // http://build.openvpn.net/doxygen/comp_8h_source.html
- // see: https://community.openvpn.net/openvpn/ticket/952#comment:5
- default:
- errMsg := fmt.Sprintf("cannot handle compression:%x", compr)
- return []byte{}, fmt.Errorf("%w:%s", errBadCompression, errMsg)
- }
- return payload, nil
-}
-
-// opcodeAndKeyHeader returns the header byte encoding the opcode and keyID (3 upper
-// and 5 lower bits, respectively)
-func opcodeAndKeyHeader(st *dataChannelState) byte {
- return byte((pDataV2 << 3) | (st.keyID & 0x07))
-}
diff --git a/vpn/data_test.go b/vpn/data_test.go
deleted file mode 100644
index b311275e..00000000
--- a/vpn/data_test.go
+++ /dev/null
@@ -1,1425 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "encoding/hex"
- "errors"
- "fmt"
- "math"
- "net"
- "reflect"
- "sync"
- "testing"
-
- "github.com/ooni/minivpn/vpn/mocks"
-)
-
-const (
- rnd16 = "0123456789012345"
- rnd32 = "01234567890123456789012345678901"
- rnd48 = "012345678901234567890123456789012345678901234567"
-)
-
-func makeTestKeys() ([32]byte, [32]byte, [48]byte) {
- r1 := *(*[32]byte)([]byte(rnd32))
- r2 := *(*[32]byte)([]byte(rnd32))
- r3 := *(*[48]byte)([]byte(rnd48))
- return r1, r2, r3
-}
-
-// getDeterministicRandomKeySize returns a sequence of integers
-// using the map in the closure. we use this to construct a deterministic
-// random function to replace the random function used in the real client.
-func getDeterministicRandomKeySizeFn() func() int {
- var rndSeq = map[int]int{
- 1: 32,
- 2: 32,
- 3: 48,
- }
- i := 1
- f := func() int {
- v := rndSeq[i]
- i += 1
- return v
- }
- return f
-}
-
-func Test_newKeySource(t *testing.T) {
-
- genKeySizeFn := getDeterministicRandomKeySizeFn()
-
- // we replace the global random function used in the constructor
- randomFn = func(int) ([]byte, error) {
- switch genKeySizeFn() {
- case 48:
- return []byte(rnd48), nil
- default:
- return []byte(rnd32), nil
- }
- }
-
- r1, r2, premaster := makeTestKeys()
- ks := &keySource{r1, r2, premaster}
-
- tests := []struct {
- name string
- want *keySource
- }{
- {
- name: "test generation of a new key with mocked random data",
- want: ks,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got, _ := newKeySource(); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newKeySource() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func makeTestingSession() *session {
- s := &session{
- RemoteSessionID: sessionID{0x01},
- LocalSessionID: sessionID{0x02},
- mu: sync.Mutex{},
- }
- return s
-}
-
-func makeTestingOptions(t *testing.T, cipher, auth string) *Options {
- crt, _ := writeTestingCerts(t.TempDir())
- opt := &Options{
- Cipher: cipher,
- Auth: auth,
- CertPath: crt.cert,
- KeyPath: crt.key,
- CaPath: crt.ca,
- }
- return opt
-}
-
-func Test_newDataFromOptions(t *testing.T) {
- type args struct {
- opt *Options
- s *session
- }
- tests := []struct {
- name string
- args args
- want *data
- wantWhatever bool
- wantErr error
- }{
- {
- name: "nil args should fail",
- args: args{},
- want: nil,
- wantErr: errBadInput,
- },
- {
- name: "empty Options should fail",
- args: args{
- opt: &Options{},
- s: makeTestingSession(),
- },
- want: nil,
- wantErr: errBadInput,
- },
- {
- name: "bad auth in Options should fail",
- args: args{
- opt: makeTestingOptions(t, "AES-128-GCM", "shabad"),
- s: makeTestingSession(),
- },
- wantWhatever: true,
- wantErr: errBadInput,
- },
- {
- name: "empty session should not fail",
- args: args{
- opt: makeTestingOptions(t, "AES-128-GCM", "sha512"),
- s: &session{},
- },
- wantWhatever: true,
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := newDataFromOptions(tt.args.opt, tt.args.s)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("newDataFromOptions() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !tt.wantWhatever && !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newDataFromOptions() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func makeTestingDataChannelKey() *dataChannelKey {
- rl1, rl2, preml := makeTestKeys()
- rr1, rr2, premr := makeTestKeys()
-
- ksLocal := &keySource{rl1, rl2, preml}
- ksRemote := &keySource{rr1, rr2, premr}
-
- dck := &dataChannelKey{
- ready: true,
- local: ksLocal,
- remote: ksRemote,
- }
- return dck
-}
-
-func Test_data_SetupKeys(t *testing.T) {
- type fields struct {
- session *session
- state *dataChannelState
- }
- type args struct {
- dck *dataChannelKey
- }
- tests := []struct {
- name string
- fields fields
- args args
- wantErr error
- }{
- {
- name: "nil in arguments should fail",
- fields: fields{
- session: makeTestingSession(),
- state: makeTestingState(),
- },
- args: args{},
- wantErr: errBadInput,
- },
- {
- name: "dataChannelKey not ready",
- fields: fields{
- session: makeTestingSession(),
- state: makeTestingState(),
- },
- args: args{
- dck: &dataChannelKey{},
- },
- wantErr: errDataChannelKey,
- },
- {
- name: "good setup",
- fields: fields{
- session: makeTestingSession(),
- state: makeTestingState(),
- },
- args: args{
- dck: makeTestingDataChannelKey(),
- },
- wantErr: nil,
- // TODO(ainghazal): should write another test to verify the key derivation?
- // but what that would be testing, if not the implementation?
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &data{
- session: tt.fields.session,
- state: tt.fields.state,
- }
- if err := d.SetupKeys(tt.args.dck); !errors.Is(err, tt.wantErr) {
- t.Errorf("data.SetupKeys() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func Test_data_EncryptAndEncodePayload(t *testing.T) {
- // TODO(ainghazal): this is exercising only one encryption method
-
- opt := &Options{}
-
- type fields struct {
- options *Options
- session *session
- state *dataChannelState
- decodeFn func([]byte, *dataChannelState) (*encryptedData, error)
- encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error)
- }
- type args struct {
- plaintext []byte
- dcs *dataChannelState
- }
- tests := []struct {
- name string
- fields fields
- args args
- want []byte
- wantErr error
- }{
- {
- name: "dummy encryptEncodeFn does not fail",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- return []byte{}, nil
- },
- },
- args: args{
- plaintext: []byte("hello"),
- dcs: makeTestingState(),
- },
- want: []byte{},
- wantErr: nil,
- },
- {
- name: "empty plaintext fails",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- return []byte{}, nil
- },
- },
- args: args{
- plaintext: []byte{},
- dcs: makeTestingState(),
- },
- want: []byte{},
- wantErr: errCannotEncrypt,
- },
- {
- name: "error on encryptEncodeFn gets propagated",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- return []byte{}, errors.New("dummyTestError")
- },
- },
- args: args{
- plaintext: []byte{},
- dcs: makeTestingState(),
- },
- want: []byte{},
- wantErr: errCannotEncrypt,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &data{
- options: tt.fields.options,
- session: tt.fields.session,
- state: tt.fields.state,
- decodeFn: tt.fields.decodeFn,
- encryptEncodeFn: tt.fields.encryptEncodeFn,
- }
- got, err := d.EncryptAndEncodePayload(tt.args.plaintext, tt.args.dcs)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("data.EncryptAndEncodePayload() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("data.EncryptAndEncodePayload() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_dataChannelState_RemotePacketID(t *testing.T) {
- type fields struct {
- remotePacketID packetID
- }
- tests := []struct {
- name string
- fields fields
- want packetID
- wantErr error
- }{
- {
- "zero",
- fields{0},
- packetID(0),
- nil,
- },
- {
- "one",
- fields{1},
- packetID(1),
- nil,
- },
- {
- "overflow",
- fields{math.MaxUint32},
- packetID(0),
- errExpiredKey,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- dcs := &dataChannelState{
- remotePacketID: tt.fields.remotePacketID,
- }
- if got, err := dcs.RemotePacketID(); got != tt.want || err != tt.wantErr {
- t.Errorf("dataChannelState.RemotePacketID() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_keySource_Bytes(t *testing.T) {
- r1, r2, premaster := makeTestKeys()
- goodSerialized := append(premaster[:], r1[:]...)
- goodSerialized = append(goodSerialized, r2[:]...)
-
- type fields struct {
- r1 [32]byte
- r2 [32]byte
- preMaster [48]byte
- }
- tests := []struct {
- name string
- fields fields
- want []byte
- }{
- {
- name: "good keysource",
- fields: fields{
- r1: r1,
- r2: r2,
- preMaster: premaster,
- },
- want: goodSerialized,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- k := &keySource{
- r1: tt.fields.r1,
- r2: tt.fields.r2,
- preMaster: tt.fields.preMaster,
- }
- if got := k.Bytes(); !bytes.Equal(got, tt.want) {
- t.Errorf("keySource.Bytes() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_dataChannelKey_addRemoteKey(t *testing.T) {
- type fields struct {
- ready bool
- remote *keySource
- }
- type args struct {
- k *keySource
- }
- tests := []struct {
- name string
- fields fields
- args args
- wantErr bool
- }{
- {
- "passing a keysource should make it ready",
- fields{false, &keySource{}},
- args{&keySource{}},
- false,
- },
- {
- "fail if ready",
- fields{true, &keySource{}},
- args{&keySource{}},
- true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- dck := &dataChannelKey{
- ready: tt.fields.ready,
- remote: tt.fields.remote,
- }
- if err := dck.addRemoteKey(tt.args.k); (err != nil) != tt.wantErr {
- t.Errorf("dataChannelKey.addRemoteKey() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func makeTestingState() *dataChannelState {
- dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeGCM)
- st := &dataChannelState{
- hash: sha1.New,
- // my linter doesn't like it, but this is the proper way of casting to keySlot
- cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
- cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
- hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
- hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
- }
- st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20])
- st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20])
- st.dataCipher = dataCipher
- return st
-}
-
-func makeTestingStateNonAEAD() *dataChannelState {
- dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeCBC)
- st := &dataChannelState{
- hash: sha1.New,
- // my linter doesn't like it, but this is the proper way of casting to keySlot
- cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
- cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
- hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
- hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
- }
- st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20])
- st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20])
- st.dataCipher = dataCipher
- return st
-}
-
-func makeTestingStateNonAEADReversed() *dataChannelState {
- dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeCBC)
- st := &dataChannelState{
- hash: sha1.New,
- // my linter doesn't like it, but this is the proper way of casting to keySlot
- cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
- cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
- hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
- hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
- }
- st.hmacLocal = hmac.New(st.hash, st.hmacKeyLocal[:20])
- st.hmacRemote = hmac.New(st.hash, st.hmacKeyRemote[:20])
- st.dataCipher = dataCipher
- return st
-}
-
-func Test_data_decrypt(t *testing.T) {
-
- goodMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) {
- return []byte("alles ist gut"), nil
- }
-
- failingMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) {
- return []byte{}, errCannotDecrypt
- }
-
- opt := &Options{}
-
- type fields struct {
- options *Options
- session *session
- state *dataChannelState
- decodeFn func([]byte, *dataChannelState) (*encryptedData, error)
- encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error)
- decryptFn func([]byte, *encryptedData) ([]byte, error)
- }
- type args struct {
- encrypted []byte
- }
- tests := []struct {
- name string
- fields fields
- args args
- want []byte
- wantErr error
- }{
- {
- name: "empty output in decodeFn does fail",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) {
- return &encryptedData{}, nil
- },
- encryptEncodeFn: nil,
- decryptFn: makeTestingState().dataCipher.decrypt,
- },
- args: args{
- encrypted: bytes.Repeat([]byte{0x0a}, 20),
- },
- want: []byte{},
- wantErr: errCannotDecrypt,
- },
- {
- name: "empty encrypted input does fail",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) {
- return &encryptedData{}, nil
- },
- encryptEncodeFn: nil,
- decryptFn: makeTestingState().dataCipher.decrypt,
- },
- args: args{
- encrypted: []byte{},
- },
- want: []byte{},
- wantErr: errCannotDecrypt,
- },
- {
- name: "error in decrypt propagates",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) {
- return &encryptedData{}, nil
- },
- encryptEncodeFn: nil,
- decryptFn: failingMockDecryptFn,
- },
- args: args{
- encrypted: []byte{},
- },
- want: []byte{},
- wantErr: errCannotDecrypt,
- },
- {
- name: "good decrypt returns expected output",
- fields: fields{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: func(b []byte, st *dataChannelState) (*encryptedData, error) {
- return &encryptedData{}, nil
- },
- encryptEncodeFn: nil,
- decryptFn: goodMockDecryptFn,
- },
- args: args{
- encrypted: []byte{},
- },
- want: []byte("alles ist gut"),
- wantErr: nil,
- },
- // TODO we already are testing decrypt + encrypt in the crypto module
- // so we can mock the decrypt here in the state.
- // TODO empty ciphertext raises error
- // TODO: Add moar test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &data{
- options: tt.fields.options,
- session: tt.fields.session,
- state: tt.fields.state,
- decodeFn: tt.fields.decodeFn,
- encryptEncodeFn: tt.fields.encryptEncodeFn,
- decryptFn: tt.fields.decryptFn,
- }
- got, err := d.decrypt(tt.args.encrypted)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("data.decrypt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("data.decrypt() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_decodeEncryptedPayloadAEAD(t *testing.T) {
-
- state := makeTestingState()
-
- goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766")
- goodDecodeIV, _ := hex.DecodeString("000000006868686868686868")
- goodDecodeCipherText, _ := hex.DecodeString("31278ff328ff2fc65c4dbf9eb8e67766b3653a842f2b8a148de26375218fb01d")
- goodDecodeAEAD, _ := hex.DecodeString("4800000000000000")
-
- type args struct {
- buf []byte
- state *dataChannelState
- }
- tests := []struct {
- name string
- args args
- want *encryptedData
- wantErr bool
- }{
- {
- "empty",
- args{[]byte{}, &dataChannelState{}},
- &encryptedData{},
- true,
- },
- {
- "too short",
- args{bytes.Repeat([]byte{0xff}, 19), &dataChannelState{}},
- &encryptedData{},
- true,
- },
- {
- "good decode",
- args{goodEncryptedPayload, state},
- &encryptedData{
- iv: goodDecodeIV,
- ciphertext: goodDecodeCipherText,
- aead: goodDecodeAEAD,
- },
- false,
- },
- // TODO: Add moar test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := decodeEncryptedPayloadAEAD(tt.args.buf, tt.args.state)
- if (err != nil) != tt.wantErr {
- t.Errorf("decodeEncryptedPayloadAEAD() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("decodeEncryptedPayloadAEAD() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_decodeEncryptedPayloadNonAEAD(t *testing.T) {
-
- goodInput, _ := hex.DecodeString("fdf9b069b2e5a637fa7b5c9231166ea96307e4123031323334353637383930313233343581e4878c5eec602c2d2f5a95139c84af")
- iv, _ := hex.DecodeString("30313233343536373839303132333435")
- ciphertext, _ := hex.DecodeString("81e4878c5eec602c2d2f5a95139c84af")
-
- type args struct {
- buf []byte
- state *dataChannelState
- }
- tests := []struct {
- name string
- args args
- want *encryptedData
- wantErr error
- }{
- {
- name: "empty",
- args: args{[]byte{}, &dataChannelState{}},
- want: &encryptedData{},
- wantErr: errBadInput,
- },
- {
- name: "too short",
- args: args{bytes.Repeat([]byte{0xff}, 27), &dataChannelState{}},
- want: &encryptedData{},
- wantErr: errBadInput,
- },
- {
- name: "nil state should fail",
- args: args{goodInput, nil},
- want: &encryptedData{},
- wantErr: errBadInput,
- },
- {
- name: "empty state.dataCipher should fail",
- args: args{goodInput, &dataChannelState{}},
- want: &encryptedData{},
- wantErr: errBadInput,
- },
- {
- name: "good decode",
- args: args{goodInput, makeTestingStateNonAEADReversed()},
- want: &encryptedData{
- iv: iv,
- ciphertext: ciphertext,
- },
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := decodeEncryptedPayloadNonAEAD(tt.args.buf, tt.args.state)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("decodeEncryptedPayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !bytes.Equal(got.iv, tt.want.iv) {
- t.Errorf("decodeEncryptedPayloadNonAEAD().iv = %v, want %v", got.iv, tt.want.iv)
- }
- if !bytes.Equal(got.ciphertext, tt.want.ciphertext) {
- t.Errorf("decodeEncryptedPayloadNonAEAD().iv = %v, want %v", got.iv, tt.want.iv)
- }
- })
- }
-}
-
-func Test_encryptAndEncodePayloadAEAD(t *testing.T) {
-
- state := makeTestingState()
- padded, _ := doPadding([]byte("hello go tests"), "", state.dataCipher.blockSize())
-
- goodEncryptedPayload, _ := hex.DecodeString("48000000000000006ba730fd633b1d5f11478f6f601cb84231278ff328ff2fc65c4dbf9eb8e67766")
-
- type args struct {
- padded []byte
- session *session
- state *dataChannelState
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr bool
- }{
- {
- "good encrypt",
- args{padded, &session{}, state},
- goodEncryptedPayload,
- false,
- },
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := encryptAndEncodePayloadAEAD(tt.args.padded, tt.args.session, tt.args.state)
- if (err != nil) != tt.wantErr {
- t.Errorf("encryptAndEncodePayloadAEAD() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("encryptAndEncodePayloadAEAD() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_encryptAndEncodePayloadNonAEAD(t *testing.T) {
-
- padded16 := bytes.Repeat([]byte{0xff}, 16)
- padded15 := bytes.Repeat([]byte{0xff}, 15)
-
- // including OP32 header + peerid (v2)
- goodEncrypted, _ := hex.DecodeString("48000000fdf9b069b2e5a637fa7b5c9231166ea96307e4123031323334353637383930313233343581e4878c5eec602c2d2f5a95139c84af")
-
- // we replace the global random function that is used for the iv in, e.g., CBC mode.
- randomFn = func(i int) ([]byte, error) {
- switch i {
- case 16:
- return []byte(rnd16), nil
- default:
- return []byte(rnd32), nil
- }
- }
-
- type args struct {
- padded []byte
- session *session
- state *dataChannelState
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{
- {
- name: "good encrypt",
- args: args{
- padded: padded16,
- session: &session{},
- state: makeTestingStateNonAEAD()},
- want: goodEncrypted,
- wantErr: nil,
- },
- {
- name: "badly padded input should fail",
- args: args{
- padded: padded15,
- session: &session{},
- state: makeTestingStateNonAEAD()},
- want: nil,
- wantErr: errCannotEncrypt,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := encryptAndEncodePayloadNonAEAD(tt.args.padded, tt.args.session, tt.args.state)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("encryptAndEncodePayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !bytes.Equal(got, tt.want) {
- fmt.Println(hex.EncodeToString(got))
- t.Errorf("encryptAndEncodePayloadNonAEAD() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_doCompress(t *testing.T) {
- type args struct {
- b []byte
- opt compression
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{
- {
- name: "null compression should not fail",
- args: args{},
- want: []byte{},
- wantErr: nil,
- },
- {
- name: "do nothing by default",
- args: args{
- b: []byte{0xde, 0xad, 0xbe, 0xef},
- opt: "",
- },
- want: []byte{0xde, 0xad, 0xbe, 0xef},
- wantErr: nil,
- },
- {
- name: "stub appends the first byte at the end",
- args: args{
- b: []byte{0xde, 0xad, 0xbe, 0xef},
- opt: "stub",
- },
- want: []byte{0xfb, 0xad, 0xbe, 0xef, 0xde},
- wantErr: nil,
- },
- {
- name: "lzo-no adds 0xfa preamble",
- args: args{
- b: []byte{0xde, 0xad, 0xbe, 0xef},
- opt: "lzo-no",
- },
- want: []byte{0xfa, 0xde, 0xad, 0xbe, 0xef},
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := doCompress(tt.args.b, tt.args.opt)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("maybeAddCompressStub() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !bytes.Equal(got, tt.want) {
- t.Errorf("maybeAddCompressStub() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_doPadding(t *testing.T) {
- type args struct {
- b []byte
- compress compression
- blockSize uint8
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{
- {
- name: "add a whole padding block if len equal to block size, no padding stub",
- args: args{
- b: []byte{0x00, 0x01, 0x02, 0x03},
- compress: compression(""),
- blockSize: 4,
- },
- want: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x04, 0x04, 0x04},
- wantErr: nil,
- },
- {
- name: "compression stub with len == blocksize",
- args: args{
- b: []byte{0x00, 0x01, 0x02, 0x03},
- compress: compressionStub,
- blockSize: 4,
- },
- want: []byte{0x00, 0x01, 0x02, 0x03},
- wantErr: nil,
- },
- {
- name: "compression stub with len < blocksize",
- args: args{
- b: []byte{0x00, 0x01, 0xff},
- compress: compressionStub,
- blockSize: 4,
- },
- want: []byte{0x00, 0x01, 0x02, 0xff},
- wantErr: nil,
- },
- {
- name: "compression stub with len = blocksize + 1",
- args: args{
- b: []byte{0x00, 0x01, 0x02, 0x03, 0xff},
- compress: compressionStub,
- blockSize: 4,
- },
- want: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x04, 0x04, 0xff},
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := doPadding(tt.args.b, tt.args.compress, tt.args.blockSize)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("doPadding() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("doPadding() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_prependPacketID(t *testing.T) {
- type args struct {
- p packetID
- buf []byte
- }
- tests := []struct {
- name string
- args args
- want []byte
- }{
- {
- name: "good append",
- args: args{
- packetID(0x01),
- []byte{0x07, 0x08},
- },
- want: []byte{0x00, 0x00, 0x00, 0x01, 0x07, 0x08},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := prependPacketID(tt.args.p, tt.args.buf); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("prependPacketID() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_maybeDecompress(t *testing.T) {
-
- getStateForDecompressTestNonAEAD := func() *dataChannelState {
- st := makeTestingStateNonAEAD()
- st.remotePacketID = packetID(0x42)
- return st
- }
-
- type args struct {
- b []byte
- st *dataChannelState
- opt *Options
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{
- {
- name: "nil state should fail",
- args: args{
- b: []byte{},
- st: nil,
- opt: &Options{},
- },
- want: []byte{},
- wantErr: errBadInput,
- },
- {
- name: "nil options should fail",
- args: args{
- b: []byte{},
- st: makeTestingState(),
- opt: nil,
- },
- want: []byte{},
- wantErr: errBadInput,
- },
- {
- name: "aead cipher, no compression",
- args: args{
- b: []byte{0xaa, 0xbb, 0xcc},
- st: makeTestingState(),
- opt: &Options{},
- },
- want: []byte{0xaa, 0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "aead cipher, no compr",
- args: args{
- b: []byte{0xfa, 0xbb, 0xcc},
- st: makeTestingState(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "aead cipher, stub on options and stub on header",
- args: args{
- b: []byte{0xfb, 0xbb, 0xcc, 0xdd},
- st: makeTestingState(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{0xdd, 0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "aead cipher, stub, unsupported compression",
- args: args{
- b: []byte{0xff, 0xbb, 0xcc},
- st: makeTestingState(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{},
- wantErr: errBadCompression,
- },
- {
- name: "aead cipher, lzo-no",
- args: args{
- b: []byte{0xfa, 0xbb, 0xcc},
- st: makeTestingState(),
- opt: &Options{Compress: "lzo-no"},
- },
- want: []byte{0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "aead cipher, compress-no",
- args: args{
- b: []byte{0x00, 0xbb, 0xcc},
- st: makeTestingState(),
- opt: &Options{Compress: "no"},
- },
- want: []byte{0x00, 0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "non-aead cipher, stub",
- args: args{
- b: []byte{0x00, 0x00, 0x00, 0x43, 0x00, 0xbb, 0xcc},
- st: getStateForDecompressTestNonAEAD(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "non-aead cipher, stub, unsupported compression",
- args: args{
- b: []byte{0x00, 0x00, 0x00, 0x43, 0x0ff, 0xbb, 0xcc},
- st: getStateForDecompressTestNonAEAD(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{},
- wantErr: errBadCompression,
- },
- {
- name: "non-aead cipher, compress-no",
- args: args{
- b: []byte{0x00, 0x00, 0x00, 0x43, 0x00, 0xbb, 0xcc},
- st: getStateForDecompressTestNonAEAD(),
- opt: &Options{Compress: "no"},
- },
- want: []byte{0x00, 0xbb, 0xcc},
- wantErr: nil,
- },
- {
- name: "non-aead cipher, replay detected (equal remote packetID)",
- args: args{
- b: []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0xbb, 0xcc},
- st: getStateForDecompressTestNonAEAD(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{},
- wantErr: errReplayAttack,
- },
- {
- name: "non-aead cipher, replay detected (lesser remote packetID)",
- args: args{
- b: []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0xbb, 0xcc},
- st: getStateForDecompressTestNonAEAD(),
- opt: &Options{Compress: "stub"},
- },
- want: []byte{},
- wantErr: errReplayAttack,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := maybeDecompress(tt.args.b, tt.args.st, tt.args.opt)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("maybeDecompress() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("maybeDecompress() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_data_ReadPacket(t *testing.T) {
-
- goodMockDecodeFn := func([]byte, *dataChannelState) (*encryptedData, error) {
- d := &encryptedData{
- iv: []byte{0xee},
- ciphertext: []byte("garbledpayload"),
- aead: []byte{0xff},
- }
- return d, nil
- }
-
- goodMockDecryptFn := func([]byte, *encryptedData) ([]byte, error) {
- return []byte("alles ist gut"), nil
- }
-
- type fields struct {
- options *Options
- state *dataChannelState
- decryptFn func([]byte, *encryptedData) ([]byte, error)
- decodeFn func([]byte, *dataChannelState) (*encryptedData, error)
- }
- type args struct {
- p *packet
- }
- tests := []struct {
- name string
- fields fields
- args args
- want []byte
- wantErr error
- }{
- {
- name: "good decrypt using mocked decrypt fn and decode fn",
- fields: fields{
- options: makeTestingOptions(t, "AES-128-GCM", "sha1"),
- state: makeTestingState(),
- decryptFn: goodMockDecryptFn,
- decodeFn: goodMockDecodeFn,
- },
- args: args{&packet{
- opcode: pDataV1,
- payload: []byte("garbled")},
- },
- want: []byte("alles ist gut"),
- wantErr: nil,
- },
- // TODO panic when call to DecodeEncryptedPayload
- // TODO error if empty payload
- // TODO make sure decompress fn is called?
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &data{
- options: tt.fields.options,
- state: tt.fields.state,
- decryptFn: tt.fields.decryptFn,
- decodeFn: tt.fields.decodeFn,
- }
- got, err := d.ReadPacket(tt.args.p)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("data.ReadPacket() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("data.ReadPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-// we'll use a mocked net.Conn for WritePacket
-
-func makeTestingConnForWrite(network, addr string, n int) net.Conn {
- mockAddr := &mocks.Addr{}
- mockAddr.MockString = func() string {
- return addr
- }
- mockAddr.MockNetwork = func() string {
- return network
- }
-
- mockConn := &mocks.Conn{}
- mockConn.MockLocalAddr = func() net.Addr {
- return mockAddr
- }
- mockConn.MockWrite = func([]byte) (int, error) {
- return n, nil
- }
- mockConn.MockRead = func([]byte) (int, error) {
- return n, nil
- }
- return mockConn
-}
-
-func Test_data_WritePacket(t *testing.T) {
- opt := &Options{}
-
- goodMockEncodedEncryptFn := func([]byte, *session, *dataChannelState) ([]byte, error) {
- return []byte("alles ist garbled gut"), nil
- }
-
- type fields struct {
- options *Options
- // session is only used for NonAEAD encryption
- session *session
- state *dataChannelState
- encryptEncodeFn func([]byte, *session, *dataChannelState) ([]byte, error)
- }
- type args struct {
- conn net.Conn
- payload []byte
- }
- tests := []struct {
- name string
- fields fields
- args args
- want int
- wantErr error
- }{
- {
- name: "good write, aead encryption",
- fields: fields{
- options: opt,
- session: nil,
- state: makeTestingState(),
- encryptEncodeFn: goodMockEncodedEncryptFn,
- },
- args: args{
- conn: makeTestingConnForWrite("udp", "10.0.42.1", 42),
- payload: []byte("hello test"),
- },
- want: 42,
- wantErr: nil,
- },
-
- // TODO: Add moar test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &data{
- options: tt.fields.options,
- session: tt.fields.session,
- state: tt.fields.state,
- encryptEncodeFn: tt.fields.encryptEncodeFn,
- }
- got, err := d.WritePacket(tt.args.conn, tt.args.payload)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("data.WritePacket() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("data.WritePacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-// Regression test for MIV-01-003
-func Test_Crash_EncryptAndEncodePayload(t *testing.T) {
- opt := &Options{}
- st := &dataChannelState{
- hash: sha1.New,
- cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
- cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
- hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
- hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
- }
- a := &data{
- options: opt,
- session: makeTestingSession(),
- state: st,
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- return []byte{}, nil
- },
- }
- a.EncryptAndEncodePayload(nil, a.state)
-}
-
-// Regression test for MIV-01-004
-func Test_Crash_doPadding(t *testing.T) {
- arr := []byte{}
- doPadding(arr, "stub", 16)
-}
-
-// Regression test for MIV-01-004
-func Test_Crash_EncryptAndEncodePayload_Zero_Len_Array(t *testing.T) {
- opt := &Options{}
- st := &dataChannelState{
- hash: sha1.New,
- cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
- cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
- hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
- hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
- }
- a := &data{
- options: opt,
- session: makeTestingSession(),
- state: st,
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- if len(b) == 0 {
- return nil, fmt.Errorf("should not receive zero len")
- }
- return []byte{}, nil
- },
- }
- _, err := a.EncryptAndEncodePayload([]byte{}, a.state)
- if err == nil || !errors.Is(err, errCannotEncrypt) {
- t.Error("should not fail with zero len")
- }
-}
-
-func base64Decode(str string) (string, bool) {
- data, err := base64.StdEncoding.DecodeString(str)
- if err != nil {
- return "", true
- }
- return string(data), false
-}
-
-// Regression test for MIV-01-004
-func Test_Crash_DecodeEncryptedPayload_Too_Short(t *testing.T) {
- input, _ := base64Decode("////////mv//////////////////////////xxk=")
- opt := &Options{}
- type args struct {
- encrypted []byte
- dcs *dataChannelState
- }
- a := &data{
- options: opt,
- session: makeTestingSession(),
- state: makeTestingState(),
- decodeFn: nil,
- encryptEncodeFn: func(b []byte, s *session, st *dataChannelState) ([]byte, error) {
- return []byte{}, nil
- },
- }
- a.decodeFn = decodeEncryptedPayloadNonAEAD
- b := &args{[]byte(input), makeTestingState()}
- a.DecodeEncryptedPayload(b.encrypted, b.dcs)
-}
diff --git a/vpn/dialer.go b/vpn/dialer.go
deleted file mode 100644
index ffb43b8d..00000000
--- a/vpn/dialer.go
+++ /dev/null
@@ -1,217 +0,0 @@
-package vpn
-
-//
-// This file contains dialer types and functions that allow transparent use of
-// an OpenVPN connection.
-//
-
-import (
- "context"
- "fmt"
- "net"
- "net/netip"
- "sync"
- "time"
-
- "golang.zx2c4.com/wireguard/tun"
- "golang.zx2c4.com/wireguard/tun/netstack"
-)
-
-var (
- openDNSPrimary = "208.67.222.222"
- openDNSSecondary = "208.67.220.220"
-)
-
-// A TunDialer contains options for obtaining a network connection tunneled
-// through an OpenVPN endpoint. It uses a userspace gVisor virtual device over
-// the raw OpenVPN tunnel.
-//
-// You need to be careful and create only one instance of TunDialer for each
-// Client, since the underlying virtual device will connect both ends of the
-// tunnel.
-type TunDialer struct {
- // Dialer will be passed to the underlying Client constructor.
- Dialer DialerContext
- client *Client
- ns1 string
- ns2 string
- skipDeviceSetup bool
- device *device
- tun *netstack.Net
- mu sync.Mutex
-
- // dependency injection to test client start
- clientStartFn func(context.Context) error
-}
-
-// NewTunDialer creates a new Dialer with the default nameservers (OpenDNS).
-func NewTunDialer(client *Client) *TunDialer {
- td := &TunDialer{
- client: client,
- ns1: openDNSPrimary,
- ns2: openDNSSecondary,
- }
- return td
-}
-
-// NewTunDialerWithNameservers creates a new TunDialer with the passed nameservers.
-// You probably want to pass the nameservers for your own VPN service here.
-func NewTunDialerWithNameservers(client *Client, ns1, ns2 string) *TunDialer {
- td := &TunDialer{
- client: client,
- ns1: ns1,
- ns2: ns2,
- }
- return td
-}
-
-// StartNewTunDialerFromOptions creates a new Dialer directly from an Options
-// object. It also starts the underlying client.
-func StartNewTunDialerFromOptions(opt *Options, dialer DialerContext) (*TunDialer, error) {
- if dialer == nil {
- return nil, fmt.Errorf("%w: nil dialer", errBadInput)
- }
- client := NewClientFromOptions(opt)
- client.Dialer = dialer
- err := client.Start(context.Background())
- if err != nil {
- defer client.Close()
- return nil, err
- }
- td := &TunDialer{
- client: client,
- ns1: openDNSPrimary,
- ns2: openDNSSecondary,
- }
- return td, nil
-}
-
-// Dial connects to the address on the named network, via the OpenVPN endpoint
-// in the Client that this TunDialer is initialized with.
-//
-// The return value implements the net.Conn interface, but it is a socket created
-// on a virtual device, using gVisor userspace network stack. This means that the
-// kernel only sees UDP packets with an encrypted payload.
-//
-// The addresses are resolved via the OpenVPN tunnel too, and against the nameservers
-// configured in the dialer. This feature uses wireguard's little custom DNS client
-// implementation.
-//
-// Dial calls DialContext with the background context. See documentation of
-// DialContext for more details.
-//
-// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
-// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ping4", "ping6".
-func (td *TunDialer) Dial(network, address string) (net.Conn, error) {
- ctx := context.Background()
- tnet, err := td.createNetTUN(ctx)
- if err != nil {
- return nil, err
- }
- return tnet.Dial(network, address)
-}
-
-// DialContext connects to the address on the named network using
-// the provided context.
-//
-// The underlying tun is created just once upon successive invocations of
-// DialContext.
-func (td *TunDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- td.mu.Lock()
- defer td.mu.Unlock()
- if td.tun == nil {
- tnet, err := td.createNetTUN(ctx)
- if err != nil {
- return nil, err
- }
- td.tun = tnet
- }
- return td.tun.DialContext(ctx, network, address)
-}
-
-// DialTimeout acts like Dial but takes a timeout.
-func (td *TunDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
- conn, err := td.Dial(network, address)
- if err != nil {
- return nil, err
- }
- err = conn.SetReadDeadline(time.Now().Add(timeout))
- return conn, err
-}
-
-func (td *TunDialer) createNetTUN(ctx context.Context) (*netstack.Net, error) {
- localIP := td.client.LocalAddr().String()
-
- // create a virtual device in userspace, courtesy of wireguard-go
- tun, tnet, err := netstack.CreateNetTUN(
- []netip.Addr{netip.Addr(netip.MustParseAddr(localIP))},
- []netip.Addr{
- netip.MustParseAddr(td.ns1),
- netip.MustParseAddr(td.ns2)},
- td.client.tunInfo.mtu-100,
- )
- // TODO(https://github.com/ooni/minivpn/issues/26):
- // we cannot use the tun-mtu that the remote advertises, so we subtract
- // a "safety" margin for the time being.
-
- if err != nil {
- return nil, err
- }
-
- // connect the virtual device to our openvpn tunnel
- if !td.skipDeviceSetup {
- dev := &device{tun, td.client}
- dev.Up()
- td.device = dev
- }
- return tnet, nil
-}
-
-// device contains the two halves of the tunnel that we are connecting in our
-// toy implementation: the virtual tun device that is handled by netstack, and
-// the vpn.Client (that satisfies a net.Conn) that writes and reads to sockets
-// provided by the kernel.
-type device struct {
- tun tun.Device
- vpn net.Conn
-}
-
-// Up spawns two goroutines that communicate the two halves of a device.
-// TODO(https://github.com/ooni/minivpn/issues/27): we probably want a way of
-// shutting them down too.
-func (d *device) Up() {
- go func() {
- b := make([]byte, 4096)
- bufs := [][]byte{b}
- sizes := []int{4096}
- for {
- n, err := d.tun.Read(bufs, sizes, 0) // zero offset
- if err != nil {
- logger.Errorf("tun read error: %v", err)
- break
- }
- _, err = d.vpn.Write(b[0:n])
- if err != nil {
- logger.Errorf("vpn write error: %v", err)
- break
- }
-
- }
- }()
- go func() {
- b := make([]byte, 4096)
- for {
- n, err := d.vpn.Read(b)
- if err != nil {
- logger.Errorf("vpn read error: %v", err)
- break
- }
-
- _, err = d.tun.Write([][]byte{b[0:n]}, 0) // zero offset
- if err != nil {
- logger.Errorf("tun write error: %v", err)
- break
- }
- }
- }()
-}
diff --git a/vpn/dialer_test.go b/vpn/dialer_test.go
deleted file mode 100644
index ed5d361c..00000000
--- a/vpn/dialer_test.go
+++ /dev/null
@@ -1,449 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "context"
- "errors"
- "net"
- "net/netip"
- "reflect"
- "testing"
- "time"
-
- "github.com/ooni/minivpn/vpn/mocks"
- tls "github.com/refraction-networking/utls"
- "golang.zx2c4.com/wireguard/tun/netstack"
-)
-
-func makeTestingClient(opt *Options) *Client {
- client := &Client{Opts: opt}
- client.conn = makeTestingConnForHandshake("udp", "10.0.0.1", 42)
- client.tunInfo = &tunnelInfo{ip: "10.0.0.1", mtu: 1500}
- client.mux = &mockMuxerForClient{}
- return client
-}
-
-func TestNewTunDialer(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
- mockClient := makeTestingClient(opt)
- type args struct {
- client *Client
- }
- tests := []struct {
- name string
- args args
- want *TunDialer
- }{
- {
- name: "get dialer ok",
- args: args{
- client: mockClient,
- },
- want: &TunDialer{
- client: mockClient,
- ns1: openDNSPrimary,
- ns2: openDNSSecondary,
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := NewTunDialer(tt.args.client)
- if tt.want != nil && got == nil {
- t.Errorf("expected non-nil result")
- return
- }
- if got.client == nil {
- t.Errorf("client should not be nil")
- return
- }
- if !reflect.DeepEqual(got.client.Opts, tt.want.client.Opts) {
- t.Errorf("NewTunDialerFromOptions() = %v, want %v", got, tt.want)
- return
- }
- })
- }
-}
-
-func TestNewTunDialerWithNameservers(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
- mockClient := makeTestingClient(opt)
- type args struct {
- client *Client
- ns1 string
- ns2 string
- }
- tests := []struct {
- name string
- args args
- want *TunDialer
- }{
- {
- name: "get tundialer with passed nameservers",
- args: args{
- client: mockClient,
- ns1: "8.8.8.8",
- ns2: "8.8.4.4",
- },
- want: &TunDialer{
- client: mockClient,
- ns1: "8.8.8.8",
- ns2: "8.8.4.4",
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := NewTunDialerWithNameservers(tt.args.client, tt.args.ns1, tt.args.ns2)
- if tt.want != nil && got == nil {
- t.Errorf("expected non-nil result")
- return
- }
- if got.client == nil {
- t.Errorf("client should not be nil")
- return
- }
- if got.ns1 != tt.want.ns1 {
- t.Errorf("NewTunDialerWithNameservers() ns1 = %v, want %v", got.ns1, tt.want.ns1)
- }
- if got.ns2 != tt.want.ns2 {
- t.Errorf("NewTunDialerWithNameservers() ns2 = %v, want %v", got.ns2, tt.want.ns2)
- }
- })
- }
-}
-
-type mockDialer struct {
- called bool
-}
-
-func (d *mockDialer) DialContext(ctx context.Context, a, b string) (net.Conn, error) {
- d.called = true
- conn := makeTestingConnForHandshake("udp", "10.0.0.0", 42)
- return conn, nil
-}
-
-func TestStartNewTunDialerFromOptions(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
-
- type args struct {
- opt *Options
- dialer *mockDialer
- }
- tests := []struct {
- name string
- args args
- want *TunDialer
- wantErr error
- }{
- {
- name: "get tundialer from options calls start and fails on tls handshake",
- args: args{
- opt: opt,
- dialer: &mockDialer{},
- },
- want: nil,
- // TODO(ainghazal): I'd like to return nil here, but that would force
- // me to leak even more internals from the client
- // initialization. maybe it's not a good idea to have a
- // convenience function that returns an started client after all?
- wantErr: ErrBadTLSHandshake,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := StartNewTunDialerFromOptions(tt.args.opt, tt.args.dialer)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("expected error %v, got %v", tt.wantErr, err)
- return
- }
- if tt.want == nil && got == nil {
- return
- }
- if tt.want != nil && got != nil {
- t.Errorf("expected non-nil result")
- return
- }
- if tt.want != nil || got.client == nil {
- t.Errorf("client should not be nil")
- return
- }
- if !tt.args.dialer.called {
- t.Errorf("the mock Dialer has not been called")
- return
- }
- if !reflect.DeepEqual(got.client.Opts, tt.want.client.Opts) {
- t.Errorf("NewTunDialerFromOptions() = %v, want %v", got, tt.want)
- return
- }
- })
- }
-}
-
-type mockedDialerContext struct{}
-
-func (md *mockedDialerContext) DialContext(context.Context, string, string) (net.Conn, error) {
- conn := makeTestingConnForHandshake("udp", "10.0.0.0", 42)
- return conn, nil
-}
-
-func makeTestingConnForReadWrite(network, addr string, n int) net.Conn {
- mockAddr := &mocks.Addr{}
- mockAddr.MockString = func() string {
- return addr
- }
- mockAddr.MockNetwork = func() string {
- return network
- }
-
- mockConn := &mocks.Conn{}
- mockConn.MockLocalAddr = func() net.Addr {
- return mockAddr
- }
- mockConn.MockWrite = func([]byte) (int, error) {
- return n, nil
- }
- mockConn.MockRead = func(b []byte) (int, error) {
- switch mockConn.Count {
- case 0:
- // control message data (to load remote key)
- p := []byte{0x00, 0x00, 0x00, 0x00, 0x02}
- p = append(p, bytes.Repeat([]byte{0x01}, 70)...)
- copy(b[:], p)
- mockConn.Count += 1
- return len(p), nil
- case 1:
- // control message data (pushed options)
- p := []byte("PUSH_REPLY,ifconfig 2.2.2.2")
- copy(b[:], p)
- mockConn.Count += 1
- return len(p), nil
- }
-
- return 0, nil
- }
- return mockConn
-}
-
-// TODO(https://github.com/ooni/minivpn/issues/28):
-// refactor test to use custom dialers
-func TestTunDialer_Dial(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
- mockClient := makeTestingClient(opt)
-
- orig := initTLSFn
- defer func() {
- initTLSFn = orig
- }()
-
- initTLSFn = func(*session, *certConfig) (*tls.Config, error) {
- return &tls.Config{InsecureSkipVerify: true}, nil
- }
- tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) {
- conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42)
- return conn, nil
- }
-
- type fields struct {
- Dialer DialerContext
- client *Client
- ns1 string
- ns2 string
- skipDeviceSetup bool
- }
- type args struct {
- network string
- address string
- }
- tests := []struct {
- name string
- fields fields
- args args
- want net.Conn
- wantErr error
- }{
- {
- name: "dial ok with mocked dialFn",
- fields: fields{
- client: mockClient,
- ns1: "8.8.8.8",
- ns2: "8.8.4.4",
- skipDeviceSetup: true,
- },
- args: args{
- network: "udp",
- address: "10.0.88.88:443",
- },
- wantErr: nil,
- },
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- td := TunDialer{
- client: tt.fields.client,
- ns1: tt.fields.ns1,
- ns2: tt.fields.ns2,
- skipDeviceSetup: tt.fields.skipDeviceSetup,
- }
- conn, err := td.Dial(tt.args.network, tt.args.address)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("TunDialer.Dial() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- conn.Close()
- })
- }
-}
-
-// TODO(https://github.com/ooni/minivpn/issues/28):
-// refactor test to use custom dialers
-func TestTunDialer_DialTimeout(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
- mockClient := makeTestingClient(opt)
-
- orig := initTLSFn
- defer func() {
- initTLSFn = orig
- }()
- initTLSFn = func(*session, *certConfig) (*tls.Config, error) {
- return &tls.Config{InsecureSkipVerify: true}, nil
- }
- tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) {
- conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42)
- return conn, nil
- }
- type fields struct {
- client *Client
- ns1 string
- ns2 string
- skipDeviceSetup bool
- }
- type args struct {
- network string
- address string
- timeout time.Duration
- }
- tests := []struct {
- name string
- fields fields
- args args
- want net.Conn
- wantErr error
- }{
- {
- name: "dial ok with mocked dialFn",
- fields: fields{
- client: mockClient,
- ns1: "8.8.8.8",
- ns2: "8.8.4.4",
- skipDeviceSetup: true,
- },
- args: args{
- network: "udp",
- address: "10.0.88.88:443",
- timeout: time.Second,
- },
- wantErr: nil,
- },
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- td := TunDialer{
- client: tt.fields.client,
- ns1: tt.fields.ns1,
- ns2: tt.fields.ns2,
- skipDeviceSetup: true,
- }
- conn, err := td.DialTimeout(tt.args.network, tt.args.address, tt.args.timeout)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("TunDialer.DialTimeout() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- conn.Close()
- })
- }
-}
-
-// TODO(https://github.com/ooni/minivpn/issues/28):
-// refactor test to use custom dialers
-func TestTunDialer_DialContext(t *testing.T) {
- opt := makeTestingOptions(t, "AES-128-GCM", "sha512")
- mockClient := makeTestingClient(opt)
-
- orig := initTLSFn
- defer func() {
- initTLSFn = orig
- }()
- initTLSFn = func(*session, *certConfig) (*tls.Config, error) {
- return &tls.Config{InsecureSkipVerify: true}, nil
- }
- tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) {
- conn := makeTestingConnForReadWrite("udp", "10.1.1.1", 42)
- return conn, nil
- }
-
- type fields struct {
- client *Client
- ns1 string
- ns2 string
- skipDeviceSetup bool
- }
- type args struct {
- ctx context.Context
- network string
- address string
- }
- tests := []struct {
- name string
- fields fields
- args args
- want net.Conn
- wantErr error
- }{
- {
- name: "dial ok with mocked dialer",
- fields: fields{
- client: mockClient,
- ns1: "8.8.8.8",
- ns2: "8.8.4.4",
- skipDeviceSetup: true,
- },
- args: args{
- ctx: context.Background(),
- network: "udp",
- address: "10.0.88.88:443",
- },
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- td := TunDialer{
- client: tt.fields.client,
- ns1: tt.fields.ns1,
- ns2: tt.fields.ns2,
- skipDeviceSetup: true,
- }
- conn, err := td.DialContext(tt.args.ctx, tt.args.network, tt.args.address)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("TunDialer.DialContext() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- conn.Close()
- })
- }
-}
-
-func Test_device_Up(t *testing.T) {
- tun, _, _ := netstack.CreateNetTUN(
- []netip.Addr{netip.MustParseAddr("10.0.0.1")},
- []netip.Addr{
- netip.MustParseAddr("8.8.8.8"),
- netip.MustParseAddr("4.4.4.4")},
- 1500)
- vpn := makeTestinConnFromNetwork("udp")
- d := device{tun: tun, vpn: vpn}
- d.Up()
-}
diff --git a/vpn/doc.go b/vpn/doc.go
deleted file mode 100644
index 1818a6ad..00000000
--- a/vpn/doc.go
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Package vpn contains the API to create an OpenVPN client that can connect to
-// a remote OpenVPN endpoint and provide you with a tunnel where to send packets.
-//
-// The recommended way to use this package is to use the TunDialer
-// constructors, that gives you a way to transparently Dial() and get TCP or
-// UDP sockets over a virtual gVisor interface, that uses the VPN tunnel as an
-// underlying transport. For examples, see the `proxy' implementation in the
-// `extras` package.
-//
-// If you need to write raw packets to the tunnel instead, you can construct
-// and use a `Client` object directly. `Client` is an implementer of the
-// `net.Conn` interface. You need to `Start()` the Client before you can Read
-// or Write packets to it. An example of this can be found in the `extras/ping`
-// package.
-//
-// Reads and Writes to the Client tunnel object are
-// actually reading and writing to the initialized Data channel of the Client.
-// Any incoming packet while reading that is not an OpenVPN data packet will be
-// dispatched accordingly.
-package vpn
diff --git a/vpn/events.go b/vpn/events.go
deleted file mode 100644
index 1b6e51ec..00000000
--- a/vpn/events.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package vpn
-
-//
-// Catalog of the events that can be emitted, so that users of the library can
-// observe progress of the client's bootstrap.
-//
-
-// The events to be emitted. This is treated as an uint8, so if we ever go past
-// 255 events we need to accomodate the data type.
-const (
- EventReady = iota
- EventDialDone
- EventHandshake
- EventReset
- EventTLSConn
- EventTLSHandshake
- EventTLSHandshakeDone
- EventDataInitDone
- EventHandshakeDone
-)
diff --git a/vpn/logger.go b/vpn/logger.go
deleted file mode 100644
index 90066229..00000000
--- a/vpn/logger.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package vpn
-
-//
-// Logging capabilities.
-//
-
-import (
- "log"
- "os"
-)
-
-// logger uses an implementation from the standard library in case the
-// binary does not set its own.
-var logger Logger = &defaultLogger{}
-
-// Logger is compatible with github.com/apex/log
-type Logger interface {
- // Debug emits a debug message.
- Debug(msg string)
-
- // Debugf formats and emits a debug message.
- Debugf(format string, v ...interface{})
-
- // Info emits an informational message.
- Info(msg string)
-
- // Infof formats and emits an informational message.
- Infof(format string, v ...interface{})
-
- // Warn emits a warning message.
- Warn(msg string)
-
- // Warnf formats and emits a warning message.
- Warnf(format string, v ...interface{})
-
- // Error emits an error message
- Error(msg string)
-
- // Errorf formats and emits an error message.
- Errorf(format string, v ...interface{})
-}
-
-// defaultLogger uses the standard log package for logs in case
-// the user does not provide a custom Log implementation.
-
-type defaultLogger struct{}
-
-func (dl *defaultLogger) Debug(msg string) {
- if os.Getenv("EXTRA_DEBUG") == "1" {
- log.Println(msg)
- }
-}
-
-func (dl *defaultLogger) Debugf(format string, v ...interface{}) {
- if os.Getenv("EXTRA_DEBUG") == "1" {
- log.Printf(format, v...)
- }
-}
-
-func (dl *defaultLogger) Info(msg string) {
- log.Printf("info : %s\n", msg)
-}
-
-func (dl *defaultLogger) Infof(format string, v ...interface{}) {
- log.Printf("info : "+format, v...)
-}
-
-func (dl *defaultLogger) Warn(msg string) {
- log.Printf("warn: %s\n", msg)
-}
-
-func (dl *defaultLogger) Warnf(format string, v ...interface{}) {
- log.Printf("warn: "+format, v...)
-}
-
-func (dl *defaultLogger) Error(msg string) {
- log.Printf("error: %s\n", msg)
-}
-
-func (dl *defaultLogger) Errorf(format string, v ...interface{}) {
- log.Printf("error: "+format, v...)
-}
diff --git a/vpn/logger_test.go b/vpn/logger_test.go
deleted file mode 100644
index b80a8f2d..00000000
--- a/vpn/logger_test.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package vpn
-
-import (
- "os"
- "testing"
-)
-
-func TestDefaultLoggerDoesNotFail(t *testing.T) {
- os.Setenv("EXTRA_DEBUG", "1")
- logger := defaultLogger{}
- logger.Debug("foo")
- logger.Debugf("%s", "foo")
- logger.Info("foo")
- logger.Infof("%s", "foo")
- logger.Warn("foo")
- logger.Warnf("%s", "foo")
- logger.Error("foo")
- logger.Errorf("%s", "foo")
-}
diff --git a/vpn/muxer.go b/vpn/muxer.go
deleted file mode 100644
index b480f63e..00000000
--- a/vpn/muxer.go
+++ /dev/null
@@ -1,553 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "context"
- "encoding/hex"
- "errors"
- "fmt"
- "log"
- "net"
-)
-
-//
-// OpenVPN Multiplexer
-//
-
-var (
- ErrBadHandshake = errors.New("bad vpn handshake")
- ErrBadDataHandshake = errors.New("bad data handshake")
-)
-
-/*
- The vpnMuxer interface represents the VPN transport multiplexer.
-
- One important limitation of the current implementation at this moment is that
- the processing of incoming packets needs to be driven by reads from the user of
- the library. This means that if you don't do reads during some time, any packets
- on the control channel that the server sends us (e.g., openvpn-pings) will not
- be processed (and so, not acknowledged) until triggered by a muxer.Read().
-
- From the original documentation:
- https://community.openvpn.net/openvpn/wiki/SecurityOverview
-
- "OpenVPN multiplexes the SSL/TLS session used for authentication and key
- exchange with the actual encrypted tunnel data stream. OpenVPN provides the
- SSL/TLS connection with a reliable transport layer (as it is designed to
- operate over). The actual IP packets, after being encrypted and signed with an
- HMAC, are tunnelled over UDP without any reliability layer. So if --proto udp
- is used, no IP packets are tunneled over a reliable transport, eliminating the
- problem of reliability-layer collisions -- Of course, if you are tunneling a
- TCP session over OpenVPN running in UDP mode, the TCP protocol itself will
- provide the reliability layer."
-
- SSL/TLS -> Reliability Layer -> \
- --tls-auth HMAC \
- \
- > Multiplexer ----> UDP/TCP
- / Transport
- IP Encrypt and HMAC /
- Tunnel -> using OpenSSL EVP --> /
- Packets interface.
-
-"This model has the benefit that SSL/TLS sees a reliable transport layer while
-the IP packet forwarder sees an unreliable transport layer -- exactly what both
-components want to see. The reliability and authentication layers are
-completely independent of one another, i.e. the sequence number is embedded
-inside the HMAC-signed envelope and is not used for authentication purposes."
-
-*/
-
-// muxer implements vpnMuxer
-type muxer struct {
-
- // A net.Conn that has access to the "wire" transport. this can
- // represent an UDP/TCP socket, or a net.Conn coming from a Pluggable
- // Transport etc.
- conn net.Conn
-
- // After completing the TLS handshake, we get a tls transport that implements
- // net.Conn. All the control packets from that moment on are read from
- // and written to the tls Conn.
- tls net.Conn
-
- // control and data are the handlers for the control and data channels.
- // they implement the methods needed for the handshake and handling of
- // packets.
- control controlHandler
- data dataHandler
-
- // bufReader is used to buffer data channel reads. We only write to
- // this buffer when we have correctly decrypted an incoming
- bufReader *bytes.Buffer
-
- // Mutable state tied to a concrete session.
- session *session
-
- // Mutable state tied to a particular vpn run.
- tunnel *tunnelInfo
-
- // Options are OpenVPN options that come from parsing a subset of the OpenVPN
- // configuration directives, plus some non-standard config directives.
- options *Options
-
- // eventListener is a channel to which Event_*- will be sent if
- // the channel is not nil.
- eventListener chan uint8
-
- failed bool
-}
-
-var _ vpnMuxer = &muxer{} // Ensure that we implement the vpnMuxer interface.
-
-//
-// Interfaces
-//
-
-// vpnMuxer contains all the behavior expected by the muxer.
-type vpnMuxer interface {
- Handshake(ctx context.Context) error
- Reset(net.Conn, *session) error
- InitDataWithRemoteKey() error
- SetEventListener(chan uint8)
- Write([]byte) (int, error)
- Read([]byte) (int, error)
-}
-
-// controlHandler manages the control "channel".
-type controlHandler interface {
- SendHardReset(net.Conn, *session) error
- ParseHardReset([]byte) (sessionID, error)
- SendACK(net.Conn, *session, packetID) error
- PushRequest() []byte
- ReadPushResponse([]byte) map[string][]string
- ControlMessage(*session, *Options) ([]byte, error)
- ReadControlMessage([]byte) (*keySource, string, error)
-}
-
-// dataHandler manages the data "channel".
-type dataHandler interface {
- SetupKeys(*dataChannelKey) error
- SetPeerID(int) error
- WritePacket(net.Conn, []byte) (int, error)
- ReadPacket(*packet) ([]byte, error)
- DecodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error)
- EncryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error)
-}
-
-//
-// muxer initialization
-//
-
-// muxFactory acepts a net.Conn, a pointer to an Options object, and another
-// pointer to a tunnelInfo object, and returns a vpnMuxer and an error if it
-// could not be initialized. This type is used to be able to mock a muxer while
-// testing the Client.
-type muxFactory func(conn net.Conn, options *Options, tunnel *tunnelInfo) (vpnMuxer, error)
-
-// newMuxerFromOptions returns a configured muxer, and any error if the
-// operation could not be completed.
-func newMuxerFromOptions(conn net.Conn, options *Options, tunnel *tunnelInfo) (vpnMuxer, error) {
- control := &control{}
- session, err := newSession()
- if err != nil {
- return &muxer{}, err
- }
- data, err := newDataFromOptions(options, session)
- if err != nil {
- return &muxer{}, err
- }
- br := bytes.NewBuffer(nil)
-
- m := &muxer{
- conn: conn,
- session: session,
- options: options,
- control: control,
- data: data,
- tunnel: tunnel,
- bufReader: br,
- }
- return m, nil
-}
-
-//
-// observability
-//
-
-// SetEvenSetEventListener assigns the passed channel as the event listener for
-// this muxer.
-func (m *muxer) SetEventListener(el chan uint8) {
- m.eventListener = el
-}
-
-// emit sends the passed stage into any configured EventListener
-func (m *muxer) emit(stage uint8) {
- select {
- case m.eventListener <- stage:
- default:
- // do not deliver
- }
-}
-
-//
-// muxer handshake
-//
-
-// Handshake performs the OpenVPN "handshake" operations serially. Accepts a
-// Context, and itt returns any error that is raised at any of the underlying
-// steps.
-func (m *muxer) Handshake(ctx context.Context) (err error) {
- errch := make(chan error, 1)
- go func() {
- errch <- m.handshake()
- }()
- select {
- case err = <-errch:
- case <-ctx.Done():
- err = ctx.Err()
- }
- return
-}
-
-func (m *muxer) handshake() error {
-
- // 1. control channel sends reset, parse response.
-
- m.emit(EventReset)
-
- if err := m.Reset(m.conn, m.session); err != nil {
- return fmt.Errorf("%w: %s", ErrBadHandshake, err)
-
- }
-
- // 2. TLS handshake.
-
- // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer
- if !m.options.hasAuthInfo() {
- return fmt.Errorf("%w: %s", errBadInput, "expected certificate or username/password")
- }
- certCfg, err := newCertConfigFromOptions(m.options)
- if err != nil {
- return err
- }
-
- tlsConf, err := initTLSFn(m.session, certCfg)
- if err != nil {
- return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err)
-
- }
- tlsConn, err := newControlChannelTLSConn(m.conn, m.session)
- m.emit(EventTLSConn)
-
- if err != nil {
- return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err)
- }
-
- m.emit(EventTLSHandshake)
-
- tls, err := tlsHandshakeFn(tlsConn, tlsConf)
- if err != nil {
- return fmt.Errorf("%w: %s", ErrBadTLSHandshake, err)
-
- }
-
- m.emit(EventTLSHandshakeDone)
-
- m.tls = tls
- logger.Info("TLS handshake done")
-
- // 3. data channel init (auth, push, data initialization).
-
- if err := m.InitDataWithRemoteKey(); err != nil {
- return fmt.Errorf("%w: %s", ErrBadDataHandshake, err)
-
- }
-
- m.emit(EventDataInitDone)
-
- logger.Info("VPN handshake done")
- return nil
-}
-
-// Reset sends a hard-reset packet to the server, and awaits the server
-// confirmation.
-func (m *muxer) Reset(conn net.Conn, s *session) error {
- if m.control == nil {
- return fmt.Errorf("%w:%s", errBadInput, "bad control")
- }
- if err := m.control.SendHardReset(conn, s); err != nil {
- return err
- }
-
- resp, err := readPacket(m.conn)
- if err != nil {
- return err
- }
-
- remoteSessionID, err := m.control.ParseHardReset(resp)
-
- // here we could check if we have received a remote session id but
- // our session.remoteSessionID is != from all zeros
- if err != nil {
- return err
- }
- m.session.RemoteSessionID = remoteSessionID
-
- logger.Infof("Remote session ID: %x", m.session.RemoteSessionID)
- logger.Infof("Local session ID: %x", m.session.LocalSessionID)
-
- // we assume id is 0, this is the first packet we ack.
- // XXX I could parse the real packet id from server instead. this
- // _might_ be important when re-keying?
- return m.control.SendACK(m.conn, m.session, packetID(0))
-}
-
-//
-// muxer: read and handle packets
-//
-
-// handleIncoming packet reads the next packet available in the underlying
-// socket. It returns true if the packet was a data packet; otherwise it will
-// process it but return false.
-func (m *muxer) handleIncomingPacket(data []byte) (bool, error) {
- if m.data == nil {
- logger.Errorf("uninitialized muxer")
- return false, errBadInput
- }
- var input []byte
- if data == nil {
- parsed, err := readPacket(m.conn)
- if err != nil {
- return false, err
- }
- input = parsed
- } else {
- input = data
- }
-
- if isPing(input) {
- err := handleDataPing(m.conn, m.data)
- if err != nil {
- logger.Errorf("cannot handle ping: %s", err.Error())
- }
- return false, nil
- }
-
- p, err := parsePacketFromBytes(input)
- if err != nil {
- logger.Error(err.Error())
- return false, err
- }
- if p.isACK() {
- logger.Warn("muxer: got ACK (ignored)")
- return false, err
- }
- if p.isControl() {
- logger.Infof("Got control packet: %d", len(data))
- // Here the server might be requesting us to reset, or to
- // re-key (but I keep ignoring that case for now).
- // we're doing nothing for now.
- fmt.Println(hex.Dump(p.payload))
- return false, err
- }
- if !p.isData() {
- logger.Warnf("unhandled data. (op: %d)", p.opcode)
- fmt.Println(hex.Dump(data))
- return false, err
- }
-
- // at this point, the incoming packet should be
- // a data packet that needs to be processed
- // (decompress+decrypt)
-
- plaintext, err := m.data.ReadPacket(p)
- if err != nil {
- logger.Errorf("bad decryption: %s", err.Error())
- // XXX I'm not sure returning false is the right thing to do here.
- return false, err
- }
-
- // all good! we write the plaintext into the read buffer.
- // the caller is responsible for reading from there.
- m.bufReader.Write(plaintext)
- return true, nil
-}
-
-// handleDataPing replies to an openvpn-ping with a canned response.
-func handleDataPing(conn net.Conn, data dataHandler) error {
- log.Println("openvpn-ping, sending reply")
- _, err := data.WritePacket(conn, pingPayload)
- return err
-}
-
-// readTLSPacket reads a packet over the TLS connection.
-func (m *muxer) readTLSPacket() ([]byte, error) {
- data := make([]byte, 4096)
- _, err := m.tls.Read(data)
- return data, err
-}
-
-// readAndLoadRemoteKey reads one incoming TLS packet, and tries to parse the
-// response contained in it. If the server response is the right kind of
-// packet, it will store the remote key and the parts of the remote options
-// that will be of use later.
-func (m *muxer) readAndLoadRemoteKey() error {
- data, err := m.readTLSPacket()
- if err != nil {
- return err
- }
- if !isControlMessage(data) {
- return fmt.Errorf("%w: %s", errBadControlMessage, "expected null header")
- }
-
- // Parse the received data: we expect remote key and remote options.
- remoteKey, remoteOptStr, err := m.control.ReadControlMessage(data)
- if err != nil {
- logger.Errorf("cannot parse control message")
- return fmt.Errorf("%w: %s", ErrBadHandshake, err)
- }
-
- // Store the remote key.
- key, err := m.session.ActiveKey()
- if err != nil {
- logger.Errorf("cannot get active key")
- return fmt.Errorf("%w: %s", ErrBadHandshake, err)
- }
- err = key.addRemoteKey(remoteKey)
- if err != nil {
- logger.Errorf("cannot add remote key")
- return fmt.Errorf("%w: %s", ErrBadHandshake, err)
- }
-
- // Parse and update the useful fields from the remote options (mtu).
- ti := newTunnelInfoFromRemoteOptionsString(remoteOptStr)
- m.tunnel.mtu = ti.mtu
- return nil
-}
-
-// sendPushRequest sends a push request over the TLS channel.
-func (m *muxer) sendPushRequest() (int, error) {
- return m.tls.Write(m.control.PushRequest())
-}
-
-// readPushReply reads one incoming TLS packet, where we expect to find the
-// response to our push request. If the server response is the right kind of
-// packet, it will store the parts of the pushed options that will be of use
-// later.
-func (m *muxer) readPushReply() error {
- if m.control == nil || m.tunnel == nil {
- return fmt.Errorf("%w:%s", errBadInput, "muxer badly initialized")
-
- }
- resp, err := m.readTLSPacket()
- if err != nil {
- return err
- }
-
- logger.Info("Server pushed options")
-
- if isBadAuthReply(resp) {
- return errBadAuth
- }
-
- if !isPushReply(resp) {
- return fmt.Errorf("%w:%s", errBadServerReply, "expected push reply")
- }
-
- optsMap := m.control.ReadPushResponse(resp)
- ti := newTunnelInfoFromPushedOptions(optsMap)
-
- m.tunnel.ip = ti.ip
- m.tunnel.gw = ti.gw
- m.tunnel.peerID = ti.peerID
-
- logger.Infof("Tunnel IP: %s", m.tunnel.ip)
- logger.Infof("Gateway IP: %s", m.tunnel.gw)
- logger.Infof("Peer ID: %d", m.tunnel.peerID)
-
- return nil
-}
-
-// sendControl message sends a control message over the TLS channel.
-func (m *muxer) sendControlMessage() error {
- cm, err := m.control.ControlMessage(m.session, m.options)
- if err != nil {
- return err
- }
- if _, err := m.tls.Write(cm); err != nil {
- return err
- }
- return nil
-}
-
-// InitDataWithRemoteKey initializes the internal data channel. To do that, it sends a
-// control packet, parses the response, and derives the cryptographic material
-// that will be used to encrypt and decrypt data through the tunnel. At the end
-// of this exchange, the data channel is ready to be used.
-func (m *muxer) InitDataWithRemoteKey() error {
-
- // 1. first we send a control message.
-
- if err := m.sendControlMessage(); err != nil {
- return err
- }
-
- // 2. then we read the server response and load the remote key.
-
- if err := m.readAndLoadRemoteKey(); err != nil {
- return err
- }
-
- // 3. now we can initialize the data channel.
-
- key0, err := m.session.ActiveKey()
- if err != nil {
- return err
- }
-
- err = m.data.SetupKeys(key0)
- if err != nil {
- return err
- }
-
- // 4. finally, we ask the server to push remote options to us. we parse
- // them and keep some useful info.
-
- if _, err := m.sendPushRequest(); err != nil {
- return err
- }
- if err := m.readPushReply(); err != nil {
- return err
- }
-
- m.data.SetPeerID(m.tunnel.peerID)
-
- return nil
-}
-
-// Write sends user bytes as encrypted packets in the data channel. It returns
-// the number of written bytes, and an error if the operation could not succeed.
-func (m *muxer) Write(b []byte) (int, error) {
- if m.data == nil {
- return 0, fmt.Errorf("%w:%s", errBadInput, "data not initialized")
-
- }
- return m.data.WritePacket(m.conn, b)
-}
-
-// Read reads bytes after decrypting packets from the data channel. This is the
-// user-view of the VPN connection reads. It returns the number of bytes read,
-// and an error if the operation could not succeed.
-func (m *muxer) Read(b []byte) (int, error) {
- for {
- ok, err := m.handleIncomingPacket(nil)
- if err != nil {
- return 0, err
- }
- if ok {
- break
- }
- }
- return m.bufReader.Read(b)
-}
diff --git a/vpn/muxer_test.go b/vpn/muxer_test.go
deleted file mode 100644
index c03ca7b8..00000000
--- a/vpn/muxer_test.go
+++ /dev/null
@@ -1,606 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "context"
- "errors"
- "net"
- "reflect"
- "testing"
-
- "github.com/ooni/minivpn/vpn/mocks"
- tls "github.com/refraction-networking/utls"
-)
-
-func Test_newMuxerFromOptions(t *testing.T) {
- randomFn = func(int) ([]byte, error) {
- return []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil
- }
- testSession, _ := newSession()
-
- type args struct {
- conn net.Conn
- options *Options
- tunnel *tunnelInfo
- }
- tests := []struct {
- name string
- args args
- want *muxer
- wantErr error
- }{
- {
- name: "get muxer ok",
- args: args{
- conn: makeTestingConnForWrite("udp", "10.0.42.2", 42),
- options: makeTestingOptions(t, "AES-128-GCM", "sha1"),
- tunnel: &tunnelInfo{},
- },
- want: &muxer{
- conn: makeTestingConnForWrite("udp", "10.0.42.2", 42),
- control: &control{},
- session: testSession,
- options: makeTestingOptions(t, "AES-128-GCM", "sha1"),
- },
- wantErr: nil,
- },
- // TODO: Add more test cases:
- // failure on newSession()
- // failure in newData()
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := newMuxerFromOptions(tt.args.conn, tt.args.options, tt.args.tunnel)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("newMuxerFromOptions() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- // TODO(ainghazal): we cannot compare the options because the paths for the certs are going to be different
- // I think this calls from separating the initial options from a more structured config
- // with the parsed, loaded certs instead.
- })
- }
-}
-
-func makeTestingConnForHandshake(network, addr string, n int) net.Conn {
- ma := &mocks.Addr{}
- ma.MockString = func() string {
- return addr
- }
- ma.MockNetwork = func() string {
- return network
- }
-
- c := &mocks.Conn{}
- c.MockLocalAddr = func() net.Addr {
- return ma
- }
- c.MockWrite = func([]byte) (int, error) {
- return n, nil
- }
- c.MockRead = func(b []byte) (int, error) {
- switch c.Count {
- case 0:
- // this is the expected reset response from server
- rp := []byte{
- 0x40,
- 0x00, 0x01, 0x02, 0x03, 0x04,
- 0x05, 0x06, 0x07, 0x08,
- }
- copy(b[:], rp)
- c.Count += 1
- return len(rp), nil
- case 1:
- // control message data (to load remote key)
- p := []byte{0x00, 0x00, 0x00, 0x00, 0x02}
- p = append(p, bytes.Repeat([]byte{0x01}, 70)...)
- copy(b[:], p)
- c.Count += 1
- return len(p), nil
- case 2:
- // control message data (to load remote key)
- p := []byte("PUSH_REPLY")
- copy(b[:], p)
- c.Count += 1
- return len(p), nil
- }
-
- return 0, nil
- }
- c.MockClose = func() error {
- return nil
- }
- return c
-}
-
-type mockMuxerForHandshake struct {
- muxer
-}
-
-func (md *mockMuxerForHandshake) sendControlMessage() error {
- return nil
-}
-
-func (md *mockMuxerForHandshake) readAndLoadRemoteKey() error {
- return nil
-}
-
-type mockMuxerWithDummyHandshake struct {
- mockMuxerForHandshake
-}
-
-func (md *mockMuxerWithDummyHandshake) Handshake(context.Context) error {
- return nil
-}
-
-func Test_muxer_Handshake(t *testing.T) {
- makeData := func() *data {
- options := makeTestingOptions(t, "AES-128-GCM", "sha1")
- data, _ := newDataFromOptions(options, makeTestingSession())
- return data
- }
-
- m := &mockMuxerForHandshake{}
- m.control = &control{}
- m.data = makeData()
- m.tunnel = &tunnelInfo{}
- s, err := newSession()
- if err != nil {
- t.Error("session failed, cannot run handshake test")
- }
- m.session = s
- m.options = makeTestingOptions(t, "AES-128-GCM", "sha512")
- m.tls = makeTestingConnForWrite("udp", "0.0.0.0", 42)
- m.conn = makeTestingConnForHandshake("udp", "10.0.0.0", 42)
-
- origInit := initTLSFn
- origHandshake := tlsHandshakeFn
-
- defer func() {
- initTLSFn = origInit
- tlsHandshakeFn = origHandshake
- }()
-
- // monkey patch the global functions
-
- initTLSFn = func(*session, *certConfig) (*tls.Config, error) {
- return &tls.Config{InsecureSkipVerify: true}, nil
- }
- tlsHandshakeFn = func(tc *controlChannelTLSConn, tconf *tls.Config) (net.Conn, error) {
- return m.conn, nil
- }
-
- // and now for the test itself...
-
- err = m.Handshake(context.Background())
- if err != nil {
- t.Errorf("muxer.Handshake() error = %v, wantErr nil", err)
- return
- }
-}
-
-func makePacketForHandleIncomingTest(opcode byte, s *session) *packet {
- p := &packet{
- id: packetID(1), // always a good packet for a clean session
- opcode: opcode,
- keyID: 0x00,
- payload: []byte("aaa"),
- localSessionID: s.LocalSessionID,
- remoteSessionID: s.RemoteSessionID,
- acks: []packetID{},
- }
- return p
-}
-
-// I have modified muxer.handleIncomingPacket() so that it optionally receives a []byte
-// in order to make it easier to test payloads. here we go:
-type mockDataHandler struct{}
-
-func (m *mockDataHandler) SetupKeys(*dataChannelKey) error {
- return nil
-}
-
-func (m *mockDataHandler) WritePacket(net.Conn, []byte) (int, error) {
- return 42, nil
-}
-
-func (m *mockDataHandler) ReadPacket(*packet) ([]byte, error) {
- return []byte("alles ist gut"), nil
-}
-
-func (m *mockDataHandler) DecodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error) {
- return &encryptedData{}, nil
-}
-
-func (m *mockDataHandler) EncryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error) {
- return []byte("this is not a payload"), nil
-}
-
-func (m *mockDataHandler) SetPeerID(int) error {
- return nil
-}
-
-type mockDataHandlerBadReadPacket struct {
- mockDataHandler
-}
-
-func (m *mockDataHandlerBadReadPacket) ReadPacket(*packet) ([]byte, error) {
- dummy := errors.New("dummy error")
- return []byte{}, dummy
-}
-
-var _ dataHandler = &mockData{}
-
-func Test_muxer_handleIncomingPacket(t *testing.T) {
- m := muxer{
- data: &mockData{},
- bufReader: &bytes.Buffer{},
- }
-
- // ping data
- if ok, _ := m.handleIncomingPacket(pingPayload); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with ping payload")
- return
- }
- // packets with different opcodes
- if ok, _ := m.handleIncomingPacket([]byte{}); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with empty bytes")
- return
- }
- p := &packet{opcode: pACKV1}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with ack packet")
- return
- }
- p = &packet{opcode: pControlV1}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with control packet")
- return
- }
- p = &packet{opcode: pControlV1}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with control packet")
- return
- }
- p = &packet{opcode: byte(0xff)}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with unknown opcode")
- return
- }
- p = &packet{opcode: pDataV1}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); !ok {
- t.Errorf("muxer.handleIncomingPacket(): expected ok with data opcode")
- return
- }
-
- // replace dataHandler in muxer with a method that raises error on ReadPacket()
- t.Run("error in ReadPacket() should propagate", func(t *testing.T) {
- m = muxer{
- data: &mockDataHandlerBadReadPacket{},
- bufReader: &bytes.Buffer{},
- }
- p = &packet{opcode: pDataV1}
- if ok, _ := m.handleIncomingPacket(p.Bytes()); ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with error in ReadPacket()")
- }
- })
-
- t.Run("null data raises error", func(t *testing.T) {
- m = muxer{
- data: nil,
- bufReader: &bytes.Buffer{},
- }
- p = &packet{opcode: pDataV1}
- ok, err := m.handleIncomingPacket(p.Bytes())
- if ok {
- t.Errorf("muxer.handleIncomingPacket(): expected !ok with null data")
- }
- if err != errBadInput {
- t.Errorf("muxer.handleIncomingPacket(): expected errBadInput")
- }
- })
-}
-
-func Test_muxer_Write(t *testing.T) {
-
- makeData := func() *data {
- options := makeTestingOptions(t, "AES-128-GCM", "sha1")
- data, _ := newDataFromOptions(options, makeTestingSession())
- return data
- }
-
- type fields struct {
- conn net.Conn
- data dataHandler
- }
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- fields fields
- args args
- want int
- wantErr error
- }{
- {
- name: "write fails if data state is not initialized",
- fields: fields{
- conn: makeTestinConnFromNetwork("udp"),
- data: &data{},
- },
- args: args{[]byte("alles ist bad")},
- want: 0,
- wantErr: errBadInput,
- },
- {
- name: "write fails if data state is nil",
- fields: fields{
- conn: makeTestinConnFromNetwork("udp"),
- data: nil,
- },
- args: args{[]byte("alles ist bad")},
- want: 0,
- wantErr: errBadInput,
- },
- {
- name: "write calls data.WritePacket",
- fields: fields{
- conn: makeTestingConnForWrite("udp", "10.0.1.1", 42),
- data: makeData(),
- },
- args: args{[]byte("alles ist gut")},
- want: 42,
- wantErr: nil,
- },
-
- // TODO can add more tests:
- // [ ] check that the error raised by the underlying data read is the error we
- // expect to be returned.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- m := &muxer{
- conn: tt.fields.conn,
- data: tt.fields.data,
- }
- got, err := m.Write(tt.args.b)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("muxer.Write() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("muxer.Write() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func makeTestingConnForRead(retInt int, retErr error, payload []byte) net.Conn {
- ma := &mocks.Addr{}
- ma.MockString = func() string {
- return "10.0.42.2"
- }
- ma.MockNetwork = func() string {
- return "udp"
- }
-
- mc := &mocks.Conn{}
- mc.MockLocalAddr = func() net.Addr {
- return ma
- }
- mc.MockRead = func(b []byte) (int, error) {
- copy(b[:], payload)
- return retInt, retErr
- }
- return mc
-}
-
-type mockData struct {
- data
-}
-
-func (md *mockData) ReadPacket(*packet) ([]byte, error) {
- return []byte("alles ist gut"), nil
-}
-
-func Test_muxer_Read(t *testing.T) {
- // XXX(ainghazal): I'm not sure this is a very good test.
- // what I want to test:
- // - that I call readPacket(mockConn) - I'm assuming we get a good data packet
- // - that I call data.ReadPacket(p)
- // - that we get the right return from muxer.Read()
- // - that the expected buffer is written into the buffer that we pass to Read()
-
- testDataPacket := &packet{opcode: pDataV1, payload: []byte("discarded")}
- bufData := "alles ist gut"
-
- b := make([]byte, 4096)
- want := len(bufData)
- m := &muxer{
- conn: makeTestingConnForRead(want, nil, testDataPacket.Bytes()),
- data: &mockData{},
- bufReader: bytes.NewBuffer(nil),
- }
- got, err := m.Read(b)
- if err != nil {
- t.Errorf("muxer.Read() error = %v, wantErr nil", err)
- return
- }
- if got != want {
- t.Errorf("muxer.Read() = %v, want %v", got, want)
- }
- if !bytes.Equal(b[:len(bufData)], []byte(bufData)) {
- t.Errorf("muxer.Read() = %v, want %v", string(b[:len(bufData)]), string(bufData))
- }
-}
-
-func Test_muxer_readTLSPacket(t *testing.T) {
- type fields struct {
- conn net.Conn
- tls net.Conn
- control controlHandler
- data dataHandler
- bufReader *bytes.Buffer
- session *session
- tunnel *tunnelInfo
- options *Options
- }
- tests := []struct {
- name string
- fields fields
- want []byte
- wantErr bool
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- m := &muxer{
- conn: tt.fields.conn,
- tls: tt.fields.tls,
- control: tt.fields.control,
- data: tt.fields.data,
- bufReader: tt.fields.bufReader,
- session: tt.fields.session,
- tunnel: tt.fields.tunnel,
- options: tt.fields.options,
- }
- got, err := m.readTLSPacket()
- if (err != nil) != tt.wantErr {
- t.Errorf("muxer.readTLSPacket() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("muxer.readTLSPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_muxer_readAndLoadRemoteKey(t *testing.T) {
- type fields struct {
- conn net.Conn
- tls net.Conn
- control controlHandler
- data dataHandler
- bufReader *bytes.Buffer
- session *session
- tunnel *tunnelInfo
- options *Options
- }
- tests := []struct {
- name string
- fields fields
- wantErr bool
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- m := &muxer{
- conn: tt.fields.conn,
- tls: tt.fields.tls,
- control: tt.fields.control,
- data: tt.fields.data,
- bufReader: tt.fields.bufReader,
- session: tt.fields.session,
- tunnel: tt.fields.tunnel,
- options: tt.fields.options,
- }
- if err := m.readAndLoadRemoteKey(); (err != nil) != tt.wantErr {
- t.Errorf("muxer.readAndLoadRemoteKey() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func Test_muxer_readPushReply(t *testing.T) {
- type fields struct {
- conn net.Conn
- tls net.Conn
- control controlHandler
- data dataHandler
- bufReader *bytes.Buffer
- session *session
- tunnel *tunnelInfo
- options *Options
- }
- tests := []struct {
- name string
- fields fields
- wantErr error
- }{
- {
- name: "control == nil should return error",
- fields: fields{
- control: nil,
- },
- wantErr: errBadInput,
- },
- {
- name: "tunnel == nil should return error",
- fields: fields{
- tunnel: nil,
- },
- wantErr: errBadInput,
- },
- // TODO: Add moar test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- m := &muxer{
- conn: tt.fields.conn,
- tls: tt.fields.tls,
- control: tt.fields.control,
- data: tt.fields.data,
- bufReader: tt.fields.bufReader,
- session: tt.fields.session,
- tunnel: tt.fields.tunnel,
- options: tt.fields.options,
- }
- if err := m.readPushReply(); !errors.Is(err, tt.wantErr) {
- t.Errorf("muxer.readPushReply() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func Test_muxer_emitSendsToListener(t *testing.T) {
- t.Run("emit writes event if listener not null", func(t *testing.T) {
- l := make(chan uint8, 2)
- m := &muxer{}
- m.SetEventListener(l)
- sent := uint8(2)
- m.emit(sent)
- got := <-l
- if got != sent {
- t.Errorf("expected %v, got %v", sent, got)
- }
- })
- t.Run("emit is a noop if evenlistener not set", func(t *testing.T) {
- m := &muxer{}
- sent := uint8(2)
- m.emit(sent)
- })
- t.Run("listener receives several events", func(t *testing.T) {
- l := make(chan uint8, 5)
- m := &muxer{}
- m.SetEventListener(l)
- received := []uint8{}
- sent := []uint8{1, 2, 3, 4, 5}
- for _, i := range sent {
- m.emit(i)
- }
- for _ = range sent {
- got := <-l
- received = append(received, got)
- }
- for i := range sent {
- if sent[i] != received[i] {
- t.Errorf("at [%d]: expected %v, got %v", i, sent, received)
- return
- }
- }
- })
-}
diff --git a/vpn/options.go b/vpn/options.go
deleted file mode 100644
index 8f57fcbc..00000000
--- a/vpn/options.go
+++ /dev/null
@@ -1,673 +0,0 @@
-package vpn
-
-//
-// Parse VPN options.
-//
-// Mostly, this file conforms to the format in the reference implementation.
-// However, there are some additions that are specific. To avoid feature creep
-// and fat dependencies, the main `vpn` module only supports mainline
-// capabilities. It is still useful to carry all options in a single type,
-// so it's up to the user of this library to do something useful with
-// such options. The `extra` package provides some of these extra features, like
-// obfuscation support.
-//
-// Following the configuration format in the reference implementation, `minivpn`
-// allows including files in the main configuration file, but only for the `ca`,
-// `cert` and `key` options.
-//
-// Each inline file is started by the line .
-//
-// Here is an example of an inline file usage:
-//
-// ```
-//
-// -----BEGIN CERTIFICATE-----
-// [...]
-// -----END CERTIFICATE-----
-//
-// ```
-
-import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "log"
- "os"
- "path/filepath"
- "strconv"
- "strings"
-)
-
-type (
- // compression describes a compression type (e.g., stub).
- compression string
-)
-
-const (
- // compressionStub adds the (empty) compression stub to the packets.
- compressionStub = compression("stub")
-
- // compressionEmpty is the empty compression.
- compressionEmpty = compression("empty")
-
- // compressionLZONo is lzo-no (another type of no-compression, older).
- compressionLZONo = compression("lzo-no")
-)
-
-type (
- // proto is the main vpn mode (e.g., TCP or UDP).
- proto string
-)
-
-func (p proto) String() string {
- return string(p)
-}
-
-const (
- // protoTCP is used for vpn in TCP mode.
- protoTCP = proto("tcp")
-
- // protoUDP is used for vpn in UDP mode.
- protoUDP = proto("udp")
-)
-
-var (
- // errBadCfg is the generic error returned for invalid config files
- errBadCfg = errors.New("bad config")
-)
-
-var supportedCiphers = []string{
- "AES-128-CBC",
- "AES-192-CBC",
- "AES-256-CBC",
- "AES-128-GCM",
- "AES-192-GCM",
- "AES-256-GCM",
-}
-
-var supportedAuth = []string{
- "SHA1",
- "SHA256",
- "SHA512",
-}
-
-// Options make all the relevant configuration options accessible to the
-// different modules that need it.
-type Options struct {
- Remote string
- Port string
- //TODO(https://github.com/ooni/minivpn/issues/25): Proto should be changed to a string and checked against known types.
- Proto int
- Username string
- Password string
- CaPath string
- CertPath string
- KeyPath string
- Ca []byte
- Cert []byte
- Key []byte
- Compress compression
- Cipher string
- Auth string
- TLSMaxVer string
- // below are options that do not conform to the OpenVPN configuration format.
- ProxyOBFS4 string
- Log Logger
-}
-
-// NewOptionsFromFilePath expects a string with a path to a valid config file,
-// and returns a pointer to a Options struct after parsing the file, and an
-// error if the operation could not be completed.
-func NewOptionsFromFilePath(filePath string) (*Options, error) {
- lines, err := getLinesFromFile(filePath)
- dir, _ := filepath.Split(filePath)
- if err != nil {
- return nil, err
- }
- return getOptionsFromLines(lines, dir)
-}
-
-// certsFromPath returns true when the options object is configured to load
-// certificates from paths; false when we have inline certificates.
-func (o *Options) certsFromPath() bool {
- return o.CertPath != "" && o.KeyPath != "" && o.CaPath != ""
-}
-
-// hasAuthInfo returns true if:
-// - we have paths for cert, key and ca; or
-// - we have inline byte arrays for cert, key and ca; or
-// - we have username + password info.
-func (o *Options) hasAuthInfo() bool {
- if o.CertPath != "" && o.KeyPath != "" && o.CaPath != "" {
- return true
- }
- if len(o.Cert) != 0 && len(o.Key) != 0 && len(o.Ca) != 0 {
- return true
- }
- if o.Username != "" && o.Password != "" {
- return true
- }
- return false
-}
-
-const clientOptions = "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto %sv4,cipher %s,auth %s,keysize %s,key-method 2,tls-client"
-
-// String produces a comma-separated representation of the options, in the same
-// order and format that the openvpn server expects from us.
-func (o *Options) String() string {
- if o.Cipher == "" {
- return ""
- }
- keysize := strings.Split(o.Cipher, "-")[1]
- proto := strings.ToUpper(protoUDP.String())
- if o.Proto == TCPMode {
- proto = strings.ToUpper(protoTCP.String())
- }
- s := fmt.Sprintf(
- clientOptions,
- proto, o.Cipher, o.Auth, keysize)
- if o.Compress == compressionStub {
- s = s + ",compress stub"
- } else if o.Compress == "lzo-no" {
- s = s + ",lzo-comp no"
- } else if o.Compress == compressionEmpty {
- s = s + ",compress"
- }
- logger.Debugf("Local opts: %s", s)
- return s
-}
-
-// newTunnelInfoFromRemoteOptionsString parses the options string returned by
-// server. it returns a new tunnel object where the needed fields have been
-// updated. At the moment, we only parse the tun-mtu parameter.
-func newTunnelInfoFromRemoteOptionsString(remoteOpts string) *tunnelInfo {
- t := &tunnelInfo{}
- opts := strings.Split(remoteOpts, ",")
- for _, opt := range opts {
- vals := strings.Split(opt, " ")
- if len(vals) < 2 {
- continue
- }
- k, v := vals[0], vals[1:]
- if k == "tun-mtu" {
- mtu, err := strconv.Atoi(v[0])
- if err != nil {
- log.Println("bad mtu:", err)
- continue
- }
- t.mtu = mtu
- }
- }
- return t
-}
-
-// newTunnelInfoFromPushedOptions takes a map of string to array of strings, and returns
-// a new tunnel struct with the relevant info.
-func newTunnelInfoFromPushedOptions(opts map[string][]string) *tunnelInfo {
- t := &tunnelInfo{}
- if r := opts["route"]; len(r) >= 1 {
- t.gw = r[0]
- } else if r := opts["route-gateway"]; len(r) >= 1 {
- t.gw = r[0]
- }
- ip := opts["ifconfig"]
- if len(ip) >= 1 {
- t.ip = ip[0]
- }
- peerID := opts["peer-id"]
- if len(peerID) == 1 {
- i, err := parseIntFromOption(peerID[0])
- if err == nil {
- t.peerID = i
- } else {
- log.Println("Cannot parse peer-id:", err.Error())
- }
- }
- return t
-}
-
-// parseIntFromOption parses an int from a null-terminated string
-func parseIntFromOption(s string) (int, error) {
- str := ""
- for i := 0; i < len(s); i++ {
- if byte(s[i]) == 0x00 {
- return strconv.Atoi(str)
- }
- str = str + string(s[i])
- }
- return 0, nil
-}
-
-// pushedOptionsAsMap returns a map for the server-pushed options,
-// where the options are the keys and each space-separated value is the value.
-// This function always returns an initialized map, even if empty.
-func pushedOptionsAsMap(pushedOptions []byte) map[string][]string {
- optMap := make(map[string][]string)
- if pushedOptions == nil || len(pushedOptions) == 0 {
- return optMap
- }
-
- optStr := string(pushedOptions[:len(pushedOptions)-1])
-
- opts := strings.Split(optStr, ",")
- for _, opt := range opts {
- vals := strings.Split(opt, " ")
- k, v := vals[0], vals[1:]
- optMap[k] = v
- }
- return optMap
-}
-
-func parseProto(p []string, o *Options) error {
- if len(p) != 1 {
- return fmt.Errorf("%w: %s", errBadCfg, "proto needs one arg")
- }
- m := p[0]
- switch m {
- case protoUDP.String():
- o.Proto = UDPMode
- case protoTCP.String():
- o.Proto = TCPMode
- default:
- return fmt.Errorf("%w: bad proto: %s", errBadCfg, m)
-
- }
- return nil
-}
-
-// TODO(ainghazal): all these little functions can be better tested if we return the options object too
-
-func parseRemote(p []string, o *Options) error {
- if len(p) != 2 {
- return fmt.Errorf("%w: %s", errBadCfg, "remote needs two args")
- }
- o.Remote, o.Port = p[0], p[1]
- return nil
-}
-
-func parseCipher(p []string, o *Options) error {
- if len(p) != 1 {
- return fmt.Errorf("%w: %s", errBadCfg, "cipher expects one arg")
- }
- cipher := p[0]
- if !hasElement(cipher, supportedCiphers) {
- return fmt.Errorf("%w: unsupported cipher: %s", errBadCfg, cipher)
- }
- o.Cipher = cipher
- return nil
-}
-
-func parseAuth(p []string, o *Options) error {
- if len(p) != 1 {
- return fmt.Errorf("%w: %s", errBadCfg, "invalid auth entry")
- }
- auth := p[0]
- if !hasElement(auth, supportedAuth) {
- return fmt.Errorf("%w: unsupported auth: %s", errBadCfg, auth)
- }
- o.Auth = auth
- return nil
-}
-
-func parseCA(p []string, o *Options, basedir string) error {
- e := fmt.Errorf("%w: %s", errBadCfg, "ca expects a valid file")
- if len(p) != 1 {
- return e
- }
- ca := toAbs(p[0], basedir)
- if sub, _ := isSubdir(basedir, ca); !sub {
- return fmt.Errorf("%w: %s", errBadCfg, "ca must be below config path")
- }
- if !existsFile(ca) {
- return e
- }
- o.CaPath = ca
- return nil
-}
-
-func parseCert(p []string, o *Options, basedir string) error {
- e := fmt.Errorf("%w: %s", errBadCfg, "cert expects a valid file")
- if len(p) != 1 {
- return e
- }
- cert := toAbs(p[0], basedir)
- if sub, _ := isSubdir(basedir, cert); !sub {
- return fmt.Errorf("%w: %s", errBadCfg, "cert must be below config path")
- }
- if !existsFile(cert) {
- return e
- }
- o.CertPath = cert
- return nil
-}
-
-func parseKey(p []string, o *Options, basedir string) error {
- e := fmt.Errorf("%w: %s", errBadCfg, "key expects a valid file")
- if len(p) != 1 {
- return e
- }
- key := toAbs(p[0], basedir)
- if sub, _ := isSubdir(basedir, key); !sub {
- return fmt.Errorf("%w: %s", errBadCfg, "key must be below config path")
- }
- if !existsFile(key) {
- return e
- }
- o.KeyPath = key
- return nil
-}
-
-// parseAuthUser reads credentials from a given file, according to the openvpn
-// format (user and pass on a line each). To avoid path traversal / LFI, the
-// credentials file is expected to be in a subdirectory of the base dir.
-func parseAuthUser(p []string, o *Options, basedir string) error {
- e := fmt.Errorf("%w: %s", errBadCfg, "auth-user-pass expects a valid file")
- if len(p) != 1 {
- return e
- }
- auth := toAbs(p[0], basedir)
- if sub, _ := isSubdir(basedir, auth); !sub {
- return fmt.Errorf("%w: %s", errBadCfg, "auth must be below config path")
- }
- if !existsFile(auth) {
- return e
- }
- creds, err := getCredentialsFromFile(auth)
- if err != nil {
- return err
- }
- o.Username, o.Password = creds[0], creds[1]
- return nil
-}
-
-func parseCompress(p []string, o *Options) error {
- if len(p) > 1 {
- return fmt.Errorf("%w: %s", errBadCfg, "compress: only empty/stub options supported")
- }
- if len(p) == 0 {
- o.Compress = compressionEmpty
- return nil
- }
- if p[0] == "stub" {
- o.Compress = compressionStub
- return nil
- }
- return fmt.Errorf("%w: %s", errBadCfg, "compress: only empty/stub options supported")
-}
-
-func parseCompLZO(p []string, o *Options) error {
- if p[0] != "no" {
- return fmt.Errorf("%w: %s", errBadCfg, "comp-lzo: compression not supported")
- }
- o.Compress = "lzo-no"
- return nil
-}
-
-// parseTLSVerMax sets the maximum TLS version. This is currently ignored
-// because we're using uTLS to parrot the Client Hello.
-func parseTLSVerMax(p []string, o *Options) error {
- if o == nil {
- return errBadInput
- }
- if len(p) == 0 {
- o.TLSMaxVer = "1.3"
- return nil
- }
- if p[0] == "1.2" {
- o.TLSMaxVer = "1.2"
- }
- return nil
-}
-
-func parseProxyOBFS4(p []string, o *Options) error {
- if len(p) != 1 {
- return fmt.Errorf("%w: %s", errBadCfg, "proto-obfs4: need a properly configured proxy")
- }
- // TODO(ainghazal): can validate the obfs4://... scheme here
- o.ProxyOBFS4 = p[0]
- return nil
-}
-
-var pMap = map[string]interface{}{
- "proto": parseProto,
- "remote": parseRemote,
- "cipher": parseCipher,
- "auth": parseAuth,
- "compress": parseCompress,
- "comp-lzo": parseCompLZO,
- "proxy-obfs4": parseProxyOBFS4,
- "tls-version-max": parseTLSVerMax, // this is currently ignored because of uTLS
-}
-
-var pMapDir = map[string]interface{}{
- "ca": parseCA,
- "cert": parseCert,
- "key": parseKey,
- "auth-user-pass": parseAuthUser,
-}
-
-func parseOption(o *Options, dir, key string, p []string, lineno int) error {
- switch key {
- case "proto", "remote", "cipher", "auth", "compress", "comp-lzo", "tls-version-max", "proxy-obfs4":
- fn := pMap[key].(func([]string, *Options) error)
- if e := fn(p, o); e != nil {
- return e
- }
- case "ca", "cert", "key", "auth-user-pass":
- fn := pMapDir[key].(func([]string, *Options, string) error)
- if e := fn(p, o, dir); e != nil {
- return e
- }
- default:
- log.Printf("warn: unsupported key in line %d\n", lineno)
- }
- return nil
-}
-
-// getOptionsFromLines tries to parse all the lines coming from a config file
-// and raises validation errors if the values do not conform to the expected
-// format.
-// the config file supports inline file inclusion for , and .
-func getOptionsFromLines(lines []string, dir string) (*Options, error) {
- opt := &Options{}
-
- // tag and inlineBuf are used to parse inline files.
- // these follow the format used by the reference openvpn implementation.
- // each block (any of ca, key, cert) is marked by a line; lines in between are expected to contain
- // the crypto block.
- tag := ""
- inlineBuf := new(bytes.Buffer)
-
- for lineno, l := range lines {
- if strings.HasPrefix(l, "#") {
- continue
- }
- l = strings.TrimSpace(l)
-
- // inline certs
- if isClosingTag(l) {
- // we expect an already existing inlineBuf
- e := parseInlineTag(opt, tag, inlineBuf)
- if e != nil {
- return nil, e
- }
- tag = ""
- inlineBuf = new(bytes.Buffer)
- continue
- }
- if tag != "" {
- inlineBuf.Write([]byte(l))
- inlineBuf.Write([]byte("\n"))
- continue
- }
- if isOpeningTag(l) {
- if len(inlineBuf.Bytes()) != 0 {
- // something wrong: an opening tag should not be found
- // when we still have bytes in the inline buffer.
- return opt, fmt.Errorf("%w: %s", errBadInput, "tag not closed")
- }
- tag = parseTag(l)
- continue
- }
-
- // parse parts in the same line
- p := strings.Split(l, " ")
- if len(p) == 0 {
- continue
- }
- var (
- key string
- parts []string
- )
- if len(p) == 1 {
- key = p[0]
- } else {
- key, parts = p[0], p[1:]
- }
- e := parseOption(opt, dir, key, parts, lineno)
- if e != nil {
- return nil, e
- }
- }
- return opt, nil
-}
-
-func isOpeningTag(key string) bool {
- switch key {
- case "", "", "":
- return true
- default:
- return false
- }
-}
-
-func isClosingTag(key string) bool {
- switch key {
- case "", "", "":
- return true
- default:
- return false
- }
-}
-
-func parseTag(tag string) string {
- switch tag {
- case "", "":
- return "ca"
- case "", "":
- return "cert"
- case "", "":
- return "key"
- default:
- return ""
- }
-}
-
-// parseInlineTag
-func parseInlineTag(o *Options, tag string, buf *bytes.Buffer) error {
- b := buf.Bytes()
- if len(b) == 0 {
- return fmt.Errorf("%w: empty inline tag: %d", errBadInput, len(b))
- }
- switch tag {
- case "ca":
- o.Ca = b
- case "cert":
- o.Cert = b
- case "key":
- o.Key = b
- default:
- return fmt.Errorf("%w: unknown tag: %s", errBadInput, tag)
-
- }
- return nil
-}
-
-// hasElement checks if a given string is present in a string array. returns
-// true if that is the case, false otherwise.
-func hasElement(el string, arr []string) bool {
- for _, v := range arr {
- if v == el {
- return true
- }
- }
- return false
-}
-
-// existsFile returns true if the file to which the path refers to exists and
-// is a regular file.
-func existsFile(path string) bool {
- statbuf, err := os.Stat(path)
- return !errors.Is(err, os.ErrNotExist) && statbuf.Mode().IsRegular()
-}
-
-// getLinesFromFile accepts a path parameter, and return a string array with
-// its content and an error if the operation cannot be completed.
-func getLinesFromFile(path string) ([]string, error) {
- f, err := os.Open(path) //#nosec G304
- defer func() {
- if err := f.Close(); err != nil {
- logger.Errorf("Error closing file: %s\n", err)
- }
- }()
- if err != nil {
- return nil, err
- }
-
- lines := make([]string, 0)
- scanner := bufio.NewScanner(f)
- for scanner.Scan() {
- lines = append(lines, scanner.Text())
- }
- err = scanner.Err()
- if err != nil {
- return nil, err
- }
- return lines, nil
-}
-
-// getCredentialsFromFile accepts a path string parameter, and return a string
-// array containing the credentials in that file, and an error if the operation
-// could not be completed.
-func getCredentialsFromFile(path string) ([]string, error) {
- lines, err := getLinesFromFile(path)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", errBadCfg, err)
- }
- if len(lines) != 2 {
- return nil, fmt.Errorf("%w: %s", errBadCfg, "malformed credentials file")
- }
- if len(lines[0]) == 0 {
- return nil, fmt.Errorf("%w: %s", errBadCfg, "empty username in creds file")
- }
- if len(lines[1]) == 0 {
- return nil, fmt.Errorf("%w: %s", errBadCfg, "empty password in creds file")
- }
- return lines, nil
-}
-
-// toAbs return an absolute path if the given path is not already absolute; to
-// do so, it will append the path to the given basedir.
-func toAbs(path, basedir string) string {
- if filepath.IsAbs(path) {
- return path
- }
- return filepath.Join(basedir, path)
-}
-
-// isSubdir checks if a given path is a subdirectory of another. It returns
-// true if that's the case, and any error raise during the check.
-func isSubdir(parent, sub string) (bool, error) {
- p, err := filepath.Abs(parent)
- if err != nil {
- return false, err
- }
- s, err := filepath.Abs(sub)
- if err != nil {
- return false, err
- }
- return strings.HasPrefix(s, p), nil
-}
diff --git a/vpn/options_test.go b/vpn/options_test.go
deleted file mode 100644
index 9c7d072f..00000000
--- a/vpn/options_test.go
+++ /dev/null
@@ -1,971 +0,0 @@
-package vpn
-
-import (
- "errors"
- "os"
- fp "path/filepath"
- "reflect"
- "testing"
-)
-
-func writeDummyCertFiles(d string) {
- os.WriteFile(fp.Join(d, "ca.crt"), []byte("dummy"), 0600)
- os.WriteFile(fp.Join(d, "cert.pem"), []byte("dummy"), 0600)
- os.WriteFile(fp.Join(d, "key.pem"), []byte("dummy"), 0600)
-}
-
-func TestOptions_String(t *testing.T) {
- type fields struct {
- Remote string
- Port string
- Proto int
- Username string
- Password string
- Ca string
- Cert string
- Key string
- Compress compression
- Cipher string
- Auth string
- TLSMaxVer string
- ProxyOBFS4 string
- Log Logger
- }
- tests := []struct {
- name string
- fields fields
- want string
- }{
- {
- name: "empty cipher",
- fields: fields{},
- want: "",
- },
- {
- name: "proto tcp",
- fields: fields{
- Cipher: "AES-128-GCM",
- Auth: "sha512",
- Proto: 1,
- },
- want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto TCPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client",
- },
- {
- name: "compress stub",
- fields: fields{
- Cipher: "AES-128-GCM",
- Auth: "sha512",
- Proto: 2,
- Compress: compressionStub,
- },
- want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client,compress stub",
- },
- {
- name: "compress lzo-no",
- fields: fields{
- Cipher: "AES-128-GCM",
- Auth: "sha512",
- Proto: 2,
- Compress: compressionLZONo,
- },
- want: "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-GCM,auth sha512,keysize 128,key-method 2,tls-client,lzo-comp no",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- o := &Options{
- Remote: tt.fields.Remote,
- Port: tt.fields.Port,
- Proto: tt.fields.Proto,
- Username: tt.fields.Username,
- Password: tt.fields.Password,
- CaPath: tt.fields.Ca,
- CertPath: tt.fields.Cert,
- KeyPath: tt.fields.Key,
- Compress: tt.fields.Compress,
- Cipher: tt.fields.Cipher,
- Auth: tt.fields.Auth,
- TLSMaxVer: tt.fields.TLSMaxVer,
- ProxyOBFS4: tt.fields.ProxyOBFS4,
- Log: tt.fields.Log,
- }
- if got := o.String(); got != tt.want {
- t.Errorf("Options.string() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestGetOptionsFromLines(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "remote 0.0.0.0 1194",
- "cipher AES-256-GCM",
- "auth SHA512",
- "ca ca.crt",
- "cert cert.pem",
- "key cert.pem",
- }
- writeDummyCertFiles(d)
- o, err := getOptionsFromLines(l, d)
- if err != nil {
- t.Errorf("Good options should not fail: %s", err)
- }
- if o.Cipher != "AES-256-GCM" {
- t.Errorf("Cipher not what expected")
- }
- if o.Auth != "SHA512" {
- t.Errorf("Auth not what expected")
- }
-}
-
-func TestGetOptionsFromLinesInlineCerts(t *testing.T) {
- l := []string{
- "",
- "ca_string",
- "",
- "",
- "cert_string",
- "",
- "",
- "key_string",
- "",
- }
- o, err := getOptionsFromLines(l, "")
- if err != nil {
- t.Errorf("Good options should not fail: %s", err)
- }
- if string(o.Ca) != "ca_string\n" {
- t.Errorf("Expected ca_string, got: %s.", string(o.Ca))
- }
- if string(o.Cert) != "cert_string\n" {
- t.Errorf("Expected cert_string, got: %s.", string(o.Cert))
- }
- if string(o.Key) != "key_string\n" {
- t.Errorf("Expected key_string, got: %s.", string(o.Key))
- }
-}
-
-func TestGetOptionsFromLinesNoFiles(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "ca ca.crt",
- }
- _, err := getOptionsFromLines(l, d)
- if err == nil {
- t.Errorf("Should fail if no files provided")
- }
-}
-
-func TestGetOptionsNoCompression(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "compress",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- o, err := getOptionsFromLines(l, d)
- if err != nil {
- t.Errorf("Should not fail: compress")
- }
- if o.Compress != "empty" {
- t.Errorf("Expected compress==empty")
- }
-}
-
-func TestGetOptionsCompressionStub(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "compress stub",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- o, err := getOptionsFromLines(l, d)
- if err != nil {
- t.Errorf("Should not fail: compress stub")
- }
- if o.Compress != "stub" {
- t.Errorf("expected compress==stub")
- }
-}
-
-func TestGetOptionsCompressionBad(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "compress foo",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- _, err := getOptionsFromLines(l, d)
- if err == nil {
- t.Errorf("Unknown compress: should fail")
- }
-}
-
-func TestGetOptionsCompressLZO(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "comp-lzo no",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- o, err := getOptionsFromLines(l, d)
- if err != nil {
- t.Errorf("Should not fail: lzo-comp no")
- }
- if o.Compress != "lzo-no" {
- t.Errorf("expected compress=lzo-no")
- }
-}
-
-func TestGetOptionsBadRemote(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "remote",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- _, err := getOptionsFromLines(l, d)
- if err == nil {
- t.Errorf("Should fail: malformed remote")
- }
-}
-
-func TestGetOptionsBadCipher(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "cipher",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- _, err := getOptionsFromLines(l, d)
- if err == nil {
- t.Errorf("Should fail: malformed cipher")
- }
- l = []string{
- "cipher AES-111-CBC",
- }
- _, err = getOptionsFromLines(l, d)
- if err == nil {
- t.Errorf("Should fail: bad cipher")
- }
-}
-
-func TestGetOptionsComment(t *testing.T) {
- d := t.TempDir()
- l := []string{
- "cipher AES-256-GCM",
- "#cipher AES-128-GCM",
- }
- // should fail if no certs
- // writeDummyCertFiles(d)
- o, err := getOptionsFromLines(l, d)
- if err != nil {
- t.Errorf("Should not fail: commented line")
- }
- if o.Cipher != "AES-256-GCM" {
- t.Errorf("Expected cipher: AES-256-GCM")
- }
-}
-
-var dummyConfigFile = []byte(`proto udp
-cipher AES-128-GCM
-auth SHA1`)
-
-func writeDummyConfigFile(dir string) (string, error) {
- f, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return "", err
- }
- f.Write(dummyConfigFile)
- return f.Name(), nil
-}
-
-func Test_ParseConfigFile(t *testing.T) {
- // parse good file
- f, err := writeDummyConfigFile(t.TempDir())
- if err != nil {
- t.Fatal("ParseConfigFile(): cannot write cert needed for the test")
- }
- o, err := NewOptionsFromFilePath(f)
- if err != nil {
- t.Errorf("ParseConfigFile(): expected err=%v, got=%v", nil, err)
- }
- wantProto := UDPMode
- if o.Proto != wantProto {
- t.Errorf("ParseConfigFile(): expected Proto=%v, got=%v", wantProto, o.Proto)
- }
- wantCipher := "AES-128-GCM"
- if o.Cipher != wantCipher {
- t.Errorf("ParseConfigFile(): expected=%v, got=%v", wantCipher, o.Cipher)
- }
-
- // expect error when parsing a bad filepath
- _, err = NewOptionsFromFilePath("")
- if err == nil {
- t.Errorf("expected error with empty file")
- }
-
- _, err = NewOptionsFromFilePath("http://example.com")
- if err == nil {
- t.Errorf("expected error with http uri")
- }
-}
-
-func Test_parseProto(t *testing.T) {
- // empty parts
- err := parseProto([]string{}, &Options{})
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseProto(): wantErr: %v, got %v", wantErr, err)
- }
-
- // two parts
- err = parseProto([]string{"foo", "bar"}, &Options{})
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseProto(): wantErr %v, got %v", wantErr, err)
- }
-
- // udp
- opt := &Options{}
- err = parseProto([]string{"udp"}, opt)
- if !errors.Is(err, nil) {
- t.Errorf("parseProto(): wantErr: %v, got %v", nil, err)
- }
- if opt.Proto != UDPMode {
- t.Errorf("parseProto(): wantErr %v, got %v", nil, err)
- }
-
- // tcp
- opt = &Options{}
- err = parseProto([]string{"tcp"}, opt)
- if !errors.Is(err, nil) {
- t.Errorf("parseProto(): wantErr: %v, got %v", nil, err)
- }
- if opt.Proto != TCPMode {
- t.Errorf("parseProto(): wantErr %v, got %v", nil, err)
- }
-
- // bad
- opt = &Options{}
- err = parseProto([]string{"kcp"}, opt)
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseProto(): wantErr: %v, got %v", errBadCfg, err)
- }
-
-}
-
-func Test_parseProxyOBFS4(t *testing.T) {
- // empty parts
- err := parseProxyOBFS4([]string{}, &Options{})
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err)
- }
-
- // obfs4 string
- opt := &Options{}
- obfs4Uri := "obfs4://foobar"
- err = parseProxyOBFS4([]string{obfs4Uri}, opt)
- wantErr = nil
- if !errors.Is(err, wantErr) {
- t.Errorf("parseProxyOBFS4(): wantErr: %v, got %v", wantErr, err)
- }
- if opt.ProxyOBFS4 != obfs4Uri {
- t.Errorf("parseProxyOBFS4(): want %v, got %v", obfs4Uri, opt.ProxyOBFS4)
- }
-
-}
-
-func Test_parseCA(t *testing.T) {
- // more than one part should fail
- err := parseCA([]string{"one", "two"}, &Options{}, "")
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCA(): want %v, got %v", wantErr, err)
- }
-
- // empty part should fail
- err = parseCA([]string{}, &Options{}, "")
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCA(): want %v, got %v", wantErr, err)
- }
-}
-
-func Test_parseCert(t *testing.T) {
- // more than one part should fail
- err := parseCert([]string{"one", "two"}, &Options{}, "")
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCert(): want %v, got %v", wantErr, err)
- }
-
- // empty part should fail
- err = parseCert([]string{}, &Options{}, "")
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCert(): want %v, got %v", wantErr, err)
- }
-
- // non-existent cert should fail
- err = parseCert([]string{"/tmp/nonexistent"}, &Options{}, "")
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCert(): want %v, got %v", wantErr, err)
- }
-}
-
-func Test_parseKey(t *testing.T) {
- // more than one part should fail
- err := parseKey([]string{"one", "two"}, &Options{}, "")
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseKey(): want %v, got %v", wantErr, err)
- }
-
- // empty part should fail
- err = parseKey([]string{}, &Options{}, "")
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseKey(): want %v, got %v", wantErr, err)
- }
-
- // non-existent key should fail
- err = parseKey([]string{"/tmp/nonexistent"}, &Options{}, "")
- wantErr = errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseKey(): want %v, got %v", wantErr, err)
- }
-}
-
-func Test_parseCompress(t *testing.T) {
- // more than one part should fail
- err := parseCompress([]string{"one", "two"}, &Options{})
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCompress(): want %v, got %v", wantErr, err)
- }
-}
-
-func Test_parseCompLZO(t *testing.T) {
- // only "no" is supported
- err := parseCompLZO([]string{"yes"}, &Options{})
- wantErr := errBadCfg
- if !errors.Is(err, wantErr) {
- t.Errorf("parseCompLZO(): want %v, got %v", wantErr, err)
- }
-}
-
-func Test_parseOption(t *testing.T) {
- // unknown key should not fail
- err := parseOption(&Options{}, t.TempDir(), "unknownKey", []string{"a", "b"}, 0)
- if err != nil {
- t.Errorf("parseOption(): want %v, got %v", nil, err)
- }
-}
-
-func Test_newTunnelInfoFromRemoteOptionsString(t *testing.T) {
- type args struct {
- remoteOpts string
- }
- tests := []struct {
- name string
- args args
- want *tunnelInfo
- }{
- {
- name: "parse good tun-mtu",
- args: args{
- remoteOpts: "foo bar,tun-mtu 1500",
- },
- want: &tunnelInfo{
- mtu: 1500,
- },
- },
- {
- name: "empty string",
- args: args{
- remoteOpts: "",
- },
- want: &tunnelInfo{},
- },
- {
- name: "empty field",
- args: args{
- remoteOpts: "tun-mtu 1200,,",
- },
- want: &tunnelInfo{
- mtu: 1200,
- },
- },
- {
- name: "extra space",
- args: args{
- remoteOpts: "tun-mtu 1200",
- },
- want: &tunnelInfo{},
- },
- {
- name: "mtu not an int",
- args: args{
- remoteOpts: "tun-mtu aaa",
- },
- want: &tunnelInfo{},
- },
- {
- name: "entry with single field",
- args: args{
- remoteOpts: "bad",
- },
- want: &tunnelInfo{},
- },
- {
- name: "entries with single field",
- args: args{
- remoteOpts: "bad,worse",
- },
- want: &tunnelInfo{},
- },
- {
- name: "tun-mtu no value",
- args: args{
- remoteOpts: "tun-mtu",
- },
- want: &tunnelInfo{},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := newTunnelInfoFromRemoteOptionsString(tt.args.remoteOpts); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("parseRemoteOptions() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_pushedOptionsAsMap(t *testing.T) {
- type args struct {
- pushedOptions []byte
- }
- tests := []struct {
- name string
- args args
- want map[string][]string
- }{
- {
- name: "do parse tunnel ip",
- args: args{[]byte("foo bar,ifconfig 10.0.0.3,")},
- want: map[string][]string{
- "foo": []string{"bar"},
- "ifconfig": []string{"10.0.0.3"},
- },
- },
- {
- name: "empty string",
- args: args{[]byte{}},
- want: map[string][]string{},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := pushedOptionsAsMap(tt.args.pushedOptions); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("pushedOptionsAsMap() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_parseAuth(t *testing.T) {
- type args struct {
- p []string
- o *Options
- }
- tests := []struct {
- name string
- args args
- wantErr error
- }{
- {
- name: "should fail with empty array",
- args: args{[]string{}, &Options{}},
- wantErr: errBadCfg,
- },
- {
- name: "should fail with 2-element array",
- args: args{[]string{"foo", "bar"}, &Options{}},
- wantErr: errBadCfg,
- },
- {
- name: "should fail with lowercase option",
- args: args{[]string{"sha1"}, &Options{}},
- wantErr: errBadCfg,
- },
- {
- name: "should fail with unknown option",
- args: args{[]string{"SHA666"}, &Options{}},
- wantErr: errBadCfg,
- },
- {
- name: "should not fail with good option",
- args: args{[]string{"SHA512"}, &Options{}},
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if err := parseAuth(tt.args.p, tt.args.o); !errors.Is(err, tt.wantErr) {
- t.Errorf("parseAuth() error = %v, wantErr %v", err, tt.wantErr)
- }
-
- })
- }
-}
-
-func Test_parseAuthUser(t *testing.T) {
- makeCreds := func(credStr string) string {
- f, err := os.CreateTemp(t.TempDir(), "tmpfile-")
- if err != nil {
- t.Fatal(err)
- }
- if _, err := f.Write([]byte(credStr)); err != nil {
- t.Fatal(err)
- }
- return f.Name()
- }
-
- baseDir := func() string {
- return os.TempDir()
- }
-
- type args struct {
- p []string
- o *Options
- d string
- }
- tests := []struct {
- name string
- args args
- wantErr error
- }{
- {
- name: "parse good auth",
- args: args{
- p: []string{makeCreds("foo\nbar\n")},
- o: &Options{},
- d: baseDir(),
- },
- wantErr: nil,
- },
- {
- name: "path traversal should fail",
- args: args{
- p: []string{"/tmp/../etc/passwd"},
- o: &Options{},
- d: baseDir(),
- },
- wantErr: errBadCfg,
- },
- {
- name: "parse empty file should fail",
- args: args{
- p: []string{""},
- o: &Options{},
- d: baseDir(),
- },
- wantErr: errBadCfg,
- },
- {
- name: "parse empty parts should fail",
- args: args{
- p: []string{},
- o: &Options{},
- d: baseDir(),
- },
- wantErr: errBadCfg,
- },
- {
- name: "parse less than two lines should fail",
- args: args{
- p: []string{makeCreds("foo\n")},
- o: &Options{},
- d: baseDir(),
- },
- wantErr: errBadCfg,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if err := parseAuthUser(tt.args.p, tt.args.o, tt.args.d); !errors.Is(err, tt.wantErr) {
- t.Errorf("parseAuthUser() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-// TODO(ainghazal): return options object so that it's testable too
-func Test_parseTLSVerMax(t *testing.T) {
- type args struct {
- p []string
- o *Options
- }
- tests := []struct {
- name string
- args args
- wantErr error
- }{
- {
- name: "nil options should fail",
- args: args{},
- wantErr: errBadInput,
- },
- {
- name: "default",
- args: args{o: &Options{}},
- wantErr: nil,
- },
- {
- name: "default with good tls opt",
- args: args{p: []string{"1.2"}, o: &Options{}},
- wantErr: nil,
- },
- {
- // FIXME this case should probably fail
- name: "default with too many parts",
- args: args{p: []string{"1.2", "1.3"}, o: &Options{}},
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if err := parseTLSVerMax(tt.args.p, tt.args.o); !errors.Is(err, tt.wantErr) {
- t.Errorf("parseTLSVerMax() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func Test_proto_String(t *testing.T) {
- tests := []struct {
- name string
- p proto
- want string
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := tt.p.String(); got != tt.want {
- t.Errorf("proto.String() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_getCredentialsFromFile(t *testing.T) {
- makeCreds := func(credStr string) string {
- f, err := os.CreateTemp(t.TempDir(), "tmpfile-")
- if err != nil {
- t.Fatal(err)
- }
- if _, err := f.Write([]byte(credStr)); err != nil {
- t.Fatal(err)
- }
- return f.Name()
- }
-
- type args struct {
- path string
- }
- tests := []struct {
- name string
- args args
- want []string
- wantErr error
- }{
- {
- name: "should fail with non-existing file",
- args: args{"/tmp/nonexistent"},
- want: nil,
- wantErr: errBadCfg,
- },
- {
- name: "should fail with empty file",
- args: args{makeCreds("")},
- want: nil,
- wantErr: errBadCfg,
- },
- {
- name: "should fail with empty user",
- args: args{makeCreds("\n\n")},
- want: nil,
- wantErr: errBadCfg,
- },
- {
- name: "should fail with empty pass",
- args: args{makeCreds("user\n\n")},
- want: nil,
- wantErr: errBadCfg,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := getCredentialsFromFile(tt.args.path)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("getCredentialsFromFile() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("getCredentialsFromFile() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_isSubdir(t *testing.T) {
- type args struct {
- parent string
- sub string
- }
- tests := []struct {
- name string
- args args
- want bool
- wantErr bool
- }{
- {
- name: "sunny path",
- args: args{
- parent: "/foo/bar",
- sub: "/foo/bar/baz",
- },
- want: true,
- wantErr: false,
- },
- {
- name: "same dir",
- args: args{
- parent: "/foo/bar",
- sub: "/foo/bar",
- },
- want: true,
- wantErr: false,
- },
- {
- name: "same dir w/ slash",
- args: args{
- parent: "/foo/bar",
- sub: "/foo/bar/",
- },
- want: true,
- wantErr: false,
- },
- {
- name: "not subdir",
- args: args{
- parent: "/foo/bar",
- sub: "/foo",
- },
- want: false,
- wantErr: false,
- },
- {
- name: "path traversal",
- args: args{
- parent: "/foo/bar",
- sub: "/foo/bar/./../",
- },
- want: false,
- wantErr: false,
- },
- {
- name: "path traversal with .",
- args: args{
- parent: ".",
- sub: "/etc/",
- },
- want: false,
- wantErr: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := isSubdir(tt.args.parent, tt.args.sub)
- if (err != nil) != tt.wantErr {
- t.Errorf("isSubdir() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("isSubdir() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_newTunnelInfoFromPushedOptions(t *testing.T) {
- type args struct {
- opts map[string][]string
- }
- tests := []struct {
- name string
- args args
- want *tunnelInfo
- }{
- {
- name: "get route",
- args: args{
- map[string][]string{
- "route": []string{"1.1.1.1"},
- },
- },
- want: &tunnelInfo{
- gw: "1.1.1.1",
- },
- },
- {
- name: "get route from gw",
- args: args{
- map[string][]string{
- "route-gateway": []string{"1.1.2.2"},
- },
- },
- want: &tunnelInfo{
- gw: "1.1.2.2",
- },
- },
- {
- name: "get ip",
- args: args{
- map[string][]string{
- "ifconfig": []string{"1.1.3.3", "foo", "bar"},
- },
- },
- want: &tunnelInfo{
- ip: "1.1.3.3",
- },
- },
- {
- name: "get ip and route",
- args: args{
- map[string][]string{
- "ifconfig": []string{"1.1.3.3", "foo", "bar"},
- "route": []string{"1.1.1.1"},
- "route-gateway": []string{"1.1.2.2"},
- },
- },
- want: &tunnelInfo{
- ip: "1.1.3.3",
- gw: "1.1.1.1",
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := newTunnelInfoFromPushedOptions(tt.args.opts); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newTunnelInfoFromPushedOptions() = %v, want %v", got, tt.want)
- }
- })
- }
-}
diff --git a/vpn/packet.go b/vpn/packet.go
deleted file mode 100644
index e5f012ba..00000000
--- a/vpn/packet.go
+++ /dev/null
@@ -1,382 +0,0 @@
-package vpn
-
-//
-// Encode and decode packets according to the OpenVPN protocol.
-//
-
-import (
- "bytes"
- "errors"
- "fmt"
- "io"
-)
-
-const (
- stNothing = iota
- stControlChannelOpen
- stControlMessageSent
- stKeyExchanged
- stPullRequestSent
- stOptionsPushed
- stInitialized
- stDataReady
-)
-
-const (
- pControlHardResetClientV1 = iota + 1
- pControlHardResetServerV1 // 2
- pControlSoftResetV1 // 3
- pControlV1 // 4
- pACKV1 // 5
- pDataV1 // 6
- pControlHardResetClientV2 // 7
- pControlHardResetServerV2 // 8
- pDataV2 // 9
-)
-
-const (
- UDPMode = iota
- TCPMode
-)
-
-var (
- errEmptyPayload = errors.New("empty payload")
- errBadKeyMethod = errors.New("unsupported key method")
- errBadControlMessage = errors.New("bad message")
- errBadServerReply = errors.New("bad server reply")
- errBadAuth = errors.New("server says: bad auth")
-
- controlMessageHeader = []byte{0x00, 0x00, 0x00, 0x00}
- pingPayload = []byte{0x2A, 0x18, 0x7B, 0xF3, 0x64, 0x1E, 0xB4, 0xCB, 0x07, 0xED, 0x2D, 0x0A, 0x98, 0x1F, 0xC7, 0x48}
-
- IV_Ver = "2.5.5" // OpenVPN version compat that we declare to the server
- IV_Proto = "2" // IV_PROTO declared to the server. We need to be sure to enable the peer-id bit to use P_DATA_V2.
-)
-
-// sessionID is the session identifier.
-type sessionID [8]byte
-
-// packetID is a packet identifier.
-type packetID uint32
-
-// ackArray holds the identifiers of packets to ack.
-type ackArray []packetID
-
-// packet represents a packet according to the OpenVPN protocol.
-type packet struct {
-
- // id is the packet-id for replay protection.
- // According to the spec: "4 or 8 bytes, includes sequence number and optional time_t timestamp".
- // We do not use the timestamp.
- id packetID
-
- // opcode is the packet message type (a P_* constant; high 5-bits of
- // the first packet byte).
- opcode byte
-
- // The key_id refers to an already negotiated TLS session.
- // This is the shortened version of the key-id (low 3-bits of the first
- // packet byte).
- keyID byte
-
- // The 64 bit form (of the key) is referred to as a session_id.
- localSessionID sessionID
- remoteSessionID sessionID
- payload []byte
- acks ackArray
-}
-
-// parsePacketFromBytes produces a packet after parsing the common header.
-// In TCP mode, it is assumed that the packet length (part of the header) has
-// already been stripped out.
-func parsePacketFromBytes(buf []byte) (*packet, error) {
- if len(buf) < 2 {
- return &packet{}, errBadInput
- }
- opcode := buf[0] >> 3
- keyID := buf[0] & 0x07
-
- var payload = []byte{}
-
- switch opcode {
- case pDataV2:
- payload = buf[4:]
- default:
- payload = buf[1:]
- }
-
- // TODO missing peerID
- p := &packet{
- opcode: opcode,
- keyID: keyID,
- payload: payload,
- }
- return parsePacket(p)
-}
-
-// newPacketFromPayload returns a packet from the passed arguments: opcode,
-// keyID and a raw byte array payload.
-func newPacketFromPayload(opcode uint8, keyID uint8, payload []byte) *packet {
- p := &packet{
- opcode: opcode,
- keyID: keyID,
- payload: payload,
- }
- return p
-}
-
-// Bytes returns a byte array that is ready to be sent on the wire.
-func (packet *packet) Bytes() []byte {
- buf := &bytes.Buffer{}
- buf.WriteByte((packet.opcode << 3) | (packet.keyID & 0x07))
- buf.Write(packet.localSessionID[:])
- // we write a byte with the number of acks, and then
- // serialize each ack.
- nAcks := len(packet.acks)
- if nAcks > 255 {
- logger.Warnf("packet %d had too many acks (%d)", packet.id, nAcks)
- nAcks = 255
- }
- buf.WriteByte(byte(nAcks))
- for i := 0; i < nAcks; i++ {
- bufWriteUint32(buf, uint32(packet.acks[i]))
- }
- // remote session id
- if len(packet.acks) > 0 {
- buf.Write(packet.remoteSessionID[:])
- }
- if packet.opcode != pACKV1 {
- bufWriteUint32(buf, uint32(packet.id))
- }
- // payload
- buf.Write(packet.payload)
- return buf.Bytes()
-}
-
-// isACK returns true if the packet is an ACK packet.
-func (p *packet) isACK() bool {
- return p.opcode == byte(pACKV1)
-}
-
-// isControl returns true if the packet is any of the control types.
-func (p *packet) isControl() bool {
- switch p.opcode {
- case byte(pControlHardResetServerV2), byte(pControlV1):
- return true
- default:
- return false
- }
-}
-
-// isControlV1 returns true if the packet is of the control v1 type.
-func (p *packet) isControlV1() bool {
- return p.opcode == byte(pControlV1)
-}
-
-// isData returns true if the packet is of data type.
-func (p *packet) isData() bool {
- switch p.opcode {
- case byte(pDataV1), byte(pDataV2):
- return true
- default:
- return false
- }
-}
-
-// parse tries to parse the payload of the packet, and returns a packet and an
-// error. it does only parse control packets (for now - parsing of data packets
-// is done on the data handler methods).
-func parsePacket(p *packet) (*packet, error) {
- if p.isControl() {
- return parseControlPacket(p)
- }
- return p, nil
-}
-
-// parseControlPacket parses the contents of a control packet, and returns a
-// packet and an error.
-func parseControlPacket(p *packet) (*packet, error) {
- if len(p.payload) == 0 {
- return p, errEmptyPayload
- }
- if !p.isControl() {
- return p, fmt.Errorf("%w: %s", errBadInput, "expected control packet")
- }
-
- buf := bytes.NewBuffer(p.payload)
-
- // local session id
- _, err := io.ReadFull(buf, p.localSessionID[:])
- if err != nil {
- return p, fmt.Errorf("%w: bad sessionID: %s", errBadInput, err)
- }
-
- // ack array
- ackBuf, err := buf.ReadByte()
- if err != nil {
- return p, fmt.Errorf("%w: bad ack: %s", errBadInput, err)
- }
- nAcks := int(ackBuf)
- p.acks = make([]packetID, nAcks)
- for i := 0; i < nAcks; i++ {
- val, err := bufReadUint32(buf)
- if err != nil {
- return p, fmt.Errorf("%w: cannot parse ack id: %s", errBadInput, err)
- }
- p.acks[i] = packetID(val)
- }
-
- // remote session id
- if nAcks > 0 {
- _, err = io.ReadFull(buf, p.remoteSessionID[:])
- if err != nil {
- return p, fmt.Errorf("%w: bad remote sessionID: %s", errBadInput, err)
- }
- }
-
- // packet id
- if p.opcode != pACKV1 {
- val, err := bufReadUint32(buf)
- if err != nil {
- return p, fmt.Errorf("%w: bad packetID: %s", errBadInput, err)
- }
- p.id = packetID(val)
- }
-
- // payload
- p.payload = buf.Bytes()
- return p, nil
-}
-
-// isPingPacket returns true if the packet payload matches a hard-coded ping
-// payload.
-func isPing(b []byte) bool {
- return bytes.Equal(b, pingPayload)
-}
-
-// serverControlMessage is sent by the server. it contains reply to the auth
-// and push requests. we initialize client's internal state after parsing the
-// fields contained in here.
-type serverControlMessage struct {
- payload []byte
-}
-
-// valid returns true if the packet has a control-message header.
-func (sc *serverControlMessage) valid() bool {
- if len(sc.payload) < 4 {
- return false
- }
- return bytes.Equal(sc.payload[:4], controlMessageHeader)
-}
-
-// newServerControlMessageFromBytes returns a server control message from the
-// passed byte array.
-func newServerControlMessageFromBytes(buf []byte) *serverControlMessage {
- return &serverControlMessage{buf}
-}
-
-// parseControlMessage gets a server control message and returns the value for
-// the remote key, the server remote options, and an error indicating if the
-// operation could not be completed.
-func parseServerControlMessage(sc *serverControlMessage) (*keySource, string, error) {
- if !sc.valid() {
- return nil, "", fmt.Errorf("%w: %s", errBadControlMessage, "bad header")
- }
- if len(sc.payload) < 71 {
- return nil, "", fmt.Errorf("%w: bad len from server:%d", errBadControlMessage, len(sc.payload))
- }
- keyMethod := sc.payload[4]
- if keyMethod != 2 {
- return nil, "", fmt.Errorf("%w: %d", errBadKeyMethod, keyMethod)
-
- }
- var random1, random2 [32]byte
- // first chunk of random bytes
- copy(random1[:], sc.payload[5:37])
- // second chunk of random bytes
- copy(random2[:], sc.payload[37:69])
-
- options, err := decodeOptionStringFromBytes(sc.payload[69:])
- if err != nil {
- return nil, "", fmt.Errorf("%w:%s", errBadControlMessage, "bad options string")
- }
-
- logger.Debugf("Remote opts: %s", options)
- remoteKey := &keySource{r1: random1, r2: random2}
- return remoteKey, options, nil
-}
-
-// encodeClientControlMessage returns a byte array with the payload for a control channel packet.
-// This is the packet that the client sends to the server with the key
-// material, local options and credentials (if username+password authentication is used).
-func encodeClientControlMessageAsBytes(k *keySource, o *Options) ([]byte, error) {
- opt, err := encodeOptionStringToBytes(o.String())
- if err != nil {
- return nil, err
- }
- user, err := encodeOptionStringToBytes(string(o.Username))
- if err != nil {
- return nil, err
- }
- pass, err := encodeOptionStringToBytes(string(o.Password))
- if err != nil {
- return nil, err
- }
-
- var out bytes.Buffer
- out.Write(controlMessageHeader)
- out.WriteByte(0x02) // key method (2)
- out.Write(k.Bytes())
- out.Write(opt)
- out.Write(user)
- out.Write(pass)
-
- // we could send IV_PLAT too, but afaik declaring the platform does not
- // make any difference for our purposes.
- rawInfo := fmt.Sprintf("IV_VER=%s\nIV_PROTO=%s\n", IV_Ver, IV_Proto)
- peerInfo, _ := encodeOptionStringToBytes(rawInfo)
- out.Write(peerInfo)
- return out.Bytes(), nil
-}
-
-// serverHard reset contains the payload for a serverHardReset message type.
-type serverHardReset struct {
- payload []byte
-}
-
-// newServerHardReset returns a serverHardReset message type, and an error if
-// the passed payload is empty.
-func newServerHardReset(b []byte) (*serverHardReset, error) {
- if len(b) == 0 {
- return nil, fmt.Errorf("%w: %s", errBadReset, "zero len")
- }
- p := &serverHardReset{b}
- return p, nil
-}
-
-// parseServerHardResetPacket returns the sessionID received from the server, or an
-// error if we could not parse the message.
-func parseServerHardResetPacket(p *serverHardReset) (sessionID, error) {
- if len(p.payload) < 10 {
- return sessionID{}, fmt.Errorf("%w: %s", errBadReset, "not enough bytes")
- }
- // BUG: this function assumes keyID == 0
- if p.payload[0] != 0x40 {
- return sessionID{}, fmt.Errorf("%w: %s", errBadReset, "bad header")
- }
- var rs sessionID
- copy(rs[:], p.payload[1:9])
- return rs, nil
-}
-
-// newACKPacket returns a packet with the P_ACK_V1 opcode.
-func newACKPacket(ackID packetID, s *session) *packet {
- acks := []packetID{ackID}
- p := &packet{
- opcode: pACKV1,
- localSessionID: s.LocalSessionID,
- remoteSessionID: s.RemoteSessionID,
- acks: acks,
- }
- return p
-}
diff --git a/vpn/packet_test.go b/vpn/packet_test.go
deleted file mode 100644
index 28ac9c99..00000000
--- a/vpn/packet_test.go
+++ /dev/null
@@ -1,286 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "errors"
- "reflect"
- "testing"
-)
-
-func Test_newACKPacket(t *testing.T) {
- type args struct {
- ackID packetID
- s *session
- }
- tests := []struct {
- name string
- args args
- want *packet
- }{
- {"good_ack",
- args{42, &session{}},
- &packet{opcode: pACKV1, acks: []packetID{42}},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := newACKPacket(tt.args.ackID, tt.args.s); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newACKPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_isPing(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want bool
- }{
- {"good ping", args{pingPayload}, true},
- {"bad ping", args{append(pingPayload, 0x00)}, false},
- {"empty", args{[]byte{}}, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isPing(tt.args.b); got != tt.want {
- t.Errorf("isPing() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_newServerControlMessageFromBytes(t *testing.T) {
- payload := []byte{0xff, 0xfe, 0xfd}
- m := newServerControlMessageFromBytes(payload)
- if !bytes.Equal(m.payload, payload) {
- t.Errorf("newServerControlMessageFromBytes() = got %v, want %v", m.payload, payload)
- }
-}
-
-func Test_serverControlMessage_valid(t *testing.T) {
- type fields struct {
- payload []byte
- }
- tests := []struct {
- name string
- fields fields
- want bool
- }{
- {
- "good control message",
- fields{controlMessageHeader},
- true,
- },
- {
- "bad control message",
- fields{[]byte{0x00, 0x00, 0x00, 0x01}},
- false,
- },
- {
- "empty control message",
- fields{[]byte{}},
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sc := &serverControlMessage{
- payload: tt.fields.payload,
- }
- if got := sc.valid(); got != tt.want {
- t.Errorf("serverControlMessage.valid() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_encodeClientControlMessageAsBytes(t *testing.T) {
-
- var manyA, manyB [32]byte
- var manyC [48]byte
-
- copy(manyA[:], bytes.Repeat([]byte{0x65}, 32))
- copy(manyB[:], bytes.Repeat([]byte{0x66}, 32))
- copy(manyC[:], bytes.Repeat([]byte{0x67}, 48))
-
- type args struct {
- k *keySource
- o *Options
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr bool
- }{
- {
- "empty options",
- args{
- &keySource{manyA, manyB, manyC},
- &Options{},
- },
- func() []byte {
- buf := []byte{0x00, 0x00, 0x00, 0x00, 0x02}
- buf = append(buf, manyC[:]...)
- buf = append(buf, manyA[:]...)
- buf = append(buf, manyB[:]...)
- buf = append(buf, []byte{
- // options, null-terminated
- 0x00, 0x01, 0x00,
- // auth strings
- 0x00, 0x01, 0x00,
- 0x00, 0x01, 0x00}...)
- buf = append(buf, []byte{0x00, 0x19}...)
- buf = append(buf, []byte("IV_VER=2.5.5\nIV_PROTO=2\n")...)
- buf = append(buf, 0x00)
- return buf
- }(),
- false,
- },
- {
- "good options",
- args{
- &keySource{manyA, manyB, manyC},
- &Options{Cipher: "AES-128-CBC"},
- },
- func() []byte {
- buf := []byte{0x00, 0x00, 0x00, 0x00, 0x02}
- buf = append(buf, manyC[:]...)
- buf = append(buf, manyA[:]...)
- buf = append(buf, manyB[:]...)
- buf = append(buf, []byte{0x00, 0x74}...)
- buf = append(buf, []byte("V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto UDPv4,cipher AES-128-CBC,auth ,keysize 128,key-method 2,tls-client")...)
- // null-terminate + auth
- buf = append(buf, []byte{
- 0x00,
- 0x00, 0x01, 0x00,
- 0x00, 0x01, 0x00}...)
- buf = append(buf, []byte{0x00, 0x19}...)
- buf = append(buf, []byte("IV_VER=2.5.5\nIV_PROTO=2\n")...)
- buf = append(buf, 0x00)
- return buf
- }(),
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := encodeClientControlMessageAsBytes(tt.args.k, tt.args.o)
- if (err != nil) != tt.wantErr {
- t.Errorf("encodeClientControlMessageAsBytes() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("encodeClientControlMessageAsBytes() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_newServerHardReset(t *testing.T) {
- type args struct {
- b []byte
- }
- tests := []struct {
- name string
- args args
- want *serverHardReset
- wantErr error
- }{
- {
- name: "good payload",
- args: args{[]byte("not a payload")},
- want: &serverHardReset{[]byte("not a payload")},
- wantErr: nil,
- },
- {
- name: "empty",
- args: args{[]byte{}},
- want: nil,
- wantErr: errBadReset,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := newServerHardReset(tt.args.b)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("newServerHardReset() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("newServerHardReset() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_parseServerHardResetPacket(t *testing.T) {
-
- var goodSessionID sessionID
- goodPayload := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
- shortPayload := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}
- copy(goodSessionID[:], goodPayload)
-
- type args struct {
- p *serverHardReset
- }
- tests := []struct {
- name string
- args args
- want sessionID
- wantErr error
- }{
- {
- name: "good server hard reset",
- args: args{
- &serverHardReset{
- payload: append([]byte{0x40}, goodPayload...),
- },
- },
- want: goodSessionID,
- wantErr: nil,
- },
- {
- name: "payload too short should fail",
- args: args{
- &serverHardReset{
- payload: append([]byte{0x40}, shortPayload...),
- },
- },
- want: sessionID{},
- wantErr: errBadReset,
- },
- {
- name: "bad header should fail",
- args: args{
- &serverHardReset{
- payload: append([]byte{0x41}, goodPayload...),
- },
- },
- want: sessionID{},
- wantErr: errBadReset,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := parseServerHardResetPacket(tt.args.p)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("parseServerHardResetPacket() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("parseServerHardResetPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-// Regression test for MIV-01-001
-func Test_Crash_parseServerHardResetPacket(t *testing.T) {
- p := &serverHardReset{}
- parseServerHardResetPacket(p)
-}
diff --git a/vpn/tls.go b/vpn/tls.go
deleted file mode 100644
index 0651847d..00000000
--- a/vpn/tls.go
+++ /dev/null
@@ -1,267 +0,0 @@
-package vpn
-
-//
-// TLS initialization and read/write wrappers.
-//
-// TODO(ainghazal): for the time being, we're using uTLS to parrot a ClientHello that can reasonably blend
-// with a recent openvpn+openssl client (2.5.x). We might want to revisit this
-// in the near future and perhaps expose other TLS Factories.
-//
-
-import (
- "crypto/x509"
- "encoding/hex"
- "errors"
- "fmt"
- "io/ioutil"
- "net"
-
- tls "github.com/refraction-networking/utls"
-)
-
-var (
- // ErrBadTLSHandshake is returned when the OpenVPN handshake failed.
- ErrBadTLSHandshake = errors.New("handshake failure")
- // ErrBadCA is returned when the CA file cannot be found or is not valid.
- ErrBadCA = errors.New("bad ca conf")
- // ErrBadKeypair is returned when the key or cert file cannot be found or is not valid.
- ErrBadKeypair = errors.New("bad keypair conf")
- // ErrBadParrot is returned for errors during TLS parroting
- ErrBadParrot = errors.New("cannot parrot")
- // ErrCannotVerifyCertChain is returned for certificate chain validation errors.
- ErrCannotVerifyCertChain = errors.New("cannot verify chain")
-)
-
-// certVerifyOptionsNoCommonNameCheck returns a x509.VerifyOptions initialized with
-// an empty string for the DNSName. This allows to skip CN verification.
-func certVerifyOptionsNoCommonNameCheck() x509.VerifyOptions {
- return x509.VerifyOptions{DNSName: ""}
-}
-
-// certVerifyOptions is the options factory that the customVerify function will
-// use; by default it configures VerifyOptions to skip the DNSName check.
-var certVerifyOptions = certVerifyOptionsNoCommonNameCheck
-
-// certPaths holds the paths for the cert, key, and ca used for OpenVPN
-// certificate authentication.
-type certPaths struct {
- certPath string
- keyPath string
- caPath string
-}
-
-// loadCertAndCAFromPath parses the PEM certificates contained in the paths pointed by
-// the passed certPaths and return a certConfig with the client and CA certificates.
-func loadCertAndCAFromPath(pth certPaths) (*certConfig, error) {
- ca := x509.NewCertPool()
- caData, err := ioutil.ReadFile(pth.caPath)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", ErrBadCA, err)
- }
- ok := ca.AppendCertsFromPEM(caData)
- if !ok {
- return nil, fmt.Errorf("%w: %s", ErrBadCA, "cannot parse ca cert")
- }
-
- cfg := &certConfig{ca: ca}
- if pth.certPath != "" && pth.keyPath != "" {
- cert, err := tls.LoadX509KeyPair(pth.certPath, pth.keyPath)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", ErrBadKeypair, err)
- }
- cfg.cert = cert
- }
- return cfg, nil
-}
-
-// certBytes holds the byte arrays for the cert, key, and ca used for OpenVPN
-// certificate authentication.
-type certBytes struct {
- cert []byte
- key []byte
- ca []byte
-}
-
-// loadCertAndCAFromBytes parses the PEM certificates from the byte arrays in the
-// the passed certBytes, and return a certConfig with the client and CA certificates.
-func loadCertAndCAFromBytes(crt certBytes) (*certConfig, error) {
- ca := x509.NewCertPool()
- ok := ca.AppendCertsFromPEM(crt.ca)
- if !ok {
- return nil, fmt.Errorf("%w: %s", ErrBadCA, "cannot parse ca cert")
- }
- cfg := &certConfig{ca: ca}
- if crt.cert != nil && crt.key != nil {
- cert, err := tls.X509KeyPair(crt.cert, crt.key)
- if err != nil {
- return nil, fmt.Errorf("%w: %s", ErrBadKeypair, err)
- }
- cfg.cert = cert
- }
- return cfg, nil
-}
-
-// authorityPinner is any object from which we can obtain a certpool containing
-// a pinned Certificate Authority for verification.
-type authorityPinner interface {
- authority() *x509.CertPool
-}
-
-// certConfig holds the parsed certificate and CA used for OpenVPN mutual
-// certificate authentication.
-type certConfig struct {
- cert tls.Certificate
- ca *x509.CertPool
-}
-
-// newCertConfigFromOptions is a constructor that returns a certConfig object initialized
-// from the paths specified in the passed Options object, and an error if it
-// could not be properly built.
-func newCertConfigFromOptions(o *Options) (*certConfig, error) {
- var cfg *certConfig
- var err error
- if o.certsFromPath() {
- cfg, err = loadCertAndCAFromPath(certPaths{
- certPath: o.CertPath,
- keyPath: o.KeyPath,
- caPath: o.CaPath,
- })
- } else {
- cfg, err = loadCertAndCAFromBytes(certBytes{
- cert: o.Cert,
- key: o.Key,
- ca: o.Ca,
- })
- }
- return cfg, err
-}
-
-// authority implements authorityPinner interface.
-func (c *certConfig) authority() *x509.CertPool {
- return c.ca
-}
-
-// ensure certConfig implements authorityPinner.
-var _ authorityPinner = &certConfig{}
-
-// verifyFun is the type expected by the VerifyPeerCertificate callback in tls.Config.
-type verifyFun func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
-
-// customVerifyFactory returns a verifyFun callback that will verify any received certificates
-// against the ca provided by the pased implementation of authorityPinner
-func customVerifyFactory(pinner authorityPinner) verifyFun {
- // customVerify is a version of the verification routines that does not try to verify
- // the Common Name, since we don't know it a priori for a VPN gateway. Returns
- // an error if the verification fails.
- // From tls/common documentation: If normal verification is disabled by
- // setting InsecureSkipVerify, [...] then this callback will be considered but
- // the verifiedChains argument will always be nil.
- customVerify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
- // we assume (from docs) that we're always given the
- // leaf certificate as the first cert in the array.
- leaf, _ := x509.ParseCertificate(rawCerts[0])
- if leaf == nil {
- return fmt.Errorf("%w: %s", ErrCannotVerifyCertChain, "nothing to verify")
- }
- // By default has DNSName verification disabled.
- opts := certVerifyOptions()
- // Set the configured CA(s) as the certificate pool to verify against.
- opts.Roots = pinner.authority()
-
- if _, err := leaf.Verify(opts); err != nil {
- return fmt.Errorf("%w: %s", ErrCannotVerifyCertChain, err)
- }
- return nil
- }
- return customVerify
-}
-
-// initTLS returns a tls.Config matching the VPN options. Internally, it uses
-// the verify function returned by the global customVerifyFactory,
-// verification function since verifying the ServerName does not make sense in
-// the context of establishing a VPN session: we perform mutual TLS
-// Authentication with the custom CA.
-func initTLS(session *session, cfg *certConfig) (*tls.Config, error) {
- if session == nil || cfg == nil {
- return nil, fmt.Errorf("%w: %s", errBadInput, "nil args")
- }
-
- customVerify := customVerifyFactory(cfg)
-
- tlsConf := &tls.Config{
- // the certificate we've loaded from the config file
- Certificates: []tls.Certificate{cfg.cert},
- // crypto/tls wants either ServerName or InsecureSkipVerify set ...
- InsecureSkipVerify: true,
- // ...but we pass our own verification function that verifies against the CA and ignores the ServerName
- VerifyPeerCertificate: customVerify,
- // disable DynamicRecordSizing to lower distinguishability.
- DynamicRecordSizingDisabled: true,
- // uTLS does not pick min/max version from the passed spec
- MinVersion: tls.VersionTLS12,
- MaxVersion: tls.VersionTLS13,
- } //#nosec G402
-
- return tlsConf, nil
-}
-
-// tlsHandshake performs the TLS handshake over the control channel, and return
-// the TLS Client as a net.Conn; returns also any error during the handshake.
-func tlsHandshake(tlsConn *controlChannelTLSConn, tlsConf *tls.Config) (net.Conn, error) {
- tlsClient, err := tlsFactoryFn(tlsConn, tlsConf)
- if err != nil {
- return nil, err
- }
- if err := tlsClient.Handshake(); err != nil {
- return nil, fmt.Errorf("%w: %s", ErrBadTLSHandshake, err)
- }
- return tlsClient, nil
-}
-
-// handshaker is a custom interface that we define here to be able to mock
-// the tls.Conn implementation.
-type handshaker interface {
- net.Conn
- Handshake() error
-}
-
-// defaultTLSFactory returns an implementer of the handshaker interface; that
-// is, the default tls.Client factory; and an error.
-// we're not using the default factory right now, but it comes handy to be able
-// to compare the fingerprints with a golang TLS handshake.
-func defaultTLSFactory(conn net.Conn, config *tls.Config) (handshaker, error) {
- c := tls.Client(conn, config)
- return c, nil
-}
-
-// vpnClientHelloHex is the hexadecimal representation of a capture from the reference openvpn implementation.
-// openvpn=2.5.5,openssl=3.0.2
-// You can use https://github.com/ainghazal/sniff/tree/main/clienthello to
-// analyze a ClientHello from the wire or pcap.
-var vpnClientHelloHex = `1603010114010001100303534e0a0f2687b240f7c7dfbb51c4aac33639f28173aa5d7bcebb159695ab0855208b835bf240a83df66885d6747b5bbf1b631e8c34ae469c629d7eb76e247128eb0032130213031301c02cc030009fcca9cca8ccaac02bc02f009ec024c028006bc023c0270067c00ac0140039c009c013003300ff01000095000b000403000102000a00160014001d0017001e00190018010001010102010301040016000000170000000d002a0028040305030603080708080809080a080b080408050806040105010601030303010302040205020602002b0009080304030303020301002d00020101003300260024001d0020a10bc24becb583293c317220e6725205d3a177a4a974090f6ffcf13a43da7035`
-
-// parrotTLSFactory returns an implementer of the handshaker interface; in this
-// case, a parroting implementation; and an error.
-func parrotTLSFactory(conn net.Conn, config *tls.Config) (handshaker, error) {
- fingerprinter := &tls.Fingerprinter{AllowBluntMimicry: true}
- rawOpenVPNClientHelloBytes, err := hex.DecodeString(vpnClientHelloHex)
- if err != nil {
- return nil, fmt.Errorf("%w: cannot decode raw fingerprint: %s", ErrBadParrot, err)
- }
- generatedSpec, err := fingerprinter.FingerprintClientHello(rawOpenVPNClientHelloBytes)
- if err != nil {
- return nil, fmt.Errorf("%w: fingerprinting failed: %s", ErrBadParrot, err)
- }
- client := tls.UClient(conn, config, tls.HelloCustom)
- if err := client.ApplyPreset(generatedSpec); err != nil {
- return nil, fmt.Errorf("%w: cannot apply spec: %s", ErrBadParrot, err)
- }
- return client, nil
-}
-
-// global variables to allow monkeypatching in tests.
-var (
- initTLSFn = initTLS
- tlsFactoryFn = parrotTLSFactory
- tlsHandshakeFn = tlsHandshake
-)
diff --git a/vpn/tls_test.go b/vpn/tls_test.go
deleted file mode 100644
index eae2c80c..00000000
--- a/vpn/tls_test.go
+++ /dev/null
@@ -1,897 +0,0 @@
-package vpn
-
-import (
- "crypto/rand"
- "crypto/rsa"
- "crypto/x509"
- "crypto/x509/pkix"
- "errors"
- "math/big"
- "net"
- "os"
- "reflect"
- "testing"
- "time"
-
- "github.com/google/martian/mitm"
- "github.com/ooni/minivpn/vpn/mocks"
- tls "github.com/refraction-networking/utls"
-)
-
-func makeDummyOptionsForCertPaths() *Options {
- return &Options{
- CertPath: "aa",
- KeyPath: "aa",
- CaPath: "aa",
- }
-}
-
-// TODO(ainghazal): many of these tests right now create certs, which is costly
-// bacause creating a CA each time is expensive.
-// I should mock any certs for the tests that do not need to actually verify
-// the cert chain.
-
-func Test_initTLS(t *testing.T) {
- type args struct {
- session *session
- cfg *certConfig
- }
- tests := []struct {
- name string
- args args
- want *tls.Config
- wantErr error
- }{
- {
- name: "empty opts should fail",
- args: args{
- session: makeTestingSession(),
- },
- want: nil,
- wantErr: errBadInput,
- },
- {
- name: "empty session should fail",
- args: args{
- cfg: func() *certConfig {
- crt, err := writeTestingCerts(t.TempDir())
- if err != nil {
- t.Errorf("initTLS() cannot create certs for test %v", err.Error())
- }
- cfg, err := newCertConfigFromOptions(&Options{
- CertPath: crt.cert,
- KeyPath: crt.key,
- CaPath: crt.ca,
- })
- if err != nil {
- t.Errorf("initTLS() cannot load config from opts %v", err.Error())
- }
- return cfg
- }(),
- },
- want: nil,
- wantErr: errBadInput,
- },
- {
- name: "default tls config should not fail with proper cert paths",
- args: args{
- session: makeTestingSession(),
- cfg: func() *certConfig {
- c, _ := writeTestingCerts("")
- cfg, err := newCertConfigFromOptions(
- &Options{
- CertPath: c.cert,
- KeyPath: c.key,
- CaPath: c.ca,
- })
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- return cfg
- }(),
- },
- want: nil,
- wantErr: nil,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := initTLS(tt.args.session, tt.args.cfg)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("initTLS() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if tt.want == nil {
- return
- }
- if !reflect.DeepEqual(got.InsecureSkipVerify, tt.want.InsecureSkipVerify) {
- t.Errorf("initTLS() InsecureSkipVerify = %v, want %v", got.InsecureSkipVerify, tt.want.InsecureSkipVerify)
- return
- }
- })
- }
-}
-
-var pemTestingKey = []byte(`-----BEGIN PRIVATE KEY-----
-MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/vw0YScdbP2wg
-3M+N6BlsCQePUVFlyLh3faPtfqKTeWfyMYhGMeUE4fMcO1H0l7b/+zfwfA85AhlT
-dU152AXvizBidnaQXwVxsxzLPiPxn3qH5KxD72vkMHMyUrRh/tdJzIj1bqlCiLcw
-SK5EDPMwuUSAIk7evRzLUdGu1JkUxi7xox03R5rvC8ZohAPSRxFAg6rajkk7HlUi
-BepNz5PRlPGJ0Kfn0oa/BF+5F3Y4WU+75r9tK+H691eRL65exTGrYIOZE9Rd6i8C
-S3WoFNmlO6tv0HMAh/GYR6/mrekOkSZdjNIbDfcNiFsvNtMIO9jztd7g/3BcQg/3
-eFydHplrAgMBAAECggEAM8lBnCGw+e/zIB0C4WyiEQ+PPyHTPg4r4/nG4EmnVvUf
-IcZG685l8B+mLSXISKsA/bm3rfeTlO4AMQ4pUpMJZ1zMQIuGEg/XxJF/YVTzGDre
-OP2FmQN8vDBprFmx5hWRx5i6FK9Cf3m1IBFBH5fvxmUDHygk7PteX3tFilZY0ccM
-TpK8nOOpbbK/8S8dC6ePXYgjamLotAnKdgKnpmxQjiprsRAWiOr7DFdjMLCUyZkC
-NYwRszVNX84wLOFNzFdU653gFKNcJ/8NI2MBQ5EaBMWOcxNgdfBtCXE9GwQVNzp2
-tjTt2QYbTdaw6LAMKgrWgaZBp0VSK4WTlYLifwrSQQKBgQD4Ah39r/l+QyTLwr6d
-AkMp/rgpOYzvaRzuUcZnObvi8yfFlJJ6EM4zfNICXNexdqeL+WTaSV1yuc4/rsRx
-nAgXklgz2UpATccLJ7JrCDsWgZm71tfUWQM5IbMgkyVixwGYiTsW+kMxFD0n2sNK
-sPkEgr2IiSEDfjzTf0LPr7sLyQKBgQDF7NCTTEp92FSz5OcKNSI7iH+lsVgV+U88
-Widc/thn/vRnyRqpvyjUvl9D9jMTz2/9DiV06lCYfN8KpknCb3jCWY5cjmOSZQTs
-oHQQX145Exe8cj2z+66QK6CsE1tlUC99Y684hn+eDlLMIQGMtRz8aSYb8oZo68sM
-hcTaP8CtkwKBgQDK0RhrrWyQWCKQS9uMFRyODFPYysq5wzE4qEFji3BeodFFoEHF
-d1bZ/lrUOc7evxU3wCU86kB0oQTNSYQ3EI4BkNl21V0Gh1Seh8E+DIYd2rC5T3JD
-ouOi5i9SFWO+itaAQsHDAbjPOyjkHeAVhfKvQKf1L4eDDsp5f5pItAJ4GQKBgDvF
-EwuYW1p7jMCynG7Bsu/Ffb68unwQSLRSCVcVAqcNICODYJDoUF1GjCBK5gvSdeA2
-eGtBI0uZUgW2R8n2vcH7J3md6kXYSc9neQVEt4CG2oEnAqkqlQGmmyO7yLrkpyK3
-ir+IJlvFuY05Xm1ueC1lV4PTDnH62tuSPesmm3oPAoGBANsj/l6xgcMZK6VKZHGV
-gG59FoMudCvMP1pITJh+TQPIJbD4TgYnDUG7z14zrYhxChWHYysVrIT35Iuu7k6S
-JlkPybAiLmv2nulx9fRkTzcGgvPtG3iHS/WQLvr9umWrfmQYMMW1Udr0IdflS1Sk
-fIeuXWkQrCE24uKSInkRupLO
------END PRIVATE KEY-----`)
-
-var pemTestingCertificate = []byte(`-----BEGIN CERTIFICATE-----
-MIIDjTCCAnUCFGb3X7au5DHHCSd8n6e5vG1/HGtyMA0GCSqGSIb3DQEBCwUAMIGB
-MQswCQYDVQQGEwJOWjELMAkGA1UECAwCTk8xEjAQBgNVBAcMCUludGVybmV0ejEN
-MAsGA1UECgwEQW5vbjENMAsGA1UECwwEcm9vdDESMBAGA1UEAwwJbG9jYWxob3N0
-MR8wHQYJKoZIhvcNAQkBFhB1c2VyQGV4YW1wbGUuY29tMB4XDTIyMDUyMDE4Mzk0
-N1oXDTIyMDYxOTE4Mzk0N1owgYMxCzAJBgNVBAYTAk5aMQswCQYDVQQIDAJOTzES
-MBAGA1UEBwwJSW50ZXJuZXR6MQ0wCwYDVQQKDARBbm9uMQ8wDQYDVQQLDAZzZXJ2
-ZXIxEjAQBgNVBAMMCWxvY2FsaG9zdDEfMB0GCSqGSIb3DQEJARYQdXNlckBleGFt
-cGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL+/DRhJx1s/
-bCDcz43oGWwJB49RUWXIuHd9o+1+opN5Z/IxiEYx5QTh8xw7UfSXtv/7N/B8DzkC
-GVN1TXnYBe+LMGJ2dpBfBXGzHMs+I/GfeofkrEPva+QwczJStGH+10nMiPVuqUKI
-tzBIrkQM8zC5RIAiTt69HMtR0a7UmRTGLvGjHTdHmu8LxmiEA9JHEUCDqtqOSTse
-VSIF6k3Pk9GU8YnQp+fShr8EX7kXdjhZT7vmv20r4fr3V5Evrl7FMatgg5kT1F3q
-LwJLdagU2aU7q2/QcwCH8ZhHr+at6Q6RJl2M0hsN9w2IWy820wg72PO13uD/cFxC
-D/d4XJ0emWsCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAGt+m0kwuULOVEr7QvbOI
-6pxEd9AysxWxGzGBM6G9jrhlgch10wWuhDZq0LqahlWQ8DK9Kjg+pHEYYN8B1m0L
-2lloFpXb+AXJR9RKsBr4iU2HdJkPIAwYlDhPUTeskfWP61JGGQC6oem3UXCbLldE
-VxcY3vSifP9/pIyjHVULa83FQwwsseavav3NvBgYIyglz+BLl6azMdFLXyzGzEUv
-iiN6MdNrJ34iDKHCYSlNvJktJY91eTsQ1GLYD6O9C5KrCJRp0ibQ1keSE7vdhnTY
-doKeoNOwq224DcktFdFAYnOM/q3dKxz3m8TsM5OLel4kebqDovPt0hJl2Wwwx43k
-0A==
------END CERTIFICATE-----`)
-
-var pemTestingCa = []byte(`-----BEGIN CERTIFICATE-----
-MIID5TCCAs2gAwIBAgIUecMREJYMxFeQEWNBRSCM1x/pAEIwDQYJKoZIhvcNAQEL
-BQAwgYExCzAJBgNVBAYTAk5aMQswCQYDVQQIDAJOTzESMBAGA1UEBwwJSW50ZXJu
-ZXR6MQ0wCwYDVQQKDARBbm9uMQ0wCwYDVQQLDARyb290MRIwEAYDVQQDDAlsb2Nh
-bGhvc3QxHzAdBgkqhkiG9w0BCQEWEHVzZXJAZXhhbXBsZS5jb20wHhcNMjIwNTIw
-MTgzOTQ3WhcNMjIwNjE5MTgzOTQ3WjCBgTELMAkGA1UEBhMCTloxCzAJBgNVBAgM
-Ak5PMRIwEAYDVQQHDAlJbnRlcm5ldHoxDTALBgNVBAoMBEFub24xDTALBgNVBAsM
-BHJvb3QxEjAQBgNVBAMMCWxvY2FsaG9zdDEfMB0GCSqGSIb3DQEJARYQdXNlckBl
-eGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMxO6abV
-xOy/2VuekAAvJnM2bFIpqSoWK1uMDHJc7NRWVPy2UFaDvCL2g+CSqEyqMN0NI0El
-J2cIAgUYOa0+wHJWQhAL60veR6ew9JfIDk3S7YNeKzUGgrRzKvTLdms5mL8fZpT+
-GFwHprx58EZwg2TDQ6bGdThsSYNbx72PRngIOl5k6NWdIgd0wiAAYIpNQQUc8rDC
-IG4VvoitbpzYcAFCxCVGivodLP02pk2hokbidnLyTj5wIVTccA3u9FeEq2+IIAfr
-OW+3LjCpH9SC+3qPjA0UHv2bCLMVzIp86lUsbx6Qcoy0RPh5qC28cLk19wQj5+pw
-XtOeL90d2Hokf40CAwEAAaNTMFEwHQYDVR0OBBYEFNuQwyljbQs208ZCI5NFuzvo
-1ez8MB8GA1UdIwQYMBaAFNuQwyljbQs208ZCI5NFuzvo1ez8MA8GA1UdEwEB/wQF
-MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAHPkGlDDq79rdxFfbt0dMKm1dWZtPlZl
-iIY9Pcet/hgf69OKXwb4h3E0IjFW7JHwo4Bfr4mqrTQLTC1qCRNEMC9XUyc4neQy
-3r2LRk+D7XAN1zwL6QPw550ukbLk4R4I1xQr+9Sap9h0QUaJj5tts6XSzhZ1AylJ
-HgmkOnPOpcIWm+yUMEDESGnhE8hfXR1nhb5lLrg2HIqp9qRRH1w/wc7jG3bYV3jg
-S5nL4GaRzx84PB1HWONlh0Wp7KBk2j6Lp0acoJwI2mHJcJoOPpaYiWWYNNTjMv2/
-XXNUizTI136liavLslSMoYkjYAun+5HOux/keA1L+lm2XeG06Ew1qS4=
------END CERTIFICATE-----`)
-
-type testingCert struct {
- cert string
- key string
- ca string
-}
-
-func writeTestingCerts(dir string) (testingCert, error) {
- certFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- certFile.Write(pemTestingCertificate)
-
- keyFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- keyFile.Write(pemTestingKey)
-
- caFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- caFile.Write(pemTestingCa)
-
- testingCert := testingCert{
- cert: certFile.Name(),
- key: keyFile.Name(),
- ca: caFile.Name(),
- }
- return testingCert, nil
-}
-
-func writeTestingCertsBadCAFile(dir string) (testingCert, error) {
- certFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- certFile.Write(pemTestingCertificate)
-
- keyFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- keyFile.Write(pemTestingKey)
-
- caFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- caFile.Write(pemTestingCa[:len(pemTestingCa)-10])
-
- testingCert := testingCert{
- cert: certFile.Name(),
- key: keyFile.Name(),
- ca: caFile.Name() + "-non-existent",
- }
- return testingCert, nil
-}
-
-func writeTestingCertsBadCA(dir string) (testingCert, error) {
- certFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- certFile.Write(pemTestingCertificate)
-
- keyFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- keyFile.Write(pemTestingKey)
-
- caFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- caFile.Write(pemTestingCa[:len(pemTestingCa)-10])
-
- testingCert := testingCert{
- cert: certFile.Name(),
- key: keyFile.Name(),
- ca: caFile.Name(),
- }
- return testingCert, nil
-}
-
-func writeTestingCertsBadKey(dir string) (testingCert, error) {
- certFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- certFile.Write(pemTestingCertificate)
-
- keyFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- keyFile.Write(pemTestingKey[:len(pemTestingKey)-10])
-
- caFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- caFile.Write(pemTestingCa)
-
- testingCert := testingCert{
- cert: certFile.Name(),
- key: keyFile.Name(),
- ca: caFile.Name(),
- }
- return testingCert, nil
-}
-
-func writeTestingCertsBadCert(dir string) (testingCert, error) {
- certFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- certFile.Write(pemTestingCertificate[:len(pemTestingCertificate)-10])
-
- keyFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- keyFile.Write(pemTestingKey[:len(pemTestingKey)-10])
-
- caFile, err := os.CreateTemp(dir, "tmpfile-")
- if err != nil {
- return testingCert{}, err
- }
- caFile.Write(pemTestingCa)
-
- testingCert := testingCert{
- cert: certFile.Name(),
- key: keyFile.Name(),
- ca: caFile.Name(),
- }
- return testingCert, nil
-}
-
-func Test_loadCertAndCAFromPath(t *testing.T) {
- type args struct {
- pth certPaths
- }
- tests := []struct {
- name string
- args args
- want *certConfig
- wantErr error
- }{
- {
- name: "bad ca (non existent file) should fail",
- args: func() args {
- crt, err := writeTestingCertsBadCAFile(t.TempDir())
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- return args{pth: certPaths{crt.cert, crt.key, crt.ca}}
-
- }(),
- want: nil,
- wantErr: ErrBadCA,
- },
- {
- name: "bad ca (malformed) should fail",
- args: func() args {
- crt, err := writeTestingCertsBadCA(t.TempDir())
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- return args{pth: certPaths{crt.cert, crt.key, crt.ca}}
-
- }(),
- want: nil,
- wantErr: ErrBadCA,
- },
- {
- name: "bad key",
- args: func() args {
- crt, err := writeTestingCertsBadKey(t.TempDir())
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- return args{pth: certPaths{crt.cert, crt.key, crt.ca}}
-
- }(),
- want: nil,
- wantErr: ErrBadKeypair,
- },
- {
- name: "bad cert",
- args: func() args {
- crt, err := writeTestingCertsBadCert(t.TempDir())
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- return args{pth: certPaths{crt.cert, crt.key, crt.ca}}
-
- }(),
- want: nil,
- wantErr: ErrBadKeypair,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := loadCertAndCAFromPath(tt.args.pth)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("loadCertAndCA() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("loadCertAndCA() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_loadCertAndCAFromBytes(t *testing.T) {
- type args struct {
- crt certBytes
- }
- tests := []struct {
- name string
- args args
- want *certConfig
- wantErr error
- }{
- {
- name: "bad ca should fail",
- args: args{crt: certBytes{
- ca: pemTestingCa[:len(pemTestingCa)-10],
- cert: pemTestingCertificate,
- key: pemTestingKey}},
- want: nil,
- wantErr: ErrBadCA,
- },
- {
- name: "bad cert should fail",
- args: args{crt: certBytes{
- ca: pemTestingCa,
- cert: pemTestingCertificate[:len(pemTestingCertificate)-10],
- key: pemTestingKey}},
- want: nil,
- wantErr: ErrBadKeypair,
- },
- {
- name: "bad key should fail",
- args: args{crt: certBytes{
- ca: pemTestingCa,
- cert: pemTestingCertificate,
- key: pemTestingKey[10:]}},
- want: nil,
- wantErr: ErrBadKeypair,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := loadCertAndCAFromBytes(tt.args.crt)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("loadCertAndCAFromBytes() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("loadCertAndCAFromBytes() = %v, want %v", got, tt.want)
- }
- })
- }
- t.Run("sunny path should not fail", func(t *testing.T) {
- crt := certBytes{
- ca: pemTestingCa,
- cert: pemTestingCertificate,
- key: pemTestingKey,
- }
- _, err := loadCertAndCAFromBytes(crt)
- if err != nil {
- t.Errorf("loadCertAndCAFromBytes() err = %v, want %v", err, nil)
- }
- })
-}
-
-func Test_initTLSLoadTestCertificates(t *testing.T) {
-
- t.Run("default options should not fail", func(t *testing.T) {
- session := makeTestingSession()
- crt, err := writeTestingCerts(t.TempDir())
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- cfg, err := newCertConfigFromOptions(&Options{
- CertPath: crt.cert,
- KeyPath: crt.key,
- CaPath: crt.ca,
- })
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
-
- _, err = initTLS(session, cfg)
- if err != nil {
- t.Errorf("initTLS() error = %v, want: nil", err)
- }
- })
-
- t.Run("default options from bytes should not fail", func(t *testing.T) {
- session := makeTestingSession()
- cfg, err := newCertConfigFromOptions(&Options{
- Cert: pemTestingCertificate,
- Key: pemTestingKey,
- Ca: pemTestingCa,
- })
- if err != nil {
- t.Errorf("error while testing: %v", err)
- }
- _, err = initTLS(session, cfg)
- if err != nil {
- t.Errorf("initTLS() error = %v, want: nil", err)
- }
- })
-}
-
-// mock good handshake
-
-type dummyTLSConn struct {
- tls.Conn
-}
-
-var _ handshaker = &dummyTLSConn{} // Ensure that we implement handshaker
-
-func (d *dummyTLSConn) Handshake() error {
- return nil
-}
-
-func dummyTLSFactory(net.Conn, *tls.Config) (handshaker, error) {
- return &dummyTLSConn{tls.Conn{}}, nil
-}
-
-// mock bad handshake
-
-type dummyTLSConnBadHandshake struct {
- tls.Conn
-}
-
-var _ handshaker = &dummyTLSConnBadHandshake{} // Ensure that we implement handshaker
-
-func (d *dummyTLSConnBadHandshake) Handshake() error {
- return errors.New("dummy error")
-}
-
-func dummyTLSFactoryBadHandshake(net.Conn, *tls.Config) (handshaker, error) {
- return &dummyTLSConnBadHandshake{tls.Conn{}}, nil
-}
-
-var tlsFactoryError = errors.New("tlsFactory error")
-
-func errorRaisingTLSFactory(net.Conn, *tls.Config) (handshaker, error) {
- return nil, tlsFactoryError
-}
-
-func Test_tlsHandshake(t *testing.T) {
-
- makeConnAndConf := func() (*controlChannelTLSConn, *tls.Config) {
- conn := &mocks.Conn{}
- s := makeTestingSession()
- tc, _ := newControlChannelTLSConn(conn, s)
-
- conf := &tls.Config{
- InsecureSkipVerify: true,
- }
- return tc, conf
- }
-
- t.Run("mocked good handshake should not fail", func(t *testing.T) {
- origTLS := tlsFactoryFn
- tlsFactoryFn = dummyTLSFactory
- defer func() {
- tlsFactoryFn = origTLS
- }()
-
- conn, conf := makeConnAndConf()
-
- _, err := tlsHandshake(conn, conf)
- if err != nil {
- t.Errorf("tlsHandshake() error = %v, wantErr %v", err, nil)
- return
- }
- })
-
- t.Run("mocked bad handshake should fail", func(t *testing.T) {
- origTLS := tlsFactoryFn
- tlsFactoryFn = dummyTLSFactoryBadHandshake
- defer func() {
- tlsFactoryFn = origTLS
- }()
-
- conn, conf := makeConnAndConf()
-
- wantErr := ErrBadTLSHandshake
- _, err := tlsHandshake(conn, conf)
- if !errors.Is(err, wantErr) {
- t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr)
- return
- }
- })
-
- t.Run("any error from the factory should be bubbled up", func(t *testing.T) {
- origTLS := tlsFactoryFn
- tlsFactoryFn = errorRaisingTLSFactory
- defer func() {
- tlsFactoryFn = origTLS
- }()
- wantErr := tlsFactoryError
-
- conn, conf := makeConnAndConf()
-
- _, err := tlsHandshake(conn, conf)
- if !errors.Is(err, wantErr) {
- t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr)
- return
- }
- })
-}
-
-func Test_defaultTLSFactory(t *testing.T) {
- conn := &mocks.Conn{}
- conf := &tls.Config{}
- defaultTLSFactory(conn, conf)
-}
-
-func Test_parrotTLSFactory(t *testing.T) {
- conn := &mocks.Conn{}
- conf := &tls.Config{InsecureSkipVerify: true}
-
- t.Run("parrotTLS factory does not return any error by default", func(t *testing.T) {
- _, err := parrotTLSFactory(conn, conf)
- if err != nil {
- t.Errorf("parrotTLSFactory() error = %v, wantErr %v", err, nil)
- return
- }
- })
-
- t.Run("an hex clienthello that cannot be decoded to raw bytes should raise ErrBadParrot", func(t *testing.T) {
- defer func(original string) {
- vpnClientHelloHex = original
- }(vpnClientHelloHex)
- vpnClientHelloHex = `aaa`
-
- _, err := parrotTLSFactory(conn, conf)
- wantErr := ErrBadParrot
- if !errors.Is(err, wantErr) {
- t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr)
- return
- }
- })
-
- t.Run("an hex representation that is not a valid clienthello should raise ErrBadParrot", func(t *testing.T) {
- defer func(original string) {
- vpnClientHelloHex = original
- }(vpnClientHelloHex)
- vpnClientHelloHex = `deadbeef`
-
- _, err := parrotTLSFactory(conn, conf)
- wantErr := ErrBadParrot
- if !errors.Is(err, wantErr) {
- t.Errorf("tlsHandshake() error = %v, wantErr %v", err, wantErr)
- return
- }
- })
-
- // TODO(ainghazal): there's an extra error case that I'm not pretty sure how to reach
- // (error on client.ApplyPreset)
-}
-
-func Test_customVerify(t *testing.T) {
-
- t.Run("happy path: a correct certChain should validate if we pin with the good ca", func(t *testing.T) {
- rawCerts, ca, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- err = customVerify(rawCerts, nil)
- if err != nil {
- t.Errorf("customVerify() error = %v, wantErr %v", err, nil)
- }
- })
-
- t.Run("a certChain should not validate if we do not pin the proper ca", func(t *testing.T) {
- rawCerts, _, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
- _, badCa, _, _, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- auth, err := makeCertAndCAFromMemory(badCa, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- wantErr := ErrCannotVerifyCertChain
- err = customVerify(rawCerts, nil)
-
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, nil)
- }
- })
-
- t.Run("a certChain should not validate if we pass an empty ca", func(t *testing.T) {
- rawCerts, _, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- emptyCa := &x509.Certificate{}
-
- auth, err := makeCertAndCAFromMemory(emptyCa, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- wantErr := ErrCannotVerifyCertChain
- err = customVerify(rawCerts, nil)
-
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, nil)
- }
- })
-
- t.Run("a correct certChain fails if DNSName is set in VerifyOptions", func(t *testing.T) {
- // this test is really only testing the behavior of golang x509 validation
- // in the stdlib, but it gives me more faith in the correctness
- // of the custom verify function
- rawCerts, ca, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- defer func(orig func() x509.VerifyOptions) {
- certVerifyOptions = orig
- }(certVerifyOptions)
-
- // the test cert has random.gateway set as the DNSName, so we're just verifying
- // that the verification actually fails with options different from the default that we're
- // setting in the certVerifyOptions global.
- certVerifyOptions = func() x509.VerifyOptions {
- return x509.VerifyOptions{DNSName: "other.gateway"}
- }
-
- wantErr := ErrCannotVerifyCertChain
-
- auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- err = customVerify(rawCerts, nil)
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, nil)
- }
- })
-
- t.Run("empty certchain raises error", func(t *testing.T) {
- emptyCerts := [][]byte{[]byte{}, []byte{}}
- wantErr := ErrCannotVerifyCertChain
-
- _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- err = customVerify(emptyCerts, nil)
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr)
- }
- })
-
- t.Run("garbage certchain raises error", func(t *testing.T) {
- garbageCerts := [][]byte{[]byte{0xde, 0xad}, []byte{0xbe, 0xef}}
- wantErr := ErrCannotVerifyCertChain
-
- _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- err = customVerify(garbageCerts, nil)
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr)
- }
- })
-
- t.Run("attempting to verify one cert with a different ca raises error", func(t *testing.T) {
- certChainOne, _, _, _, _ := makeRawCertsForTesting()
- certChainTwo, _, _, _, _ := makeRawCertsForTesting()
- badChain := [][]byte{certChainOne[0], certChainTwo[1]}
- wantErr := ErrCannotVerifyCertChain
-
- _, ca, vpnCert, vpnKey, err := makeRawCertsForTesting()
- if err != nil {
- t.Errorf("error getting raw certs")
- return
- }
-
- auth, err := makeCertAndCAFromMemory(ca, vpnCert, vpnKey)
- customVerify := customVerifyFactory(auth)
-
- err = customVerify(badChain, nil)
- if !errors.Is(err, wantErr) {
- t.Errorf("customVerify() error = %v, wantErr %v", err, wantErr)
- }
- })
-}
-
-// makeRawCertsForTesting creates a CA, and returns:
-// * an array of byte arrays containing a cert signed with that CA and the CA itself (to be used to test the verify routine).
-// * the ca used to sign the certs
-// * a cert that simulates a vpn certificate signed by the ca (rsa)
-// * the private key for the vpn certificate
-// * an error if it could not build any of the certs correctly.
-func makeRawCertsForTesting() ([][]byte, *x509.Certificate, []byte, []byte, error) {
- // set up a CA certificate. this sets up a 2048 cert for the ca, if we ever
- // want to shave milliseconds we can roll a ca with a smaller key size.
- ca, caPrivKey, err := mitm.NewAuthority("ca", "oonitarians united", 1*time.Hour)
- if err != nil {
- return nil, nil, nil, nil, err
- }
- caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- // set up a leaf certificate - this would be the gateway cert
- cert := &x509.Certificate{
- SerialNumber: big.NewInt(1984),
- Subject: pkix.Name{
- Organization: []string{"Oonitarians united"},
- StreetAddress: []string{"On a pinneaple at the bottom of the sea"},
- CommonName: "random.gateway",
- },
- NotBefore: time.Now(),
- NotAfter: time.Now().AddDate(10, 0, 0),
- DNSNames: []string{"random.gateway", "randomgw"},
- }
-
- // tiny cert size to make tests go brrr
- certPrivKey, err := rsa.GenerateKey(rand.Reader, 512)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- // set up a vpn certificate - this would be the client cert
- vpnCert := &x509.Certificate{
- SerialNumber: big.NewInt(1984),
- Subject: pkix.Name{
- Organization: []string{"Oonitarians united"},
- StreetAddress: []string{"On a pinneaple at the bottom of the sea"},
- CommonName: "client cert",
- },
- NotBefore: time.Now(),
- NotAfter: time.Now().AddDate(10, 0, 0),
- }
-
- // tiny cert size to make tests go brrr
- vpnCertPrivKey, err := rsa.GenerateKey(rand.Reader, 512)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- vpnCertBytes, err := x509.CreateCertificate(rand.Reader, vpnCert, ca, &vpnCertPrivKey.PublicKey, caPrivKey)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- vpnKeyBytes := x509.MarshalPKCS1PrivateKey(vpnCertPrivKey)
-
- result := [][]byte{certBytes, caBytes}
- return result, ca, vpnCertBytes, vpnKeyBytes, nil
-}
-
-func makeCertAndCAFromMemory(caCert *x509.Certificate, vpnCert []byte, vpnKey []byte) (*certConfig, error) {
- ca := x509.NewCertPool()
- ca.AddCert(caCert)
- cert, _ := tls.X509KeyPair(vpnCert, vpnKey)
- auth := &certConfig{
- ca: ca,
- cert: cert,
- }
- return auth, nil
-}
diff --git a/vpn/transport.go b/vpn/transport.go
deleted file mode 100644
index baa3cdc4..00000000
--- a/vpn/transport.go
+++ /dev/null
@@ -1,310 +0,0 @@
-package vpn
-
-//
-// Transports for OpenVPN over TCP and over UDP.
-// This file includes:
-// 1. Methods for reading packets from the wire
-// 2. A TLS transport that reads and writes TLS records as part of control packets.
-//
-
-import (
- "bytes"
- "encoding/binary"
- "encoding/hex"
- "errors"
- "fmt"
- "io"
- "net"
- "time"
-)
-
-var (
- // ErrBadConnNetwork indicates that the conn's network is neither TCP nor UDP.
- ErrBadConnNetwork = errors.New("bad conn.Network value")
-
- // ErrPacketTooShort indicates that a packet is too short.
- ErrPacketTooShort = errors.New("packet too short")
-)
-
-// direct reads on the underlying conn
-
-func readPacket(conn net.Conn) ([]byte, error) {
- switch network := conn.LocalAddr().Network(); network {
- case "tcp", "tcp4", "tcp6":
- return readPacketFromTCP(conn)
- case "udp", "udp4", "upd6":
- // for UDP we don't need to parse size frames
- return readPacketFromUDP(conn)
- default:
- return nil, fmt.Errorf("%w: %s", ErrBadConnNetwork, network)
- }
-}
-
-func readPacketFromUDP(conn net.Conn) ([]byte, error) {
- const enough = 1 << 17
- buf := make([]byte, enough)
-
- count, err := conn.Read(buf)
- if err != nil {
- return nil, err
- }
- buf = buf[:count]
- return buf, nil
-}
-
-func readPacketFromTCP(conn net.Conn) ([]byte, error) {
- lenbuf := make([]byte, 2)
- if _, err := io.ReadFull(conn, lenbuf); err != nil {
- return nil, err
- }
- length := binary.BigEndian.Uint16(lenbuf)
- buf := make([]byte, length)
- if _, err := io.ReadFull(conn, buf); err != nil {
- return nil, err
- }
- return buf, nil
-}
-
-// tlsModeTransporter is a transport for OpenVPN in TLS mode.
-//
-// See https://openvpn.net/community-resources/openvpn-protocol/ for documentation
-// on the protocol used by OpenVPN on the wire.
-type tlsModeTransporter interface {
- // ReadPacket reads an OpenVPN packet from the wire.
- ReadPacket() (p *packet, err error)
-
- // WritePacket writes an OpenVPN packet to the wire.
- WritePacket(opcodeKeyID uint8, data []byte) error
-
- // SetDeadline sets the underlying conn's deadline.
- SetDeadline(deadline time.Time) error
-
- // SetReadDeadline sets the underlying conn's read deadline.
- SetReadDeadline(deadline time.Time) error
-
- // SetWriteDeadline sets the underlying conn's write deadline.
- SetWriteDeadline(deadline time.Time) error
-
- // Close closes the underlying conn.
- Close() error
-
- // LocalAddr returns the underlying conn's local addr.
- LocalAddr() net.Addr
-
- // RemoteAddr returns the underlying conn's remote addr.
- RemoteAddr() net.Addr
-}
-
-// newTLSModeTransport creates a new TLSModeTransporter using the given net.Conn.
-func newTLSModeTransport(conn net.Conn, s *session) (tlsModeTransporter, error) {
- return &tlsTransport{Conn: conn, session: s}, nil
-}
-
-// tlsTransport implements TLSModeTransporter.
-type tlsTransport struct {
- net.Conn
- session *session
-}
-
-// ReadPacket returns a packet reading from the underlying conn, and an error
-// if the read did not succeed.
-func (t *tlsTransport) ReadPacket() (*packet, error) {
- buf, err := readPacket(t.Conn)
- if err != nil {
- return nil, err
- }
-
- p, err := parsePacketFromBytes(buf)
- if err != nil {
- return &packet{}, err
- }
- if p.isACK() {
- logger.Warn("tls: got ACK (ignored)")
- return &packet{}, nil
- }
- return p, nil
-}
-
-// WritePacket writes a packet to the underlying conn. It expect the opcode of the packet and a byte array containing the serialized data. It returns an error if the write did not succeed.
-func (t *tlsTransport) WritePacket(opcodeKeyID uint8, data []byte) error {
- if t.session == nil {
- return fmt.Errorf("%w:%s", errBadInput, "tlsTransport badly initialized")
-
- }
- p := newPacketFromPayload(opcodeKeyID, 0, data)
- id, err := t.session.LocalPacketID()
- if err != nil {
- return err
- }
- p.id = id
- p.localSessionID = t.session.LocalSessionID
- payload := p.Bytes()
-
- out := maybeAddSizeFrame(t.Conn, payload)
-
- logger.Debug(fmt.Sprintln("tls write:", len(out)))
- logger.Debug(fmt.Sprintln(hex.Dump(out)))
-
- _, err = t.Conn.Write(out)
- return err
-}
-
-var _ tlsModeTransporter = &tlsTransport{} // Ensure that we implement TLSModelTransporter
-
-// controlChannelTLSConn implements net.Conn, and is passed to the tls.Client to perform a
-// TLS Handshake over OpenVPN control packets.
-type controlChannelTLSConn struct {
- conn net.Conn
- session *session
- transport tlsModeTransporter
- // we need to buffer reads because the tls records request less than
- // the payload we receive.
- bufReader *bytes.Buffer
-
- doReadFromConnFn func(*controlChannelTLSConn, []byte) (bool, int, error)
- doReadFromQueueFn func(*controlChannelTLSConn, []byte) (bool, int, error)
-}
-
-// newControlChannelTLSConn returns a controlChannelTLSConn. It requires the on-the-wire
-// net.Conn that will be used underneath, and a configured session. It returns
-// also an error if the operation cannot be completed.
-func newControlChannelTLSConn(conn net.Conn, s *session) (*controlChannelTLSConn, error) {
- transport, err := newTLSModeTransport(conn, s)
- if err != nil {
- return &controlChannelTLSConn{}, err
- }
- buf := bytes.NewBuffer(nil)
- tlsConn := &controlChannelTLSConn{
- conn: conn,
- session: s,
- transport: transport,
- bufReader: buf,
- }
- tlsConn.doReadFromConnFn = doReadFromConn
- tlsConn.doReadFromQueueFn = doReadFromQueue
- return tlsConn, err
-}
-
-// Read over the control channel. This method implements the reliability layer:
-// it retries reads until the _next_ packet is received (according to the
-// packetID). Returns also an error if the operation cannot be completed.
-func (c *controlChannelTLSConn) Read(b []byte) (int, error) {
- if c.session == nil || c.session.ackQueue == nil {
- return 0, fmt.Errorf("%w: %s", errBadInput, "bad session in TLSConn.Read()")
- }
- for {
- switch len(c.session.ackQueue) {
- case 0:
- ok, n, err := c.doReadFromConnFn(c, b)
- if ok {
- return n, err
- }
- default:
- ok, n, err := c.doReadFromQueueFn(c, b)
- if ok {
- return n, err
- }
- }
- }
-}
-
-func doReadFromConn(c *controlChannelTLSConn, b []byte) (bool, int, error) {
- p, err := c.doRead()
-
- if err != nil {
- return true, 0, err
- }
- switch c.canRead(p) {
- case true:
- if err := sendACKFn(c.conn, c.session, p.id); err != nil {
- return true, 0, err
- }
- n, err := writeAndReadFromBufferFn(c.bufReader, b, p.payload)
- return true, n, err
- case false:
- if p != nil {
- c.session.ackQueue <- p
- }
- }
-
- return false, 0, nil
-}
-
-func doReadFromQueue(c *controlChannelTLSConn, b []byte) (bool, int, error) {
- for p := range c.session.ackQueue {
- if c.canRead(p) {
- if err := sendACKFn(c.conn, c.session, p.id); err != nil {
- return true, 0, err
- }
- n, err := writeAndReadFromBufferFn(c.bufReader, b, p.payload)
- return true, n, err
- } else {
- c.session.ackQueue <- p
- return doReadFromConn(c, b)
- }
- }
- return false, 0, nil
-}
-
-// doRead() calls ReadPacket() in the underlying transport implementation. It
-// returns a packet and an error.
-func (c *controlChannelTLSConn) doRead() (*packet, error) {
- if c.transport == nil {
- return nil, fmt.Errorf("%w:%s", errBadInput, "tlsConn is missing transport")
-
- }
- return c.transport.ReadPacket()
-}
-
-// canRead returns true if the packet is not nil and its packetID is the next
-// integer in the expected sequence; returns false otherwise.
-func (c *controlChannelTLSConn) canRead(p *packet) bool {
- return p != nil && c.session.isNextPacket(p)
-}
-
-// writeAndReadPayloadFromBuffer writes a given payload to a buffered reader, and returns
-// a read from that same buffered reader into the passed byte array. it returns both an integer
-// denoting the amount of bytes read, and any error during the operation.
-func writeAndReadFromBuffer(bb *bytes.Buffer, b []byte, payload []byte) (int, error) {
- bb.Write(payload)
- return bb.Read(b)
-}
-
-var writeAndReadFromBufferFn = writeAndReadFromBuffer
-
-// Write writes the given data to the tls connection.
-func (c *controlChannelTLSConn) Write(b []byte) (int, error) {
- err := c.transport.WritePacket(uint8(pControlV1), b)
- if err != nil {
- logger.Errorf("tls write: %s", err.Error())
- return 0, err
- }
- return len(b), err
-}
-
-// Close closes the tls connection.
-func (c *controlChannelTLSConn) Close() error {
- return c.conn.Close()
-}
-
-func (c *controlChannelTLSConn) LocalAddr() net.Addr {
- return c.conn.LocalAddr()
-}
-
-func (c *controlChannelTLSConn) RemoteAddr() net.Addr {
- return c.conn.RemoteAddr()
-}
-
-func (c *controlChannelTLSConn) SetDeadline(tt time.Time) error {
- return c.conn.SetDeadline(tt)
-}
-
-func (c *controlChannelTLSConn) SetReadDeadline(tt time.Time) error {
- return c.conn.SetReadDeadline(tt)
-}
-
-func (c *controlChannelTLSConn) SetWriteDeadline(tt time.Time) error {
- return c.conn.SetWriteDeadline(tt)
-}
-
-var _ net.Conn = &controlChannelTLSConn{} // Ensure that we implement net.Conn
diff --git a/vpn/transport_test.go b/vpn/transport_test.go
deleted file mode 100644
index afd7ab53..00000000
--- a/vpn/transport_test.go
+++ /dev/null
@@ -1,647 +0,0 @@
-package vpn
-
-import (
- "bytes"
- "errors"
- "net"
- "reflect"
- "testing"
- "time"
-
- "github.com/ooni/minivpn/vpn/mocks"
-)
-
-func Test_readPacketFromUDP(t *testing.T) {
- conn := makeTestinConnFromNetwork("udp")
- got, err := readPacketFromUDP(conn)
- want := []byte("alles ist gut")
- if err != nil {
- t.Errorf("readPacketFromUDP() error = %v, want %v", err, nil)
- }
- if !bytes.Equal(got, want) {
- t.Errorf("readPacketFromTCP() got = %s, want %s", got, want)
- }
-}
-
-func Test_readPacketFromTCP(t *testing.T) {
- conn := makeTestinConnFromNetwork("tcp")
- got, err := readPacketFromTCP(conn)
- want := []byte("alles ist gut")
- if err != nil {
- t.Errorf("readPacketFromTCP() error = %v, want %v", err, nil)
- }
- if !bytes.Equal(got, want) {
- t.Errorf("readPacketFromTCP() got = %s, want %s", got, want)
- }
-}
-
-func Test_readPacket_BadNetwork(t *testing.T) {
- conn := makeTestinConnFromNetwork("unix")
- _, err := readPacket(conn)
- wantErr := ErrBadConnNetwork
- if !errors.Is(err, wantErr) {
- t.Errorf("readPacket() got = %v, want %v", err, wantErr)
- }
-}
-
-type MockTLSTransportConn struct {
- *mocks.Conn
- written []byte
-}
-
-func makeTestingTLSTransportWithPacket(packetPayload *packet) (*tlsTransport, *MockTLSTransportConn) {
- s := makeTestingSession()
- a := &mocks.Addr{}
- a.MockNetwork = func() string { return "udp" }
- c := &MockTLSTransportConn{Conn: &mocks.Conn{}}
- c.MockLocalAddr = func() net.Addr { return a }
- c.MockRead = func(b []byte) (int, error) {
- out := packetPayload.Bytes()
- copy(b, out)
- return len(out), nil
- }
- c.MockWrite = func(b []byte) (int, error) {
- c.written = b
- return 0, nil
- }
- return &tlsTransport{Conn: c, session: s}, c
-}
-
-func makeTestingTLSTransportWithDefaultPacketPayload() (*tlsTransport, *MockTLSTransportConn) {
- readPayload := &packet{opcode: pDataV1, payload: []byte("this is not a payload")}
- return makeTestingTLSTransportWithPacket(readPayload)
-}
-
-func Test_tlsTransport_ReadPacket(t *testing.T) {
- fakePayload := append(
- // fake tag
- bytes.Repeat([]byte{0x00}, 13),
- []byte("this is not a payload")...)
- want := &packet{opcode: pDataV1, payload: fakePayload}
-
- tt, _ := makeTestingTLSTransportWithDefaultPacketPayload()
- got, err := tt.ReadPacket()
-
- if err != nil {
- t.Errorf("ReadPacket() error = %v, wantErr %v", err, nil)
- }
- if !bytes.Equal(got.payload, want.payload) {
- t.Errorf("ReadPacket() got = %v, want = %v", got.payload, want.payload)
- }
-}
-
-func Test_tlsTransport_ReadPacket_ACK(t *testing.T) {
- ackPacket := &packet{opcode: pACKV1}
- tt, _ := makeTestingTLSTransportWithPacket(ackPacket)
- got, err := tt.ReadPacket()
- if err != nil {
- t.Errorf("ReadPacket() error = %v, wantErr %v", err, nil)
- }
- if !bytes.Equal(got.payload, ackPacket.payload) {
- t.Errorf("ReadPacket() got = %v, want = %v", got.payload, ackPacket.payload)
- }
-
-}
-
-func Test_tlsTransport_WritePacket(t *testing.T) {
- payload := []byte("this is not a payload")
- fakePacket := append([]byte{0x30, 0x02}, bytes.Repeat([]byte{0x00}, 12)...)
- fakePacket = append(fakePacket, payload...)
-
- tt, conn := makeTestingTLSTransportWithDefaultPacketPayload()
- err := tt.WritePacket(pDataV1, payload)
- if err != nil {
- t.Errorf("ReadPacket() error = %v, want = %v", err, nil)
- }
- if !bytes.Equal(conn.written, fakePacket) {
- t.Errorf("ReadPacket(): got = %v, want = %v", conn.written, fakePacket)
- }
-}
-
-func makeTestinConnFromNetwork(network string) net.Conn {
- mockAddr := &mocks.Addr{}
- mockAddr.MockNetwork = func() string {
- return network
- }
- c := &mocks.Conn{}
- c.MockLocalAddr = func() net.Addr {
- return mockAddr
- }
- switch network {
- case "udp":
- c.MockRead = func(b []byte) (int, error) {
- out := []byte("alles ist gut")
- copy(b, out)
- return len(out), nil
- }
- case "tcp":
- c.MockRead = func(b []byte) (int, error) {
- var out []byte
- switch c.Count {
- case 0:
- out = []byte{0x00, 0x0d}
- copy(b, out)
- c.Count += 1
- case 1:
- out = []byte("alles ist gut")
- copy(b, out)
- }
- return len(out), nil
- }
- default:
- c.MockRead = func([]byte) (int, error) {
- return 0, nil
- }
- }
- return c
-}
-
-func Test_readPacket(t *testing.T) {
-
- type args struct {
- conn net.Conn
- }
- tests := []struct {
- name string
- args args
- want []byte
- wantErr error
- }{
- {
- name: "test read from udp conn is ok",
- args: args{
- conn: makeTestinConnFromNetwork("udp"),
- },
- want: []byte("alles ist gut"),
- wantErr: nil,
- },
- {
- name: "test read from tcp conn is ok",
- args: args{
- conn: makeTestinConnFromNetwork("tcp"),
- },
- want: []byte("alles ist gut"),
- wantErr: nil,
- },
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := readPacket(tt.args.conn)
- if !errors.Is(err, tt.wantErr) {
- t.Errorf("readPacket() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("readPacket() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_NewTLSConn(t *testing.T) {
- conn := makeTestinConnFromNetwork("udp")
- s := makeTestingSession()
- _, err := newControlChannelTLSConn(conn, s)
- if err != nil {
- t.Errorf("NewTLSConn() error = %v, want = nil", err)
- }
-}
-
-type MockTLSConn struct {
- mocks.Conn
- closedCalled bool
- localAddrCalled bool
- remoteAddrCalled bool
- setDeadlineCalled bool
- setReadDeadlineCalled bool
- setWriteDeadlineCalled bool
-}
-
-func makeConnForTransportTest() *MockTLSConn {
- localAddr := &mocks.Addr{}
- localAddr.MockString = func() string { return "1.1.1.1" }
- localAddr.MockNetwork = func() string { return "udp" }
-
- remoteAddr := &mocks.Addr{}
- remoteAddr.MockString = func() string { return "2.2.2.2" }
- remoteAddr.MockNetwork = func() string { return "udp" }
-
- c := &MockTLSConn{}
- c.MockClose = func() error {
- c.closedCalled = true
- return nil
- }
- c.MockLocalAddr = func() net.Addr {
- c.localAddrCalled = true
- return localAddr
- }
- c.MockRemoteAddr = func() net.Addr {
- c.remoteAddrCalled = true
- return remoteAddr
- }
- c.MockSetDeadline = func(time.Time) error {
- c.setDeadlineCalled = true
- return nil
- }
- c.MockSetReadDeadline = func(time.Time) error {
- c.setReadDeadlineCalled = true
- return nil
- }
- c.MockSetWriteDeadline = func(time.Time) error {
- c.setWriteDeadlineCalled = true
- return nil
- }
- return c
-}
-
-func makeTestingTLSConn() (*controlChannelTLSConn, *MockTLSConn) {
- c := makeConnForTransportTest()
- t := &controlChannelTLSConn{}
- t.conn = c
- return t, c
-}
-
-func TestTLSConn_Read_Fails_With_Bad_Data(t *testing.T) {
- tc, _ := makeTestingTLSConn()
- b := make([]byte, 16)
- _, err := tc.Read(b)
- wantErr := errBadInput
- if !errors.Is(err, wantErr) {
- t.Errorf("TLSConn.Read(): empty session; gotErr = %v, wantErr = %v ", err, wantErr)
- }
-
-}
-
-func TestTLSConn_Read(t *testing.T) {
- // call witnesses
- readFromConnCalled := false
- readFromQueueCalled := false
-
- payload := []byte("alles ist gut")
-
- // setup the fields we need
- tc, _ := makeTestingTLSConn()
- tc.session = makeTestingSession()
- ackQueue := make(chan *packet, 16)
- tc.session.ackQueue = ackQueue
-
- // mock read functions
- tc.doReadFromConnFn = func(tcn *controlChannelTLSConn, b []byte) (bool, int, error) {
- readFromConnCalled = true
- copy(b[:], payload)
- return true, len(payload), nil
- }
- tc.doReadFromQueueFn = func(tcn *controlChannelTLSConn, b []byte) (bool, int, error) {
- readFromQueueCalled = true
- copy(b[:], payload)
- return true, len(payload), nil
- }
-
- // first we read from conn
-
- b := make([]byte, 255)
- n, err := tc.Read(b)
-
- if err != nil {
- t.Errorf("TLSConn.Read(): expected no error, got %v", err)
- }
- if n != len(payload) {
- t.Errorf("TLSConn.Read(): readFromConn returned wrong len %v", n)
- }
- if !readFromConnCalled {
- t.Errorf("TLSConn.Read(): readFromConn not called")
- }
- if readFromQueueCalled {
- t.Errorf("TLSConn.Read(): readFromQueue should have not been called")
- }
-
- // now we read from queue. reset the witnesses:
-
- readFromConnCalled = false
- readFromQueueCalled = false
-
- // inject one packet in the queue
- p := &packet{opcode: pDataV1, payload: []byte("alles ist gut")}
- tc.session.ackQueue <- p
-
- b = make([]byte, 255)
-
- // and do another call to Read()
- n, err = tc.Read(b)
- if err != nil {
- t.Errorf("TLSConn.Read(): expected no error, got %v", err)
- }
- if !readFromQueueCalled {
- t.Errorf("TLSConn.Read(): readFromQueue not called")
- }
- if readFromConnCalled {
- t.Errorf("TLSConn.Read(): readFromConn should not have been called")
- }
-}
-
-func makeTestingTLSTransportFromPayload(payload []byte) (*tlsTransport, *MockTLSTransportConn) {
- s := makeTestingSession()
- a := &mocks.Addr{}
- a.MockNetwork = func() string { return "udp" }
- c := &MockTLSTransportConn{Conn: &mocks.Conn{}}
- c.MockLocalAddr = func() net.Addr { return a }
- c.MockRead = func(b []byte) (int, error) {
- out := payload
- copy(b, out)
- return len(out), nil
- }
- c.MockWrite = func(b []byte) (int, error) {
- c.written = b
- return 0, nil
- }
- return &tlsTransport{Conn: c, session: s}, c
-}
-
-func makePacketForTLSConnTest(id int, s *session) *packet {
- p := &packet{
- id: packetID(id),
- opcode: pControlV1,
- keyID: 0x00,
- payload: []byte("aaa"),
- localSessionID: s.LocalSessionID,
- remoteSessionID: s.RemoteSessionID,
- acks: []packetID{},
- }
- return p
-}
-
-func makeTestingTLSConnForReadTest(payload []byte) *controlChannelTLSConn {
- tc, _ := makeTestingTLSConn()
- tt, _ := makeTestingTLSTransportFromPayload(payload)
- tc.transport = tt
- tc.session = makeTestingSession()
- ackQueue := make(chan *packet, 16)
- tc.session.ackQueue = ackQueue
- return tc
-}
-
-func Test_doReadFromConn(t *testing.T) {
- s := makeTestingSession()
- p := makePacketForTLSConnTest(1, s) // next packet
- payload := p.Bytes()
-
- tc := makeTestingTLSConnForReadTest(payload)
- sendACKFn = func(net.Conn, *session, packetID) error {
- return nil
- }
- writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) {
- return 42, nil
- }
- b := make([]byte, 255)
- ok, n, err := doReadFromConn(tc, b)
- if err != nil {
- t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", nil, err)
- return
- }
- if !ok {
- t.Errorf("doReadFromBuffer(): expected ok=true, got ok=%v", ok)
- return
- }
- if n != 42 {
- t.Errorf("doReadFromBuffer(): expected %v, got %v", 42, n)
- }
- if len(tc.session.ackQueue) != 0 {
- t.Errorf("doReadFromBuffer(): ackQueue should be 0")
- }
-}
-
-func Test_doReadFromConn_Out_Of_Order_Packet(t *testing.T) {
- s := makeTestingSession()
- p := makePacketForTLSConnTest(2, s) // not next packet
- payload := p.Bytes()
-
- tc := makeTestingTLSConnForReadTest(payload)
-
- sendACKFn = func(net.Conn, *session, packetID) error {
- return nil
- }
- writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) {
- return 42, nil
- }
- b := make([]byte, 255)
- ok, n, err := doReadFromConn(tc, b)
- if err != nil {
- t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", nil, err)
- return
- }
- if ok {
- t.Errorf("doReadFromBuffer(): expected ok=false, got ok=%v", ok)
- return
- }
- if n != 0 {
- t.Errorf("doReadFromBuffer(): expected %v, got %v", 0, n)
- }
- if len(tc.session.ackQueue) != 1 {
- t.Errorf("doReadFromBuffer(): ackQueue should be 1")
- }
-}
-
-func Test_doReadFromConn_Bubble_Up_Errors(t *testing.T) {
- s := makeTestingSession()
- p := makePacketForTLSConnTest(1, s) // next packet
- payload := p.Bytes()
-
- tc := makeTestingTLSConnForReadTest(payload)
-
- makeUpError := errors.New("silly error")
-
- sendACKFn = func(net.Conn, *session, packetID) error {
- return makeUpError
- }
- writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) {
- return 42, nil
- }
- b := make([]byte, 255)
- _, _, err := doReadFromConn(tc, b)
- if !errors.Is(err, makeUpError) {
- t.Errorf("doReadFromBuffer(): wanted error=%v, got=%v", makeUpError, err)
- return
- }
-}
-
-func Test_doReadFromQueue(t *testing.T) {
- s := makeTestingSession()
- p := makePacketForTLSConnTest(2, s) // not next packet
- tc := makeTestingTLSConnForReadTest(p.Bytes()) // dont care, not going to use it
- tc.session.ackQueue <- p
-
- // mock ack and writes
- sendACKFn = func(net.Conn, *session, packetID) error {
- return nil
- }
- writeAndReadFromBufferFn = func(*bytes.Buffer, []byte, []byte) (int, error) {
- return 42, nil
- }
- b := make([]byte, 255)
- _, _, err := doReadFromQueue(tc, b)
- if err != nil {
- t.Errorf("doReadFromQueue(): wanted error=%v, got=%v", nil, err)
- }
-
-}
-
-func TestTLSConn_doRead(t *testing.T) {
- tt, _ := makeTestingTLSTransportWithDefaultPacketPayload()
- tc := &controlChannelTLSConn{transport: tt}
- _, err := tc.doRead()
- if err != nil {
- t.Errorf("TLSConn.doRead(): expected nil error")
- return
- }
-
- tc = &controlChannelTLSConn{}
- _, err = tc.doRead()
- if !errors.Is(err, errBadInput) {
- t.Errorf("TLSConn.doRead(): should fail with nil transport. got: %v, wanted: %v", err, errBadInput)
- return
- }
-
-}
-
-func TestTLSConn_canRead(t *testing.T) {
- tc := &controlChannelTLSConn{
- session: makeTestingSession(),
- }
- canRead := tc.canRead(nil)
- if canRead {
- t.Errorf("TLSConn.canRead() should return false with nil packet")
- }
-
- pNext := &packet{id: 1}
- canRead = tc.canRead(pNext)
- if !canRead {
- t.Errorf("TLSConn.canRead() should be able to read pID = 1")
- }
-
- pEq := &packet{id: 0}
- canRead = tc.canRead(pEq)
- if canRead {
- t.Errorf("TLSConn.canRead() should not able to read pID = 0")
- }
-
- tc.session.localPacketID = packetID(42)
- pMore := &packet{id: 44}
- canRead = tc.canRead(pMore)
- if canRead {
- t.Errorf("TLSConn.canRead() should not able to read pID = 44")
- }
-
- pLess := &packet{id: 41}
- canRead = tc.canRead(pLess)
- if canRead {
- t.Errorf("TLSConn.canRead() should not able to read pID = 41")
- }
-}
-
-func Test_writeAndReadFromBuffer(t *testing.T) {
- bb := &bytes.Buffer{}
- b := make([]byte, 255)
- payload := []byte("this test is green")
- n, err := writeAndReadFromBuffer(bb, b, payload)
- if err != nil {
- t.Error("writeAndReadFromBuffer(): expected no error")
- }
- if n != len(payload) {
- t.Errorf("writeAndReadFromBuffer(): got len = %v, wanted = %v", n, len(payload))
- }
-}
-
-func TestTLSConn_Close(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- err := tc.Close()
- if err != nil {
- t.Errorf("TLSConn.Close() error = %v, want = nil", err)
- }
- if !conn.closedCalled {
- t.Error("TLSConn.Close(): conn.Close() not called")
- }
-}
-
-func TestTLSConn_LocalAddr(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- want := "1.1.1.1"
- if addr := tc.LocalAddr(); addr.String() != want {
- t.Errorf("TLSConn.LocalAddr() got = %s, want = %s", addr, want)
- }
- if !conn.localAddrCalled {
- t.Error("TLSConn.LocalAddr(): conn.LocalAddr() not called")
- }
-}
-
-func TestTLSConn_RemoteAddr(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- want := "2.2.2.2"
- if addr := tc.RemoteAddr(); addr.String() != want {
- t.Errorf("TLSConn.RemoteAddr() got = %s, want = %s", addr, want)
- }
- if !conn.remoteAddrCalled {
- t.Error("TLSConn.RemoteAddr(): conn.RemoteAddr() not called")
- }
-}
-
-func TestTLSConn_SetDeadline(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- err := tc.SetDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("TLSConn.SetDeadline() error = %v, want = nil", err)
- }
- if !conn.setDeadlineCalled {
- t.Error("TLSConn.SetDeadline(): conn.SetDeadline() not called")
- }
-}
-
-func TestTLSConn_SetReadDeadline(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- err := tc.SetReadDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("TLSConn.SetReadDeadline() error = %v, want = nil", err)
- }
- if !conn.setReadDeadlineCalled {
- t.Error("TLSConn.SetReadDeadline(): conn.SetReadDeadline() not called")
- }
-}
-
-func TestTLSConn_SetWriteDeadline(t *testing.T) {
- tc, conn := makeTestingTLSConn()
- err := tc.SetWriteDeadline(time.Now().Add(time.Second))
- if err != nil {
- t.Errorf("TLSConn.SetWriteDeadline() error = %v, want = nil", err)
- }
- if !conn.setWriteDeadlineCalled {
- t.Error("TLSConn.SetWriteDeadline(): conn.SetWriteDeadline() not called")
- }
-}
-
-func TestTLSConn_Write(t *testing.T) {
- a := &mocks.Addr{}
- a.MockNetwork = func() string { return "udp" }
- conn := &mocks.Conn{}
- conn.MockLocalAddr = func() net.Addr { return a }
- c := &MockTLSTransportConn{Conn: conn}
- c.MockWrite = func(b []byte) (int, error) {
- c.written = b
- return len(b), nil
- }
- s := makeTestingSession()
- tlsTr := &tlsTransport{Conn: c, session: s}
- tc := &controlChannelTLSConn{transport: tlsTr, session: s}
-
- payload := []byte("this is fine")
- want := append(
- []byte{0x20, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- payload...)
- _, err := tc.Write(payload)
- if err != nil {
- t.Errorf("TLSConn.Write(): expected err = nil, got = %v", err)
- }
- if !bytes.Equal(c.written, want) {
- t.Errorf("TLSConn.Write(): written = %v, want = %v", c.written, want)
- }
-}
diff --git a/vpn/utils.go b/vpn/utils.go
deleted file mode 100644
index fd234a6c..00000000
--- a/vpn/utils.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package vpn
-
-//
-// Utility functions
-//
-
-// panicIfFalse calls panic with the given message if the given statement is false.
-func panicIfFalse(stmt bool, message interface{}) {
- if !stmt {
- panic(message)
- }
-}
diff --git a/vpn/utils_test.go b/vpn/utils_test.go
deleted file mode 100644
index 51f4eec1..00000000
--- a/vpn/utils_test.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package vpn
-
-import "testing"
-
-func Test_panicIfFalse(t *testing.T) {
- t.Run("panics when false", func(t *testing.T) {
- var happened bool
- func() {
- defer func() {
- happened = recover() != nil
- }()
- panicIfFalse(false, "should happen")
- }()
- if !happened {
- t.Fatal("did not panic")
- }
- })
-
- t.Run("does nothing when true", func(t *testing.T) {
- panicIfFalse(true, "should not happen")
- })
-}