Skip to content

Commit 515b3a3

Browse files
committed
Add ability to close connection pool in mem store
This adds an __enter__ and __exit__ function to the PostgresMemoryStore so that it can be used as a context manager. Also modify the integration tests to utilize the context manager, so that connections are freed up after the tests are run. Without this change, the tests will hold onto the connections and cause the database to run out of connections.
1 parent 508feca commit 515b3a3

File tree

2 files changed

+127
-118
lines changed

2 files changed

+127
-118
lines changed

python/semantic_kernel/connectors/memory/postgres/postgres_memory_store.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,13 @@ def __serialize_metadata(self, record: MemoryRecord) -> str:
500500
"description": record._description,
501501
"additional_metadata": record._additional_metadata,
502502
})
503+
504+
# Enable the connection pool to be closed when using as a context manager
505+
def __enter__(self) -> "PostgresMemoryStore":
506+
"""Enter the runtime context."""
507+
return self
508+
509+
def __exit__(self, exc_type, exc_value, traceback) -> bool:
510+
"""Exit the runtime context and dispose of the connection pool."""
511+
self._connection_pool.close()
512+
return False

python/tests/integration/connectors/memory/test_postgres_memory_store.py

Lines changed: 117 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -46,166 +46,165 @@ def connection_string():
4646

4747

4848
def test_constructor(connection_string):
49-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
50-
assert memory._connection_pool is not None
49+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
50+
assert memory._connection_pool is not None
5151

5252

5353
@pytest.mark.asyncio
5454
async def test_create_and_does_collection_exist(connection_string):
55-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
56-
await memory.create_collection("test_collection")
57-
result = await memory.does_collection_exist("test_collection")
58-
assert result is not None
55+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
56+
await memory.create_collection("test_collection")
57+
result = await memory.does_collection_exist("test_collection")
58+
assert result is not None
5959

6060

6161
@pytest.mark.asyncio
6262
async def test_get_collections(connection_string):
63-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
64-
65-
try:
66-
await memory.create_collection("test_collection")
67-
result = await memory.get_collections()
68-
assert "test_collection" in result
69-
except PoolTimeout:
70-
pytest.skip("PoolTimeout exception raised, skipping test.")
63+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
64+
try:
65+
await memory.create_collection("test_collection")
66+
result = await memory.get_collections()
67+
assert "test_collection" in result
68+
except PoolTimeout:
69+
pytest.skip("PoolTimeout exception raised, skipping test.")
7170

7271

7372
@pytest.mark.asyncio
7473
async def test_delete_collection(connection_string):
75-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
76-
try:
77-
await memory.create_collection("test_collection")
74+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
75+
try:
76+
await memory.create_collection("test_collection")
7877

79-
result = await memory.get_collections()
80-
assert "test_collection" in result
78+
result = await memory.get_collections()
79+
assert "test_collection" in result
8180

82-
await memory.delete_collection("test_collection")
83-
result = await memory.get_collections()
84-
assert "test_collection" not in result
85-
except PoolTimeout:
86-
pytest.skip("PoolTimeout exception raised, skipping test.")
81+
await memory.delete_collection("test_collection")
82+
result = await memory.get_collections()
83+
assert "test_collection" not in result
84+
except PoolTimeout:
85+
pytest.skip("PoolTimeout exception raised, skipping test.")
8786

8887

8988
@pytest.mark.asyncio
9089
async def test_does_collection_exist(connection_string):
91-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
92-
try:
93-
await memory.create_collection("test_collection")
94-
result = await memory.does_collection_exist("test_collection")
95-
assert result is True
96-
except PoolTimeout:
97-
pytest.skip("PoolTimeout exception raised, skipping test.")
90+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
91+
try:
92+
await memory.create_collection("test_collection")
93+
result = await memory.does_collection_exist("test_collection")
94+
assert result is True
95+
except PoolTimeout:
96+
pytest.skip("PoolTimeout exception raised, skipping test.")
9897

9998

