Skip to content

Commit 08b0f27

Browse files
committed
Do not alias fields of tracked_struct Values when updating
1 parent 4d92253 commit 08b0f27

File tree

4 files changed

+74
-85
lines changed

4 files changed

+74
-85
lines changed

src/attach.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ impl Attached {
4141
// Already attached? Assert that the database has not changed.
4242
// NOTE: It's important to use `addr_eq` here because `NonNull::eq`
4343
// not only compares the address but also the type's metadata.
44-
if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) {
45-
panic!(
46-
"Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}",
47-
);
48-
}
44+
assert!(
45+
std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()),
46+
"Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}"
47+
);
4948

5049
Self { state: None }
5150
} else {

src/revision.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ impl OptionalAtomicRevision {
104104
)
105105
}
106106

107+
pub(crate) fn swap_mut(&mut self, val: Option<Revision>) -> Option<Revision> {
108+
Revision::from_opt(std::mem::replace(
109+
self.data.get_mut(),
110+
val.map_or(0, |r| r.as_usize()),
111+
))
112+
}
113+
107114
pub(crate) fn compare_exchange(
108115
&self,
109116
current: Option<Revision>,

src/tracked_struct.rs

Lines changed: 62 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -400,18 +400,17 @@ where
400400
disambiguator,
401401
};
402402

