Skip to content

Commit

Permalink
fixing None type serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Feb 25, 2025
1 parent dadef19 commit ea1f3ec
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 3 deletions.
86 changes: 85 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# surrealdb.py

The official SurrealDB SDK for Python.
The official SurrealDB SDK for Python. If you find that the python SDK is not behaving exactly how you expect, please check the gottas section at the bottom of this README to see if your problem can be quickly solved.

## Documentation

Expand Down Expand Up @@ -184,3 +184,87 @@ To exit the terminal session merely execute the following command:
exit
```
And there we have it, our tests are passing.

## Gottas

Due to quirks either unearthed by python or how the `cbor` serialization library handles data, there are some slight quirks that you might not expect. This section clarifies these quirks, why they exist, and how to handle them.

### None types

Python's `cbor` library serializes `None` types automatically before they have a chance of reaching our encoder. While we are looking into ways to override default serialization methods, we have our own data type that denotes a `None` type which can be shown below:

```python
from surrealdb import AsyncSurreal, NoneType, RecordID

vars = {
"username": "root",
"password": "root"
}

schema = """
DEFINE TABLE person SCHEMAFULL TYPE NORMAL;
DEFINE FIELD name ON person TYPE string;
DEFINE FIELD age ON person TYPE option<int>;
"""
connection = AsyncSurreal("ws://localhost:8000/rpc")
await connection.query(schema)

outcome = await connection.create(
"person:john",
{"name": "John", "age": None}
)
record_check = RecordID(table_name="person", identifier="john")
self.assertEqual(record_check, outcome["id"])
self.assertEqual("John", outcome["name"])
# below we need a .get because fields with None are currently not serialized
# a .get gives the same result
self.assertEqual(None, outcome.get("age"))
```

It must be noted that the field that had a `None` is not returned as a field at all. Using a `.get()` function will give the same effect as if the field is there but it is a `None`. Using a `outcome["age"]` will throw an error. We can also see how it works when the field is not None with the following code:

```python
outcome = await connection.create(
"person:dave",
{"name": "Dave", "age": 34}
)
record_check = RecordID(table_name="person", identifier="dave")
self.assertEqual(record_check, outcome["id"])
self.assertEqual("Dave", outcome["name"])
self.assertEqual(34, outcome["age"])
```
Here we can see that the age is returned because it is not `None`. There is a slight performance cost for this `None` safety as the client needs to recursively go through the data passed into the client replacing `None` with `NoneType`. If you do not want this performance cost, you can disable the check but you have to ensure that all `None` types you pass into the client are replaced yourself. You can us a `NoneType` via the following code:

```python
from surrealdb import AsyncSurreal, NoneType, RecordID
import os

vars = {
"username": "root",
"password": "root"
}

schema = """
DEFINE TABLE person SCHEMAFULL TYPE NORMAL;
DEFINE FIELD name ON person TYPE string;
DEFINE FIELD age ON person TYPE option<int>;
"""
connection = AsyncSurreal("ws://localhost:8000/rpc")
await connection.query(schema)

# bypass the recursive check to replace None with NoneType
os.environ["SURREALDB_BYPASS_CHECKS"] = "true"

