DoxigAlpha

Modulus

A modulus, defining a finite field. All operations within the field are performed modulo this modulus, without heap allocations. max_bits represents the number of bits in the maximum value the modulus can be set to.

Fields of this type

Fields

#
zero:Fe
The neutral element.
v:FeUint
The modulus value.
rr:Fe
R^2 for the Montgomery representation.
m0inv:Limb
Inverse of the first limb
leading:usize
Number of leading zero bits in the modulus.

Actual size of the modulus, in bits.

Functions

#
bits
Actual size of the modulus, in bits.
one
Returns the element `1`.
fromUint
Creates a new modulus from a `Uint` value.
fromPrimitive
Creates a new modulus from a primitive value.
fromBytes
Creates a new modulus from a byte string.
toBytes
Serializes the modulus to a byte string.
rejectNonCanonical
Rejects field elements that are not in the canonical form.
add
Adds two field elements (mod m).
sub
Subtracts two field elements (mod m).
toMontgomery
Converts a field element to the Montgomery form.
fromMontgomery
Takes a field element out of the Montgomery form.
reduce
Reduces an arbitrary `Uint`, converting it to a field element.
mul
Multiplies two field elements.
sq
Squares a field element.
pow
Returns x^e (mod m) in constant time.
powPublic
Returns x^e (mod m), assuming that the exponent is public.
powWithEncodedExponent
Returns x^e (mod m), with the exponent provided as a byte string.
powWithEncodedPublicExponent
Returns x^e (mod m), the exponent being public and provided as a byte string.

Source

Implementation

