Skip to content

Commit

Permalink
Implement remainder of todos for local Vcs
Browse files Browse the repository at this point in the history
  • Loading branch information
bgw committed Jul 30, 2024
1 parent 914331b commit 0f8caae
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 45 deletions.
64 changes: 62 additions & 2 deletions crates/turbo-tasks-memory/tests/local_cell.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![feature(arbitrary_self_types)]

use turbo_tasks::Vc;
use turbo_tasks::{debug::ValueDebug, test_helpers::current_task_for_testing, ValueDefault, Vc};
use turbo_tasks_testing::{register, run, Registration};

static REGISTRATION: Registration = register!();
Expand Down Expand Up @@ -57,6 +57,57 @@ async fn test_return_resolved() {
.await
}

#[turbo_tasks::value_trait]
trait UnimplementedTrait {}

#[tokio::test]
async fn test_try_resolve_sidecast() {
run(&REGISTRATION, async {
let trait_vc: Vc<Box<dyn ValueDebug>> = Vc::upcast(Vc::<u32>::local_cell(42));

// `u32` is both a `ValueDebug` and a `ValueDefault`, so this sidecast is valid
let sidecast_vc = Vc::try_resolve_sidecast::<Box<dyn ValueDefault>>(trait_vc)
.await
.unwrap();
assert!(sidecast_vc.is_some());

// `u32` is not an `UnimplementedTrait` though, so this should return None
let wrongly_sidecast_vc = Vc::try_resolve_sidecast::<Box<dyn UnimplementedTrait>>(trait_vc)
.await
.unwrap();
assert!(wrongly_sidecast_vc.is_none());
})
.await
}

#[tokio::test]
async fn test_try_resolve_downcast_type() {
run(&REGISTRATION, async {
let trait_vc: Vc<Box<dyn ValueDebug>> = Vc::upcast(Vc::<u32>::local_cell(42));

let downcast_vc: Vc<u32> = Vc::try_resolve_downcast_type(trait_vc)
.await
.unwrap()
.unwrap();
assert_eq!(*downcast_vc.await.unwrap(), 42);

let wrongly_downcast_vc: Option<Vc<i64>> =
Vc::try_resolve_downcast_type(trait_vc).await.unwrap();
assert!(wrongly_downcast_vc.is_none());
})
.await
}

#[tokio::test]
async fn test_get_task_id() {
run(&REGISTRATION, async {
// the task id as reported by the RawVc
let vc_task_id = Vc::into_raw(Vc::<()>::local_cell(())).get_task_id();
assert_eq!(vc_task_id, current_task_for_testing());
})
.await
}

