Skip to content

Commit

Permalink
[Device][Backend] refactor: extract methods and deviceguard (#4)
Browse files Browse the repository at this point in the history
Refactor: extract methods and deviceguard
  • Loading branch information
shink authored Jul 17, 2024
1 parent c1b2f0a commit 4045d05
Show file tree
Hide file tree
Showing 33 changed files with 379 additions and 398 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ on:
types:
- 'published'

concurrency:
group: '${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
cancel-in-progress: ${{ !contains(github.ref, 'release/')}}

jobs:
build:
name: Build for Python${{ matrix.python-version }}
name: Build with Python${{ matrix.python-version }}
runs-on: ubuntu-latest
defaults:
run:
Expand Down
1 change: 1 addition & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ---[ Googletest
if(BUILD_TEST)
enable_testing()
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/googletest)
include_directories(BEFORE SYSTEM ${PROJECT_SOURCE_DIR}/third_party/googletest/googletest/include)
include_directories(BEFORE SYSTEM ${PROJECT_SOURCE_DIR}/third_party/googletest/googlemock/include)
Expand Down
Empty file removed csrc/core/.keep
Empty file.
146 changes: 146 additions & 0 deletions csrc/core/PrivateUse1Guard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#pragma once

#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>

#include "csrc/core/impl/PrivateUse1GuardImpl.h"

namespace c10_backend {

// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
// boilerplate]

/// A variant of DeviceGuard that is specialized for PrivateUse1. It accepts
/// integer indices (interpreting them as PrivateUse1 devices) and is a little
/// more efficient than DeviceGuard (it compiles to straight line
/// cudaSetDevice/cudaGetDevice calls); however, it can only be used
/// from code that links against PrivateUse1 directly.
template <typename T>
struct PrivateUse1Guard {
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit PrivateUse1Guard() = delete;

/// Set the current PrivateUse1 device to the passed device index.
explicit PrivateUse1Guard(c10::DeviceIndex device_index)
: guard_(device_index) {}

/// Sets the current PrivateUse1 device to the passed device. Errors if the
/// passed device is not a PrivateUse1 device.
explicit PrivateUse1Guard(c10::Device device) : guard_(device) {}

// Copy is not allowed
PrivateUse1Guard(const PrivateUse1Guard&) = delete;
PrivateUse1Guard& operator=(const PrivateUse1Guard&) = delete;

// Move is not allowed (there is no uninitialized state)
PrivateUse1Guard(PrivateUse1Guard&& other) = delete;
PrivateUse1Guard& operator=(PrivateUse1Guard&& other) = delete;

/// Sets the PrivateUse1 device to the given device. Errors if the given
/// device is not a PrivateUse1 device.
void set_device(c10::Device device) {
guard_.set_device(device);
}

/// Sets the PrivateUse1 device to the given device. Errors if the given
/// device is not a PrivateUse1 device. (This method is provided for
/// uniformity with DeviceGuard).
void reset_device(c10::Device device) {
guard_.reset_device(device);
}

/// Sets the PrivateUse1 device to the given device index.
void set_index(c10::DeviceIndex device_index) {
guard_.set_index(device_index);
}

/// Returns the device that was set upon construction of the guard
c10::Device original_device() const {
return guard_.original_device();
}

/// Returns the last device that was set via `set_device`, if any, otherwise
/// the device passed during construction.
c10::Device current_device() const {
return guard_.current_device();
}

private:
/// The guard for the current device.
c10::impl::InlineDeviceGuard<T> guard_;
};

/// A variant of OptionalDeviceGuard that is specialized for PrivateUse1. See
/// PrivateUse1Guard for when you can use this.
template <typename T>
struct OptionalPrivateUse1Guard {
/// Create an uninitialized OptionalPrivateUse1Guard.
explicit OptionalPrivateUse1Guard() : guard_() {}

/// Set the current PrivateUse1 device to the passed Device, if it is not
/// nullopt.
explicit OptionalPrivateUse1Guard(std::optional<c10::Device> device_opt)
: guard_(device_opt) {}

/// Set the current PrivateUse1 device to the passed device index, if it is
/// not nullopt
explicit OptionalPrivateUse1Guard(
std::optional<c10::DeviceIndex> device_index_opt)
: guard_(device_index_opt) {}

// Copy is not allowed
OptionalPrivateUse1Guard(const OptionalPrivateUse1Guard&) = delete;
OptionalPrivateUse1Guard& operator=(const OptionalPrivateUse1Guard&) = delete;

// See Note [Move construction for RAII guards is tricky]
OptionalPrivateUse1Guard(OptionalPrivateUse1Guard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalPrivateUse1Guard& operator=(OptionalPrivateUse1Guard&& other) =
delete;

/// Sets the PrivateUse1 device to the given device, initializing the guard if
/// it is not already initialized. Errors if the given device is not a
/// PrivateUse1 device.
void set_device(c10::Device device) {
guard_.set_device(device);
}

/// Sets the PrivateUse1 device to the given device, initializing the guard if
/// it is not already initialized. Errors if the given device is not a
/// PrivateUse1 device. (This method is provided for uniformity with
/// OptionalDeviceGuard).
void reset_device(c10::Device device) {
guard_.reset_device(device);
}

/// Sets the PrivateUse1 device to the given device index, initializing the
/// guard if it is not already initialized.
void set_index(c10::DeviceIndex device_index) {
guard_.set_index(device_index);
}

/// Returns the device that was set immediately prior to initialization of the
/// guard, or nullopt if the guard is uninitialized.
std::optional<c10::Device> original_device() const {
return guard_.original_device();
}

/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
std::optional<c10::Device> current_device() const {
return guard_.current_device();
}

/// Restore the original PrivateUse1 device, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}

private:
c10::impl::InlineOptionalDeviceGuard<T> guard_;
};

} // namespace c10_backend
27 changes: 27 additions & 0 deletions csrc/core/impl/PrivateUse1GuardImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/macros/Macros.h>

namespace c10_backend::impl {

/**
* All classes which inherit from PrivateUse1GuardImpl should be declared
* 'final'.
*/
struct PrivateUse1GuardImpl : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;

PrivateUse1GuardImpl() = default;

explicit PrivateUse1GuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::PrivateUse1);
}

c10::DeviceType type() const final {
return c10::DeviceType::PrivateUse1;
}
};

} // namespace c10_backend::impl
2 changes: 1 addition & 1 deletion csrc/npu/CachingAllocatorHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ ExpandableSegment* createExpandableSegment(
size_t size);

