@@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
8
8
from libc.stdint cimport uint32_t
9
9
from libc.string cimport memcpy
10
10
from murmurhash.mrmr cimport hash32, hash64
11
+ from preshed.maps cimport map_clear
11
12
12
13
import srsly
13
14
@@ -125,10 +126,9 @@ cdef class StringStore:
125
126
self .mem = Pool()
126
127
self ._non_temp_mem = self .mem
127
128
self ._map = PreshMap()
128
- self ._transient_map = None
129
129
if strings is not None :
130
130
for string in strings:
131
- self .add(string)
131
+ self .add(string, allow_transient = False )
132
132
133
133
def __getitem__ (self , object string_or_id ):
134
134
""" Retrieve a string from a given hash, or vice versa.
@@ -158,17 +158,17 @@ cdef class StringStore:
158
158
return SYMBOLS_BY_INT[str_hash]
159
159
else :
160
160
utf8str = < Utf8Str* > self ._map.get(str_hash)
161
- if utf8str is NULL and self ._transient_map is not None :
162
- utf8str = < Utf8Str* > self ._transient_map.get(str_hash)
161
+ if utf8str is NULL :
162
+ raise KeyError (Errors.E018.format(hash_value = string_or_id))
163
+ else :
164
+ return decode_Utf8Str(utf8str)
163
165
else :
164
166
# TODO: Raise an error instead
165
167
utf8str = < Utf8Str* > self ._map.get(string_or_id)
166
- if utf8str is NULL and self ._transient_map is not None :
167
- utf8str = < Utf8Str* > self ._transient_map.get(str_hash)
168
- if utf8str is NULL :
169
- raise KeyError (Errors.E018.format(hash_value = string_or_id))
170
- else :
171
- return decode_Utf8Str(utf8str)
168
+ if utf8str is NULL :
169
+ raise KeyError (Errors.E018.format(hash_value = string_or_id))
170
+ else :
171
+ return decode_Utf8Str(utf8str)
172
172
173
173
def as_int (self , key ):
174
174
""" If key is an int, return it; otherwise, get the int value."""
@@ -184,16 +184,12 @@ cdef class StringStore:
184
184
else :
185
185
return self [key]
186
186
187
- def __reduce__ (self ):
188
- strings = list (self .non_transient_keys())
189
- return (StringStore, (strings,), None , None , None )
190
-
191
187
def __len__ (self ) -> int:
192
188
"""The number of strings in the store.
193
189
194
190
RETURNS (int ): The number of strings in the store.
195
191
"""
196
- return self._keys .size() + self._transient_keys.size()
192
+ return self.keys .size() + self._transient_keys.size()
197
193
198
194
@contextmanager
199
195
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
@@ -209,13 +205,13 @@ cdef class StringStore:
209
205
if mem is None:
210
206
mem = Pool()
211
207
self.mem = mem
212
- self._transient_map = PreshMap()
213
208
yield mem
214
- self.mem = self._non_temp_mem
215
- self._transient_map = None
209
+ for key in self._transient_keys:
210
+ map_clear( self._map.c_map, key)
216
211
self._transient_keys.clear()
212
+ self.mem = self._non_temp_mem
217
213
218
- def add(self, string: str, allow_transient: bool = False ) -> int:
214
+ def add(self, string: str, allow_transient: Optional[ bool] = None ) -> int:
219
215
""" Add a string to the StringStore.
220
216
221
217
string (str ): The string to add.
@@ -226,6 +222,8 @@ cdef class StringStore:
226
222
internally should not .
227
223
RETURNS (uint64): The string' s hash value.
228
224
"""
225
+ if allow_transient is None:
226
+ allow_transient = self.mem is not self._non_temp_mem
229
227
cdef hash_t str_hash
230
228
if isinstance(string, str):
231
229
if string in SYMBOLS_BY_STR:
@@ -273,17 +271,13 @@ cdef class StringStore:
273
271
# TODO: Raise an error instead
274
272
if self._map.get(string_or_id) is not NULL:
275
273
return True
276
- elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
277
- return True
278
274
else:
279
275
return False
280
276
if str_hash < len(SYMBOLS_BY_INT):
281
277
return True
282
278
else:
283
279
if self._map.get(str_hash) is not NULL:
284
280
return True
285
- elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
286
- return True
287
281
else:
288
282
return False
289
283
@@ -292,32 +286,21 @@ cdef class StringStore:
292
286
293
287
YIELDS (str ): A string in the store.
294
288
"""
295
- yield from self.non_transient_keys()
296
- yield from self.transient_keys()
297
-
298
- def non_transient_keys(self) -> Iterator[str]:
299
- """ Iterate over the stored strings in insertion order.
300
-
301
- RETURNS: A list of strings.
302
- """
303
289
cdef int i
304
290
cdef hash_t key
305
291
for i in range(self.keys.size()):
306
292
key = self.keys[i]
307
293
utf8str = <Utf8Str*>self._map.get(key)
308
294
yield decode_Utf8Str(utf8str)
295
+ for i in range(self._transient_keys.size()):
296
+ key = self._transient_keys[i]
297
+ utf8str = <Utf8Str*>self._map.get(key)
298
+ yield decode_Utf8Str(utf8str)
309
299
310
300
def __reduce__(self):
311
301
strings = list(self)
312
302
return (StringStore, (strings,), None, None, None)
313
303
314
- def transient_keys(self) -> Iterator[str]:
315
- if self._transient_map is None:
316
- return []
317
- for i in range(self._transient_keys.size()):
318
- utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
319
- yield decode_Utf8Str(utf8str)
320
-
321
304
def values(self) -> List[int]:
322
305
""" Iterate over the stored strings hashes in insertion order.
323
306
@@ -327,12 +310,9 @@ cdef class StringStore:
327
310
hashes = [None] * self._keys.size()
328
311
for i in range(self._keys.size()):
329
312
hashes[i] = self._keys[i]
330
- if self._transient_map is not None:
331
- transient_hashes = [None] * self._transient_keys.size()
332
- for i in range(self._transient_keys.size()):
333
- transient_hashes[i] = self._transient_keys[i]
334
- else:
335
- transient_hashes = []
313
+ transient_hashes = [None] * self._transient_keys.size()
314
+ for i in range(self._transient_keys.size()):
315
+ transient_hashes[i] = self._transient_keys[i]
336
316
return hashes + transient_hashes
337
317
338
318
def to_disk(self, path):
@@ -383,8 +363,10 @@ cdef class StringStore:
383
363
384
364
def _reset_and_load(self, strings):
385
365
self.mem = Pool()
366
+ self._non_temp_mem = self.mem
386
367
self._map = PreshMap()
387
368
self.keys.clear()
369
+ self._transient_keys.clear()
388
370
for string in strings:
389
371
self.add(string, allow_transient=False)
390
372
@@ -401,19 +383,10 @@ cdef class StringStore:
401
383
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
402
384
if value is not NULL:
403
385
return value
404
- if allow_transient and self._transient_map is not None:
405
- # If we've already allocated a transient string, and now we
406
- # want to intern it permanently, we'll end up with the string
407
- # in both places. That seems fine -- I don't see why we need
408
- # to remove it from the transient map.
409
- value = <Utf8Str*>self._transient_map.get(key)
410
- if value is not NULL:
411
- return value
412
386
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
413
- if allow_transient and self._transient_map is not None:
414
- self._transient_map.set(key, value)
387
+ self._map.set(key, value)
388
+ if allow_transient and self.mem is not self._non_temp_mem:
415
389
self._transient_keys.push_back(key)
416
390
else:
417
- self._map.set(key, value)
418
391
self.keys.push_back(key)
419
392
return value
0 commit comments