Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5750.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Abstract `storage-proxy` storage types
14 changes: 0 additions & 14 deletions src/ai/backend/common/dto/storage/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,6 @@ class HuggingFaceImportModelsReq(BaseRequestModel):
""",
examples=["default-minio", "s3-storage", "local-storage"],
)
bucket_name: str = Field(
description="""
Target bucket name within the storage for all models.
The bucket must exist and be writable by the service.
""",
examples=["models", "huggingface-models", "ai-models"],
)


class ReservoirImportModelsReq(BaseRequestModel):
Expand Down Expand Up @@ -228,13 +221,6 @@ class ReservoirImportModelsReq(BaseRequestModel):
""",
examples=["default-minio", "s3-storage", "local-storage"],
)
bucket_name: str = Field(
description="""
Target bucket name within the storage for all models.
The bucket must exist and be writable by the service.
""",
examples=["models", "huggingface-models", "ai-models"],
)


class DeleteObjectReq(BaseRequestModel):
Expand Down
2 changes: 0 additions & 2 deletions src/ai/backend/manager/services/artifact_revision/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ async def import_revision(
],
registry_name=registry_data.name,
storage_name=storage_data.name,
bucket_name=storage_namespace.bucket,
)
)
case ArtifactRegistryType.RESERVOIR:
Expand All @@ -202,7 +201,6 @@ async def import_revision(
],
registry_name=registry_data.name,
storage_name=storage_data.name,
bucket_name=storage_namespace.bucket,
)
)
case _:
Expand Down
6 changes: 1 addition & 5 deletions src/ai/backend/storage/api/v1/registries/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
HuggingFaceService,
HuggingFaceServiceArgs,
)
from ai.backend.storage.services.storages.object_storage import ObjectStorageService

from ....utils import log_client_api_entry

Expand Down Expand Up @@ -143,7 +142,6 @@ async def import_models(
registry_name=body.parsed.registry_name,
models=body.parsed.models,
storage_name=body.parsed.storage_name,
bucket_name=body.parsed.bucket_name,
)