#[turbo_tasks::value(eq = "manual")]
#[derive(Default)]
struct Untracked {
Expand All @@ -83,7 +134,7 @@ async fn get_untracked_local_cell() -> Vc<Untracked> {

#[tokio::test]
#[should_panic(expected = "Local Vcs must only be accessed within their own task")]
async fn test_panics_on_local_cell_escape() {
async fn test_panics_on_local_cell_escape_read() {
run(&REGISTRATION, async {
get_untracked_local_cell()
.await
Expand All @@ -94,3 +145,12 @@ async fn test_panics_on_local_cell_escape() {
})
.await
}

#[tokio::test]
#[should_panic(expected = "Local Vcs must only be accessed within their own task")]
async fn test_panics_on_local_cell_escape_get_task_id() {
run(&REGISTRATION, async {
Vc::into_raw(get_untracked_local_cell().await.unwrap().cell).get_task_id();
})
.await
}
29 changes: 23 additions & 6 deletions crates/turbo-tasks/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,8 @@ pub(crate) fn create_local_cell(value: TypedSharedReference) -> (ExecutionId, Lo
/// local cells are always filled. The returned value can be cheaply converted
/// with `.into()`.
///
/// Panics if the ExecutionId does not match the expected value.
/// Panics if the [`ExecutionId`] does not match the current task's
/// `execution_id`.
pub(crate) fn read_local_cell(
execution_id: ExecutionId,
local_cell_id: LocalCellId,
Expand All @@ -1762,12 +1763,28 @@ pub(crate) fn read_local_cell(
local_cells,
..
} = &*cell.borrow();
assert_eq!(
execution_id, *expected_execution_id,
"This Vc is local. Local Vcs must only be accessed within their own task. Resolve the \
Vc to convert it into a non-local version."
);
assert_eq_local_cell(execution_id, *expected_execution_id);
// local cell ids are one-indexed (they use NonZeroU32)
local_cells[(*local_cell_id as usize) - 1].clone()
})
}

/// Panics if the [`ExecutionId`] does not match the current task's
/// `execution_id`.
pub(crate) fn assert_execution_id(execution_id: ExecutionId) {
CURRENT_TASK_STATE.with(|cell| {
let CurrentTaskState {
execution_id: expected_execution_id,
..
} = &*cell.borrow();
assert_eq_local_cell(execution_id, *expected_execution_id);
})
}

fn assert_eq_local_cell(actual: ExecutionId, expected: ExecutionId) {
assert_eq!(
actual, expected,
"This Vc is local. Local Vcs must only be accessed within their own task. Resolve the Vc \
to convert it into a non-local version."
);
}
85 changes: 48 additions & 37 deletions crates/turbo-tasks/src/raw_vc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ use crate::{
backend::{CellContent, TypedCellContent},
event::EventListener,
id::{ExecutionId, LocalCellId},
manager::{read_local_cell, read_task_cell, read_task_output, TurboTasksApi},
manager::{
assert_execution_id, current_task, read_local_cell, read_task_cell, read_task_output,
TurboTasksApi,
},
registry::{self, get_value_type},
turbo_tasks, CollectiblesSource, TaskId, TraitTypeId, ValueTypeId, Vc, VcValueTrait,
turbo_tasks, CollectiblesSource, TaskId, TraitTypeId, ValueType, ValueTypeId, Vc, VcValueTrait,
};

#[derive(Error, Debug)]
Expand Down Expand Up @@ -100,38 +103,31 @@ impl RawVc {
self,
trait_type: TraitTypeId,
) -> Result<Option<RawVc>, ResolveTypeError> {
let tt = turbo_tasks();
tt.notify_scheduled_tasks();
let mut current = self;
loop {
match current {
RawVc::TaskOutput(task) => {
current = read_task_output(&*tt, task, false)
.await
.map_err(|source| ResolveTypeError::TaskError { source })?;
}
RawVc::TaskCell(task, index) => {
let content = read_task_cell(&*tt, task, index)
.await
.map_err(|source| ResolveTypeError::ReadError { source })?;
if let TypedCellContent(value_type, CellContent(Some(_))) = content {
if get_value_type(value_type).has_trait(&trait_type) {
return Ok(Some(RawVc::TaskCell(task, index)));
} else {
return Ok(None);
}
} else {
return Err(ResolveTypeError::NoContent);
}
}
RawVc::LocalCell(_, _) => todo!(),
}
}
self.resolve_type_inner(|value_type_id| {
let value_type = get_value_type(value_type_id);
(value_type.has_trait(&trait_type), Some(value_type))
})
.await
}

pub(crate) async fn resolve_value(
self,
value_type: ValueTypeId,
) -> Result<Option<RawVc>, ResolveTypeError> {
self.resolve_type_inner(|cell_value_type| (cell_value_type == value_type, None))
.await
}

/// Helper for `resolve_trait` and `resolve_value`.
///
/// After finding a cell, returns `Ok(Some(...))` when `conditional` returns
/// `true`, and `Ok(None)` when `conditional` returns `false`.
///
/// As an optimization, `conditional` may return the `&'static ValueType` to
/// avoid a potential extra lookup later.
async fn resolve_type_inner(
self,
conditional: impl FnOnce(ValueTypeId) -> (bool, Option<&'static ValueType>),
) -> Result<Option<RawVc>, ResolveTypeError> {
let tt = turbo_tasks();
tt.notify_scheduled_tasks();
Expand All @@ -147,17 +143,29 @@ impl RawVc {
let content = read_task_cell(&*tt, task, index)
.await
.map_err(|source| ResolveTypeError::ReadError { source })?;
if let TypedCellContent(cell_value_type, CellContent(Some(_))) = content {
if cell_value_type == value_type {
return Ok(Some(RawVc::TaskCell(task, index)));
if let TypedCellContent(value_type, CellContent(Some(_))) = content {
return Ok(if conditional(value_type).0 {
Some(RawVc::TaskCell(task, index))
} else {
return Ok(None);
}
None
});
} else {
return Err(ResolveTypeError::NoContent);
}
}
RawVc::LocalCell(_, _) => todo!(),
RawVc::LocalCell(execution_id, local_cell_id) => {
let shared_reference = read_local_cell(execution_id, local_cell_id);
return Ok(
if let (true, value_type) = conditional(shared_reference.0) {
// re-use the `ValueType` lookup from `conditional`, if it exists
let value_type =
value_type.unwrap_or_else(|| get_value_type(shared_reference.0));
Some((value_type.raw_cell)(shared_reference))
} else {
None
},
);
}
}
}
}
Expand All @@ -172,7 +180,7 @@ impl RawVc {
self.resolve_inner(/* strongly_consistent */ true).await
}

pub(crate) async fn resolve_inner(self, strongly_consistent: bool) -> Result<RawVc> {
async fn resolve_inner(self, strongly_consistent: bool) -> Result<RawVc> {
let tt = turbo_tasks();
let mut current = self;
let mut notified = false;
Expand Down Expand Up @@ -203,7 +211,10 @@ impl RawVc {
pub fn get_task_id(&self) -> TaskId {
match self {
RawVc::TaskOutput(t) | RawVc::TaskCell(t, _) => *t,
RawVc::LocalCell(_, _) => todo!(),
RawVc::LocalCell(execution_id, _) => {
assert_execution_id(*execution_id);
current_task("RawVc::get_task_id")
}
}
}
}
Expand Down

0 comments on commit 0f8caae

Please sign in to comment.