| 1 | // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause | 
|---|
| 2 | /* ****************************************************************** | 
|---|
| 3 | * Common functions of New Generation Entropy library | 
|---|
| 4 | * Copyright (c) Meta Platforms, Inc. and affiliates. | 
|---|
| 5 | * | 
|---|
| 6 | *  You can contact the author at : | 
|---|
| 7 | *  - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy | 
|---|
| 8 | *  - Public forum : https://groups.google.com/forum/#!forum/lz4c | 
|---|
| 9 | * | 
|---|
| 10 | * This source code is licensed under both the BSD-style license (found in the | 
|---|
| 11 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found | 
|---|
| 12 | * in the COPYING file in the root directory of this source tree). | 
|---|
| 13 | * You may select, at your option, one of the above-listed licenses. | 
|---|
| 14 | ****************************************************************** */ | 
|---|
| 15 |  | 
|---|
| 16 | /* ************************************* | 
|---|
| 17 | *  Dependencies | 
|---|
| 18 | ***************************************/ | 
|---|
| 19 | #include "mem.h" | 
|---|
| 20 | #include "error_private.h"       /* ERR_*, ERROR */ | 
|---|
| 21 | #define FSE_STATIC_LINKING_ONLY  /* FSE_MIN_TABLELOG */ | 
|---|
| 22 | #include "fse.h" | 
|---|
| 23 | #include "huf.h" | 
|---|
| 24 | #include "bits.h"                /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ | 
|---|
| 25 |  | 
|---|
| 26 |  | 
|---|
| 27 | /*===   Version   ===*/ | 
|---|
| 28 | unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } | 
|---|
| 29 |  | 
|---|
| 30 |  | 
|---|
| 31 | /*===   Error Management   ===*/ | 
|---|
| 32 | unsigned FSE_isError(size_t code) { return ERR_isError(code); } | 
|---|
| 33 | const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } | 
|---|
| 34 |  | 
|---|
| 35 | unsigned HUF_isError(size_t code) { return ERR_isError(code); } | 
|---|
| 36 | const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } | 
|---|
| 37 |  | 
|---|
| 38 |  | 
|---|
| 39 | /*-************************************************************** | 
|---|
| 40 | *  FSE NCount encoding-decoding | 
|---|
| 41 | ****************************************************************/ | 
|---|
| 42 | FORCE_INLINE_TEMPLATE | 
|---|
| 43 | size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, | 
|---|
| 44 | const void* , size_t hbSize) | 
|---|
| 45 | { | 
|---|
| 46 | const BYTE* const istart = (const BYTE*) headerBuffer; | 
|---|
| 47 | const BYTE* const iend = istart + hbSize; | 
|---|
| 48 | const BYTE* ip = istart; | 
|---|
| 49 | int nbBits; | 
|---|
| 50 | int remaining; | 
|---|
| 51 | int threshold; | 
|---|
| 52 | U32 bitStream; | 
|---|
| 53 | int bitCount; | 
|---|
| 54 | unsigned charnum = 0; | 
|---|
| 55 | unsigned const maxSV1 = *maxSVPtr + 1; | 
|---|
| 56 | int previous0 = 0; | 
|---|
| 57 |  | 
|---|
| 58 | if (hbSize < 8) { | 
|---|
| 59 | /* This function only works when hbSize >= 8 */ | 
|---|
| 60 | char buffer[8] = {0}; | 
|---|
| 61 | ZSTD_memcpy(buffer, headerBuffer, hbSize); | 
|---|
| 62 | {   size_t const countSize = FSE_readNCount(normalizedCounter, maxSymbolValuePtr: maxSVPtr, tableLogPtr, | 
|---|
| 63 | rBuffer: buffer, rBuffSize: sizeof(buffer)); | 
|---|
| 64 | if (FSE_isError(code: countSize)) return countSize; | 
|---|
| 65 | if (countSize > hbSize) return ERROR(corruption_detected); | 
|---|
| 66 | return countSize; | 
|---|
| 67 | }   } | 
|---|
| 68 | assert(hbSize >= 8); | 
|---|
| 69 |  | 
|---|
| 70 | /* init */ | 
|---|
| 71 | ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0]));   /* all symbols not present in NCount have a frequency of 0 */ | 
|---|
| 72 | bitStream = MEM_readLE32(memPtr: ip); | 
|---|
| 73 | nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG;   /* extract tableLog */ | 
|---|
| 74 | if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); | 
|---|
| 75 | bitStream >>= 4; | 
|---|
| 76 | bitCount = 4; | 
|---|
| 77 | *tableLogPtr = nbBits; | 
|---|
| 78 | remaining = (1<<nbBits)+1; | 
|---|
| 79 | threshold = 1<<nbBits; | 
|---|
| 80 | nbBits++; | 
|---|
| 81 |  | 
|---|
| 82 | for (;;) { | 
|---|
| 83 | if (previous0) { | 
|---|
| 84 | /* Count the number of repeats. Each time the | 
|---|
| 85 | * 2-bit repeat code is 0b11 there is another | 
|---|
| 86 | * repeat. | 
|---|
| 87 | * Avoid UB by setting the high bit to 1. | 
|---|
| 88 | */ | 
|---|
| 89 | int repeats = ZSTD_countTrailingZeros32(val: ~bitStream | 0x80000000) >> 1; | 
|---|
| 90 | while (repeats >= 12) { | 
|---|
| 91 | charnum += 3 * 12; | 
|---|
| 92 | if (LIKELY(ip <= iend-7)) { | 
|---|
| 93 | ip += 3; | 
|---|
| 94 | } else { | 
|---|
| 95 | bitCount -= (int)(8 * (iend - 7 - ip)); | 
|---|
| 96 | bitCount &= 31; | 
|---|
| 97 | ip = iend - 4; | 
|---|
| 98 | } | 
|---|
| 99 | bitStream = MEM_readLE32(memPtr: ip) >> bitCount; | 
|---|
| 100 | repeats = ZSTD_countTrailingZeros32(val: ~bitStream | 0x80000000) >> 1; | 
|---|
| 101 | } | 
|---|
| 102 | charnum += 3 * repeats; | 
|---|
| 103 | bitStream >>= 2 * repeats; | 
|---|
| 104 | bitCount += 2 * repeats; | 
|---|
| 105 |  | 
|---|
| 106 | /* Add the final repeat which isn't 0b11. */ | 
|---|
| 107 | assert((bitStream & 3) < 3); | 
|---|
| 108 | charnum += bitStream & 3; | 
|---|
| 109 | bitCount += 2; | 
|---|
| 110 |  | 
|---|
| 111 | /* This is an error, but break and return an error | 
|---|
| 112 | * at the end, because returning out of a loop makes | 
|---|
| 113 | * it harder for the compiler to optimize. | 
|---|
| 114 | */ | 
|---|
| 115 | if (charnum >= maxSV1) break; | 
|---|
| 116 |  | 
|---|
| 117 | /* We don't need to set the normalized count to 0 | 
|---|
| 118 | * because we already memset the whole buffer to 0. | 
|---|
| 119 | */ | 
|---|
| 120 |  | 
|---|
| 121 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { | 
|---|
| 122 | assert((bitCount >> 3) <= 3); /* For first condition to work */ | 
|---|
| 123 | ip += bitCount>>3; | 
|---|
| 124 | bitCount &= 7; | 
|---|
| 125 | } else { | 
|---|
| 126 | bitCount -= (int)(8 * (iend - 4 - ip)); | 
|---|
| 127 | bitCount &= 31; | 
|---|
| 128 | ip = iend - 4; | 
|---|
| 129 | } | 
|---|
| 130 | bitStream = MEM_readLE32(memPtr: ip) >> bitCount; | 
|---|
| 131 | } | 
|---|
| 132 | { | 
|---|
| 133 | int const max = (2*threshold-1) - remaining; | 
|---|
| 134 | int count; | 
|---|
| 135 |  | 
|---|
| 136 | if ((bitStream & (threshold-1)) < (U32)max) { | 
|---|
| 137 | count = bitStream & (threshold-1); | 
|---|
| 138 | bitCount += nbBits-1; | 
|---|
| 139 | } else { | 
|---|
| 140 | count = bitStream & (2*threshold-1); | 
|---|
| 141 | if (count >= threshold) count -= max; | 
|---|
| 142 | bitCount += nbBits; | 
|---|
| 143 | } | 
|---|
| 144 |  | 
|---|
| 145 | count--;   /* extra accuracy */ | 
|---|
| 146 | /* When it matters (small blocks), this is a | 
|---|
| 147 | * predictable branch, because we don't use -1. | 
|---|
| 148 | */ | 
|---|
| 149 | if (count >= 0) { | 
|---|
| 150 | remaining -= count; | 
|---|
| 151 | } else { | 
|---|
| 152 | assert(count == -1); | 
|---|
| 153 | remaining += count; | 
|---|
| 154 | } | 
|---|
| 155 | normalizedCounter[charnum++] = (short)count; | 
|---|
| 156 | previous0 = !count; | 
|---|
| 157 |  | 
|---|
| 158 | assert(threshold > 1); | 
|---|
| 159 | if (remaining < threshold) { | 
|---|
| 160 | /* This branch can be folded into the | 
|---|
| 161 | * threshold update condition because we | 
|---|
| 162 | * know that threshold > 1. | 
|---|
| 163 | */ | 
|---|
| 164 | if (remaining <= 1) break; | 
|---|
| 165 | nbBits = ZSTD_highbit32(val: remaining) + 1; | 
|---|
| 166 | threshold = 1 << (nbBits - 1); | 
|---|
| 167 | } | 
|---|
| 168 | if (charnum >= maxSV1) break; | 
|---|
| 169 |  | 
|---|
| 170 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { | 
|---|
| 171 | ip += bitCount>>3; | 
|---|
| 172 | bitCount &= 7; | 
|---|
| 173 | } else { | 
|---|
| 174 | bitCount -= (int)(8 * (iend - 4 - ip)); | 
|---|
| 175 | bitCount &= 31; | 
|---|
| 176 | ip = iend - 4; | 
|---|
| 177 | } | 
|---|
| 178 | bitStream = MEM_readLE32(memPtr: ip) >> bitCount; | 
|---|
| 179 | }   } | 
|---|
| 180 | if (remaining != 1) return ERROR(corruption_detected); | 
|---|
| 181 | /* Only possible when there are too many zeros. */ | 
|---|
| 182 | if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); | 
|---|
| 183 | if (bitCount > 32) return ERROR(corruption_detected); | 
|---|
| 184 | *maxSVPtr = charnum-1; | 
|---|
| 185 |  | 
|---|
| 186 | ip += (bitCount+7)>>3; | 
|---|
| 187 | return ip-istart; | 
|---|
| 188 | } | 
|---|
| 189 |  | 
|---|
| 190 | /* Avoids the FORCE_INLINE of the _body() function. */ | 
|---|
| 191 | static size_t FSE_readNCount_body_default( | 
|---|
| 192 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, | 
|---|
| 193 | const void* , size_t hbSize) | 
|---|
| 194 | { | 
|---|
| 195 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); | 
|---|
| 196 | } | 
|---|
| 197 |  | 
|---|
| 198 | #if DYNAMIC_BMI2 | 
|---|
| 199 | BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( | 
|---|
| 200 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, | 
|---|
| 201 | const void* , size_t hbSize) | 
|---|
| 202 | { | 
|---|
| 203 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); | 
|---|
| 204 | } | 
|---|
| 205 | #endif | 
|---|
| 206 |  | 
|---|
| 207 | size_t FSE_readNCount_bmi2( | 
|---|
| 208 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, | 
|---|
| 209 | const void* , size_t hbSize, int bmi2) | 
|---|
| 210 | { | 
|---|
| 211 | #if DYNAMIC_BMI2 | 
|---|
| 212 | if (bmi2) { | 
|---|
| 213 | return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); | 
|---|
| 214 | } | 
|---|
| 215 | #endif | 
|---|
| 216 | (void)bmi2; | 
|---|
| 217 | return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); | 
|---|
| 218 | } | 
|---|
| 219 |  | 
|---|
| 220 | size_t FSE_readNCount( | 
|---|
| 221 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, | 
|---|
| 222 | const void* , size_t hbSize) | 
|---|
| 223 | { | 
|---|
| 224 | return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); | 
|---|
| 225 | } | 
|---|
| 226 |  | 
|---|
| 227 |  | 
|---|
| 228 | /*! HUF_readStats() : | 
|---|
| 229 | Read compact Huffman tree, saved by HUF_writeCTable(). | 
|---|
| 230 | `huffWeight` is destination buffer. | 
|---|
| 231 | `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. | 
|---|
| 232 | @return : size read from `src` , or an error Code . | 
|---|
| 233 | Note : Needed by HUF_readCTable() and HUF_readDTableX?() . | 
|---|
| 234 | */ | 
|---|
| 235 | size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, | 
|---|
| 236 | U32* nbSymbolsPtr, U32* tableLogPtr, | 
|---|
| 237 | const void* src, size_t srcSize) | 
|---|
| 238 | { | 
|---|
| 239 | U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; | 
|---|
| 240 | return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workspace: wksp, wkspSize: sizeof(wksp), /* flags */ 0); | 
|---|
| 241 | } | 
|---|
| 242 |  | 
|---|
| 243 | FORCE_INLINE_TEMPLATE size_t | 
|---|
| 244 | HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, | 
|---|
| 245 | U32* nbSymbolsPtr, U32* tableLogPtr, | 
|---|
| 246 | const void* src, size_t srcSize, | 
|---|
| 247 | void* workSpace, size_t wkspSize, | 
|---|
| 248 | int bmi2) | 
|---|
| 249 | { | 
|---|
| 250 | U32 weightTotal; | 
|---|
| 251 | const BYTE* ip = (const BYTE*) src; | 
|---|
| 252 | size_t iSize; | 
|---|
| 253 | size_t oSize; | 
|---|
| 254 |  | 
|---|
| 255 | if (!srcSize) return ERROR(srcSize_wrong); | 
|---|
| 256 | iSize = ip[0]; | 
|---|
| 257 | /* ZSTD_memset(huffWeight, 0, hwSize);   *//* is not necessary, even though some analyzer complain ... */ | 
|---|
| 258 |  | 
|---|
| 259 | if (iSize >= 128) {  /* special header */ | 
|---|
| 260 | oSize = iSize - 127; | 
|---|
| 261 | iSize = ((oSize+1)/2); | 
|---|
| 262 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); | 
|---|
| 263 | if (oSize >= hwSize) return ERROR(corruption_detected); | 
|---|
| 264 | ip += 1; | 
|---|
| 265 | {   U32 n; | 
|---|
| 266 | for (n=0; n<oSize; n+=2) { | 
|---|
| 267 | huffWeight[n]   = ip[n/2] >> 4; | 
|---|
| 268 | huffWeight[n+1] = ip[n/2] & 15; | 
|---|
| 269 | }   }   } | 
|---|
| 270 | else  {   /* header compressed with FSE (normal case) */ | 
|---|
| 271 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); | 
|---|
| 272 | /* max (hwSize-1) values decoded, as last one is implied */ | 
|---|
| 273 | oSize = FSE_decompress_wksp_bmi2(dst: huffWeight, dstCapacity: hwSize-1, cSrc: ip+1, cSrcSize: iSize, maxLog: 6, workSpace, wkspSize, bmi2); | 
|---|
| 274 | if (FSE_isError(code: oSize)) return oSize; | 
|---|
| 275 | } | 
|---|
| 276 |  | 
|---|
| 277 | /* collect weight stats */ | 
|---|
| 278 | ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); | 
|---|
| 279 | weightTotal = 0; | 
|---|
| 280 | {   U32 n; for (n=0; n<oSize; n++) { | 
|---|
| 281 | if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); | 
|---|
| 282 | rankStats[huffWeight[n]]++; | 
|---|
| 283 | weightTotal += (1 << huffWeight[n]) >> 1; | 
|---|
| 284 | }   } | 
|---|
| 285 | if (weightTotal == 0) return ERROR(corruption_detected); | 
|---|
| 286 |  | 
|---|
| 287 | /* get last non-null symbol weight (implied, total must be 2^n) */ | 
|---|
| 288 | {   U32 const tableLog = ZSTD_highbit32(val: weightTotal) + 1; | 
|---|
| 289 | if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); | 
|---|
| 290 | *tableLogPtr = tableLog; | 
|---|
| 291 | /* determine last weight */ | 
|---|
| 292 | {   U32 const total = 1 << tableLog; | 
|---|
| 293 | U32 const rest = total - weightTotal; | 
|---|
| 294 | U32 const verif = 1 << ZSTD_highbit32(val: rest); | 
|---|
| 295 | U32 const lastWeight = ZSTD_highbit32(val: rest) + 1; | 
|---|
| 296 | if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */ | 
|---|
| 297 | huffWeight[oSize] = (BYTE)lastWeight; | 
|---|
| 298 | rankStats[lastWeight]++; | 
|---|
| 299 | }   } | 
|---|
| 300 |  | 
|---|
| 301 | /* check tree construction validity */ | 
|---|
| 302 | if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);   /* by construction : at least 2 elts of rank 1, must be even */ | 
|---|
| 303 |  | 
|---|
| 304 | /* results */ | 
|---|
| 305 | *nbSymbolsPtr = (U32)(oSize+1); | 
|---|
| 306 | return iSize+1; | 
|---|
| 307 | } | 
|---|
| 308 |  | 
|---|
| 309 | /* Avoids the FORCE_INLINE of the _body() function. */ | 
|---|
| 310 | static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, | 
|---|
| 311 | U32* nbSymbolsPtr, U32* tableLogPtr, | 
|---|
| 312 | const void* src, size_t srcSize, | 
|---|
| 313 | void* workSpace, size_t wkspSize) | 
|---|
| 314 | { | 
|---|
| 315 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, bmi2: 0); | 
|---|
| 316 | } | 
|---|
| 317 |  | 
|---|
| 318 | #if DYNAMIC_BMI2 | 
|---|
| 319 | static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, | 
|---|
| 320 | U32* nbSymbolsPtr, U32* tableLogPtr, | 
|---|
| 321 | const void* src, size_t srcSize, | 
|---|
| 322 | void* workSpace, size_t wkspSize) | 
|---|
| 323 | { | 
|---|
| 324 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, bmi2: 1); | 
|---|
| 325 | } | 
|---|
| 326 | #endif | 
|---|
| 327 |  | 
|---|
| 328 | size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, | 
|---|
| 329 | U32* nbSymbolsPtr, U32* tableLogPtr, | 
|---|
| 330 | const void* src, size_t srcSize, | 
|---|
| 331 | void* workSpace, size_t wkspSize, | 
|---|
| 332 | int flags) | 
|---|
| 333 | { | 
|---|
| 334 | #if DYNAMIC_BMI2 | 
|---|
| 335 | if (flags & HUF_flags_bmi2) { | 
|---|
| 336 | return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); | 
|---|
| 337 | } | 
|---|
| 338 | #endif | 
|---|
| 339 | (void)flags; | 
|---|
| 340 | return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); | 
|---|
| 341 | } | 
|---|
| 342 |  | 
|---|