response = HuggingFaceImportModelsResponse(
Expand All @@ -161,8 +159,6 @@ def create_app(ctx: RootContext) -> web.Application:
app["ctx"] = ctx
app["prefix"] = "v1/registries/huggingface"

storage_service = ObjectStorageService(storage_configs=ctx.local_config.storages)

huggingface_registry_configs = dict(
(r.name, r.config)
for r in ctx.local_config.registries
Expand All @@ -171,7 +167,7 @@ def create_app(ctx: RootContext) -> web.Application:
huggingface_service = HuggingFaceService(
HuggingFaceServiceArgs(
background_task_manager=ctx.background_task_manager,
storage_service=storage_service,
storage_pool=ctx.storage_pool,
registry_configs=huggingface_registry_configs,
event_producer=ctx.event_producer,
)
Expand Down
3 changes: 1 addition & 2 deletions src/ai/backend/storage/api/v1/registries/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ async def import_models(
registry_name=body.parsed.registry_name,
models=body.parsed.models,
storage_name=body.parsed.storage_name,
bucket_name=body.parsed.bucket_name,
)

return APIResponse.build(
Expand All @@ -74,7 +73,7 @@ def create_app(ctx: RootContext) -> web.Application:
ReservoirServiceArgs(
background_task_manager=ctx.background_task_manager,
event_producer=ctx.event_producer,
storage_configs=ctx.local_config.storages,
storage_pool=ctx.storage_pool,
reservoir_registry_configs=reservoir_registry_configs,
)
)
Expand Down
24 changes: 11 additions & 13 deletions src/ai/backend/storage/api/v1/storages/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
UploadObjectReq,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.storage.config.unified import (
ObjectStorageConfig,
)

from ....services.storages.object_storage import ObjectStorageService
from ....storages.base import StoragePool
from ....utils import log_client_api_entry

if TYPE_CHECKING:
Expand All @@ -42,13 +40,13 @@


class ObjectStorageAPIHandler:
_storage_configs: list[ObjectStorageConfig]
_storage_pool: StoragePool

def __init__(
self,
storage_configs: list[ObjectStorageConfig],
storage_pool: StoragePool,
) -> None:
self._storage_configs = storage_configs
self._storage_pool = storage_pool

@api_handler
async def upload_object(
Expand All @@ -69,7 +67,7 @@ async def upload_object(

await log_client_api_entry(log, "upload_object", req)

storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)

file_part = await file_reader.next()
while file_part and not getattr(file_part, "filename", None):
Expand Down Expand Up @@ -115,7 +113,7 @@ async def download_file(
bucket_name = path.parsed.bucket_name

await log_client_api_entry(log, "download_file", req)
storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)
download_stream = storage_service.stream_download(storage_name, bucket_name, filepath)

return APIStreamResponse(
Expand All @@ -141,7 +139,7 @@ async def presigned_upload_url(
bucket_name = path.parsed.bucket_name

await log_client_api_entry(log, "presigned_upload_url", req)
storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)
response = await storage_service.generate_presigned_upload_url(
storage_name, bucket_name, req.key
)
Expand All @@ -167,7 +165,7 @@ async def presigned_download_url(
bucket_name = path.parsed.bucket_name

await log_client_api_entry(log, "presigned_download_url", req)
storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)
response = await storage_service.generate_presigned_download_url(
storage_name, bucket_name, filepath
)
Expand All @@ -194,7 +192,7 @@ async def get_object_meta(

await log_client_api_entry(log, "get_object_meta", req)

storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)
response = await storage_service.get_object_info(storage_name, bucket_name, filepath)

return APIResponse.build(
Expand All @@ -218,7 +216,7 @@ async def delete_object(
bucket_name = path.parsed.bucket_name

await log_client_api_entry(log, "delete_object", req)
storage_service = ObjectStorageService(self._storage_configs)
storage_service = ObjectStorageService(self._storage_pool)

await storage_service.delete_object(storage_name, bucket_name, prefix)

Expand All @@ -233,7 +231,7 @@ def create_app(ctx: RootContext) -> web.Application:
app["prefix"] = "v1/storages/s3"

api_handler = ObjectStorageAPIHandler(
storage_configs=ctx.local_config.storages,
storage_pool=ctx.storage_pool,
)
app.router.add_route(
"GET", "/{storage_name}/buckets/{bucket_name}/object/meta", api_handler.get_object_meta
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/storage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
StoragePluginContext,
)
from .services.service import VolumeService
from .storages.base import StoragePool
from .types import VolumeInfo
from .volumes.abc import AbstractVolume
from .volumes.cephfs import CephFSVolume
Expand Down Expand Up @@ -117,6 +118,7 @@ class RootContext:
local_config: StorageProxyUnifiedConfig
dsn: str | None
volume_pool: VolumePool
storage_pool: StoragePool
event_producer: EventProducer
event_dispatcher: EventDispatcher
watcher: WatcherClient | None
Expand All @@ -132,6 +134,7 @@ def __init__(
etcd: AsyncEtcd,
*,
volume_pool: VolumePool,
storage_pool: StoragePool,
background_task_manager: BackgroundTaskManager,
event_producer: EventProducer,
event_dispatcher: EventDispatcher,
Expand All @@ -158,6 +161,7 @@ def __init__(
}
self.metric_registry = metric_registry
self.volume_pool = volume_pool
self.storage_pool = storage_pool
self.background_task_manager = background_task_manager

async def __aenter__(self) -> None:
Expand Down
13 changes: 13 additions & 0 deletions src/ai/backend/storage/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,16 @@ def error_code(cls) -> ErrorCode:
operation=ErrorOperation.CREATE,
error_detail=ErrorDetail.INTERNAL_ERROR,
)


class NotImplementedAPI(BackendAIError, web.HTTPBadRequest):
error_type = "https://api.backend.ai/probs/storage/api/not-implemented"
error_title = "API Not Implemented"

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.GENERIC,
error_detail=ErrorDetail.NOT_IMPLEMENTED,
)
1 change: 1 addition & 0 deletions src/ai/backend/storage/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ async def check_and_upgrade(
event_dispatcher=event_dispatcher,
watcher=None,
volume_pool=None, # type: ignore[arg-type]
storage_pool=None, # type: ignore[arg-type]
background_task_manager=None, # type: ignore[arg-type]
)

Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/storage/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ async def server_main(
from .bgtask.registry import BgtaskHandlerRegistryCreator
from .context import RootContext
from .migration import check_latest
from .storages.base import AbstractStorage, StoragePool
from .storages.object_storage import ObjectStorage
from .volumes.pool import VolumePool
from .watcher import WatcherClient

Expand Down Expand Up @@ -211,13 +213,21 @@ async def server_main(
valkey_client=valkey_client,
server_id=local_config.storage_proxy.node_id,
)

# Create StoragePool for object storage
storages: dict[str, AbstractStorage] = {}
for config in local_config.storages:
storages[config.name] = ObjectStorage(config)
storage_pool = StoragePool(storages)

ctx = RootContext(
pid=os.getpid(),
node_id=local_config.storage_proxy.node_id,
pidx=pidx,
local_config=local_config,
etcd=etcd,
volume_pool=volume_pool,
storage_pool=storage_pool,
background_task_manager=bgtask_mgr,
event_producer=event_producer,
event_dispatcher=event_dispatcher,
Expand Down
Loading
Loading