From 2884a39eb1a0fbe6f048ceac41642d2c7afd68f5 Mon Sep 17 00:00:00 2001 From: Simon Gardling Date: Tue, 7 Apr 2026 20:04:06 -0400 Subject: [PATCH] llama-cpp: patch for vulkan support instead --- flake.lock | 12 +- flake.nix | 3 +- patches/0002-llamacpp-vulkan-turbo3.patch | 392 ++++++++++++++++++++++ services/llama-cpp.nix | 6 +- 4 files changed, 404 insertions(+), 9 deletions(-) create mode 100644 patches/0002-llamacpp-vulkan-turbo3.patch diff --git a/flake.lock b/flake.lock index f6cc9e2..931a4bd 100644 --- a/flake.lock +++ b/flake.lock @@ -325,16 +325,16 @@ ] }, "locked": { - "lastModified": 1774922513, - "narHash": "sha256-TKk1i8AZzxy4/z0MkqKxoGf/CQDvoL+jo8JDtZeCRy8=", - "owner": "apollosenvy", + "lastModified": 1775603401, + "narHash": "sha256-kp+cnqLX+K4M6gBc5Iy4S+G0xkz78qVEcO1xmNTrtgM=", + "owner": "TheTom", "repo": "llama-cpp-turboquant", - "rev": "9e80e93ceb115bc5055997c373d8c09bfa47a565", + "rev": "a4e8af4455d34d4872f967e615c8212643c2123e", "type": "github" }, "original": { - "owner": "apollosenvy", - "ref": "pr/vulkan-turbo3", + "owner": "TheTom", + "ref": "feature/turboquant-kv-cache", "repo": "llama-cpp-turboquant", "type": "github" } diff --git a/flake.nix b/flake.nix index 728fa02..56319ee 100644 --- a/flake.nix +++ b/flake.nix @@ -29,8 +29,7 @@ }; llamacpp = { - # url = "github:TheTom/llama-cpp-turboquant/feature/turboquant-kv-cache"; - url = "github:apollosenvy/llama-cpp-turboquant/pr/vulkan-turbo3"; + url = "github:TheTom/llama-cpp-turboquant/feature/turboquant-kv-cache"; inputs.nixpkgs.follows = "nixpkgs"; }; diff --git a/patches/0002-llamacpp-vulkan-turbo3.patch b/patches/0002-llamacpp-vulkan-turbo3.patch new file mode 100644 index 0000000..a8e0f7b --- /dev/null +++ b/patches/0002-llamacpp-vulkan-turbo3.patch @@ -0,0 +1,392 @@ +From 9e80e93ceb115bc5055997c373d8c09bfa47a565 Mon Sep 17 00:00:00 2001 +From: Tuklus-Labs +Date: Mon, 30 Mar 2026 07:48:27 -0700 +Subject: [PATCH] feat: Vulkan compute shader support for turbo3 KV cache + +Full turbo3 quantize/dequant pipeline for Vulkan backend: + +- types.glsl: block_turbo3_0 struct (norm + qs[8] + signs[4]) +- dequant_turbo3_0.comp: standalone dequant shader (3-bit index + reconstruction from 2-bit qs + 1-bit signs, centroid lookup) +- dequant_funcs.glsl: inline dequant for get_rows/mul_mat paths +- dequant_funcs_cm2.glsl: cooperative matrix 2 FA path support +- copy_to_quant.comp: quantize function with norm correction +- vulkan-shaders-gen.cpp: turbo3_0 type registration +- ggml-vulkan.cpp: pipeline creation and supports_op dispatch + +Tested on AMD 7900 XTX (RADV): 243 pp / 25.8 tg t/s with turbo3 KV. + +Co-Authored-By: Claude Opus 4.6 (1M context) +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++++- + .../vulkan-shaders/copy_to_quant.comp | 58 +++++++++++++++++++ + .../vulkan-shaders/dequant_funcs.glsl | 36 ++++++++++++ + .../vulkan-shaders/dequant_funcs_cm2.glsl | 29 ++++++++++ + .../vulkan-shaders/dequant_turbo3_0.comp | 46 +++++++++++++++ + .../src/ggml-vulkan/vulkan-shaders/types.glsl | 17 ++++++ + .../vulkan-shaders/vulkan-shaders-gen.cpp | 5 +- + 7 files changed, 201 insertions(+), 3 deletions(-) + create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 221e6fa04e9..bf826075c11 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -4177,6 +4177,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); +@@ -4202,6 +4203,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_TURBO3_0], "get_rows_turbo3_0", get_rows_turbo3_0_len, get_rows_turbo3_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); +@@ -4227,6 +4229,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_TURBO3_0], "get_rows_turbo3_0_f32", get_rows_turbo3_0_f32_len, get_rows_turbo3_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); +@@ -4294,6 +4297,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); +@@ -4301,6 +4305,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } + + #define SET_ROWS(itype, rte) \ +@@ -4312,7 +4317,8 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ +- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ++ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## rte ## _len, set_rows_turbo3_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + if (device->float_controls_rte_fp16) { + SET_ROWS(_i32, _rte) +@@ -4330,6 +4336,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TURBO3_0], "cpy_turbo3_0_f32", cpy_turbo3_0_f32_len, cpy_turbo3_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TURBO3_0), 1, 1}, {}, 1); + + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; +@@ -15376,6 +15383,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: ++ case GGML_TYPE_TURBO3_0: + case GGML_TYPE_I32: + return true; + default: +@@ -15394,6 +15402,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: ++ case GGML_TYPE_TURBO3_0: + return true; + default: + return false; +@@ -15417,6 +15426,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: ++ case GGML_TYPE_TURBO3_0: + return true; + default: + break; +@@ -15431,6 +15441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: ++ case GGML_TYPE_TURBO3_0: + return true; + default: + break; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +index b8c40eec102..54331e28c82 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +@@ -184,6 +184,64 @@ void quantize(uint dst_idx, uint src_idx) + } + #endif + ++#if defined(DATA_A_TURBO3_0) ++void quantize(uint dst_idx, uint src_idx) ++{ ++ const float centroids[8] = float[8]( ++ -0.190685, -0.117832, -0.065717, -0.021460, ++ 0.021460, 0.065717, 0.117832, 0.190685 ++ ); ++ const float midpoints[7] = float[7]( ++ -0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259 ++ ); ++ ++ // Compute L2 norm ++ float norm_sq = 0.0; ++ [[unroll]] for (int j = 0; j < 32; ++j) { ++ float v = data_s[src_idx + j]; ++ norm_sq += v * v; ++ } ++ float norm = sqrt(norm_sq); ++ float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0; ++ ++ // Clear output ++ [[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0); ++ [[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0); ++ ++ // Accumulate centroid reconstruction norm for correction ++ float recon_norm_sq = 0.0; ++ ++ // Quantize each element ++ [[unroll]] for (int j = 0; j < 32; ++j) { ++ float val = data_s[src_idx + j] * inv_norm; ++ ++ // Find nearest centroid via midpoint comparison ++ uint idx = 0; ++ if (val < midpoints[0]) idx = 0; ++ else if (val < midpoints[1]) idx = 1; ++ else if (val < midpoints[2]) idx = 2; ++ else if (val < midpoints[3]) idx = 3; ++ else if (val < midpoints[4]) idx = 4; ++ else if (val < midpoints[5]) idx = 5; ++ else if (val < midpoints[6]) idx = 6; ++ else idx = 7; ++ ++ recon_norm_sq += centroids[idx] * centroids[idx]; ++ ++ // Pack: low 2 bits to qs, high 1 bit to signs ++ uint low2 = idx & 0x3; ++ uint hi1 = (idx >> 2) & 0x1; ++ data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2)); ++ data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8)); ++ } ++ ++ // Norm correction: scale so reconstruction matches original norm ++ float recon_norm = sqrt(recon_norm_sq); ++ float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm; ++ data_q[dst_idx].norm = float16_t(corrected_norm); ++} ++#endif ++ + #if defined(DATA_A_IQ4_NL) + uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +index 7865a6bda79..eefffe9d502 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +@@ -602,3 +602,39 @@ vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); + } + #endif ++ ++#if defined(DATA_A_TURBO3_0) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ // PolarQuant 3-bit centroids (Lloyd-Max for Gaussian) ++ const float centroids[8] = float[8]( ++ -0.190685, -0.117832, -0.065717, -0.021460, ++ 0.021460, 0.065717, 0.117832, 0.190685 ++ ); ++ ++ // iqs is the element index within the block (0..31), we decode 2 consecutive elements ++ const uint j0 = iqs; ++ const uint j1 = iqs + 1; ++ ++ // Extract 2-bit low indices from qs (4 per byte) ++ const uint low2_0 = (uint(data_a[a_offset + ib].qs[j0 / 4]) >> ((j0 % 4) * 2)) & 0x3; ++ const uint low2_1 = (uint(data_a[a_offset + ib].qs[j1 / 4]) >> ((j1 % 4) * 2)) & 0x3; ++ ++ // Extract 1-bit high from signs (8 per byte) ++ const uint hi1_0 = (uint(data_a[a_offset + ib].signs[j0 / 8]) >> (j0 % 8)) & 0x1; ++ const uint hi1_1 = (uint(data_a[a_offset + ib].signs[j1 / 8]) >> (j1 % 8)) & 0x1; ++ ++ // Combine to 3-bit index ++ const uint idx0 = low2_0 | (hi1_0 << 2); ++ const uint idx1 = low2_1 | (hi1_1 << 2); ++ ++ return vec2(centroids[idx0], centroids[idx1]); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ vec2 v0 = dequantize(ib, iqs, a_offset); ++ vec2 v1 = dequantize(ib, iqs + 2, a_offset); ++ return vec4(v0.x, v0.y, v1.x, v1.y); ++} ++vec2 get_dm(uint ib, uint a_offset) { ++ return vec2(float(data_a[a_offset + ib].norm), 0); ++} ++#endif +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +index 8ac6482dc94..03d200bd964 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +@@ -685,6 +685,33 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords + } + #endif + ++#if defined(DATA_A_TURBO3_0) ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTURBO3_0 { ++ block_turbo3_0 block; ++}; ++ ++float16_t dequantFuncTURBO3_0(const in decodeBufTURBO3_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float centroids[8] = float[8]( ++ -0.190685, -0.117832, -0.065717, -0.021460, ++ 0.021460, 0.065717, 0.117832, 0.190685 ++ ); ++ const float norm = float(bl.block.norm); ++ const uint j = coordInBlock[1]; ++ ++ // Extract 2-bit low index from qs (4 per byte) ++ const uint low2 = (uint(bl.block.qs[j / 4]) >> ((j % 4) * 2)) & 0x3; ++ ++ // Extract 1-bit high from signs (8 per byte) ++ const uint hi1 = (uint(bl.block.signs[j / 8]) >> (j % 8)) & 0x1; ++ ++ // Combine to 3-bit index ++ const uint idx = low2 | (hi1 << 2); ++ ++ return float16_t(centroids[idx] * norm); ++} ++#endif ++ + #if defined(DATA_A_Q4_0) + #define dequantFuncA dequantFuncQ4_0 + #elif defined(DATA_A_Q4_1) +@@ -729,6 +756,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords + #define dequantFuncA dequantFuncIQ4_NL + #elif defined(DATA_A_MXFP4) + #define dequantFuncA dequantFuncMXFP4 ++#elif defined(DATA_A_TURBO3_0) ++#define dequantFuncA dequantFuncTURBO3_0 + #elif defined(DATA_A_F32) + #define dequantFuncA dequantFuncF32 + #endif +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp +new file mode 100644 +index 00000000000..17b9bd9eb4b +--- /dev/null ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp +@@ -0,0 +1,46 @@ ++#version 450 ++ ++#include "dequant_head.glsl" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_turbo3_0 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const float centroids[8] = float[8]( ++ -0.190685, -0.117832, -0.065717, -0.021460, ++ 0.021460, 0.065717, 0.117832, 0.190685 ++ ); ++ ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint b_idx = 1024*i + 32*ir + 16*il; ++ ++ const float norm = float(data_a[ib].norm); ++ ++ const uint q_start = 16*il; ++ ++ [[unroll]] for (uint l = 0; l < 16; ++l) { ++ const uint j = q_start + l; ++ ++ // Extract 2-bit low index from qs (4 per byte) ++ const uint low2 = (uint(data_a[ib].qs[j / 4]) >> ((j % 4) * 2)) & 0x3; ++ ++ // Extract 1-bit high from signs (8 per byte) ++ const uint hi1 = (uint(data_a[ib].signs[j / 8]) >> (j % 8)) & 0x1; ++ ++ // Combine to 3-bit index ++ const uint idx = low2 | (hi1 << 2); ++ ++ data_b[b_idx + l] = D_TYPE(centroids[idx] * norm); ++ } ++} +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +index bdb2c09259b..e3635fa01b7 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +@@ -1696,6 +1696,23 @@ struct block_mxfp4 + #define A_TYPE block_mxfp4 + #endif + ++#define QUANT_K_TURBO3_0 32 ++#define QUANT_R_TURBO3_0 1 ++ ++struct block_turbo3_0 ++{ ++ float16_t norm; ++ uint8_t qs[8]; // 2-bit centroid indices (4 per byte) ++ uint8_t signs[4]; // 1-bit high bit of 3-bit index (8 per byte) ++}; ++ ++#if defined(DATA_A_TURBO3_0) ++#define QUANT_K QUANT_K_TURBO3_0 ++#define QUANT_R QUANT_R_TURBO3_0 ++#define QUANT_AUXF 1 ++#define A_TYPE block_turbo3_0 ++#endif ++ + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) + const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +index 8186dba36f6..90253243ab8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +@@ -66,6 +66,7 @@ const std::vector type_names = { + "iq4_nl", + "mxfp4", + "bf16", ++ "turbo3_0", + }; + + enum MatMulIdType { +@@ -757,13 +758,13 @@ void process_shaders() { + string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); + +- for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { ++ for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + +- for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { ++ for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) { + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); diff --git a/services/llama-cpp.nix b/services/llama-cpp.nix index 1f1d834..cdf9955 100644 --- a/services/llama-cpp.nix +++ b/services/llama-cpp.nix @@ -23,7 +23,11 @@ in ); port = service_configs.ports.private.llama_cpp.port; host = "0.0.0.0"; - package = (lib.optimizePackage inputs.llamacpp.packages.${pkgs.system}.vulkan); + package = lib.optimizePackage ( + inputs.llamacpp.packages.${pkgs.system}.vulkan.overrideAttrs (old: { + patches = (old.patches or [ ]) ++ [ ../patches/0002-llamacpp-vulkan-turbo3.patch ]; + }) + ); extraFlags = [ "-ngl" "999"