Skip to content

Commit

Permalink
Minimal implementation of local cells
Browse files Browse the repository at this point in the history
  • Loading branch information
bgw committed Jul 26, 2024
1 parent 2070b13 commit 10b9aa5
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 65 deletions.
10 changes: 10 additions & 0 deletions crates/turbo-tasks-macros/src/value_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ pub fn value(args: TokenStream, input: TokenStream) -> TokenStream {
let content = self;
turbo_tasks::Vc::cell_private(#cell_access_content)
}

/// Places a value in a task-local cell stored in the current task.
///
/// Task-local cells are stored in a task-local arena, and do not persist outside the
/// lifetime of the current task (including child tasks). Task-local cells can be resolved
/// to be converted into normal cells.
#cell_prefix fn local_cell(self) -> turbo_tasks::Vc<Self> {
let content = self;
turbo_tasks::Vc::local_cell_private(#cell_access_content)
}
};

let into = if let IntoMode::New | IntoMode::Shared = into_mode {
Expand Down
2 changes: 1 addition & 1 deletion crates/turbo-tasks-memory/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Display for OutputContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OutputContent::Empty => write!(f, "empty"),
OutputContent::Link(raw_vc) => write!(f, "link {}", raw_vc),
OutputContent::Link(raw_vc) => write!(f, "link {:?}", raw_vc),
OutputContent::Error(err) => write!(f, "error {}", err),
OutputContent::Panic(Some(message)) => write!(f, "panic {}", message),
OutputContent::Panic(None) => write!(f, "panic"),
Expand Down
45 changes: 45 additions & 0 deletions crates/turbo-tasks-memory/tests/local_cell.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#![feature(arbitrary_self_types)]

use turbo_tasks::Vc;
use turbo_tasks_testing::{register, run, Registration};

static REGISTRATION: Registration = register!();

#[turbo_tasks::value]
struct Wrapper(u32);

#[turbo_tasks::value(transparent)]
struct TransparentWrapper(u32);

#[tokio::test]
async fn store_and_read() {
run(&REGISTRATION, async {
let a: Vc<u32> = Vc::local_cell(42);
assert_eq!(*a.await.unwrap(), 42);

let b = Wrapper(42).local_cell();
assert_eq!((*b.await.unwrap()).0, 42);

let c = TransparentWrapper(42).local_cell();
assert_eq!(*c.await.unwrap(), 42);
})
.await
}

#[tokio::test]
async fn store_and_read_generic() {
run(&REGISTRATION, async {
// `Vc<Vec<Vc<T>>>` is stored as `Vc<Vec<Vc<()>>>` and requires special
// transmute handling
let cells: Vc<Vec<Vc<u32>>> =
Vc::local_cell(vec![Vc::local_cell(1), Vc::local_cell(2), Vc::cell(3)]);

let mut output = Vec::new();
for el in cells.await.unwrap() {
output.push(*el.await.unwrap());
}

assert_eq!(output, vec![1, 2, 3]);
})
.await
}
1 change: 1 addition & 0 deletions crates/turbo-tasks/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ define_id!(ValueTypeId: u32);
define_id!(TraitTypeId: u32);
define_id!(BackendJobId: u32);
define_id!(ExecutionId: u64, derive(Debug));
define_id!(LocalCellId: u32, derive(Debug));

impl Debug for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
166 changes: 118 additions & 48 deletions crates/turbo-tasks/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::{
},
capture_future::{self, CaptureFuture},
event::{Event, EventListener},
id::{BackendJobId, FunctionId, TraitTypeId},
id_factory::IdFactoryWithReuse,
id::{BackendJobId, ExecutionId, FunctionId, LocalCellId, TraitTypeId},
id_factory::{IdFactory, IdFactoryWithReuse},
magic_any::MagicAny,
raw_vc::{CellId, RawVc},
registry,
Expand Down Expand Up @@ -243,6 +243,7 @@ pub struct TurboTasks<B: Backend + 'static> {
this: Weak<Self>,
backend: B,
task_id_factory: IdFactoryWithReuse<TaskId>,
execution_id_factory: IdFactory<ExecutionId>,
stopped: AtomicBool,
currently_scheduled_tasks: AtomicUsize,
currently_scheduled_foreground_jobs: AtomicUsize,
Expand All @@ -257,7 +258,6 @@ pub struct TurboTasks<B: Backend + 'static> {
program_start: Instant,
}

