Skip to content

Commit 22f388e

Browse files
authored
Merge pull request #339 from aiverify-foundation/apgw-update-schema-nested-filepaths
Support nested artifact filepath and updated schema and misc fixes
2 parents 6c17fad + 0661ec3 commit 22f388e

File tree

9 files changed

+149
-77
lines changed

9 files changed

+149
-77
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from pathlib import PurePath
2+
import io
3+
import re
4+
import hashlib
5+
6+
7+
class InvalidFilename(Exception):
8+
"""Raised when the filename is invalid"""
9+
pass
10+
11+
12+
def check_valid_filename(filename: str):
13+
# return PurePath(filename).stem.isalnum
14+
if not filename[0].isalnum():
15+
return False
16+
# filename must be alphanumeric plus charters ./
17+
# sanitized = re.fullmatch(r'[a-zA-Z0-9_./]*', filename)
18+
# return sanitized is not None
19+
return re.search(r'\.\.', filename) is None
20+
21+
22+
def append_filename(filename: str, append_name: str) -> str:
23+
fpath = PurePath(filename)
24+
return fpath.stem + append_name + fpath.suffix
25+
26+
27+
def get_suffix(filename: str) -> str:
28+
return PurePath(filename).suffix.lower()
29+
30+
31+
def get_stem(filename: str) -> str:
32+
return PurePath(filename).stem
33+
34+
35+
def get_file_digest(contents: io.BytesIO):
36+
contents.seek(0)
37+
return hashlib.file_digest(contents, "sha256").digest()
38+
39+
40+
def sanitize_filename(filename: str) -> str:
41+
if not filename[0].isalnum():
42+
raise InvalidFilename("The first character of the filename must be alphanumeric.")
43+
44+
# Use regex to replace invalid characters
45+
sanitized = re.sub(r'[^a-zA-Z0-9./]', '', filename)
46+
return sanitized

aiverify-apigw/aiverify_apigw/lib/filestore.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2-
import io
3-
import re
4-
import hashlib
52
import shutil
6-
from pathlib import Path, PurePath
3+
from pathlib import Path
4+
import urllib.parse
5+
urllib.parse.uses_relative.append("s3")
6+
urllib.parse.uses_netloc.append("s3")
77
from urllib.parse import urljoin
88

99
from .s3 import MyS3
@@ -23,18 +23,13 @@ class FileStoreError(Exception):
2323
pass
2424

2525

26-
class InvalidFilename(Exception):
27-
"""Raised when the filename is invalid"""
28-
pass
29-
30-
31-
def sanitize_filename(filename: str) -> str:
32-
if not filename[0].isalnum():
33-
raise InvalidFilename("The first character of the filename must be alphanumeric.")
34-
35-
# Use regex to replace invalid characters
36-
sanitized = re.sub(r'[^a-zA-Z0-9.]', '', filename)
37-
return sanitized
26+
def check_relative_to_base(base_path: Path | str, filepath: str) -> bool:
27+
logger.debug(f"check_relative_to_base: {base_path} -> {filepath}")
28+
if isinstance(base_path, Path):
29+
return base_path.joinpath(filepath).resolve().is_relative_to(base_path)
30+
else:
31+
# for s3 must be full url
32+
return urljoin(base_path, filepath).startswith(base_path)
3833

3934

4035
def get_base_data_dir() -> Path | str:
@@ -80,27 +75,6 @@ def is_s3(path: Path | str) -> bool:
8075
raise InvalidFileStore(f"Invalid path: {path}")
8176

8277

83-
def check_valid_filename(filename: str):
84-
return PurePath(filename).stem.isalnum
85-
86-
87-
def append_filename(filename: str, append_name: str) -> str:
88-
fpath = PurePath(filename)
89-
return fpath.stem + append_name + fpath.suffix
90-
91-
92-
def get_suffix(filename: str) -> str:
93-
return PurePath(filename).suffix.lower()
94-
95-
96-
def get_stem(filename: str) -> str:
97-
return PurePath(filename).stem
98-
99-
100-
def get_file_digest(contents: io.BytesIO):
101-
contents.seek(0)
102-
return hashlib.file_digest(contents, "sha256").digest()
103-
10478
# def absolute_plugin_base_path(source: str) -> Path | str:
10579
# if isinstance(base_plugin_dir, Path):
10680
# return base_plugin_dir.joinpath(source).resolve()
@@ -143,14 +117,28 @@ def get_plugin_component_folder(gid: str, component_type: str) -> Path | str:
143117
return urljoin(plugin_path, f"{component_type}/")
144118