// Wraps the insert event function
void insertEventWrapper(int device, std::function<void()> insertEventFn);
void insertEventWrapper(c10::DeviceIndex device, std::function<void()> insertEventFn);

// Returns the current stream for the given device
void* getCurrentStream(c10::DeviceIndex);
Expand Down
22 changes: 11 additions & 11 deletions csrc/npu/NPUCachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct BlockPool {
};

struct Block {
int device;
c10::DeviceIndex device;
void* stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
Expand All @@ -155,7 +155,7 @@ struct Block {
// memory out from our cache.
std::shared_ptr<c10::GatheredContext> context_when_segment_allocated;

Block(int device, void* stream, size_t size, BlockPool* pool, void* ptr)
Block(c10::DeviceIndex device, void* stream, size_t size, BlockPool* pool, void* ptr)
: device(device),
stream(stream),
stream_uses(),
Expand All @@ -170,7 +170,7 @@ struct Block {
gc_count(0) {}

// constructor for search key
Block(int device, void* stream, size_t size)
Block(c10::DeviceIndex device, void* stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
Expand Down Expand Up @@ -244,7 +244,7 @@ static std::string format_size(uint64_t size) {

struct AllocParams {
AllocParams(
int device,
c10::DeviceIndex device,
size_t size,
void* stream,
BlockPool* pool,
Expand All @@ -256,7 +256,7 @@ struct AllocParams {
block(nullptr),
err(ACL_ERROR_NONE) {}

int device() const {
c10::DeviceIndex device() const {
return search_key.device;
}
void* stream() const {
Expand Down Expand Up @@ -641,7 +641,7 @@ class DeviceCachingAllocator {
// All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method.

Block* malloc(int device, size_t orig_size, void* stream) {
Block* malloc(c10::DeviceIndex device, size_t orig_size, void* stream) {
// done outside the lock because we don't know what locks the recorder needs
// to have...
auto context = maybeGatherContext(RecordContext::STATE);
Expand Down Expand Up @@ -2012,7 +2012,7 @@ class NpuCachingAllocator : public NPUAllocator {
return !device_allocator.empty();
}
/** allocates a block which is safe to use from the provided stream */
void malloc(void** devPtr, int device, size_t size, void* stream) {
void malloc(void** devPtr, c10::DeviceIndex device, size_t size, void* stream) {
Block* block = device_allocator[device]->malloc(device, size, stream);
add_allocated_block(block);
*devPtr = static_cast<void*>(block->ptr);
Expand Down Expand Up @@ -2060,7 +2060,7 @@ class NpuCachingAllocator : public NPUAllocator {
}

bool isHistoryEnabled() override {
int device = 0;
c10::DeviceIndex device = 0;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
return device_allocator[device]->isHistoryEnabled();
}
Expand Down Expand Up @@ -2154,7 +2154,7 @@ class NpuCachingAllocator : public NPUAllocator {
}

c10::DataPtr allocate(size_t size) override {
int device = 0;
c10::DeviceIndex device = 0;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
void* devPtr = nullptr;
void (*deleteFunc)(void*) = &local_raw_delete;
Expand Down Expand Up @@ -2215,7 +2215,7 @@ class NpuCachingAllocator : public NPUAllocator {
if (nbytes == 0) {
return nullptr;
}
int device = 0;
c10::DeviceIndex device = 0;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, getCurrentStream(device));
Expand All @@ -2226,7 +2226,7 @@ class NpuCachingAllocator : public NPUAllocator {
if (nbytes == 0) {
return nullptr;
}
int device;
c10::DeviceIndex device;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, stream);
Expand Down
Loading

0 comments on commit 4045d05

Please sign in to comment.