Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug for asort kernel & faster sampler with GPU sorting #2730

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl CudaDevice {
self.id
}

fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
pub fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
Expand Down
117 changes: 104 additions & 13 deletions candle-core/src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{Result, Tensor};
use crate::{DType, Result, Shape, Storage, Tensor};
use rayon::prelude::*;

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
struct ArgSort {
asc: bool,
last_dim: usize,
dtype: DType,
#[cfg(feature = "cuda")]
indices: Tensor,
}

impl ArgSort {
Expand Down Expand Up @@ -56,7 +58,8 @@ impl ArgSort {
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
result::memcpy_dtod_sync, CudaSlice, DevicePtr, DeviceRepr, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
Expand All @@ -74,22 +77,100 @@ mod cuda {
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let (indices, _) = self.indices.storage_and_layout();

let indices = match &*indices {
Storage::Cuda(k) => k.as_cuda_slice::<u32>()?.to_owned(),
_ => crate::bail!("indices must be a cuda tensor"),
};
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;

//size of each row must be log2-base for bitonic sort
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
//alloc temp buffer for paddings
let tmp_rows = dev.const_impl(
if self.asc {
std::f64::MAX
} else {
std::f64::MIN
},
&Shape::from((nrows, ncols_pad)),
self.dtype,
)?;
let tmp_indices = unsafe { dev.alloc::<u32>(ncols_pad) }.w()?;
// Determine the number of threads per block and blocks per row
let max_threads_per_block = 1024;
let threads_per_block = max_threads_per_block.min(ncols_pad);
let blocks_per_row = (ncols_pad + threads_per_block - 1) / threads_per_block;

let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
grid_dim: (blocks_per_row as u32, 1, 1),
block_dim: (threads_per_block as u32, 1, 1),
shared_mem_bytes: (threads_per_block * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;

unsafe {
for row in 0..nrows {
let start_o = row * ncols;
let slice_row = slice.slice(start_o..);
let dst_row = dst.slice(start_o..);
let tmp_row_ptr = match &tmp_rows.slice {
S::U8(inp) => *inp.slice(start_o..).device_ptr(),
S::U32(inp) => *inp.slice(start_o..).device_ptr(),
S::I64(inp) => *inp.slice(start_o..).device_ptr(),
S::BF16(inp) => *inp.slice(start_o..).device_ptr(),
S::F16(inp) => *inp.slice(start_o..).device_ptr(),
S::F32(inp) => *inp.slice(start_o..).device_ptr(),
S::F64(inp) => *inp.slice(start_o..).device_ptr(),
};

memcpy_dtod_sync(
tmp_row_ptr,
*slice_row.device_ptr(),
ncols * std::mem::size_of::<T>(),
)
.w()?;
memcpy_dtod_sync(
*tmp_indices.device_ptr(),
*indices.device_ptr(),
ncols * std::mem::size_of::<u32>(),
)
.w()?;

let mut k = 2;
while k <= ncols_pad {
// Minor step
let mut j = k >> 1;
while j > 0 {
let params = (tmp_row_ptr, &tmp_indices, j as i32, k as i32);
func.clone().launch(cfg, params).w()?;
j = j >> 1;
}
k <<= 1;
}

//copy back valid elements
memcpy_dtod_sync(
*slice_row.device_ptr(),
tmp_row_ptr,
ncols * std::mem::size_of::<T>(),
)
.w()?;
memcpy_dtod_sync(
*dst_row.device_ptr(),
*tmp_indices.device_ptr(),
ncols * std::mem::size_of::<u32>(),
)
.w()?;
}
}
Ok(S::U32(dst))
}
}
Expand Down Expand Up @@ -221,8 +302,18 @@ impl Tensor {
None => crate::bail!("empty last-dim in arg-sort"),
Some(last_dim) => *last_dim,
};
#[cfg(feature = "cuda")]
let indices_cpu = (0..last_dim).into_iter().map(|a| a as u32).collect();
#[cfg(feature = "cuda")]
let indices = Tensor::from_vec(indices_cpu, (1, last_dim), self.device())?;
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
self.apply_op1_no_bwd(&ArgSort {
asc,
last_dim,
dtype: self.dtype(),
#[cfg(feature = "cuda")]
indices,
})
}

/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
Expand All @@ -237,8 +328,8 @@ impl Tensor {
op: "sort_last_dim",
});
}
let asort = self.arg_sort_last_dim(asc)?;
let sorted = self.gather(&asort, crate::D::Minus1)?;
let sorted = self.copy()?;
let asort = sorted.arg_sort_last_dim(asc)?;
Ok((sorted, asort))
}
}
37 changes: 26 additions & 11 deletions candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;

use candle_transformers::models::llama as model;
use hf_hub::{api::sync::Api, Repo, RepoType};
use model::{Llama, LlamaConfig};
use std::io::Write;
use std::path::Path;

const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
Expand Down Expand Up @@ -120,6 +120,9 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,

#[arg(long)]
weight_path: Option<String>,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -173,8 +176,15 @@ fn main() -> Result<()> {
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));

let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let tokenizer_filename = match &args.weight_path {
Some(path) => Path::new(path).join("tokenizer.json"),
_ => api.get("tokenizer.json")?,
};
let config_filename = match &args.weight_path {
Some(path) => Path::new(path).join("config.json"),
_ => api.get("config.json")?,
};

let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);

Expand All @@ -187,9 +197,13 @@ fn main() -> Result<()> {
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
| Which::Solar10_7B => match &args.weight_path {
Some(path) => candle_examples::hub_load_local_safetensors(
path,
"model.safetensors.index.json",
)?,
_ => candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?,
},
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
Expand All @@ -198,9 +212,10 @@ fn main() -> Result<()> {
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
| Which::TinyLlama1_1BChat => match &args.weight_path {
Some(path) => vec![Path::new(path).join("model.safetensors")],
_ => vec![api.get("model.safetensors")?],
},
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

Expand Down
87 changes: 39 additions & 48 deletions candle-kernels/src/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,63 @@
#include<stdint.h>

template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
inline __device__ void swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}

template<int order, typename T>
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
template <typename T>
__device__ void bitonicSortGPU(T* arr, uint32_t * dst, int j, int k, bool ascending) {
unsigned int i, ij;
i = threadIdx.x + blockDim.x * blockIdx.x;
ij = i ^ j;

if (col >= ncols_pad) {
return;
}

const T * x_row = x + row * ncols;
extern __shared__ int dst_row[];

// initialize indices
dst_row[col] = col;

__syncthreads();

for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
if (ij > i) {
if ((i & k) == 0) {
// Sort in ascending order
if (ascending) {
if (arr[i] > arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
// Sort in descending order
else {
if (arr[i] < arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
} else {
// Sort in ascending order
if (ascending) {
if (arr[i] < arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
// Sort in descending order
else {
if (arr[i] > arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
__syncthreads();
}
}

// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}

#define ASORT_OP(TYPENAME, RUST_NAME) \
extern "C" __global__ void asort_asc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
TYPENAME * x, uint32_t * dst, const int j, const int k \
) { \
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
bitonicSortGPU(x, dst, j, k, true);\
} \
extern "C" __global__ void asort_desc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
TYPENAME * x, uint32_t * dst, const int j, const int k \
) { \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
bitonicSortGPU(x, dst, j, k, false);\
} \

#if __CUDA_ARCH__ >= 800
Expand Down
Loading