Skip to content

Commit a019315

Browse files
committed
Fix memory zones
1 parent 59ac7e6 commit a019315

File tree

9 files changed

+76
-71
lines changed

9 files changed

+76
-71
lines changed

spacy/lang/kmr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .lex_attrs import LEX_ATTRS
21
from ...language import BaseDefaults, Language
2+
from .lex_attrs import LEX_ATTRS
33
from .stop_words import STOP_WORDS
44

55

spacy/lang/kmr/lex_attrs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ...attrs import LIKE_NUM
22

3-
43
_num_words = [
54
"sifir",
65
"yek",

spacy/language.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import random
66
import traceback
77
import warnings
8-
from contextlib import contextmanager
8+
from contextlib import ExitStack, contextmanager
99
from copy import deepcopy
1010
from dataclasses import dataclass
1111
from itertools import chain, cycle
@@ -31,6 +31,7 @@
3131
)
3232

3333
import srsly
34+
from cymem.cymem import Pool
3435
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
3536

3637
from . import about, ty, util
@@ -2091,6 +2092,38 @@ def replace_listeners(
20912092
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
20922093
tok2vec.remove_listener(listener, pipe_name)
20932094

2095+
@contextmanager
2096+
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
2097+
"""Begin a block where all resources allocated during the block will
2098+
be freed at the end of it. If a resources was created within the
2099+
memory zone block, accessing it outside the block is invalid.
2100+
Behaviour of this invalid access is undefined. Memory zones should
2101+
not be nested.
2102+
2103+
The memory zone is helpful for services that need to process large
2104+
volumes of text with a defined memory budget.
2105+
2106+
Example
2107+
-------
2108+
>>> with nlp.memory_zone():
2109+
... for doc in nlp.pipe(texts):
2110+
... process_my_doc(doc)
2111+
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
2112+
"""
2113+
if mem is None:
2114+
mem = Pool()
2115+
# The ExitStack allows programmatic nested context managers.
2116+
# We don't know how many we need, so it would be awkward to have
2117+
# them as nested blocks.
2118+
with ExitStack() as stack:
2119+
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
2120+
if hasattr(self.tokenizer, "memory_zone"):
2121+
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
2122+
for _, pipe in self.pipeline:
2123+
if hasattr(pipe, "memory_zone"):
2124+
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
2125+
yield mem
2126+
20942127
def to_disk(
20952128
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
20962129
) -> None:

spacy/pipeline/_parser_internals/arc_eager.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ cdef class ArcEagerGold:
203203
def __init__(self, ArcEager moves, StateClass stcls, Example example):
204204
self.mem = Pool()
205205
heads, labels = example.get_aligned_parse(projectivize=True)
206-
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels]
206+
labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
207207
sent_starts = _get_aligned_sent_starts(example)
208208
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
209209
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)

spacy/pipeline/_parser_internals/nonproj.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
183183
new_label, head_label = label.split(DELIMITER)
184184
new_head = _find_new_head(doc[i], head_label)
185185
doc.c[i].head = new_head.i - i
186-
doc.c[i].dep = doc.vocab.strings.add(new_label)
186+
doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
187187
set_children_from_heads(doc.c, 0, doc.length)
188188
return doc
189189

spacy/strings.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,4 @@ cdef class StringStore:
2828
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
2929
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
3030
cdef vector[hash_t] _transient_keys
31-
cdef PreshMap _transient_map
3231
cdef Pool _non_temp_mem

spacy/strings.pyx

Lines changed: 28 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
88
from libc.stdint cimport uint32_t
99
from libc.string cimport memcpy
1010
from murmurhash.mrmr cimport hash32, hash64
11+
from preshed.maps cimport map_clear
1112

1213
import srsly
1314

@@ -125,10 +126,9 @@ cdef class StringStore:
125126
self.mem = Pool()
126127
self._non_temp_mem = self.mem
127128
self._map = PreshMap()
128-
self._transient_map = None
129129
if strings is not None:
130130
for string in strings:
131-
self.add(string)
131+
self.add(string, allow_transient=False)
132132

