@@ -57,6 +57,7 @@ def mock_context(shared_data):
5757 "mod_exp" : modified_exp ,
5858 "mod_haz" : modified_haz ,
5959 "mod_imp" : modified_imp ,
60+ "date" : pd .Timestamp (2023 ),
6061 }
6162
6263
@@ -78,9 +79,9 @@ def test_not_from_factory_warning(mock_context):
7879@pytest .mark .parametrize (
7980 "input_date,expected" ,
8081 [
81- (2023 , datetime . date (2023 , 1 , 1 )),
82- ("2023-01-01" , datetime . date (2023 , 1 , 1 )),
83- (datetime .date (2023 , 1 , 1 ), datetime . date (2023 , 1 , 1 )),
82+ (2023 , pd . Timestamp (2023 , 1 , 1 )),
83+ ("2023-01-01" , pd . Timestamp (2023 , 1 , 1 )),
84+ (datetime .date (2023 , 1 , 1 ), pd . Timestamp (2023 , 1 , 1 )),
8485 ],
8586)
8687def test_init_valid_dates (mock_context , input_date , expected ):
@@ -106,7 +107,8 @@ def test_init_invalid_date_format(mock_context):
106107
107108def test_init_invalid_date_type (mock_context ):
108109 with pytest .raises (
109- TypeError , match = r"date_arg must be an int, str, or datetime.date"
110+ TypeError ,
111+ match = r"date_arg must be an int, str, datetime.date or pandas.Timestamp" ,
110112 ):
111113 Snapshot .from_triplet (exposure = mock_context ["exp" ], hazard = mock_context ["haz" ], impfset = mock_context ["imp" ], date = 2023.5 ) # type: ignore
112114
@@ -116,7 +118,7 @@ def test_properties(mock_context):
116118 exposure = mock_context ["exp" ],
117119 hazard = mock_context ["haz" ],
118120 impfset = mock_context ["imp" ],
119- date = 2023 ,
121+ date = mock_context [ "date" ] ,
120122 )
121123
122124 # Check that it's a deep copy (new reference)
@@ -129,14 +131,15 @@ def test_properties(mock_context):
129131 pd .testing .assert_frame_equal (snapshot .exposure .gdf , mock_context ["exp" ].gdf )
130132 assert snapshot .hazard .haz_type == mock_context ["haz" ].haz_type
131133 assert snapshot .impfset == mock_context ["imp" ]
134+ assert snapshot .date == mock_context ["date" ]
132135
133136
134137def test_reference (mock_context ):
135138 snapshot = Snapshot .from_triplet (
136139 exposure = mock_context ["exp" ],
137140 hazard = mock_context ["haz" ],
138141 impfset = mock_context ["imp" ],
139- date = 2023 ,
142+ date = mock_context [ "date" ] ,
140143 ref_only = True ,
141144 )
142145
@@ -146,18 +149,13 @@ def test_reference(mock_context):
146149 assert snapshot .impfset is mock_context ["imp" ]
147150 assert snapshot .measure is None
148151
149- # Check data equality
150- pd .testing .assert_frame_equal (snapshot .exposure .gdf , mock_context ["exp" ].gdf )
151- assert snapshot .hazard .haz_type == mock_context ["haz" ].haz_type
152- assert snapshot .impfset == mock_context ["imp" ]
153-
154152
155153def test_apply_measure (mock_context ):
156154 snapshot = Snapshot .from_triplet (
157155 exposure = mock_context ["exp" ],
158156 hazard = mock_context ["haz" ],
159157 impfset = mock_context ["imp" ],
160- date = 2023 ,
158+ date = mock_context [ "date" ] ,
161159 )
162160 new_snapshot = snapshot .apply_measure (mock_context ["measure" ])
163161
@@ -166,3 +164,4 @@ def test_apply_measure(mock_context):
166164 assert new_snapshot .exposure == mock_context ["mod_exp" ]
167165 assert new_snapshot .hazard == mock_context ["mod_haz" ]
168166 assert new_snapshot .impfset == mock_context ["mod_imp" ]
167+ assert new_snapshot .date == mock_context ["date" ]
0 commit comments