#
pub fn Modulus(comptime max_bits: comptime_int) type {
    return struct {
        const Self = @This();

        /// A field element, representing a value within the field defined by this modulus.
        pub const Fe = Fe_(max_bits);

        const FeUint = Fe.FeUint;

        /// The neutral element.
        zero: Fe,

        /// The modulus value.
        v: FeUint,

        /// R^2 for the Montgomery representation.
        rr: Fe,
        /// Inverse of the first limb
        m0inv: Limb,
        /// Number of leading zero bits in the modulus.
        leading: usize,

        // Number of active limbs in the modulus.
        fn limbs_count(self: Self) usize {
            return self.v.limbs_len;
        }

        /// Actual size of the modulus, in bits.
        pub fn bits(self: Self) usize {
            return self.limbs_count() * t_bits - self.leading;
        }

        /// Returns the element `1`.
        pub fn one(self: Self) Fe {
            var fe = self.zero;
            fe.v.limbs()[0] = 1;
            return fe;
        }

        /// Creates a new modulus from a `Uint` value.
        /// The modulus must be odd and larger than 2.
        pub fn fromUint(v_: FeUint) InvalidModulusError!Self {
            if (!v_.isOdd()) return error.EvenModulus;

            var v = v_.normalize();
            const hi = v.limbsConst()[v.limbs_len - 1];
            const lo = v.limbsConst()[0];

            if (v.limbs_len < 2 and lo < 3) {
                return error.ModulusTooSmall;
            }

            const leading = @clz(hi) - carry_bits;

            var y = lo;

            inline for (0..comptime math.log2_int(usize, t_bits)) |_| {
                y = y *% (2 -% lo *% y);
            }
            const m0inv = (@as(Limb, 1) << t_bits) - (@as(TLimb, @truncate(y)));

            const zero = Fe{ .v = FeUint.zero };

            var m = Self{
                .zero = zero,
                .v = v,
                .leading = leading,
                .m0inv = m0inv,
                .rr = undefined, // will be computed right after
            };
            m.shrink(&m.zero) catch unreachable;
            computeRR(&m);

            return m;
        }

        /// Creates a new modulus from a primitive value.
        /// The modulus must be odd and larger than 2.
        pub fn fromPrimitive(comptime T: type, x: T) (InvalidModulusError || OverflowError)!Self {
            comptime assert(@bitSizeOf(T) <= max_bits); // Primitive type is larger than the modulus type.
            const v = try FeUint.fromPrimitive(T, x);
            return try Self.fromUint(v);
        }

        /// Creates a new modulus from a byte string.
        pub fn fromBytes(bytes: []const u8, comptime endian: Endian) (InvalidModulusError || OverflowError)!Self {
            const v = try FeUint.fromBytes(bytes, endian);
            return try Self.fromUint(v);
        }

        /// Serializes the modulus to a byte string.
        pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
            return self.v.toBytes(bytes, endian);
        }

        /// Rejects field elements that are not in the canonical form.
        pub fn rejectNonCanonical(self: Self, fe: Fe) error{NonCanonical}!void {
            if (fe.limbs_count() != self.limbs_count() or ct.limbsCmpGeq(fe.v, self.v)) {
                return error.NonCanonical;
            }
        }

        // Makes the number of active limbs in a field element match the one of the modulus.
        fn shrink(self: Self, fe: *Fe) OverflowError!void {
            const new_len = self.limbs_count();
            if (fe.limbs_count() < new_len) return error.Overflow;
            var acc: Limb = 0;
            for (fe.v.limbsConst()[new_len..]) |limb| {
                acc |= limb;
            }
            if (acc != 0) return error.Overflow;
            if (new_len > fe.v.limbs_buffer.len) return error.Overflow;
            fe.v.limbs_len = new_len;
        }

        // Computes R^2 for the Montgomery representation.
        fn computeRR(self: *Self) void {
            self.rr = self.zero;
            const n = self.rr.limbs_count();
            self.rr.v.limbs()[n - 1] = 1;
            for ((n - 1)..(2 * n)) |_| {
                self.shiftIn(&self.rr, 0);
            }
            self.shrink(&self.rr) catch unreachable;
        }

        /// Computes x << t_bits + y (mod m)
        fn shiftIn(self: Self, x: *Fe, y: Limb) void {
            var d = self.zero;
            const x_limbs = x.v.limbs();
            const d_limbs = d.v.limbs();
            const m_limbs = self.v.limbsConst();

            var need_sub = false;
            var i: usize = t_bits - 1;
            while (true) : (i -= 1) {
                var carry: u1 = @truncate(math.shr(Limb, y, i));
                var borrow: u1 = 0;
                for (0..self.limbs_count()) |j| {
                    const l = ct.select(need_sub, d_limbs[j], x_limbs[j]);
                    var res = (l << 1) + carry;
                    x_limbs[j] = @as(TLimb, @truncate(res));
                    carry = @truncate(res >> t_bits);

                    res = x_limbs[j] -% m_limbs[j] -% borrow;
                    d_limbs[j] = @as(TLimb, @truncate(res));

                    borrow = @truncate(res >> t_bits);
                }
                need_sub = ct.eql(carry, borrow);
                if (i == 0) break;
            }
            x.v.cmov(need_sub, d.v);
        }

        /// Adds two field elements (mod m).
        pub fn add(self: Self, x: Fe, y: Fe) Fe {
            var out = x;
            const overflow = out.v.addWithOverflow(y.v);
            const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
            const need_sub = ct.eql(overflow, underflow);
            _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
            return out;
        }

        /// Subtracts two field elements (mod m).
        pub fn sub(self: Self, x: Fe, y: Fe) Fe {
            var out = x;
            const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
            _ = out.v.conditionalAddWithOverflow(underflow, self.v);
            return out;
        }

        /// Converts a field element to the Montgomery form.
        pub fn toMontgomery(self: Self, x: *Fe) RepresentationError!void {
            if (x.montgomery) {
                return error.UnexpectedRepresentation;
            }
            self.shrink(x) catch unreachable;
            x.* = self.montgomeryMul(x.*, self.rr);
            x.montgomery = true;
        }

        /// Takes a field element out of the Montgomery form.
        pub fn fromMontgomery(self: Self, x: *Fe) RepresentationError!void {
            if (!x.montgomery) {
                return error.UnexpectedRepresentation;
            }
            self.shrink(x) catch unreachable;
            x.* = self.montgomeryMul(x.*, self.one());
            x.montgomery = false;
        }

        /// Reduces an arbitrary `Uint`, converting it to a field element.
        pub fn reduce(self: Self, x: anytype) Fe {
            var out = self.zero;
            var i = x.limbs_len - 1;
            if (self.limbs_count() >= 2) {
                const start = @min(i, self.limbs_count() - 2);
                var j = start;
                while (true) : (j -= 1) {
                    out.v.limbs()[j] = x.limbsConst()[i];
                    i -= 1;
                    if (j == 0) break;
                }
            }
            while (true) : (i -= 1) {
                self.shiftIn(&out, x.limbsConst()[i]);
                if (i == 0) break;
            }
            return out;
        }

        fn montgomeryLoop(self: Self, d: *Fe, x: Fe, y: Fe) u1 {
            assert(d.limbs_count() == x.limbs_count());
            assert(d.limbs_count() == y.limbs_count());
            assert(d.limbs_count() == self.limbs_count());

            const a_limbs = x.v.limbsConst();
            const b_limbs = y.v.limbsConst();
            const d_limbs = d.v.limbs();
            const m_limbs = self.v.limbsConst();

            var overflow: u1 = 0;
            for (0..self.limbs_count()) |i| {
                var carry: Limb = 0;

                var wide = ct.mulWide(a_limbs[i], b_limbs[0]);
                var z_lo = @addWithOverflow(d_limbs[0], wide.lo);
                const f = @as(TLimb, @truncate(z_lo[0] *% self.m0inv));
                var z_hi = wide.hi +% z_lo[1];
                wide = ct.mulWide(f, m_limbs[0]);
                z_lo = @addWithOverflow(z_lo[0], wide.lo);
                z_hi +%= z_lo[1];
                z_hi +%= wide.hi;
                carry = (z_hi << 1) | (z_lo[0] >> t_bits);

                for (1..self.limbs_count()) |j| {
                    wide = ct.mulWide(a_limbs[i], b_limbs[j]);
                    z_lo = @addWithOverflow(d_limbs[j], wide.lo);
                    z_hi = wide.hi +% z_lo[1];
                    wide = ct.mulWide(f, m_limbs[j]);
                    z_lo = @addWithOverflow(z_lo[0], wide.lo);
                    z_hi +%= z_lo[1];
                    z_hi +%= wide.hi;
                    z_lo = @addWithOverflow(z_lo[0], carry);
                    z_hi +%= z_lo[1];
                    if (j > 0) {
                        d_limbs[j - 1] = @as(TLimb, @truncate(z_lo[0]));
                    }
                    carry = (z_hi << 1) | (z_lo[0] >> t_bits);
                }
                const z = overflow + carry;
                d_limbs[self.limbs_count() - 1] = @as(TLimb, @truncate(z));
                overflow = @as(u1, @truncate(z >> t_bits));
            }
            return overflow;
        }

        // Montgomery multiplication.
        fn montgomeryMul(self: Self, x: Fe, y: Fe) Fe {
            var d = self.zero;
            assert(x.limbs_count() == self.limbs_count());
            assert(y.limbs_count() == self.limbs_count());
            const overflow = self.montgomeryLoop(&d, x, y);
            const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
            const need_sub = ct.eql(overflow, underflow);
            _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
            d.montgomery = x.montgomery == y.montgomery;
            return d;
        }

        // Montgomery squaring.
        fn montgomerySq(self: Self, x: Fe) Fe {
            var d = self.zero;
            assert(x.limbs_count() == self.limbs_count());
            const overflow = self.montgomeryLoop(&d, x, x);
            const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
            const need_sub = ct.eql(overflow, underflow);
            _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
            d.montgomery = true;
            return d;
        }

        // Returns x^e (mod m), with the exponent provided as a byte string.
        // `public` must be set to `false` if the exponent it secret.
        fn powWithEncodedExponentInternal(self: Self, x: Fe, e: []const u8, endian: Endian, comptime public: bool) NullExponentError!Fe {
            var acc: u8 = 0;
            for (e) |b| acc |= b;
            if (acc == 0) return error.NullExponent;

            var out = self.one();
            self.toMontgomery(&out) catch unreachable;

            if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) {
                // Do not use a precomputation table for short, public exponents
                var x_m = x;
                if (x.montgomery == false) {
                    self.toMontgomery(&x_m) catch unreachable;
                }
                var s = switch (endian) {
                    .big => 0,
                    .little => e.len - 1,
                };
                while (true) {
                    const b = e[s];
                    var j: u3 = 7;
                    while (true) : (j -= 1) {
                        out = self.montgomerySq(out);
                        const k: u1 = @truncate(b >> j);
                        if (k != 0) {
                            const t = self.montgomeryMul(out, x_m);
                            @memcpy(out.v.limbs(), t.v.limbsConst());
                        }
                        if (j == 0) break;
                    }
                    switch (endian) {
                        .big => {
                            s += 1;
                            if (s == e.len) break;
                        },
                        .little => {
                            if (s == 0) break;
                            s -= 1;
                        },
                    }
                }
            } else {
                // Use a precomputation table for large exponents
                var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
                if (x.montgomery == false) {
                    self.toMontgomery(&pc[0]) catch unreachable;
                }
                for (1..pc.len) |i| {
                    pc[i] = self.montgomeryMul(pc[i - 1], pc[0]);
                }
                var t0 = self.zero;
                var s = switch (endian) {
                    .big => 0,
                    .little => e.len - 1,
                };
                while (true) {
                    const b = e[s];
                    for ([_]u3{ 4, 0 }) |j| {
                        for (0..4) |_| {
                            out = self.montgomerySq(out);
                        }
                        const k = (b >> j) & 0b1111;
                        if (public or std.options.side_channels_mitigations == .none) {
                            if (k == 0) continue;
                            t0 = pc[k - 1];
                        } else {
                            for (pc, 0..) |t, i| {
                                t0.v.cmov(ct.eql(k, @as(u8, @truncate(i + 1))), t.v);
                            }
                        }
                        const t1 = self.montgomeryMul(out, t0);
                        if (public) {
                            @memcpy(out.v.limbs(), t1.v.limbsConst());
                        } else {
                            out.v.cmov(!ct.eql(k, 0), t1.v);
                        }
                    }
                    switch (endian) {
                        .big => {
                            s += 1;
                            if (s == e.len) break;
                        },
                        .little => {
                            if (s == 0) break;
                            s -= 1;
                        },
                    }
                }
            }
            self.fromMontgomery(&out) catch unreachable;
            return out;
        }

        /// Multiplies two field elements.
        pub fn mul(self: Self, x: Fe, y: Fe) Fe {
            if (x.montgomery != y.montgomery) {
                return self.montgomeryMul(x, y);
            }
            var a_ = x;
            if (x.montgomery == false) {
                self.toMontgomery(&a_) catch unreachable;
            } else {
                self.fromMontgomery(&a_) catch unreachable;
            }
            return self.montgomeryMul(a_, y);
        }

        /// Squares a field element.
        pub fn sq(self: Self, x: Fe) Fe {
            var out = x;
            if (x.montgomery == true) {
                self.fromMontgomery(&out) catch unreachable;
            }
            out = self.montgomerySq(out);
            out.montgomery = false;
            self.toMontgomery(&out) catch unreachable;
            return out;
        }

        /// Returns x^e (mod m) in constant time.
        pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
            var buf: [Fe.encoded_bytes]u8 = undefined;
            e.toBytes(&buf, native_endian) catch unreachable;
            return self.powWithEncodedExponent(x, &buf, native_endian);
        }

        /// Returns x^e (mod m), assuming that the exponent is public.
        /// The function remains constant time with respect to `x`.
        pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
            var e_normalized = Fe{ .v = e.v.normalize() };
            var buf_: [Fe.encoded_bytes]u8 = undefined;
            var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
            e_normalized.toBytes(buf, .little) catch unreachable;
            const leading = @clz(e_normalized.v.limbsConst()[e_normalized.v.limbs_len - carry_bits]);
            buf = buf[0 .. buf.len - leading / 8];
            return self.powWithEncodedPublicExponent(x, buf, .little);
        }

        /// Returns x^e (mod m), with the exponent provided as a byte string.
        /// Exponents are usually small, so this function is faster than `powPublic` as a field element
        /// doesn't have to be created if a serialized representation is already available.
        ///
        /// If the exponent is public, `powWithEncodedPublicExponent()` can be used instead for a slight speedup.
        pub fn powWithEncodedExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
            return self.powWithEncodedExponentInternal(x, e, endian, false);
        }

        /// Returns x^e (mod m), the exponent being public and provided as a byte string.
        /// Exponents are usually small, so this function is faster than `powPublic` as a field element
        /// doesn't have to be created if a serialized representation is already available.
        ///
        /// If the exponent is secret, `powWithEncodedExponent` must be used instead.
        pub fn powWithEncodedPublicExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
            return self.powWithEncodedExponentInternal(x, e, endian, true);
        }
    };
}