403-
let current_revision = zalsa.current_revision();
404403
match zalsa_local.tracked_struct_id(&identity) {
405404
Some(id) => {
406405
// The struct already exists in the intern map.
407406
zalsa_local.add_output(self.database_key_index(id).into());
408-
self.update(zalsa, current_revision, id, &current_deps, fields);
407+
self.update(zalsa, id, &current_deps, fields);
409408
C::struct_from_id(id)
410409
}
411410

412411
None => {
413412
// This is a new tracked struct, so create an entry in the struct map.
414-
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
413+
let id = self.allocate(zalsa, zalsa_local, &current_deps, fields);
415414
let key = self.database_key_index(id);
416415
zalsa_local.add_output(key.into());
417416
zalsa_local.store_tracked_struct_id(identity, id);
@@ -424,10 +423,10 @@ where
424423
&'db self,
425424
zalsa: &'db Zalsa,
426425
zalsa_local: &'db ZalsaLocal,
427-
current_revision: Revision,
428426
current_deps: &StampedValue<()>,
429427
fields: C::Fields<'db>,
430428
) -> Id {
429+
let current_revision = zalsa.current_revision();
431430
let value = |_| Value {
432431
created_at: current_revision,
433432
updated_at: OptionalAtomicRevision::new(Some(current_revision)),
@@ -440,16 +439,14 @@ where
440439

441440
if let Some(id) = self.free_list.pop() {
442441
let data_raw = Self::data_raw(zalsa.table(), id);
443-
assert!(
442+
debug_assert!(
444443
unsafe { (*data_raw).updated_at.load().is_none() },
445-
"free list entry for `{id:?}` does not have `None` for `updated_at`"
444+
"free list entry for `{id:?}` should not be locked"
446445
);
447446

448447
// Overwrite the free-list entry. Use `*foo = ` because the entry
449448
// has been previously initialized and we want to free the old contents.
450-
unsafe {
451-
*data_raw = value(id);
452-
}
449+
unsafe { *data_raw = value(id) };
453450

454451
id
455452
} else {
@@ -467,7 +464,6 @@ where
467464
fn update<'db>(
468465
&'db self,
469466
zalsa: &'db Zalsa,
470-
current_revision: Revision,
471467
id: Id,
472468
current_deps: &StampedValue<()>,
473469
fields: C::Fields<'db>,
@@ -508,6 +504,7 @@ where
508504
// during the current revision and thus obtained an `&` reference to those fields
509505
// that is still live.
510506

507+
let current_revision = zalsa.current_revision();
511508
// UNSAFE: Marking as mut requires exclusive access for the duration of
512509
// the `mut`. We have now *claimed* this data by swapping in `None`,
513510
// any attempt to read concurrently will panic.
@@ -524,17 +521,19 @@ where
524521
// Acquire the write-lock. This can only fail if there is a parallel thread
525522
// reading from this same `id`, which can only happen if the user has leaked it.
526523
// Tsk tsk.
527-
let swapped_out = unsafe { (*data_raw).updated_at.swap(None) };
528-
if swapped_out != last_updated_at {
524+
525+
let swapped = unsafe { (*data_raw).updated_at.swap(None) };
526+
if last_updated_at != swapped {
529527
panic!(
530528
"failed to acquire write lock, id `{id:?}` must have been leaked across threads"
531529
);
532530
}
533531

534-
// UNSAFE: Marking as mut requires exclusive access for the duration of
532+
// SAFETY: Marking as mut requires exclusive access for the duration of
535533
// the `mut`. We have now *claimed* this data by swapping in `None`,
536-
// any attempt to read concurrently will panic.
537-
let data = unsafe { &mut *data_raw };
534+
// any attempt to read concurrently will panic. Note that we cannot create
535+
// a `&mut` reference to the full `Value` though because
536+
// another thread may access `updated_at` concurrently.
538537

539538
// SAFETY: We assert that the pointer to `data.revisions`
540539
// is a pointer into the database referencing a value
@@ -544,8 +543,8 @@ where
544543
unsafe {
545544
if C::update_fields(
546545
current_revision,
547-
&mut data.revisions,
548-
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
546+
&mut (*data_raw).revisions,
547+
self.to_self_ptr(std::ptr::addr_of_mut!((*data_raw).fields)),
549548
fields,
550549
) {
551550
// Consider this a new tracked-struct (even though it still uses the same id)
@@ -554,22 +553,28 @@ where
554553
// which makes Salsa consider two tracked structs to still be the same
555554
// even though the fields are different.
556555
// See `tracked-struct-id-field-bad-hash` for more details.
557-
data.created_at = current_revision;
556+
(*data_raw).revisions = C::new_revisions(current_revision);
557+
(*data_raw).created_at = current_revision;
558+
} else if current_deps.durability < (*data_raw).durability {
559+
(*data_raw).revisions = C::new_revisions(current_revision);
560+
(*data_raw).created_at = current_revision;
558561
}
562+
(*data_raw).durability = current_deps.durability;
559563
}
560-
if current_deps.durability < data.durability {
561-
data.revisions = C::new_revisions(current_revision);
562-
data.created_at = current_revision;
563-
}
564-
data.durability = current_deps.durability;
565-
let swapped_out = data.updated_at.swap(Some(current_revision));
566-
assert!(swapped_out.is_none());
564+
let swapped_out = unsafe { (*data_raw).updated_at.swap_mut(Some(current_revision)) };
565+
assert!(
566+
swapped_out.is_none(),
567+
"two concurrent writers to {id:?}, should not be possible"
568+
);
567569
}
568570

569571
/// Fetch the data for a given id created by this ingredient from the table,
570572
/// -giving it the appropriate type.
571-
fn data(table: &Table, id: Id) -> &Value<C> {
572-
table.get(id)
573+
fn data(table: &Table, id: Id, current_revision: Revision) -> &Value<C> {
574+
let val = Self::data_raw(table, id);
575+
acquire_read_lock(unsafe { &(*val).updated_at }, current_revision);
576+
// We have acquired the read lock, so it is safe to return a reference to the data.
577+
unsafe { &*val }
573578
}
574579

575580
fn data_raw(table: &Table, id: Id) -> *mut Value<C> {
@@ -594,29 +599,23 @@ where
594599
});
595600

596601
let zalsa = db.zalsa();
597-
let current_revision = zalsa.current_revision();
598602
let data = Self::data_raw(zalsa.table(), id);
599603

600604
// We want to set `updated_at` to `None`, signalling that other field values
601605
// cannot be read. The current value should be `Some(R0)` for some older revision.
602-
let data_ref = unsafe { &*data };
603-
match data_ref.updated_at.load() {
606+
match unsafe { (*data).updated_at.swap(None) }{
604607
None => {
605608
panic!("cannot delete write-locked id `{id:?}`; value leaked across threads");
606609
}
607-
Some(r) if r == current_revision => panic!(
610+
Some(r) if r == zalsa.current_revision() => panic!(
608611
"cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic"
609612
),
610-
Some(r) => {
611-
if data_ref.updated_at.compare_exchange(Some(r), None).is_err() {
612-
panic!("race occurred when deleting value `{id:?}`")
613-
}
614-
}
613+
Some(_) => ()
615614
}
616615

617616
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
618-
// and the code that references the memo-table has a read-lock.
619-
let memo_table = unsafe { (*data).take_memo_table() };
617+
// signalling that we have acquired the write lock
618+
let memo_table = std::mem::take(unsafe { &mut (*data).memos });
620619

621620
// SAFETY: We have verified that no more references to these memos exist and so we are good
622621
// to drop them.
@@ -648,7 +647,7 @@ where
648647
s: C::Struct<'db>,
649648
) -> &'db C::Fields<'db> {
650649
let id = C::deref_struct(s);
651-
let value = Self::data(db.zalsa().table(), id);
650+
let value = Self::data(db.zalsa().table(), id, db.zalsa().current_revision());
652651
unsafe { self.to_self_ref(&value.fields) }
653652
}
654653

@@ -670,9 +669,7 @@ where
670669
let (zalsa, zalsa_local) = db.zalsas();
671670
let id = C::deref_struct(s);
672671
let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index);
673-
let data = Self::data(zalsa.table(), id);
674-
675-
data.read_lock(zalsa.current_revision());
672+
let data = Self::data(zalsa.table(), id, zalsa.current_revision());
676673

677674
let field_changed_at = data.revisions[relative_tracked_index];
678675

@@ -697,9 +694,7 @@ where
697694
) -> &'db C::Fields<'db> {
698695
let (zalsa, zalsa_local) = db.zalsas();
699696
let id = C::deref_struct(s);
700-
let data = Self::data(zalsa.table(), id);
701-
702-
data.read_lock(zalsa.current_revision());
697+
let data = Self::data(zalsa.table(), id, zalsa.current_revision());
703698

704699
// Add a dependency on the tracked struct itself.
705700
zalsa_local.report_tracked_read(
@@ -742,7 +737,7 @@ where
742737
revision: Revision,
743738
) -> MaybeChangedAfter {
744739
let zalsa = db.zalsa();
745-
let data = Self::data(zalsa.table(), input);
740+
let data = Self::data(zalsa.table(), input, zalsa.current_revision());
746741

747742
MaybeChangedAfter::from(data.created_at > revision)
748743
}
@@ -761,9 +756,7 @@ where
761756
_executor: DatabaseKeyIndex,
762757
_output_key: crate::Id,
763758
) {
764-
// we used to update `update_at` field but now we do it lazilly when data is accessed
765-
//
766-
// FIXME: delete this method
759+
// we used to update `update_at` field but now we do it lazily when data is accessed
767760
}
768761

769762
fn remove_stale_output(
@@ -776,7 +769,7 @@ where
776769
// `executor` creates a tracked struct `salsa_output_key`,
777770
// but it did not in the current revision.
778771
// In that case, we can delete `stale_output_key` and any data associated with it.
779-
self.delete_entity(db.as_dyn_database(), stale_output_key);
772+
self.delete_entity(db, stale_output_key);
780773
}
781774

782775
fn fmt_index(&self, index: Option<crate::Id>, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -811,34 +804,22 @@ where
811804
pub fn fields(&self) -> &C::Fields<'static> {
812805
&self.fields
813806
}
807+
}
814808

815-
fn take_memo_table(&mut self) -> MemoTable {
816-
// This fn is only called after `updated_at` has been set to `None`;
817-
// this ensures that there is no concurrent access
818-
// (and that the `&mut self` is accurate...).
819-
assert!(self.updated_at.load().is_none());
820-
821-
std::mem::take(&mut self.memos)
822-
}
823-
824-
fn read_lock(&self, current_revision: Revision) {
825-
loop {
826-
match self.updated_at.load() {
827-
None => {
828-
panic!("access to field whilst the value is being initialized");
829-
}
830-
Some(r) => {
831-
if r == current_revision {
832-
return;
833-
}
834-
835-
if self
836-
.updated_at
837-
.compare_exchange(Some(r), Some(current_revision))
838-
.is_ok()
839-
{
840-
break;
841-
}
809+
fn acquire_read_lock(updated_at: &OptionalAtomicRevision, current_revision: Revision) {
810+
loop {
811+
match updated_at.load() {
812+
None => panic!(
813+
"write lock taken; value leaked across threads or user functions not deterministic"
814+
),
815+
// the read lock was taken by someone else, so we also succeed
816+
Some(r) if r == current_revision => return,
817+
Some(r) => {
818+
if updated_at
819+
.compare_exchange(Some(r), Some(current_revision))
820+
.is_ok()
821+
{
822+
break;
842823
}
843824
}
844825
}
@@ -849,23 +830,25 @@ impl<C> Slot for Value<C>
849830
where
850831
C: Configuration,
851832
{
833+
// FIXME: `&self` may alias here before the lock is taken?
852834
unsafe fn memos(&self, current_revision: Revision) -> &crate::table::memo::MemoTable {
853835
// Acquiring the read lock here with the current revision
854836
// ensures that there is no danger of a race
855837
// when deleting a tracked struct.
856-
self.read_lock(current_revision);
838+
acquire_read_lock(&self.updated_at, current_revision);
857839
&self.memos
858840
}
859841

860842
fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable {
861843
&mut self.memos
862844
}
863845

846+
// FIXME: `&self` may alias here?
864847
unsafe fn syncs(&self, current_revision: Revision) -> &crate::table::sync::SyncTable {
865848
// Acquiring the read lock here with the current revision
866849
// ensures that there is no danger of a race
867850
// when deleting a tracked struct.
868-
self.read_lock(current_revision);
851+
acquire_read_lock(&self.updated_at, current_revision);
869852
&self.syncs
870853
}
871854
}

src/tracked_struct/tracked_field.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ where
6262
revision: crate::Revision,
6363
) -> MaybeChangedAfter {
6464
let zalsa = db.zalsa();
65-
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input);
65+
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input, zalsa.current_revision());
6666
let field_changed_at = data.revisions[self.field_index];
6767
MaybeChangedAfter::from(field_changed_at > revision)
6868
}

0 commit comments

Comments
 (0)