10099
@pytest.mark.asyncio
101100
async def test_upsert_and_get(connection_string, memory_record1):
102-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
103-
try:
104-
await memory.create_collection("test_collection")
105-
await memory.upsert("test_collection", memory_record1)
106-
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
107-
assert result is not None
108-
assert result._id == memory_record1._id
109-
assert result._text == memory_record1._text
110-
assert result._timestamp == memory_record1._timestamp
111-
for i in range(len(result._embedding)):
112-
assert result._embedding[i] == memory_record1._embedding[i]
113-
except PoolTimeout:
114-
pytest.skip("PoolTimeout exception raised, skipping test.")
101+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
102+
try:
103+
await memory.create_collection("test_collection")
104+
await memory.upsert("test_collection", memory_record1)
105+
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
106+
assert result is not None
107+
assert result._id == memory_record1._id
108+
assert result._text == memory_record1._text
109+
assert result._timestamp == memory_record1._timestamp
110+
for i in range(len(result._embedding)):
111+
assert result._embedding[i] == memory_record1._embedding[i]
112+
except PoolTimeout:
113+
pytest.skip("PoolTimeout exception raised, skipping test.")
115114

116115

117116
@pytest.mark.asyncio
118117
async def test_upsert_batch_and_get_batch(connection_string, memory_record1, memory_record2):
119-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
120-
try:
121-
await memory.create_collection("test_collection")
122-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
123-
124-
results = await memory.get_batch(
125-
"test_collection",
126-
[memory_record1._id, memory_record2._id],
127-
with_embeddings=True,
128-
)
129-
assert len(results) == 2
130-
assert results[0]._id in [memory_record1._id, memory_record2._id]
131-
assert results[1]._id in [memory_record1._id, memory_record2._id]
132-
except PoolTimeout:
133-
pytest.skip("PoolTimeout exception raised, skipping test.")
118+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
119+
try:
120+
await memory.create_collection("test_collection")
121+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
122+
123+
results = await memory.get_batch(
124+
"test_collection",
125+
[memory_record1._id, memory_record2._id],
126+
with_embeddings=True,
127+
)
128+
assert len(results) == 2
129+
assert results[0]._id in [memory_record1._id, memory_record2._id]
130+
assert results[1]._id in [memory_record1._id, memory_record2._id]
131+
except PoolTimeout:
132+
pytest.skip("PoolTimeout exception raised, skipping test.")
134133

135134

136135
@pytest.mark.asyncio
137136
async def test_remove(connection_string, memory_record1):
138-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
139-
try:
140-
await memory.create_collection("test_collection")
141-
await memory.upsert("test_collection", memory_record1)
137+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
138+
try:
139+
await memory.create_collection("test_collection")
140+
await memory.upsert("test_collection", memory_record1)
142141

143-
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
144-
assert result is not None
142+
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
143+
assert result is not None
145144

146-
await memory.remove("test_collection", memory_record1._id)
147-
with pytest.raises(ServiceResourceNotFoundError):
148-
await memory.get("test_collection", memory_record1._id, with_embedding=True)
149-
except PoolTimeout:
150-
pytest.skip("PoolTimeout exception raised, skipping test.")
145+
await memory.remove("test_collection", memory_record1._id)
146+
with pytest.raises(ServiceResourceNotFoundError):
147+
await memory.get("test_collection", memory_record1._id, with_embedding=True)
148+
except PoolTimeout:
149+
pytest.skip("PoolTimeout exception raised, skipping test.")
151150

152151

153152
@pytest.mark.asyncio
154153
async def test_remove_batch(connection_string, memory_record1, memory_record2):
155-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
156-
try:
157-
await memory.create_collection("test_collection")
158-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
159-
await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id])
160-
with pytest.raises(ServiceResourceNotFoundError):
161-
_ = await memory.get("test_collection", memory_record1._id, with_embedding=True)
154+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
155+
try:
156+
await memory.create_collection("test_collection")
157+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
158+
await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id])
159+
with pytest.raises(ServiceResourceNotFoundError):
160+
_ = await memory.get("test_collection", memory_record1._id, with_embedding=True)
162161

