ドキュメントを更新し、型注釈を追加してコードの可読性とメンテナンス性を向上。各モジュールの使用例や依存関係を明示化し、エラーハンドリングを改善。
This commit is contained in:
@@ -1,3 +1,78 @@
|
||||
## 文字起こしモジュール (models.transcription)
|
||||
|
||||
このドキュメントでは `models/transcription` に関する設計・セットアップ・使用例・テスト方針・トラブルシュートをまとめます。
|
||||
|
||||
### 概要
|
||||
- `models/transcription` は音声入力をテキストに変換する機能を提供します。主に:
|
||||
- `transcription_recorder.py` — マイクやスピーカからの音声取得ラッパー
|
||||
- `transcription_transcriber.py` — 音声バッファを認識エンジンに渡して文字起こしを行うロジック
|
||||
- `transcription_whisper.py` — faster-whisper(WhisperModel)周りのダウンロード/ロード補助
|
||||
- `transcription_languages.py` — 各言語・国別のエンジン別コードマップ
|
||||
|
||||
### 最近の変更点
|
||||
- 各モジュールに型注釈と docstring を追加しました。これによりメンテナンス性が向上します。
|
||||
- `transcription_whisper.py` にダウンロード進捗コールバックを明記した実装を追加しました。
|
||||
|
||||
### 依存関係
|
||||
主要な依存:
|
||||
- `speech_recognition` — オーディオ録音と Google 音声認識のラッパー
|
||||
- `pyaudiowpatch` — クロスプラットフォームのオーディオ設定
|
||||
- `pydub` — 音声のチャンネル変換や処理
|
||||
- `faster_whisper`(オプショナル)— ローカルで Whisper を使う場合
|
||||
- `huggingface_hub`(オプショナル)— モデルアーティファクトのダウンロード
|
||||
|
||||
注意: `pydub` は `ffmpeg` が必要です。環境に ffmpeg が無いとワーニングが出ます。
|
||||
|
||||
推奨インストール(任意):
|
||||
|
||||
```powershell
|
||||
pip install speechrecognition pyaudiowpatch pydub faster-whisper huggingface-hub
|
||||
```
|
||||
|
||||
テストでは多くの外部依存をモックするため、全てをインストールする必要はありません。
|
||||
|
||||
### 初回セットアップ
|
||||
1. 必要に応じて `ffmpeg` をインストールしてください(pydub の動作に必要)。
|
||||
2. Whisper ローカルモデルを使う場合、`transcription_whisper.downloadWhisperWeight(root, weight_type, callback, end_callback)` を呼んでモデルを取得します。
|
||||
- `callback(progress: float)` は 0.0〜1.0 の進捗通知です。
|
||||
- 例:
|
||||
|
||||
```python
|
||||
from models.transcription import transcription_whisper as tw
|
||||
tw.downloadWhisperWeight("./", "tiny", callback=lambda p: print(f"{p*100:.1f}%"), end_callback=lambda: print("done"))
|
||||
```
|
||||
|
||||
### API 使用例
|
||||
簡単な `AudioTranscriber` の使い方:
|
||||
|
||||
```python
|
||||
from models.transcription.transcription_transcriber import AudioTranscriber
|
||||
|
||||
# source はライブラリが提供するオーディオソースオブジェクト
|
||||
tr = AudioTranscriber(speaker=False, source=source, phrase_timeout=3, max_phrases=10, transcription_engine="Google")
|
||||
# audio_queue は録音スレッドがプッシュするキュー
|
||||
tr.transcribeAudioQueue(audio_queue, languages=["English"], countries=["United States"])
|
||||
```
|
||||
|
||||
戻り値やエラー処理のルールについては各関数の docstring を参照してください。
|
||||
|
||||
### テスト方針
|
||||
- `AudioTranscriber` と `Whisper` ラッパーはユニットテストでモック化して検証します。
|
||||
- 推奨: `pytest` と `unittest.mock` を使い、以下のケースをカバーします:
|
||||
- 正常系: Google/Whisper の成功パス(モックで期待テキストを返す)
|
||||
- エッジ: 無音、低確信、複数言語
|
||||
- フォールバック: Whisper が利用不可の場合のフォールバック動作
|
||||
|
||||
### トラブルシュート
|
||||
- ffmpeg が見つからない: `pydub` がワーニングを出します。OS に合わせて ffmpeg をインストールしてください。
|
||||
- Whisper のロード時に VRAM エラー: `getWhisperModel` は VRAM 不足を検出して `ValueError("VRAM_OUT_OF_MEMORY", message)` を投げます。デバイス設定や compute_type を調整してください。
|
||||
- ハッシュ不一致やダウンロード失敗: キャッシュや weights ディレクトリを削除して再ダウンロードしてください。
|
||||
|
||||
### 変更履歴
|
||||
- 2025-10-09: 型注釈と docstring を追加、ダウンロード/コールバック仕様を明記。
|
||||
|
||||
---
|
||||
このドキュメントは簡潔な参照用です。さらに詳細な実行手順(ログ収集方法、ffmpeg のインストール手順例など)が必要であれば追記します。
|
||||
# transcription — 文字起こしモジュール
|
||||
概要: マイク/スピーカー音声の録音と Whisper/Google などのエンジンを使った文字起こしを提供するモジュール群です。主なクラスは録音用の Recorder と `AudioTranscriber` です。
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Language table used by transcription components.
|
||||
|
||||
Maps a display language and country to engine-specific language codes.
|
||||
"""
|
||||
|
||||
transcription_lang = {
|
||||
"Afrikaans":{
|
||||
"South Africa":{
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
"""Recorders that wrap speech_recognition microphone interfaces.
|
||||
|
||||
These classes provide small adapters that push raw audio bytes into queues.
|
||||
They intentionally keep a thin API so the rest of the system can mock them
|
||||
in tests.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from speech_recognition import Recognizer, Microphone
|
||||
from pyaudiowpatch import get_sample_size, paInt16
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class BaseRecorder:
|
||||
def __init__(self, source, energy_threshold, dynamic_energy_threshold, record_timeout):
|
||||
def __init__(self, source: Any, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
|
||||
self.recorder = Recognizer()
|
||||
self.recorder.energy_threshold = energy_threshold
|
||||
self.recorder.dynamic_energy_threshold = dynamic_energy_threshold
|
||||
@@ -15,27 +24,29 @@ class BaseRecorder:
|
||||
|
||||
self.source = source
|
||||
|
||||
def adjustForNoise(self):
|
||||
def adjustForNoise(self) -> None:
|
||||
with self.source:
|
||||
self.recorder.adjust_for_ambient_noise(self.source)
|
||||
|
||||
def recordIntoQueue(self, audio_queue):
|
||||
def recordIntoQueue(self, audio_queue: Any) -> None:
|
||||
def record_callback(_, audio):
|
||||
audio_queue.put((audio.get_raw_data(), datetime.now()))
|
||||
|
||||
self.stop, self.pause, self.resume = self.recorder.listen_in_background(self.source, record_callback, phrase_time_limit=self.record_timeout)
|
||||
|
||||
|
||||
class SelectedMicRecorder(BaseRecorder):
|
||||
def __init__(self, device, energy_threshold, dynamic_energy_threshold, record_timeout):
|
||||
source=Microphone(
|
||||
def __init__(self, device: dict, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
|
||||
source = Microphone(
|
||||
device_index=device['index'],
|
||||
sample_rate=int(device["defaultSampleRate"]),
|
||||
)
|
||||
super().__init__(source=source, energy_threshold=energy_threshold, dynamic_energy_threshold=dynamic_energy_threshold, record_timeout=record_timeout)
|
||||
# self.adjustForNoise()
|
||||
|
||||
|
||||
class SelectedSpeakerRecorder(BaseRecorder):
|
||||
def __init__(self, device, energy_threshold, dynamic_energy_threshold, record_timeout):
|
||||
def __init__(self, device: dict, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
|
||||
|
||||
source = Microphone(speaker=True,
|
||||
device_index= device["index"],
|
||||
@@ -47,7 +58,7 @@ class SelectedSpeakerRecorder(BaseRecorder):
|
||||
# self.adjustForNoise()
|
||||
|
||||
class BaseEnergyRecorder:
|
||||
def __init__(self, source):
|
||||
def __init__(self, source: Any) -> None:
|
||||
self.recorder = Recognizer()
|
||||
self.recorder.energy_threshold = 0
|
||||
self.recorder.dynamic_energy_threshold = False
|
||||
@@ -59,27 +70,29 @@ class BaseEnergyRecorder:
|
||||
|
||||
self.source = source
|
||||
|
||||
def adjustForNoise(self):
|
||||
def adjustForNoise(self) -> None:
|
||||
with self.source:
|
||||
self.recorder.adjust_for_ambient_noise(self.source)
|
||||
|
||||
def recordIntoQueue(self, energy_queue):
|
||||
def recordIntoQueue(self, energy_queue: Any) -> None:
|
||||
def recordCallback(_, energy):
|
||||
energy_queue.put(energy)
|
||||
|
||||
self.stop, self.pause, self.resume = self.recorder.listen_energy_in_background(self.source, recordCallback)
|
||||
|
||||
|
||||
class SelectedMicEnergyRecorder(BaseEnergyRecorder):
|
||||
def __init__(self, device):
|
||||
source=Microphone(
|
||||
def __init__(self, device: dict) -> None:
|
||||
source = Microphone(
|
||||
device_index=device['index'],
|
||||
sample_rate=int(device["defaultSampleRate"]),
|
||||
)
|
||||
super().__init__(source=source)
|
||||
# self.adjustForNoise()
|
||||
|
||||
|
||||
class SelectedSpeakerEnergyRecorder(BaseEnergyRecorder):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: dict) -> None:
|
||||
|
||||
source = Microphone(speaker=True,
|
||||
device_index= device["index"],
|
||||
@@ -90,7 +103,15 @@ class SelectedSpeakerEnergyRecorder(BaseEnergyRecorder):
|
||||
# self.adjustForNoise()
|
||||
|
||||
class BaseEnergyAndAudioRecorder:
|
||||
def __init__(self, source, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout, record_timeout):
|
||||
def __init__(
|
||||
self,
|
||||
source: Any,
|
||||
energy_threshold: int,
|
||||
dynamic_energy_threshold: bool,
|
||||
phrase_time_limit: int,
|
||||
phrase_timeout: int,
|
||||
record_timeout: int,
|
||||
) -> None:
|
||||
self.recorder = Recognizer()
|
||||
self.recorder.energy_threshold = energy_threshold
|
||||
self.recorder.dynamic_energy_threshold = dynamic_energy_threshold
|
||||
@@ -104,11 +125,11 @@ class BaseEnergyAndAudioRecorder:
|
||||
|
||||
self.source = source
|
||||
|
||||
def adjustForNoise(self):
|
||||
def adjustForNoise(self) -> None:
|
||||
with self.source:
|
||||
self.recorder.adjust_for_ambient_noise(self.source)
|
||||
|
||||
def recordIntoQueue(self, audio_queue, energy_queue=None):
|
||||
def recordIntoQueue(self, audio_queue: Any, energy_queue: Any = None) -> None:
|
||||
def audioRecordCallback(_, audio):
|
||||
audio_queue.put((audio.get_raw_data(), datetime.now()))
|
||||
|
||||
@@ -121,11 +142,21 @@ class BaseEnergyAndAudioRecorder:
|
||||
phrase_time_limit=self.phrase_time_limit,
|
||||
callback_energy=energyRecordCallback if energy_queue is not None else None,
|
||||
phrase_timeout=self.phrase_timeout,
|
||||
record_timeout=self.record_timeout)
|
||||
record_timeout=self.record_timeout,
|
||||
)
|
||||
|
||||
|
||||
class SelectedMicEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
|
||||
def __init__(self, device, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout:int=1, record_timeout:int=5):
|
||||
source=Microphone(
|
||||
def __init__(
|
||||
self,
|
||||
device: dict,
|
||||
energy_threshold: int,
|
||||
dynamic_energy_threshold: bool,
|
||||
phrase_time_limit: int,
|
||||
phrase_timeout: int = 1,
|
||||
record_timeout: int = 5,
|
||||
) -> None:
|
||||
source = Microphone(
|
||||
device_index=device['index'],
|
||||
sample_rate=int(device["defaultSampleRate"]),
|
||||
)
|
||||
@@ -139,14 +170,23 @@ class SelectedMicEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
|
||||
)
|
||||
# self.adjustForNoise()
|
||||
|
||||
|
||||
class SelectedSpeakerEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
|
||||
def __init__(self, device, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout:int=1, record_timeout:int=5):
|
||||
def __init__(
|
||||
self,
|
||||
device: dict,
|
||||
energy_threshold: int,
|
||||
dynamic_energy_threshold: bool,
|
||||
phrase_time_limit: int,
|
||||
phrase_timeout: int = 1,
|
||||
record_timeout: int = 5,
|
||||
) -> None:
|
||||
|
||||
source = Microphone(speaker=True,
|
||||
device_index= device["index"],
|
||||
sample_rate=int(device["defaultSampleRate"]),
|
||||
chunk_size=get_sample_size(paInt16),
|
||||
channels=device["maxInputChannels"]
|
||||
channels=device["maxInputChannels"],
|
||||
)
|
||||
super().__init__(
|
||||
source=source,
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
"""Runtime transcriber that wraps Google SpeechRecognition and faster-whisper.
|
||||
|
||||
This class focuses on converting incoming raw audio buffers into text using
|
||||
either the Google web recognizer (online) or a local Whisper model (offline).
|
||||
"""
|
||||
|
||||
import time
|
||||
from io import BytesIO
|
||||
from threading import Event
|
||||
import wave
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from speech_recognition import Recognizer, AudioData, AudioFile
|
||||
from speech_recognition.exceptions import UnknownValueError
|
||||
from datetime import timedelta
|
||||
@@ -20,38 +27,71 @@ warnings.simplefilter('ignore', RuntimeWarning)
|
||||
PHRASE_TIMEOUT = 3
|
||||
MAX_PHRASES = 10
|
||||
|
||||
|
||||
class AudioTranscriber:
|
||||
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"):
|
||||
"""Convert queued audio buffers into transcripts.
|
||||
|
||||
Public attributes set by the constructor:
|
||||
- speaker: bool
|
||||
- phrase_timeout: int
|
||||
- max_phrases: int
|
||||
|
||||
Methods are intentionally permissive about input types to match the
|
||||
existing codebase; this wrapper adds typing for clarity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speaker: bool,
|
||||
source: Any,
|
||||
phrase_timeout: int,
|
||||
max_phrases: int,
|
||||
transcription_engine: str,
|
||||
root: Optional[str] = None,
|
||||
whisper_weight_type: Optional[str] = None,
|
||||
device: str = "cpu",
|
||||
device_index: int = 0,
|
||||
compute_type: str = "auto",
|
||||
) -> None:
|
||||
self.speaker = speaker
|
||||
self.phrase_timeout = phrase_timeout
|
||||
self.max_phrases = max_phrases
|
||||
self.transcript_data = []
|
||||
self.transcript_data: List[Dict[str, Any]] = []
|
||||
self.transcript_changed_event = Event()
|
||||
self.audio_recognizer = Recognizer()
|
||||
self.transcription_engine = "Google"
|
||||
self.whisper_model = None
|
||||
self.audio_sources = {
|
||||
"sample_rate": source.SAMPLE_RATE,
|
||||
"sample_width": source.SAMPLE_WIDTH,
|
||||
"channels": source.channels,
|
||||
"last_sample": bytes(),
|
||||
"last_spoken": None,
|
||||
"new_phrase": True,
|
||||
"process_data_func": self.processSpeakerData if speaker else self.processSpeakerData
|
||||
self.audio_sources: Dict[str, Any] = {
|
||||
"sample_rate": source.SAMPLE_RATE,
|
||||
"sample_width": source.SAMPLE_WIDTH,
|
||||
"channels": source.channels,
|
||||
"last_sample": bytes(),
|
||||
"last_spoken": None,
|
||||
"new_phrase": True,
|
||||
"process_data_func": self.processSpeakerData if speaker else self.processSpeakerData,
|
||||
}
|
||||
|
||||
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
||||
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type)
|
||||
self.whisper_model = getWhisperModel(
|
||||
root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type
|
||||
)
|
||||
self.transcription_engine = "Whisper"
|
||||
|
||||
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6):
|
||||
def transcribeAudioQueue(
|
||||
self,
|
||||
audio_queue: Any,
|
||||
languages: List[str],
|
||||
countries: List[str],
|
||||
avg_logprob: float = -0.8,
|
||||
no_speech_prob: float = 0.6,
|
||||
) -> bool:
|
||||
if audio_queue.empty():
|
||||
time.sleep(0.01)
|
||||
return False
|
||||
audio, time_spoken = audio_queue.get()
|
||||
self.updateLastSampleAndPhraseStatus(audio, time_spoken)
|
||||
|
||||
confidences = [{"confidence": 0, "text": "", "language": None}]
|
||||
confidences: List[Dict[str, Any]] = [{"confidence": 0, "text": "", "language": None}]
|
||||
try:
|
||||
audio_data = self.audio_sources["process_data_func"]()
|
||||
match self.transcription_engine:
|
||||
@@ -67,13 +107,19 @@ class AudioTranscriber:
|
||||
except Exception:
|
||||
pass
|
||||
case "Whisper":
|
||||
audio_data = np.frombuffer(audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16).flatten().astype(np.float32) / 32768.0
|
||||
audio_data = np.frombuffer(
|
||||
audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16
|
||||
).flatten().astype(np.float32) / 32768.0
|
||||
if isinstance(audio_data, torch.Tensor):
|
||||
audio_data = audio_data.detach().numpy()
|
||||
|
||||
for language, country in zip(languages, countries):
|
||||
text = ""
|
||||
source_language = transcription_lang[language][country][self.transcription_engine] if len(languages) == 1 else None
|
||||
source_language = (
|
||||
transcription_lang[language][country][self.transcription_engine]
|
||||
if len(languages) == 1
|
||||
else None
|
||||
)
|
||||
segments, info = self.whisper_model.transcribe(
|
||||
audio_data,
|
||||
beam_size=5,
|
||||
@@ -85,13 +131,15 @@ class AudioTranscriber:
|
||||
without_timestamps=True,
|
||||
task="transcribe",
|
||||
vad_filter=False,
|
||||
)
|
||||
)
|
||||
for s in segments:
|
||||
if s.avg_logprob < avg_logprob or s.no_speech_prob > no_speech_prob:
|
||||
continue
|
||||
text += s.text
|
||||
confidences.append({"confidence": info.language_probability, "text": text, "language": language})
|
||||
if (len(languages) == 1) or (transcription_lang[language][country][self.transcription_engine] == info.language):
|
||||
if (len(languages) == 1) or (
|
||||
transcription_lang[language][country][self.transcription_engine] == info.language
|
||||
):
|
||||
break
|
||||
|
||||
except UnknownValueError:
|
||||
@@ -106,7 +154,7 @@ class AudioTranscriber:
|
||||
self.updateTranscript(result)
|
||||
return True
|
||||
|
||||
def updateLastSampleAndPhraseStatus(self, data, time_spoken):
|
||||
def updateLastSampleAndPhraseStatus(self, data: bytes, time_spoken) -> None:
|
||||
source_info = self.audio_sources
|
||||
if source_info["last_spoken"] and time_spoken - source_info["last_spoken"] > timedelta(seconds=self.phrase_timeout):
|
||||
source_info["last_sample"] = bytes()
|
||||
@@ -117,11 +165,13 @@ class AudioTranscriber:
|
||||
source_info["last_sample"] += data
|
||||
source_info["last_spoken"] = time_spoken
|
||||
|
||||
def processMicData(self):
|
||||
audio_data = AudioData(self.audio_sources["last_sample"], self.audio_sources["sample_rate"], self.audio_sources["sample_width"])
|
||||
def processMicData(self) -> AudioData:
|
||||
audio_data = AudioData(
|
||||
self.audio_sources["last_sample"], self.audio_sources["sample_rate"], self.audio_sources["sample_width"]
|
||||
)
|
||||
return audio_data
|
||||
|
||||
def processSpeakerData(self):
|
||||
def processSpeakerData(self) -> AudioData:
|
||||
temp_file = BytesIO()
|
||||
with wave.open(temp_file, 'wb') as wf:
|
||||
wf.setnchannels(self.audio_sources["channels"])
|
||||
@@ -141,7 +191,7 @@ class AudioTranscriber:
|
||||
audio = self.audio_recognizer.record(source)
|
||||
return audio
|
||||
|
||||
def updateTranscript(self, result):
|
||||
def updateTranscript(self, result: dict) -> None:
|
||||
source_info = self.audio_sources
|
||||
transcript = self.transcript_data
|
||||
|
||||
@@ -152,14 +202,14 @@ class AudioTranscriber:
|
||||
else:
|
||||
transcript[0] = result
|
||||
|
||||
def getTranscript(self):
|
||||
def getTranscript(self) -> dict:
|
||||
if len(self.transcript_data) > 0:
|
||||
result = self.transcript_data.pop(-1)
|
||||
else:
|
||||
result = {"confidence": 0, "text": "", "language": None}
|
||||
return result
|
||||
|
||||
def clearTranscriptData(self):
|
||||
def clearTranscriptData(self) -> None:
|
||||
self.transcript_data.clear()
|
||||
self.audio_sources["last_sample"] = bytes()
|
||||
self.audio_sources["new_phrase"] = True
|
||||
@@ -1,6 +1,17 @@
|
||||
"""Helpers for downloading and loading Whisper (faster-whisper) models.
|
||||
|
||||
This module exposes small utilities used by the transcription subsystem:
|
||||
- downloadFile: stream-download a file with optional progress callback
|
||||
- checkWhisperWeight: quick local availability check
|
||||
- downloadWhisperWeight: download model artifacts from HF hub
|
||||
- getWhisperModel: construct and return a WhisperModel instance
|
||||
|
||||
The functions are defensive: failures are caught and reported by the caller.
|
||||
"""
|
||||
|
||||
from os import path as os_path, makedirs as os_makedirs
|
||||
from requests import get as requests_get
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
import huggingface_hub
|
||||
from faster_whisper import WhisperModel
|
||||
import logging
|
||||
@@ -30,24 +41,36 @@ _FILENAMES = [
|
||||
"vocabulary.json",
|
||||
]
|
||||
|
||||
def downloadFile(url, path, func=None):
|
||||
def downloadFile(url: str, path: str, func: Optional[Callable[[float], None]] = None) -> None:
|
||||
"""Download a file from `url` to `path`.
|
||||
|
||||
Args:
|
||||
url: remote URL to download from
|
||||
path: local filepath to write
|
||||
func: optional callback(progress: float) called with a 0.0-1.0 progress
|
||||
"""
|
||||
try:
|
||||
res = requests_get(url, stream=True)
|
||||
res.raise_for_status()
|
||||
file_size = int(res.headers.get('content-length', 0))
|
||||
total_chunk = 0
|
||||
with open(os_path.join(path), 'wb') as file:
|
||||
for chunk in res.iter_content(chunk_size=1024*2000):
|
||||
for chunk in res.iter_content(chunk_size=1024 * 2000):
|
||||
file.write(chunk)
|
||||
if isinstance(func, Callable):
|
||||
if callable(func) and file_size:
|
||||
total_chunk += len(chunk)
|
||||
func(total_chunk/file_size)
|
||||
func(total_chunk / file_size)
|
||||
except Exception:
|
||||
# Silent failure here; caller may re-check or log
|
||||
pass
|
||||
|
||||
def checkWhisperWeight(root, weight_type):
|
||||
def checkWhisperWeight(root: str, weight_type: str) -> bool:
|
||||
"""Return True if a Whisper model for `weight_type` is loadable from disk.
|
||||
|
||||
This attempts to construct a local `WhisperModel` with local_files_only=True
|
||||
to verify required files exist and are compatible.
|
||||
"""
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
result = False
|
||||
try:
|
||||
WhisperModel(
|
||||
path,
|
||||
@@ -58,23 +81,47 @@ def checkWhisperWeight(root, weight_type):
|
||||
num_workers=1,
|
||||
local_files_only=True,
|
||||
)
|
||||
result = True
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
return False
|
||||
|
||||
def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
|
||||
def downloadWhisperWeight(
|
||||
root: str,
|
||||
weight_type: str,
|
||||
callback: Optional[Callable[[float], None]] = None,
|
||||
end_callback: Optional[Callable[[], None]] = None,
|
||||
) -> None:
|
||||
"""Ensure Whisper weight files are present locally; download them if missing.
|
||||
|
||||
Args:
|
||||
root: project root where `weights/whisper` lives
|
||||
weight_type: key from `_MODELS` (eg. "tiny", "base")
|
||||
callback: progress callback for the main model file
|
||||
end_callback: called when download completes
|
||||
"""
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
os_makedirs(path, exist_ok=True)
|
||||
if checkWhisperWeight(root, weight_type) is False:
|
||||
if not checkWhisperWeight(root, weight_type):
|
||||
for filename in _FILENAMES:
|
||||
file_path = os_path.join(path, filename)
|
||||
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
|
||||
downloadFile(url, file_path, func=callback if filename == "model.bin" else None)
|
||||
if isinstance(end_callback, Callable):
|
||||
if callable(end_callback):
|
||||
end_callback()
|
||||
|
||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"):
|
||||
def getWhisperModel(
|
||||
root: str,
|
||||
weight_type: str,
|
||||
device: str = "cpu",
|
||||
device_index: int = 0,
|
||||
compute_type: str = "auto",
|
||||
) -> WhisperModel:
|
||||
"""Return a `WhisperModel` instance loaded from local weights.
|
||||
|
||||
Raises:
|
||||
ValueError: when VRAM shortage is detected (wrapped from RuntimeError)
|
||||
Exception: other loading errors are propagated.
|
||||
"""
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
if compute_type == "auto":
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
@@ -90,11 +137,10 @@ def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_typ
|
||||
)
|
||||
return model
|
||||
except RuntimeError as e:
|
||||
# VRAM不足エラーの検出
|
||||
# Detect VRAM out-of-memory-like errors and raise a clear ValueError
|
||||
error_message = str(e)
|
||||
if "CUDA out of memory" in error_message or "CUBLAS_STATUS_ALLOC_FAILED" in error_message:
|
||||
raise ValueError("VRAM_OUT_OF_MEMORY", error_message)
|
||||
# その他のエラーは通常通り再送出
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user