从 wenetspeech 训练中文模型,结合 sherpa-ncnn/onnx 调用
训练
环境配置(必看)
参考 https://icefall.readthedocs.io/en/latest/docker/intro.html#view-available-tags
在服务器,使用 docker 直接拉取官方配置好的 image
docker pull k2fsa/icefall:torch2.2.0-cuda12.1#拉取镜像,踩坑:镜像不能版本太老,否则其中签名大量过期,bug很多
docker run -itd --gpus all --privileged -v ~/nlp/:/nlp --name nlp k2fsa/icefall:torch2.2.0-cuda12.1#启动镜像
docker exec -it nlp bash#进入镜像,默认进入workspace/icefall
apt update#更新
git pull#拉取最新代码
准备数据集(必看)
在服务器终端执行
cd /workspace/icefall/egs/wenetspeech/ASR/#进入wenetspeech示例目录
git clone https://github.com/wenet-e2e/WenetSpeech.git#拉取wenetspeech仓库目录
cd /workspace/icefall/egs/wenetspeech/ASR/WenetSpeech#进入wenetspeech仓库
echo 'PASSWORD' > SAFEBOX/password#配置密码
bash utils/download_wenetspeech.sh DOWNLOAD_DIR UNTAR_DIR#开始下载
#数据集非常大,下载很慢
mv DOWNLOAD_DIR ../download#移动数据集
bahs prepare.sh#数据预处理
cp -r data pruned_transducer_stateless5#复制预处理生成的文件到pruned_transducer_stateless5
注:prepare.sh 会针对数据集进行大量处理,将在当前目录下生成 data/lang_char 文件夹,其中包含了 tokens.txt 等我们后续会使用到的文件
训练模型(必看)
在终端执行
cd /workspace/icefall/egs/wenetspeech/ASR/
export CUDA_VISIBLE_DEVICES="0,1"#根据自己的显卡数量调整
python ./pruned_transducer_stateless5/train.py \#训练流式模型,与常规有所不同
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp \
--world-size 8 \
--num-epochs 99 \
--start-epoch 1 \
--max-duration 140 \
--valid-interval 3000 \
--model-warm-step 3000 \
--save-every-n 8000 \
--average-period 1000 \
--training-subset L \
--dynamic-chunk-training True \
--causal-convolution True \
--short-chunk-size 25 \
--num-left-chunks 4
训练完成后,可以在 pruned_transducer_stateless5/exp 下看到训练生成的文件
解码测试(可跳过)
export CUDA_VISIBLE_DEVICES='0'
python pruned_transducer_stateless5/streaming_decode.py \
--epoch 6 \
--avg 1 \
--decode-chunk-size 16 \
--left-context 64 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless5/exp \
--use-averaged-model True \
--decoding-method greedy_search \
--num-decode-streams 200
tensorboard 查看(必看)
当训练完毕以后,我们可以得到相关的训练 log
文件和 tensorboard
损失记录,可以在终端使用如下指令:
cd pruned_transducer_stateless5/exp
tensorboard --logdir tensorboard --host=0.0.0.0
然后会看到
然后访问 http://localhost:6006 即可看到结果
数据集 wenetspeech 非常大,下载很耗时间
准备音频文件并进行特征提取
(注:在这里我们也用了 musan 数据集对训练数据进行增广,具体的可以参考 prepare.sh 中对 musan 处理和使用的相关指令,这里不针对介绍。)
下载并解压数据
为了统一文件名,这里将数据包文件名变为 WenetSpeech, 其中 audio 包含了所有训练和测试的音频数据
>> tree download/WenetSpeech -L 1
download/WenetSpeech
├── audio
├── TERMS_OF_ACCESS
└── WenetSpeech.json
>> tree download/WenetSpeech/audio -L 1
download/WenetSpeech/audio
├── dev
├── test_meeting
├── test_net
└── train
WenetSpeech.json
中包含了音频文件路径和相关的监督信息,我们可以查看 WenetSpeech.json
文件,部分信息如下所示:
"audios": [
{
"aid": "Y0000000000_--5llN02F84",
"duration": 2494.57,
"md5": "48af998ec7dab6964386c3522386fa4b",
"path": "audio/train/youtube/B00000/Y0000000000_--5llN02F84.opus",
"source": "youtube",
"tags": [
"drama"
],
"url": "https://www.youtube.com/watch?v=--5llN02F84",
"segments": [
{
"sid": "Y0000000000_--5llN02F84_S00000",
"confidence": 1.0,
"begin_time": 20.08,
"end_time": 24.4,
"subsets": [
"L"
],
"text": "怎么样这些日子住得还习惯吧"
},
{
"sid": "Y0000000000_--5llN02F84_S00002",
"confidence": 1.0,
"begin_time": 25.0,
"end_time": 26.28,
"subsets": [
"L"
],
"text": "挺好的"
(注:WenetSpeech 中文数据集中包含了 S,M,L 三个不同规模的训练数据集)
利用 lhotse 生成 manifests
关于 lhotse 是如何将原始数据处理成 jsonl.gz
格式文件的,这里可以参考文件 wenet_speech.py, 其主要功能是生成 recordings
和 supervisions
的 jsonl.gz
格式文件
>> lhotse prepare wenet-speech download/WenetSpeech data/manifests -j 15
>> tree data/manifests -L 1
├── wenetspeech_recordings_DEV.jsonl.gz
├── wenetspeech_recordings_L.jsonl.gz
├── wenetspeech_recordings_M.jsonl.gz
├── wenetspeech_recordings_S.jsonl.gz
├── wenetspeech_recordings_TEST_MEETING.jsonl.gz
├── wenetspeech_recordings_TEST_NET.jsonl.gz
├── wenetspeech_supervisions_DEV.jsonl.gz
├── wenetspeech_supervisions_L.jsonl.gz
├── wenetspeech_supervisions_M.jsonl.gz
├── wenetspeech_supervisions_S.jsonl.gz
├── wenetspeech_supervisions_TEST_MEETING.jsonl.gz
└── wenetspeech_supervisions_TEST_NET.jsonl.gz
这里,可用 vim
对 recordings
和 supervisions
的 jsonl.gz
文件进行查看, 其中:
wenetspeech_recordings_S.jsonl.gz:
-
wenetspeech_supervisions_S.jsonl.gz:
-
由上面两幅图可知,recordings
用于描述音频文件信息,包含了音频样本的 id、具体路径、通道、采样率、子样本数和时长等。supervisions
用于记录监督信息,包含了音频样本对应的 id、起始时间、时长、通道、文本和语言类型等。
接下来,我们将对音频数据提取特征。
计算、提取和贮存音频特征
首先,对数据进行预处理,包括对文本进行标准化和对音频进行时域上的增广,可参考文件 preprocess_wenetspeech.py。
python3 ./local/preprocess_wenetspeech.py
其次,将数据集切片并对每个切片数据集进行特征提取。可参考文件 compute_fbank_wenetspeech_splits.py。
(注:这里的切片是为了可以开启多个进程同时对大规模数据集进行特征提取,提高效率。如果数据集比较小,对数据进行切片处理不是必须的。)
# 这里的 L 也可修改为 M 或 S, 表示训练数据子集
lhotse split 1000 ./data/fbank/cuts_L_raw.jsonl.gz data/fbank/L_split_1000
python3 ./local/compute_fbank_wenetspeech_splits.py \
--training-subset L \
--num-workers 20 \
--batch-duration 600 \
--start 0 \
--num-splits 1000
最后,待提取完每个切片数据集的特征后,将所有切片数据集的特征数据合并成一个总的特征数据集:
# 这里的 L 也可修改为 M 或 S, 表示训练数据子集
pieces=$(find data/fbank/L_split_1000 -name "cuts_L.*.jsonl.gz")
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
至此,我们基本完成了音频文件的准备和特征提取。接下来,我们将构建语言建模文件。
构建语言建模文件
在 RNN-T
模型框架中,我们实际需要的用于训练和测试的建模文件有 tokens.txt
、words.txt
和 Linv.pt
。 我们按照如下步骤构建语言建模文件:
规范化文本并生成 text
在这一步骤中,规范文本的函数文件可参考 text2token.py。
# Note: in Linux, you can install jq with the following command:
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
# 2. chmod +x ./jq
# 3. cp jq /usr/bin
gunzip -c data/manifests/wenetspeech_supervisions_L.jsonl.gz \
| jq 'text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > data/lang_char/text
text
的形式如下:
怎么样这些日子住得还习惯吧
挺好的
对了美静这段日子经常不和我们一起用餐
是不是对我回来有什么想法啊
哪有的事啊
她这两天挺累的身体也不太舒服
我让她多睡一会那就好如果要是觉得不方便
我就搬出去住
............
分词并生成 words.txt
这里我们用 jieba
对中文句子进行分词,可参考文件 text2segments.py 。
python3 ./local/text2segments.py \
--input-file data/lang_char/text \
--output-file data/lang_char/text_words_segmentation
cat data/lang_char/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > data/lang_char/words_no_ids.txt
python3 ./local/prepare_words.py \
--input-file data/lang_char/words_no_ids.txt \
--output-file data/lang_char/words.txt
text_words_segmentation
的形式如下:
怎么样 这些 日子 住 得 还 习惯 吧
挺 好 的
对 了 美静 这段 日子 经常 不 和 我们 一起 用餐
是不是 对 我 回来 有 什么 想法 啊
哪有 的 事 啊
她 这 两天 挺累 的 身体 也 不 太 舒服
我 让 她 多 睡 一会 那就好 如果 要是 觉得 不 方便
我 就 搬出去 住
............
words_no_ids.txt
的形式如下:
............
阿
阿Q
阿阿虎
阿阿离
阿阿玛
阿阿毛
阿阿强
阿阿淑
阿安
............
words.txt
的形式如下:
............
阿 225
阿Q 226
阿阿虎 227
阿阿离 228
阿阿玛 229
阿阿毛 230
阿阿强 231
阿阿淑 232
阿安 233
............
生成 tokens.txt 和 lexicon.txt
这里生成 tokens.txt
和 lexicon.txt
的函数文件可参考 prepare_char.py 。
python3 ./local/prepare_char.py \
--lang-dir data/lang_char
tokens.txt
的形式如下:
<blk> 0
<sos/eos> 1
<unk> 2
怎 3
么 4
样 5
这 6
些 7
日 8
子 9
............
lexicon.txt
的形式如下:
............
X光 X 光
X光线 X 光 线
X射线 X 射 线
Y Y
YC Y C
YS Y S
YY Y Y
Z Z
ZO Z O
ZSU Z S U
○ ○
一 一
一一 一 一
一一二 一 一 二
一一例 一 一 例
............
至此,第一步全部完成。对于不同数据集来说,其基本思路也是类似的。在数据准备和处理阶段,我们主要做两件事情:准备音频文件并进行特征提取
、构建语言建模文件
。
这里我们使用的范例是中文汉语,建模单元是字。在英文数据中,我们一般用 BPE 作为建模单元,具体的可参考 egs/librispeech/ASR/prepare.sh 。
转化到 ncnn(服务器执行)
参考 https://k2-fsa.github.io/icefall/model-export/export-ncnn-zipformer.html
编译 ncnn
创建文件夹,拉取代码。在终端执行
cd /workspace/icefall/egs/wenetspeech/ASR
mkdir -p open-source
cd open-source
git clone https://github.com/csukuangfj/ncnn
cd ncnn
git submodule update --recursive --init
开始编译,在终端执行
cd /workspace/icefall/egs/wenetspeech/ASR/open-source/ncnn
mkdir -p build-wheel
cd build-wheel
cmake \
-DCMAKE_BUILD_TYPE=Release \
-DNCNN_PYTHON=ON \
-DNCNN_BUILD_BENCHMARK=OFF \
-DNCNN_BUILD_EXAMPLES=OFF \
-DNCNN_BUILD_TOOLS=ON \
..
make -j4
cd /workspace/icefall/egs/wenetspeech/ASR/open-source/ncnn
export PYTHONPATH=$PWD/python:$PYTHONPATH
export PATH=$PWD/tools/pnnx/build/src:$PATH
export PATH=$PWD/build-wheel/tools/quantize:$PATH
export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0"
cd tools/pnnx
vim CMakeList.txt
mkdir build
cd build
cmake ..
make -j4
./src/pnnx
#备注:作者用的C++14 如果出现编译pytorch需要C++17,请修改CMakeList.txt
转化模型
声明:官方并没有给出 wenetspeech 训练出的模型转化为 ncnn 的代码,但是有人发布了已经转化好的 ncnn 模型,所以我们直接使用(迫于无奈)
该作者基于 <https://github.com/k2fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>
进行改写,训练了模型,并转化为 ncnn 模型,但是没有提供转化代码。理论上只需要使用 export-for-ncnn.py 即可,可惜没有这个文件
dir=/workspace/icefall/egs/wenetspeech/ASR/pruned_transducer_stateless5
./pruned_transducer_stateless5/export-for-ncnn.py \
--tokens $dir/data/lang_bpe_500/tokens.txt \
--exp-dir $dir/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--num-left-chunks 4 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
下载他人转化好的 ncnn 模型:
mkdir sherpa-ncnn
wget https://github.com/k2-fsa/sherpa-ncnn/releases/download/models/sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.tar.bz2
tar xvf sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.tar.bz2
部署到树莓派(ncnn)
下载模型
cd ~
mkdir sherpa-ncnn
cd sherpa-ncnn
wget https://github.com/k2-fsa/sherpa-ncnn/releases/download/models/sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.tar.bz2
tar xvf sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.tar.bz2
配置 sherpa-ncnn
pip install sherpa-ncnn
pip install sounddevice
pip install playsound
编写代码
修改官方 https://github.com/k2-fsa/sherpa-ncnn/blob/master/python-api-examples/speech-recognition-from-microphone.py 代码,然后执行即可,具体代码如下:
try:
import sounddevice as sd
import sys
playsound import playsound
import sherpa_ncnn
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
def create_recognizer():
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
# for download links.
recognizer = sherpa_ncnn.Recognizer(
tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
encoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.bin",
decoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.param",
decoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.bin",
joiner_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.param",
joiner_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.bin",
num_threads=4,
)
return recognizer
def main():
print("Started! Please speak")
recognizer = create_recognizer()
sample_rate = recognizer.sample_rate
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
with sd.InputStream(
channels=1, dtype="float32", samplerate=sample_rate
) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
recognizer.accept_waveform(sample_rate, samples)
result = recognizer.text
if result.contain("SB"):#修改这里
print("shut down")
if last_result != result:
last_result = result
print(result)
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
转化到 onnx(服务器执行)
参考 https://k2-fsa.github.io/icefall/model-export/export-onnx.html
安装 onnx 转化工具
pip install onnx
转化执行
cd /workspace/icefall/egs/wenetspeech/ASR/pruned_transducer_stateless5
python3 export-onnx-streaming.py --tokens ./data/lang_char/tokens.txt --use-averaged-model 0 --epoch 99 --avg 1 --exp-dir ./exp/
进入 /workspace/icefall/egs/wenetspeech/ASR/runed_transducer_stateless5/exp,可以看到
部署(树莓派)
参考 https://github.com/k2-fsa/sherpa-onnx
cd ~
git clone https://github.com/k2-fsa/sherpa-onnx
cd sherpa-onnx
python3 setup.py install
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
你会看到:
/Users/fangjun/py38/lib/python3.8/site-packages/sherpa_onnx/__init__.py
安装自己的模型文件
cd ~/sherpa-onnx
#在这里解压缩model文件夹,不用找了这是付费内容
对~/sherpa-onnx/sherpa-onnx/python-api-examples/speech-recognition-from-microphone.py 代码修改:
#!/usr/bin/env python3
# Real-time speech recognition from a microphone with sherpa-onnx Python API
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# to download pre-trained models
import argparse
import sys
from pathlib import Path
import re
from typing import List
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_onnx
dirty_words = ['傻逼', '白痴', '操']
# 定义脏话检测函数
def detect_dirty_words(text):
# 使用正则表达式匹配脏话
pattern = r'|'.join(map(re.escape, dirty_words))
if re.search(pattern, text, re.IGNORECASE):
print('(警告: 检测到脏话!)',end='')
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
default="./model/tokens.txt",
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
default="./model/encoder-epoch-99-avg-1.onnx",
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
type=str,
default="./model/decoder-epoch-99-avg-1.onnx",
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
type=str,
default="./model/joiner-epoch-99-avg-1.onnx",
help="Path to the joiner model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""Used only when --decoding-method is modified_beam_search.
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
return parser.parse_args()
def create_recognizer(args):
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=1,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
provider=args.provider,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
)
return recognizer
def main():
args = get_args()
devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
recognizer = create_recognizer(args)
print("Started! Please speak")
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
# sherpa-onnx will do resampling inside.
sample_rate = 48000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = recognizer.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
result = recognizer.get_result(stream)
if last_result != result:
last_result = result
print("\r{}".format(result), end="", flush=True)
detect_dirty_words(result)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
运行:
cd ~/sherpa-onnx
python3 ./sherpa-onnx/python-api-examples/speech-recognition-from-microphone.py
PC 测试
安装 sherpa-onnx
cd ~
mkdir sherpa-onnx
cd sherpa-onnx
git clone https://github.com/k2-fsa/sherpa-onnx#为了copy代码
pip install sherpa-onnx#为了安装
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"#测试
加载模型
cd ~/sherpa-onnx
#在此处,解压缩我的model文件,上文中放了,引用不过来
对~/sherpa-onnx/sherpa-onnx/python-api-examples/speech-recognition-from-microphone.py 代码修改:
#!/usr/bin/env python3
# Real-time speech recognition from a microphone with sherpa-onnx Python API
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# to download pre-trained models
import argparse
import sys
from pathlib import Path
from typing import List
import re
# 定义一个脏话列表
dirty_words = ['傻逼', '白痴', '操']
# 定义脏话检测函数
def detect_dirty_words(text):
# 使用正则表达式匹配脏话
pattern = r'|'.join(map(re.escape, dirty_words))
if re.search(pattern, text, re.IGNORECASE):
print('警告: 检测到脏话!')
else:
print('识别结果没有脏话')
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
required=True,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
type=str,
required=True,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""Used only when --decoding-method is modified_beam_search.
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
return parser.parse_args()
def create_recognizer(args):
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=1,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
provider=args.provider,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
)
return recognizer
def main():
args = get_args()
devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
recognizer = create_recognizer(args)
print("Started! Please speak")
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
# sherpa-onnx will do resampling inside.
sample_rate = 48000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = recognizer.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
result = recognizer.get_result(stream)
if last_result != result:
last_result = result
print("\r{}".format(result), end="", flush=True)
# 检测结果中是否有脏话
detect_dirty_words(result)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
运行代码
python3 .\sherpa-onnx\python-api-examples\speech-recognition-from-microphone.py --tokens=./model/tokens.txt --encoder=model/encoder-epoch-99-avg-1.onnx --decoder=model/decoder-epoch-99-avg-1.onnx --joiner=model/joiner-epoch-99-avg-1.onnx
得到:
PS:代码时付费内容
完整代码链接:https://pan.baidu.com/s/1uPurik6NgGplEUSdoMqSvg?
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于