diff --git a/crates/turbo-tasks-memory/tests/trait_ref_cell_mode.rs b/crates/turbo-tasks-memory/tests/trait_ref_cell_mode.rs new file mode 100644 index 0000000000000..77b092fdf4e7c --- /dev/null +++ b/crates/turbo-tasks-memory/tests/trait_ref_cell_mode.rs @@ -0,0 +1,144 @@ +#![feature(arbitrary_self_types)] + +use anyhow::Result; +use turbo_tasks::{IntoTraitRef, State, TraitRef, Upcast, Vc}; +use turbo_tasks_testing::{register, run, Registration}; + +static REGISTRATION: Registration = register!(); + +// Test that with `cell = "shared"`, the cell will be re-used as long as the +// value is equal. +#[tokio::test] +async fn test_trait_ref_shared_cell_mode() { + run(®ISTRATION, async { + let input = CellIdSelector { + value: 42, + cell_idx: State::new(0), + } + .cell(); + + // create the task and compute it + let counter_value_vc = shared_value_from_input(input); + let trait_ref_a = counter_value_vc.into_trait_ref().await.unwrap(); + + // invalidate the task, and pick a different cell id for the next execution + input.await.unwrap().cell_idx.set_unconditionally(1); + + // recompute the task + let trait_ref_b = counter_value_vc.into_trait_ref().await.unwrap(); + + for trait_ref in [&trait_ref_a, &trait_ref_b] { + assert_eq!( + *TraitRef::cell(trait_ref.clone()).get_value().await.unwrap(), + 42 + ); + } + + // because we're using `cell = "shared"`, these trait refs must use the same + // underlying Arc/SharedRef (by identity) + assert!(TraitRef::ptr_eq(&trait_ref_a, &trait_ref_b)); + }) + .await +} + +// Test that with `cell = "new"`, the cell will is never re-used, even if the +// value is equal. +#[tokio::test] +async fn test_trait_ref_new_cell_mode() { + run(®ISTRATION, async { + let input = CellIdSelector { + value: 42, + cell_idx: State::new(0), + } + .cell(); + + // create the task and compute it + let counter_value_vc = new_value_from_input(input); + let trait_ref_a = counter_value_vc.into_trait_ref().await.unwrap(); + + // invalidate the task, and pick a different cell id for the next execution + input.await.unwrap().cell_idx.set_unconditionally(1); + + // recompute the task + let trait_ref_b = counter_value_vc.into_trait_ref().await.unwrap(); + + for trait_ref in [&trait_ref_a, &trait_ref_b] { + assert_eq!( + *TraitRef::cell(trait_ref.clone()).get_value().await.unwrap(), + 42 + ); + } + + // because we're using `cell = "new"`, these trait refs must use different + // underlying Arc/SharedRefs (by identity) + assert!(!TraitRef::ptr_eq(&trait_ref_a, &trait_ref_b)); + }) + .await +} + +#[turbo_tasks::value_trait] +trait ValueTrait { + fn get_value(&self) -> Vc; +} + +#[turbo_tasks::value(transparent, cell = "shared")] +struct SharedValue(usize); + +#[turbo_tasks::value(transparent, cell = "new")] +struct NewValue(usize); + +#[turbo_tasks::value_impl] +impl ValueTrait for SharedValue { + #[turbo_tasks::function] + fn get_value(&self) -> Vc { + Vc::cell(self.0) + } +} + +#[turbo_tasks::value_impl] +impl ValueTrait for NewValue { + #[turbo_tasks::function] + fn get_value(&self) -> Vc { + Vc::cell(self.0) + } +} + +#[turbo_tasks::value] +struct CellIdSelector { + value: usize, + cell_idx: State, +} + +async fn value_from_input( + input: Vc, + mut cell_fn: impl FnMut(usize) -> Vc, +) -> Result>> +where + T: ValueTrait + Upcast>, +{ + let input = input.await?; + + // create multiple cells so that we can pick from them, simulating a function + // with non-deterministic ordering that returns a "random" cell that happens to + // contain the same value + let mut upcast_vcs = Vec::new(); + for _idx in 0..2 { + upcast_vcs.push(Vc::upcast((cell_fn)(input.value))); + } + + // pick a different cell idx upon each invalidation/execution + let picked_vc = upcast_vcs[*input.cell_idx.get()]; + + // round-trip through `TraitRef::cell` + Ok(TraitRef::cell(picked_vc.into_trait_ref().await?)) +} + +#[turbo_tasks::function] +async fn shared_value_from_input(input: Vc) -> Result>> { + value_from_input::(input, Vc::::cell).await +} + +#[turbo_tasks::function] +async fn new_value_from_input(input: Vc) -> Result>> { + value_from_input::(input, Vc::::cell).await +} diff --git a/crates/turbo-tasks/src/trait_ref.rs b/crates/turbo-tasks/src/trait_ref.rs index ddb35820435e2..5473254b0cf6e 100644 --- a/crates/turbo-tasks/src/trait_ref.rs +++ b/crates/turbo-tasks/src/trait_ref.rs @@ -4,10 +4,10 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use crate::{ - manager::find_cell_by_type, + registry::get_value_type, task::shared_reference::TypedSharedReference, vc::{cast::VcCast, ReadVcFuture, VcValueTraitCast}, - RawVc, Vc, VcValueTrait, + Vc, VcValueTrait, }; /// Similar to a [`ReadRef`][crate::ReadRef], but contains a value trait @@ -90,6 +90,10 @@ where _t: PhantomData, } } + + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + triomphe::Arc::ptr_eq(&this.shared_reference.1 .0, &other.shared_reference.1 .0) + } } impl TraitRef @@ -99,12 +103,11 @@ where /// Returns a new cell that points to a value that implements the value /// trait `T`. pub fn cell(trait_ref: TraitRef) -> Vc { - // See Safety clause above. - let TypedSharedReference(ty, shared_ref) = trait_ref.shared_reference; - let local_cell = find_cell_by_type(ty); - local_cell.update_with_shared_reference(shared_ref); - let raw_vc: RawVc = local_cell.into(); - raw_vc.into() + let TraitRef { + shared_reference, .. + } = trait_ref; + let value_type = get_value_type(shared_reference.0); + (value_type.raw_cell)(shared_reference).into() } }