Skip to content

Commit

Permalink
Merge pull request #1 from huweihuang/feat/fix-memory-leak
Browse files Browse the repository at this point in the history
fix memory leak
  • Loading branch information
rambohe-ch authored May 18, 2023
2 parents 48fbcdd + 1e7abb5 commit cb59e86
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 10 deletions.
8 changes: 4 additions & 4 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -49,7 +48,8 @@ var (
type ProxyClientConnection struct {
Mode string
Grpc client.ProxyService_ProxyServer
HTTP net.Conn
HTTP io.ReadWriter
CloseHTTP func() error
connected chan struct{}
connectID int64
agentID string
Expand All @@ -67,13 +67,13 @@ func (c *ProxyClientConnection) send(pkt *client.Packet) error {
return stream.Send(pkt)
} else if c.Mode == "http-connect" {
if pkt.Type == client.PacketType_CLOSE_RSP {
return c.HTTP.Close()
return c.CloseHTTP()
} else if pkt.Type == client.PacketType_DATA {
_, err := c.HTTP.Write(pkt.GetData().Data)
return err
} else if pkt.Type == client.PacketType_DIAL_RSP {
if pkt.GetDialResponse().Error != "" {
return c.HTTP.Close()
return c.CloseHTTP()
}
return nil
} else {
Expand Down
15 changes: 12 additions & 3 deletions pkg/server/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"math/rand"
"net/http"
"sync"
"time"

"k8s.io/klog/v2"
Expand Down Expand Up @@ -54,7 +55,8 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
var closeOnce sync.Once
defer closeOnce.Do(func() { conn.Close() })

random := rand.Int63() /* #nosec G404 */
dialRequest := &client.Packet{
Expand All @@ -75,10 +77,16 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

closed := make(chan struct{})
connected := make(chan struct{})
connection := &ProxyClientConnection{
Mode: "http-connect",
HTTP: conn,
Mode: "http-connect",
HTTP: io.ReadWriter(conn), // pass as ReadWriter so the caller must close with CloseHTTP
CloseHTTP: func() error {
closeOnce.Do(func() { conn.Close() })
close(closed)
return nil
},
connected: connected,
start: time.Now(),
backend: backend,
Expand All @@ -103,6 +111,7 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) {

select {
case <-connection.connected: // Waiting for response before we begin full communication.
case <-closed: // Connection was closed before being established
}

if connection.connectID == 0 {
Expand Down
81 changes: 78 additions & 3 deletions tests/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -254,6 +255,71 @@ func TestBasicProxy_HTTPCONN(t *testing.T) {

}

func TestFailedDial_HTTPCONN(t *testing.T) {
server := httptest.NewServer(newEchoServer("hello"))
server.Close() // cleanup immediately so connections will fail

stopCh := make(chan struct{})
defer close(stopCh)

proxy, cleanup, err := runHTTPConnProxyServer()
if err != nil {
t.Fatal(err)
}
defer cleanup()

runAgent(proxy.agent, stopCh)

// Wait for agent to register on proxy server
time.Sleep(time.Second)

conn, err := net.Dial("tcp", proxy.front)
if err != nil {
t.Error(err)
}

serverURL, _ := url.Parse(server.URL)

// Send HTTP-Connect request
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", serverURL.Host, "127.0.0.1")
if err != nil {
t.Error(err)
}

// Parse the HTTP response for Connect
br := bufio.NewReader(conn)
res, err := http.ReadResponse(br, nil)
if err != nil {
t.Errorf("reading HTTP response from CONNECT: %v", err)
}
if res.StatusCode != 200 {
t.Errorf("expect 200; got %d", res.StatusCode)
}

dialer := func(network, addr string) (net.Conn, error) {
return conn, nil
}

c := &http.Client{
Transport: &http.Transport{
Dial: dialer,
},
}

_, err = c.Get(server.URL)
if err == nil || !strings.Contains(err.Error(), "connection reset by peer") {
t.Error(err)
}

for i := 0; i < 20; i++ {
if proxy.getActiveHTTPConnectConns() == 0 {
return
}
time.Sleep(time.Millisecond * 10)
}
t.Errorf("expected connection to eventually be closed")
}

func localAddr(addr net.Addr) string {
return addr.String()
}
Expand All @@ -262,6 +328,8 @@ type proxy struct {
server *server.ProxyServer
front string
agent string

getActiveHTTPConnectConns func() int
}

func runGRPCProxyServer() (proxy, func(), error) {
Expand Down Expand Up @@ -326,10 +394,17 @@ func runHTTPConnProxyServer() (proxy, func(), error) {
proxy.agent = localAddr(lis.Addr())

// http-connect
active := int32(0)
proxy.getActiveHTTPConnectConns = func() int { return int(atomic.LoadInt32(&active)) }
handler := &server.Tunnel{
Server: s,
}
httpServer := &http.Server{
Handler: &server.Tunnel{
Server: s,
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&active, 1)
defer atomic.AddInt32(&active, -1)
handler.ServeHTTP(w, r)
}),
}
lis2, err := net.Listen("tcp", "")
if err != nil {
Expand Down

0 comments on commit cb59e86

Please sign in to comment.