mqtt_protocol.cc 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. #include "mqtt_protocol.h"
  2. #include "board.h"
  3. #include "application.h"
  4. #include "settings.h"
  5. #include <esp_log.h>
  6. #include <ml307_mqtt.h>
  7. #include <ml307_udp.h>
  8. #include <cstring>
  9. #include <arpa/inet.h>
  10. #include "assets/lang_config.h"
  11. #define TAG "MQTT"
  12. MqttProtocol::MqttProtocol() {
  13. event_group_handle_ = xEventGroupCreate();
  14. }
  15. MqttProtocol::~MqttProtocol() {
  16. ESP_LOGI(TAG, "MqttProtocol deinit");
  17. if (udp_ != nullptr) {
  18. delete udp_;
  19. }
  20. if (mqtt_ != nullptr) {
  21. delete mqtt_;
  22. }
  23. vEventGroupDelete(event_group_handle_);
  24. }
  25. void MqttProtocol::Start() {
  26. StartMqttClient(false);
  27. }
  28. bool MqttProtocol::StartMqttClient(bool report_error) {
  29. if (mqtt_ != nullptr) {
  30. ESP_LOGW(TAG, "Mqtt client already started");
  31. delete mqtt_;
  32. }
  33. Settings settings("mqtt", false);
  34. endpoint_ = settings.GetString("endpoint");
  35. client_id_ = settings.GetString("client_id");
  36. username_ = settings.GetString("username");
  37. password_ = settings.GetString("password");
  38. publish_topic_ = settings.GetString("publish_topic");
  39. if (endpoint_.empty()) {
  40. ESP_LOGW(TAG, "MQTT endpoint is not specified");
  41. if (report_error) {
  42. SetError(Lang::Strings::SERVER_NOT_FOUND);
  43. }
  44. return false;
  45. }
  46. mqtt_ = Board::GetInstance().CreateMqtt();
  47. mqtt_->SetKeepAlive(90);
  48. mqtt_->OnDisconnected([this]() {
  49. ESP_LOGI(TAG, "Disconnected from endpoint");
  50. });
  51. mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
  52. cJSON* root = cJSON_Parse(payload.c_str());
  53. if (root == nullptr) {
  54. ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
  55. return;
  56. }
  57. cJSON* type = cJSON_GetObjectItem(root, "type");
  58. if (type == nullptr) {
  59. ESP_LOGE(TAG, "Message type is not specified");
  60. cJSON_Delete(root);
  61. return;
  62. }
  63. if (strcmp(type->valuestring, "hello") == 0) {
  64. ParseServerHello(root);
  65. } else if (strcmp(type->valuestring, "goodbye") == 0) {
  66. auto session_id = cJSON_GetObjectItem(root, "session_id");
  67. ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
  68. if (session_id == nullptr || session_id_ == session_id->valuestring) {
  69. Application::GetInstance().Schedule([this]() {
  70. CloseAudioChannel();
  71. });
  72. }
  73. } else if (on_incoming_json_ != nullptr) {
  74. on_incoming_json_(root);
  75. }
  76. cJSON_Delete(root);
  77. last_incoming_time_ = std::chrono::steady_clock::now();
  78. });
  79. // ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint_.c_str());
  80. if (!mqtt_->Connect(endpoint_, 8883, client_id_, username_, password_)) {
  81. ESP_LOGE(TAG, "Failed to connect to endpoint");
  82. SetError(Lang::Strings::SERVER_NOT_CONNECTED);
  83. return false;
  84. }
  85. ESP_LOGI(TAG, "Connected to endpoint");
  86. return true;
  87. }
  88. void MqttProtocol::SendText(const std::string& text) {
  89. if (publish_topic_.empty()) {
  90. return;
  91. }
  92. if (!mqtt_->Publish(publish_topic_, text)) {
  93. ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
  94. SetError(Lang::Strings::SERVER_ERROR);
  95. }
  96. }
  97. void MqttProtocol::SendAudio(const std::vector<uint8_t>& data) {
  98. std::lock_guard<std::mutex> lock(channel_mutex_);
  99. if (udp_ == nullptr) {
  100. return;
  101. }
  102. std::string nonce(aes_nonce_);
  103. *(uint16_t*)&nonce[2] = htons(data.size());
  104. *(uint32_t*)&nonce[12] = htonl(++local_sequence_);
  105. std::string encrypted;
  106. encrypted.resize(aes_nonce_.size() + data.size());
  107. memcpy(encrypted.data(), nonce.data(), nonce.size());
  108. size_t nc_off = 0;
  109. uint8_t stream_block[16] = {0};
  110. if (mbedtls_aes_crypt_ctr(&aes_ctx_, data.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
  111. (uint8_t*)data.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
  112. ESP_LOGE(TAG, "Failed to encrypt audio data");
  113. return;
  114. }
  115. udp_->Send(encrypted);
  116. }
  117. void MqttProtocol::CloseAudioChannel() {
  118. {
  119. std::lock_guard<std::mutex> lock(channel_mutex_);
  120. if (udp_ != nullptr) {
  121. delete udp_;
  122. udp_ = nullptr;
  123. }
  124. }
  125. std::string message = "{";
  126. message += "\"session_id\":\"" + session_id_ + "\",";
  127. message += "\"type\":\"goodbye\"";
  128. message += "}";
  129. SendText(message);
  130. if (on_audio_channel_closed_ != nullptr) {
  131. on_audio_channel_closed_();
  132. }
  133. }
  134. bool MqttProtocol::OpenAudioChannel() {
  135. if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
  136. ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
  137. if (!StartMqttClient(true)) {
  138. return false;
  139. }
  140. }
  141. error_occurred_ = false;
  142. session_id_ = "";
  143. xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
  144. // 发送 hello 消息申请 UDP 通道
  145. std::string message = "{";
  146. message += "\"type\":\"hello\",";
  147. message += "\"version\": 3,";
  148. message += "\"transport\":\"udp\",";
  149. message += "\"audio_params\":{";
  150. message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
  151. message += "}}";
  152. SendText(message);
  153. // 等待服务器响应
  154. EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
  155. if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
  156. ESP_LOGE(TAG, "Failed to receive server hello");
  157. SetError(Lang::Strings::SERVER_TIMEOUT);
  158. return false;
  159. }
  160. std::lock_guard<std::mutex> lock(channel_mutex_);
  161. if (udp_ != nullptr) {
  162. delete udp_;
  163. }
  164. udp_ = Board::GetInstance().CreateUdp();
  165. udp_->OnMessage([this](const std::string& data) {
  166. if (data.size() < sizeof(aes_nonce_)) {
  167. ESP_LOGE(TAG, "Invalid audio packet size: %zu", data.size());
  168. return;
  169. }
  170. if (data[0] != 0x01) {
  171. ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
  172. return;
  173. }
  174. uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
  175. if (sequence < remote_sequence_) {
  176. ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
  177. return;
  178. }
  179. if (sequence != remote_sequence_ + 1) {
  180. ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
  181. }
  182. std::vector<uint8_t> decrypted;
  183. size_t decrypted_size = data.size() - aes_nonce_.size();
  184. size_t nc_off = 0;
  185. uint8_t stream_block[16] = {0};
  186. decrypted.resize(decrypted_size);
  187. auto nonce = (uint8_t*)data.data();
  188. auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
  189. int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)decrypted.data());
  190. if (ret != 0) {
  191. ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
  192. return;
  193. }
  194. if (on_incoming_audio_ != nullptr) {
  195. on_incoming_audio_(std::move(decrypted));
  196. }
  197. remote_sequence_ = sequence;
  198. last_incoming_time_ = std::chrono::steady_clock::now();
  199. });
  200. udp_->Connect(udp_server_, udp_port_);
  201. if (on_audio_channel_opened_ != nullptr) {
  202. on_audio_channel_opened_();
  203. }
  204. return true;
  205. }
  206. void MqttProtocol::ParseServerHello(const cJSON* root) {
  207. auto transport = cJSON_GetObjectItem(root, "transport");
  208. if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
  209. ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
  210. return;
  211. }
  212. auto session_id = cJSON_GetObjectItem(root, "session_id");
  213. if (session_id != nullptr) {
  214. session_id_ = session_id->valuestring;
  215. ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
  216. }
  217. // Get sample rate from hello message
  218. auto audio_params = cJSON_GetObjectItem(root, "audio_params");
  219. if (audio_params != NULL) {
  220. auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
  221. if (sample_rate != NULL) {
  222. server_sample_rate_ = sample_rate->valueint;
  223. }
  224. }
  225. auto udp = cJSON_GetObjectItem(root, "udp");
  226. if (udp == nullptr) {
  227. ESP_LOGE(TAG, "UDP is not specified");
  228. return;
  229. }
  230. udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
  231. udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
  232. auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
  233. auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
  234. // auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
  235. // ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
  236. aes_nonce_ = DecodeHexString(nonce);
  237. mbedtls_aes_init(&aes_ctx_);
  238. mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
  239. local_sequence_ = 0;
  240. remote_sequence_ = 0;
  241. xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
  242. }
  243. static const char hex_chars[] = "0123456789ABCDEF";
  244. // 辅助函数,将单个十六进制字符转换为对应的数值
  245. static inline uint8_t CharToHex(char c) {
  246. if (c >= '0' && c <= '9') return c - '0';
  247. if (c >= 'A' && c <= 'F') return c - 'A' + 10;
  248. if (c >= 'a' && c <= 'f') return c - 'a' + 10;
  249. return 0; // 对于无效输入,返回0
  250. }
  251. std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
  252. std::string decoded;
  253. decoded.reserve(hex_string.size() / 2);
  254. for (size_t i = 0; i < hex_string.size(); i += 2) {
  255. char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
  256. decoded.push_back(byte);
  257. }
  258. return decoded;
  259. }
  260. bool MqttProtocol::IsAudioChannelOpened() const {
  261. return udp_ != nullptr && !error_occurred_ && !IsTimeout();
  262. }