133133
def __getitem__(self, object string_or_id):
134134
"""Retrieve a string from a given hash, or vice versa.
@@ -158,17 +158,17 @@ cdef class StringStore:
158158
return SYMBOLS_BY_INT[str_hash]
159159
else:
160160
utf8str = <Utf8Str*>self._map.get(str_hash)
161-
if utf8str is NULL and self._transient_map is not None:
162-
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
161+
if utf8str is NULL:
162+
raise KeyError(Errors.E018.format(hash_value=string_or_id))
163+
else:
164+
return decode_Utf8Str(utf8str)
163165
else:
164166
# TODO: Raise an error instead
165167
utf8str = <Utf8Str*>self._map.get(string_or_id)
166-
if utf8str is NULL and self._transient_map is not None:
167-
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
168-
if utf8str is NULL:
169-
raise KeyError(Errors.E018.format(hash_value=string_or_id))
170-
else:
171-
return decode_Utf8Str(utf8str)
168+
if utf8str is NULL:
169+
raise KeyError(Errors.E018.format(hash_value=string_or_id))
170+
else:
171+
return decode_Utf8Str(utf8str)
172172

173173
def as_int(self, key):
174174
"""If key is an int, return it; otherwise, get the int value."""
@@ -184,16 +184,12 @@ cdef class StringStore:
184184
else:
185185
return self[key]
186186

187-
def __reduce__(self):
188-
strings = list(self.non_transient_keys())
189-
return (StringStore, (strings,), None, None, None)
190-
191187
def __len__(self) -> int:
192188
"""The number of strings in the store.
193189

194190
RETURNS (int): The number of strings in the store.
195191
"""
196-
return self._keys.size() + self._transient_keys.size()
192+
return self.keys.size() + self._transient_keys.size()
197193
198194
@contextmanager
199195
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
@@ -209,13 +205,13 @@ cdef class StringStore:
209205
if mem is None:
210206
mem = Pool()
211207
self.mem = mem
212-
self._transient_map = PreshMap()
213208
yield mem
214-
self.mem = self._non_temp_mem
215-
self._transient_map = None
209+
for key in self._transient_keys:
210+
map_clear(self._map.c_map, key)
216211
self._transient_keys.clear()
212+
self.mem = self._non_temp_mem
217213
218-
def add(self, string: str, allow_transient: bool = False) -> int:
214+
def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
219215
"""Add a string to the StringStore.
220216

