Files
MSE-PI-E2EEDA-Plein-de-eeee…/db/src/mqtt/mqtt.go

249 lines
7.4 KiB
Go

// Package mqtt_gateway provides an abstraction to an MQTT broker client.
package mqtt
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
dp "gateway/point"
"log"
"strings"
"time"
mqtt "github.com/eclipse/paho.mqtt.golang"
)
const (
maxQoS = 2
defaultTimeout = 5 * time.Second
)
// A MqttParams is the abstracted MQTT gateway.
// It provides the MQTT parameters to initialize the connection and the method to add data.
type MqttParams struct {
Broker string
ClientId string
Qos byte
Username string
Password string
TlsConfig *tls.Config
OnConnect mqtt.OnConnectHandler
OnConnectionLost mqtt.ConnectionLostHandler
Timeout time.Duration
}
// MqttGateway is the abstracted MQTT gateway.
// It connects to the MQTT broker and provides the method to send data to the broker.
type MqttGateway struct {
MqttParams MqttParams
Client mqtt.Client
}
// mqttPayload is the JSON structure published to the broker in a specific topic.
// It contains the values and the timestamp of the data.
type mqttPayload map[string]any
// connectHandler is called when the client connects to the broker. It prints a message to the console.
var connectHandler mqtt.OnConnectHandler = func(client mqtt.Client) {
log.Println("[MQTT Gateway] Connected to MQTT Broker")
}
// connectLostHandler is called when the client loses connection to the broker. It prints a message to the console with the error.
var connectLostHandler mqtt.ConnectionLostHandler = func(client mqtt.Client, err error) {
log.Printf("[MQTT Gateway] Connection lost: %v\n", err)
}
func getTopic(t []dp.Topic) string {
var topic []string
for _, t := range t {
topic = append(topic, t.Content)
}
return strings.Join(topic, "/")
}
// NewMqttGateway creates a new MqttGateway with the given parameters.
// And establishes the connection
func NewMqttGateway(p MqttParams) (*MqttGateway, error) {
// Verify input variable
if p.Broker == "" {
return nil, errors.New("[MQTT Gateway] Invalid broker address")
}
if p.ClientId == "" {
return nil, errors.New("[MQTT Gateway] Invalid client id")
}
if p.Qos > maxQoS {
return nil, errors.New("[MQTT Gateway] Invalid QoS level")
}
if p.Timeout == 0 {
// Set to default value
p.Timeout = defaultTimeout
}
opts := mqtt.NewClientOptions()
opts.AddBroker(p.Broker)
opts.SetClientID(p.ClientId)
if p.TlsConfig != nil {
opts.SetTLSConfig(p.TlsConfig)
}
if p.OnConnect != nil {
opts.SetOnConnectHandler(p.OnConnect)
} else {
opts.SetOnConnectHandler(connectHandler)
}
if p.OnConnectionLost != nil {
opts.SetConnectionLostHandler(p.OnConnectionLost)
} else {
opts.SetConnectionLostHandler(connectLostHandler)
}
if p.Username != "" {
opts.SetUsername(p.Username)
opts.SetPassword(p.Password)
}
client := mqtt.NewClient(opts)
token := client.Connect()
if !token.WaitTimeout(p.Timeout) {
return nil, fmt.Errorf("[MQTT Gateway] Mqtt connect timed out")
}
if err := token.Error(); err != nil {
return nil, fmt.Errorf("[MQTT Gateway] Mqtt connect failed: %w", err)
}
return &MqttGateway{
MqttParams: p,
Client: client,
}, nil
}
// SendData is used to send data in the MQTT gateway.
// It uses the DataPointInfo interface for abstracting the generic type of the DataPoint
func (g *MqttGateway) SendData(msg dp.DataPointInfo) error {
topic := getTopic(msg.Tags())
if topic == "" {
return errors.New("[MQTT Gateway] Invalid topic")
}
payload := mqttPayload{
"timestamp": msg.Timestamp().Unix(),
}
for key, value := range msg.PayloadAsAny() {
payload[key] = value
}
payloadJson, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("[MQTT Gateway] Failed to marshal payload: %w", err)
}
token := g.Client.Publish(topic, g.MqttParams.Qos, false, payloadJson)
if !token.WaitTimeout(g.MqttParams.Timeout) {
return fmt.Errorf("[MQTT Gateway] Mqtt connect timed out")
}
if token.Error() != nil {
return fmt.Errorf("[MQTT Gateway] Failed to publish message: %w", token.Error())
}
return nil
}
// Disconnect is used to disconnect the MQTT gateway from the broker.
// It prints a message to the console when the disconnection is successful.
func (g *MqttGateway) Disconnect() {
g.Client.Disconnect(0)
log.Println("[MQTT Gateway] Disconnected from MQTT Broker")
}
// Subscribe is used to subscribe to a topic in the MQTT gateway.
// It takes a topic and a callback function as parameters.
// The callback function is called when a message is received on the subscribed topic.
func (g *MqttGateway) Subscribe(topic string, callback mqtt.MessageHandler) error {
token := g.Client.Subscribe(topic, g.MqttParams.Qos, callback)
if !token.WaitTimeout(g.MqttParams.Timeout) {
return fmt.Errorf("[MQTT Gateway] MQTT gateway timed out")
}
if token.Error() != nil {
return fmt.Errorf("[MQTT Gateway] MQTT gateway failed to subscribe: %w", token.Error())
}
log.Printf("[MQTT Gateway] Subscribed to topic: %s\n", topic)
return nil
}
// Unsubscribe is used to unsubscribe from a topic in the MQTT gateway.
func (g *MqttGateway) Unsubscribe(topic string) error {
token := g.Client.Unsubscribe(topic)
if !token.WaitTimeout(g.MqttParams.Timeout) {
return fmt.Errorf("[MQTT Gateway] MQTT gateway timed out")
}
if token.Error() != nil {
return fmt.Errorf("[MQTT Gateway] MQTT gateway failed to unsubscribe: %w", token.Error())
}
log.Printf("[MQTT Gateway] Unsubscribed from topic: %s\n", topic)
return nil
}
// SubscribeTyped is a helper to subscribe to a topic and automatically convert
// the received JSON message to a DataPoint of type T.
// T should be a struct or a map that matches the JSON payload (excluding timestamp).
// tagSubjects is a list of tag subjects that correspond to the parts of the topic.
// For example, if the topic is "provence/B3/update" and tagSubjects is ["city", "room"],
// it will create tags {Subject: "city", Content: "provence"} and {Subject: "room", Content: "B3"}.
func SubscribeTyped[T any](g *MqttGateway, topic string, m dp.Measurement[T], tagSubjects []string, handler func(dp.DataPoint[T])) error {
return g.Subscribe(topic, func(client mqtt.Client, msg mqtt.Message) {
// Unmarshal into T for fields
var fields T
if err := json.Unmarshal(msg.Payload(), &fields); err != nil {
log.Printf("[MQTT Gateway] Error unmarshaling fields: %v", err)
return
}
// Unmarshal into a map to extract timestamp
var raw map[string]any
if err := json.Unmarshal(msg.Payload(), &raw); err != nil {
log.Printf("[MQTT Gateway] Error unmarshaling raw: %v", err)
return
}
ts := time.Now()
if tsRaw, ok := raw["timestamp"].(string); ok {
// Try RFC3339 first (default)
if parsedTs, err := time.Parse(time.RFC3339, tsRaw); err == nil {
ts = parsedTs
} else {
log.Printf("[MQTT Gateway] Failed to parse timestamp '%s' as RFC3339: %v", tsRaw, err)
}
} else if tsRaw, ok := raw["timestamp"].(float64); ok {
// Handle Unix timestamp in seconds
ts = time.Unix(int64(tsRaw), 0)
}
// Extract tags from topic
parts := strings.Split(msg.Topic(), "/")
var tags []dp.Topic
for i, subject := range tagSubjects {
if i < len(parts) && subject != "" {
tags = append(tags, dp.Topic{Subject: subject, Content: parts[i]})
}
}
// Fallback for backward compatibility if no tagSubjects provided
if len(tagSubjects) == 0 && len(parts) > 1 {
tags = append(tags, dp.Topic{Subject: "id", Content: parts[1]})
}
dp := m.CreateDataPoint(tags, fields, ts)
handler(dp)
})
}