ota.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #include "ota.h"
  2. #include "system_info.h"
  3. #include "board.h"
  4. #include "settings.h"
  5. #include <cJSON.h>
  6. #include <esp_log.h>
  7. #include <esp_partition.h>
  8. #include <esp_ota_ops.h>
  9. #include <esp_app_format.h>
  10. #include <cstring>
  11. #include <vector>
  12. #include <sstream>
  13. #include <algorithm>
  14. #define TAG "Ota"
  15. Ota::Ota() {
  16. }
  17. Ota::~Ota() {
  18. }
  19. void Ota::SetCheckVersionUrl(std::string check_version_url) {
  20. check_version_url_ = check_version_url;
  21. }
  22. void Ota::SetHeader(const std::string& key, const std::string& value) {
  23. headers_[key] = value;
  24. }
  25. void Ota::SetPostData(const std::string& post_data) {
  26. post_data_ = post_data;
  27. }
  28. bool Ota::CheckVersion() {
  29. current_version_ = esp_app_get_description()->version;
  30. ESP_LOGI(TAG, "Current version: %s", current_version_.c_str());
  31. if (check_version_url_.length() < 10) {
  32. ESP_LOGE(TAG, "Check version URL is not properly set");
  33. return false;
  34. }
  35. auto http = Board::GetInstance().CreateHttp();
  36. for (const auto& header : headers_) {
  37. http->SetHeader(header.first, header.second);
  38. }
  39. http->SetHeader("Content-Type", "application/json");
  40. std::string method = post_data_.length() > 0 ? "POST" : "GET";
  41. if (!http->Open(method, check_version_url_, post_data_)) {
  42. ESP_LOGE(TAG, "Failed to open HTTP connection");
  43. delete http;
  44. return false;
  45. }
  46. auto response = http->GetBody();
  47. http->Close();
  48. delete http;
  49. // Response: { "firmware": { "version": "1.0.0", "url": "http://" } }
  50. // Parse the JSON response and check if the version is newer
  51. // If it is, set has_new_version_ to true and store the new version and URL
  52. cJSON *root = cJSON_Parse(response.c_str());
  53. if (root == NULL) {
  54. ESP_LOGE(TAG, "Failed to parse JSON response");
  55. return false;
  56. }
  57. has_activation_code_ = false;
  58. cJSON *activation = cJSON_GetObjectItem(root, "activation");
  59. if (activation != NULL) {
  60. cJSON* message = cJSON_GetObjectItem(activation, "message");
  61. if (message != NULL) {
  62. activation_message_ = message->valuestring;
  63. }
  64. cJSON* code = cJSON_GetObjectItem(activation, "code");
  65. if (code != NULL) {
  66. activation_code_ = code->valuestring;
  67. }
  68. has_activation_code_ = true;
  69. }
  70. has_mqtt_config_ = false;
  71. cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
  72. if (mqtt != NULL) {
  73. Settings settings("mqtt", true);
  74. cJSON *item = NULL;
  75. cJSON_ArrayForEach(item, mqtt) {
  76. if (item->type == cJSON_String) {
  77. if (settings.GetString(item->string) != item->valuestring) {
  78. settings.SetString(item->string, item->valuestring);
  79. }
  80. }
  81. }
  82. has_mqtt_config_ = true;
  83. }
  84. has_server_time_ = false;
  85. cJSON *server_time = cJSON_GetObjectItem(root, "server_time");
  86. if (server_time != NULL) {
  87. cJSON *timestamp = cJSON_GetObjectItem(server_time, "timestamp");
  88. cJSON *timezone_offset = cJSON_GetObjectItem(server_time, "timezone_offset");
  89. if (timestamp != NULL) {
  90. // 设置系统时间
  91. struct timeval tv;
  92. double ts = timestamp->valuedouble;
  93. // 如果有时区偏移,计算本地时间
  94. if (timezone_offset != NULL) {
  95. ts += (timezone_offset->valueint * 60 * 1000); // 转换分钟为毫秒
  96. }
  97. tv.tv_sec = (time_t)(ts / 1000); // 转换毫秒为秒
  98. tv.tv_usec = (suseconds_t)((long long)ts % 1000) * 1000; // 剩余的毫秒转换为微秒
  99. settimeofday(&tv, NULL);
  100. has_server_time_ = true;
  101. }
  102. }
  103. cJSON *firmware = cJSON_GetObjectItem(root, "firmware");
  104. if (firmware == NULL) {
  105. ESP_LOGE(TAG, "Failed to get firmware object");
  106. cJSON_Delete(root);
  107. return false;
  108. }
  109. cJSON *version = cJSON_GetObjectItem(firmware, "version");
  110. if (version == NULL) {
  111. ESP_LOGE(TAG, "Failed to get version object");
  112. cJSON_Delete(root);
  113. return false;
  114. }
  115. cJSON *url = cJSON_GetObjectItem(firmware, "url");
  116. if (url == NULL) {
  117. ESP_LOGE(TAG, "Failed to get url object");
  118. cJSON_Delete(root);
  119. return false;
  120. }
  121. firmware_version_ = version->valuestring;
  122. firmware_url_ = url->valuestring;
  123. cJSON_Delete(root);
  124. // Check if the version is newer, for example, 0.1.0 is newer than 0.0.1
  125. has_new_version_ = IsNewVersionAvailable(current_version_, firmware_version_);
  126. if (has_new_version_) {
  127. ESP_LOGI(TAG, "New version available: %s", firmware_version_.c_str());
  128. } else {
  129. ESP_LOGI(TAG, "Current is the latest version");
  130. }
  131. return true;
  132. }
  133. void Ota::MarkCurrentVersionValid() {
  134. auto partition = esp_ota_get_running_partition();
  135. if (strcmp(partition->label, "factory") == 0) {
  136. ESP_LOGI(TAG, "Running from factory partition, skipping");
  137. return;
  138. }
  139. ESP_LOGI(TAG, "Running partition: %s", partition->label);
  140. esp_ota_img_states_t state;
  141. if (esp_ota_get_state_partition(partition, &state) != ESP_OK) {
  142. ESP_LOGE(TAG, "Failed to get state of partition");
  143. return;
  144. }
  145. if (state == ESP_OTA_IMG_PENDING_VERIFY) {
  146. ESP_LOGI(TAG, "Marking firmware as valid");
  147. esp_ota_mark_app_valid_cancel_rollback();
  148. }
  149. }
  150. void Ota::Upgrade(const std::string& firmware_url) {
  151. ESP_LOGI(TAG, "Upgrading firmware from %s", firmware_url.c_str());
  152. esp_ota_handle_t update_handle = 0;
  153. auto update_partition = esp_ota_get_next_update_partition(NULL);
  154. if (update_partition == NULL) {
  155. ESP_LOGE(TAG, "Failed to get update partition");
  156. return;
  157. }
  158. ESP_LOGI(TAG, "Writing to partition %s at offset 0x%lx", update_partition->label, update_partition->address);
  159. bool image_header_checked = false;
  160. std::string image_header;
  161. auto http = Board::GetInstance().CreateHttp();
  162. if (!http->Open("GET", firmware_url)) {
  163. ESP_LOGE(TAG, "Failed to open HTTP connection");
  164. delete http;
  165. return;
  166. }
  167. size_t content_length = http->GetBodyLength();
  168. if (content_length == 0) {
  169. ESP_LOGE(TAG, "Failed to get content length");
  170. delete http;
  171. return;
  172. }
  173. char buffer[512];
  174. size_t total_read = 0, recent_read = 0;
  175. auto last_calc_time = esp_timer_get_time();
  176. while (true) {
  177. int ret = http->Read(buffer, sizeof(buffer));
  178. if (ret < 0) {
  179. ESP_LOGE(TAG, "Failed to read HTTP data: %s", esp_err_to_name(ret));
  180. delete http;
  181. return;
  182. }
  183. // Calculate speed and progress every second
  184. recent_read += ret;
  185. total_read += ret;
  186. if (esp_timer_get_time() - last_calc_time >= 1000000 || ret == 0) {
  187. size_t progress = total_read * 100 / content_length;
  188. ESP_LOGI(TAG, "Progress: %zu%% (%zu/%zu), Speed: %zuB/s", progress, total_read, content_length, recent_read);
  189. if (upgrade_callback_) {
  190. upgrade_callback_(progress, recent_read);
  191. }
  192. last_calc_time = esp_timer_get_time();
  193. recent_read = 0;
  194. }
  195. if (ret == 0) {
  196. break;
  197. }
  198. if (!image_header_checked) {
  199. image_header.append(buffer, ret);
  200. if (image_header.size() >= sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t) + sizeof(esp_app_desc_t)) {
  201. esp_app_desc_t new_app_info;
  202. memcpy(&new_app_info, image_header.data() + sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t), sizeof(esp_app_desc_t));
  203. ESP_LOGI(TAG, "New firmware version: %s", new_app_info.version);
  204. auto current_version = esp_app_get_description()->version;
  205. if (memcmp(new_app_info.version, current_version, sizeof(new_app_info.version)) == 0) {
  206. ESP_LOGE(TAG, "Firmware version is the same, skipping upgrade");
  207. delete http;
  208. return;
  209. }
  210. if (esp_ota_begin(update_partition, OTA_WITH_SEQUENTIAL_WRITES, &update_handle)) {
  211. esp_ota_abort(update_handle);
  212. delete http;
  213. ESP_LOGE(TAG, "Failed to begin OTA");
  214. return;
  215. }
  216. image_header_checked = true;
  217. std::string().swap(image_header);
  218. }
  219. }
  220. auto err = esp_ota_write(update_handle, buffer, ret);
  221. if (err != ESP_OK) {
  222. ESP_LOGE(TAG, "Failed to write OTA data: %s", esp_err_to_name(err));
  223. esp_ota_abort(update_handle);
  224. delete http;
  225. return;
  226. }
  227. }
  228. delete http;
  229. esp_err_t err = esp_ota_end(update_handle);
  230. if (err != ESP_OK) {
  231. if (err == ESP_ERR_OTA_VALIDATE_FAILED) {
  232. ESP_LOGE(TAG, "Image validation failed, image is corrupted");
  233. } else {
  234. ESP_LOGE(TAG, "Failed to end OTA: %s", esp_err_to_name(err));
  235. }
  236. return;
  237. }
  238. err = esp_ota_set_boot_partition(update_partition);
  239. if (err != ESP_OK) {
  240. ESP_LOGE(TAG, "Failed to set boot partition: %s", esp_err_to_name(err));
  241. return;
  242. }
  243. ESP_LOGI(TAG, "Firmware upgrade successful, rebooting in 3 seconds...");
  244. vTaskDelay(pdMS_TO_TICKS(3000));
  245. esp_restart();
  246. }
  247. void Ota::StartUpgrade(std::function<void(int progress, size_t speed)> callback) {
  248. upgrade_callback_ = callback;
  249. Upgrade(firmware_url_);
  250. }
  251. std::vector<int> Ota::ParseVersion(const std::string& version) {
  252. std::vector<int> versionNumbers;
  253. std::stringstream ss(version);
  254. std::string segment;
  255. while (std::getline(ss, segment, '.')) {
  256. versionNumbers.push_back(std::stoi(segment));
  257. }
  258. return versionNumbers;
  259. }
  260. bool Ota::IsNewVersionAvailable(const std::string& currentVersion, const std::string& newVersion) {
  261. std::vector<int> current = ParseVersion(currentVersion);
  262. std::vector<int> newer = ParseVersion(newVersion);
  263. for (size_t i = 0; i < std::min(current.size(), newer.size()); ++i) {
  264. if (newer[i] > current[i]) {
  265. return true;
  266. } else if (newer[i] < current[i]) {
  267. return false;
  268. }
  269. }
  270. return newer.size() > current.size();
  271. }