Skip to content

Commit bd8e398

Browse files
Merge pull request #178 from qinguoyi/feat-update-modelloader
feat:update model loader
2 parents da979a1 + a64129c commit bd8e398

File tree

24 files changed

+239
-131
lines changed

24 files changed

+239
-131
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ cache
3131
__pycache__
3232
*.pyc
3333
.pytest_cache
34-
*.tgz
34+
*.tgz
35+
.huggingface

api/core/v1alpha1/model_types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,18 @@ type ModelHub struct {
4444
// the whole repo which includes all kinds of quantized models.
4545
// TODO: this is only supported with Huggingface, add support for ModelScope
4646
// in the near future.
47+
// Note: once filename is set, allowPatterns and ignorePatterns should be left unset.
4748
Filename *string `json:"filename,omitempty"`
4849
// Revision refers to a Git revision id which can be a branch name, a tag, or a commit hash.
4950
// +kubebuilder:default=main
5051
// +optional
5152
Revision *string `json:"revision,omitempty"`
53+
// AllowPatterns refers to files matched with at least one pattern will be downloaded.
54+
// +optional
55+
AllowPatterns []string `json:"allowPatterns,omitempty"`
56+
// IgnorePatterns refers to files matched with any of the patterns will not be downloaded.
57+
// +optional
58+
IgnorePatterns []string `json:"ignorePatterns,omitempty"`
5259
}
5360

5461
// URIProtocol represents the protocol of the URI.

api/core/v1alpha1/zz_generated.deepcopy.go

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

client-go/applyconfiguration/core/v1alpha1/modelhub.go

Lines changed: 26 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

config/crd/bases/llmaz.io_openmodels.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,25 @@ spec:
110110
description: ModelHub represents the model registry for model
111111
downloads.
112112
properties:
113+
allowPatterns:
114+
description: AllowPatterns refers to only files matching at
115+
least one pattern are downloaded.
116+
items:
117+
type: string
118+
type: array
113119
filename:
114120
description: |-
115121
Filename refers to a specified model file rather than the whole repo.
116122
This is helpful to download a specified GGUF model rather than downloading
117123
the whole repo which includes all kinds of quantized models.
118124
in the near future.
119125
type: string
126+
ignorePatterns:
127+
description: IgnorePatterns refers to files matching any of
128+
the patterns are not downloaded.
129+
items:
130+
type: string
131+
type: array
120132
modelID:
121133
description: |-
122134
ModelID refers to the model identifier on model hub,

llmaz/main.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,39 @@
1717
import os
1818
from datetime import datetime
1919

20+
from llmaz.model_loader.constant import *
21+
2022
from llmaz.model_loader.objstore.objstore import model_download
2123
from llmaz.model_loader.model_hub.hub_factory import HubFactory
22-
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE
24+
from llmaz.model_loader.model_hub.huggingface import HUB_HUGGING_FACE
2325
from llmaz.util.logger import Logger
2426

25-
2627
if __name__ == "__main__":
27-
model_source_type = os.getenv("MODEL_SOURCE_TYPE")
28+
model_source_type = os.getenv(ENV_HUB_MODEL_SOURCE_TYPE)
2829
start_time = datetime.now()
2930

3031
if model_source_type == "modelhub":
31-
hub_name = os.getenv("MODEL_HUB_NAME", HUGGING_FACE)
32-
revision = os.getenv("REVISION")
33-
model_id = os.getenv("MODEL_ID")
34-
model_file_name = os.getenv("MODEL_FILENAME")
32+
hub_name = os.getenv(ENV_HUB_MODEL_HUB_NAME, HUB_HUGGING_FACE)
33+
revision = os.getenv(ENV_HUB_REVISION)
34+
model_id = os.getenv(ENV_HUB_MODEL_ID)
35+
model_file_name = os.getenv(ENV_HUB_MODEL_FILENAME)
36+
model_allow_patterns = os.getenv(ENV_HUB_MODEL_ALLOW_PATTERNS)
37+
model_ignore_patterns = os.getenv(ENV_HUB_MODEL_IGNORE_PATTERNS)
3538

3639
if not model_id:
3740
raise EnvironmentError(f"Environment variable '{model_id}' not found.")
38-
3941
hub = HubFactory.new(hub_name)
40-
hub.load_model(model_id, model_file_name, revision)
42+
model_allow_patterns_list, model_ignore_patterns_list = [], []
43+
if model_allow_patterns:
44+
model_allow_patterns_list = model_allow_patterns.split(',')
45+
if model_ignore_patterns:
46+
model_ignore_patterns_list = model_ignore_patterns.split(',')
47+
hub.load_model(model_id, model_file_name, revision, model_allow_patterns_list, model_ignore_patterns_list)
4148
elif model_source_type == "objstore":
42-
provider = os.getenv("PROVIDER")
43-
endpoint = os.getenv("ENDPOINT")
44-
bucket = os.getenv("BUCKET")
45-
src = os.getenv("MODEL_PATH")
49+
provider = os.getenv(ENV_OBJ_PROVIDER)
50+
endpoint = os.getenv(ENV_OBJ_ENDPOINT)
51+
bucket = os.getenv(ENV_OBJ_BUCKET)
52+
src = os.getenv(ENV_OBJ_MODEL_PATH)
4653

4754
model_download(provider=provider, endpoint=endpoint, bucket=bucket, src=src)
4855
else:

llmaz/model_loader/constant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
MODEL_LOCAL_DIR = "/workspace/models/"
2+
HUB_HUGGING_FACE = "Huggingface"
3+
HUB_MODEL_SCOPE = "ModelScope"
4+
5+
ENV_HUB_MODEL_SOURCE_TYPE = "MODEL_SOURCE_TYPE"
6+
ENV_HUB_MODEL_HUB_NAME = "MODEL_HUB_NAME"
7+
ENV_HUB_REVISION = "REVISION"
8+
ENV_HUB_MODEL_ID = "MODEL_ID"
9+
ENV_HUB_MODEL_FILENAME = "MODEL_FILENAME"
10+
ENV_HUB_MODEL_ALLOW_PATTERNS = "MODEL_ALLOW_PATTERNS"
11+
ENV_HUB_MODEL_IGNORE_PATTERNS = "MODEL_IGNORE_PATTERNS"
12+
13+
ENV_OBJ_PROVIDER = "PROVIDER"
14+
ENV_OBJ_ENDPOINT = "ENDPOINT"
15+
ENV_OBJ_BUCKET = "BUCKET"
16+
ENV_OBJ_MODEL_PATH = "MODEL_PATH"

llmaz/model_loader/defaults.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

llmaz/model_loader/model_hub/hub_factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
16+
from llmaz.model_loader.constant import HUB_HUGGING_FACE, HUB_MODEL_SCOPE
1717
from llmaz.model_loader.model_hub.model_hub import ModelHub
18-
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE, Huggingface
19-
from llmaz.model_loader.model_hub.modelscope import MODEL_SCOPE, ModelScope
20-
18+
from llmaz.model_loader.model_hub.huggingface import Huggingface
19+
from llmaz.model_loader.model_hub.modelscope import ModelScope
2120

2221
SUPPORT_MODEL_HUBS = {
23-
HUGGING_FACE: Huggingface,
24-
MODEL_SCOPE: ModelScope,
22+
HUB_HUGGING_FACE: Huggingface,
23+
HUB_MODEL_SCOPE: ModelScope,
2524
}
2625

2726

2827
class HubFactory:
28+
2929
@classmethod
3030
def new(cls, hub_name: str) -> ModelHub:
3131
if hub_name not in SUPPORT_MODEL_HUBS.keys():

llmaz/model_loader/model_hub/huggingface.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,51 @@
1717
import concurrent.futures
1818
import os
1919

20-
from huggingface_hub import hf_hub_download, list_repo_files
20+
from huggingface_hub import snapshot_download
2121

22-
from llmaz.model_loader.defaults import MODEL_LOCAL_DIR
22+
from llmaz.model_loader.constant import MODEL_LOCAL_DIR, HUB_HUGGING_FACE
2323
from llmaz.model_loader.model_hub.model_hub import (
24-
HUGGING_FACE,
25-
MAX_WORKERS,
2624
ModelHub,
2725
)
2826
from llmaz.util.logger import Logger
2927
from llmaz.model_loader.model_hub.util import get_folder_total_size
3028

31-
from typing import Optional
29+
from typing import Optional, List
3230

3331

3432
class Huggingface(ModelHub):
3533
@classmethod
3634
def name(cls) -> str:
37-
return HUGGING_FACE
35+
return HUB_HUGGING_FACE
3836

3937
@classmethod
4038
def load_model(
41-
cls, model_id: str, filename: Optional[str], revision: Optional[str]
39+
cls,
40+
model_id: str,
41+
filename: Optional[str],
42+
revision: Optional[str],
43+
allow_patterns: Optional[List[str]],
44+
ignore_patterns: Optional[List[str]],
4245
) -> None:
4346
Logger.info(
4447
f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}"
4548
)
4649

47-
if filename:
48-
hf_hub_download(
49-
repo_id=model_id,
50-
filename=filename,
51-
local_dir=MODEL_LOCAL_DIR,
52-
revision=revision,
53-
)
54-
file_size = os.path.getsize(MODEL_LOCAL_DIR + filename) / (1024**3)
55-
Logger.info(
56-
f"The total size of {MODEL_LOCAL_DIR + filename} is {file_size: .2f} GB"
57-
)
58-
return
59-
6050
local_dir = os.path.join(
61-
MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}"
51+
MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}"
6252
)
6353

64-
# # TODO: Should we verify the download is finished?
65-
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
66-
futures = []
67-
for file in list_repo_files(repo_id=model_id):
68-
# TODO: support version management, right now we didn't distinguish with them.
69-
futures.append(
70-
executor.submit(
71-
hf_hub_download,
72-
repo_id=model_id,
73-
filename=file,
74-
local_dir=local_dir,
75-
revision=revision,
76-
).add_done_callback(handle_completion)
77-
)
54+
if filename:
55+
allow_patterns.append(filename)
56+
local_dir = MODEL_LOCAL_DIR
57+
58+
snapshot_download(
59+
repo_id=model_id,
60+
revision=revision,
61+
local_dir=local_dir,
62+
allow_patterns=allow_patterns,
63+
ignore_patterns=ignore_patterns,
64+
)
7865

7966
total_size = get_folder_total_size(local_dir)
80-
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")
81-
82-
83-
def handle_completion(future):
84-
filename = future.result()
85-
Logger.info(f"Download completed for {filename}")
67+
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")

llmaz/model_loader/model_hub/model_hub.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
"""
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Optional
19-
20-
MAX_WORKERS = 4
21-
HUGGING_FACE = "Huggingface"
22-
MODEL_SCOPE = "ModelScope"
18+
from typing import Optional, List
2319

2420

2521
class ModelHub(ABC):
@@ -31,6 +27,11 @@ def name(cls) -> str:
3127
@classmethod
3228
@abstractmethod
3329
def load_model(
34-
cls, model_id: str, filename: Optional[str], revision: Optional[str]
30+
cls,
31+
model_id: str,
32+
filename: Optional[str],
33+
revision: Optional[str],
34+
allow_patterns: Optional[List[str]],
35+
ignore_patterns: Optional[List[str]],
3536
) -> None:
3637
pass

0 commit comments

Comments
 (0)