Skip to content

Commit

Permalink
Conversation mode, mic input, eos token for faster chat, debug comman…
Browse files Browse the repository at this point in the history
…d, redo command, output streaming
  • Loading branch information
FontaineRiant committed Apr 17, 2024
1 parent bb52d51 commit 2c2aed9
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 104 deletions.
49 changes: 49 additions & 0 deletions audio/stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import sys
from contextlib import contextmanager
from whisper_mic import WhisperMic, get_logger

# Hide error output from ALSA, JACK... (pyaudio)
@contextmanager
def ignoreStderr():
devnull = os.open(os.devnull, os.O_WRONLY)
old_stderr = os.dup(2)
sys.stderr.flush()
os.dup2(devnull, 2)
os.close(devnull)
try:
yield
finally:
os.dup2(old_stderr, 2)
os.close(old_stderr)

# custom WhisperMic for various fixes
class CustomMic(WhisperMic):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger('whisper_mic', level='warning')
self.audio_model.to('cpu')

def listen(self, timeout=None, phrase_time_limit=None):
self.logger.info("Listening...")
while self.result_queue.empty():
self._WhisperMic__listen_handler(timeout, phrase_time_limit)
if self.result_queue.empty():
print('Too quiet, please repeat')
while True:
if not self.result_queue.empty():
return self.result_queue.get()

# init mic
with ignoreStderr():
mic = CustomMic(english=True, device='cuda')

def listen():
print('\n> Listening (ctrl+c for menu)')
with ignoreStderr():
mic.audio_model.to('cuda')
result = None
while not result:
result = mic.listen()
mic.audio_model.to('cpu')
return result
11 changes: 6 additions & 5 deletions generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def __init__(self,
self.max_history = min(self.model.config.max_position_embeddings - self.length, 6000)
self.streamer = TextStreamer(self.enc, skip_prompt=True)

def __del__(self):
pass
def generate(self, prompt: str, stream=True, eos_tokens=[]):
eos_token_ids = [self.enc.encode(term)[1] for term in eos_tokens]

def generate(self, prompt: str):
model_inputs = self.enc([prompt], return_tensors='pt').to(self.device)


if self.offload_to_memory:
self.model.to(self.device)

Expand All @@ -52,8 +52,9 @@ def generate(self, prompt: str):
do_sample=True,
use_cache=True,
pad_token_id=self.enc.eos_token_id,
streamer=self.streamer,
penalty_alpha=0.5
streamer=self.streamer if stream else None,
repetition_penalty=1.05,
eos_token_id=eos_token_ids + [self.enc.eos_token_id]
)
print('\033[00m', end='')

Expand Down
Loading

0 comments on commit 2c2aed9

Please sign in to comment.