221217
string (str): The string to add.
@@ -226,6 +222,8 @@ cdef class StringStore:
226222
internally should not.
227223
RETURNS (uint64): The string's hash value.
228224
"""
225+
if allow_transient is None:
226+
allow_transient = self.mem is not self._non_temp_mem
229227
cdef hash_t str_hash
230228
if isinstance(string, str):
231229
if string in SYMBOLS_BY_STR:
@@ -273,17 +271,13 @@ cdef class StringStore:
273271
# TODO: Raise an error instead
274272
if self._map.get(string_or_id) is not NULL:
275273
return True
276-
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
277-
return True
278274
else:
279275
return False
280276
if str_hash < len(SYMBOLS_BY_INT):
281277
return True
282278
else:
283279
if self._map.get(str_hash) is not NULL:
284280
return True
285-
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
286-
return True
287281
else:
288282
return False
289283
@@ -292,32 +286,21 @@ cdef class StringStore:
292286

293287
YIELDS (str): A string in the store.
294288
"""
295-
yield from self.non_transient_keys()
296-
yield from self.transient_keys()
297-
298-
def non_transient_keys(self) -> Iterator[str]:
299-
"""Iterate over the stored strings in insertion order.
300-
301-
RETURNS: A list of strings.
302-
"""
303289
cdef int i
304290
cdef hash_t key
305291
for i in range(self.keys.size()):
306292
key = self.keys[i]
307293
utf8str = <Utf8Str*>self._map.get(key)
308294
yield decode_Utf8Str(utf8str)
295+
for i in range(self._transient_keys.size()):
296+
key = self._transient_keys[i]
297+
utf8str = <Utf8Str*>self._map.get(key)
298+
yield decode_Utf8Str(utf8str)
309299
310300
def __reduce__(self):
311301
strings = list(self)
312302
return (StringStore, (strings,), None, None, None)
313303
314-
def transient_keys(self) -> Iterator[str]:
315-
if self._transient_map is None:
316-
return []
317-
for i in range(self._transient_keys.size()):
318-
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
319-
yield decode_Utf8Str(utf8str)
320-
321304
def values(self) -> List[int]:
322305
"""Iterate over the stored strings hashes in insertion order.
323306

@@ -327,12 +310,9 @@ cdef class StringStore:
327310
hashes = [None] * self._keys.size()
328311
for i in range(self._keys.size()):
329312
hashes[i] = self._keys[i]
330-
if self._transient_map is not None:
331-
transient_hashes = [None] * self._transient_keys.size()
332-
for i in range(self._transient_keys.size()):
333-
transient_hashes[i] = self._transient_keys[i]
334-
else:
335-
transient_hashes = []
313+
transient_hashes = [None] * self._transient_keys.size()
314+
for i in range(self._transient_keys.size()):
315+
transient_hashes[i] = self._transient_keys[i]
336316
return hashes + transient_hashes
337317
338318
def to_disk(self, path):
@@ -383,8 +363,10 @@ cdef class StringStore:
383363
384364
def _reset_and_load(self, strings):
385365
self.mem = Pool()
366+
self._non_temp_mem = self.mem
386367
self._map = PreshMap()
387368
self.keys.clear()
369+
self._transient_keys.clear()
388370
for string in strings:
389371
self.add(string, allow_transient=False)
390372
@@ -401,19 +383,10 @@ cdef class StringStore:
401383
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
402384
if value is not NULL:
403385
return value
404-
if allow_transient and self._transient_map is not None:
405-
# If we've already allocated a transient string, and now we
406-
# want to intern it permanently, we'll end up with the string
407-
# in both places. That seems fine -- I don't see why we need
408-
# to remove it from the transient map.
409-
value = <Utf8Str*>self._transient_map.get(key)
410-
if value is not NULL:
411-
return value
412386
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
413-
if allow_transient and self._transient_map is not None:
414-
self._transient_map.set(key, value)
387+
self._map.set(key, value)
388+
if allow_transient and self.mem is not self._non_temp_mem:
415389
self._transient_keys.push_back(key)
416390
else:
417-
self._map.set(key, value)
418391
self.keys.push_back(key)
419392
return value

spacy/tokenizer.pyx

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,8 @@ cdef class Tokenizer:
517517
if n <= 0:
518518
# avoid mem alloc of zero length
519519
return 0
520-
# Historically this check was mostly used to avoid caching
521-
# chunks that had tokens owned by the Doc. Now that that's
522-
# not a thing, I don't think we need this?
523-
for i in range(n):
524-
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
525-
return 0
520+
if self.vocab.in_memory_zone:
521+
return 0
526522
# See #1250
527523
if has_special[0]:
528524
return 0

spacy/vocab.pyx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from typing import Iterator, Optional
55
import numpy
66
import srsly
77
from thinc.api import get_array_module, get_current_ops
8+
from preshed.maps cimport map_clear
89

910
from .attrs cimport LANG, ORTH
1011
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
@@ -104,7 +105,7 @@ cdef class Vocab:
104105
def vectors(self, vectors):
105106
if hasattr(vectors, "strings"):
106107
for s in vectors.strings:
107-
self.strings.add(s)
108+
self.strings.add(s, allow_transient=False)
108109
self._vectors = vectors
109110
self._vectors.strings = self.strings
110111

@@ -115,6 +116,10 @@ cdef class Vocab:
115116
langfunc = self.lex_attr_getters.get(LANG, None)
116117
return langfunc("_") if langfunc else ""
117118

119+
@property
120+
def in_memory_zone(self) -> bool:
121+
return self.mem is not self._non_temp_mem
122+
118123
def __len__(self):
119124
"""The current number of lexemes stored.
120125
@@ -218,7 +223,7 @@ cdef class Vocab:
218223
# this size heuristic.
219224
mem = self.mem
220225
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
221-
lex.orth = self.strings.add(string)
226+
lex.orth = self.strings.add(string, allow_transient=True)
222227
lex.length = len(string)
223228
if self.vectors is not None and hasattr(self.vectors, "key2row"):
224229
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
@@ -239,13 +244,13 @@ cdef class Vocab:
239244
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
240245
self._by_orth.set(lex.orth, <void*>lex)
241246
self.length += 1
242-
if is_transient:
247+
if is_transient and self.in_memory_zone:
243248
self._transient_orths.push_back(lex.orth)
244249

245250
def _clear_transient_orths(self):
246251
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
247252
for orth in self._transient_orths:
248-
self._by_orth.pop(orth)
253+
map_clear(self._by_orth.c_map, orth)
249254
self._transient_orths.clear()
250255

251256
def __contains__(self, key):

0 commit comments

Comments
 (0)