#[derive(Default)]
struct CurrentTaskState {
/// Affected tasks, that are tracked during task execution. These tasks will
/// be invalidated when the execution finishes or before reading a cell
Expand All @@ -266,6 +266,26 @@ struct CurrentTaskState {

/// True if the current task has state in cells
stateful: bool,

/// A unique identifier created for each unique `CurrentTaskState`. Used to
/// check that [`CurrentTaskState::local_cells`] are valid for the current
/// `RawVc::LocalCell`.
execution_id: ExecutionId,

/// Cells for locally allocated Vcs (`RawVc::LocalCell`). This is freed
/// (along with `CurrentTaskState`) when
local_cells: Vec<TypedCellContent>,
}

impl CurrentTaskState {
fn new(execution_id: ExecutionId) -> Self {
Self {
tasks_to_notify: Vec::new(),
stateful: false,
execution_id,
local_cells: Vec::new(),
}
}
}

// TODO implement our own thread pool and make these thread locals instead
Expand Down Expand Up @@ -293,6 +313,7 @@ impl<B: Backend + 'static> TurboTasks<B> {
this: this.clone(),
backend,
task_id_factory,
execution_id_factory: IdFactory::new(),
stopped: AtomicBool::new(false),
currently_scheduled_tasks: AtomicUsize::new(0),
currently_scheduled_background_jobs: AtomicUsize::new(0),
Expand Down Expand Up @@ -490,53 +511,57 @@ impl<B: Backend + 'static> TurboTasks<B> {
let future = async move {
#[allow(clippy::blocks_in_conditions)]
while CURRENT_TASK_STATE
.scope(Default::default(), async {
if this.stopped.load(Ordering::Acquire) {
return false;
}
.scope(
RefCell::new(CurrentTaskState::new(this.execution_id_factory.get())),
async {
if this.stopped.load(Ordering::Acquire) {
return false;
}

// Setup thread locals
CELL_COUNTERS
.scope(Default::default(), async {
let Some(TaskExecutionSpec { future, span }) =
this.backend.try_start_task_execution(task_id, &*this)
else {
return false;
};

async {
let (result, duration, memory_usage) =
CaptureFuture::new(AssertUnwindSafe(future).catch_unwind())
.await;

let result = result.map_err(|any| match any.downcast::<String>() {
Ok(owned) => Some(Cow::Owned(*owned)),
Err(any) => match any.downcast::<&'static str>() {
Ok(str) => Some(Cow::Borrowed(*str)),
Err(_) => None,
},
});
this.backend.task_execution_result(task_id, result, &*this);
let stateful = this.finish_current_task_state();
let cell_counters =
CELL_COUNTERS.with(|cc| take(&mut *cc.borrow_mut()));
let schedule_again = this.backend.task_execution_completed(
task_id,
duration,
memory_usage,
cell_counters,
stateful,
&*this,
);
// task_execution_completed might need to notify tasks
this.notify_scheduled_tasks();
schedule_again
}
.instrument(span)
// Setup thread locals
CELL_COUNTERS
.scope(Default::default(), async {
let Some(TaskExecutionSpec { future, span }) =
this.backend.try_start_task_execution(task_id, &*this)
else {
return false;
};

async {
let (result, duration, memory_usage) =
CaptureFuture::new(AssertUnwindSafe(future).catch_unwind())
.await;

let result =
result.map_err(|any| match any.downcast::<String>() {
Ok(owned) => Some(Cow::Owned(*owned)),
Err(any) => match any.downcast::<&'static str>() {
Ok(str) => Some(Cow::Borrowed(*str)),
Err(_) => None,
},
});
this.backend.task_execution_result(task_id, result, &*this);
let stateful = this.finish_current_task_state();
let cell_counters =
CELL_COUNTERS.with(|cc| take(&mut *cc.borrow_mut()));
let schedule_again = this.backend.task_execution_completed(
task_id,
duration,
memory_usage,
cell_counters,
stateful,
&*this,
);
// task_execution_completed might need to notify tasks
this.notify_scheduled_tasks();
schedule_again
}
.instrument(span)
.await
})
.await
})
.await
})
},
)
.await
{}
this.finish_primary_job();
Expand Down Expand Up @@ -841,6 +866,7 @@ impl<B: Backend + 'static> TurboTasks<B> {
let CurrentTaskState {
tasks_to_notify,
stateful,
..
} = &mut *cell.borrow_mut();
(*stateful, take(tasks_to_notify))
});
Expand Down Expand Up @@ -1688,3 +1714,47 @@ pub fn find_cell_by_type(ty: ValueTypeId) -> CurrentCellRef {
}
})
}