outcome = await connection.create(
"person:john",
{"name": "John", "age": None}
)
record_check = RecordID(table_name="person", identifier="john")
self.assertEqual(record_check, outcome["id"])
self.assertEqual("John", outcome["name"])
# below we need a .get because fields with None are currently not serialized
# a .get gives the same result
self.assertEqual(None, outcome.get("age"))
```

Here we set the environment variable `SURREALDB_BYPASS_CHECKS` to `"true"`.
1 change: 1 addition & 0 deletions src/surrealdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from surrealdb.data.types.range import Range
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.datetime import IsoDateTimeWrapper
from surrealdb.data.types.none import NoneType


class AsyncSurrealDBMeta(type):
Expand Down
11 changes: 9 additions & 2 deletions src/surrealdb/data/cbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
from surrealdb.data.types.range import BoundIncluded, BoundExcluded, Range
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.table import Table
from surrealdb.data.types.none import NoneType, replace_none
import os


@cbor2.shareable_encoder
def default_encoder(encoder, obj):
if isinstance(obj, GeometryPoint):
tagged = cbor2.CBORTag(constants.TAG_GEOMETRY_POINT, obj.get_coordinates())

elif isinstance(obj, GeometryLine):
tagged = cbor2.CBORTag(constants.TAG_GEOMETRY_LINE, obj.get_coordinates())

elif isinstance(obj, GeometryPolygon):
tagged = cbor2.CBORTag(constants.TAG_GEOMETRY_POLYGON, obj.get_coordinates())

Expand All @@ -50,6 +50,9 @@ def default_encoder(encoder, obj):
elif isinstance(obj, RecordID):
tagged = cbor2.CBORTag(constants.TAG_RECORD_ID, [obj.table_name, obj.id])

elif isinstance(obj, NoneType):
tagged = cbor2.CBORTag(constants.TAG_NONE, None)

elif isinstance(obj, Table):
tagged = cbor2.CBORTag(constants.TAG_TABLE_NAME, obj.table_name)

Expand Down Expand Up @@ -134,6 +137,10 @@ def tag_decoder(decoder, tag, shareable_index=None):


def encode(obj):
none_check_flag = os.environ.get("SURREALDB_BYPASS_CHECKS")
if none_check_flag is not None and none_check_flag.upper() == "TRUE":
return cbor2.dumps(obj, default=default_encoder, timezone=timezone.utc)
obj = replace_none(obj)
return cbor2.dumps(obj, default=default_encoder, timezone=timezone.utc)


Expand Down
35 changes: 35 additions & 0 deletions src/surrealdb/data/types/none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Defines the class for None types so they can be serialized in cbor. Check the gottas section at the bottom of the main readme for details"""
from typing import Any, Optional


class NoneType:
"""
A None type that can be serialized in cbor
"""
@staticmethod
def parse_value(value: Optional[Any]) -> Any:
if value is None:
return NoneType()
return value


def replace_none(obj: Any) -> Any:
"""
Recursively replace None values with NoneType instances in any structure.
Args:
obj: The object to be scanned for None types
returns:
the same object but will all None types replaced for NonType objects
"""
if obj is None:
return NoneType()
elif isinstance(obj, dict):
return {key: replace_none(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [replace_none(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(replace_none(item) for item in obj)
elif isinstance(obj, set):
return {replace_none(item) for item in obj}
else:
return obj
62 changes: 62 additions & 0 deletions tests/unit_tests/data_types/test_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unittest import main, IsolatedAsyncioTestCase

from surrealdb.connections.async_ws import AsyncWsSurrealConnection
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.none import NoneType
import os


class TestAsyncWsSurrealConnectionNone(IsolatedAsyncioTestCase):

async def asyncSetUp(self):
self.url = "ws://localhost:8000/rpc"
self.password = "root"
self.username = "root"
self.vars_params = {
"username": self.username,
"password": self.password,
}
self.database_name = "test_db"
self.namespace = "test_ns"
self.connection = AsyncWsSurrealConnection(self.url)

# Sign in and select DB
await self.connection.signin(self.vars_params)
await self.connection.use(namespace=self.namespace, database=self.database_name)

# Cleanup
await self.connection.query("DELETE person;")

async def test_none(self):
schema = """
DEFINE TABLE person SCHEMAFULL TYPE NORMAL;
DEFINE FIELD name ON person TYPE string;
DEFINE FIELD age ON person TYPE option<int>;
"""
await self.connection.query(schema)
outcome = await self.connection.create(
"person:john",
{"name": "John", "age": None}
)
record_check = RecordID(table_name="person", identifier="john")
self.assertEqual(record_check, outcome["id"])
self.assertEqual("John", outcome["name"])
self.assertEqual(None, outcome.get("age"))

outcome = await self.connection.create(
"person:dave",
{"name": "Dave", "age": 34}
)
record_check = RecordID(table_name="person", identifier="dave")
self.assertEqual(record_check, outcome["id"])
self.assertEqual("Dave", outcome["name"])
self.assertEqual(34, outcome["age"])

outcome = await self.connection.query("SELECT * FROM person")
self.assertEqual(2, len(outcome))

await self.connection.query("DELETE person;")
await self.connection.close()

if __name__ == "__main__":
main()

0 comments on commit ea1f3ec

Please sign in to comment.