-
Notifications
You must be signed in to change notification settings - Fork 32
/
snowboydecoder.py
184 lines (155 loc) · 6.67 KB
/
snowboydecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python
import collections
import pyaudio
import snowboydetect
import time
import wave
import os
import logging
logging.basicConfig()
logger = logging.getLogger("snowboy")
logger.setLevel(logging.DEBUG)
TOP_DIR = os.path.dirname(os.path.abspath(__file__))
RESOURCE_FILE = os.path.join(TOP_DIR, "resources/common.res")
DETECT_DING = os.path.join(TOP_DIR, "resources/ding.wav")
DETECT_DONG = os.path.join(TOP_DIR, "resources/dong.wav")
class RingBuffer(object):
"""Ring buffer to hold audio from PortAudio"""
def __init__(self, size = 4096):
self._buf = collections.deque(maxlen=size)
def extend(self, data):
"""Adds data to the end of buffer"""
self._buf.extend(data)
def get(self):
"""Retrieves data from the beginning of buffer and clears it"""
tmp = bytes(bytearray(self._buf))
self._buf.clear()
return tmp
def play_audio_file(fname=DETECT_DING):
"""Simple callback function to play a wave file. By default it plays
a Ding sound.
:param str fname: wave file name
:return: None
"""
ding_wav = wave.open(fname, 'rb')
ding_data = ding_wav.readframes(ding_wav.getnframes())
audio = pyaudio.PyAudio()
stream_out = audio.open(
format=audio.get_format_from_width(ding_wav.getsampwidth()),
channels=ding_wav.getnchannels(),
rate=ding_wav.getframerate(), input=False, output=True)
stream_out.start_stream()
stream_out.write(ding_data)
time.sleep(0.2)
stream_out.stop_stream()
stream_out.close()
audio.terminate()
class HotwordDetector(object):
"""
Snowboy decoder to detect whether a keyword specified by `decoder_model`
exists in a microphone input stream.
:param decoder_model: decoder model file path, a string or a list of strings
:param resource: resource file path.
:param sensitivity: decoder sensitivity, a float of a list of floats.
The bigger the value, the more senstive the
decoder. If an empty list is provided, then the
default sensitivity in the model will be used.
:param audio_gain: multiply input volume by this factor.
"""
def __init__(self, decoder_model,
resource=RESOURCE_FILE,
sensitivity=[],
audio_gain=1):
tm = type(decoder_model)
ts = type(sensitivity)
if tm is not list:
decoder_model = [decoder_model]
if ts is not list:
sensitivity = [sensitivity]
model_str = ",".join(decoder_model)
self.detector = snowboydetect.SnowboyDetect(
resource_filename=resource.encode(), model_str=model_str.encode())
self.detector.SetAudioGain(audio_gain)
self.num_hotwords = self.detector.NumHotwords()
if len(decoder_model) > 1 and len(sensitivity) == 1:
sensitivity = sensitivity*self.num_hotwords
if len(sensitivity) != 0:
assert self.num_hotwords == len(sensitivity), \
"number of hotwords in decoder_model (%d) and sensitivity " \
"(%d) does not match" % (self.num_hotwords, len(sensitivity))
sensitivity_str = ",".join([str(t) for t in sensitivity])
if len(sensitivity) != 0:
self.detector.SetSensitivity(sensitivity_str.encode())
self.ring_buffer = RingBuffer(
self.detector.NumChannels() * self.detector.SampleRate() * 5)
def start(self, detected_callback=play_audio_file,
interrupt_check=lambda: False,
sleep_time=0.03):
"""
Start the voice detector. For every `sleep_time` second it checks the
audio buffer for triggering keywords. If detected, then call
corresponding function in `detected_callback`, which can be a single
function (single model) or a list of callback functions (multiple
models). Every loop it also calls `interrupt_check` -- if it returns
True, then breaks from the loop and return.
:param detected_callback: a function or list of functions. The number of
items must match the number of models in
`decoder_model`.
:param interrupt_check: a function that returns True if the main loop
needs to stop.
:param float sleep_time: how much time in second every loop waits.
:return: None
"""
def audio_callback(in_data, frame_count, time_info, status):
self.ring_buffer.extend(in_data)
play_data = chr(0) * len(in_data)
return play_data, pyaudio.paContinue
self.audio = pyaudio.PyAudio()
self.stream_in = self.audio.open(
input=True, output=False,
format=self.audio.get_format_from_width(
self.detector.BitsPerSample() / 8),
channels=self.detector.NumChannels(),
rate=self.detector.SampleRate(),
frames_per_buffer=2048,
stream_callback=audio_callback)
if interrupt_check():
logger.debug("detect voice return")
return
tc = type(detected_callback)
if tc is not list:
detected_callback = [detected_callback]
if len(detected_callback) == 1 and self.num_hotwords > 1:
detected_callback *= self.num_hotwords
assert self.num_hotwords == len(detected_callback), \
"Error: hotwords in your models (%d) do not match the number of " \
"callbacks (%d)" % (self.num_hotwords, len(detected_callback))
logger.debug("detecting...")
while True:
if interrupt_check():
logger.debug("detect voice break")
break
data = self.ring_buffer.get()
if len(data) == 0:
time.sleep(sleep_time)
continue
ans = self.detector.RunDetection(data)
if ans == -1:
logger.warning("Error initializing streams or reading audio data")
elif ans > 0:
message = "Keyword " + str(ans) + " detected at time: "
message += time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time()))
logger.info(message)
callback = detected_callback[ans-1]
if callback is not None:
callback()
logger.debug("finished.")
def terminate(self):
"""
Terminate audio stream. Users cannot call start() again to detect.
:return: None
"""
self.stream_in.stop_stream()
self.stream_in.close()
self.audio.terminate()