pub(crate) fn create_local_cell(value: TypedCellContent) -> (ExecutionId, LocalCellId) {
CURRENT_TASK_STATE.with(|cell| {
let CurrentTaskState {
execution_id,
local_cells,
..
} = &mut *cell.borrow_mut();

// store in the task-local arena
local_cells.push(value);

// generate a one-indexed id
let raw_local_cell_id = local_cells.len();
let local_cell_id = if cfg!(debug_assertions) {
LocalCellId::from(u32::try_from(raw_local_cell_id).unwrap())
} else {
unsafe { LocalCellId::new_unchecked(raw_local_cell_id as u32) }
};

(*execution_id, local_cell_id)
})
}

/// Panics if the ExecutionId does not match the expected value.
pub(crate) fn read_local_cell(
execution_id: ExecutionId,
local_cell_id: LocalCellId,
) -> TypedCellContent {
CURRENT_TASK_STATE.with(|cell| {
let CurrentTaskState {
execution_id: expected_execution_id,
local_cells,
..
} = &*cell.borrow();
assert_eq!(
execution_id, *expected_execution_id,
"This Vc is local. Local Vcs must only be accessed within their own task. Resolve the \
Vc to convert it into a non-local version."
);
// local cell ids are one-indexed (they use NonZeroU32)
local_cells[(*local_cell_id as usize) - 1].clone()
})
}
27 changes: 13 additions & 14 deletions crates/turbo-tasks/src/raw_vc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use thiserror::Error;
use crate::{
backend::{CellContent, TypedCellContent},
event::EventListener,
manager::{read_task_cell, read_task_output, TurboTasksApi},
id::{ExecutionId, LocalCellId},
manager::{read_local_cell, read_task_cell, read_task_output, TurboTasksApi},
registry::{self, get_value_type},
turbo_tasks, CollectiblesSource, TaskId, TraitTypeId, ValueTypeId, Vc, VcValueTrait,
};
Expand Down Expand Up @@ -53,13 +54,16 @@ impl Display for CellId {
pub enum RawVc {
TaskOutput(TaskId),
TaskCell(TaskId, CellId),
#[serde(skip)]
LocalCell(ExecutionId, LocalCellId),
}

impl RawVc {
pub(crate) fn is_resolved(&self) -> bool {
match self {
RawVc::TaskOutput(_) => false,
RawVc::TaskCell(_, _) => true,
RawVc::LocalCell(_, _) => false,
}
}

Expand Down Expand Up @@ -120,6 +124,7 @@ impl RawVc {
return Err(ResolveTypeError::NoContent);
}
}
RawVc::LocalCell(_, _) => todo!(),
}
}
}
Expand Down Expand Up @@ -152,6 +157,7 @@ impl RawVc {
return Err(ResolveTypeError::NoContent);
}
}
RawVc::LocalCell(_, _) => todo!(),
}
}
}
Expand All @@ -171,6 +177,7 @@ impl RawVc {
current = read_task_output(&*tt, task, false).await?;
}
RawVc::TaskCell(_, _) => return Ok(current),
RawVc::LocalCell(_, _) => todo!(),
}
}
}
Expand All @@ -190,6 +197,7 @@ impl RawVc {
current = read_task_output(&*tt, task, true).await?;
}
RawVc::TaskCell(_, _) => return Ok(current),
RawVc::LocalCell(_, _) => todo!(),
}
}
}
Expand All @@ -202,6 +210,7 @@ impl RawVc {
pub fn get_task_id(&self) -> TaskId {
match self {
RawVc::TaskOutput(t) | RawVc::TaskCell(t, _) => *t,
RawVc::LocalCell(_, _) => todo!(),
}
}
}
Expand All @@ -227,19 +236,6 @@ impl CollectiblesSource for RawVc {
}
}

impl Display for RawVc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RawVc::TaskOutput(task) => {
write!(f, "output of {}", task)
}
RawVc::TaskCell(task, index) => {
write!(f, "value {} of {}", index, task)
}
}
}
}

pub struct ReadRawVcFuture {
turbo_tasks: Arc<dyn TurboTasksApi>,
strongly_consistent: bool,
Expand Down Expand Up @@ -358,6 +354,9 @@ impl Future for ReadRawVcFuture {
Err(err) => return Poll::Ready(Err(err)),
}
}
RawVc::LocalCell(execution_id, local_cell_id) => {
return Poll::Ready(Ok(read_local_cell(execution_id, local_cell_id)));
}
};
// SAFETY: listener is from previous pinned this
match unsafe { Pin::new_unchecked(&mut listener) }.poll(cx) {
Expand Down
Loading

0 comments on commit 10b9aa5

Please sign in to comment.