-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
189 lines (160 loc) · 7.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from logging import captureWarnings
import os, time, argparse
# GPIO - Pi Buttons
from gpiozero import Button
# Audio
from gtts import gTTS
from pydub import AudioSegment
from pydub.playback import play
from recording import record_to_file
from voice_recognition import VoiceRecognition
# TFLite detection
from TFLite_detection_webcam import initialize_detector, safari_mode, query_mode
class VMobi:
"""Class that represents the system as a whole"""
def __init__(self, args, lang = "en"):
self.args = args # Saving arguments on a variable
self.MODEL_DIR = args.modeldir # Directory of the .tflite file and names file
self.RESOLUTION = args.resolution # Camera resolution as in pixels
self.USE_EDGETPU = args.edgetpu # Flag to use the google coral tpu
self.lang = lang # Language used on tts speech voice
self.main() # Runs on the raspberry with buttons on the GPIO
# self.test() # Function to test tts on PC
def test(self):
categories = self.get_all_categories()
print(categories)
time.sleep(1)
print("Now playing the audio")
play_voice("Query mode activaded. Which category do you want?", self.lang)
for cat in categories:
play_voice(cat, self.lang)
def main(self):
"""Main function that orchestrates the product"""
print("On main function!")
# Get a list of the categories as strings
self.categories = self.get_all_categories()
print(f"Got all categories: {self.categories}")
# Conect button on GPIO2 and Ground
# Watch out for connenctions in 'pin_layout.svg'
self.query_button = Button(2)
# Running the safari mode to run on the background
# thread_safari_mode = threading.Thread(target=initialize_detector, args=(self.args,))
# thread_safari_mode.start()
detector_args = initialize_detector(self.args)
while (True):
s = safari_mode(detector_args, self.query_button)
if s > 0:
# Enter Query Mode
query_cat = self.query_mode_type2() # Get the category with the GPIO buttons
query_mode(detector_args, query_cat)
continue
def query_mode_selection(self):
"""[Type 1] Query mode that functions only with buttons"""
up_button = Button(18) # GPIO18 -> Up button
down_button = Button(23) # GPIO23 -> Down Button
print("Entering query mode only with buttons. (Type 1)")
play_voice("Query mode activaded. Which category do you want?", self.lang)
selection = None
index = 0
play_voice(self.categories[index], self.lang) # To read first category
while True:
if (self.query_button.is_pressed or up_button.is_pressed or down_button.is_pressed):
if up_button.is_pressed:
if (index + 1 >= len(self.categories)):
index = 0
continue
print("Up Button was pressed!")
index += 1
if down_button.is_pressed:
if (index - 1 < 0):
index = len(self.categories) - 1
continue
print("Down Button was pressed!")
index -= 1
if self.query_button.is_pressed:
# User choosed the category self.categories[index]
selection = self.categories[index]
print("Query Button was pressed!")
break
play_voice(self.categories[index], self.lang)
play_voice(f"You chose the category: {selection}", self.lang)
return selection
def query_mode_type2(self):
"""Query mode that uses voice recognition and only the query button"""
print("Entering query mode with voice recognition. (Type 2)")
qmode = VoiceRecognition()
qmode.greetings()
record_to_file("output.wav")
categ = qmode.speech_recog()
while categ == None or categ == "list" or categ == "least" or (categ not in self.categories):
if categ == None:
qmode.repeat("category")
record_to_file("output.wav")
elif categ not in self.categories:
qmode.play_voice("Category not in dataset. Which category do you want?")
record_to_file("output.wav")
else:
qmode.list_categories(self.categories)
##### IMPLEMENTAR ##########
##############################
#if categ == 'text': #
# text_recognition_mode #
# ############################
play_voice(f"You chose the category: {categ}", self.lang)
return categ
"""
qmode.play_voice("Category {} selected. Which element do you want?".format(categ))
record_to_file("output.wav")
element = qmode.speech_recog()
while element == None or element == "list" or element == "least" or (element not in moc[categ]):
print(element)
if element == None:
qmode.repeat("element")
record_to_file("output.wav")
elif element not in moc[category]:
qmode.play_voice("Element not in category. Which element do you want?")
record_to_file("output.wav")
else:
qmode.list_elements(moc[element])
qmode.play_voice("Start detecting element {}".format(element))
play_voice("Query mode activaded. Which category do you want?", self.lang)
"""
def get_all_categories(self):
"""Function that get all available categories from model '.name' file"""
for root, dir, files in os.walk(self.MODEL_DIR):
for f in files:
if "labelmap.txt" in f:
filename = f
break
cat = []
f = open(self.MODEL_DIR + filename, "r")
for line in f.readlines():
if "?" in line:
continue
cat.append(line.replace("\n", ""))
return cat
def play_voice(mText, lang="en"):
"""Function used to play the string 'mText' in audio using tts"""
print(f"[play_voice] now playing: '{mText}'")
tts_audio = gTTS(text=mText, lang=lang, slow=False)
tts_audio.save("voice.wav")
play(AudioSegment.from_file("voice.wav"))
os.remove("voice.wav")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--modeldir', help='Folder the .tflite file is located in',
default="Sample_TFLite_model/")
parser.add_argument('--graph', help='Name of the .tflite file, if different than detect.tflite',
default='detect.tflite')
parser.add_argument('--labels', help='Name of the labelmap file, if different than labelmap.txt',
default='labelmap.txt')
parser.add_argument('--threshold', help='Minimum confidence threshold for displaying detected objects',
default=0.5)
parser.add_argument('--resolution', help='Desired webcam resolution in WxH. If the webcam does not support the resolution entered, errors may occur.',
default='1280x720')
parser.add_argument('--edgetpu', help='Use Coral Edge TPU Accelerator to speed up detection',
action='store_true')
parser.add_argument('--safari', help='Start Safari Mode', action='store_true')
parser.add_argument('--query', help='Start Query Mode', default='?')
args = parser.parse_args()
helper = VMobi(args)