@@ -46,166 +46,165 @@ def connection_string():
46
46
47
47
48
48
def test_constructor (connection_string ):
49
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
50
- assert memory ._connection_pool is not None
49
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
50
+ assert memory ._connection_pool is not None
51
51
52
52
53
53
@pytest .mark .asyncio
54
54
async def test_create_and_does_collection_exist (connection_string ):
55
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
56
- await memory .create_collection ("test_collection" )
57
- result = await memory .does_collection_exist ("test_collection" )
58
- assert result is not None
55
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
56
+ await memory .create_collection ("test_collection" )
57
+ result = await memory .does_collection_exist ("test_collection" )
58
+ assert result is not None
59
59
60
60
61
61
@pytest .mark .asyncio
62
62
async def test_get_collections (connection_string ):
63
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
64
-
65
- try :
66
- await memory .create_collection ("test_collection" )
67
- result = await memory .get_collections ()
68
- assert "test_collection" in result
69
- except PoolTimeout :
70
- pytest .skip ("PoolTimeout exception raised, skipping test." )
63
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
64
+ try :
65
+ await memory .create_collection ("test_collection" )
66
+ result = await memory .get_collections ()
67
+ assert "test_collection" in result
68
+ except PoolTimeout :
69
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
71
70
72
71
73
72
@pytest .mark .asyncio
74
73
async def test_delete_collection (connection_string ):
75
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
76
- try :
77
- await memory .create_collection ("test_collection" )
74
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
75
+ try :
76
+ await memory .create_collection ("test_collection" )
78
77
79
- result = await memory .get_collections ()
80
- assert "test_collection" in result
78
+ result = await memory .get_collections ()
79
+ assert "test_collection" in result
81
80
82
- await memory .delete_collection ("test_collection" )
83
- result = await memory .get_collections ()
84
- assert "test_collection" not in result
85
- except PoolTimeout :
86
- pytest .skip ("PoolTimeout exception raised, skipping test." )
81
+ await memory .delete_collection ("test_collection" )
82
+ result = await memory .get_collections ()
83
+ assert "test_collection" not in result
84
+ except PoolTimeout :
85
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
87
86
88
87
89
88
@pytest .mark .asyncio
90
89
async def test_does_collection_exist (connection_string ):
91
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
92
- try :
93
- await memory .create_collection ("test_collection" )
94
- result = await memory .does_collection_exist ("test_collection" )
95
- assert result is True
96
- except PoolTimeout :
97
- pytest .skip ("PoolTimeout exception raised, skipping test." )
90
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
91
+ try :
92
+ await memory .create_collection ("test_collection" )
93
+ result = await memory .does_collection_exist ("test_collection" )
94
+ assert result is True
95
+ except PoolTimeout :
96
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
98
97
99
98
100
99
@pytest .mark .asyncio
101
100
async def test_upsert_and_get (connection_string , memory_record1 ):
102
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
103
- try :
104
- await memory .create_collection ("test_collection" )
105
- await memory .upsert ("test_collection" , memory_record1 )
106
- result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
107
- assert result is not None
108
- assert result ._id == memory_record1 ._id
109
- assert result ._text == memory_record1 ._text
110
- assert result ._timestamp == memory_record1 ._timestamp
111
- for i in range (len (result ._embedding )):
112
- assert result ._embedding [i ] == memory_record1 ._embedding [i ]
113
- except PoolTimeout :
114
- pytest .skip ("PoolTimeout exception raised, skipping test." )
101
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
102
+ try :
103
+ await memory .create_collection ("test_collection" )
104
+ await memory .upsert ("test_collection" , memory_record1 )
105
+ result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
106
+ assert result is not None
107
+ assert result ._id == memory_record1 ._id
108
+ assert result ._text == memory_record1 ._text
109
+ assert result ._timestamp == memory_record1 ._timestamp
110
+ for i in range (len (result ._embedding )):
111
+ assert result ._embedding [i ] == memory_record1 ._embedding [i ]
112
+ except PoolTimeout :
113
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
115
114
116
115
117
116
@pytest .mark .asyncio
118
117
async def test_upsert_batch_and_get_batch (connection_string , memory_record1 , memory_record2 ):
119
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
120
- try :
121
- await memory .create_collection ("test_collection" )
122
- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
123
-
124
- results = await memory .get_batch (
125
- "test_collection" ,
126
- [memory_record1 ._id , memory_record2 ._id ],
127
- with_embeddings = True ,
128
- )
129
- assert len (results ) == 2
130
- assert results [0 ]._id in [memory_record1 ._id , memory_record2 ._id ]
131
- assert results [1 ]._id in [memory_record1 ._id , memory_record2 ._id ]
132
- except PoolTimeout :
133
- pytest .skip ("PoolTimeout exception raised, skipping test." )
118
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
119
+ try :
120
+ await memory .create_collection ("test_collection" )
121
+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
122
+
123
+ results = await memory .get_batch (
124
+ "test_collection" ,
125
+ [memory_record1 ._id , memory_record2 ._id ],
126
+ with_embeddings = True ,
127
+ )
128
+ assert len (results ) == 2
129
+ assert results [0 ]._id in [memory_record1 ._id , memory_record2 ._id ]
130
+ assert results [1 ]._id in [memory_record1 ._id , memory_record2 ._id ]
131
+ except PoolTimeout :
132
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
134
133
135
134
136
135
@pytest .mark .asyncio
137
136
async def test_remove (connection_string , memory_record1 ):
138
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
139
- try :
140
- await memory .create_collection ("test_collection" )
141
- await memory .upsert ("test_collection" , memory_record1 )
137
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
138
+ try :
139
+ await memory .create_collection ("test_collection" )
140
+ await memory .upsert ("test_collection" , memory_record1 )
142
141
143
- result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
144
- assert result is not None
142
+ result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
143
+ assert result is not None
145
144
146
- await memory .remove ("test_collection" , memory_record1 ._id )
147
- with pytest .raises (ServiceResourceNotFoundError ):
148
- await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
149
- except PoolTimeout :
150
- pytest .skip ("PoolTimeout exception raised, skipping test." )
145
+ await memory .remove ("test_collection" , memory_record1 ._id )
146
+ with pytest .raises (ServiceResourceNotFoundError ):
147
+ await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
148
+ except PoolTimeout :
149
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
151
150
152
151
153
152
@pytest .mark .asyncio
154
153
async def test_remove_batch (connection_string , memory_record1 , memory_record2 ):
155
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
156
- try :
157
- await memory .create_collection ("test_collection" )
158
- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
159
- await memory .remove_batch ("test_collection" , [memory_record1 ._id , memory_record2 ._id ])
160
- with pytest .raises (ServiceResourceNotFoundError ):
161
- _ = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
154
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
155
+ try :
156
+ await memory .create_collection ("test_collection" )
157
+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
158
+ await memory .remove_batch ("test_collection" , [memory_record1 ._id , memory_record2 ._id ])
159
+ with pytest .raises (ServiceResourceNotFoundError ):
160
+ _ = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
162
161
163
- with pytest .raises (ServiceResourceNotFoundError ):
164
- _ = await memory .get ("test_collection" , memory_record2 ._id , with_embedding = True )
165
- except PoolTimeout :
166
- pytest .skip ("PoolTimeout exception raised, skipping test." )
162
+ with pytest .raises (ServiceResourceNotFoundError ):
163
+ _ = await memory .get ("test_collection" , memory_record2 ._id , with_embedding = True )
164
+ except PoolTimeout :
165
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
167
166
168
167
169
168
@pytest .mark .asyncio
170
169
async def test_get_nearest_match (connection_string , memory_record1 , memory_record2 ):
171
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
172
- try :
173
- await memory .create_collection ("test_collection" )
174
- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
175
- test_embedding = memory_record1 .embedding .copy ()
176
- test_embedding [0 ] = test_embedding [0 ] + 0.01
177
-
178
- result = await memory .get_nearest_match (
179
- "test_collection" , test_embedding , min_relevance_score = 0.0 , with_embedding = True
180
- )
181
- assert result is not None
182
- assert result [0 ]._id == memory_record1 ._id
183
- assert result [0 ]._text == memory_record1 ._text
184
- assert result [0 ]._timestamp == memory_record1 ._timestamp
185
- for i in range (len (result [0 ]._embedding )):
186
- assert result [0 ]._embedding [i ] == memory_record1 ._embedding [i ]
187
- except PoolTimeout :
188
- pytest .skip ("PoolTimeout exception raised, skipping test." )
170
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
171
+ try :
172
+ await memory .create_collection ("test_collection" )
173
+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
174
+ test_embedding = memory_record1 .embedding .copy ()
175
+ test_embedding [0 ] = test_embedding [0 ] + 0.01
176
+
177
+ result = await memory .get_nearest_match (
178
+ "test_collection" , test_embedding , min_relevance_score = 0.0 , with_embedding = True
179
+ )
180
+ assert result is not None
181
+ assert result [0 ]._id == memory_record1 ._id
182
+ assert result [0 ]._text == memory_record1 ._text
183
+ assert result [0 ]._timestamp == memory_record1 ._timestamp
184
+ for i in range (len (result [0 ]._embedding )):
185
+ assert result [0 ]._embedding [i ] == memory_record1 ._embedding [i ]
186
+ except PoolTimeout :
187
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
189
188
190
189
191
190
@pytest .mark .asyncio
192
191
async def test_get_nearest_matches (connection_string , memory_record1 , memory_record2 , memory_record3 ):
193
- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
194
- try :
195
- await memory .create_collection ("test_collection" )
196
- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 , memory_record3 ])
197
- test_embedding = memory_record2 .embedding
198
- test_embedding [0 ] = test_embedding [0 ] + 0.025
199
-
200
- result = await memory .get_nearest_matches (
201
- "test_collection" ,
202
- test_embedding ,
203
- limit = 2 ,
204
- min_relevance_score = 0.0 ,
205
- with_embeddings = True ,
206
- )
207
- assert len (result ) == 2
208
- assert result [0 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
209
- assert result [1 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
210
- except PoolTimeout :
211
- pytest .skip ("PoolTimeout exception raised, skipping test." )
192
+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
193
+ try :
194
+ await memory .create_collection ("test_collection" )
195
+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 , memory_record3 ])
196
+ test_embedding = memory_record2 .embedding
197
+ test_embedding [0 ] = test_embedding [0 ] + 0.025
198
+
199
+ result = await memory .get_nearest_matches (
200
+ "test_collection" ,
201
+ test_embedding ,
202
+ limit = 2 ,
203
+ min_relevance_score = 0.0 ,
204
+ with_embeddings = True ,
205
+ )
206
+ assert len (result ) == 2
207
+ assert result [0 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
208
+ assert result [1 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
209
+ except PoolTimeout :
210
+ pytest .skip ("PoolTimeout exception raised, skipping test." )
0 commit comments