Merge branch 'bugfix_trans_model_name' into develop
This commit is contained in:
@@ -2870,10 +2870,13 @@ class Controller:
|
|||||||
if hasattr(self, '_ctranslate2_available_cache'):
|
if hasattr(self, '_ctranslate2_available_cache'):
|
||||||
# 起動時のキャッシュを使用: 選択中の重みタイプのみ設定
|
# 起動時のキャッシュを使用: 選択中の重みタイプのみ設定
|
||||||
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT[config.CTRANSLATE2_WEIGHT_TYPE] = self._ctranslate2_available_cache
|
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT[config.CTRANSLATE2_WEIGHT_TYPE] = self._ctranslate2_available_cache
|
||||||
else:
|
|
||||||
# 通常時は全重みタイプをチェック
|
# すべての重みタイプをチェック(キャッシュされていないものだけ)
|
||||||
for weight_type in config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT.keys():
|
for weight_type in config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT.keys():
|
||||||
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT[weight_type] = model.checkTranslatorCTranslate2ModelWeight(weight_type)
|
# 選択中のウェイトはキャッシュで設定済みなのでスキップ
|
||||||
|
if hasattr(self, '_ctranslate2_available_cache') and weight_type == config.CTRANSLATE2_WEIGHT_TYPE:
|
||||||
|
continue
|
||||||
|
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT[weight_type] = model.checkTranslatorCTranslate2ModelWeight(weight_type)
|
||||||
|
|
||||||
def updateTranslationEngineAndEngineList(self):
|
def updateTranslationEngineAndEngineList(self):
|
||||||
engines = config.SELECTED_TRANSLATION_ENGINES
|
engines = config.SELECTED_TRANSLATION_ENGINES
|
||||||
@@ -2899,10 +2902,13 @@ class Controller:
|
|||||||
if hasattr(self, '_whisper_available_cache'):
|
if hasattr(self, '_whisper_available_cache'):
|
||||||
# 起動時のキャッシュを使用: 選択中の重みタイプのみ設定
|
# 起動時のキャッシュを使用: 選択中の重みタイプのみ設定
|
||||||
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT[config.WHISPER_WEIGHT_TYPE] = self._whisper_available_cache
|
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT[config.WHISPER_WEIGHT_TYPE] = self._whisper_available_cache
|
||||||
else:
|
|
||||||
# 通常時は全重みタイプをチェック
|
# すべての重みタイプをチェック(キャッシュされていないものだけ)
|
||||||
for weight_type in config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT.keys():
|
for weight_type in config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT.keys():
|
||||||
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT[weight_type] = model.checkTranscriptionWhisperModelWeight(weight_type)
|
# 選択中のウェイトはキャッシュで設定済みなのでスキップ
|
||||||
|
if hasattr(self, '_whisper_available_cache') and weight_type == config.WHISPER_WEIGHT_TYPE:
|
||||||
|
continue
|
||||||
|
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT[weight_type] = model.checkTranscriptionWhisperModelWeight(weight_type)
|
||||||
|
|
||||||
def updateTranscriptionEngine(self):
|
def updateTranscriptionEngine(self):
|
||||||
weight_type = config.WHISPER_WEIGHT_TYPE
|
weight_type = config.WHISPER_WEIGHT_TYPE
|
||||||
@@ -3134,6 +3140,9 @@ class Controller:
|
|||||||
# Download weights
|
# Download weights
|
||||||
if connected_network is True:
|
if connected_network is True:
|
||||||
printLog("Download CTranslate2 Model Weight")
|
printLog("Download CTranslate2 Model Weight")
|
||||||
|
# 後方互換用
|
||||||
|
model.backwardCompatibleTranslatorCTranslate2ModelRenameWeightsDir()
|
||||||
|
|
||||||
weight_type = config.CTRANSLATE2_WEIGHT_TYPE
|
weight_type = config.CTRANSLATE2_WEIGHT_TYPE
|
||||||
th_download_ctranslate2 = None
|
th_download_ctranslate2 = None
|
||||||
if model.checkTranslatorCTranslate2ModelWeight(weight_type) is False:
|
if model.checkTranslatorCTranslate2ModelWeight(weight_type) is False:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from models.transcription.transcription_recorder import SelectedMicEnergyRecorde
|
|||||||
from models.transcription.transcription_transcriber import AudioTranscriber
|
from models.transcription.transcription_transcriber import AudioTranscriber
|
||||||
from models.translation.translation_languages import translation_lang
|
from models.translation.translation_languages import translation_lang
|
||||||
from models.transcription.transcription_languages import transcription_lang
|
from models.transcription.transcription_languages import transcription_lang
|
||||||
from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight, downloadCTranslate2Tokenizer
|
from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight, downloadCTranslate2Tokenizer, backwardCompatibleRenameWeightsDir
|
||||||
from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight
|
from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight
|
||||||
from models.transliteration.transliteration_transliterator import Transliterator
|
from models.transliteration.transliteration_transliterator import Transliterator
|
||||||
from models.overlay.overlay import Overlay
|
from models.overlay.overlay import Overlay
|
||||||
@@ -158,6 +158,9 @@ class Model:
|
|||||||
# Log and continue; callers should handle missing features.
|
# Log and continue; callers should handle missing features.
|
||||||
errorLogging()
|
errorLogging()
|
||||||
|
|
||||||
|
def backwardCompatibleTranslatorCTranslate2ModelRenameWeightsDir(self):
|
||||||
|
return backwardCompatibleRenameWeightsDir(config.PATH_LOCAL)
|
||||||
|
|
||||||
def checkTranslatorCTranslate2ModelWeight(self, weight_type:str):
|
def checkTranslatorCTranslate2ModelWeight(self, weight_type:str):
|
||||||
return checkCTranslate2Weight(config.PATH_LOCAL, weight_type)
|
return checkCTranslate2Weight(config.PATH_LOCAL, weight_type)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from os import path as os_path
|
from os import path as os_path
|
||||||
from os import makedirs as os_makedirs
|
from os import makedirs as os_makedirs
|
||||||
|
from os import rename as os_rename
|
||||||
from requests import get as requests_get
|
from requests import get as requests_get
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import transformers
|
import transformers
|
||||||
@@ -47,9 +48,23 @@ ctranslate2_weights = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def backwardCompatibleRenameWeightsDir(root: str):
|
||||||
|
# 後方互換のためファイル名を変更する
|
||||||
|
legacy_dirs = {
|
||||||
|
"m2m100_418M": "m2m100_418M-ct2-int8",
|
||||||
|
"m2m100_12b": "m2m100_1.2B-ct2-int8",
|
||||||
|
}
|
||||||
|
|
||||||
|
for weight_type_old, weight_type_new in legacy_dirs.items():
|
||||||
|
path = os_path.join(root, "weights", "ctranslate2", weight_type_new)
|
||||||
|
old_path = os_path.join(root, "weights", "ctranslate2", weight_type_old)
|
||||||
|
if os_path.isdir(old_path):
|
||||||
|
os_rename(old_path, path)
|
||||||
|
|
||||||
def checkCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
def checkCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
||||||
weight_directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
weight_directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||||
path = os_path.join(root, "weights", "ctranslate2", weight_directory_name)
|
path = os_path.join(root, "weights", "ctranslate2", weight_directory_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# モデルロード可能かどうかで判定
|
# モデルロード可能かどうかで判定
|
||||||
compute_type = getBestComputeType("cpu", 0)
|
compute_type = getBestComputeType("cpu", 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user