[Update] Add compute type management for CTranslate2 and Whisper models
This commit is contained in:
@@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang
|
||||
from models.translation.translation_utils import ctranslate2_weights
|
||||
from models.transcription.transcription_languages import transcription_lang
|
||||
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
||||
from utils import errorLogging, validateDictStructure
|
||||
from utils import errorLogging, validateDictStructure, getComputeTypeList
|
||||
|
||||
json_serializable_vars = {}
|
||||
def json_serializable(var_name):
|
||||
@@ -135,6 +135,14 @@ class Config:
|
||||
def SELECTABLE_COMPUTE_DEVICE_LIST(self):
|
||||
return self._SELECTABLE_COMPUTE_DEVICE_LIST
|
||||
|
||||
@property
|
||||
def SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST(self):
|
||||
return self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST
|
||||
|
||||
@property
|
||||
def SELECTABLE_WHISPER_COMPUTE_TYPE_LIST(self):
|
||||
return self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST
|
||||
|
||||
@property
|
||||
def SEND_MESSAGE_BUTTON_TYPE_LIST(self):
|
||||
return self._SEND_MESSAGE_BUTTON_TYPE_LIST
|
||||
@@ -814,6 +822,18 @@ class Config:
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('CTRANSLATE2_COMPUTE_TYPE')
|
||||
def CTRANSLATE2_COMPUTE_TYPE(self):
|
||||
return self._CTRANSLATE2_COMPUTE_TYPE
|
||||
|
||||
@CTRANSLATE2_COMPUTE_TYPE.setter
|
||||
def CTRANSLATE2_COMPUTE_TYPE(self, value):
|
||||
if isinstance(value, str):
|
||||
if value in self.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST:
|
||||
self._CTRANSLATE2_COMPUTE_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('WHISPER_WEIGHT_TYPE')
|
||||
def WHISPER_WEIGHT_TYPE(self):
|
||||
@@ -826,6 +846,18 @@ class Config:
|
||||
self._WHISPER_WEIGHT_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('WHISPER_COMPUTE_TYPE')
|
||||
def WHISPER_COMPUTE_TYPE(self):
|
||||
return self._WHISPER_COMPUTE_TYPE
|
||||
|
||||
@WHISPER_COMPUTE_TYPE.setter
|
||||
def WHISPER_COMPUTE_TYPE(self, value):
|
||||
if isinstance(value, str):
|
||||
if value in self.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST:
|
||||
self._WHISPER_COMPUTE_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('AUTO_CLEAR_MESSAGE_BOX')
|
||||
def AUTO_CLEAR_MESSAGE_BOX(self):
|
||||
@@ -1051,6 +1083,8 @@ class Config:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)})
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"})
|
||||
self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
||||
self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
||||
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
||||
self._SEND_MESSAGE_FORMAT_PARTS = {
|
||||
"message": {
|
||||
@@ -1189,7 +1223,9 @@ class Config:
|
||||
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
||||
self._CTRANSLATE2_COMPUTE_TYPE = "auto"
|
||||
self._WHISPER_WEIGHT_TYPE = "base"
|
||||
self._WHISPER_COMPUTE_TYPE = "auto"
|
||||
self._AUTO_CLEAR_MESSAGE_BOX = True
|
||||
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
||||
self._OVERLAY_SMALL_LOG = False
|
||||
|
||||
Reference in New Issue
Block a user