-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dadef19
commit ea1f3ec
Showing
5 changed files
with
192 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Defines the class for None types so they can be serialized in cbor. Check the gottas section at the bottom of the main readme for details""" | ||
from typing import Any, Optional | ||
|
||
|
||
class NoneType: | ||
""" | ||
A None type that can be serialized in cbor | ||
""" | ||
@staticmethod | ||
def parse_value(value: Optional[Any]) -> Any: | ||
if value is None: | ||
return NoneType() | ||
return value | ||
|
||
|
||
def replace_none(obj: Any) -> Any: | ||
""" | ||
Recursively replace None values with NoneType instances in any structure. | ||
Args: | ||
obj: The object to be scanned for None types | ||
returns: | ||
the same object but will all None types replaced for NonType objects | ||
""" | ||
if obj is None: | ||
return NoneType() | ||
elif isinstance(obj, dict): | ||
return {key: replace_none(value) for key, value in obj.items()} | ||
elif isinstance(obj, list): | ||
return [replace_none(item) for item in obj] | ||
elif isinstance(obj, tuple): | ||
return tuple(replace_none(item) for item in obj) | ||
elif isinstance(obj, set): | ||
return {replace_none(item) for item in obj} | ||
else: | ||
return obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from unittest import main, IsolatedAsyncioTestCase | ||
|
||
from surrealdb.connections.async_ws import AsyncWsSurrealConnection | ||
from surrealdb.data.types.record_id import RecordID | ||
from surrealdb.data.types.none import NoneType | ||
import os | ||
|
||
|
||
class TestAsyncWsSurrealConnectionNone(IsolatedAsyncioTestCase): | ||
|
||
async def asyncSetUp(self): | ||
self.url = "ws://localhost:8000/rpc" | ||
self.password = "root" | ||
self.username = "root" | ||
self.vars_params = { | ||
"username": self.username, | ||
"password": self.password, | ||
} | ||
self.database_name = "test_db" | ||
self.namespace = "test_ns" | ||
self.connection = AsyncWsSurrealConnection(self.url) | ||
|
||
# Sign in and select DB | ||
await self.connection.signin(self.vars_params) | ||
await self.connection.use(namespace=self.namespace, database=self.database_name) | ||
|
||
# Cleanup | ||
await self.connection.query("DELETE person;") | ||
|
||
async def test_none(self): | ||
schema = """ | ||
DEFINE TABLE person SCHEMAFULL TYPE NORMAL; | ||
DEFINE FIELD name ON person TYPE string; | ||
DEFINE FIELD age ON person TYPE option<int>; | ||
""" | ||
await self.connection.query(schema) | ||
outcome = await self.connection.create( | ||
"person:john", | ||
{"name": "John", "age": None} | ||
) | ||
record_check = RecordID(table_name="person", identifier="john") | ||
self.assertEqual(record_check, outcome["id"]) | ||
self.assertEqual("John", outcome["name"]) | ||
self.assertEqual(None, outcome.get("age")) | ||
|
||
outcome = await self.connection.create( | ||
"person:dave", | ||
{"name": "Dave", "age": 34} | ||
) | ||
record_check = RecordID(table_name="person", identifier="dave") | ||
self.assertEqual(record_check, outcome["id"]) | ||
self.assertEqual("Dave", outcome["name"]) | ||
self.assertEqual(34, outcome["age"]) | ||
|
||
outcome = await self.connection.query("SELECT * FROM person") | ||
self.assertEqual(2, len(outcome)) | ||
|
||
await self.connection.query("DELETE person;") | ||
await self.connection.close() | ||
|
||
if __name__ == "__main__": | ||
main() |