From cbadb0a9904b3d666da0497f23f2ba4d541781c9 Mon Sep 17 00:00:00 2001 From: bonan Date: Sun, 12 Apr 2020 23:36:56 +0200 Subject: [PATCH] Add MQTT.Start() to better support detecting connection errors during runtime (#13) Breaking change: Start() must be called after New() to start transport --- transport/mqtt/interface.go | 17 ++- transport/mqtt/libmqtt.go | 115 +++++++++++++++++- transport/mqtt/message.go | 29 ++++- transport/mqtt/mqtt.go | 232 +++++++++++++++++++++++------------- transport/mqtt/raw.go | 11 +- 5 files changed, 315 insertions(+), 89 deletions(-) diff --git a/transport/mqtt/interface.go b/transport/mqtt/interface.go index 4d23cf8..b777628 100644 --- a/transport/mqtt/interface.go +++ b/transport/mqtt/interface.go @@ -1,8 +1,23 @@ package mqtt -import "lib.hemtjan.st/device" +import ( + "lib.hemtjan.st/device" +) type MQTT interface { + // Start connects to mqtt and block until disconnected. + // "ok" is true if the client is still valid and should be reused by calling Start() again + // + // Example of running with reconnect: + // for { + // ok, err := client.Start(ctx) + // if !ok { + // break + // } + // log.Printf("Error %v - retrying in 5 seconds", err) + // time.Sleep(5 * time.Second) + // } + Start() (ok bool, err error) TopicName(t EventType) string DeviceState() chan *device.State PublishMeta(topic string, payload []byte) diff --git a/transport/mqtt/libmqtt.go b/transport/mqtt/libmqtt.go index fe9eb78..c4408f8 100644 --- a/transport/mqtt/libmqtt.go +++ b/transport/mqtt/libmqtt.go @@ -1,6 +1,9 @@ package mqtt -import "github.com/goiiot/libmqtt" +import ( + "fmt" + "github.com/goiiot/libmqtt" +) type mqttClient interface { // Connect to all specified server with client options @@ -18,3 +21,113 @@ type mqttClient interface { // Destroy all client connection Destroy(force bool) } + +// Code is a wrapper for mqtt codes with utils for displaying the error. +// The type does not check if the code is an error or not, it always implements the error interface for convenience +type Code byte + +// Error returns the same information as String(), with the prefix "mqtt: " +func (c Code) Error() string { + return "mqtt: " + c.String() +} + +// String returns the same information as Text(), with the integer code appended +func (c Code) String() string { + return fmt.Sprintf("%s (%d)", c.Text(), uint8(c)) +} + +// Text returns a textual representation of the code +func (c Code) Text() string { + switch c { + case libmqtt.CodeSuccess: // 0 - Packet: ConnAck, PubAck, PubRecv, PubRel, PubComp, UnSubAck, Auth + return "success" + case libmqtt.CodeGrantedQos1: // 1 - Packet: SubAck + return "granted QoS 1" + case libmqtt.CodeGrantedQos2: // 2 - Packet: SubAck + return "granted QoS 2" + case libmqtt.CodeDisconnWithWill: // 4 - Packet: DisConn + return "disconnected with will" + case libmqtt.CodeNoMatchingSubscribers: // 16 - Packet: PubAck, PubRecv + return "no matching subscribers" + case libmqtt.CodeNoSubscriptionExisted: // 17 - Packet: UnSubAck + return "no subscription existed" + case libmqtt.CodeContinueAuth: // 24 - Packet: Auth + return "continue auth" + case libmqtt.CodeReAuth: // 25 - Packet: Auth + return "re auth" + case libmqtt.CodeUnspecifiedError: // 128 - Packet: ConnAck, PubAck, PubRecv, SubAck, UnSubAck, DisConn + return "unspecified error" + case libmqtt.CodeMalformedPacket: // 129 - Packet: ConnAck, DisConn + return "malformed packet" + case libmqtt.CodeProtoError: // 130 - Packet: ConnAck, DisConn + return "protocol error" + case libmqtt.CodeImplementationSpecificError: // 131 - Packet: ConnAck, PubAck, PubRecv, SubAck, UnSubAck, DisConn + return "implementation specific error" + case libmqtt.CodeUnsupportedProtoVersion: // 132 - Packet: ConnAck + return "unsupported protocol version" + case libmqtt.CodeClientIdNotValid: // 133 - Packet: ConnAck + return "client id not valid" + case libmqtt.CodeBadUserPass: // 134 - Packet: ConnAck + return "bad username or password" + case libmqtt.CodeNotAuthorized: // 135 - Packet: ConnAck, PubAck, PubRecv, SubAck, UnSubAck, DisConn + return "not authorized" + case libmqtt.CodeServerUnavail: // 136 - Packet: ConnAck + return "server unavailable" + case libmqtt.CodeServerBusy: // 137 - Packet: ConnAck, DisConn + return "server busy" + case libmqtt.CodeBanned: // 138 - Packet: ConnAck + return "banned" + case libmqtt.CodeServerShuttingDown: // 139 - Packet: DisConn + return "server is shutting down" + case libmqtt.CodeBadAuthenticationMethod: // 140 - Packet: ConnAck, DisConn + return "bad authentication method" + case libmqtt.CodeKeepaliveTimeout: // 141 - Packet: DisConn + return "keepalive timeout" + case libmqtt.CodeSessionTakenOver: // 142 - Packet: DisConn + return "session taken over" + case libmqtt.CodeTopicFilterInvalid: // 143 - Packet: SubAck, UnSubAck, DisConn + return "topic filter invalid" + case libmqtt.CodeTopicNameInvalid: // 144 - Packet: ConnAck, PubAck, PubRecv, DisConn + return "topic name invalid" + case libmqtt.CodePacketIdentifierInUse: // 145 - Packet: PubAck, PubRecv, PubAck, UnSubAck + return "packet identifier in use" + case libmqtt.CodePacketIdentifierNotFound: // 146 - Packet: PubRel, PubComp + return "packet identifier not found" + case libmqtt.CodeReceiveMaxExceeded: // 147 - Packet: DisConn + return "receive max exceeded" + case libmqtt.CodeTopicAliasInvalid: // 148 - Packet: DisConn + return "topic alias invalid" + case libmqtt.CodePacketTooLarge: // 149 - Packet: ConnAck, DisConn + return "packet too large" + case libmqtt.CodeMessageRateTooHigh: // 150 - Packet: DisConn + return "message rate too high" + case libmqtt.CodeQuotaExceeded: // 151 - Packet: ConnAck, PubAck, PubRec, SubAck, DisConn + return "quota exceeded" + case libmqtt.CodeAdministrativeAction: // 152 - Packet: DisConn + return "administrative action" + case libmqtt.CodePayloadFormatInvalid: // 153 - Packet: ConnAck, PubAck, PubRecv, DisConn + return "payload format invalid" + case libmqtt.CodeRetainNotSupported: // 154 - Packet: ConnAck, DisConn + return "retain not supported" + case libmqtt.CodeQosNoSupported: // 155 - Packet: ConnAck, DisConn + return "QoS not supported" + case libmqtt.CodeUseAnotherServer: // 156 - Packet: ConnAck, DisConn + return "use another server" + case libmqtt.CodeServerMoved: // 157 - Packet: ConnAck, DisConn + return "server moved" + case libmqtt.CodeSharedSubscriptionNotSupported: // 158 - Packet: SubAck, DisConn + return "shared subscription not supported" + case libmqtt.CodeConnectionRateExceeded: // 159 - Packet: ConnAck, DisConn + return "connection rate exceeded" + case libmqtt.CodeMaxConnectTime: // 160 - Packet: DisConn + return "max connect time reached" + case libmqtt.CodeSubscriptionIdentifiersNotSupported: // 161 - Packet: SubAck, DisConn + return "subscription identifiers not supported" + case libmqtt.CodeWildcardSubscriptionNotSupported: // 162 - Packet: SubAck, DisConn + return "wildcard subscriptions not supported" + case 255: + return "network error" + default: + return "unknown" + } +} diff --git a/transport/mqtt/message.go b/transport/mqtt/message.go index 6540e0e..be93cf4 100644 --- a/transport/mqtt/message.go +++ b/transport/mqtt/message.go @@ -162,7 +162,9 @@ func (m *mqtt) Publish(topic string, payload []byte, retain bool) { ) } -func (m *mqtt) Unsubscribe(topic string) bool { +// TODO: Make UnsubscribeRaw()? + +func (m *mqtt) Unsubscribe(topic string) (found bool) { m.Lock() defer m.Unlock() if v, ok := m.sub[topic]; ok { @@ -171,14 +173,31 @@ func (m *mqtt) Unsubscribe(topic string) bool { } delete(m.sub, topic) m.client.UnSubscribe(topic) - return true + found = true } - return false + if m.subRaw == nil { + return + } + if v, ok := m.subRaw[topic]; ok { + for _, ch := range v { + close(ch) + } + delete(m.sub, topic) + if !found { + m.client.UnSubscribe(topic) + found = true + } + } + return } func (m *mqtt) Resubscribe(oldTopic, newTopic string) bool { m.Lock() defer m.Unlock() + keep := false + if m.subRaw != nil { + _, keep = m.subRaw[oldTopic] + } if v, ok := m.sub[oldTopic]; ok { if _, ok := m.sub[newTopic]; !ok { m.sub[newTopic] = v @@ -189,7 +208,9 @@ func (m *mqtt) Resubscribe(oldTopic, newTopic string) bool { m.sub[newTopic] = append(m.sub[newTopic], v...) } delete(m.sub, oldTopic) - m.client.UnSubscribe(oldTopic) + if !keep { + m.client.UnSubscribe(oldTopic) + } return true } return false diff --git a/transport/mqtt/mqtt.go b/transport/mqtt/mqtt.go index b7c2a63..a1f5630 100644 --- a/transport/mqtt/mqtt.go +++ b/transport/mqtt/mqtt.go @@ -2,9 +2,6 @@ package mqtt import ( "context" - "errors" - "fmt" - "log" "sync" "time" @@ -18,7 +15,8 @@ type mqtt struct { deviceState chan *device.State client mqttClient addr string - initCh chan error + errCh chan error + stopCh []chan struct{} sub map[string][]chan []byte subRaw map[string][]chan *Packet willMap map[string][]string @@ -30,137 +28,207 @@ type mqtt struct { announceTopic string discoverTopic string leaveTopic string + ctx context.Context sync.RWMutex } -func New(ctx context.Context, c *Config) (m MQTT, err error) { +func New(ctx context.Context, c *Config) (MQTT, error) { if c == nil { return nil, ErrNoConfig } - if err = c.check(); err != nil { - return + if err := c.check(); err != nil { + return nil, err } - mq := &mqtt{ + m := &mqtt{ discoverDelay: c.DiscoverDelay, willID: c.ClientID, announceTopic: c.AnnounceTopic, discoverTopic: c.DiscoverTopic, leaveTopic: c.LeaveTopic, willMap: map[string][]string{}, + ctx: ctx, } + if m.announceTopic == "" { + m.announceTopic = "announce" + } + if m.discoverTopic == "" { + m.discoverTopic = "discover" + } + if m.leaveTopic == "" { + m.leaveTopic = "leave" + } + if m.discoverDelay == 0 { + m.discoverDelay = 5 * time.Second + } + opts := []libmqtt.Option{ - libmqtt.WithRouter(newRouter(mq)), + libmqtt.WithRouter(newRouter(m)), } opts = append(opts, c.opts()...) client, err := libmqtt.NewClient(opts...) if err != nil { - m = nil - return + return nil, err + } + m.client = client + + if ctx != nil && ctx.Done() != nil { + go func() { + <-ctx.Done() + m.destroy() + }() } - if err = mq.init(ctx, client); err != nil { - m = nil + return m, nil +} + +func (m *mqtt) isCancelled() bool { + if m.ctx != nil { + select { + case <-m.ctx.Done(): + return true + default: + } } - m = mq - return + return false } -func (m *mqtt) init(ctx context.Context, client mqttClient) (err error) { +type Err string + +func (e Err) Error() string { return string(e) } + +const ( + ErrIsAlreadyRunning Err = "already running" + ErrIsCancelled Err = "cancelled" +) + +func (m *mqtt) Start() (bool, error) { m.Lock() - if m.client != nil { + if m.errCh != nil { m.Unlock() - return errors.New("already initialized") + return false, ErrIsAlreadyRunning } - if m.announceTopic == "" { - m.announceTopic = "announce" + if m.isCancelled() { + m.Unlock() + return false, ErrIsCancelled } - if m.discoverTopic == "" { - m.discoverTopic = "discover" + + m.errCh = make(chan error) + if m.sub == nil { + m.sub = map[string][]chan []byte{} } - if m.leaveTopic == "" { - m.leaveTopic = "leave" + if m.subRaw == nil { + m.subRaw = map[string][]chan *Packet{} } - if m.discoverDelay == 0 { - m.discoverDelay = 5 * time.Second + + defer func() { + m.Lock() + stopCh := m.stopCh + m.stopCh = nil + m.errCh = nil + m.Unlock() + for _, ch := range stopCh { + close(ch) + } + }() + stopCh := make(chan struct{}) + m.stopCh = append(m.stopCh, stopCh) + if m.ctx != nil { + go func(ctx context.Context) { + select { + case <-ctx.Done(): + m.destroy() + case <-stopCh: + return + } + }(m.ctx) + } + m.Unlock() + m.client.Connect(m.onConnect) + err := <-m.errCh + return !m.isCancelled(), err +} + +func (m *mqtt) destroy() { + stopCh := make(chan struct{}) + m.Lock() + dsub := m.discoverSub + subs := m.sub + subRaw := m.subRaw + stateCh := m.deviceState + cl := m.client + if m.errCh != nil { + m.stopCh = append(m.stopCh, stopCh) + } else { + close(stopCh) + } + m.Unlock() + + if cl != nil { + cl.Destroy(true) } - m.initCh = make(chan error) + m.Lock() + if m.errCh != nil { + close(m.errCh) + m.errCh = nil + } + m.discoverSub = []chan struct{}{} m.sub = map[string][]chan []byte{} m.subRaw = map[string][]chan *Packet{} - m.client = client + m.deviceState = nil + m.client = nil m.Unlock() - m.client.Connect(m.onConnect) - err, _ = <-m.initCh - m.initCh = nil - if err != nil { - m.client.Destroy(true) - return + if stateCh != nil { + close(stateCh) } - if ctx != nil { - go func() { - <-ctx.Done() - m.client.Destroy(false) - m.Lock() - dsub := m.discoverSub - m.discoverSub = []chan struct{}{} - subs := m.sub - m.sub = map[string][]chan []byte{} - stateCh := m.deviceState - m.deviceState = nil - m.Unlock() - - if stateCh != nil { - close(stateCh) - } - - for _, ch := range dsub { - close(ch) - } + for _, ch := range dsub { + close(ch) + } - if m.subRaw != nil { - for _, v := range m.subRaw { - for _, vv := range v { - close(vv) - } - } - m.subRaw = map[string][]chan *Packet{} + if subRaw != nil { + for _, v := range subRaw { + for _, vv := range v { + func() { + defer func() { + _ = recover() + }() + close(vv) + }() } + } + } - if subs != nil { - for _, chans := range subs { - for _, ch := range chans { - close(ch) - } - } + if subs != nil { + for _, chans := range subs { + for _, ch := range chans { + func() { + defer func() { + _ = recover() + }() + close(ch) + }() } - }() + } } - - return } -func (m *mqtt) onConnect(server string, code byte, err error) { +func (m *mqtt) onConnect(server string, _code byte, err error) { m.Lock() defer m.Unlock() + code := Code(_code) if code != libmqtt.CodeSuccess && err == nil { - err = fmt.Errorf("error code %d", int(code)) - } - - if m.initCh != nil { - if err != nil { - m.initCh <- err - } else { - close(m.initCh) - } + err = code } if err != nil { - log.Printf("MQTT Connect Error: %s (%x) %v", server, code, err) - return + if m.errCh != nil { + m.errCh <- err + return + } } if m.deviceState != nil { diff --git a/transport/mqtt/raw.go b/transport/mqtt/raw.go index e5ded42..8b2a599 100644 --- a/transport/mqtt/raw.go +++ b/transport/mqtt/raw.go @@ -8,6 +8,10 @@ import ( func (m *mqtt) SubscribeRaw(topic string) chan *Packet { m.Lock() defer m.Unlock() + if m.subRaw == nil || m.client == nil { + // Probably not started + return nil + } c := make(chan *Packet, 5) if _, ok := m.subRaw[topic]; ok { @@ -33,7 +37,12 @@ func (m *mqtt) OnRaw(p *libmqtt.PublishPacket) { } m.RUnlock() for _, ch := range chans { - ch <- (*Packet)(p) + func() { + defer func() { + _ = recover() + }() + ch <- (*Packet)(p) + }() } }