Skip to content

Commit e9c3170

Browse files
MoSheikhMoIMC
andauthored
Add support for boolean expressions and quoted columns (#1286)
* Add support for boolean expressions and quoted columns * Add AlwaysTrue & AlwaysFalse support plus tests * Add test for quoted column * Remove commented code --------- Co-authored-by: Mohammad Sheikh <[email protected]>
1 parent 2778ec2 commit e9c3170

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

pyiceberg/expressions/parser.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CaselessKeyword,
2222
DelimitedList,
2323
Group,
24+
MatchFirst,
2425
ParserElement,
2526
ParseResults,
2627
Suppress,
@@ -57,6 +58,7 @@
5758
StartsWith,
5859
)
5960
from pyiceberg.expressions.literals import (
61+
BooleanLiteral,
6062
DecimalLiteral,
6163
Literal,
6264
LongLiteral,
@@ -77,7 +79,9 @@
7779
NAN = CaselessKeyword("nan")
7880
LIKE = CaselessKeyword("like")
7981

80-
identifier = Word(alphas, alphanums + "_$").set_results_name("identifier")
82+
unquoted_identifier = Word(alphas, alphanums + "_$")
83+
quoted_identifier = Suppress('"') + unquoted_identifier + Suppress('"')
84+
identifier = MatchFirst([unquoted_identifier, quoted_identifier]).set_results_name("identifier")
8185
column = DelimitedList(identifier, delim=".", combine=False).set_results_name("column")
8286

8387
like_regex = r"(?P<valid_wildcard>(?<!\\)%$)|(?P<invalid_wildcard>(?<!\\)%)"
@@ -100,16 +104,18 @@ def _(result: ParseResults) -> Reference:
100104
string = sgl_quoted_string.set_results_name("raw_quoted_string")
101105
decimal = common.real().set_results_name("decimal")
102106
integer = common.signed_integer().set_results_name("integer")
103-
literal = Group(string | decimal | integer).set_results_name("literal")
104-
literal_set = Group(DelimitedList(string) | DelimitedList(decimal) | DelimitedList(integer)).set_results_name("literal_set")
107+
literal = Group(string | decimal | integer | boolean).set_results_name("literal")
108+
literal_set = Group(
109+
DelimitedList(string) | DelimitedList(decimal) | DelimitedList(integer) | DelimitedList(boolean)
110+
).set_results_name("literal_set")
105111

106112

107113
@boolean.set_parse_action
108-
def _(result: ParseResults) -> BooleanExpression:
114+
def _(result: ParseResults) -> Literal[bool]:
109115
if strtobool(result.boolean):
110-
return AlwaysTrue()
116+
return BooleanLiteral(True)
111117
else:
112-
return AlwaysFalse()
118+
return BooleanLiteral(False)
113119

114120

115121
@string.set_parse_action
@@ -265,14 +271,29 @@ def handle_or(result: ParseResults) -> Or:
265271
return Or(*result[0])
266272

267273

268-
boolean_expression = infix_notation(
269-
predicate,
270-
[
271-
(Suppress(NOT), 1, opAssoc.RIGHT, handle_not),
272-
(Suppress(AND), 2, opAssoc.LEFT, handle_and),
273-
(Suppress(OR), 2, opAssoc.LEFT, handle_or),
274-
],
275-
).set_name("expr")
274+
def handle_always_expression(result: ParseResults) -> BooleanExpression:
275+
# If the entire result is "true" or "false", return AlwaysTrue or AlwaysFalse
276+
expr = result[0]
277+
if isinstance(expr, BooleanLiteral):
278+
if expr.value:
279+
return AlwaysTrue()
280+
else:
281+
return AlwaysFalse()
282+
return result[0]
283+
284+
285+
boolean_expression = (
286+
infix_notation(
287+
predicate,
288+
[
289+
(Suppress(NOT), 1, opAssoc.RIGHT, handle_not),
290+
(Suppress(AND), 2, opAssoc.LEFT, handle_and),
291+
(Suppress(OR), 2, opAssoc.LEFT, handle_or),
292+
],
293+
)
294+
.set_name("expr")
295+
.add_parse_action(handle_always_expression)
296+
)
276297

277298

278299
def parse(expr: str) -> BooleanExpression:

tests/expressions/test_parser.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,28 @@
4141
)
4242

4343

44-
def test_true() -> None:
44+
def test_always_true() -> None:
4545
assert AlwaysTrue() == parser.parse("true")
4646

4747

48-
def test_false() -> None:
48+
def test_always_false() -> None:
4949
assert AlwaysFalse() == parser.parse("false")
5050

5151

52+
def test_quoted_column() -> None:
53+
assert EqualTo("foo", True) == parser.parse('"foo" = TRUE')
54+
55+
56+
def test_equals_true() -> None:
57+
assert EqualTo("foo", True) == parser.parse("foo = true")
58+
assert EqualTo("foo", True) == parser.parse("foo == TRUE")
59+
60+
61+
def test_equals_false() -> None:
62+
assert EqualTo("foo", False) == parser.parse("foo = false")
63+
assert EqualTo("foo", False) == parser.parse("foo == FALSE")
64+
65+
5266
def test_is_null() -> None:
5367
assert IsNull("foo") == parser.parse("foo is null")
5468
assert IsNull("foo") == parser.parse("foo IS NULL")

0 commit comments

Comments
 (0)