Skip to content

Commit

Permalink
Merge pull request #445 from weaviate/fix-where-filter-error-message
Browse files Browse the repository at this point in the history
move client-side filter check and error from gql.get to batch.delete_objects
  • Loading branch information
tsmith023 authored Aug 25, 2023
2 parents 9c00952 + f456f77 commit 2b762f7
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 45 deletions.
45 changes: 45 additions & 0 deletions integration/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import weaviate
from weaviate import Tenant
from weaviate.gql.filter import VALUE_ARRAY_TYPES, WHERE_OPERATORS

UUID = Union[str, uuid.UUID]

Expand Down Expand Up @@ -131,6 +132,50 @@ def test_delete_objects(client: weaviate.Client):
assert "four" in names
assert "five" in names

with pytest.raises(ValueError) as error:
with client.batch as batch:
batch.delete_objects(
"Test",
where={
"path": ["name"],
"operator": "ContainsAny",
"valueText": ["four"],
},
)
assert (
error.value.args[0]
== f"Operator 'ContainsAny' is not supported for value type 'valueText'. Supported value types are: {VALUE_ARRAY_TYPES}"
)

where = {
"path": ["name"],
"valueTextArray": ["four"],
}
with pytest.raises(ValueError) as error:
with client.batch as batch:
batch.delete_objects(
"Test",
where=where,
)
assert (
error.value.args[0] == f"Where filter is missing required field `operator`. Given: {where}"
)

with pytest.raises(ValueError) as error:
with client.batch as batch:
batch.delete_objects(
"Test",
where={
"path": ["name"],
"operator": "Wrong",
"valueText": ["four"],
},
)
assert (
error.value.args[0]
== f"Operator Wrong is not allowed. Allowed operators are: {WHERE_OPERATORS}"
)


@pytest.mark.parametrize("from_object_uuid", [uuid.uuid4(), str(uuid.uuid4()), uuid.uuid4().hex])
@pytest.mark.parametrize("to_object_uuid", [uuid.uuid4().hex, uuid.uuid4(), str(uuid.uuid4())])
Expand Down
38 changes: 13 additions & 25 deletions test/gql/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Where,
Ask,
WHERE_OPERATORS,
VALUE_TYPES,
)


Expand Down Expand Up @@ -551,14 +552,11 @@ def test___init__(self):
content_error_msg = lambda dt: f"Where filter is expected to be type dict but is {dt}"
content_key_error_msg = "Filter is missing required fields `path` or `operands`. Given: "
path_key_error = "Filter is missing required field `operator`. Given: "
dtype_no_value_error_msg = "Filter is missing required field 'value<TYPE>': "
dtype_no_value_error_msg = "'value<TYPE>' field is either missing or incorrect: "
dtype_multiple_value_error_msg = "Multiple fields 'value<TYPE>' are not supported: "
operator_error_msg = (
lambda op: f"Operator {op} is not allowed. Allowed operators are: {', '.join(WHERE_OPERATORS)}"
)
contains_operator_value_type_mismatch_msg = (
lambda op, vt: f"Operator {op} requires a value of type {vt}List. Given value type: {vt}"
)
geo_operator_value_type_mismatch_msg = (
lambda op, vt: f"Operator {op} requires a value of type valueGeoRange. Given value type: {vt}"
)
Expand Down Expand Up @@ -603,18 +601,6 @@ def test___init__(self):
Where({"path": "some_path", "operator": "NotValid"})
check_error_message(self, error, operator_error_msg("NotValid"))

with self.assertRaises(ValueError) as error:
Where({"path": "some_path", "operator": "ContainsAll", "valueString": "A"})
check_error_message(
self, error, contains_operator_value_type_mismatch_msg("ContainsAll", "valueString")
)

with self.assertRaises(ValueError) as error:
Where({"path": "some_path", "operator": "ContainsAny", "valueInt": 1})
check_error_message(
self, error, contains_operator_value_type_mismatch_msg("ContainsAny", "valueInt")
)

with self.assertRaises(ValueError) as error:
Where({"path": "some_path", "operator": "WithinGeoRange", "valueBoolean": True})
check_error_message(
Expand Down Expand Up @@ -795,15 +781,17 @@ def test___str__(self):
self, error, value_is_list_err(["test-2021-02-02", "test-2021-02-03"], "valueDate")
)

