Merge branch 'CTranslate2' into develop

This commit is contained in:
misyaguziya
2024-01-12 13:42:41 +09:00

View File

@@ -3,7 +3,6 @@ from zipfile import ZipFile
from os import path as os_path
from os import makedirs as os_makedirs
from requests import get as requests_get
from tqdm import tqdm
from typing import Callable
import hashlib
@@ -39,7 +38,7 @@ def calculate_file_hash(file_path, block_size=65536):
return hash_object.hexdigest()
def downloadCTranslate2Weight(path, weight_type="small", func=None):
def downloadCTranslate2Weight(path, weight_type="m2m100_418m", func=None):
url = ctranslate2_weights[weight_type]["url"]
filename = 'weight.zip'
directory_name = 'weight'
@@ -64,16 +63,13 @@ def downloadCTranslate2Weight(path, weight_type="small", func=None):
with tempfile.TemporaryDirectory() as tmp_path:
res = requests_get(url, stream=True)
file_size = int(res.headers.get('content-length', 0))
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
total_chunk = 0
with open(os_path.join(tmp_path, filename), 'wb') as file:
for chunk in res.iter_content(chunk_size=1024*5):
file.write(chunk)
pbar.update(len(chunk))
if isinstance(func, Callable):
total_chunk += len(chunk)
func(total_chunk/file_size)
pbar.close()
with ZipFile(os_path.join(tmp_path, filename)) as zf:
zf.extractall(os_path.join(current_directory, directory_name))