diff --git a/npsr/hwy.h b/npsr/hwy.h new file mode 100644 index 0000000..cf821b7 --- /dev/null +++ b/npsr/hwy.h @@ -0,0 +1,39 @@ +#ifndef NPSR_HWY_H_ +#define NPSR_HWY_H_ + +#include +// This macro is used to define intrinsics that are: +// NOTE: equals to HWY_API. +// - always inlined +// - flattened (no separate stack frame) +// - marked maybe unused to suppress warnings when they are not used +// NOTE: we do not need to use HWY_ATTR because we wrap Highway intrinsics in +// HWY_BEFORE_NAMESPACE()/HWY_AFTER_NAMESPACE() +// which implies the nessessary target attributes via #pargma. +#define NPSR_INTRIN static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED +#endif // NPSR_HWY_H_ + +#if defined(NPSR_HWY_FOREACH_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_HWY_FOREACH_H_ +#undef NPSR_HWY_FOREACH_H_ +#else +#define NPSR_HWY_FOREACH_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace npsr::HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; +using hn::DFromV; +using hn::MFromD; +using hn::Rebind; +using hn::RebindToUnsigned; +using hn::TFromD; +using hn::TFromV; +using hn::VFromD; +constexpr bool kNativeFMA = HWY_NATIVE_FMA != 0; + +inline HWY_ATTR void DummyToSuppressUnusedWarning() {} +} // namespace npsr::HWY_NAMESPACE +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_HWY_FOREACH_H_ diff --git a/npsr/lut-inl.h b/npsr/lut-inl.h new file mode 100644 index 0000000..eb7611c --- /dev/null +++ b/npsr/lut-inl.h @@ -0,0 +1,216 @@ +#if defined(NPSR_LUT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_LUT_INL_H_ +#undef NPSR_LUT_INL_H_ +#else +#define NPSR_LUT_INL_H_ +#endif + +#include + +#include "npsr/hwy.h" + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE { + +/** + * @brief SIMD-optimized lookup table implementation + * + * This class provides an efficient lookup table. + * It stores data in both row-major and column-major + * formats to optimize different access patterns. + * + * @tparam T Element type (must match SIMD vector element type) + * @tparam kRows Number of rows in the lookup table + * @tparam kCols Number of columns in the lookup table + * + * Example usage: + * @code + * // Create a 2x4 lookup table + * constexpr Lut lut{{1.0f, 2.0f, 3.0f, 4.0f}, {5.0f, 6.0f, 7.0f, 8.0f}}; + * // Load values using SIMD indices + * auto indices = Set(d, 2); // SIMD vector of indices + * Vec out0, out1; + * lut.Load(indices, out0, out1); + * @endcode + */ +template +class Lut { + public: + static constexpr size_t kLength = kRows * kCols; + + /** + * @brief Construct a lookup table from row arrays + * + * @tparam ColSizes Size of each row array (deduced) + * @param rows Variable number of arrays, each representing a row + * + * @note All rows must have exactly kCols elements + * @note The constructor is constexpr for compile-time initialization + */ + template + constexpr Lut(const T (&...rows)[ColSizes]) : row_{} { + // Check that we have the right number of rows + static_assert(sizeof...(rows) == kRows, + "Number of rows doesn't match template parameter"); + // Check that all rows have the same number of columns + static_assert(((ColSizes == kCols) && ...), + "All rows must have the same number of columns"); + + // Copy data using recursive template approach + ToRowMajor_<0>(rows...); + } + + /** + * @brief Load values from the LUT using SIMD indices + * + * This method performs efficient SIMD lookups by selecting the optimal + * implementation based on the vector size and LUT dimensions. + * + * @tparam VU SIMD vector type for indices + * @tparam OutV Output SIMD vector types (must match number of rows) + * @param idx SIMD vector of column indices + * @param out Output vectors (one per row) + * + * @note The number of output vectors must exactly match kRows + * @note Index values must be in range [0, kCols) + */ + template + HWY_INLINE void Load(VU idx, OutV &...out) const { + static_assert(sizeof...(OutV) == kRows, + "Number of output vectors must match number of rows in LUT"); + using namespace hn; + using TU = TFromV; + static_assert(sizeof(TU) == sizeof(T), + "Index type must match LUT element type"); + // Row-major based optimization + LoadRow_(idx, out...); + } + + private: + /// Convert input rows to row-major storage format + template + constexpr void ToRowMajor_(const T (&...rows)[ColSizes]) { + if constexpr (RowIDX < kRows) { + auto row_array = std::get(std::make_tuple(rows...)); + for (size_t col = 0; col < kCols; ++col) { + row_[RowIDX * kCols + col] = row_array[col]; + } + ToRowMajor_(rows...); + } + } + + /// Dispatch to optimal row-load implementation based on vector/LUT size + template + HWY_INLINE void LoadRow_(VU idx, OutV &...out) const { + using namespace hn; + using DU = DFromV; + const DU du; + using D = Rebind; + const D d; + + HWY_LANES_CONSTEXPR size_t kLanes = Lanes(du); + if HWY_LANES_CONSTEXPR (kLanes == kCols) { + // Vector size matches table width - use single table lookup + const auto ind = IndicesFromVec(d, idx); + LoadX1_(ind, out...); + } else if HWY_LANES_CONSTEXPR (kLanes * 2 == kCols) { + // Vector size is half table width - use two table lookup + const auto ind = IndicesFromVec(d, idx); + LoadX2_(ind, out...); + } else { + // Fallback to gather for other configurations + LoadGather_(idx, out...); + } + } + + // Load using single table lookup (vector size == table width) + template + HWY_INLINE void LoadX1_(const VInd &ind, OutV0 &out0, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + + const OutV0 lut0 = LoadU(d, row_ + Off); + out0 = TableLookupLanes(lut0, ind); + + if constexpr (sizeof...(OutV) > 0) { + LoadX1_(ind, out...); + } + } + + // Load using two table lookups (vector size == table width / 2) + template + HWY_INLINE void LoadX2_(const VInd &ind, OutV0 &out0, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + + constexpr size_t kLanes = kCols / 2; + const OutV0 lut0 = LoadU(d, row_ + Off); + const OutV0 lut1 = LoadU(d, row_ + Off + kLanes); + out0 = TwoTablesLookupLanes(d, lut0, lut1, ind); + + if constexpr (sizeof...(OutV) > 0) { + LoadX2_(ind, out...); + } + } + + // General fallback using gather instructions + template + HWY_INLINE void LoadGather_(const VU &idx, OutV0 &out0, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + out0 = GatherIndex(d, row_ + Off, BitCast(RebindToSigned(), idx)); + if constexpr (sizeof...(OutV) > 0) { + LoadGather_(idx, out...); + } + } + + // Row-major + HWY_ALIGN T row_[kLength]; +}; + +/** + * @brief Deduction guide for automatic dimension detection + * + * Allows constructing a Lut without explicitly specifying dimensions: + * @code + * Lut lut{row0, row1, row2}; // Dimensions deduced from arrays + * @endcode + */ +template +Lut(const T (&first)[First], const T (&...rest)[Rest]) + -> Lut; + +/** + * @brief Factory function that requires explicit type specification + * + * This approach forces users to specify the type T explicitly while + * automatically deducing the dimensions from the array arguments. + * + * Note: We use MakeLut since partial deduction guides (e.g., Lut{...}) + * require C++20, but this codebase targets C++17. + * + * @tparam T Element type (must be explicitly specified) + * @param first First row array + * @param rest Additional row arrays + * @return Lut with deduced dimensions + * + * Usage: + * @code + * auto lut = MakeLut(row0, row1, row2); // T explicit, dimensions + * deduced + * @endcode + */ +template +constexpr auto MakeLut(const T (&first)[First], const T (&...rest)[Rest]) { + return Lut{first, rest...}; +} + +} // namespace npsr::HWY_NAMESPACE + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_LUT_INL_H_ diff --git a/npsr/npsr.h b/npsr/npsr.h new file mode 100644 index 0000000..7e9d08a --- /dev/null +++ b/npsr/npsr.h @@ -0,0 +1,11 @@ +// To include them once per target, which is ensured by the toggle check. +#if defined(NPSR_NPSR_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_NPSR_H_ +#undef NPSR_NPSR_H_ +#else +#define NPSR_NPSR_H_ +#endif + +#include "npsr/trig/inl.h" + +#endif // NPSR_NPSR_H_ diff --git a/npsr/precise.h b/npsr/precise.h new file mode 100644 index 0000000..ec4e617 --- /dev/null +++ b/npsr/precise.h @@ -0,0 +1,198 @@ +#ifndef NPSR_PRECISE_H_ +#define NPSR_PRECISE_H_ + +#include +#include +#include +#include + +namespace npsr { +using std::is_same_v; + +// Tag types for configuring floating-point behavior +// These allow compile-time configuration without runtime overhead + +// Algorithm configuration tags +// Skip extended precision for |x| > 2^24 (float) or 2^53 (double) +struct _NoLargeArgument {}; +// Skip checks for NaN, Inf, and other special values +struct _NoSpecialCases {}; +struct _NoExceptions {}; // Disable floating-point exception tracking +// Use faster, less accurate algorithms (typ. 1-4 ULP vs 1.0 ULP) +struct _LowAccuracy {}; + +// Convenience constants for cleaner API +constexpr auto kNoLargeArgument = _NoLargeArgument{}; +constexpr auto kNoSpecialCases = _NoSpecialCases{}; +constexpr auto kNoExceptions = _NoExceptions{}; +constexpr auto kLowAccuracy = _LowAccuracy{}; + +// Subnormal (denormal) number handling modes +// Controls how the CPU handles numbers smaller than the minimum normalized +// value +struct Subnormal { + struct _DAZ {}; // Denormals Are Zero: treat subnormals as zero on input + struct _FTZ {}; // Flush To Zero: round subnormal results to zero + struct _IEEE754 { + }; // Strict IEEE 754 compliance: handle subnormals correctly + + static constexpr auto kDAZ = _DAZ{}; + static constexpr auto kFTZ = _FTZ{}; + static constexpr auto kIEEE754 = _IEEE754{}; +}; + +// Floating-point exception flags +// These match the standard C library FE_* macros +class FPExceptions { + public: + static constexpr auto kNone = 0; +// guard against missing macros on some platforms +// (e.g. Emscripten) +#ifdef FE_INVALID + static constexpr auto kInvalid = FE_INVALID; +#else + static constexpr auto kInvalid = 0; +#endif +#ifdef FE_DIVBYZERO + static constexpr auto kDivByZero = FE_DIVBYZERO; +#else + static constexpr auto kDivByZero = 0; +#endif +#ifdef FE_OVERFLOW + static constexpr auto kOverflow = FE_OVERFLOW; +#else + static constexpr auto kOverflow = 0; +#endif +#ifdef FE_UNDERFLOW + static constexpr auto kUnderflow = FE_UNDERFLOW; +#else + static constexpr auto kUnderflow = 0; +#endif + static constexpr auto kAll = kInvalid | kDivByZero | kOverflow | kUnderflow; + + void Raise(int errors) noexcept { mask_ |= errors; } + + protected: + void Load() noexcept { loaded_ = std::fegetexceptflag(&saved_, kAll) == 0; } + + ~FPExceptions() noexcept { + if (loaded_) { + std::fesetexceptflag(&saved_, kAll); + } + if (mask_ != kNone) { + std::feraiseexcept(mask_); + } + } + + private: + bool loaded_ = false; + int mask_ = kNone; + std::fexcept_t saved_; +}; + +/** + * @brief RAII floating-point precision control class + * + * The Precise class provides automatic management of floating-point + * environment settings during its lifetime. It uses RAII principles to save + * the current floating-point state on construction and restore it on + * destruction. + * + * The class is configured using variadic template arguments that specify + * the desired floating-point behavior through tag types. + * + * **IMPORTANT PERFORMANCE NOTE**: Create the Precise object BEFORE loops, + * not inside them. The constructor and destructor have overhead from saving + * and restoring floating-point state, so it should be done once per + * computational scope, not per iteration. + * + * @tparam Args Variadic template arguments for configuration flags + * + * Configuration options: + * - kLowAccuracy: Use faster algorithms with ~1-4 ULP error (default: high + * accuracy ~1.0 ULP) + * - kNoLargeArgument: Skip extended precision reduction for large arguments + * - kNoSpecialCases: Skip NaN/Inf handling (assumes finite inputs) + * - kNoExceptions: Disable FP exception tracking for better performance + * - Subnormal::kDAZ/kFTZ: Flush subnormals to zero for performance + * - Subnormal::kIEEE754: Strict IEEE 754 compliance (default if DAZ/FTZ not + * specified) + * + * @example + * ```cpp + * using namespace hwy::HWY_NAMESPACE; + * using namespace npsr; + * using namespace npsr::HWY_NAMESPACE; + * + * // Configure for maximum performance with reduced accuracy + * Precise precise = {kLowAccuracy, kNoSpecialCases, kNoLargeArgument}; + * + * const ScalableTag d; + * using V = Vec; + * + * for (size_t i = 0; i < n; i += Lanes(d)) { + * V input = LoadU(d, &input_data[i]); + * V result = Sin(precise, input); // Uses configured precision + * StoreU(result, d, &output_data[i]); + * } + * ``` + */ +template +class Precise : public FPExceptions { + public: + // Default constructor saves current FP state + Precise() noexcept { + // Save exception flags unless disabled + if constexpr (!kNoExceptions) { + FPExceptions::Load(); + } + } + + // Variadic constructor for tag-based configuration + template + Precise(T1&& arg1, Rest&&... rest) noexcept : Precise() { + // Tags are processed at compile time via template parameters + // This constructor exists to enable Precise{tag1, tag2, ...} syntax + } + + // Compile-time configuration queries + // These allow algorithms to optimize based on precision requirements + static constexpr bool kNoExceptions = (is_same_v<_NoExceptions, Args> || ...); + static constexpr bool kNoLargeArgument = + (is_same_v<_NoLargeArgument, Args> || ...); + static constexpr bool kNoSpecialCases = + (is_same_v<_NoSpecialCases, Args> || ...); + static constexpr bool kLowAccuracy = (is_same_v<_LowAccuracy, Args> || ...); + + // Derived flags (defaults when not explicitly specified) + static constexpr bool kHighAccuracy = !kLowAccuracy; + static constexpr bool kLargeArgument = !kNoLargeArgument; + static constexpr bool kSpecialCases = !kNoSpecialCases; + static constexpr bool kExceptions = !kNoExceptions; + + // Subnormal handling configuration + static constexpr bool kDAZ = (is_same_v || ...); + static constexpr bool kFTZ = (is_same_v || ...); + static constexpr bool _kIEEE754 = + (is_same_v || ...); + + // Ensure IEEE754 mode is exclusive with DAZ/FTZ + static_assert(!_kIEEE754 || !(kDAZ || kFTZ), + "IEEE754 mode cannot be used " + "with Denormals Are Zero (DAZ) or Flush To Zero (FTZ) " + "subnormal handling"); + + // Default to IEEE754 if no subnormal mode specified + static constexpr bool kIEEE754 = _kIEEE754 || !(kDAZ || kFTZ); +}; // namespace npsr + +// Deduction guides for convenient construction + +// Enable Precise{} with no arguments +Precise() -> Precise<>; +// Enable Precise{tag1, tag2, ...} syntax +template +Precise(T1&&, Rest&&...) -> Precise, std::decay_t...>; + +} // namespace npsr +#endif // NPSR_PRECISE_H_ diff --git a/npsr/trig/extended-inl.h b/npsr/trig/extended-inl.h new file mode 100644 index 0000000..cdce91a --- /dev/null +++ b/npsr/trig/extended-inl.h @@ -0,0 +1,312 @@ +#if defined(NPSR_TRIG_EXTENDED_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_EXTENDED_INL_H_ +#undef NPSR_TRIG_EXTENDED_INL_H_ +#else +#define NPSR_TRIG_EXTENDED_INL_H_ +#endif + +#include "npsr/hwy.h" +#include "npsr/trig/data/data.h" +#include "npsr/trig/low-inl.h" // Operation + +HWY_BEFORE_NAMESPACE(); +namespace npsr::HWY_NAMESPACE::trig { + +template +NPSR_INTRIN V Extended(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + using hwy::ExponentBits; + using hwy::MantissaBits; + using hwy::MantissaMask; + using hwy::SignMask; + + using D = DFromV; + using DI = RebindToSigned; + using DU = RebindToUnsigned; + using VI = Vec; + using VU = Vec; + using T = TFromV; + using TU = TFromV; + const D d; + const DI di; + const DU du; + + constexpr bool kIsSingle = std::is_same_v; + + // ============================================================================= + // PHASE 1: Table Lookup for Reduction Constants + // ============================================================================= + // Each table entry contains 3 consecutive values [high, mid, low] + // providing ~96-bit (F32) or ~192-bit (F64) precision for (4/π) × 2^exp + VU u_exponent = GetBiasedExponent(x); + VI i_table_idx = + BitCast(di, Add(ShiftLeft<1>(u_exponent), u_exponent)); // × 2 + 1 = × 3 + + // Gather three parts of (4/π) × 2^exp from precomputed table + // Generated by npsr/trig/data/reduction.h.sol + VU u_p_hi = GatherIndex(du, data::kLargeReductionTable, i_table_idx); + VU u_p_med = GatherIndex(du, data::kLargeReductionTable + 1, i_table_idx); + VU u_p_lo = GatherIndex(du, data::kLargeReductionTable + 2, i_table_idx); + + // ============================================================================= + // PHASE 2: Extract and Normalize Mantissa + // ============================================================================= + + // Extract and Normalize Mantissa + V abx = Abs(x); + VU u_input = BitCast(du, abx); + VU u_significand = And(u_input, Set(du, MantissaMask())); + // Add implicit leading 1 bit + VU u_integer_bit = + Or(u_significand, Set(du, static_cast(1) << MantissaBits())); + VU u_mantissa = Or(u_significand, u_integer_bit); + + // Split mantissa into halves for extended precision multiplication + // F32: 16-bit halves, F64: 32-bit halves + constexpr int kHalfShift = (sizeof(T) / 2) * 8; + VU u_low_mask = Set(du, (static_cast(1) << kHalfShift) - 1); + VU u_m0 = And(u_mantissa, u_low_mask); + VU u_m1 = ShiftRight(u_mantissa); + + // Split reduction constants into halves + VU u_p0 = And(u_p_lo, u_low_mask); + VU u_p1 = ShiftRight(u_p_lo); + VU u_p2 = And(u_p_med, u_low_mask); + VU u_p3 = ShiftRight(u_p_med); + VU u_p4 = And(u_p_hi, u_low_mask); + VU u_p5 = ShiftRight(u_p_hi); + + // ============================================================================= + // PHASE 3: Extended Precision Multiplication + // ============================================================================= + // mantissa × (4/π × 2^exp) using half-word multiplications + // F32: 16×16→32 bit, F64: 32×32→64 bit multiplications + + // Products with highest precision part + VU u_m04 = Mul(u_m0, u_p4); + VU u_m05 = Mul(u_m0, u_p5); + VU u_m14 = Mul(u_m1, u_p4); + // Omit u_m1 × u_p5 to prevent overflow + + // Products with medium precision part + VU u_m02 = Mul(u_m0, u_p2); + VU u_m03 = Mul(u_m0, u_p3); + VU u_m12 = Mul(u_m1, u_p2); + VU u_m13 = Mul(u_m1, u_p3); + + // Products with lowest precision part + VU u_m01 = Mul(u_m0, u_p1); + VU u_m10 = Mul(u_m1, u_p0); + VU u_m11 = Mul(u_m1, u_p1); + + // ============================================================================= + // PHASE 4: Carry Propagation and Result Assembly + // ============================================================================= + // Extract carry bits from each product + VU u_carry04 = ShiftRight(u_m04); + VU u_carry02 = ShiftRight(u_m02); + VU u_carry03 = ShiftRight(u_m03); + VU u_carry01 = ShiftRight(u_m01); + VU u_carry10 = ShiftRight(u_m10); + + // Extract lower halves + VU u_low04 = And(u_m04, u_low_mask); + VU u_low02 = And(u_m02, u_low_mask); + VU u_low05 = And(u_m05, u_low_mask); + VU u_low03 = And(u_m03, u_low_mask); + + // Column-wise accumulation (Intel SVML pattern) + VU u_col3 = Add(u_low05, Add(u_m14, u_carry04)); + VU u_col2 = Add(u_low04, Add(u_m13, u_carry03)); + VU u_col1 = Add(u_low02, Add(u_m11, u_carry01)); + VU u_col0 = Add(u_low03, Add(u_m12, u_carry02)); + + // Carry propagation through columns + VU u_sum0 = Add(u_carry10, u_col1); + VU u_carry_final0 = ShiftRight(u_sum0); + VU u_sum1 = Add(u_carry_final0, u_col0); + VU u_carry_final1 = ShiftRight(u_sum1); + VU u_sum1_shifted = ShiftLeft(u_sum1); + VU u_sum2 = Add(u_carry_final1, u_col2); + VU u_carry_final2 = ShiftRight(u_sum2); + VU u_sum3 = Add(u_carry_final2, u_col3); + + // Assemble final result + VU u_result0 = And(u_sum0, u_low_mask); + VU u_result2 = And(u_sum2, u_low_mask); + VU u_result3 = ShiftLeft(u_sum3); + + VU u_n_hi = Add(u_result3, u_result2); + VU u_n_lo = Add(u_sum1_shifted, u_result0); + + // ============================================================================= + // PHASE 5: Extract Quotient and Fractional Parts + // ============================================================================= + + // Extract integer quotient. 9 for F32, 12 for F64 + constexpr int kQuotientShift = ExponentBits() + 1; + VU u_shifted_n = ShiftRight(u_n_hi); + + // fractional shifts derived from magic constants + // F32: 5, 18, 14 (sum = 37, total with quotient = 46 = 2×23) + // F64: 28, 24, 40 (sum = 92, total with quotient = 104 = 2×52) + constexpr int kFracLowShift = kIsSingle ? 5 : 28; + constexpr int kFracMidShift = kIsSingle ? 18 : 24; + constexpr int kFracHighShift = kIsSingle ? 14 : 40; + + // Verify total shift constraint + constexpr int kTotalShift = + kQuotientShift + kFracLowShift + kFracMidShift + kFracHighShift; + static_assert(kTotalShift == (kIsSingle ? 46 : 104), + "Total shift must equal 2×mantissa_bits"); + + // Extract fractional parts + constexpr TU kFracMidMask = (static_cast(1) << kFracMidShift) - 1; + VU u_frac_low_bits = And(u_n_lo, Set(du, kFracMidMask)); + VU u_shifted_sig_lo = ShiftLeft(u_frac_low_bits); + VU u_frac_mid_bits = ShiftRight(u_n_lo); + constexpr TU kFracHighMask = + (static_cast(1) << (kFracHighShift - kFracLowShift)) - 1; + VU u_frac_high_bits = And(u_n_hi, Set(du, kFracHighMask)); + + // ============================================================================= + // PHASE 6: Conversion to Floating Point + // ============================================================================= + // magic constants for branchless int→float conversion + // Handle sign bit + VU u_sign_bit = And(BitCast(du, x), Set(du, SignMask())); + VU u_exponent_part = + Xor(u_sign_bit, BitCast(du, Set(d, static_cast(1.0)))); + VU u_quotient_signed = Or(u_shifted_n, u_exponent_part); + + // Magic number conversion for quotient + V shifter = Set(d, kIsSingle ? 0x1.8p15f : 0x1.8p43); + V integer_part = Add(shifter, BitCast(d, u_quotient_signed)); + + V n_hi = Sub(integer_part, shifter); + n_hi = Sub(BitCast(d, u_quotient_signed), n_hi); + + // constants for fractional parts + VU u_epsilon = BitCast(du, Set(d, kIsSingle ? 0x1p-23f : 0x1p-52)); + VU u_exp_mid = Xor(u_sign_bit, u_epsilon); + VU u_shifted_sig_mid = + Or(ShiftLeft(u_frac_high_bits), u_frac_mid_bits); + VU u_frac_mid_combined = Or(u_shifted_sig_mid, u_exp_mid); + V shifter_mid = BitCast(d, u_exp_mid); + V n_med = Sub(BitCast(d, u_frac_mid_combined), shifter_mid); + + VU u_epsilon_low = BitCast(du, Set(d, kIsSingle ? 0x1p-46f : 0x1p-104)); + VU u_exp_low = Xor(u_sign_bit, u_epsilon_low); + VU u_frac_low_combined = Or(u_shifted_sig_lo, u_exp_low); + + V exp_low = BitCast(d, u_exp_low); + V frac_low_combined = BitCast(d, u_frac_low_combined); + + V n = Add(n_hi, n_med); + V n_lo = Sub(n_hi, n); + n_lo = Add(n_med, n_lo); + n_lo = Add(n_lo, Sub(frac_low_combined, exp_low)); + + // ============================================================================= + // PHASE 7: Convert to Radians + // ============================================================================= + + // Multiply by π with error compensation (Cody-Waite multiplication) + constexpr auto kPiMul2 = data::kPiMul2; + const V pi2_hi = Set(d, kPiMul2[0]); + const V pi2_med = Set(d, kPiMul2[1]); + + V r = Mul(pi2_hi, n); + V r_lo, r_w0, r_w1; + if constexpr (!kNativeFMA && kIsSingle) { + using DW = RepartitionToWide; + using DH = Half; + using VW = Vec; + const DW dw; + const DH dh; + VW pi2_whi = Set(dw, data::kPiMul2[0]); + VW r0 = Mul(pi2_whi, PromoteUpperTo(dw, n)); + VW r1 = Mul(pi2_whi, PromoteLowerTo(dw, n)); + VW r_lo_w0 = Sub(r0, PromoteUpperTo(dw, r)); + VW r_lo_w1 = Sub(r1, PromoteLowerTo(dw, r)); + r_lo = Combine(d, DemoteTo(dh, r_lo_w0), DemoteTo(dh, r_lo_w1)); + r_w0 = BitCast(d, r0); + r_w1 = BitCast(d, r1); + } else { + r_lo = MulSub(pi2_hi, n, r); + r_lo = MulAdd(pi2_med, n, r_lo); + } + r_lo = MulAdd(pi2_hi, n_lo, r_lo); + + // ============================================================================= + // PHASE 8: Small Argument Handling + // ============================================================================= + + const V min_input = Set(d, static_cast(0x1p-20)); + const auto ismall_arg = Gt(min_input, abx); + + r = IfThenElse(ismall_arg, x, r); + r_lo = IfThenElse(ismall_arg, Zero(d), r_lo); + V r2 = Mul(r, r); + + // ============================================================================= + // PHASE 9: Table Lookup + // ============================================================================= + + // Generated by npsr/trig/data/approx.h.sol + const T *table_base = OP == Operation::kCos ? data::kCosApproxTable + : data::kSinApproxTable; + + // Calculate table index + VU u_n_mask = Set(du, kIsSingle ? 0xFF : 0x1FF); + VU u_index = And(BitCast(du, integer_part), u_n_mask); + VI u_table_index = BitCast(di, ShiftLeft<2>(u_index)); + + const V deriv_hi = GatherIndex(d, table_base, u_table_index); + const V sigma = GatherIndex(d, table_base + 1, u_table_index); + const V func_hi = GatherIndex(d, table_base + 2, u_table_index); + const V func_lo = GatherIndex(d, table_base + 3, u_table_index); + const V deriv = Add(deriv_hi, sigma); + + // ============================================================================= + // PHASE 10: Final Assembly + // ============================================================================= + V res_lo = NegMulAdd(func_hi, r, deriv); + res_lo = MulAdd(res_lo, r_lo, func_lo); + V res_hi_lo = MulAdd(sigma, r, func_hi); + V res_hi = MulAdd(deriv_hi, r, res_hi_lo); + + V sum_cor = MulAdd(sigma, r, Sub(func_hi, res_hi_lo)); + V deriv_hi_r_cor = MulAdd(deriv_hi, r, Sub(res_hi_lo, res_hi)); + deriv_hi_r_cor = Add(deriv_hi_r_cor, sum_cor); + res_lo = Add(res_lo, deriv_hi_r_cor); + + // Polynomial corrections + V s2 = Set(d, kIsSingle ? 0x1.1110b8p-7f : 0x1.1110fabb3551cp-7); + V s1 = Set(d, kIsSingle ? -0x1.555556p-3f : -0x1.5555555554448p-3); + V sin_poly = MulAdd(s2, r2, s1); + sin_poly = Mul(sin_poly, r); + sin_poly = Mul(sin_poly, r2); + + V c1 = Set(d, kIsSingle ? 0x1.5554f8p-5f : 0x1.5555555554ccfp-5); + const V neg_half = Set(d, static_cast(-0.5)); + V cos_poly; + if constexpr (kIsSingle) { + cos_poly = MulAdd(c1, r2, neg_half); + } else { + V c2 = Set(d, -0x1.6c16ab163b2d7p-10); + cos_poly = MulAdd(c2, r2, c1); + cos_poly = MulAdd(cos_poly, r2, neg_half); + } + cos_poly = Mul(cos_poly, r2); + + res_lo = MulAdd(sin_poly, deriv, res_lo); + res_lo = MulAdd(cos_poly, func_hi, res_lo); + return Add(res_hi, res_lo); +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_EXTENDED_INL_H_ diff --git a/npsr/trig/high-inl.h b/npsr/trig/high-inl.h new file mode 100644 index 0000000..68cdd9f --- /dev/null +++ b/npsr/trig/high-inl.h @@ -0,0 +1,299 @@ +#if defined(NPSR_TRIG_HIGH_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_HIGH_INL_H_ +#undef NPSR_TRIG_HIGH_INL_H_ +#else +#define NPSR_TRIG_HIGH_INL_H_ +#endif + +#include "npsr/hwy.h" +#include "npsr/lut-inl.h" +#include "npsr/trig/data/data.h" +#include "npsr/trig/low-inl.h" // Operation + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { + +template )> +NPSR_INTRIN V High(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + + using T = TFromV; + using D = DFromV; + using DU = RebindToUnsigned; + using DH = Half; + using DW = RepartitionToWide; + using VW = Vec; + + const D d; + const DU du; + const DH dh; + const DW dw; + // Load frequently used constants as vector registers + const V abs_mask = BitCast(d, Set(du, 0x7FFFFFFF)); + const V x_abs = And(abs_mask, x); + const V x_sign = AndNot(x_abs, x); + + // Transform cosine to sine using identity: cos(x) = sin(x + π/2) + const V half_pi = Set(d, data::kHalfPi); + V x_trans = x_abs; + if constexpr (OP == Operation::kCos) { + x_trans = Add(x_abs, half_pi); + } + // check zero input/subnormal for cosine (cos(~0) = 1) + const auto is_cos_near_zero = Eq(x_trans, half_pi); + + // Compute N = round(input/π) + const V magic_round = Set(d, 0x1.8p23f); + V n_biased = MulAdd(x_trans, Set(d, data::kInvPi), magic_round); + V n = Sub(n_biased, magic_round); + + // Adjust quotient for cosine (accounts for π/2 phase shift) + if constexpr (OP == Operation::kCos) { + // For cosine, we computed N = round((x + π/2)/π) but need N' for x: + // N = round((x + π/2)/π) = round(x/π + 0.5) + // This is often 1 more than round(x/π), so we subtract 0.5: + // N' = N - 0.5 + n = Sub(n, Set(d, 0.5f)); + } + auto WideCal = [](const VW &nh, const VW &xh_abs) -> VW { + const DFromV dw; + constexpr auto kPiPrec35 = data::kPiPrec35; + VW r = NegMulAdd(nh, Set(dw, kPiPrec35[0]), xh_abs); + r = NegMulAdd(nh, Set(dw, kPiPrec35[1]), r); + VW r2 = Mul(r, r); + + // Polynomial coefficients for sin(r) approximation on [-π/2, π/2] + const VW c9 = Set(dw, 0x1.5dbdf0e4c7deep-19); + const VW c7 = Set(dw, -0x1.9f6ffeea73463p-13); + const VW c5 = Set(dw, 0x1.110ed3804ca96p-7); + const VW c3 = Set(dw, -0x1.55554bc836587p-3); + VW poly = MulAdd(c9, r2, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + poly = Mul(poly, r2); + poly = MulAdd(r, poly, r); + return poly; + }; + + VW poly_lo = WideCal(PromoteLowerTo(dw, n), PromoteLowerTo(dw, x_abs)); + VW poly_up = WideCal(PromoteUpperTo(dw, n), PromoteUpperTo(dw, x_abs)); + + V poly = Combine(d, DemoteTo(dh, poly_up), DemoteTo(dh, poly_lo)); + // Extract octant sign information from quotient and flip the sign bit + poly = Xor(poly, + BitCast(d, ShiftLeft(BitCast(du, n_biased)))); + if constexpr (OP == Operation::kCos) { + poly = IfThenElse(is_cos_near_zero, Set(d, 1.0f), poly); + } else { + // Restore original sign for sine (odd function) + poly = Xor(poly, x_sign); + } + return poly; +} +/** + * This function computes sin(x) or cos(x) for |x| < 2^24 using the Cody-Waite + * reduction algorithm combined with table lookup and polynomial approximation, + * achieves < 1 ULP error for |x| < 2^24. + * + * Algorithm Overview: + * 1. Range Reduction: Reduces input x to r where |r| < π/16 + * - Computes n = round(x * 16/π) and r = x - n*π/16 + * - Uses multi-precision arithmetic (3 parts of π/16) for accuracy + * + * 2. Table Lookup: Retrieves precomputed sin(n*π/16) and cos(n*π/16) + * - Includes high and low precision parts for cos values + * + * 3. Polynomial Approximation: Computes sin(r) and cos(r) + * - sin(r) ≈ r * (1 + r²*P_sin(r²)) where P_sin is a minimax polynomial + * - cos(r) ≈ 1 + r²*P_cos(r²) where P_cos is a minimax polynomial + * + * 4. Reconstruction: Applies angle addition formulas + * - sin(x) = sin(n*π/16 + r) = sin(n*π/16)*cos(r) + cos(n*π/16)*sin(r) + * - cos(x) = cos(n*π/16 + r) = cos(n*π/16)*cos(r) - sin(n*π/16)*sin(r) + * + */ +template )> +NPSR_INTRIN V High(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + using T = TFromV; + using D = DFromV; + using DU = RebindToUnsigned; + using VU = Vec; + + const D d; + const DU du; + + // Step 1: Range reduction - find n such that x = n*(π/16) + r, where |r| < + // π/16 + V magic = Set(d, 0x1.8p52); + V n_biased = MulAdd(x, Set(d, data::k16DivPi), magic); + V n = Sub(n_biased, magic); + + // Extract integer index for table lookup (n mod 16) + VU n_int = BitCast(du, n_biased); + VU table_idx = And(n_int, Set(du, 0xF)); // Mask to get n mod 16 + + // Step 2: Load precomputed sine/cosine values for n mod 16 + V sin_hi, cos_hi, cos_lo; + kKPi16Table.Load(table_idx, sin_hi, cos_hi, cos_lo); + // Note: cos_lo and sin_lo are packed together (32 bits each) to save memory. + // cos_lo can be used as-is since it's in the upper bits, sin_lo needs + // extraction. The precision loss is negligible for the final result. + // see data/lut-inl.h.sol for the table generation code. + V sin_lo = BitCast(d, ShiftLeft<32>(BitCast(du, cos_lo))); + + // Step 3: Multi-precision computation of remainder r + // r = x - n*(π/16)_high + constexpr auto kPiDiv16Prec29 = data::kPiDiv16Prec29; + V r_hi = NegMulAdd(n, Set(d, kPiDiv16Prec29[0]), x); + if constexpr (!kNativeFMA) { + // For F64, we need to handle the low precision part separately + r_hi = NegMulAdd(n, Set(d, kPiDiv16Prec29[3]), r_hi); + } + const V pi16_med = Set(d, kPiDiv16Prec29[1]); + const V pi16_lo = Set(d, kPiDiv16Prec29[2]); + V r_med = NegMulAdd(n, pi16_med, r_hi); + V r = NegMulAdd(n, pi16_lo, r_med); + + // Compute low precision part of r for extra accuracy + V term = NegMulAdd(pi16_med, n, Sub(r_hi, r_med)); + V r_lo = MulAdd(pi16_lo, n, Sub(r, r_med)); + r_lo = Sub(term, r_lo); + + // Step 4: Polynomial approximation + V r2 = Mul(r, r); + + // Minimax polynomial for (sin(r)/r - 1) + // sin(r)/r = 1 - r²/3! + r⁴/5! - r⁶/7! + ... + // This polynomial computes the terms after 1 + V sin_poly = Set(d, 0x1.71c97d22a73ddp-19); + sin_poly = MulAdd(sin_poly, r2, Set(d, -0x1.a01a00ed01edep-13)); + sin_poly = MulAdd(sin_poly, r2, Set(d, 0x1.111111110e99dp-7)); + sin_poly = MulAdd(sin_poly, r2, Set(d, -0x1.5555555555555p-3)); + + // Minimax polynomial for (cos(r) - 1)/r² + // cos(r) = 1 - r²/2! + r⁴/4! - r⁶/6! + ... + // This polynomial computes (cos(r) - 1)/r² + V cos_poly = Set(d, 0x1.9ffd7d9d749bcp-16); + cos_poly = MulAdd(cos_poly, r2, Set(d, -0x1.6c16c075d73f8p-10)); + cos_poly = MulAdd(cos_poly, r2, Set(d, 0x1.555555554e8d6p-5)); + cos_poly = MulAdd(cos_poly, r2, Set(d, -0x1.ffffffffffffcp-2)); + + // Step 5: Reconstruction using angle addition formulas + // + // Mathematical equivalence between traditional and SVML approaches: + // + // Traditional angle addition: + // sin(a+r) = sin(a)*cos(r) + cos(a)*sin(r) + // cos(a+r) = cos(a)*cos(r) - sin(a)*sin(r) + // + // Where for small r (|r| < π/16): + // cos(r) ≈ 1 + r²*cos_poly + // sin(r) ≈ r*(1 + sin_poly) ≈ r + r*sin_poly + // + // SVML's efficient linear approximation: + // sin(a+r) ≈ sin(a) + cos(a)*r + polynomial_corrections + // cos(a+r) ≈ cos(a) - sin(a)*r + polynomial_corrections + // + // This is mathematically equivalent but computationally more efficient: + // - Uses first-order linear terms directly: Sh + Ch*R, Ch - R*Sh + // - Applies higher-order polynomial corrections separately + // - Fewer multiplications and better numerical stability + // + // Implementation follows SVML structure: + // sin(n*π/16 + r) = sin_table + cos_table*remainder (+ corrections) + // cos(n*π/16 + r) = cos_table - sin_table*remainder (+ corrections) + V result; + if constexpr (OP == Operation::kCos) { + // Cosine reconstruction: cos_table - sin_table*remainder + // Equivalent to: cos(a)*cos(r) - sin(a)*sin(r) but more efficient + V res_hi = NegMulAdd(r, sin_hi, cos_hi); // cos_hi - r*sin_hi + + // This captures the precision lost in the main computation + V r_sin_hi = Sub(cos_hi, res_hi); // Extract high part of multiplication + + // Handles rounding errors and adds sin_low contribution + V r_sin_low = MulSub(r, sin_hi, r_sin_hi); // Compute multiplication error + V sin_low_corr = MulAdd(r, sin_lo, r_sin_low); // Add sin_low term + + // This is used to apply the low-precision remainder correction + V sin_cos_r = MulAdd(r, cos_hi, sin_hi); + + // Main low precision correction: cos_low - r_low*(sin_table + cos_table*r) + // Applies the effect of the low-precision remainder on the final result + V low_corr = NegMulAdd(r_lo, sin_cos_r, cos_lo); + + // Polynomial corrections using the remainder + V r_sin = Mul(r, sin_hi); // For polynomial application + + // Apply polynomial corrections: cos_table*cos_poly - r*sin_table*sin_poly + // This handles the higher-order terms from cos(r) and sin(r) expansions + V poly_corr = Mul(cos_hi, cos_poly); // cos(a) * (cos(r)-1)/r² + // - sin(a)*r * (sin(r)/r-1) + poly_corr = NegMulAdd(r_sin, sin_poly, poly_corr); + + // Combine all low precision corrections + V total_low = Sub(low_corr, sin_low_corr); + + // Final assembly: main_term + r²*polynomial_corrections + low_corrections + result = MulAdd(r2, poly_corr, total_low); + result = Add(res_hi, result); + + } else { + // Sine reconstruction: sin_table + cos_table*remainder + // Equivalent to: sin(a)*cos(r) + cos(a)*sin(r) but more efficient + V res_hi = MulAdd(r, cos_hi, sin_hi); // sin_hi + r*cos_hi + + // This captures the precision lost in the main computation + V r_cos_hi = Sub(res_hi, sin_hi); // Extract high part of multiplication + + // Handles rounding errors and adds cos_low contribution + V r_cos_low = MulSub(r, cos_hi, r_cos_hi); // Compute multiplication error + V cos_low_corr = MulAdd(r, cos_lo, r_cos_low); // Add cos_low term + + // Intermediate term for r_low correction: cos_table - sin_table*r + // This is used to apply the low-precision remainder correction + V cos_r_sin = NegMulAdd(r, sin_hi, cos_hi); + + // Main low precision correction: sin_low - r_low*(cos_table - sin_table*r) + // Applies the effect of the low-precision remainder on the final result + V low_corr = MulAdd(r_lo, cos_r_sin, sin_lo); + // Polynomial corrections using the remainder + V r_cos = Mul(r, cos_hi); // For polynomial application + + // Apply polynomial corrections: sin_table*cos_poly + r*cos_table*sin_poly + // This handles the higher-order terms from cos(r) and sin(r) expansions + V poly_corr = Mul(sin_hi, cos_poly); // sin(a) * (cos(r)-1)/r² + poly_corr = + MulAdd(r_cos, sin_poly, poly_corr); // + cos(a)*r * (sin(r)/r-1) + + // Combine all low precision corrections + V total_low = Add(low_corr, cos_low_corr); + // Final assembly: main_term + r²*polynomial_corrections + low_corrections + result = MulAdd(r2, poly_corr, total_low); + result = Add(res_hi, result); + } + + // Apply final sign correction same for both sine and cosine + // Both functions change sign every π radians, corresponding to bit 4 of n_int + // This unified approach works because: + // - sin(x + π) = -sin(x) + // - cos(x + π) = -cos(x) + VU x_sign_int = ShiftLeft<63>(BitCast(du, x)); + // XOR with quadrant info in n_biased + VU combined = Xor(BitCast(du, n_biased), ShiftLeft<4>(x_sign_int)); + // Extract final sign + VU sign = ShiftRight<4>(combined); + sign = ShiftLeft<63>(sign); + result = Xor(result, BitCast(d, sign)); // Apply sign flip + return result; +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_HIGH_INL_H_ diff --git a/npsr/trig/inl.h b/npsr/trig/inl.h new file mode 100644 index 0000000..2e22bd7 --- /dev/null +++ b/npsr/trig/inl.h @@ -0,0 +1,154 @@ +// Main trigonometric function dispatcher for Highway SIMD library +// This file provides the public API for sine and cosine functions with +// configurable precision, special case handling, and algorithm selection +// +// The implementation automatically selects between three algorithms: +// 1. Low precision: ~1-4 ULP error, fastest +// 2. High precision: ~1 ULP error, moderate speed +// 3. Extended precision: Exact for |x| > 2^24 (float) or 2^53 (double) + +#if defined(NPSR_TRIG_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_INL_H_ +#undef NPSR_TRIG_INL_H_ +#else +#define NPSR_TRIG_INL_H_ +#endif + +#include "npsr/hwy.h" +#include "npsr/precise.h" +#include "npsr/trig/extended-inl.h" // Payne-Hanek reduction for huge arguments +#include "npsr/trig/high-inl.h" // High precision with table lookup +#include "npsr/trig/low-inl.h" // Fast low precision implementation + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { +/** + * @brief Unified sine/cosine implementation with configurable precision + * + * This template function dispatches to the appropriate algorithm based on: + * - Precision requirements (Low vs High accuracy) + * - Input magnitude (standard vs extended precision for large arguments) + * - Special case handling (NaN, Inf) + * + * @tparam OP Operation type: kSin or kCos + * @tparam Prec Precise configuration class with accuracy/feature flags + * @tparam V Highway vector type + * + * @param prec Precise object controlling FP environment and exceptions + * @param x Input vector + * @return sin(x) or cos(x) depending on OP + * + * Algorithm selection: + * 1. If kLowAccuracy: Use Low<> (Cody-Waite with minimal polynomial) + * 2. Otherwise: Use High<> (π/16 reduction with table lookup) + * 3. If kLargeArgument and |x| > threshold: Override with Extended<> + * + * Thresholds for extended precision: + * - Float: |x| > 10,000 (empirically chosen for accuracy) + * - Double: |x| > 2^24 (16,777,216 - where 53-bit mantissa loses precision) + */ +template +NPSR_INTRIN V Trig(Prec &prec, V x) { + using namespace hwy::HWY_NAMESPACE; + constexpr bool kIsSingle = std::is_same_v, float>; + const DFromV d; + V ret; + // Step 1: Select base algorithm based on accuracy requirements + if constexpr (Prec::kLowAccuracy) { + // Low precision: Cody-Waite reduction with degree-9 polynomial + // Error: ~2 ULP and 3~ for non-fma + ret = Low(x); + } else { + // High precision: π/16 reduction with table lookup + polynomial + // Error: ~1 ULP + ret = High(x); + } + // Step 2: Handle special cases (NaN, Inf) if enabled + auto is_finite = IsFinite(x); + if constexpr (Prec::kSpecialCases) { + // IEEE 754 requires: sin(±∞) = NaN, cos(±∞) = NaN + ret = IfThenElse(is_finite, ret, NaN(d)); + // -0.0 should return -0.0 for sine + if constexpr (OP == Operation::kSin) { + ret = IfThenElse(Eq(x, Set(d, 0.0)), x, ret); + } + } + // Step 3: Handle very large arguments if enabled + // For |x| > threshold, standard algorithms lose precision due to + // catastrophic cancellation in x - n*π reduction + if constexpr (Prec::kLargeArgument) { + // Thresholds chosen based on when standard reduction loses accuracy: + // - Float: 10,000 is conservative but ensures < 1 ULP error + // - Double: 2^24 is where mantissa can't represent x and x-2π distinctly + auto has_large_arg = + And(Gt(Abs(x), Set(d, kIsSingle ? 10000.0f : 16777216.0)), is_finite); + + // Extended precision is expensive, only use when necessary + if (HWY_UNLIKELY(!AllFalse(d, has_large_arg))) { + // Payne-Hanek reduction: Uses ~96-bit (float) or ~192-bit (double) + // precision for 4/π to maintain accuracy for huge arguments + ret = IfThenElse(has_large_arg, Extended(x), ret); + } + } + // Step 4: Raise invalid operation exception for infinity inputs + if constexpr (Prec::kExceptions) { + prec.Raise(!AllFalse(d, IsInf(x)) ? FPExceptions::kInvalid : 0); + } + return ret; +} + +} // namespace npsr::HWY_NAMESPACE::trig + +// Public API in the main npsr namespace +namespace npsr::HWY_NAMESPACE { + +/** + * @brief Compute sine of vector elements with configurable precision + * + * @tparam Prec Precise configuration (e.g., Precise{kLowAccuracy}) + * @tparam V Highway vector type + * @param prec Precise object managing FP environment + * @param x Input vector + * @return sin(x) for each element + * + * @example + * ```cpp + * Precise prec{ + * kLowAccuracy, kNoLargeArgument, kNoExceptions, kNoSpecialCases + * }; + * auto result = Sin(prec, input_vector); + * ``` + */ +template +NPSR_INTRIN V Sin(Prec &prec, V x) { + return trig::Trig(prec, x); +} + +/** + * @brief Compute cosine of vector elements with configurable precision + * + * @tparam Prec Precise configuration (e.g., Precise{kLowAccuracy}) + * @tparam V Highway vector type + * @param prec Precise object managing FP environment + * @param x Input vector + * @return cos(x) for each element + * + * @example + * ```cpp + * Precise prec{ + * kLowAccuracy, kNoLargeArgument, kNoExceptions, kNoSpecialCases + * }; + * auto result = Cos(prec, input_vector); + * ``` + */ +template +NPSR_INTRIN V Cos(Prec &prec, V x) { + return trig::Trig(prec, x); +} + +} // namespace npsr::HWY_NAMESPACE + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_INL_H_ diff --git a/npsr/trig/low-inl.h b/npsr/trig/low-inl.h new file mode 100644 index 0000000..6da18ae --- /dev/null +++ b/npsr/trig/low-inl.h @@ -0,0 +1,153 @@ +#if defined(NPSR_TRIG_LOW_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_LOW_INL_H_ +#undef NPSR_TRIG_LOW_INL_H_ +#else +#define NPSR_TRIG_LOW_INL_H_ +#endif + +#include "npsr/hwy.h" +#include "npsr/trig/data/data.h" + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { + +enum class Operation { kSin = 0, kCos = 1 }; + +template )> +NPSR_INTRIN V PolyLow(V r, V r2) { + using namespace hn; + + const DFromV d; + constexpr bool kCos = OP == Operation::kCos; + const V c9 = Set(d, kCos ? 0x1.5d866ap-19f : 0x1.5dbdfp-19f); + const V c7 = Set(d, kCos ? -0x1.9f6d9ep-13 : -0x1.9f6ffep-13f); + const V c5 = Set(d, kCos ? 0x1.110ec8p-7 : 0x1.110eccp-7f); + const V c3 = Set(d, -0x1.55554cp-3f); + V poly = MulAdd(c9, r2, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + if constexpr (OP == Operation::kCos) { + // Although this path handles cosine, we have already transformed the + // input using the identity: cos(x) = sin(x + π/2) This means we're no + // longer directly evaluating a cosine Taylor series; instead, we evaluate + // the sine approximation polynomial at (x + π/2). + // + // The sine approximation has the general form: + // sin(r) ≈ r + r³ · P(r²) + // + // So, we compute: + // r³ = r · r² + // sin(r) ≈ r + r³ · poly + // + // This formulation preserves accuracy by computing the highest order + // terms last, which benefits from FMA to reduce rounding error. + V r3 = Mul(r2, r); + poly = MulAdd(r3, poly, r); + } else { + poly = Mul(poly, r2); + poly = MulAdd(r, poly, r); + } + return poly; +} + +template )> +NPSR_INTRIN V PolyLow(V r, V r2) { + using namespace hn; + + const DFromV d; + const V c15 = Set(d, -0x1.9f1517e9f65fp-41); + const V c13 = Set(d, 0x1.60e6bee01d83ep-33); + const V c11 = Set(d, -0x1.ae6355aaa4a53p-26); + const V c9 = Set(d, 0x1.71de3806add1ap-19); + const V c7 = Set(d, -0x1.a01a019a659ddp-13); + const V c5 = Set(d, 0x1.111111110a573p-7); + const V c3 = Set(d, -0x1.55555555554a8p-3); + V poly = MulAdd(c15, r2, c13); + poly = MulAdd(r2, poly, c11); + poly = MulAdd(r2, poly, c9); + poly = MulAdd(r2, poly, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + return poly; +} + +template +NPSR_INTRIN V Low(V x) { + using namespace hn; + using hwy::SignMask; + namespace data = ::npsr::trig::data; + + const DFromV d; + const RebindToUnsigned du; + using T = TFromV; + // Load frequently used constants as vector registers + const V abs_mask = BitCast(d, Set(du, SignMask() - 1)); + const V x_abs = And(abs_mask, x); + const V x_sign = AndNot(x_abs, x); + + constexpr bool kIsSingle = std::is_same_v; + // Transform cosine to sine using identity: cos(x) = sin(x + π/2) + const V half_pi = Set(d, data::kHalfPi); + V x_trans = x_abs; + if constexpr (OP == Operation::kCos) { + x_trans = Add(x_abs, half_pi); + } + // check zero input/subnormal for cosine (cos(~0) = 1) + const auto is_cos_near_zero = Eq(x_trans, half_pi); + + // Compute N = round(x/π) using "magic number" technique + // and stores integer part in mantissa + const V magic_round = Set(d, kIsSingle ? 0x1.8p23f : 0x1.8p52); + V n_biased = MulAdd(x_trans, Set(d, data::kInvPi), magic_round); + V n = Sub(n_biased, magic_round); + + // Adjust quotient for cosine (accounts for π/2 phase shift) + if constexpr (OP == Operation::kCos) { + // For cosine, we computed N = round((x + π/2)/π) but need N' for x: + // N = round((x + π/2)/π) = round(x/π + 0.5) + // This is often 1 more than round(x/π), so we subtract 0.5: + // N' = N - 0.5 + n = Sub(n, Set(d, static_cast(0.5))); + } + // Use Cody-Waite method with triple-precision PI + constexpr auto kPi = data::kPi; + + V r = NegMulAdd(n, Set(d, kPi[0]), x_abs); + r = NegMulAdd(n, Set(d, kPi[1]), r); + V r_lo = NegMulAdd(n, Set(d, kPi[2]), r); + if constexpr (!kNativeFMA) { + if (!kIsSingle) { + r = r_lo; + } + r_lo = NegMulAdd(n, Set(d, kPi[3]), r_lo); + } + + if constexpr (kIsSingle) { + r = r_lo; + } + V r2 = Mul(r, r); + V poly = PolyLow(r, r2); + + if constexpr (!kIsSingle) { + V r2_corr = Mul(r2, r_lo); + poly = MulAdd(r2_corr, poly, r_lo); + } + + // Extract octant sign information from quotient and flip the sign bit + poly = Xor(poly, + BitCast(d, ShiftLeft(BitCast(du, n_biased)))); + if constexpr (OP == Operation::kCos) { + poly = IfThenElse(is_cos_near_zero, Set(d, static_cast(1.0)), poly); + } else { + // Restore original sign for sine (odd function) + poly = Xor(poly, x_sign); + } + return poly; +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_LOW_INL_H_