# test_filter = {
# "path": ["name"],
# "operator": "Equal",
# "valueText": "😃",
# }
# result = str(Where(test_filter))
# self.assertEqual(
# 'where: {path: ["name"] operator: Equal valueText: "\\ud83d\\ude03"} ', str(result)
# )
test_filter = {
"path": ["name"],
"operator": "Equal",
"valueWrong": "whatever",
}
with self.assertRaises(ValueError) as error:
str(Where(test_filter))
assert (
error.exception.args[0]
== f"'value<TYPE>' field is either missing or incorrect: {test_filter}. Valid values are: {VALUE_TYPES}."
)


class TestAskFilter(unittest.TestCase):
Expand Down
14 changes: 13 additions & 1 deletion weaviate/batch/crud_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from weaviate.connect import Connection
from weaviate.data.replication import ConsistencyLevel
from weaviate.gql.filter import _find_value_type
from weaviate.gql.filter import _find_value_type, VALUE_ARRAY_TYPES, WHERE_OPERATORS
from weaviate.types import UUID
from .requests import BatchRequest, ObjectsBatchRequest, ReferenceBatchRequest, BatchResponse
from ..cluster import Cluster
Expand Down Expand Up @@ -1814,6 +1814,18 @@ def _clean_delete_objects_where(where: dict) -> dict:
"""
py_value_type = _find_value_type(where)
weaviate_value_type = _convert_value_type(py_value_type)
if "operator" not in where:
raise ValueError("Where filter is missing required field `operator`." f" Given: {where}")
if where["operator"] not in WHERE_OPERATORS:
raise ValueError(
f"Operator {where['operator']} is not allowed. "
f"Allowed operators are: {WHERE_OPERATORS}"
)
operator = where["operator"]
if "Contains" in operator and "Array" not in weaviate_value_type:
raise ValueError(
f"Operator '{operator}' is not supported for value type '{weaviate_value_type}'. Supported value types are: {VALUE_ARRAY_TYPES}"
)
where[weaviate_value_type] = where.pop(py_value_type)
return where

Expand Down
31 changes: 12 additions & 19 deletions weaviate/gql/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
from weaviate.util import get_vector, _sanitize_str

VALUE_LIST_TYPES = {
"valueStringArray",
"valueTextArray",
"valueIntArray",
"valueNumberArray",
"valueBooleanArray",
"valueDateArray",
"valueStringList",
"valueTextList",
"valueIntList",
"valueNumberList",
"valueBooleanList",
"valueDateList",
}

VALUE_ARRAY_TYPES = {
"valueStringArray",
"valueTextArray",
"valueIntArray",
"valueNumberArray",
"valueBooleanArray",
}

VALUE_PRIMITIVE_TYPES = {
Expand All @@ -41,7 +42,8 @@
"valueGeoRange",
}

VALUE_TYPES = VALUE_LIST_TYPES.union(VALUE_PRIMITIVE_TYPES)
ALL_VALUE_TYPES = VALUE_LIST_TYPES.union(VALUE_ARRAY_TYPES).union(VALUE_PRIMITIVE_TYPES)
VALUE_TYPES = VALUE_ARRAY_TYPES.union(VALUE_PRIMITIVE_TYPES)

WHERE_OPERATORS = [
"And",
Expand Down Expand Up @@ -796,15 +798,6 @@ def _parse_filter(self, content: dict) -> None:
self.value_type = _find_value_type(content)
self.value = content[self.value_type]

if (
self.operator in ["ContainsAny", "ContainsAll"]
and self.value_type not in VALUE_LIST_TYPES
):
raise ValueError(
f"Operator {self.operator} requires a value of type {self.value_type}List. "
f"Given value type: {self.value_type}"
)

if self.operator == "WithinGeoRange" and self.value_type != "valueGeoRange":
raise ValueError(
f"Operator {self.operator} requires a value of type valueGeoRange. "
Expand Down Expand Up @@ -1147,11 +1140,11 @@ def _find_value_type(content: dict) -> str:
If missing required fields.
"""

value_type = VALUE_TYPES & set(content.keys())
value_type = ALL_VALUE_TYPES & set(content.keys())

if len(value_type) == 0:
raise ValueError(
f"Filter is missing required field 'value<TYPE>': {content}. Valid values are: {VALUE_TYPES}."
f"'value<TYPE>' field is either missing or incorrect: {content}. Valid values are: {VALUE_TYPES}."
)
if len(value_type) != 1:
raise ValueError(f"Multiple fields 'value<TYPE>' are not supported: {content}")
Expand Down

0 comments on commit 2b762f7

Please sign in to comment.