WeNet是一款开源端到端ASR工具包,它与ESPnet等开源语音项目相比,最大的优势在于提供了从锻炼到部署的一整套工具链,使ASR效劳的工业落地愈加简单。如图1所示,WeNet工具包完好依赖于PyTorch生态:运用TorchScript中止模型开发,运用Torchaudio中止动态特征提取,运用DistributedDataParallel中止分布式锻炼,运用torch JIT(Just In Time)中止模型导出,运用LibTorch作为消费环境运转时。本系列将对WeNet云端推理部署代码中止解析。
图1:WeNet系统设计[1]
1. 代码结构
WeNet云端推理和部署代码位于wenet/runtime/server/x86途径下,编程言语为C++,其结构如下所示:
其中:
语音文件读入与特征提取相关代码位于frontend文件夹下;端到端模型导入、端点检测与语音解码识别相关代码位于decoder文件夹下,WeNet支持CTC prefix beam search和融合了WFST的CTC beam search这两种解码算法,后者的完成大量自创了Kaldi,相关代码放在kaldi文件夹下;在效劳化方面,WeNet分别完成了基于WebSocket和基于gRPC的两套效劳端与客户端,基于WebSocket的完成位于websocket文件夹下,基于gRPC的完成位于grpc文件夹下,两种完成的入口main函数代码都位于bin文件夹下。日志、计时、字符串处置等辅助代码位于utils文件夹下。
WeNet提供了CMakeLists.txt和Dockerfile,使得用户能便当地中止项目编译和镜像构建。
2. 前端:frontend文件夹
1)语音文件读入
WeNet只支持44字节header的wav格式音频数据,wav header定义在WavHeader结构体中,包括音频格式、声道数、采样率等音频元信息。WavReader类用于语音文件读入,调用fopen翻开语音文件后,WavReader先读入WavHeader大小的数据(也就是44字节),再根据WavHeader中的元信息肯定待读入音频数据的大小,最后调用fread把音频数据读入buffer,并经过static_cast把数据转化为float类型。
struct WavHeader { char riff[4]; // "riff" unsigned int size; char wav[4]; // "WAVE" char fmt[4]; // "fmt " unsigned int fmt_size; uint16_t format; uint16_t channels; unsigned int sample_rate; unsigned int bytes_per_second; uint16_t block_size; uint16_t bit; char data[4]; // "data" unsigned int data_size;};
这里存在的一个风险是,假设WavHeader中存放的元信息有误,则会影响到语音数据的正确读入。
2)特征提取
WeNet运用的特征是fbank,经过FeaturePipelineConfig结构体中止特征设置。默许帧长为25ms,帧移为10ms,采样率和fbank维数则由用户输入。
用于特征提取的类是FeaturePipeline。为了同时支持流式与非流式语音识别,FeaturePipeline类中设置了input_finished_属性来标志输入能否终了,并经过set_input_finished()成员函数来对input_finished_属性中止操作。
提取出来的fbank特征放在feature_queue_中,feature_queue_的类型是BlockingQueue<std::vector>。BlockingQueue类是WeNet完成的一个阻塞队列,初始化的时分需求提供队列的容量(capacity),经过Push()函数向队列中增加特征,经过Pop()函数从队列中读取特征:
当feature_queue_中的feature数量超越capacity,则Push线程被挂起,等候feature_queue_.Pop()释放出空间。当feature_queue_为空,则Pop线程被挂起,等候feature_queue_.Push()。线程的挂起和恢复是经过C++标准库中的线程同步原语std::mutex、std::condition_variable等完成。线程同步还用在AcceptWaveform和ReadOne两个成员函数中,AcceptWaveform把语音数据提取得到的fbank特征放到feature_queue_中,ReadOne成员函数则把特征从feature_queue_中读出,是经典的消费者消费者方式。
3. 解码器:decoder文件夹
1)TorchAsrModel
经过torch::jit::load对存在磁盘上的模型中止反序列化,得到一个ScriptModule对象。
torch::jit::script::Module model = torch::jit::load(model_path);
2)SearchInterface
WeNet推理支持的解码方式都继承自基类SearchInterface,假设要新增解码算法,则需继承SearchInterface类,并提供该类中一切纯虚函数的完成,包括:
// 解码算法的细致完成virtual void Search(const torch::Tensor& logp) = 0;// 重置解码过程virtual void Reset() = 0;// 终了解码过程virtual void FinalizeSearch() = 0;// 解码算法类型,返回一个枚举常量SearchTypevirtual SearchType Type() const = 0;// 返回解码输入virtual const std::vector<std::vector<int>>& Inputs() const = 0;// 返回解码输出virtual const std::vector<std::vector<int>>& Outputs() const = 0;// 返回解码输出对应的似然值virtual const std::vector<float>& Likelihood() const = 0;// 返回解码输出对应的次数virtual const std::vector<std::vector<int>>& Times() const = 0;
目前WeNet只提供了SearchInterface的两种子类完成,也即两种解码算法,分别定义在CtcPrefixBeamSearch和CtcWfstBeamSearch两个类中。
3)CtcEndpoint
WeNet支持语音端点检测,提供了一种基于规则的完成方式,用户可以经过CtcEndpointConfig结构体和CtcEndpointRule结构体中止规则配置。WeNet默许的规则有三条:
检测到了5s的静音,则以为检测到端点;解码出了恣意时长的语音后,检测到了1s的静音,则以为检测到端点;解码出了20s的语音,则以为检测到端点。一旦检测到端点,则终了解码。另外,WeNet把解码得到的空白符(blank)视作静音。
4)TorchAsrDecoder
WeNet提供的解码器定义在TorchAsrDecoder类中。如图3所示,WeNet支持双向解码,即叠加从左往右解码和从右往左解码的结果。在CTC beam search之后,用户还可以选择中止attention重打分。
图2:WeNet解码计算流程[2]
可以经过DecodeOptions结构体中止解码参数配置,包括如下参数:
struct DecodeOptions { int chunk_size = 16; int num_left_chunks = -1; float ctc_weight = 0.0; float rescoring_weight = 1.0; float reverse_weight = 0.0; CtcEndpointConfig ctc_endpoint_config; CtcPrefixBeamSearchOptions ctc_prefix_search_opts; CtcWfstBeamSearchOptions ctc_wfst_search_opts;};
其中,ctc_weight表示CTC解码权重,rescoring_weight表示重打分权重,reverse_weight表示从右往左解码权重。最终解码打分的计算方式为:
final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;rescoring_score = left_to_right_score * (1 - reverse_weight) +right_to_left_score * reverse_weight
TorchAsrDecoder对外提供的解码接口是Decode(),重打分接口是Rescoring()。Decode()返回的是枚举类型DecodeState,包括三个枚举常量:kEndBatch,kEndpoint和kEndFeats,分别表示当前批数据解码终了、检测到端点、一切特征解码终了。
为了支持长语音识别,WeNet还提供了连续解码接口ResetContinuousDecoding(),它与解码注重置接口Reset()的区别在于:连续解码接口会记载全局曾经解码的语音帧数,并保管当前feature_pipeline_的状态。
由于流式ASR效劳需求在客户端和效劳端之间中止双向的流式数据传输,WeNet完成了两种支持双向流式通讯的效劳化接口,分别基于WebSocket和gRPC。
4. 基于WebSocket
1)WebSocket简介
WebSocket是基于TCP的一种新的网络协议,与HTTP协议不同,WebSocket允许效劳器主动发送信息给客户端。 在衔接树立后,客户端和效劳端可以连续互相发送数据,而无需在每次发送数据时重新发起衔接央求。因此大大减小了网络带宽的资源消耗 ,在性能上更有优势。
WebSocket支持文本和二进制两种格式的数据传输 。
2)WeNet的WebSocket接口
WeNet运用了boost库的WebSocket完成,定义了WebSocketClient(客户端)和WebSocketServer(效劳端)两个类。
在流式ASR过程中,WebSocketClient给WebSocketServer发送数据可以分为三个步骤:1)发送开端信号与解码配置;2)发送二进制语音数据:pcm字节流;3)发送中止信号。从WebSocketClient::SendStartSignal()和WebSocketClient::SendEndSignal()可以看到,开端信号、解码配置和中止信号都是包装在json字符串中,经过WebSocket文本格式传输。pcm字节流则经过WebSocket二进制格式中止传输。
void WebSocketClient::SendStartSignal() { // TODO(Binbin Zhang): Add sample rate and other setting surpport json::value start_tag = {{"signal", "start"}, {"nbest", nbest_}, {"continuous_decoding", continuous_decoding_}}; std::string start_message = json::serialize(start_tag); this->SendTextData(start_message);}void WebSocketClient::SendEndSignal() { json::value end_tag = {{"signal", "end"}}; std::string end_message = json::serialize(end_tag); this->SendTextData(end_message);}
WebSocketServer在收到数据后,需求先判别收到的数据是文本还是二进制格式:假设是文本数据,则中止json解析,并根据解析结果中止解码配置、启动或中止,处置逻辑定义在ConnectionHandler::OnText()函数中。假设是二进制数据,则中止语音识别,处置逻辑定义在ConnectionHandler::OnSpeechData()中。
3)缺陷
WebSocket需求开发者在WebSocketClient和WebSocketServer写好对应的消息构造和解析代码,容易出错。另外,从以上代码来看,效劳需求借助json格式来序列化和反序列化数据,效率没有protobuf格式高。
关于这些缺陷,gRPC框架提供了更好的处置方法。
5. 基于gRPC
1)gRPC简介
gRPC是谷歌推出的开源RPC框架,运用HTTP2作为网络传输协议,并运用protobuf作为数据交流格式,有更高的数据传输效率。在gRPC框架下,开发者只需经过一个.proto文件定义好RPC效劳(service)与消息(message),便可经过gRPC提供的代码生成工具(protoc compiler)自动生成消息构造和解析代码,使开发者能更好地聚焦于接口设计本身。
中止RPC调用时,gRPC Stub(客户端)向gRPC Server(效劳端)发送.proto文件中定义的Request消息,gRPC Server在处置完央求之后,经过.proto文件中定义的Response消息将结果返回给gRPC Stub。
gRPC具有跨言语特性,支持不同言语写的微效劳中止互动,比如压效劳端用C++完成,客户端用Ruby完成。protoc compiler支持12种言语的代码生成。
图1:gRPC Server和gRPC Stub交互[1]
2)WeNet的proto文件
WeNet定义的效劳为ASR,包含一个Recognize方法,该方法的输入(Request)、输出(Response)都是流式数据(stream)。在运用protoc compiler编译proto文件后,会得到4个文件:wenet.grpc.pb.h,http://wenet.grpc.pb.cc,wenet.pb.h,http://wenet.pb.cc。其中,wenet.pb.h/cc中存储了protobuf数据格式的定义,wenet.grpc.pb.h中存储了gRPC效劳端/客户端的定义。经过在代码中包括wenet.pb.h和wenet.grpc.pb.h两个头文件,开发者可以直接运用Request消息和Response消息类,访问其字段。
service ASR { rpc Recognize (stream Request) returns (stream Response) {}}message Request { message DecodeConfig { int32 nbest_config = 1; bool continuous_decoding_config = 2; } oneof RequestPayload { DecodeConfig decode_config = 1; bytes audio_data = 2; }}message Response { message OneBest { string sentence = 1; repeated OnePiece wordpieces = 2; } message OnePiece { string word = 1; int32 start = 2; int32 end = 3; } enum Status { ok = 0; failed = 1; } enum Type { server_ready = 0; partial_result = 1; final_result = 2; speech_end = 3; } Status status = 1; Type type = 2; repeated OneBest nbest = 3;}
3)WeNet的gRPC完成
WeNet gRPC效劳端定义了GrpcServer类,该类继承自wenet.grpc.pb.h中的纯虚基类ASR::Service。
语音识别的入口函数是GrpcServer::Recognize,该函数初始化一个GRPCConnectionHandler实例来中止语音识别,并经过ServerReaderWriter类的stream对象来传送输入输出。
Status GrpcServer::Recognize(ServerContext* context, ServerReaderWriter<Response, Request>* stream) { LOG(INFO) << "Get Recognize request" << std::endl; auto request = std::make_shared(); auto response = std::make_shared(); GrpcConnectionHandler handler(stream, request, response, feature_config_, decode_config_, symbol_table_, model_, fst_); std::thread t(std::move(handler)); t.join(); return Status::OK;}
WeNet gRPC客户端定义了GrpcClient类。客户端在树立与效劳端的衔接时需实例化ASR::Stub,并经过ClientReaderWriter类的stream对象,完成双向流式通讯。
void GrpcClient::Connect() { channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_), grpc::InsecureChannelCredentials()); stub_ = ASR::NewStub(channel_); context_ = std::make_shared(); stream_ = stub_->Recognize(context_.get()); request_ = std::make_shared(); response_ = std::make_shared(); request_->mutable_decode_config()->set_nbest_config(nbest_); request_->mutable_decode_config()->set_continuous_decoding_config( continuous_decoding_); stream_->Write(*request_);}
http://grpc_client_main.cc中,客户端分段传输语音数据,每0.5s中止一次传输,即关于一个采样率为8k的语音文件来说,每次传4000帧数据。为了减小传输数据的大小,提升数据传输速度,先在客户端将float类型转为int16_t,效劳端在接受到数据后,再将int16_t转为float。c++中float为32位。
int main(int argc, char *argv[]) { ... // Send data every 0.5 second const float interval = 0.5; const int sample_interval = interval * sample_rate; for (int start = 0; start < num_sample; start += sample_interval) { if (client.done()) { break; } int end = std::min(start + sample_interval, num_sample); // Convert to short std::vector<int16_t> data; data.reserve(end - start); for (int j = start; j < end; j++) { data.push_back(static_cast<int16_t>(pcm_data[j])); } // Send PCM data client.SendBinaryData(data.data(), data.size() * sizeof(int16_t)); ...}
相关推荐
© 2020 asciim码
人生就是一场修行