163-
with pytest.raises(ServiceResourceNotFoundError):
164-
_ = await memory.get("test_collection", memory_record2._id, with_embedding=True)
165-
except PoolTimeout:
166-
pytest.skip("PoolTimeout exception raised, skipping test.")
162+
with pytest.raises(ServiceResourceNotFoundError):
163+
_ = await memory.get("test_collection", memory_record2._id, with_embedding=True)
164+
except PoolTimeout:
165+
pytest.skip("PoolTimeout exception raised, skipping test.")
167166

168167

169168
@pytest.mark.asyncio
170169
async def test_get_nearest_match(connection_string, memory_record1, memory_record2):
171-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
172-
try:
173-
await memory.create_collection("test_collection")
174-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
175-
test_embedding = memory_record1.embedding.copy()
176-
test_embedding[0] = test_embedding[0] + 0.01
177-
178-
result = await memory.get_nearest_match(
179-
"test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True
180-
)
181-
assert result is not None
182-
assert result[0]._id == memory_record1._id
183-
assert result[0]._text == memory_record1._text
184-
assert result[0]._timestamp == memory_record1._timestamp
185-
for i in range(len(result[0]._embedding)):
186-
assert result[0]._embedding[i] == memory_record1._embedding[i]
187-
except PoolTimeout:
188-
pytest.skip("PoolTimeout exception raised, skipping test.")
170+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
171+
try:
172+
await memory.create_collection("test_collection")
173+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
174+
test_embedding = memory_record1.embedding.copy()
175+
test_embedding[0] = test_embedding[0] + 0.01
176+
177+
result = await memory.get_nearest_match(
178+
"test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True
179+
)
180+
assert result is not None
181+
assert result[0]._id == memory_record1._id
182+
assert result[0]._text == memory_record1._text
183+
assert result[0]._timestamp == memory_record1._timestamp
184+
for i in range(len(result[0]._embedding)):
185+
assert result[0]._embedding[i] == memory_record1._embedding[i]
186+
except PoolTimeout:
187+
pytest.skip("PoolTimeout exception raised, skipping test.")
189188

190189

191190
@pytest.mark.asyncio
192191
async def test_get_nearest_matches(connection_string, memory_record1, memory_record2, memory_record3):
193-
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
194-
try:
195-
await memory.create_collection("test_collection")
196-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3])
197-
test_embedding = memory_record2.embedding
198-
test_embedding[0] = test_embedding[0] + 0.025
199-
200-
result = await memory.get_nearest_matches(
201-
"test_collection",
202-
test_embedding,
203-
limit=2,
204-
min_relevance_score=0.0,
205-
with_embeddings=True,
206-
)
207-
assert len(result) == 2
208-
assert result[0][0]._id in [memory_record3._id, memory_record2._id]
209-
assert result[1][0]._id in [memory_record3._id, memory_record2._id]
210-
except PoolTimeout:
211-
pytest.skip("PoolTimeout exception raised, skipping test.")
192+
with PostgresMemoryStore(connection_string, 2, 1, 5) as memory:
193+
try:
194+
await memory.create_collection("test_collection")
195+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3])
196+
test_embedding = memory_record2.embedding
197+
test_embedding[0] = test_embedding[0] + 0.025
198+
199+
result = await memory.get_nearest_matches(
200+
"test_collection",
201+
test_embedding,
202+
limit=2,
203+
min_relevance_score=0.0,
204+
with_embeddings=True,
205+
)
206+
assert len(result) == 2
207+
assert result[0][0]._id in [memory_record3._id, memory_record2._id]
208+
assert result[1][0]._id in [memory_record3._id, memory_record2._id]
209+
except PoolTimeout:
210+
pytest.skip("PoolTimeout exception raised, skipping test.")

0 commit comments

Comments
 (0)