145119

120+
plugin_ignore_patten = shutil.ignore_patterns(
121+
".venv",
122+
"venv",
123+
"output",
124+
"node_modules",
125+
"build",
126+
"temp",
127+
"__pycache__",
128+
".pytest_cache",
129+
".cache"
130+
"*.pyc",
131+
)
132+
133+
146134
def save_plugin(gid: str, source_dir: Path):
147135
folder = get_plugin_folder(gid)
148136
logger.debug(f"Copy plugin folder from {source_dir} to {folder}")
149137
if isinstance(folder, Path):
150138
if folder.exists():
151139
shutil.rmtree(folder)
152140
folder.mkdir(parents=True, exist_ok=True)
153-
shutil.copytree(source_dir, folder, dirs_exist_ok=True)
141+
shutil.copytree(source_dir, folder, dirs_exist_ok=True, ignore=plugin_ignore_patten)
154142
elif s3 is not None:
155143
# folder is s3 prefix
156144
if s3.check_s3_prefix_exists(folder):
@@ -201,10 +189,17 @@ def get_artifacts_folder(test_result_id: str):
201189

202190

203191
def save_artifact(test_result_id: str, filename: str, data: bytes):
204-
filename = sanitize_filename(filename)
192+
# validate input
193+
if not test_result_id.isalnum():
194+
raise FileStoreError(f"Invalid test result id {test_result_id}")
205195
folder = get_artifacts_folder(test_result_id)
196+
if not check_relative_to_base(folder, filename):
197+
raise FileStoreError(f"Invalid filename {filename}")
206198
if isinstance(folder, Path):
207199
filepath = folder.joinpath(filename)
200+
# check for nest dir
201+
if not filepath.parent.exists():
202+
filepath.parent.mkdir(parents=True, exist_ok=True)
208203
with open(filepath, "wb") as fp:
209204
fp.write(data)
210205
return filepath
@@ -215,10 +210,13 @@ def save_artifact(test_result_id: str, filename: str, data: bytes):
215210

216211

217212
def get_artifact(test_result_id: str, filename: str):
218-
filename = sanitize_filename(filename)
213+
if not test_result_id.isalnum():
214+
raise FileStoreError(f"Invalid test result id {test_result_id}")
219215
folder = get_artifacts_folder(test_result_id)
216+
if not check_relative_to_base(folder, filename):
217+
raise FileStoreError(f"Invalid filename {filename}")
220218
if isinstance(folder, Path):
221-
filepath = folder.joinpath(filename)
219+
filepath = folder.joinpath(filename).resolve()
222220
with open(filepath, "rb") as fp:
223221
data = fp.read()
224222
return data

aiverify-apigw/aiverify_apigw/lib/plugin_store.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
import json
33
import tomllib
4+
from typing import List
45

56
from .schemas_utils import read_and_validate, plugin_schema, algorithm_schema
67
from ..lib.syntax_checker import validate_python_script
@@ -62,8 +63,9 @@ def delete_plugin(cls, gid: str):
6263
def scan_stock_plugins(cls):
6364
logger.info(f"Scanning stock plugins in folder {str(cls.stock_plugin_folder)}..")
6465
cls.delete_all_plugins() # remove all current plugins first
66+
plugins: List[PluginModel] = []
6567
for plugin_dir in cls.stock_plugin_folder.iterdir():
66-
if not plugin_dir.is_dir():
68+
if not plugin_dir.is_dir() or not plugin_dir.name[0].isalnum() or plugin_dir.name == "user_defined_files":
6769
continue
6870
logger.debug(f"Scanning directory {plugin_dir.name}")
6971
try:
@@ -72,13 +74,23 @@ def scan_stock_plugins(cls):
7274
logger.debug(f"Invalid plugin: {e}")
7375
continue
7476
try:
75-
cls.scan_plugin_directory(plugin_dir)
77+
plugin = cls.scan_plugin_directory(plugin_dir)
78+
if plugin:
79+
plugins.append(plugin)
7680
except Exception as e:
7781
logger.warning(f"Error saving plugin in directory {plugin_dir.name}: {e}")
7882
with SessionLocal() as session:
79-
stmt = select(func.count("*")).select_from(PluginModel)
80-
count = session.scalar(stmt)
81-
logger.info(f"Finished scanning stock plugins. {count} plugins found")
83+
# stmt = select(func.count("*")).select_from(PluginModel)
84+
# count = session.scalar(stmt)
85+
stmt = select(PluginModel)
86+
my_plugins = list(session.scalars(stmt))
87+
logger.info(f"Finished scanning stock plugins. {len(my_plugins)} plugins found")
88+
for plugin in my_plugins:
89+
logger.info(f"Stock plugin: gid {plugin.gid}, version {plugin.version}, name {plugin.name}")
90+
if plugin.algorithms and len(plugin.algorithms) > 0:
91+
logger.info(f" Number of algoritms: {len(plugin.algorithms)}")
92+
for algo in plugin.algorithms:
93+
logger.info(f" Algorithm: cid {algo.cid}, version {algo.version}, name {algo.name}")
8294

