diff --git a/models/transcription/transcription_transcriber.py b/models/transcription/transcription_transcriber.py index c1856b34..a535cd8a 100644 --- a/models/transcription/transcription_transcriber.py +++ b/models/transcription/transcription_transcriber.py @@ -10,6 +10,7 @@ from .transcription_whisper import getWhisperModel, checkWhisperWeight import torch import numpy as np +from pydub import AudioSegment PHRASE_TIMEOUT = 3 MAX_PHRASES = 10 @@ -97,21 +98,21 @@ class AudioTranscriber: return audio_data def processSpeakerData(self): - original_channels = self.audio_sources["channels"] - if original_channels <= 2: - channels = original_channels - sample_rate = self.audio_sources["sample_rate"] - else: - channels = 2 - sample_rate = self.audio_sources["sample_rate"]*original_channels/2 - temp_file = BytesIO() with wave.open(temp_file, 'wb') as wf: - wf.setnchannels(channels) + wf.setnchannels(self.audio_sources["channels"]) wf.setsampwidth(get_sample_size(paInt16)) - wf.setframerate(sample_rate) + wf.setframerate(self.audio_sources["sample_rate"]) wf.writeframes(self.audio_sources["last_sample"]) temp_file.seek(0) + + if self.audio_sources["channels"] > 2: + audio = AudioSegment.from_file(temp_file, format="wav") + mono_audio = audio.set_channels(1) + temp_file = BytesIO() + mono_audio.export(temp_file, format="wav") + temp_file.seek(0) + with AudioFile(temp_file) as source: audio = self.audio_recognizer.record(source) return audio