-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
ggml-cuda.cu
11708 lines (9540 loc) · 433 KB
/
ggml-cuda.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#if defined(GGML_USE_HIPBLAS)
#define GGML_COMMON_DECL_HIP
#define GGML_COMMON_IMPL_HIP
#else
#define GGML_COMMON_DECL_CUDA
#define GGML_COMMON_IMPL_CUDA
#endif
#include "ggml-common.h"
#include <algorithm>
#include <array>
#include <assert.h>
#include <atomic>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <stdio.h>
#include <string>
#include <vector>
// stringize macro for converting __CUDA_ARCH_LIST__ (list of integers) to string
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define __trap abort
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#else
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
#endif // defined(GGML_USE_HIPBLAS)
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
// - 7B quantum model: +100-200 MB
// - 13B quantum model: +200-400 MB
//
//#define GGML_CUDA_FORCE_MMQ
// TODO: improve this to be correct for more hardware
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
#if !defined(GGML_CUDA_FORCE_MMQ)
#define CUDA_USE_TENSOR_CORES
#endif
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
[[noreturn]]
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
int id = -1; // in case cudaGetDevice fails
cudaGetDevice(&id);
fprintf(stderr, "CUDA error: %s\n", msg);
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
fprintf(stderr, " %s\n", stmt);
// abort with GGML_ASSERT to get a stack trace
GGML_ASSERT(!"CUDA error");
}
#define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
}
#else
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
default: return "unknown error";
}
}
#endif // CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11100
#define GGML_CUDA_MAX_STREAMS 8
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
};
// this is faster on Windows
// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
static void ggml_cuda_set_device(const int device) {
int current_device;
CUDA_CHECK(cudaGetDevice(¤t_device));
if (device == current_device) {
return;
}
CUDA_CHECK(cudaSetDevice(device));
}
static int ggml_cuda_get_device() {
int id;
CUDA_CHECK(cudaGetDevice(&id));
return id;
}
struct ggml_cuda_device_info {
int device_count;
struct cuda_device_info {
int cc; // compute capability
size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
};
static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
rocblas_initialize();
CUDA_CHECK(cudaDeviceSynchronize());
#endif
ggml_cuda_device_info info = {};
cudaError_t err = cudaGetDeviceCount(&info.device_count);
if (err != cudaSuccess) {
fprintf(stderr, "%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
return info;
}
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
#if defined(GGML_CUDA_FORCE_MMQ)
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif
#if defined(CUDA_USE_TENSOR_CORES)
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
#else
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
if (device_vmm) {
CUmemAllocationProp alloc_prop = {};
alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
}
#endif // !defined(GGML_USE_HIPBLAS)
info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpb = prop.sharedMemPerBlock;
}
for (int id = 0; id < info.device_count; ++id) {
info.default_tensor_split[id] /= total_vram;
}
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
return info;
}
static const ggml_cuda_device_info & get_cuda_global_info() {
static ggml_cuda_device_info info = ggml_cuda_init();
return info;
}
// #define DEBUG_CUDA_MALLOC
// buffer pool for cuda (legacy)
struct ggml_cuda_pool {
virtual ~ggml_cuda_pool() = default;
virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0;
};
struct ggml_cuda_pool_leg : public ggml_cuda_pool {
static const int MAX_BUFFERS = 256;
int device;
struct ggml_cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;
explicit ggml_cuda_pool_leg(int device) :
device(device) {
}
~ggml_cuda_pool_leg() {
ggml_cuda_set_device(device);
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
CUDA_CHECK(cudaFree(b.ptr));
pool_size -= b.size;
}
}
GGML_ASSERT(pool_size == 0);
}
void * alloc(size_t size, size_t * actual_size) override {
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer& b = buffer_pool[i];
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
size_t diff = b.size - size;
if (diff < best_diff) {
best_diff = diff;
ibest = i;
if (!best_diff) {
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
}
}
}
}
if (ibest >= 0) {
ggml_cuda_buffer& b = buffer_pool[ibest];
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(pool_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr;
}
void free(void * ptr, size_t size) override {
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer& b = buffer_pool[i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
return;
}
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr));
pool_size -= size;
}
};
// pool with virtual memory
#if !defined(GGML_USE_HIPBLAS)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
int device;
CUdeviceptr pool_addr = 0;
size_t pool_used = 0;
size_t pool_size = 0;
size_t granularity;
explicit ggml_cuda_pool_vmm(int device) :
device(device),
granularity(get_cuda_global_info().devices[device].vmm_granularity) {
}
~ggml_cuda_pool_vmm() {
if (pool_addr != 0) {
CU_CHECK(cuMemUnmap(pool_addr, pool_size));
CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
}
}
void * alloc(size_t size, size_t * actual_size) override {
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
size_t avail = pool_size - pool_used;
if (size > avail) {
// round up to the next multiple of the granularity
size_t reserve_size = size - avail;
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
// reserve virtual address space (if not already reserved)
if (pool_addr == 0) {
CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
// add to the pool
pool_size += reserve_size;
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}
GGML_ASSERT(pool_addr != 0);
void * ptr = (void *) (pool_addr + pool_used);
*actual_size = size;
pool_used += size;
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
#endif
return ptr;
}
void free(void * ptr, size_t size) override {
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
#endif
pool_used -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
}
};
#endif // !defined(GGML_USE_HIPBLAS)
template<typename T>
struct ggml_cuda_pool_alloc {
ggml_cuda_pool * pool = nullptr;
T * ptr = nullptr;
size_t actual_size = 0;
ggml_cuda_pool_alloc() = default;
explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
}
ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
alloc(size);
}
~ggml_cuda_pool_alloc() {
if (ptr != nullptr) {
pool->free(ptr, actual_size);
}
}
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(pool != nullptr);
GGML_ASSERT(ptr == nullptr);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
return ptr;
}
T * alloc(ggml_cuda_pool & pool, size_t size) {
this->pool = &pool;
return alloc(size);
}
T * get() {
return ptr;
}
ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
};
// backend interface
struct ggml_backend_cuda_context {
int device;
std::string name;
cudaEvent_t copy_event = nullptr;
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
}
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
ggml_cuda_set_device(device);
CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
}
return cublas_handles[device];
}
cublasHandle_t cublas_handle() {
return cublas_handle(device);
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device) {
#if !defined(GGML_USE_HIPBLAS)
if (get_cuda_global_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
}
#endif
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
}
return *pools[device];
}
ggml_cuda_pool & pool() {
return pool(device);
}
};
// cuda buffer
struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
std::string name;
ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
device(device), dev_ptr(dev_ptr),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_buffer_context() {
CUDA_CHECK(cudaFree(dev_ptr));
}
};
GGML_CALL static const char * ggml_backend_cuda_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
return ctx->name.c_str();
}
GGML_CALL static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
return buffer->iface.get_name == ggml_backend_cuda_buffer_get_name;
}
GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
delete ctx;
}
GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
return ctx->dev_ptr;
}
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
if (tensor->view_src != NULL && tensor->view_offs == 0) {
assert(tensor->view_src->buffer->buft == buffer->buft);
tensor->backend = tensor->view_src->backend;
tensor->extra = tensor->view_src->extra;
return;
}
if (ggml_is_quantized(tensor->type)) {
// initialize padding to 0 to avoid possible NaN values
size_t original_size = ggml_nbytes(tensor);
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
if (padded_size > original_size && tensor->view_src == nullptr) {
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
}
}
}
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_backend_buffer_is_cuda(src->buffer)) {
ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
if (src_ctx->device == dst_ctx->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
} else {
#ifdef GGML_CUDA_NO_PEER_COPY
return false;
#else
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
#endif
}
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
return true;
}
return false;
GGML_UNUSED(buffer);
}
GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
CUDA_CHECK(cudaDeviceSynchronize());
}
static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
/* .get_name = */ ggml_backend_cuda_buffer_get_name,
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
/* .clear = */ ggml_backend_cuda_buffer_clear,
/* .reset = */ NULL,
};
// cuda buffer type
struct ggml_backend_cuda_buffer_type_context {
int device;
std::string name;
};
GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
return ctx->name.c_str();
}
GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
ggml_cuda_set_device(buft_ctx->device);
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size);
if (err != cudaSuccess) {
fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err));
return nullptr;
}
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}
GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
GGML_UNUSED(buft);
}
GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0];
if (ggml_is_quantized(tensor->type)) {
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
}
return size;
GGML_UNUSED(buft);
}
GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
if (!ggml_backend_is_cuda(backend)) {
return false;
}
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
return buft_ctx->device == cuda_ctx->device;
}
static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
/* .get_name = */ ggml_backend_cuda_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
/* .is_host = */ NULL,
};
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (device >= ggml_backend_cuda_get_device_count()) {
return nullptr;
}
static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
static bool ggml_backend_cuda_buffer_type_initialized = false;
if (!ggml_backend_cuda_buffer_type_initialized) {
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
ggml_backend_cuda_buffer_types[i] = {
/* .iface = */ ggml_backend_cuda_buffer_type_interface,
/* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
};
}
ggml_backend_cuda_buffer_type_initialized = true;
}
return &ggml_backend_cuda_buffer_types[device];
}
// cuda split buffer
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > get_cuda_global_info().devices[id].cc) {
min_compute_capability = get_cuda_global_info().devices[id].cc;
}
if (max_compute_capability < get_cuda_global_info().devices[id].cc) {
max_compute_capability = get_cuda_global_info().devices[id].cc;
}
}
}
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
case GGML_TYPE_Q3_K:
return min_compute_capability < CC_RDNA2 ? 128 : 64;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
}
#else
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
default:
GGML_ASSERT(false);
}
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
*row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding;
if (id == ggml_backend_cuda_get_device_count() - 1) {
*row_high = nrows;
} else {
*row_high = nrows*tensor_split[id + 1];
*row_high -= *row_high % rounding;
}
}