🐛[bugfix] Model : tqdmを削除
This commit is contained in:
@@ -3,7 +3,6 @@ from zipfile import ZipFile
|
||||
from os import path as os_path
|
||||
from os import makedirs as os_makedirs
|
||||
from requests import get as requests_get
|
||||
from tqdm import tqdm
|
||||
from typing import Callable
|
||||
import hashlib
|
||||
|
||||
@@ -39,7 +38,7 @@ def calculate_file_hash(file_path, block_size=65536):
|
||||
|
||||
return hash_object.hexdigest()
|
||||
|
||||
def downloadCTranslate2Weight(path, weight_type="small", func=None):
|
||||
def downloadCTranslate2Weight(path, weight_type="m2m100_418m", func=None):
|
||||
url = ctranslate2_weights[weight_type]["url"]
|
||||
filename = 'weight.zip'
|
||||
directory_name = 'weight'
|
||||
@@ -64,16 +63,13 @@ def downloadCTranslate2Weight(path, weight_type="small", func=None):
|
||||
with tempfile.TemporaryDirectory() as tmp_path:
|
||||
res = requests_get(url, stream=True)
|
||||
file_size = int(res.headers.get('content-length', 0))
|
||||
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
|
||||
total_chunk = 0
|
||||
with open(os_path.join(tmp_path, filename), 'wb') as file:
|
||||
for chunk in res.iter_content(chunk_size=1024*5):
|
||||
file.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
if isinstance(func, Callable):
|
||||
total_chunk += len(chunk)
|
||||
func(total_chunk/file_size)
|
||||
pbar.close()
|
||||
|
||||
with ZipFile(os_path.join(tmp_path, filename)) as zf:
|
||||
zf.extractall(os_path.join(current_directory, directory_name))
|
||||
|
||||
Reference in New Issue
Block a user