8395
@classmethod
8496
def check_plugin_registry(cls):
@@ -167,7 +179,8 @@ def scan_plugin_directory(cls, folder: Path):
167179
if "project" not in pyproject_data or "name" not in pyproject_data["project"]:
168180
logger.debug(f"Algorithm folder {algopath.name} has invalid pyproject.toml")
169181
continue
170-
project_name = pyproject_data["project"]["name"]
182+
# TODO: is this the best way to get algorithm folder?
183+
project_name = pyproject_data["project"]["name"].replace("-", "_")
171184
sub_path = algopath.joinpath(project_name)
172185
module_name = project_name
173186
meta_path = sub_path.joinpath("algo.meta.json")
@@ -185,9 +198,9 @@ def scan_plugin_directory(cls, folder: Path):
185198
script_path = algopath.joinpath(f"{meta.cid}.py")
186199
if not script_path.exists():
187200
script_path = algopath.joinpath(f"algo.py")
188-
if script_path.exists(): # if script exists
201+
if script_path.exists(): # if script exists
189202
if not validate_python_script(script_path):
190-
logger.warning(f"algorithm {cid} script is not valid")
203+
logger.warning(f"algorithm {meta.cid} script is not valid")
191204
continue
192205

193206
# validate requirements.txt
@@ -221,7 +234,7 @@ def scan_plugin_directory(cls, folder: Path):
221234
input_schema=json.dumps(input_schema).encode("utf-8"),
222235
output_schema=json.dumps(output_schema).encode("utf-8"),
223236
algo_dir=algopath.relative_to(folder).as_posix(),
224-
language="python", # fixed to python first. To support other languages in future
237+
language="python", # fixed to python first. To support other languages in future
225238
script=script_path.name,
226239
module_name=module_name,
227240
)
@@ -239,6 +252,7 @@ def scan_plugin_directory(cls, folder: Path):
239252
session.commit()
240253

241254
fs_save_plugin(plugin_meta.gid, folder)
255+
return plugin
242256

243257
@classmethod
244258
def validate_plugin_directory(cls, folder: Path) -> PluginMeta:
@@ -288,7 +302,8 @@ def validate_plugin_directory(cls, folder: Path) -> PluginMeta:
288302
if "project" not in pyproject_data or "name" not in pyproject_data["project"]:
289303
logger.debug(f"Algorithm folder {algopath.name} has invalid pyproject.toml")
290304
continue
291-
project_name = pyproject_data["project"]["name"]
305+
# TODO: is this the best way to get algorithm folder?
306+
project_name = pyproject_data["project"]["name"].replace("-", "_")
292307
sub_path = algopath.joinpath(project_name)
293308
meta_path = sub_path.joinpath("algo.meta.json")
294309
if not meta_path.exists():

aiverify-apigw/aiverify_apigw/lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import mimetypes
2-
from .filestore import get_suffix
2+
from .file_utils import get_suffix
33

44

55
def guess_mimetype_from_filename(filename: str):

