Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1943469 Implement Strategy Pattern for I/O in Snowpark-Checkpoints-Validators #121

Draft
wants to merge 1 commit into
base: feature/SNOW-1928735/main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["EnvStrategy", "IOFileManager", "IODefaultStrategy"]

from snowflake.snowpark_checkpoints.io_utils.io_env_strategy import (
EnvStrategy,
)
from snowflake.snowpark_checkpoints.io_utils.io_default_strategy import (
IODefaultStrategy,
)
from snowflake.snowpark_checkpoints.io_utils.io_file_manager import (
IOFileManager,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os

from typing import Optional

from snowflake.snowpark_checkpoints.io_utils import EnvStrategy


class IODefaultStrategy(EnvStrategy):
def mkdir(self, path: str, exist_ok=False) -> bool:
try:
os.makedirs(path, exist_ok=exist_ok)
return True
except Exception:
return False

def folder_exists(self, path: str) -> bool:
try:
return os.path.isdir(path)
except Exception:
return False

def file_exists(self, path: str) -> bool:
try:
return os.path.isfile(path)
except Exception:
return False

def write(self, file_path: str, file_content: str, overwrite: bool = True) -> bool:
try:
mode = "w" if overwrite else "x"
with open(file_path, mode) as file:
file.write(file_content)
return True
except Exception:
return False

def read(
self, file_path: str, mode: str = "r", encoding: str = None
) -> Optional[str]:
try:
with open(file_path, mode=mode, encoding=encoding) as file:
return file.read()
except Exception:
return None

def read_bytes(self, file_path: str) -> Optional[bytes]:
try:
with open(file_path, mode="rb") as f:
return f.read()
except Exception:
return None

def ls(self, path: str, recursive: bool = False) -> list[str]:
try:
return glob.glob(path, recursive=recursive)
except Exception:
return []

def getcwd(self) -> str:
try:
return os.getcwd()
except Exception:
return ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Optional


class EnvStrategy(ABC):

"""An abstract base class that defines methods for file and directory operations.

Subclasses should implement these methods to provide environment-specific behavior.
"""

@abstractmethod
def mkdir(self, path: str, exist_ok=False) -> bool:
"""Create a directory.

Args:
path: The name of the directory to create.
exist_ok: If False, an error is raised if the directory already exists.

Returns:
bool: True if the directory was created successfully, False otherwise.

"""

@abstractmethod
def folder_exists(self, path: str) -> bool:
"""Check if a folder exists.

Args:
path: The path to the folder.

Returns:
bool: True if the folder exists, False otherwise.

"""

@abstractmethod
def file_exists(self, path: str) -> bool:
"""Check if a file exists.

Args:
path: The path to the file.

Returns:
bool: True if the file exists, False otherwise.

"""

@abstractmethod
def write(self, file_path: str, file_content: str, overwrite: bool = True) -> bool:
"""Write content to a file.

Args:
file_path: The name of the file to write to.
file_content: The content to write to the file.
overwrite: If True, overwrite the file if it exists.

Returns:
bool: True if the file was written successfully, False otherwise.

"""

@abstractmethod
def read(
self, file_path: str, mode: str = "r", encoding: str = None
) -> Optional[str]:
"""Read content from a file.

Args:
file_path: The path to the file to read from.
mode: The mode in which to open the file.
encoding: The encoding to use for reading the file.

Returns:
Optional[str]: The content of the file, or None if an error occurred.

"""

@abstractmethod
def read_bytes(self, file_path: str) -> Optional[bytes]:
"""Read binary content from a file.

Args:
file_path: The path to the file to read from.

Returns:
Optional[BinaryIO]: The binary content of the file, or None if an error occurred.

"""

@abstractmethod
def ls(self, path: str, recursive: bool = False) -> list[str]:
"""List the contents of a directory.

Args:
path: The path to the directory.
recursive: If True, list the contents recursively.

Returns:
list[str]: A list of the contents of the directory.

"""

@abstractmethod
def getcwd(self) -> str:
"""Get the current working directory.

Returns:
str: The current working directory.

"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from snowflake.snowpark_checkpoints.io_utils import (
EnvStrategy,
IODefaultStrategy,
)
from snowflake.snowpark_checkpoints.singleton import Singleton


class IOFileManager(metaclass=Singleton):
def __init__(self, strategy: EnvStrategy = None):
self.strategy = strategy or IODefaultStrategy()

def mkdir(self, path: str, exist_ok=False) -> bool:
return self.strategy.mkdir(path, exist_ok)

def folder_exists(self, path: str) -> bool:
return self.strategy.folder_exists(path)

def file_exists(self, path: str) -> bool:
return self.strategy.file_exists(path)

def write(self, file_path: str, file_content: str, overwrite: bool = True) -> bool:
return self.strategy.write(file_path, file_content, overwrite)

def read(
self, file_path: str, mode: str = "r", encoding: str = None
) -> Optional[str]:
return self.strategy.read(file_path, mode, encoding)

def read_bytes(self, file_path: str) -> Optional[bytes]:
return self.strategy.read_bytes(file_path)

def ls(self, path: str, recursive: bool = False) -> list[str]:
return self.strategy.ls(path, recursive)

def getcwd(self) -> str:
return self.strategy.getcwd()

def set_strategy(self, strategy: EnvStrategy):
"""Set the strategy for file and directory operations.

Args:
strategy (EnvStrategy): The strategy to use for file and directory operations.

"""
self.strategy = strategy


def get_io_file_manager():
"""Get the singleton instance of IOFileManager.

Returns:
IOFileManager: The singleton instance of IOFileManager.

"""
return IOFileManager()
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

from typing import Optional

from snowflake.snowpark_checkpoints.io_utils.io_file_manager import get_io_file_manager
from snowflake.snowpark_checkpoints.utils.constants import (
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR,
)


# noinspection DuplicatedCode
def _get_checkpoint_contract_file_path() -> str:
return os.environ.get(SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, os.getcwd())
return os.environ.get(
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, get_io_file_manager().getcwd()
)


# noinspection DuplicatedCode
Expand Down
Loading
Loading