wake_word_detect.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. #include "wake_word_detect.h"
  2. #include "application.h"
  3. #include <esp_log.h>
  4. #include <model_path.h>
  5. #include <arpa/inet.h>
  6. #include <sstream>
  7. #include <esp_mn_iface.h>
  8. #include <esp_mn_models.h>
  9. #include <esp_process_sdkconfig.h>
  10. #define DETECTION_RUNNING_EVENT 1
  11. srmodel_list_t *models = NULL;
  12. static const char* TAG = "WakeWordDetect";
  13. WakeWordDetect::WakeWordDetect()
  14. : afe_detection_data_(nullptr),
  15. wake_word_pcm_(),
  16. wake_word_opus_() {
  17. event_group_ = xEventGroupCreate();
  18. }
  19. WakeWordDetect::~WakeWordDetect() {
  20. if (afe_detection_data_ != nullptr) {
  21. esp_afe_sr_v1.destroy(afe_detection_data_);
  22. }
  23. if (wake_word_encode_task_stack_ != nullptr) {
  24. heap_caps_free(wake_word_encode_task_stack_);
  25. }
  26. vEventGroupDelete(event_group_);
  27. }
  28. void WakeWordDetect::Initialize(int channels, bool reference) {
  29. channels_ = channels;
  30. reference_ = reference;
  31. int ref_num = reference_ ? 1 : 0;
  32. models = esp_srmodel_init("model");
  33. for (int i = 0; i < models->num; i++) {
  34. ESP_LOGI(TAG, "Model %d: %s", i, models->model_name[i]);
  35. if (strstr(models->model_name[i], ESP_WN_PREFIX) != NULL) {
  36. wakenet_model_ = models->model_name[i];
  37. auto words = esp_srmodel_get_wake_words(models, wakenet_model_);
  38. // split by ";" to get all wake words
  39. std::stringstream ss(words);
  40. std::string word;
  41. while (std::getline(ss, word, ';')) {
  42. wake_words_.push_back(word);
  43. }
  44. }
  45. }
  46. afe_config_t afe_config = {
  47. .aec_init = reference_,
  48. .se_init = true,
  49. .vad_init = true,
  50. .wakenet_init = true,
  51. .voice_communication_init = false,
  52. .voice_communication_agc_init = false,
  53. .voice_communication_agc_gain = 10,
  54. .vad_mode = VAD_MODE_3,
  55. .wakenet_model_name = wakenet_model_,
  56. .wakenet_model_name_2 = NULL,
  57. .wakenet_mode = DET_MODE_90,
  58. .afe_mode = SR_MODE_HIGH_PERF,
  59. .afe_perferred_core = 1,
  60. .afe_perferred_priority = 1,
  61. .afe_ringbuf_size = 50,
  62. .memory_alloc_mode = AFE_MEMORY_ALLOC_MORE_PSRAM,
  63. .afe_linear_gain = 1.0,
  64. .agc_mode = AFE_MN_PEAK_AGC_MODE_2,
  65. .pcm_config = {
  66. .total_ch_num = channels_,
  67. .mic_num = channels_ - ref_num,
  68. .ref_num = ref_num,
  69. .sample_rate = 16000
  70. },
  71. .debug_init = false,
  72. .debug_hook = {{ AFE_DEBUG_HOOK_MASE_TASK_IN, NULL }, { AFE_DEBUG_HOOK_FETCH_TASK_IN, NULL }},
  73. .afe_ns_mode = NS_MODE_SSP,
  74. .afe_ns_model_name = NULL,
  75. .fixed_first_channel = true,
  76. };
  77. afe_detection_data_ = esp_afe_sr_v1.create_from_config(&afe_config);
  78. xTaskCreate([](void* arg) {
  79. auto this_ = (WakeWordDetect*)arg;
  80. this_->AudioDetectionTask();
  81. vTaskDelete(NULL);
  82. }, "audio_detection", 4096 * 2, this, 2, nullptr);
  83. }
  84. void WakeWordDetect::OnWakeWordDetected(std::function<void(const std::string& wake_word)> callback) {
  85. wake_word_detected_callback_ = callback;
  86. }
  87. void WakeWordDetect::OnVadStateChange(std::function<void(bool speaking)> callback) {
  88. vad_state_change_callback_ = callback;
  89. }
  90. void WakeWordDetect::StartDetection() {
  91. xEventGroupSetBits(event_group_, DETECTION_RUNNING_EVENT);
  92. }
  93. void WakeWordDetect::StopDetection() {
  94. xEventGroupClearBits(event_group_, DETECTION_RUNNING_EVENT);
  95. }
  96. bool WakeWordDetect::IsDetectionRunning() {
  97. return xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT;
  98. }
  99. void WakeWordDetect::Feed(const std::vector<int16_t>& data) {
  100. input_buffer_.insert(input_buffer_.end(), data.begin(), data.end());
  101. auto feed_size = esp_afe_sr_v1.get_feed_chunksize(afe_detection_data_) * channels_;
  102. while (input_buffer_.size() >= feed_size) {
  103. esp_afe_sr_v1.feed(afe_detection_data_, input_buffer_.data());
  104. input_buffer_.erase(input_buffer_.begin(), input_buffer_.begin() + feed_size);
  105. }
  106. }
  107. void WakeWordDetect::AudioDetectionTask() {
  108. auto fetch_size = esp_afe_sr_v1.get_fetch_chunksize(afe_detection_data_);
  109. auto feed_size = esp_afe_sr_v1.get_feed_chunksize(afe_detection_data_);
  110. ESP_LOGI(TAG, "Audio detection task started, feed size: %d fetch size: %d",
  111. feed_size, fetch_size);
  112. //以下内容为暂时加入的,用于测试命令词
  113. // char *mn_name=esp_srmodel_filter(models,ESP_MN_PREFIX,ESP_MN_CHINESE);
  114. // ESP_LOGI(TAG,"mn_name:%s",mn_name);
  115. // esp_mn_iface_t *multinet=esp_mn_handle_from_name(mn_name);
  116. // model_iface_data_t *model_data=multinet->create(mn_name,6000);
  117. // esp_mn_commands_update_from_sdkconfig(multinet,model_data);
  118. // int mu_chunksize = multinet->get_samp_chunksize(model_data);
  119. while (true) {
  120. xEventGroupWaitBits(event_group_, DETECTION_RUNNING_EVENT, pdFALSE, pdTRUE, portMAX_DELAY);
  121. auto res = esp_afe_sr_v1.fetch(afe_detection_data_);
  122. if (res == nullptr || res->ret_value == ESP_FAIL) {
  123. continue;;
  124. }
  125. // Store the wake word data for voice recognition, like who is speaking
  126. StoreWakeWordData((uint16_t*)res->data, res->data_size / sizeof(uint16_t));
  127. // VAD state change
  128. if (vad_state_change_callback_) {
  129. if (res->vad_state == AFE_VAD_SPEECH && !is_speaking_) {
  130. is_speaking_ = true;
  131. vad_state_change_callback_(true);
  132. } else if (res->vad_state == AFE_VAD_SILENCE && is_speaking_) {
  133. is_speaking_ = false;
  134. vad_state_change_callback_(false);
  135. }
  136. }
  137. if (res->wakeup_state == WAKENET_DETECTED) {
  138. StopDetection();
  139. last_detected_wake_word_ = wake_words_[res->wake_word_index - 1];
  140. if (wake_word_detected_callback_) {
  141. wake_word_detected_callback_(last_detected_wake_word_);
  142. }
  143. }
  144. //以下内容为暂时加入的,用于测试命令词
  145. // esp_mn_state_t mn_state = multinet->detect(model_data, res->data);
  146. // if (mn_state == ESP_MN_STATE_DETECTING) {
  147. // continue;
  148. // }
  149. // if (mn_state == ESP_MN_STATE_DETECTED) {
  150. // esp_mn_results_t *mn_result = multinet->get_results(model_data);
  151. // std::string wake_word="你好鹅大厨";
  152. // Application::GetInstance().WakeWordInvoke(wake_word);
  153. // for (int i = 0; i < mn_result->num; i++) {
  154. // ESP_LOGI(TAG,"TOP %d, command_id: %d, phrase_id: %d, string:%s prob: %f",
  155. // i+1, mn_result->command_id[i], mn_result->phrase_id[i], mn_result->string, mn_result->prob[i]);
  156. // }
  157. // }
  158. // if (mn_state == ESP_MN_STATE_TIMEOUT) {
  159. // esp_mn_results_t *mn_result = multinet->get_results(model_data);
  160. // ESP_LOGI(TAG,"timeout, string:%s", mn_result->string);
  161. // // afe_handle->enable_wakenet(afe_data);
  162. // continue;
  163. // }
  164. }
  165. }
  166. void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
  167. // store audio data to wake_word_pcm_
  168. wake_word_pcm_.emplace_back(std::vector<int16_t>(data, data + samples));
  169. // keep about 2 seconds of data, detect duration is 32ms (sample_rate == 16000, chunksize == 512)
  170. while (wake_word_pcm_.size() > 2000 / 32) {
  171. wake_word_pcm_.pop_front();
  172. }
  173. }
  174. void WakeWordDetect::EncodeWakeWordData() {
  175. wake_word_opus_.clear();
  176. if (wake_word_encode_task_stack_ == nullptr) {
  177. wake_word_encode_task_stack_ = (StackType_t*)heap_caps_malloc(4096 * 8, MALLOC_CAP_SPIRAM);
  178. }
  179. wake_word_encode_task_ = xTaskCreateStatic([](void* arg) {
  180. auto this_ = (WakeWordDetect*)arg;
  181. {
  182. auto start_time = esp_timer_get_time();
  183. auto encoder = std::make_unique<OpusEncoderWrapper>(16000, 1, OPUS_FRAME_DURATION_MS);
  184. encoder->SetComplexity(0); // 0 is the fastest
  185. for (auto& pcm: this_->wake_word_pcm_) {
  186. encoder->Encode(std::move(pcm), [this_](std::vector<uint8_t>&& opus) {
  187. std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
  188. this_->wake_word_opus_.emplace_back(std::move(opus));
  189. this_->wake_word_cv_.notify_all();
  190. });
  191. }
  192. this_->wake_word_pcm_.clear();
  193. auto end_time = esp_timer_get_time();
  194. ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms",
  195. this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
  196. std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
  197. this_->wake_word_opus_.push_back(std::vector<uint8_t>());
  198. this_->wake_word_cv_.notify_all();
  199. }
  200. vTaskDelete(NULL);
  201. }, "encode_detect_packets", 4096 * 8, this, 2, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_);
  202. }
  203. bool WakeWordDetect::GetWakeWordOpus(std::vector<uint8_t>& opus) {
  204. std::unique_lock<std::mutex> lock(wake_word_mutex_);
  205. wake_word_cv_.wait(lock, [this]() {
  206. return !wake_word_opus_.empty();
  207. });
  208. opus.swap(wake_word_opus_.front());
  209. wake_word_opus_.pop_front();
  210. return !opus.empty();
  211. }