Skip to content

Commit

Permalink
save complex number to the sqlite DB
Browse files Browse the repository at this point in the history
  • Loading branch information
tmichela committed Jan 22, 2025
1 parent ca4c0f3 commit 6b276c4
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 8 deletions.
19 changes: 14 additions & 5 deletions damnit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import h5py

from .backend.db import BlobTypes, DamnitDB
from .util import blob2complex


# This is a copy of damnit.ctxsupport.ctxrunner.DataType, purely so that we can
Expand Down Expand Up @@ -166,6 +167,8 @@ def summary(self):
# after creating the VariableData object.
raise RuntimeError(f"Could not find value for '{self.name}' in p{self.proposal}, r{self.name}")
else:
if isinstance(result[0], bytes) and BlobTypes.identify(result[0]) is BlobTypes.complex:
return blob2complex(result[0])
return result[0]

def __repr__(self):
Expand Down Expand Up @@ -385,13 +388,19 @@ def table(self, with_titles=False) -> "pd.DataFrame":
if "comment" not in df:
df.insert(3, "comment", None)

# Convert PNG blobs into a string
def image2str(value):
if isinstance(value, bytes) and BlobTypes.identify(value) is BlobTypes.png:
return "<image>"
# interpret blobs
def blob2type(value):
if isinstance(value, bytes):
match BlobTypes.identify(value):
case BlobTypes.png | BlobTypes.numpy:
return "<image>"
case BlobTypes.complex:
return blob2complex(value)
case BlobTypes.unknown | _:
return "<unknown>"
else:
return value
df = df.applymap(image2str)
df = df.applymap(blob2type)

# Use the full variable titles
if with_titles:
Expand Down
6 changes: 6 additions & 0 deletions damnit/backend/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..definitions import UPDATE_TOPIC
from .user_variables import UserEditableVariable
from ..util import complex2blob

DB_NAME = Path('runs.sqlite')

Expand Down Expand Up @@ -71,6 +72,7 @@ class ReducedData:
class BlobTypes(Enum):
png = 'png'
numpy = 'numpy'
complex = 'complex'
unknown = 'unknown'

@classmethod
Expand All @@ -79,6 +81,8 @@ def identify(cls, blob: bytes):
return cls.png
elif blob.startswith(b'\x93NUMPY'):
return cls.numpy
elif blob.startswith(b'_DAMNIT_COMPLEX_'):
return cls.complex

return cls.unknown

Expand Down Expand Up @@ -324,6 +328,8 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
if variable["value"] is None:
for key in variable:
variable[key] = None
elif isinstance(variable["value"], complex):
variable["value"] = complex2blob(variable["value"])

variable["proposal"] = proposal
variable["run"] = run
Expand Down
2 changes: 1 addition & 1 deletion damnit/backend/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def add_to_db(reduced_data, db: DamnitDB, proposal, run):
db.ensure_run(proposal, run, start_time=start_time.value)

for name, reduced in reduced_data.items():
if not isinstance(reduced.value, (int, float, str, bytes)):
if not isinstance(reduced.value, (int, float, str, bytes, complex)):
raise TypeError(f"Unsupported type for database: {type(reduced.value)}")

db.set_variable(proposal, run, name, reduced)
Expand Down
16 changes: 16 additions & 0 deletions damnit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ class StatusbarStylesheet(Enum):
ERROR = "QStatusBar {background: red; color: white; font-weight: bold;}"


def complex2blob(data: complex) -> bytes:
# convert complex to bytes
real = data.real.hex()
imag = data.imag.hex()
return f"_DAMNIT_COMPLEX_{real}_{imag}".encode()


def blob2complex(data: bytes) -> complex:
# convert bytes to complex
real, _, imag = data[16:].decode().partition("_")
real = float.fromhex(real)
imag = float.fromhex(imag)
return complex(real, imag)


def timestamp2str(timestamp):
if timestamp is None or pd.isna(timestamp):
return None
Expand Down Expand Up @@ -45,6 +60,7 @@ def bool_to_numeric(data):
def fix_data_for_plotting(data):
return bool_to_numeric(make_finite(data))


def delete_variable(db, name):
# Remove from the database
db.delete_variable(name)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ def test_variable_data(mock_db_with_data, monkeypatch):

# Insert a DataSet variable
dataset_code = """
from damnit_ctx import Variable
from damnit_ctx import Cell, Variable
import xarray as xr
@Variable(title="Dataset")
def dataset(run):
return xr.Dataset(data_vars={
data = xr.Dataset(data_vars={
"foo": xr.DataArray([1, 2, 3]),
"bar/baz": xr.DataArray([1+2j, 3-4j, 5+6j]),
})
return Cell(data, summary_value=data['bar/baz'][2])
"""
(db_dir / "context.py").write_text(dedent(dataset_code))
extract_mock_run(1)
Expand Down Expand Up @@ -131,6 +132,10 @@ def dataset(run):
# Datasets have a internal _damnit attribute that should be removed
assert len(dataset.attrs) == 0

summary = rv["dataset"].summary()
assert isinstance(summary, complex)
assert summary == complex(5, 6)

fig = rv['plotly_mc_plotface'].read()
assert isinstance(fig, PlotlyFigure)
assert fig == px.bar(x=["a", "b", "c"], y=[1, 3, 2])
Expand Down
18 changes: 18 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from damnit.util import complex2blob, blob2complex


@pytest.mark.parametrize("value", [
1+2j,
0+0j,
-1.5-3.7j,
2.5+0j,
0+3.1j,
float('inf')+0j,
complex(float('inf'), -float('inf')),
])
def test_complex_blob_conversion(value):
# Test that converting complex -> blob -> complex preserves the value
blob = complex2blob(value)
result = blob2complex(blob)
assert result == value

0 comments on commit 6b276c4

Please sign in to comment.