aiverify-apigw/aiverify_apigw/routers/test_result_router.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from ..lib.logging import logger
1313
from ..lib.database import get_db_session
1414
from ..lib.constants import TestDatasetFileType, TestDatasetStatus, TestModelMode, TestModelStatus
15-
from ..lib.filestore import save_artifact, get_artifact, get_suffix
15+
from ..lib.filestore import save_artifact, get_artifact
1616
from ..lib.utils import guess_mimetype_from_filename
17+
from ..lib.file_utils import get_suffix, check_valid_filename
1718
from ..schemas import TestResult, TestResultOutput, TestResultUpdate
1819
from ..schemas.load_examples import test_result_examples
1920
from ..models import AlgorithmModel, TestModelModel, TestResultModel, TestDatasetModel, TestArtifactModel
@@ -37,10 +38,10 @@ async def _save_test_result(session: Session, test_result: TestResult, artifact_
3738
)
3839
algorithm = session.scalar(stmt)
3940
if algorithm is None:
40-
raise HTTPException(status_code=400, detail=f"Algorithm {test_result.gid} not found")
41+
raise HTTPException(status_code=400, detail=f"Algorithm not found: gid: {test_result.gid}, cid: {test_result.cid}")
4142

4243
now = datetime.now()
43-
test_arguments = test_result.test_arguments
44+
test_arguments = test_result.testArguments
4445

4546
# validate output
4647
output_schema = json.loads(algorithm.output_schema.decode("utf-8"))
@@ -133,8 +134,8 @@ async def _save_test_result(session: Session, test_result: TestResult, artifact_
133134
test_dataset=test_dataset,
134135
ground_truth_dataset=ground_truth_dataset,
135136
ground_truth=test_arguments.groundTruth,
136-
start_time=test_result.start_time,
137-
time_taken=test_result.time_taken,
137+
start_time=test_result.startTime,
138+
time_taken=test_result.timeTaken,
138139
algo_arguments=json.dumps(test_arguments.algorithmArgs).encode('utf-8'),
139140
output=json.dumps(test_result.output).encode('utf-8'),
140141
created_at=now,
@@ -152,6 +153,9 @@ async def _save_test_result(session: Session, test_result: TestResult, artifact_
152153
if filename not in artifact_set:
153154
logger.warn(f"Unable to find artifact filename {filename} in uploaded files, skipping")
154155
continue
156+
if not check_valid_filename(filename):
157+
logger.warn(f"Invalid artifact filename {filename}, skipping")
158+
raise HTTPException(status_code=400, detail=f"Invalid artifact filename in result output: {filename}")
155159
artifact_file = artifact_set[filename]
156160
data = artifact_file.file.read()
157161
save_artifact(test_result_id, filename, data)
@@ -271,12 +275,12 @@ async def upload_zip_file(
271275
for zip_info in zip_infos:
272276
if zip_info.filename != result_filename and zip_info.filename != foldername:
273277
with zip_ref.open(zip_info) as artifact_file:
274-
filename = zip_info.filename[len(foldername):] # remove foldername
278+
filename = zip_info.filename[len(foldername):] # remove foldername
275279
artifact_set[filename] = UploadFile(
276280
filename=filename,
277281
file=io.BytesIO(artifact_file.read())
278-
)
279-
282+
)
283+
280284
logger.debug(f"artifact_set: {artifact_set}")
281285

282286
if isinstance(results_data, dict):
@@ -311,7 +315,7 @@ async def upload_zip_file(
311315
raise HTTPException(status_code=500, detail="Internal server error")
312316

313317

314-
@router.get("/{test_result_id}/artifacts/{filename}")
318+
@router.get("/{test_result_id}/artifacts/{filename:path}")
315319
async def get_test_result_artifact(
316320
test_result_id: str,
317321
filename: str,
@@ -320,6 +324,7 @@ async def get_test_result_artifact(
320324
"""
321325
Endpoint to retrieve an artifact file by test_result_id and filename.
322326
"""
327+
logger.debug(f"get_test_result_artifact: {test_result_id}, {filename}")
323328
try:
324329
stmt = (
325330
select(TestArtifactModel)

aiverify-apigw/aiverify_apigw/schemas/examples/test_result_examples.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
"gid": "aiverify.stock.fairness_metrics_toolbox_for_classification",
44
"version": "0.9.0",
55
"cid": "fairness_metrics_toolbox_for_classification",
6-
"start_time": "2024-07-24T09:20:24.822881",
7-
"time_taken": 0,
8-
"test_arguments": {
6+
"startTime": "2024-07-24T09:20:24.822881",
7+
"timeTaken": 0,
8+
"testArguments": {
99
"testDataset": "file:///examples/data/sample_bc_credit_data.sav",
1010
"modelFile": "file:///examples/model/sample_bc_credit_sklearn_linear.LogisticRegression.sav",
1111
"groundTruthDataset": "file:///examples/data/sample_bc_credit_data.sav",

0 commit comments

Comments
 (0)