From 69b39ac40b1f66593e53f93ee0b3a9d5449045e3 Mon Sep 17 00:00:00 2001 From: Jamie-Cui Date: Thu, 6 Jun 2024 16:28:43 +0800 Subject: [PATCH] repo-sync-2024-06-06T16:28:31+0800 --- yacl/crypto/ecc/FourQlib/BUILD.bazel | 1 + yacl/crypto/ecc/FourQlib/FourQ_group.cc | 51 +++- yacl/crypto/ecc/ec_point.h | 5 + yacl/crypto/ecc/lib25519/ed25519_group.cc | 5 +- yacl/crypto/ecc/lib25519/lib25519_group.cc | 19 +- yacl/crypto/ecc/libsodium/BUILD.bazel | 15 + yacl/crypto/ecc/libsodium/sodium_factory.cc | 9 + yacl/crypto/ecc/libsodium/x25519_group.cc | 195 +++++++++++++ yacl/crypto/ecc/libsodium/x25519_group.h | 70 +++++ yacl/crypto/ecc/mcl/mcl_ec_group.cc | 2 +- yacl/crypto/pairing/pairing_test.cc | 2 +- yacl/kernel/BUILD.bazel | 33 +++ yacl/kernel/algorithms/base_vole.h | 2 +- yacl/kernel/algorithms/base_vole_test.cc | 2 +- yacl/kernel/algorithms/ferret_ote.cc | 186 +++++++++++- yacl/kernel/algorithms/ferret_ote.h | 21 +- yacl/kernel/algorithms/ferret_ote_rn.h | 4 +- yacl/kernel/algorithms/ferret_ote_test.cc | 12 +- yacl/kernel/algorithms/gywz_ote_test.cc | 4 +- yacl/kernel/algorithms/iknp_ote.cc | 2 +- yacl/kernel/algorithms/iknp_ote.h | 2 +- yacl/kernel/algorithms/iknp_ote_test.cc | 4 +- yacl/kernel/algorithms/kos_ote.cc | 2 +- yacl/kernel/algorithms/kos_ote.h | 2 +- yacl/kernel/algorithms/kos_ote_test.cc | 6 +- yacl/kernel/algorithms/ot_store.cc | 187 +++++++----- yacl/kernel/algorithms/ot_store.h | 68 +++-- yacl/kernel/algorithms/ot_store_test.cc | 44 +++ yacl/kernel/algorithms/sgrr_ote.cc | 2 +- yacl/kernel/algorithms/softspoken_ote.cc | 266 ++++++++++++++++- yacl/kernel/algorithms/softspoken_ote.h | 48 ++- yacl/kernel/algorithms/softspoken_ote_test.cc | 36 +++ yacl/kernel/kernel.h | 29 +- yacl/kernel/ot_kernel.cc | 275 ++++++++++++++++++ yacl/kernel/ot_kernel.h | 129 ++++++++ yacl/kernel/ot_kernel_bench.cc | 91 ++++++ yacl/kernel/ot_kernel_test.cc | 112 +++++++ yacl/kernel/svole_kernel.h | 2 +- yacl/link/transport/channel_test.cc | 157 +++++----- yacl/math/galois_field/factory/mcl_factory.cc | 4 +- 40 files changed, 1843 insertions(+), 263 deletions(-) create mode 100644 yacl/crypto/ecc/libsodium/x25519_group.cc create mode 100644 yacl/crypto/ecc/libsodium/x25519_group.h create mode 100644 yacl/kernel/ot_kernel.cc create mode 100644 yacl/kernel/ot_kernel.h create mode 100644 yacl/kernel/ot_kernel_bench.cc create mode 100644 yacl/kernel/ot_kernel_test.cc diff --git a/yacl/crypto/ecc/FourQlib/BUILD.bazel b/yacl/crypto/ecc/FourQlib/BUILD.bazel index df3d2941..d0138bfb 100644 --- a/yacl/crypto/ecc/FourQlib/BUILD.bazel +++ b/yacl/crypto/ecc/FourQlib/BUILD.bazel @@ -29,6 +29,7 @@ yacl_cc_library( hdrs = ["FourQ_group.h"], deps = [ "//yacl/crypto/ecc:spi", + "//yacl/crypto/hash:ssl_hash", "@com_github_microsoft_FourQlib//:FourQlib", ], ) diff --git a/yacl/crypto/ecc/FourQlib/FourQ_group.cc b/yacl/crypto/ecc/FourQlib/FourQ_group.cc index 5dbd287c..0ae320b0 100644 --- a/yacl/crypto/ecc/FourQlib/FourQ_group.cc +++ b/yacl/crypto/ecc/FourQlib/FourQ_group.cc @@ -14,13 +14,21 @@ #include "yacl/crypto/ecc/FourQlib/FourQ_group.h" +#include "absl/types/span.h" + +#include "yacl/crypto/hash/ssl_hash.h" + namespace yacl::crypto::FourQ { // Elements (a+b*i) over GF(p^2), where a and b are defined over GF(p), are // encoded as a||b, with a in the least significant position. MPInt F2elm2MPInt(const f2elm_t f2elm) { + f2elm_t c; + fp2copy1271(const_cast(f2elm), c); + mod1271(c[0]); + mod1271(c[1]); MPInt r(0, 256); - r.FromMagBytes(yacl::ByteContainerView(f2elm, 32), Endian::little); + r.FromMagBytes(yacl::ByteContainerView(c, 32), Endian::little); return r; } @@ -28,6 +36,8 @@ MPInt F2elm2MPInt(const f2elm_t f2elm) { void MPIntToF2elm(const MPInt& x, f2elm_t f2elm) { memset(f2elm, 0, 32); x.ToMagBytes(reinterpret_cast(f2elm), 32, Endian::little); + mod1271(f2elm[0]); + mod1271(f2elm[1]); } FourQGroup::FourQGroup(const CurveMeta& meta) : EcGroupSketch(meta) { @@ -257,8 +267,28 @@ EcPoint FourQGroup::DeserializePoint(ByteContainerView buf, return r; } -EcPoint FourQGroup::HashToCurve(HashToCurveStrategy, std::string_view) const { - YACL_THROW("not impl"); +EcPoint FourQGroup::HashToCurve(HashToCurveStrategy strategy, + std::string_view input) const { + YACL_ENFORCE(strategy == HashToCurveStrategy::Autonomous, + "FourQlib only supports Autonomous strategy now. select={}", + static_cast(strategy)); + + std::vector sha_bytes = + SslHash(HashAlgorithm::SHA512) + .Update(absl::Span(input.data(), input.size())) + .CumulativeHash(); + auto* f2elmt = reinterpret_cast(sha_bytes.data()); + mod1271(reinterpret_cast(f2elmt)[0]); + mod1271(reinterpret_cast(f2elmt)[1]); + + point_t p; + ECCRYPTO_STATUS status = ::HashToCurve(reinterpret_cast(f2elmt), p); + YACL_ENFORCE(status == ECCRYPTO_SUCCESS, FourQ_get_error_message(status)); + + EcPoint r(std::in_place_type); + point_setup(p, CastR1(r)); + + return r; } size_t FourQGroup::HashPoint(const EcPoint& point) const { @@ -295,6 +325,10 @@ bool FourQGroup::PointEqual(const EcPoint& p1, const EcPoint& p2) const { f2elm_t b; fp2mul1271(p1p->x, p2p->z, a); fp2mul1271(p1p->z, p2p->x, b); + mod1271(a[0]); + mod1271(a[1]); + mod1271(b[0]); + mod1271(b[1]); auto* pa = reinterpret_cast(a); auto* pb = reinterpret_cast(b); for (size_t i = 0; i < 2 * NWORDS_FIELD; ++i) { @@ -305,6 +339,10 @@ bool FourQGroup::PointEqual(const EcPoint& p1, const EcPoint& p2) const { fp2mul1271(p1p->y, p2p->z, a); fp2mul1271(p1p->z, p2p->y, b); + mod1271(a[0]); + mod1271(a[1]); + mod1271(b[0]); + mod1271(b[1]); pa = reinterpret_cast(a); pb = reinterpret_cast(b); for (size_t i = 0; i < 2 * NWORDS_FIELD; ++i) { @@ -331,7 +369,10 @@ bool FourQGroup::IsInfinity(const EcPoint& point) const { const_cast(reinterpret_cast(CastR1(point)->x)); auto* z = const_cast(reinterpret_cast(CastR1(point)->z)); - + mod1271(x); + mod1271(x + 2); + mod1271(z); + mod1271(z + 2); return is_zero_ct(x, 2 * NWORDS_FIELD) || is_zero_ct(z, 2 * NWORDS_FIELD); } @@ -347,4 +388,4 @@ point_extproj* FourQGroup::CastR1(EcPoint& p) { return reinterpret_cast(std::get(p).data()); } -} // namespace yacl::crypto::FourQ \ No newline at end of file +} // namespace yacl::crypto::FourQ diff --git a/yacl/crypto/ecc/ec_point.h b/yacl/crypto/ecc/ec_point.h index fbd00e72..aba51010 100644 --- a/yacl/crypto/ecc/ec_point.h +++ b/yacl/crypto/ecc/ec_point.h @@ -30,6 +30,11 @@ enum class PointOctetFormat { // The format is determined by the library itself. Autonomous, + // Uncompressed format + // The point is encoded as x||y + // For X25519, only need the x value + Uncompressed, + // ANSI X9.62 compressed format // The point is encoded as z||x, where the octet z specifies which solution of // the quadratic equation y is. diff --git a/yacl/crypto/ecc/lib25519/ed25519_group.cc b/yacl/crypto/ecc/lib25519/ed25519_group.cc index 22694477..4c4408e4 100644 --- a/yacl/crypto/ecc/lib25519/ed25519_group.cc +++ b/yacl/crypto/ecc/lib25519/ed25519_group.cc @@ -17,9 +17,10 @@ namespace yacl::crypto::lib25519 { MPInt Fe25519ToMPInt(const fe25519& x) { - // TODO: whether to freeze x first? + fe25519 t = x; + fe25519_freeze(&t); MPInt r(0, 255); - r.FromMagBytes(yacl::ByteContainerView(&x, 32), Endian::little); + r.FromMagBytes(yacl::ByteContainerView(&t, 32), Endian::little); return r; } diff --git a/yacl/crypto/ecc/lib25519/lib25519_group.cc b/yacl/crypto/ecc/lib25519/lib25519_group.cc index 46712a1a..c4b0202e 100644 --- a/yacl/crypto/ecc/lib25519/lib25519_group.cc +++ b/yacl/crypto/ecc/lib25519/lib25519_group.cc @@ -118,15 +118,12 @@ size_t Lib25519Group::HashPoint(const EcPoint& point) const { const auto* p3 = CastP3(point); fe25519 recip; fe25519 x; - fe25519_invert(&recip, &p3->z); fe25519_mul(&x, &p3->x, &recip); - - uint64_t buf[4]; // x is always 255 bits - fe25519_pack(reinterpret_cast(buf), &x); + fe25519_freeze(&x); std::hash h; - return h(buf[0]) ^ h(buf[1]) ^ h(buf[2]) ^ h(buf[3]); + return h(x.v[0]) ^ h(x.v[1]) ^ h(x.v[2]) ^ h(x.v[3]); } bool Lib25519Group::PointEqual(const EcPoint& p1, const EcPoint& p2) const { @@ -143,19 +140,13 @@ bool Lib25519Group::PointEqual(const EcPoint& p1, const EcPoint& p2) const { fe25519 b; fe25519_mul(&a, &p1p->x, &p2p->z); fe25519_mul(&b, &p1p->z, &p2p->x); - for (size_t i = 0; i < sizeof(fe25519) / sizeof(a.v[0]); ++i) { - if (a.v[i] != b.v[i]) { - return false; - } + if (!fe25519_iseq_vartime(&a, &b)) { + return false; } fe25519_mul(&a, &p1p->y, &p2p->z); fe25519_mul(&b, &p1p->z, &p2p->y); - uint128_t buf_a[2]; - uint128_t buf_b[2]; - fe25519_pack(reinterpret_cast(buf_a), &a); - fe25519_pack(reinterpret_cast(buf_b), &b); - return buf_a[0] == buf_b[0] && buf_a[1] == buf_b[1]; + return fe25519_iseq_vartime(&a, &b); } const ge25519_p3* Lib25519Group::CastP3(const yacl::crypto::EcPoint& p) { diff --git a/yacl/crypto/ecc/libsodium/BUILD.bazel b/yacl/crypto/ecc/libsodium/BUILD.bazel index f6b9af5f..d2726a7c 100644 --- a/yacl/crypto/ecc/libsodium/BUILD.bazel +++ b/yacl/crypto/ecc/libsodium/BUILD.bazel @@ -23,6 +23,7 @@ yacl_cc_library( ], deps = [ ":ed25519_group", + ":x25519_group", ], alwayslink = 1, ) @@ -57,6 +58,20 @@ yacl_cc_library( ], ) +yacl_cc_library( + name = "x25519_group", + srcs = [ + "x25519_group.cc", + ], + hdrs = [ + "x25519_group.h", + ], + deps = [ + ":sodium_group", + "//yacl/crypto/hash:hash_utils", + ], +) + yacl_cc_test( name = "ed25519_test", srcs = ["ed25519_test.cc"], diff --git a/yacl/crypto/ecc/libsodium/sodium_factory.cc b/yacl/crypto/ecc/libsodium/sodium_factory.cc index 6b910a6a..8f7c5d8f 100644 --- a/yacl/crypto/ecc/libsodium/sodium_factory.cc +++ b/yacl/crypto/ecc/libsodium/sodium_factory.cc @@ -15,6 +15,7 @@ #include #include "yacl/crypto/ecc/libsodium/ed25519_group.h" +#include "yacl/crypto/ecc/libsodium/x25519_group.h" namespace yacl::crypto::sodium { @@ -23,6 +24,12 @@ const std::string kLibName = "libsodium"; std::map kPredefinedCurves = { {"ed25519", + { + (2_mp).Pow(255) - 19_mp, // p = 2^255 - 19 + (2_mp).Pow(252) + "0x14def9dea2f79cd65812631a5cf5d3ed"_mp, // n + "8"_mp // h + }}, + {"curve25519", { (2_mp).Pow(255) - 19_mp, // p = 2^255 - 19 (2_mp).Pow(252) + "0x14def9dea2f79cd65812631a5cf5d3ed"_mp, // n @@ -36,6 +43,8 @@ std::unique_ptr Create(const CurveMeta &meta) { if (meta.LowerName() == "ed25519") { return std::make_unique(meta, conf); + } else if (meta.LowerName() == "curve25519") { + return std::make_unique(meta, conf); } else { YACL_THROW("unexpected curve {}", meta.name); } diff --git a/yacl/crypto/ecc/libsodium/x25519_group.cc b/yacl/crypto/ecc/libsodium/x25519_group.cc new file mode 100644 index 00000000..c1a3e2d9 --- /dev/null +++ b/yacl/crypto/ecc/libsodium/x25519_group.cc @@ -0,0 +1,195 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/ecc/libsodium/x25519_group.h" + +#include "sodium/crypto_scalarmult_curve25519.h" + +#include "yacl/crypto/hash/hash_utils.h" + +namespace yacl::crypto::sodium { + +X25519Group::X25519Group(const CurveMeta& meta, const CurveParam& param) + : SodiumGroup(meta, param) {} + +EcPoint X25519Group::GetGenerator() const { YACL_THROW("not implemented"); } + +EcPoint X25519Group::Add(const EcPoint&, const EcPoint&) const { + YACL_THROW("not implemented"); +} + +void X25519Group::AddInplace(EcPoint*, const EcPoint&) const { + YACL_THROW("not implemented"); +} + +EcPoint X25519Group::Sub(const EcPoint&, const EcPoint&) const { + YACL_THROW("not implemented"); +} + +void X25519Group::SubInplace(EcPoint*, const EcPoint&) const { + YACL_THROW("not implemented"); +} + +EcPoint X25519Group::Double(const EcPoint&) const { + YACL_THROW("not implemented"); +} + +void X25519Group::DoubleInplace(EcPoint*) const { + YACL_THROW("not implemented"); +} + +EcPoint X25519Group::MulBase(const MPInt& scalar) const { + Array32 buf; + memset(buf.data(), 0, sizeof(buf)); + scalar.Mod(param_.n).ToMagBytes(buf.data(), buf.size(), Endian::little); + + EcPoint r(std::in_place_type); + YACL_ENFORCE(0 == + crypto_scalarmult_curve25519_base(CastString(r), buf.data())); + return r; +} + +EcPoint X25519Group::Mul(const EcPoint& point, const MPInt& scalar) const { + Array32 buf; + memset(buf.data(), 0, sizeof(buf)); + scalar.Mod(param_.n).ToMagBytes(buf.data(), buf.size(), Endian::little); + + EcPoint r(std::in_place_type); + YACL_ENFORCE(0 == crypto_scalarmult_curve25519(CastString(r), buf.data(), + CastString(point))); + + return r; +} + +void X25519Group::MulInplace(EcPoint* point, const MPInt& scalar) const { + Array32 buf; + memset(buf.data(), 0, sizeof(buf)); + scalar.ToMagBytes(buf.data(), buf.size(), Endian::little); + + YACL_ENFORCE(0 == crypto_scalarmult_curve25519(CastString(*point), buf.data(), + CastString(*point))); +} + +EcPoint X25519Group::MulDoubleBase(const MPInt&, const MPInt&, + const EcPoint&) const { + YACL_THROW("not implemented"); +} + +EcPoint X25519Group::Negate(const EcPoint&) const { + YACL_THROW("not implemented"); +} + +void X25519Group::NegateInplace(EcPoint*) const { + YACL_THROW("not implemented"); +} + +AffinePoint X25519Group::GetAffinePoint(const EcPoint&) const { + YACL_THROW("not implemented"); +} + +bool X25519Group::IsInCurveGroup(const EcPoint&) const { + YACL_THROW("not implemented"); +} + +bool X25519Group::IsInfinity(const EcPoint&) const { + YACL_THROW("not implemented"); +} + +uint64_t X25519Group::GetSerializeLength(PointOctetFormat format) const { + switch (format) { + case PointOctetFormat::Autonomous: + case PointOctetFormat::Uncompressed: + return 32; + default: + YACL_THROW("{} only support Uncompressed format, given={}", + GetLibraryName(), static_cast(format)); + } +} + +Buffer X25519Group::SerializePoint(const EcPoint& point, + PointOctetFormat format) const { + switch (format) { + case PointOctetFormat::Autonomous: + case PointOctetFormat::Uncompressed: { + Buffer buf(32); + memcpy(buf.data(), CastString(point), 32); + return buf; + } + default: + YACL_THROW("{} only support Uncompressed format, given={}", + GetLibraryName(), static_cast(format)); + } +} + +void X25519Group::SerializePoint(const EcPoint& point, PointOctetFormat format, + Buffer* buf) const { + *buf = SerializePoint(point, format); +} + +void X25519Group::SerializePoint(const EcPoint& point, PointOctetFormat format, + uint8_t* buf, uint64_t buf_size) const { + switch (format) { + case PointOctetFormat::Autonomous: + case PointOctetFormat::Uncompressed: { + YACL_ENFORCE(buf_size >= 32, "buf size is smaller than needed 32"); + memcpy(buf, CastString(point), 32); + break; + } + default: + YACL_THROW("{} only support Uncompressed format, given={}", + GetLibraryName(), static_cast(format)); + } +} + +EcPoint X25519Group::DeserializePoint(ByteContainerView buf, + PointOctetFormat format) const { + switch (format) { + case PointOctetFormat::Autonomous: + case PointOctetFormat::Uncompressed: { + YACL_ENFORCE(buf.size() == 32, "buf size not equal to 32"); + EcPoint p(std::in_place_type); + memcpy(CastString(p), buf.data(), buf.size()); + return p; + } + default: + YACL_THROW("{} only support Uncompressed format, given={}", + GetLibraryName(), static_cast(format)); + } +} + +EcPoint X25519Group::HashToCurve(HashToCurveStrategy strategy, + std::string_view input) const { + switch (strategy) { + case HashToCurveStrategy::Autonomous: + case HashToCurveStrategy::HashAsPointX_SHA2: + return yacl::crypto::Sha256(input); + default: + YACL_THROW("hash to curve strategy {} not supported", + static_cast(strategy)); + } +} + +const unsigned char* X25519Group::CastString(const EcPoint& p) { + YACL_ENFORCE(std::holds_alternative(p), + "Illegal EcPoint, expected Array32, real={}", p.index()); + return std::get(p).data(); +} + +unsigned char* X25519Group::CastString(EcPoint& p) { + YACL_ENFORCE(std::holds_alternative(p), + "Illegal EcPoint, expected Array32, real={}", p.index()); + return std::get(p).data(); +} + +} // namespace yacl::crypto::sodium \ No newline at end of file diff --git a/yacl/crypto/ecc/libsodium/x25519_group.h b/yacl/crypto/ecc/libsodium/x25519_group.h new file mode 100644 index 00000000..c422d9f4 --- /dev/null +++ b/yacl/crypto/ecc/libsodium/x25519_group.h @@ -0,0 +1,70 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/crypto/ecc/ec_point.h" +#include "yacl/crypto/ecc/libsodium/sodium_group.h" + +namespace yacl::crypto::sodium { + +class X25519Group : public SodiumGroup { + public: + X25519Group(const CurveMeta& meta, const CurveParam& param); + EcPoint GetGenerator() const override; + + EcPoint Add(const EcPoint& p1, const EcPoint& p2) const override; + void AddInplace(EcPoint* p1, const EcPoint& p2) const override; + + EcPoint Sub(const EcPoint& p1, const EcPoint& p2) const override; + void SubInplace(EcPoint* p1, const EcPoint& p2) const override; + + EcPoint Double(const EcPoint& p) const override; + void DoubleInplace(EcPoint* p) const override; + + EcPoint Mul(const EcPoint& point, const MPInt& scalar) const override; + void MulInplace(EcPoint* point, const MPInt& scalar) const override; + + EcPoint MulBase(const MPInt& scalar) const override; + EcPoint MulDoubleBase(const MPInt& s1, const MPInt& s2, + const EcPoint& p2) const override; + EcPoint Negate(const EcPoint& point) const override; + void NegateInplace(EcPoint* point) const override; + + // EcPoint(SodiumPoint) -> AffinePoint + AffinePoint GetAffinePoint(const EcPoint& point) const override; + + bool IsInCurveGroup(const EcPoint& point) const override; + bool IsInfinity(const EcPoint& point) const override; + + uint64_t GetSerializeLength(PointOctetFormat format) const override; + + Buffer SerializePoint(const EcPoint& point, + PointOctetFormat format) const override; + void SerializePoint(const EcPoint& point, PointOctetFormat format, + Buffer* buf) const override; + void SerializePoint(const EcPoint& point, PointOctetFormat format, + uint8_t* buf, uint64_t buf_size) const override; + EcPoint DeserializePoint(ByteContainerView buf, + PointOctetFormat format) const override; + + EcPoint HashToCurve(HashToCurveStrategy strategy, + std::string_view str) const override; + + private: + static const unsigned char* CastString(const EcPoint& p); + static unsigned char* CastString(EcPoint& p); +}; + +} // namespace yacl::crypto::sodium \ No newline at end of file diff --git a/yacl/crypto/ecc/mcl/mcl_ec_group.cc b/yacl/crypto/ecc/mcl/mcl_ec_group.cc index 49a31ca0..830593f5 100644 --- a/yacl/crypto/ecc/mcl/mcl_ec_group.cc +++ b/yacl/crypto/ecc/mcl/mcl_ec_group.cc @@ -422,7 +422,7 @@ EcPoint MclGroupT::HashToStdCurve(HashToCurveStrategy strategy, buf = Blake3Hash((bits + 7) / 8).Update(str).CumulativeHash(); } - Fp p; + Fp p{0}; p.deserialize(buf.data(), buf.size()); mcl::ec::tryAndIncMapTo(*CastAny(ret), p); return ret; diff --git a/yacl/crypto/pairing/pairing_test.cc b/yacl/crypto/pairing/pairing_test.cc index 9ca743df..28322d72 100644 --- a/yacl/crypto/pairing/pairing_test.cc +++ b/yacl/crypto/pairing/pairing_test.cc @@ -390,7 +390,7 @@ TEST_P(BNSnarkTest, SpiTest) { } } -TEST(Pairing_Multi_Instance_Test, Works) { +TEST(DISABLED_Pairing_Multi_Instance_Test, Works) { PairingName pairing_name = "bls12-381"; for (auto lib_name : PairingGroupFactory::Instance().ListLibraries(pairing_name)) { diff --git a/yacl/kernel/BUILD.bazel b/yacl/kernel/BUILD.bazel index 596245d4..c61086f9 100644 --- a/yacl/kernel/BUILD.bazel +++ b/yacl/kernel/BUILD.bazel @@ -55,3 +55,36 @@ yacl_cc_test( "//yacl/link:test_util", ], ) + +yacl_cc_library( + name = "ot_kernel", + srcs = [ + "ot_kernel.cc", + ], + hdrs = ["ot_kernel.h"], + deps = [ + ":kernel", + "//yacl/kernel/algorithms:ferret_ote", + "//yacl/kernel/algorithms:softspoken_ote", + ], +) + +yacl_cc_test( + name = "ot_kernel_test", + srcs = ["ot_kernel_test.cc"], + copts = AES_COPT_FLAGS, + deps = [ + ":ot_kernel", + "//yacl/link:test_util", + ], +) + +yacl_cc_binary( + name = "ot_kernel_bench", + srcs = ["ot_kernel_bench.cc"], + deps = [ + ":ot_kernel", + "//yacl/link:test_util", + "@com_github_google_benchmark//:benchmark", + ], +) diff --git a/yacl/kernel/algorithms/base_vole.h b/yacl/kernel/algorithms/base_vole.h index b0a9f43a..29b79cd0 100644 --- a/yacl/kernel/algorithms/base_vole.h +++ b/yacl/kernel/algorithms/base_vole.h @@ -75,7 +75,7 @@ void inline Ot2VoleRecv(OtRecvStore& recv_ot, absl::Span u, YACL_ENFORCE(recv_ot.Size() >= size * T_bits); // [Warning] Copying, low efficiency - auto choices = recv_ot.CopyChoice(); + auto choices = recv_ot.CopyBitBuf(); memcpy(u.data(), choices.data(), size * sizeof(T)); std::array v_buff; diff --git a/yacl/kernel/algorithms/base_vole_test.cc b/yacl/kernel/algorithms/base_vole_test.cc index 80cdaf03..901fca11 100644 --- a/yacl/kernel/algorithms/base_vole_test.cc +++ b/yacl/kernel/algorithms/base_vole_test.cc @@ -75,7 +75,7 @@ DECLARE_OT2VOLE_TEST(GF128, GF128); // Vole: GF(2^128) x GF(2^128) auto lctxs = link::test::SetupWorld(2); \ const uint64_t vole_num = GetParam().num; \ auto rot = MockRots(128); \ - auto delta128 = rot.recv.CopyChoice().data()[0]; \ + auto delta128 = rot.recv.CopyBitBuf().data()[0]; \ std::vector u(vole_num); \ std::vector v(vole_num); \ std::vector w(vole_num); \ diff --git a/yacl/kernel/algorithms/ferret_ote.cc b/yacl/kernel/algorithms/ferret_ote.cc index edf0be17..7f6535f3 100644 --- a/yacl/kernel/algorithms/ferret_ote.cc +++ b/yacl/kernel/algorithms/ferret_ote.cc @@ -49,6 +49,184 @@ uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t /*ot_num*/, return lpn_param.k + mpcot_cot + check_cot; } +void FerretOtExtSend(const std::shared_ptr& ctx, + const OtSendStore& base_cot, const LpnParam& lpn_param, + uint64_t ot_num, /* compact mode */ OtSendStore* out, + bool mal) { + YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties + YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); + YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); + YACL_ENFORCE( + ot_num >= 2 * lpn_param.t, + "ot_num is {}, which should be much greater than 2 * lpn_param.t ({})", + ot_num, 2 * lpn_param.t); + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + YACL_ENFORCE(out->Size() == ot_num); + + // get constants: the number of cot needed for mpcot phase + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); + + // get constants: batch information + const uint64_t cache_size = lpn_param.k + mpcot_cot_num; + const uint64_t batch_size = lpn_param.n - cache_size; + const uint64_t batch_num = (ot_num + batch_size - 1) / batch_size; + out->SetDelta(base_cot.GetDelta()); + + // prepare v (before silent expansion), where w = v ^ u * delta + // FIX ME: "Slice" would would force to slice original OtStore from "begin" to + // "end", it would be better to use "NextSlice" here, but it's not a const + // function. + auto cot_mpcot = base_cot.Slice(0, mpcot_cot_num); + auto cot_seed = base_cot.Slice(mpcot_cot_num, mpcot_cot_num + lpn_param.k); + auto working_v = cot_seed.CopyCotBlkBuf(); + + // get lpn public matrix A + uint128_t seed = SyncSeedSend(ctx); + LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); + + // placeholder for the outputs + auto out_span = out->GetBlkBufSpan(); + + // For uniform noise assumption only + // CuckooIndex::Options option; + // std::unique_ptr simple_map; + // if (lpn_param.noise_asm == LpnNoiseAsm::UniformNoise) { + // YACL_THROW("Not Implemented!"); + // option = CuckooIndex::SelectParams(lpn_param.t, kFerretCuckooStashNum, + // kFerretCuckooHashNum); + // simple_map = MakeSimpleMap(option, lpn_param.n); + // } + + auto spcot_size = lpn_param.n / lpn_param.t; + for (uint64_t i = 0; i < batch_num; ++i) { + // the ot generated by this batch (including the seeds for next batch if + // necessary) + auto batch_ot_num = std::min(lpn_param.n, ot_num - i * batch_size); + auto working_s = out_span.subspan(i * batch_size, batch_ot_num); + + auto idx_num = lpn_param.t; + auto idx_range = batch_ot_num; + if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { + MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s, + mal); + } else { + YACL_THROW("Not Implemented!"); + // MpCotUNSend(ctx, cot_mpcot, simple_map, option, working_s); + } + + // use lpn to calculate v*A + // llc.Encode(in,out) would calculate out = out + in * A + llc.Encode(working_v, working_s); + + // bool is_last_batch = (i == batch_num - 1); + // update v (first lpn_k of va^s) + if ((ot_num - i * batch_size) > batch_ot_num) { + // update v for the next batch + memcpy(working_v.data(), working_s.data() + batch_size, + lpn_param.k * sizeof(uint128_t)); + + // manually set the cot for next batch mpcot + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_s.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = + MakeCompactOtSendStore(std::move(mpcot_data), base_cot.GetDelta()); + } else { + break; + } + } +} + +void FerretOtExtRecv(const std::shared_ptr& ctx, + const OtRecvStore& base_cot, const LpnParam& lpn_param, + uint64_t ot_num, /* compact mode */ OtRecvStore* out, + bool mal) { + YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties + YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); + YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); + YACL_ENFORCE( + ot_num >= 2 * lpn_param.t, + "ot_num is {}, which should be much greater than 2 * lpn_param.t ({})", + ot_num, 2 * lpn_param.t); + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + YACL_ENFORCE(out->Size() == ot_num); + + // get constants: the number of cot needed for mpcot phase + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); + + // get constants: batch information + const uint64_t cache_size = lpn_param.k + mpcot_cot_num; + const uint64_t batch_size = lpn_param.n - cache_size; + const uint64_t batch_num = (ot_num + batch_size - 1) / batch_size; + + // F2, but we store it in uint128_t + UninitAlignedVector u(lpn_param.k); + + // prepare u, w, where w = v ^ u * delta + // FIX ME: "Slice" would would force to slice original OtStore from "begin" to + // "end", it would be better to use "NextSlice" here, but it's not a const + // function. + auto cot_mpcot = base_cot.Slice(0, mpcot_cot_num); + auto cot_seed = base_cot.Slice(mpcot_cot_num, mpcot_cot_num + lpn_param.k); + auto working_w = cot_seed.CopyBlkBuf(); + + // get lpn public matrix A + uint128_t seed = SyncSeedRecv(ctx); + LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); + + // placeholder for the outputs + auto out_span = out->GetBlkBufSpan(); + + // For uniform noise assumption only + // CuckooIndex::Options option; + // std::unique_ptr simple_map; + // if (lpn_param.noise_asm == LpnNoiseAsm::UniformNoise) { + // YACL_THROW("Not Implemented!"); + // option = CuckooIndex::SelectParams(lpn_param.t, kFerretCuckooStashNum, + // kFerretCuckooHashNum); + // simple_map = MakeSimpleMap(option, lpn_param.n); + // } + + auto spcot_size = lpn_param.n / lpn_param.t; + for (uint64_t i = 0; i < batch_num; ++i) { + // the ot generated by this batch (including the seeds for next batch if + // necessary) + auto batch_ot_num = std::min(lpn_param.n, ot_num - i * batch_size); + auto working_r = out_span.subspan(i * batch_size, batch_ot_num); + + // run mpcot (get r) + auto idx_num = lpn_param.t; + auto idx_range = batch_ot_num; + + if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { + MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r, + mal); + } else { + YACL_THROW("Not Implemented!"); + // MpCotUNRecv(ctx, cot_mpcot, simple_map, option, e, working_r); + } + + // use lpn to calculate w*A, and u*A + // llc.Encode(in,out) would calculate out = out + in * A + llc.Encode(working_w, working_r); + + // bool is_last_batch = (i == batch_num - 1); + if ((ot_num - i * batch_size) > batch_ot_num) { + // update u, w (first lpn_k of va^s) + memcpy(working_w.data(), working_r.data() + batch_size, + lpn_param.k * sizeof(uint128_t)); + + // manually set the cot for next batch mpcot + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_r.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = MakeCompactOtRecvStore(std::move(mpcot_data)); + } else { + break; + } + } +} + OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, const OtSendStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, @@ -76,7 +254,7 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, // function. auto cot_mpcot = base_cot.Slice(0, mpcot_cot_num); auto cot_seed = base_cot.Slice(mpcot_cot_num, mpcot_cot_num + lpn_param.k); - auto working_v = cot_seed.CopyCotBlocks(); + auto working_v = cot_seed.CopyCotBlkBuf(); // get lpn public matrix A uint128_t seed = SyncSeedSend(ctx); @@ -167,7 +345,7 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, // function. auto cot_mpcot = base_cot.Slice(0, mpcot_cot_num); auto cot_seed = base_cot.Slice(mpcot_cot_num, mpcot_cot_num + lpn_param.k); - auto working_w = cot_seed.CopyBlocks(); + auto working_w = cot_seed.CopyBlkBuf(); // get lpn public matrix A uint128_t seed = SyncSeedRecv(ctx); @@ -249,7 +427,7 @@ void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, // prepare v (before silent expansion), where w = v ^ u * delta auto cot_seed = base_cot.Slice(0, lpn_param.k); auto cot_mpcot = base_cot.Slice(lpn_param.k, lpn_param.k + mpcot_cot_num); - auto working_v = cot_seed.CopyCotBlocks(); + auto working_v = cot_seed.CopyCotBlkBuf(); // get lpn public matrix A uint128_t seed = SyncSeedSend(ctx); @@ -334,7 +512,7 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, // prepare u, w, where w = v ^ u * delta auto cot_seed = base_cot.Slice(0, lpn_param.k); auto cot_mpcot = base_cot.Slice(lpn_param.k, lpn_param.k + mpcot_cot_num); - auto working_w = cot_seed.CopyBlocks(); + auto working_w = cot_seed.CopyBlkBuf(); // get lpn public matrix A uint128_t seed = SyncSeedRecv(ctx); diff --git a/yacl/kernel/algorithms/ferret_ote.h b/yacl/kernel/algorithms/ferret_ote.h index 4d9f6564..c33452b1 100644 --- a/yacl/kernel/algorithms/ferret_ote.h +++ b/yacl/kernel/algorithms/ferret_ote.h @@ -55,11 +55,25 @@ namespace yacl::crypto { // Security assumptions: // > Correlation-robust hash function, for more details about its // implementation, see `yacl/crypto-tools/rp.h` -// > Primal LPN, for more details, please see the original paper +// > Primal LPN, see: https://eprint.iacr.org/2020/924.pdf uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t ot_num, bool mal = false); +void FerretOtExtSend(const std::shared_ptr& ctx, + /* compact mode */ const OtSendStore& base_cot, + const LpnParam& lpn_param, uint64_t ot_num, + /* compact mode */ OtSendStore* out, bool mal = false); + +void FerretOtExtRecv(const std::shared_ptr& ctx, + /* compact mode */ const OtRecvStore& base_cot, + const LpnParam& lpn_param, uint64_t ot_num, + /* compact mode */ OtRecvStore* out, bool mal = false); + +// ------------------------------ +// Historical or Customized APIs +// ------------------------------ + OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, const OtSendStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, @@ -70,11 +84,6 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, const LpnParam& lpn_param, uint64_t ot_num, bool mal = false); -// -// -------------------------- -// Customized -// -------------------------- -// // [Warning] for cheetah only void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, const OtSendStore& base_cot, diff --git a/yacl/kernel/algorithms/ferret_ote_rn.h b/yacl/kernel/algorithms/ferret_ote_rn.h index 58c455fc..fbb86cdb 100644 --- a/yacl/kernel/algorithms/ferret_ote_rn.h +++ b/yacl/kernel/algorithms/ferret_ote_rn.h @@ -139,9 +139,9 @@ inline void MpCotRNRecv(const std::shared_ptr& ctx, auto uhash = math::UniversalHash(seed, out); // [Warning] low efficency - uint128_t choices = check_cot.CopyChoice().data()[0]; + uint128_t choices = check_cot.CopyBitBuf().data()[0]; - auto check_cot_data = check_cot.CopyBlocks(); + auto check_cot_data = check_cot.CopyBlkBuf(); auto diff = PackGf128(absl::MakeSpan(check_cot_data)); uhash = uhash ^ diff; diff --git a/yacl/kernel/algorithms/ferret_ote_test.cc b/yacl/kernel/algorithms/ferret_ote_test.cc index 80a1dfce..ea1c3c81 100644 --- a/yacl/kernel/algorithms/ferret_ote_test.cc +++ b/yacl/kernel/algorithms/ferret_ote_test.cc @@ -51,14 +51,18 @@ TEST_P(FerretOtExtTest, Works) { auto cots_compact = MockCompactOts(cot_num); // mock cots // WHEN + + OtSendStore ot_send(ot_num, OtStoreType::Compact); + OtRecvStore ot_recv(ot_num, OtStoreType::Compact); + auto sender = std::async([&] { - return FerretOtExtSend(lctxs[0], cots_compact.send, lpn_param, ot_num); + FerretOtExtSend(lctxs[0], cots_compact.send, lpn_param, ot_num, &ot_send); }); auto receiver = std::async([&] { - return FerretOtExtRecv(lctxs[1], cots_compact.recv, lpn_param, ot_num); + FerretOtExtRecv(lctxs[1], cots_compact.recv, lpn_param, ot_num, &ot_recv); }); - auto ot_recv = receiver.get(); - auto ot_send = sender.get(); + receiver.get(); + sender.get(); // THEN auto zero = MakeUint128(0, 0); diff --git a/yacl/kernel/algorithms/gywz_ote_test.cc b/yacl/kernel/algorithms/gywz_ote_test.cc index 2e092fd1..0e19947d 100644 --- a/yacl/kernel/algorithms/gywz_ote_test.cc +++ b/yacl/kernel/algorithms/gywz_ote_test.cc @@ -76,7 +76,7 @@ TEST_P(GywzParamTest, FerretSpCotWork) { auto base_ot = MockCompactOts(math::Log2Ceil(n)); // mock many base OTs auto delta = base_ot.send.GetDelta(); - // [Warning] Compact Cot doest not support CopyChoice. + // [Warning] Compact Cot doest not support CopyBitBuf. // TODO: fix it uint32_t index = 0; for (uint32_t i = 0; i < math::Log2Ceil(n); ++i) { @@ -117,7 +117,7 @@ TEST_P(GywzParamTest, FixIndexSpCotWork) { uint128_t delta = SecureRandSeed(); auto base_ot = MockCots(math::Log2Ceil(n), delta); // mock many base OTs - uint32_t index = base_ot.recv.CopyChoice().data()[0]; + uint32_t index = base_ot.recv.CopyBitBuf().data()[0]; std::vector send_out(n); std::vector recv_out(n); diff --git a/yacl/kernel/algorithms/iknp_ote.cc b/yacl/kernel/algorithms/iknp_ote.cc index 8f6d53cb..9aa9b263 100644 --- a/yacl/kernel/algorithms/iknp_ote.cc +++ b/yacl/kernel/algorithms/iknp_ote.cc @@ -85,7 +85,7 @@ void IknpOtExtSend(const std::shared_ptr& ctx, // Transpose. MatrixTranspose128(&batch0); - auto tmp_choice = base_ot.CopyChoice(); + auto tmp_choice = base_ot.CopyBitBuf(); batch1 = XorBatchedBlock(absl::MakeSpan(batch0), static_cast(*tmp_choice.data())); diff --git a/yacl/kernel/algorithms/iknp_ote.h b/yacl/kernel/algorithms/iknp_ote.h index 8dd01918..c53cdb98 100644 --- a/yacl/kernel/algorithms/iknp_ote.h +++ b/yacl/kernel/algorithms/iknp_ote.h @@ -75,7 +75,7 @@ inline OtSendStore IknpOtExtSend(const std::shared_ptr &ctx, IknpOtExtSend(ctx, base_ot, absl::MakeSpan(blocks), cot); auto ret = MakeOtSendStore(blocks); if (cot) { - auto tmp_choice = base_ot.CopyChoice(); + auto tmp_choice = base_ot.CopyBitBuf(); ret.SetDelta(static_cast(*tmp_choice.data())); } return ret; // FIXME: Drop explicit copy diff --git a/yacl/kernel/algorithms/iknp_ote_test.cc b/yacl/kernel/algorithms/iknp_ote_test.cc index 0807ec35..db28abc3 100644 --- a/yacl/kernel/algorithms/iknp_ote_test.cc +++ b/yacl/kernel/algorithms/iknp_ote_test.cc @@ -112,7 +112,7 @@ TEST_P(IknpCotExtTest, Works) { // THEN // cot correlation = base ot choice - uint128_t check = base_ot.recv.CopyChoice().data()[0]; + uint128_t check = base_ot.recv.CopyBitBuf().data()[0]; for (size_t i = 0; i < num_ot; ++i) { EXPECT_NE(recv_out[i], 0); EXPECT_NE(send_out[i][0], 0); @@ -141,7 +141,7 @@ TEST_P(IknpCotExtTest, OtStoreWorks) { // THEN // cot correlation = base ot choice - uint128_t check = base_ot.recv.CopyChoice().data()[0]; // base ot choices + uint128_t check = base_ot.recv.CopyBitBuf().data()[0]; // base ot choices uint128_t delta = send_out.GetDelta(); // cot's delta EXPECT_EQ(check, delta); for (size_t i = 0; i < num_ot; ++i) { diff --git a/yacl/kernel/algorithms/kos_ote.cc b/yacl/kernel/algorithms/kos_ote.cc index 052415cf..20824bc7 100644 --- a/yacl/kernel/algorithms/kos_ote.cc +++ b/yacl/kernel/algorithms/kos_ote.cc @@ -201,7 +201,7 @@ void KosOtExtSend(const std::shared_ptr& ctx, } } - uint128_t delta = static_cast(*base_ot.CopyChoice().data()); + uint128_t delta = static_cast(*base_ot.CopyBitBuf().data()); q_ext.resize(ot_num_valid); auto& batch0 = q_ext; auto batch1 = VecXorMonochrome(absl::MakeSpan(q_ext), delta); diff --git a/yacl/kernel/algorithms/kos_ote.h b/yacl/kernel/algorithms/kos_ote.h index 8281e3f5..750ff854 100644 --- a/yacl/kernel/algorithms/kos_ote.h +++ b/yacl/kernel/algorithms/kos_ote.h @@ -82,7 +82,7 @@ inline OtSendStore KosOtExtSend(const std::shared_ptr& ctx, KosOtExtSend(ctx, base_ot, absl::MakeSpan(blocks), cot); auto ret = MakeOtSendStore(blocks); if (cot) { - auto tmp_choice = base_ot.CopyChoice(); + auto tmp_choice = base_ot.CopyBitBuf(); ret.SetDelta(static_cast(*tmp_choice.data())); } return ret; // FIXME: Drop explicit copy diff --git a/yacl/kernel/algorithms/kos_ote_test.cc b/yacl/kernel/algorithms/kos_ote_test.cc index 080eba0c..b414698b 100644 --- a/yacl/kernel/algorithms/kos_ote_test.cc +++ b/yacl/kernel/algorithms/kos_ote_test.cc @@ -57,7 +57,7 @@ TEST_P(KosOtExtTest, RotTestWorks) { sender.get(); // THEN - uint128_t delta = ot_store.recv.CopyChoice().data()[0]; + uint128_t delta = ot_store.recv.CopyBitBuf().data()[0]; uint128_t zero = MakeUint128(0, 0); for (size_t i = 0; i < num_ot; ++i) { bool choice = choices[i]; @@ -95,7 +95,7 @@ TEST_P(KosOtExtTest, CotTestWorks) { sender.get(); // THEN - uint128_t check = ot_store.recv.CopyChoice().data()[0]; + uint128_t check = ot_store.recv.CopyBitBuf().data()[0]; uint128_t zero = MakeUint128(0, 0); for (size_t i = 0; i < num_ot; ++i) { bool choice = choices[i]; @@ -154,7 +154,7 @@ TEST_P(KosOtExtTest, CotStoreTestWorks) { auto send_out = sender.get(); // THEN - uint128_t check = ot_store.recv.CopyChoice().data()[0]; // base ot choices + uint128_t check = ot_store.recv.CopyBitBuf().data()[0]; // base ot choices uint128_t delta = send_out.GetDelta(); // cot's delta EXPECT_EQ(check, delta); for (size_t i = 0; i < num_ot; ++i) { diff --git a/yacl/kernel/algorithms/ot_store.cc b/yacl/kernel/algorithms/ot_store.cc index 19cb9335..72c43d10 100644 --- a/yacl/kernel/algorithms/ot_store.cc +++ b/yacl/kernel/algorithms/ot_store.cc @@ -24,9 +24,9 @@ namespace yacl::crypto { -//================================// -// Slice Base // -//================================// +//---------------------------------- +// Slice Base +//---------------------------------- void SliceBase::ConsistencyCheck() const { YACL_ENFORCE( @@ -76,9 +76,9 @@ void SliceBase::Reset() { ConsistencyCheck(); } -//================================// -// OtRecvStore // -//================================// +//---------------------------------- +// OtRecvStore +//---------------------------------- OtRecvStore::OtRecvStore(BitBufPtr bit_ptr, BlkBufPtr blk_ptr, uint64_t use_ctr, uint64_t use_size, uint64_t buf_ctr, uint64_t buf_size, @@ -98,33 +98,29 @@ OtRecvStore::OtRecvStore(uint64_t num, OtStoreType type) : type_(type) { ConsistencyCheck(); } -Buffer OtRecvStore::GetChoiceBuf() { - // Constructs Buffer object by copy - return Buffer(bit_buf_->data(), bit_buf_->num_blocks() * sizeof(uint128_t)); -} +OtRecvStore OtRecvStore::NextSlice(uint64_t num) { + // Recall: A new slice looks like the follwoing: + // + // |---------------|-----slice-----|----------------| internal buffer + // a b c d + // + // internal_use_ctr_ = b + // internal_use_size_ = c - b + // internal_buf_ctr_ = b + // internal_buf_size_ = c -Buffer OtRecvStore::GetBlockBuf() { - // Constructs Buffer object by copy - return Buffer(blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)); -} + auto out = Slice(GetBufCtr(), GetBufCtr() + num); + // |---------------|-----slice-----|----------------| internal buffer + // a b c d + // + // internal_use_ctr_ = a + // internal_use_size_ = d - a + // internal_buf_ctr_ = c (since the underlying buffer is already sliced to c) + // internal_buf_size_ = d - a -void OtRecvStore::Reset() { - SliceBase::Reset(); - bit_buf_.reset(); - blk_buf_.reset(); - type_ = OtStoreType::Compact; - ConsistencyCheck(); -} + IncreaseBufCtr(num); // increase the buffer counter -void OtRecvStore::ConsistencyCheck() const { - SliceBase::ConsistencyCheck(); - YACL_ENFORCE(blk_buf_->size() >= internal_buf_size_, - "Actual buffer size: {}, but recorded " - "internal buffer size is: {}", - blk_buf_->size(), internal_buf_size_); - if (type_ == OtStoreType::Normal) { - YACL_ENFORCE_EQ(bit_buf_->size(), blk_buf_->size()); - } + return out; } // FIX ME: a const OtRecvStore could execute "Slice" to get a non-const @@ -148,32 +144,50 @@ OtRecvStore OtRecvStore::Slice(uint64_t begin, uint64_t end) const { return {bit_buf_, blk_buf_, slice_use_ctr, slice_use_size, slice_buf_ctr, slice_buf_size, type_}; } -OtRecvStore OtRecvStore::NextSlice(uint64_t num) { - // Recall: A new slice looks like the follwoing: - // - // |---------------|-----slice-----|----------------| internal buffer - // a b c d - // - // internal_use_ctr_ = b - // internal_use_size_ = c - b - // internal_buf_ctr_ = b - // internal_buf_size_ = c - auto out = Slice(GetBufCtr(), GetBufCtr() + num); - // |---------------|-----slice-----|----------------| internal buffer - // a b c d - // - // internal_use_ctr_ = a - // internal_use_size_ = d - a - // internal_buf_ctr_ = c (since the underlying buffer is already sliced to c) - // internal_buf_size_ = d - a +void OtRecvStore::Reset() { + SliceBase::Reset(); + bit_buf_.reset(); + blk_buf_.reset(); + type_ = OtStoreType::Compact; + ConsistencyCheck(); +} - IncreaseBufCtr(num); // increase the buffer counter +uint128_t OtRecvStore::GetBlock(uint64_t idx) const { + YACL_ENFORCE(idx < GetUseSize()); + return blk_buf_->operator[](GetBufIdx(idx)); +} - return out; +void OtRecvStore::SetBlock(uint64_t idx, uint128_t val) { + blk_buf_->operator[](GetBufIdx(idx)) = val; +} + +absl::Span OtRecvStore::GetBlkBufSpan() { + YACL_ENFORCE(!IsSliced()); + return {reinterpret_cast(blk_buf_->data()), GetBufSize()}; +} + +Buffer OtRecvStore::GetBlkBuf() { + // Constructs Buffer object by copy + return {blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)}; +} + +OtRecvStore::BlkBufPtr OtRecvStore::StealBlkBuf() { + YACL_ENFORCE(blk_buf_.use_count() == 1); + auto ret = std::move(blk_buf_); + Reset(); + return ret; +} + +UninitAlignedVector OtRecvStore::CopyBlkBuf() const { + return {blk_buf_->begin() + internal_use_ctr_, + blk_buf_->begin() + internal_use_ctr_ + internal_use_size_}; } uint8_t OtRecvStore::GetChoice(uint64_t idx) const { + YACL_ENFORCE(idx < GetUseSize(), + "required idx={} is larger than the allowed use size={}", idx, + GetUseSize()); if (type_ == OtStoreType::Compact) { return blk_buf_->operator[](GetBufIdx(idx)) & 0x1; } else { @@ -181,27 +195,28 @@ uint8_t OtRecvStore::GetChoice(uint64_t idx) const { } } -uint128_t OtRecvStore::GetBlock(uint64_t idx) const { - return blk_buf_->operator[](GetBufIdx(idx)); -} - void OtRecvStore::SetChoice(uint64_t idx, bool val) { YACL_ENFORCE(type_ == OtStoreType::Normal, "Manipulating choice is currently not allowed in compact mode"); bit_buf_->operator[](GetBufIdx(idx)) = val; } -void OtRecvStore::SetBlock(uint64_t idx, uint128_t val) { - blk_buf_->operator[](GetBufIdx(idx)) = val; -} - void OtRecvStore::FlipChoice(uint64_t idx) { YACL_ENFORCE(type_ == OtStoreType::Normal, "Manipulating choice is currently not allowed in compact mode"); bit_buf_->operator[](GetBufIdx(idx)).flip(); } -dynamic_bitset OtRecvStore::CopyChoice() const { +Buffer OtRecvStore::GetBitBuf() { + // Constructs Buffer object by copy + return {bit_buf_->data(), bit_buf_->num_blocks() * sizeof(uint128_t)}; +} + +void OtRecvStore::SetBitBuf(const dynamic_bitset& in) { + bit_buf_ = std::make_shared>(in); +} + +dynamic_bitset OtRecvStore::CopyBitBuf() const { // [Warning] low efficency if (type_ == OtStoreType::Compact) { dynamic_bitset out(Size()); @@ -218,9 +233,15 @@ dynamic_bitset OtRecvStore::CopyChoice() const { return out; } -UninitAlignedVector OtRecvStore::CopyBlocks() const { - return {blk_buf_->begin() + internal_use_ctr_, - blk_buf_->begin() + internal_use_ctr_ + internal_use_size_}; +void OtRecvStore::ConsistencyCheck() const { + SliceBase::ConsistencyCheck(); + YACL_ENFORCE(blk_buf_->size() >= internal_buf_size_, + "Actual buffer size: {}, but recorded " + "internal buffer size is: {}", + blk_buf_->size(), internal_buf_size_); + if (type_ == OtStoreType::Normal) { + YACL_ENFORCE_EQ(bit_buf_->size(), blk_buf_->size()); + } } OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, @@ -303,9 +324,9 @@ OtRecvStore MakeCompactOtRecvStore(UninitAlignedVector&& blocks) { OtStoreType::Compact}; } -//================================// -// OtSendStore // -//================================// +//---------------------------------- +// OtSendStore +//---------------------------------- OtSendStore::OtSendStore(BlkBufPtr blk_ptr, uint128_t delta, uint64_t use_ctr, uint64_t use_size, uint64_t buf_ctr, uint64_t buf_size, @@ -326,11 +347,6 @@ OtSendStore::OtSendStore(uint64_t num, OtStoreType type) : type_(type) { ConsistencyCheck(); } -Buffer OtSendStore::GetBlockBuf() { - // Constructs Buffer object by copy - return Buffer(blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)); -} - void OtSendStore::Reset() { SliceBase::Reset(); blk_buf_.reset(); @@ -413,6 +429,11 @@ uint128_t OtSendStore::GetDelta() const { return delta_; } +void OtSendStore::SetDelta(uint128_t delta) { + YACL_ENFORCE(delta != 0, "Error, you can not set delta to zero."); + delta_ = delta; +} + uint128_t OtSendStore::GetBlock(uint64_t ot_idx, uint64_t msg_idx) const { YACL_ENFORCE(msg_idx == 0 || msg_idx == 1); const uint64_t ot_blk_num = (type_ == OtStoreType::Compact) ? 1 : 2; @@ -424,11 +445,6 @@ uint128_t OtSendStore::GetBlock(uint64_t ot_idx, uint64_t msg_idx) const { } } -void OtSendStore::SetDelta(uint128_t delta) { - YACL_ENFORCE(delta != 0, "Error, you can not set delta to zero."); - delta_ = delta; -} - void OtSendStore::SetNormalBlock(uint64_t ot_idx, uint64_t msg_idx, uint128_t val) { YACL_ENFORCE(type_ == OtStoreType::Normal, @@ -443,9 +459,26 @@ void OtSendStore::SetCompactBlock(uint64_t ot_idx, uint128_t val) { blk_buf_->operator[](GetBufIdx(ot_idx)) = val; } -UninitAlignedVector OtSendStore::CopyCotBlocks() const { +Buffer OtSendStore::GetBlkBuf() { + // Constructs Buffer object by copy + return {blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)}; +} + +absl::Span OtSendStore::GetBlkBufSpan() { + YACL_ENFORCE(!IsSliced()); + return {reinterpret_cast(blk_buf_->data()), GetBufSize()}; +} + +OtSendStore::BlkBufPtr OtSendStore::StealBlkBuf() { + YACL_ENFORCE(blk_buf_.use_count() == 1); + auto ret = std::move(blk_buf_); + Reset(); + return ret; +} + +UninitAlignedVector OtSendStore::CopyCotBlkBuf() const { YACL_ENFORCE(type_ == OtStoreType::Compact, - "CopyCotBlocks() is only allowed in compact mode"); + "CopyCotBlkBuf() is only allowed in compact mode"); return {blk_buf_->begin() + internal_buf_ctr_, blk_buf_->begin() + internal_buf_ctr_ + internal_use_size_}; } diff --git a/yacl/kernel/algorithms/ot_store.h b/yacl/kernel/algorithms/ot_store.h index 08d2274d..c898cb41 100644 --- a/yacl/kernel/algorithms/ot_store.h +++ b/yacl/kernel/algorithms/ot_store.h @@ -113,38 +113,56 @@ class OtRecvStore : public SliceBase { // get ot store type OtStoreType Type() const { return type_; } - // get a buffer copy of choice buf - Buffer GetChoiceBuf(); - - // get a buffer copy of block buf - Buffer GetBlockBuf(); - // reset ot store void Reset(); - // get the avaliable ot number for this slice1 + // get the avaliable ot number for this slice uint64_t Size() const { return GetUseSize(); } - // access a choice bit with a given slice index - uint8_t GetChoice(uint64_t idx) const; + // ----------------- + // Manipulate blocks + // ----------------- // access a block element with the given index uint128_t GetBlock(uint64_t idx) const; - // modify a choice bit(val) with a given slice index - void SetChoice(uint64_t idx, bool val); - // modify a block with a given slice index void SetBlock(uint64_t idx, uint128_t val); + // get a span of the block + absl::Span GetBlkBufSpan(); + + // get a buffer copy of block buf (in bytes) + Buffer GetBlkBuf(); + + // allow steal (in bytes) + BlkBufPtr StealBlkBuf(); + + // copy out the blocks (in bytes) + UninitAlignedVector CopyBlkBuf() const; + + // ------------------ + // Manipulate choices + // ------------------ + + // access a choice bit with a given slice index + uint8_t GetChoice(uint64_t idx) const; + // modify a choice bit(val) with a given slice index + + void SetChoice(uint64_t idx, bool val); + // modify a choice bit(val) with a given slice index + // flip a choice bit with a given slice index void FlipChoice(uint64_t idx); - // copy out the sliced choice buffer [wanring: low efficiency] - dynamic_bitset CopyChoice() const; + // get a buffer copy of bit buf (choice buf) + Buffer GetBitBuf(); + + // set bit buf + void SetBitBuf(const dynamic_bitset& in); // copy out the sliced choice buffer [wanring: low efficiency] - UninitAlignedVector CopyBlocks() const; + dynamic_bitset CopyBitBuf() const; private: // check the consistency of ot receiver store @@ -209,9 +227,6 @@ class OtSendStore : public SliceBase { // get ot store type OtStoreType Type() const { return type_; } - // get a buffer copy of block buf - Buffer GetBlockBuf(); - // reset ot store void Reset(); @@ -221,20 +236,29 @@ class OtSendStore : public SliceBase { // access the delta of the cot uint128_t GetDelta() const; - // access a block with the given index - uint128_t GetBlock(uint64_t ot_idx, uint64_t msg_idx) const; - // set the delta of the cot void SetDelta(uint128_t delta); + // access a block with the given index + uint128_t GetBlock(uint64_t ot_idx, uint64_t msg_idx) const; + // modify a block with the given index void SetNormalBlock(uint64_t ot_idx, uint64_t msg_idx, uint128_t val); // set a cot block void SetCompactBlock(uint64_t ot_idx, uint128_t val); + // get a buffer copy of block buf + Buffer GetBlkBuf(); + + // get a span of block buf + absl::Span GetBlkBufSpan(); + + // allow steal + BlkBufPtr StealBlkBuf(); + // copy out cot blocks - UninitAlignedVector CopyCotBlocks() const; + UninitAlignedVector CopyCotBlkBuf() const; private: // check the consistency of ot receiver store diff --git a/yacl/kernel/algorithms/ot_store_test.cc b/yacl/kernel/algorithms/ot_store_test.cc index bd95e364..d1cd8be4 100644 --- a/yacl/kernel/algorithms/ot_store_test.cc +++ b/yacl/kernel/algorithms/ot_store_test.cc @@ -75,6 +75,10 @@ RandCompactOtSendStore(uint64_t num) { } } // namespace +// ---------------------------------------------------- +// TESTs +// ---------------------------------------------------- + TEST(OtRecvStoreTest, ConstructorTest) { // GIVEN const size_t ot_num = 100; @@ -134,6 +138,46 @@ TEST(OtRecvStoreTest, GetElementsTest) { EXPECT_THROW(ot_store.GetBlock(-1), yacl::Exception); } +TEST(OtRecvStoreTest, BlkBufTest) { + // GIVEN + const size_t ot_num = 100; + auto recv_choices = RandBits>(ot_num); + auto recv_blocks = RandVec(ot_num); + auto ot_store = MakeOtRecvStore(recv_choices, recv_blocks); + + // get element tests + auto idx = RandInRange(ot_num); + + auto span = ot_store.GetBlkBufSpan(); + EXPECT_EQ(span.size(), ot_num); + + EXPECT_EQ(ot_store.GetChoice(idx), recv_choices[idx]); + EXPECT_EQ(span[idx], recv_blocks[idx]); + + EXPECT_EQ(ot_store.GetChoice(0), recv_choices[0]); + EXPECT_EQ(span[0], recv_blocks[0]); +} + +TEST(OtSendStoreTest, BlkBufTest) { + // GIVEN + const size_t ot_num = 100; + std::vector> blocks(ot_num); + Prg prg; + for (size_t i = 0; i < ot_num; ++i) { + blocks[i][0] = prg(); + blocks[i][1] = prg(); + } + auto ot_store = MakeOtSendStore(blocks); + + // get element tests + auto idx = RandInRange(ot_num); + + auto span = ot_store.GetBlkBufSpan(); + EXPECT_EQ(span.size(), 2 * ot_num); + EXPECT_EQ(span[2 * idx], blocks[idx][0]); + EXPECT_EQ(span[2 * idx + 1], blocks[idx][1]); +} + TEST(OtSendStoreTest, ConstructorTest) { // GIVEN const uint64_t ot_num = 2; diff --git a/yacl/kernel/algorithms/sgrr_ote.cc b/yacl/kernel/algorithms/sgrr_ote.cc index c8e360f5..14febe5f 100644 --- a/yacl/kernel/algorithms/sgrr_ote.cc +++ b/yacl/kernel/algorithms/sgrr_ote.cc @@ -307,7 +307,7 @@ void SgrrOtExtRecv_fixed_index(const OtRecvStore& base_ot, uint32_t n, // we need log(n) 1-2 OTs from log(n) ROTs // most significant bit first - dynamic_bitset choice = base_ot.CopyChoice(); + dynamic_bitset choice = base_ot.CopyBitBuf(); const uint64_t index = GetPuncturedIndex(choice, ot_num - 1); YACL_ENFORCE_LT(index, n); // index < n diff --git a/yacl/kernel/algorithms/softspoken_ote.cc b/yacl/kernel/algorithms/softspoken_ote.cc index de907ae6..12ff4379 100644 --- a/yacl/kernel/algorithms/softspoken_ote.cc +++ b/yacl/kernel/algorithms/softspoken_ote.cc @@ -22,7 +22,9 @@ #include "yacl/base/aligned_vector.h" #include "yacl/base/byte_container_view.h" +#include "yacl/base/exception.h" #include "yacl/crypto/tools/common.h" +#include "yacl/kernel/algorithms/ot_store.h" #include "yacl/math/f2k/f2k.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" @@ -158,6 +160,28 @@ inline void XorReduceImpl(uint64_t k, absl::Span inout) { } // namespace +void SoftspokenOtExtSend(const std::shared_ptr& ctx, + /* rot */ const OtRecvStore& base_ot, + /* cot */ OtSendStore* out, uint64_t k, uint64_t step, + bool mal) { + std::vector> send_blocks(out->Size()); + auto send = + SoftspokenOtExtSender(k, step, mal, out->Type() == OtStoreType::Compact); + send.OneTimeSetup(ctx, base_ot); + send.Send(ctx, out); +} + +void SoftspokenOtExtRecv(const std::shared_ptr& ctx, + /* rot */ const OtSendStore& base_ot, + const dynamic_bitset& choices, + /* cot */ OtRecvStore* out, uint64_t k, uint64_t step, + bool mal) { + auto recv = SoftspokenOtExtReceiver(k, step, mal, + out->Type() == OtStoreType::Compact); + recv.OneTimeSetup(ctx, base_ot); + recv.Recv(ctx, choices, out); +} + SoftspokenOtExtSender::SoftspokenOtExtSender(uint64_t k, uint64_t step, bool mal, bool compact) : k_(k), step_(step), mal_(mal), compact_(compact) { @@ -240,7 +264,7 @@ void SoftspokenOtExtSender::OneTimeSetup( dup_base_ot.SetChoice(0, 1); } - delta_ = dup_base_ot.CopyChoice().data()[0]; + delta_ = dup_base_ot.CopyBitBuf().data()[0]; auto recv_size = 128 * 2 * sizeof(uint128_t) + pprf_num_ * (mal_ ? 64 : 0); auto recv_buf = ctx->Recv(ctx->NextRank(), "SGRR_OTE:RECV-CORR"); @@ -256,7 +280,7 @@ void SoftspokenOtExtSender::OneTimeSetup( auto sub_ot = dup_base_ot.NextSlice(k_limit); // TODO(@wenfan): [low efficiency] It would copy dynamic_bitset. // punctured index for i-th pprf - punctured_idx_[i] = sub_ot.CopyChoice().data()[0]; + punctured_idx_[i] = sub_ot.CopyBitBuf().data()[0]; // punctured leaves for the i-th pprf auto leaves = absl::MakeSpan(punctured_leaves_.data() + i * pprf_range_, range_limit); @@ -608,6 +632,7 @@ void SoftspokenOtExtSender::Send( const uint64_t batch_num = (expand_numOt - batch_offset + kBatchSize - 1) / kBatchSize; const uint64_t all_batch_num = super_batch_num * step + batch_num; + YACL_ENFORCE(all_batch_num * kBatchSize == expand_numOt); UninitAlignedVector, 32> allV(all_batch_num); @@ -723,6 +748,243 @@ void SoftspokenOtExtSender::Send( } } +void SoftspokenOtExtSender::Send(const std::shared_ptr& ctx, + OtSendStore* out) { + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + + if (!inited_) { + OneTimeSetup(ctx); + } + + const uint64_t& step = step_; + const uint64_t batch_size = kBatchSize; + const uint64_t super_batch_size = step * batch_size; + const uint64_t numOt = out->Size(); + + const uint64_t expand_numOt = + (numOt + kS + kBatchSize - 1) / kBatchSize * kBatchSize; + const uint64_t super_batch_num = numOt / super_batch_size; + const uint64_t batch_offset = super_batch_num * super_batch_size; + const uint64_t batch_num = + (expand_numOt - batch_offset + kBatchSize - 1) / kBatchSize; + const uint64_t all_batch_num = super_batch_num * step + batch_num; + + YACL_ENFORCE(all_batch_num * kBatchSize == expand_numOt); + + out->SetDelta(delta_); + UninitAlignedVector, 32> allV(all_batch_num); + // OT extension + // AVX need to be aligned to 32 bytes. + // Extra one array for consitency check in batch_num for-loop. + UninitAlignedVector, 32> V(step + 1); + + // Hash Buffer to perform AES/PRG + // Xor Buffer to perform XorReduce ( \sum x PRG(M_x) ) + auto hash_buff = UninitAlignedVector(compress_leaves_.size()); + auto xor_buff = UninitAlignedVector(pprf_num_ * pprf_range_, 0); + + // deal with super batch + for (uint64_t t = 0; t < super_batch_num; ++t) { + // The same as IKNP OTe, see `yacl/crypto/primitive/ot/iknp_ote_cc` + // 1. receive the masked choices + auto recv_buff = ctx->Recv(ctx->NextRank(), "softspoken_switch_u"); + auto recv_U = absl::MakeSpan(static_cast(recv_buff.data()), + recv_buff.size() / sizeof(uint128_t)); + YACL_ENFORCE(recv_U.size() == step * pprf_num_); + + for (uint64_t s = 0; s < step; ++s) { + // 2. smallfield/subspace VOLE + GenSfVole(absl::MakeSpan(hash_buff), absl::MakeSpan(xor_buff), + absl::MakeSpan(recv_U.data() + s * pprf_num_, pprf_num_), + absl::MakeSpan(V[s])); + if (mal_) { + allV[t * step + s] = V[s]; + } + + // 3. Matrix Transpose + MatrixTranspose128(&V[s]); + + for (uint64_t j = 0; j < kBatchSize; ++j) { + out->SetCompactBlock(t * super_batch_size + s * kBatchSize + j, + V[s][j]); + } + } + } + + // deal with normal batch + for (uint64_t t = 0; t < batch_num; ++t) { + // The same as IKNP OTe + // 1. receive the masked choices + auto recv_buff = ctx->Recv(ctx->NextRank(), "softspoken_switch_u"); + auto recv_U = absl::MakeSpan(static_cast(recv_buff.data()), + recv_buff.size() / sizeof(uint128_t)); + YACL_ENFORCE(recv_U.size() == pprf_num_); + + // 2. smallfield/subspace VOLE + GenSfVole(absl::MakeSpan(hash_buff), absl::MakeSpan(xor_buff), + absl::MakeSpan(recv_U), absl::MakeSpan(V[t])); + if (mal_) { + allV[super_batch_num * step + t] = V[t]; + } + + // 3. Matrix Transpose + if (numOt > batch_offset + t * kBatchSize) { + MatrixTranspose128(&V[t]); + + const uint64_t limit = + std::min(kBatchSize, numOt - batch_offset - t * kBatchSize); + for (uint64_t j = 0; j < limit; ++j) { + out->SetCompactBlock(batch_offset + t * kBatchSize + j, V[t][j]); + } + } + } + + if (mal_) { + // Sender generates a random seed and sends it to receiver. + uint128_t seed = SyncSeedSend(ctx); + + // Consistency check + std::vector rand_samples(all_batch_num * 2); + PrgAesCtr(seed, absl::Span(rand_samples)); + + CheckMsg check_msgs; + for (size_t i = 0; i < all_batch_num; ++i) { + for (size_t k = 0; k < kKappa; ++k) { + check_msgs.t[k] ^= ClMul64( + absl::MakeSpan(rand_samples.data() + i * 2, 2), + absl::MakeSpan(reinterpret_cast(allV[i].data() + k), 2)); + } + } + + CheckMsg msgs; + std::array check_vals; + for (size_t k = 0; k < kKappa; ++k) { + check_vals[k] = Reduce64(check_msgs.t[k]); + } + + msgs.Unpack(ctx->Recv(ctx->NextRank(), fmt::format("MAL-SS-CHECK-FINAL"))); + + for (size_t k = 0; k < kKappa; ++k) { + auto recv_check_val = msgs.t[k] ^ (p_idx_mask_[k] & msgs.x); + YACL_ENFORCE(recv_check_val == check_vals[k]); + } + } +} + +void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, + const dynamic_bitset& choices, + /* compact cot */ OtRecvStore* out) { + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + + if (!inited_) { + OneTimeSetup(ctx); + } + + YACL_ENFORCE(choices.size() == out->Size()); + const uint64_t& step = step_; + const uint64_t batch_size = kBatchSize; + const uint64_t super_batch_size = step * batch_size; + const uint64_t numOt = out->Size(); + const uint64_t expand_numOt = + (numOt + kS + kBatchSize - 1) / kBatchSize * kBatchSize; + const uint64_t super_batch_num = numOt / super_batch_size; + const uint64_t batch_offset = super_batch_num * super_batch_size; + const uint64_t batch_num = + (expand_numOt - batch_offset + kBatchSize - 1) / kBatchSize; + const uint64_t all_batch_num = super_batch_num * step + batch_num; + YACL_ENFORCE(all_batch_num * kBatchSize == expand_numOt); + + UninitAlignedVector, 32> allW(all_batch_num); + auto choice_ext = ExtendChoice(choices, expand_numOt); + // AVX need to be aligned to 32 bytes. + // Extra one array for consitency check in batch_num for-loop. + UninitAlignedVector, 32> W(step + 1); + // AES Buffer & Xor Buffer to perform AES/PRG and XorReduce + auto xor_buff = UninitAlignedVector(pprf_num_ * pprf_range_, 0); + UninitAlignedVector U(pprf_num_ * step); + + // deal with super batch + for (uint64_t t = 0; t < super_batch_num; ++t) { + // The same as IKNP OTe, see `yacl/crypto/primitive/ot/iknp_ote_cc` + // 1. smallfield/subspace VOLE + for (uint64_t s = 0; s < step; ++s) { + GenSfVole(choice_ext.data()[t * step + s], absl::MakeSpan(xor_buff), + absl::MakeSpan(U.data() + s * pprf_num_, pprf_num_), + absl::MakeSpan(W[s])); + if (mal_) { + allW[t * step + s] = W[s]; + } + } + // 2. send the masked choices + ctx->SendAsync(ctx->NextRank(), + ByteContainerView(U.data(), U.size() * sizeof(uint128_t)), + "softspoken_switch_u"); + for (uint64_t s = 0; s < step; ++s) { + // 3. matrix transpose + MatrixTranspose128(&W[s]); + for (uint64_t j = 0; j < kBatchSize; ++j) { + out->SetBlock(t * super_batch_size + s * batch_size + j, W[s][j]); + } + } + } + + // deal with normal bathc + for (uint64_t t = 0; t < batch_num; ++t) { + // The same as IKNP OTe + // 1. smallfield/subspace VOLE + GenSfVole(choice_ext.data()[super_batch_num * step + t], + absl::MakeSpan(xor_buff), absl::MakeSpan(U), + absl::MakeSpan(W[t])); + if (mal_) { + allW[super_batch_num * step + t] = W[t]; + } + // 2. send the masked choices + ctx->SendAsync(ctx->NextRank(), + ByteContainerView(U.data(), pprf_num_ * sizeof(uint128_t)), + "softspoken_switch_u"); + + // 3. matrix transpose + if (numOt > batch_offset + t * kBatchSize) { + MatrixTranspose128(&W[t]); + const uint64_t limit = + std::min(kBatchSize, numOt - batch_offset - t * kBatchSize); + for (uint64_t j = 0; j < limit; ++j) { + out->SetBlock(batch_offset + t * kBatchSize + j, W[t][j]); + } + } + } + + if (mal_) { + // Recevies the random seed from sender + uint128_t seed = SyncSeedRecv(ctx); + + // Consistency check + std::vector rand_samples(all_batch_num * 2); + PrgAesCtr(seed, absl::Span(rand_samples)); + + CheckMsg check_msgs; + auto choice_span = absl::MakeSpan( + reinterpret_cast(choice_ext.data()), all_batch_num * 2); + check_msgs.x ^= ClMul64(absl::MakeSpan(rand_samples), choice_span); + + for (size_t i = 0; i < all_batch_num; ++i) { + for (size_t k = 0; k < kKappa; ++k) { + check_msgs.t[k] ^= ClMul64( + absl::MakeSpan(rand_samples.data() + i * 2, 2), + absl::MakeSpan(reinterpret_cast(allW[i].data() + k), 2)); + } + } + + CheckMsg msgs; + msgs.x = Reduce64(check_msgs.x); + for (size_t k = 0; k < kKappa; ++k) { + msgs.t[k] = Reduce64(check_msgs.t[k]); + } + auto buf = msgs.Pack(); + ctx->SendAsync(ctx->NextRank(), buf, fmt::format("MAL-SS-CHECK-FINAL")); + } +} + // old style interface void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, const dynamic_bitset& choices, diff --git a/yacl/kernel/algorithms/softspoken_ote.h b/yacl/kernel/algorithms/softspoken_ote.h index 8503c0c1..42d18f06 100644 --- a/yacl/kernel/algorithms/softspoken_ote.h +++ b/yacl/kernel/algorithms/softspoken_ote.h @@ -42,10 +42,9 @@ namespace yacl::crypto { // SoftSpoken OT Extension Implementation // -// This implementation bases on Softspoken OTE, for more theoretical details, -// see https://eprint.iacr.org/2022/192.pdf, figure 7 and figure 8. -// In our implementation, we choose repetition code as default code for subspace -// VOLE. +// See https://eprint.iacr.org/2022/192.pdf, figure 7 and figure 8. +// In our implementation, we choose repetition code as the default code for +// subspace VOLE. // // kappa/k instances kappa/k instances // +---------+ +----------+ +------------------+ +-------+ @@ -67,13 +66,30 @@ namespace yacl::crypto { // NOTE: // => OT Extension sender requires receiver base ot context. // => OT Extension receiver requires sender base ot context. -// => Computation cost would be O(2^k/k). -// => Communication for each OT needs 128/k bits. +// => Computation complexity: O(2^k/k). +// => Communication for each OT: 128/k bits. // => parameter k should be a small number (no greater than 10). -// => k = 2, 4, 8 are recommended in the localhost, LAN, WAN setting -// respectively. -// => step = 64 for k = 1 or 2; step = 32 for k = 3 or 4. - +// +// Our recommendations for different use cases (you may also choose your own +// parameter) +// 1. localhost: k=2, step=64 +// 2. LAN: k=4, step=32 +// 3. WAN: k=8, step=32 + +void SoftspokenOtExtSend(const std::shared_ptr& ctx, + /* rot */ const OtRecvStore& base_ot, + /* cot */ OtSendStore* out, uint64_t k = 2, + uint64_t step = 0, bool mal = false); + +void SoftspokenOtExtRecv(const std::shared_ptr& ctx, + /* rot */ const OtSendStore& base_ot, + const dynamic_bitset& choices, + /* cot */ OtRecvStore* out, uint64_t k = 2, + uint64_t step = 0, bool mal = false); + +// ------------------------------ +// Historical or Customized APIs +// ------------------------------ class SoftspokenOtExtSender { public: explicit SoftspokenOtExtSender(uint64_t k = 2, uint64_t step = 0, @@ -84,6 +100,10 @@ class SoftspokenOtExtSender { void OneTimeSetup(const std::shared_ptr& ctx, const OtRecvStore& base_ot /* rot */); + // this function should be the core api + void Send(const std::shared_ptr& ctx, + /* compact cot*/ OtSendStore* out); + // old-style interface void Send(const std::shared_ptr& ctx, absl::Span> send_blocks, bool cot = false); @@ -139,6 +159,9 @@ class SoftspokenOtExtSender { bool compact_{false}; // compact mode }; +// ------------------------------ +// Historical or Customized APIs +// ------------------------------ class SoftspokenOtExtReceiver { public: explicit SoftspokenOtExtReceiver(uint64_t k = 2, uint64_t step = 0, @@ -149,6 +172,11 @@ class SoftspokenOtExtReceiver { void OneTimeSetup(const std::shared_ptr& ctx, const OtSendStore& base_ot /* rot */); + // this function should be the core api + void Recv(const std::shared_ptr& ctx, + const dynamic_bitset& choices, + /* compact cot */ OtRecvStore* out); + // old-style interface void Recv(const std::shared_ptr& ctx, const dynamic_bitset& choices, diff --git a/yacl/kernel/algorithms/softspoken_ote_test.cc b/yacl/kernel/algorithms/softspoken_ote_test.cc index 3ff0344e..24c461fd 100644 --- a/yacl/kernel/algorithms/softspoken_ote_test.cc +++ b/yacl/kernel/algorithms/softspoken_ote_test.cc @@ -361,6 +361,42 @@ TEST_P(SoftspokenOtExtTest, CotStoreWorks) { } } +TEST_P(SoftspokenOtExtTest, APIWorks) { + // GIVEN + const int kWorldSize = 2; + const size_t num_ot = GetParam().num_ot; + const bool mal = GetParam().mal; + auto lctxs = link::test::SetupWorld(kWorldSize); // setup network + auto base_ot = MockRots(128); // mock option + auto choices = RandBits>(num_ot); // get input + + // WHEN + // std::vector> send_out(num_ot); + // std::vector recv_out(num_ot); + OtSendStore ot_send(num_ot, OtStoreType::Compact); + OtRecvStore ot_recv(num_ot, OtStoreType::Compact); + + std::future sender = std::async([&] { + SoftspokenOtExtSend(lctxs[0], base_ot.recv, &ot_send, 2, false, mal); + }); + std::future receiver = std::async([&] { + SoftspokenOtExtRecv(lctxs[1], base_ot.send, choices, &ot_recv, 2, false, + mal); + }); + receiver.get(); + sender.get(); + + // THEN + for (size_t i = 0; i < num_ot; ++i) { + // correctness of ot + EXPECT_EQ(ot_send.GetBlock(i, ot_recv.GetChoice(i)), ot_recv.GetBlock(i)); + + // generated ot messages should not equal + EXPECT_NE(ot_send.GetBlock(i, 1 - ot_recv.GetChoice(i)), + ot_recv.GetBlock(i)); + } +} + INSTANTIATE_TEST_SUITE_P(Works_Instances, SoftspokenStepTest, testing::Values(StepTestParams{1}, // StepTestParams{2}, // diff --git a/yacl/kernel/kernel.h b/yacl/kernel/kernel.h index f2ca4d38..b9b82f28 100644 --- a/yacl/kernel/kernel.h +++ b/yacl/kernel/kernel.h @@ -20,9 +20,14 @@ namespace yacl::crypto { +// Kernel interface class class Kernel { public: - enum class Kind { SingleThread, MultiThread }; + enum class Kind { + SingleThread, // supports eval + MultiThread, // supports eval, eval_multithread + Streaming // supports eval, eval_multithread, and eval_streaming + }; virtual ~Kernel() = default; @@ -32,13 +37,29 @@ class Kernel { // virtual void comm(); - // virtual void eval(); + // virtual void eval() = 0; }; -// Stream kernel -class StreamKernel : public Kernel { +// Single-thread kernel +class SingleThreadKernel : public Kernel { public: Kind kind() const override { return Kind::SingleThread; } + + // virtual void eval(/* kernel-specific args*/) = 0; +}; + +// Multi-thread kernel +class MultiThreadKernel : public Kernel { + public: + Kind kind() const override { return Kind::MultiThread; } + // virtual void eval_multithread(/* kernel-specific args*/) = 0; +}; + +// Streaming kernel +class StreamingKernel : public Kernel { + public: + Kind kind() const override { return Kind::Streaming; } + // virtual void eval_streaming(/* kernel-specific args*/) = 0; }; } // namespace yacl::crypto diff --git a/yacl/kernel/ot_kernel.cc b/yacl/kernel/ot_kernel.cc new file mode 100644 index 00000000..923dbd9c --- /dev/null +++ b/yacl/kernel/ot_kernel.cc @@ -0,0 +1,275 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/kernel/ot_kernel.h" + +#include +#include +#include + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/exception.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/common.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/ferret_ote.h" +#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/algorithms/softspoken_ote.h" +#include "yacl/link/context.h" +#include "yacl/secparam.h" +#include "yacl/utils/parallel.h" +#include "yacl/utils/thread_pool.h" + +namespace yacl::crypto { + +namespace { + +constexpr uint32_t kBatchSize = 128; + +using OtMsg = uint128_t; +using OtMsgPair = std::array; +using OtChoices = dynamic_bitset; + +// Inplace-conversion from cot to rot +void naive_cot2rot(const OtSendStore& cot_store, OtSendStore* rot_store) { + YACL_ENFORCE(cot_store.Size() == rot_store->Size()); // size should match + YACL_ENFORCE(cot_store.Type() == OtStoreType::Compact); // compact mode + YACL_ENFORCE(rot_store->Type() == OtStoreType::Normal); // normal mode + const uint32_t ot_num = cot_store.Size(); // warning: narrow conversion + parallel_for(0, ot_num, 1, [&](uint64_t beg, uint64_t end) { + for (uint64_t i = beg; i < end; ++i) { + rot_store->SetNormalBlock(i, 0, CrHash_128(cot_store.GetBlock(i, 0))); + rot_store->SetNormalBlock(i, 1, CrHash_128(cot_store.GetBlock(i, 1))); + } + }); +} + +// Inplace-conversion from cot to rot +void naive_cot2rot(const OtRecvStore& cot_store, OtRecvStore* rot_store) { + const uint32_t ot_num = cot_store.Size(); // warning: narrow conversion + YACL_ENFORCE(cot_store.Type() == OtStoreType::Compact); // compact mode + YACL_ENFORCE(rot_store->Type() == OtStoreType::Normal); // normal mode + auto choices = cot_store.CopyBitBuf(); + rot_store->SetBitBuf(choices); + + parallel_for(0, ot_num, 1, [&](uint64_t beg, uint64_t end) { + for (uint64_t i = beg; i < end; ++i) { + rot_store->SetBlock(i, CrHash_128(cot_store.GetBlock(i))); + } + }); +} + +// Conversion from cot to rot (OtStoreType == Normal) +[[maybe_unused]] void naive_rot2ot( + const std::shared_ptr& lctx, + const OtSendStore& ot_store, absl::Span msgpairs) { + static_assert(kBatchSize % 128 == 0); // batch size should be multiple of 128 + YACL_ENFORCE(ot_store.Type() == OtStoreType::Normal); + YACL_ENFORCE(ot_store.Size() == msgpairs.size()); + const uint32_t ot_num = msgpairs.size(); + const uint32_t batch_num = (ot_num + kBatchSize - 1) / kBatchSize; + + dynamic_bitset masked_choices(ot_num); + auto buf = lctx->Recv(lctx->NextRank(), ""); + std::memcpy(masked_choices.data(), buf.data(), buf.size()); + + // for each batch + for (uint32_t i = 0; i < batch_num; ++i) { + const uint32_t limit = std::min(kBatchSize, ot_num - i * kBatchSize); + + // generate masks for all msg pairs + std::vector batch_send(limit); + for (uint32_t j = 0; j < limit; ++j) { + auto idx = i * kBatchSize + j; + // fmt::print("{} {}\n", idx, masked_choices.size()); + + if (!masked_choices[idx]) { + batch_send[j][0] = ot_store.GetBlock(idx, 0) ^ msgpairs[idx][0]; + batch_send[j][1] = ot_store.GetBlock(idx, 1) ^ msgpairs[idx][1]; + } else { + batch_send[j][0] = ot_store.GetBlock(idx, 1) ^ msgpairs[idx][0]; + batch_send[j][1] = ot_store.GetBlock(idx, 0) ^ msgpairs[idx][1]; + } + } + + lctx->SendAsync( + lctx->NextRank(), + ByteContainerView(batch_send.data(), sizeof(uint128_t) * limit * 2), + ""); + } +} + +[[maybe_unused]] void naive_rot2ot( + const std::shared_ptr& lctx, + const OtRecvStore& ot_store, const OtChoices& choices, + absl::Span out) { + static_assert(kBatchSize % 128 == 0); // batch size should be multiple of 128 + YACL_ENFORCE(ot_store.Type() == OtStoreType::Normal); + YACL_ENFORCE(ot_store.Size() == choices.size()); + const uint32_t ot_num = ot_store.Size(); + const uint32_t batch_num = (ot_num + kBatchSize - 1) / kBatchSize; + + auto masked_choice = ot_store.CopyBitBuf() ^ choices; + lctx->SendAsync( + lctx->NextRank(), + ByteContainerView(masked_choice.data(), + sizeof(uint128_t) * masked_choice.num_blocks()), + "Sending masked choices"); + + // for each batch + for (uint32_t i = 0; i < batch_num; ++i) { + const uint32_t limit = std::min(kBatchSize, ot_num - i * kBatchSize); + + // receive masked messages + auto buf = lctx->Recv(lctx->NextRank(), ""); + std::vector batch_recv(limit); + std::memcpy(batch_recv.data(), buf.data(), buf.size()); + + for (uint32_t j = 0; j < limit; ++j) { + auto idx = i * kBatchSize + j; + // fmt::print("{} {}\n", idx, choices.size()); + out[idx] = batch_recv[j][choices[idx]] ^ ot_store.GetBlock(idx); + } + } +} + +} // namespace + +void OtKernel::init(const std::shared_ptr& lctx) { + switch (ext_algorithm_) { + case ExtAlgorithm::Ferret: { + auto required_ot_num = + FerretCotHelper(LpnParam::GetDefault(), LpnParam::GetDefault().n); + + // we use softspoken to init ferret ote + OtKernel ss_ote_kernel(role_, ExtAlgorithm::SoftSpoken); + ss_ote_kernel.init(lctx); + if (role_ == Role::Sender) { + init_ot_cache_ = OtSendStore(required_ot_num, OtStoreType::Compact); + ss_ote_kernel.eval_cot_random_choice( + lctx, required_ot_num, &std::get(init_ot_cache_)); + } else { + init_ot_cache_ = OtRecvStore(required_ot_num, OtStoreType::Compact); + ss_ote_kernel.eval_cot_random_choice( + lctx, required_ot_num, &std::get(init_ot_cache_)); + } + break; + } + case ExtAlgorithm::SoftSpoken: + if (role_ == Role::Sender) { + ss_core_ = SoftspokenOtExtSender(2, 0, false, /* compact ot */ true); + std::get(ss_core_).OneTimeSetup(lctx); + } else { + ss_core_ = SoftspokenOtExtReceiver(2, 0, false, /* compact ot */ true); + std::get(ss_core_).OneTimeSetup(lctx); + } + break; + default: + YACL_THROW("Unsupported OT Extension Algorithm"); + } + inited_ = true; +} + +void OtKernel::eval_cot_random_choice( + const std::shared_ptr& lctx, uint64_t ot_num, + OtSendStore* out) { + YACL_ENFORCE(ot_num == out->Size()); // size should match + YACL_ENFORCE(!out->IsSliced()); // no slice + YACL_ENFORCE(inited_); + + // the output ot store should be in compact mode + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + + switch (ext_algorithm_) { + case ExtAlgorithm::Ferret: { + // ferret ot sender needs OtSendStore + YACL_ENFORCE(std::holds_alternative(init_ot_cache_)); + + auto lpn_param = LpnParam::GetDefault(); // use default Lpn parameter + *out = FerretOtExtSend(lctx, std::get(init_ot_cache_), + lpn_param, ot_num); + break; + } + case ExtAlgorithm::SoftSpoken: { + YACL_ENFORCE(std::holds_alternative(ss_core_)); + std::get(ss_core_).Send(lctx, out); + break; + } + default: + YACL_THROW("Unsupported OT Extension Algorithm"); + } +} + +void OtKernel::eval_cot_random_choice( + const std::shared_ptr& lctx, uint64_t ot_num, + OtRecvStore* out) { + YACL_ENFORCE(ot_num == out->Size()); // size should match + YACL_ENFORCE(!out->IsSliced()); // no slice + YACL_ENFORCE(out->Type() == OtStoreType::Compact); // compact mode + YACL_ENFORCE(inited_); + + // the output ot store should be in compact mode + YACL_ENFORCE(out->Type() == OtStoreType::Compact); + + switch (ext_algorithm_) { + case ExtAlgorithm::Ferret: { + // ferret ot sender needs OtRecvStore + YACL_ENFORCE(std::holds_alternative(init_ot_cache_)); + + auto lpn_param = LpnParam::GetDefault(); // use default Lpn parameter + *out = FerretOtExtRecv(lctx, std::get(init_ot_cache_), + lpn_param, ot_num); + break; + } + case ExtAlgorithm::SoftSpoken: { + YACL_ENFORCE(std::holds_alternative(ss_core_)); + auto choices = SecureRandBits(ot_num); + std::get(ss_core_).Recv(lctx, choices, out); + break; + } + default: + YACL_THROW("Unsupported OT Extension Algorithm"); + } +} + +void OtKernel::eval_rot(const std::shared_ptr& lctx, + uint64_t ot_num, OtSendStore* out) { + YACL_ENFORCE(ot_num == out->Size()); // size should match + YACL_ENFORCE(!out->IsSliced()); // no slice + YACL_ENFORCE(out->Type() == OtStoreType::Normal); // normal mode + OtSendStore cot(ot_num, OtStoreType::Compact); + eval_cot_random_choice(lctx, ot_num, &cot); + naive_cot2rot(cot, out); +} + +void OtKernel::eval_rot(const std::shared_ptr& lctx, + uint64_t ot_num, OtRecvStore* out) { + YACL_ENFORCE(ot_num == out->Size()); // size should match + YACL_ENFORCE(!out->IsSliced()); // no slice + YACL_ENFORCE(out->Type() == OtStoreType::Normal); // normal mode + OtRecvStore cot(ot_num, OtStoreType::Compact); + eval_cot_random_choice(lctx, ot_num, &cot); + naive_cot2rot(cot, out); +} + +// void OtKernel::eval_rot(const std::shared_ptr& lctx, +// uint64_t ot_num, OtSendStore* out) { +// eval_cot_random_choice(lctx, ot_num, out); +// } +// void OtKernel::eval_rot(const std::shared_ptr& lctx, +// uint64_t ot_num, +// /* random choice */ OtRecvStore* out) {} + +} // namespace yacl::crypto diff --git a/yacl/kernel/ot_kernel.h b/yacl/kernel/ot_kernel.h new file mode 100644 index 00000000..697b8fb0 --- /dev/null +++ b/yacl/kernel/ot_kernel.h @@ -0,0 +1,129 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yacl/base/dynamic_bitset.h" +#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/kernel.h" + +/* submodules */ +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/ferret_ote.h" +#include "yacl/kernel/algorithms/softspoken_ote.h" +#include "yacl/secparam.h" + +namespace yacl::crypto { + +// -------------------------- +// Kernel: Oblivious Transfer +// -------------------------- +// OT kernel is an application-level API. The functionality of oblivious +// transfer could be seen as the following: +// +// [sender] [receiver] +// +// m0 ----> +----------+ <---- b +// | OT | +// m1 ----> +----------+ ----> mb +// +// where, +// - m0 is a uint128_t message +// - m1 is a uint128_t message +// - b is a single bit +// - mb is a uint128_t message, when b=0, mb=m0, when b=1, mb=m1. + +class OtKernel : SingleThreadKernel { + public: + enum class Role { Sender, Receiver }; + + enum class ExtAlgorithm { + Ferret, // default: softspoken + ferrret + SoftSpoken, // faster on LAN + // IKNP, // not recommended + // KOS, // not recommended + }; + + // constructor + explicit OtKernel(Role role, + ExtAlgorithm ext_algorithm = ExtAlgorithm::Ferret) + : role_(role), ext_algorithm_(ext_algorithm) {} + + // the one-time setup (base OT + extra) + void init(const std::shared_ptr& lctx); + + // re-init this kernel, kernel behaves the same as self-destroy, and re-init a + // new kernel, but refresh should be faster + // void refresh(); + + // ---------------------------------- + // Correlated OT, a.k.a. delta-ot + // ---------------------------------- + // In the correlated ot case, the two messages received by the sender is + // delta-correlated, which means m0 xor m1 = delta + // + // Note: all correlated ot instances are stored in *compact mode* + // see: yacl/kernel/algorithms/ot_store.h + void eval_cot_random_choice(const std::shared_ptr& lctx, + uint64_t ot_num, + /* compact mode */ OtSendStore* out); + void eval_cot_random_choice(const std::shared_ptr& lctx, + uint64_t ot_num, + /* compact mode */ OtRecvStore* out); + + // TODO(@shanzhu): Add this feature + // void cot_update_delta(); // update the delta of this ot kernel + + // ------------------------------- + // Random OT + // ------------------------------- + // Random ot with random messages and random chocies, rot will first generate + // *ot_num* cot instances, and then runs crhash (correlation-robust hash + // function) in parallel for all the cot messages. Even though the output + // ot_store of cot is compact, we use normal ot_store for rots. + void eval_rot(const std::shared_ptr& lctx, uint64_t ot_num, + /* normal mode */ OtSendStore* out); + void eval_rot(const std::shared_ptr& lctx, uint64_t ot_num, + /* normal mode */ OtRecvStore* out); + + private: + // ------------------------------- + // Configurations for Kernel + // ------------------------------- + const Role role_; // receiver or sender + + // OT Extension algorithm + const ExtAlgorithm ext_algorithm_ = ExtAlgorithm::Ferret; + + // -------------------------------// + // Kernel Internal States + // -------------------------------// + + // whether this kernel has been inited (e.g. one-time-setup) + bool inited_ = false; + + // the underlying store type for internal cache + using StoreTy = std::variant; + StoreTy init_ot_cache_; // ot cache from the init phase + + // the underlying core type for softspoken + using SoftSpokenCoreTy = std::variant; + SoftSpokenCoreTy ss_core_; +}; + +} // namespace yacl::crypto diff --git a/yacl/kernel/ot_kernel_bench.cc b/yacl/kernel/ot_kernel_bench.cc new file mode 100644 index 00000000..383f9680 --- /dev/null +++ b/yacl/kernel/ot_kernel_bench.cc @@ -0,0 +1,91 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include "yacl/kernel/ot_kernel.h" +#include "yacl/link/test_util.h" + +namespace yacl::crypto { + +static void BM_Ferret_OT_single_thread(benchmark::State& state) { + auto lctxs = link::test::SetupWorld(2); + for (auto _ : state) { + state.PauseTiming(); + { + const size_t num_ot = 1 << 24; + OtSendStore ot_send(num_ot, OtStoreType::Compact); // placeholder + OtRecvStore ot_recv(num_ot, OtStoreType::Compact); // placeholder + OtKernel kernel0(OtKernel::Role::Sender, OtKernel::ExtAlgorithm::Ferret); + OtKernel kernel1(OtKernel::Role::Receiver, + OtKernel::ExtAlgorithm::Ferret); + + // WHEN + state.ResumeTiming(); + auto sender = std::async([&] { + kernel0.init(lctxs[0]); + kernel0.eval_cot_random_choice(lctxs[0], num_ot, &ot_send); + }); + auto receiver = std::async([&] { + kernel1.init(lctxs[1]); + kernel1.eval_cot_random_choice(lctxs[1], num_ot, &ot_recv); + }); + sender.get(); + receiver.get(); + state.PauseTiming(); + } + state.ResumeTiming(); + } +} + +static void BM_SoftSpoken_OT_single_thread(benchmark::State& state) { + auto lctxs = link::test::SetupWorld(2); + for (auto _ : state) { + state.PauseTiming(); + { + const size_t num_ot = 1 << 24; + OtSendStore ot_send(num_ot, OtStoreType::Compact); // placeholder + OtRecvStore ot_recv(num_ot, OtStoreType::Compact); // placeholder + OtKernel kernel0(OtKernel::Role::Sender, + OtKernel::ExtAlgorithm::SoftSpoken); + OtKernel kernel1(OtKernel::Role::Receiver, + OtKernel::ExtAlgorithm::SoftSpoken); + + // WHEN + state.ResumeTiming(); + auto sender = std::async([&] { + kernel0.init(lctxs[0]); + kernel0.eval_cot_random_choice(lctxs[0], num_ot, &ot_send); + }); + auto receiver = std::async([&] { + kernel1.init(lctxs[1]); + kernel1.eval_cot_random_choice(lctxs[1], num_ot, &ot_recv); + }); + sender.get(); + receiver.get(); + state.PauseTiming(); + } + state.ResumeTiming(); + } +} +} // namespace yacl::crypto + +BENCHMARK(yacl::crypto::BM_Ferret_OT_single_thread) + ->Iterations(1) + ->Unit(benchmark::kMillisecond); + +BENCHMARK(yacl::crypto::BM_SoftSpoken_OT_single_thread) + ->Iterations(1) + ->Unit(benchmark::kMillisecond); +BENCHMARK_MAIN(); diff --git a/yacl/kernel/ot_kernel_test.cc b/yacl/kernel/ot_kernel_test.cc new file mode 100644 index 00000000..5086804a --- /dev/null +++ b/yacl/kernel/ot_kernel_test.cc @@ -0,0 +1,112 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/kernel/ot_kernel.h" + +#include + +#include +#include + +#include "gtest/gtest.h" + +#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/link/test_util.h" + +namespace yacl::crypto { + +struct TestParams { + size_t num_ot; + OtKernel::ExtAlgorithm ext_algorithm; +}; + +class OtTest : public ::testing::TestWithParam {}; + +TEST_P(OtTest, EvalCotRandomChoice) { + auto lctxs = link::test::SetupWorld(2); + + const size_t num_ot = GetParam().num_ot; + const auto ext_algorithm = GetParam().ext_algorithm; + + OtSendStore ot_send(num_ot, OtStoreType::Compact); // placeholder + OtRecvStore ot_recv(num_ot, OtStoreType::Compact); // placeholder + + OtKernel kernel0(OtKernel::Role::Sender, ext_algorithm); + OtKernel kernel1(OtKernel::Role::Receiver, ext_algorithm); + + // WHEN + auto sender = std::async([&] { + kernel0.init(lctxs[0]); + kernel0.eval_cot_random_choice(lctxs[0], num_ot, &ot_send); + }); + auto receiver = std::async([&] { + kernel1.init(lctxs[1]); + kernel1.eval_cot_random_choice(lctxs[1], num_ot, &ot_recv); + }); + sender.get(); + receiver.get(); + + EXPECT_EQ(ot_send.Type(), OtStoreType::Compact); + EXPECT_EQ(ot_recv.Type(), OtStoreType::Compact); + + for (uint64_t i = 0; i < num_ot; ++i) { + // correctness of ot + EXPECT_EQ(ot_send.GetBlock(i, ot_recv.GetChoice(i)), ot_recv.GetBlock(i)); + + // generated ot messages should not equal + EXPECT_NE(ot_send.GetBlock(i, 1 - ot_recv.GetChoice(i)), + ot_recv.GetBlock(i)); + + // generated choice should be random + // ... + } +} + +TEST_P(OtTest, EvalRot) { + auto lctxs = link::test::SetupWorld(2); + + const size_t num_ot = GetParam().num_ot; + const auto ext_algorithm = GetParam().ext_algorithm; + + OtSendStore ot_send(num_ot, OtStoreType::Normal); // placeholder + OtRecvStore ot_recv(num_ot, OtStoreType::Normal); // placeholder + + OtKernel kernel0(OtKernel::Role::Sender, ext_algorithm); + OtKernel kernel1(OtKernel::Role::Receiver, ext_algorithm); + + // WHEN + auto sender = std::async([&] { + kernel0.init(lctxs[0]); + kernel0.eval_rot(lctxs[0], num_ot, &ot_send); + }); + auto receiver = std::async([&] { + kernel1.init(lctxs[1]); + kernel1.eval_rot(lctxs[1], num_ot, &ot_recv); + }); + sender.get(); + receiver.get(); + + for (uint64_t i = 0; i < num_ot; ++i) { + EXPECT_EQ(ot_send.GetBlock(i, ot_recv.GetChoice(i)), ot_recv.GetBlock(i)); + EXPECT_NE(ot_send.GetBlock(i, 1 - ot_recv.GetChoice(i)), + ot_recv.GetBlock(i)); + } +} + +INSTANTIATE_TEST_SUITE_P( + Works_Instances, OtTest, + testing::Values(TestParams{1 << 20, OtKernel::ExtAlgorithm::Ferret}, + TestParams{1 << 20, OtKernel::ExtAlgorithm::SoftSpoken} // + )); +} // namespace yacl::crypto diff --git a/yacl/kernel/svole_kernel.h b/yacl/kernel/svole_kernel.h index b6581073..db130ecb 100644 --- a/yacl/kernel/svole_kernel.h +++ b/yacl/kernel/svole_kernel.h @@ -32,7 +32,7 @@ namespace yacl::crypto { // - where a is in GF(2^64), delta, b, c are in GF(2^128) // Sender receives: c, delta // Receiver receives: a, b -class SVoleKernel : StreamKernel { +class SVoleKernel : StreamingKernel { public: enum class Role { Sender, Receiver }; diff --git a/yacl/link/transport/channel_test.cc b/yacl/link/transport/channel_test.cc index 0d7b9450..6d6300a2 100644 --- a/yacl/link/transport/channel_test.cc +++ b/yacl/link/transport/channel_test.cc @@ -14,6 +14,9 @@ #include "yacl/link/transport/channel.h" +#include +#include + #include "fmt/format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -21,30 +24,27 @@ #include "brpc/errno.pb.h" #include "interconnection/link/transport.pb.h" +namespace brpc { + +DECLARE_bool(usercode_in_pthread); + +} // namespace brpc + namespace yacl::link::transport::test { -namespace ic_pb = org::interconnection::link; +struct Initial { + Initial() { brpc::FLAGS_usercode_in_pthread = true; } +}; -static std::string RandStr(size_t length) { - auto randchar = []() -> char { - const char charset[] = - "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[rand() % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; -} +static Initial g_initial{}; + +namespace ic_pb = org::interconnection::link; class MockTransportLink : public TransportLink { public: using TransportLink::TransportLink; MOCK_METHOD(void, SetMaxBytesPerChunk, (size_t), (override)); - MOCK_METHOD(std::unique_ptr, PackMonoRequest, - (const std::string&, ByteContainerView), (const, override)); + MOCK_METHOD(void, UnpackMonoRequest, (const Request&, std::string*, ByteContainerView*), (const, override)); @@ -65,6 +65,15 @@ class MockTransportLink : public TransportLink { static size_t max_bytes_per_chunk = 2; return max_bytes_per_chunk; } + std::unique_ptr<::google::protobuf::Message> PackMonoRequest( + const std::string& key, ByteContainerView value) const override { + auto request = std::make_unique(); + request->set_key(key); + request->set_value(value.data(), value.size()); + request->set_trans_type(ic_pb::TransType::MONO); + + return request; + } std::unique_ptr<::google::protobuf::Message> PackChunkedRequest( const std::string& key, ByteContainerView value, size_t offset, size_t total_length) const override { @@ -84,6 +93,10 @@ class MockTransportLink : public TransportLink { class ChannelSendRetryTest : public testing::Test { protected: void SetUp() override { + brpc::FLAGS_usercode_in_pthread = true; + SPDLOG_INFO("brpc::usercode_in_pthread: {}", + brpc::FLAGS_usercode_in_pthread); + const size_t send_rank = 0; const size_t recv_rank = 1; sender_delegate_ = @@ -99,6 +112,7 @@ class ChannelSendRetryTest : public testing::Test { brpc::HTTP_STATUS_MULTIPLE_CHOICES}; sender_ = std::make_shared(sender_delegate_, false, retry_options); + SPDLOG_INFO("test_start"); } void TearDown() override { @@ -111,109 +125,68 @@ class ChannelSendRetryTest : public testing::Test { }; TEST_F(ChannelSendRetryTest, NoRetrySuccess) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(1) - .WillRepeatedly([]() {}); + .WillRepeatedly([](const TransportLink::Request&, uint32_t) {}); const std::string key = "key"; - sender_->Send(key, {}); -} - -TEST_F(ChannelSendRetryTest, NoRetryFail) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) - .Times(1) - .WillRepeatedly([]() { - throw yacl::LinkError("not valid error code.", brpc::EUNUSED); - }); - const std::string key = "key"; - sender_->Send(key, {}); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + sender_->SendRequestWithRetry(*request, 0); } TEST_F(ChannelSendRetryTest, RetrySuccess) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(2) - .WillOnce([]() { - throw yacl::LinkError("valid error code, will retry once and success.", - brpc::ENOSERVICE); - }) - .WillRepeatedly([]() {}); + .WillOnce(testing::Throw(yacl::LinkError( + "valid error code, will retry once and success.", brpc::ENOSERVICE))) + .WillRepeatedly([](const TransportLink::Request&, uint32_t) {}); const std::string key = "key"; - sender_->Send(key, {}); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + sender_->SendRequestWithRetry(*request, 0); } TEST_F(ChannelSendRetryTest, RetryFail) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(4) - .WillRepeatedly([]() { - throw yacl::LinkError( - "valid error code, will retry max count and fail.", - brpc::ENOSERVICE); - }); + .WillRepeatedly(testing::Throw( + yacl::LinkError("valid error code, will retry max count and fail.", + brpc::ENOSERVICE))); const std::string key = "key"; - sender_->Send(key, {}); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + EXPECT_THROW(sender_->SendRequestWithRetry(*request, 0), yacl::LinkError); } TEST_F(ChannelSendRetryTest, HttpNoRetryFail) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(1) - .WillRepeatedly([]() { - throw yacl::LinkError("not valid http code and no retry.", brpc::EHTTP, - brpc::HTTP_STATUS_GATEWAY_TIMEOUT); - }); + .WillRepeatedly(testing::Throw( + yacl::LinkError("not valid http code and no retry.", brpc::EHTTP, + brpc::HTTP_STATUS_GATEWAY_TIMEOUT))); const std::string key = "key"; - sender_->Send(key, {}); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + EXPECT_THROW(sender_->SendRequestWithRetry(*request, 0), yacl::LinkError); } TEST_F(ChannelSendRetryTest, HttpRetrySuccess) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(2) - .WillOnce([]() { - throw yacl::LinkError("valid http code, will retry once and success.", - brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY); - }) - .WillRepeatedly([]() {}); + .WillOnce(testing::Throw( + yacl::LinkError("valid http code, will retry once and success.", + brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY))) + .WillRepeatedly([](const TransportLink::Request&, uint32_t) {}); const std::string key = "key"; - sender_->Send(key, {}); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + sender_->SendRequestWithRetry(*request, 0); } TEST_F(ChannelSendRetryTest, HttpRetryFail) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + EXPECT_CALL(*sender_delegate_, SendRequest) .Times(4) - .WillRepeatedly([]() { - throw yacl::LinkError("valid http code, will retry max count and fail.", - brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY); - }); - const std::string key = "key"; - sender_->Send(key, {}); -} - -TEST_F(ChannelSendRetryTest, ChunkedNoRetrySuccess) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) - .Times(2) - .WillRepeatedly([]() {}); - const std::string key = "key"; - sender_->Send(key, RandStr(3)); -} - -TEST_F(ChannelSendRetryTest, ChunkedNoRetryFail) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) - .Times(2) // chunk1 fail | chunk2 fail - .WillRepeatedly([]() { - throw yacl::LinkError("not valid error code.", brpc::EUNUSED); - }); - const std::string key = "key"; - sender_->Send(key, RandStr(3)); -} - -TEST_F(ChannelSendRetryTest, ChunkedRetrySuccess) { - EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) - .Times(3) // chunk1 fail | chunk1 retry | chunk2 success - .WillOnce([]() { - throw yacl::LinkError("valid error code, will retry once and success.", - brpc::ENOSERVICE); - }) - .WillRepeatedly([]() {}); + .WillRepeatedly(testing::Throw( + yacl::LinkError("valid http code, will retry max count and fail.", + brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY))); const std::string key = "key"; - sender_->Send(key, RandStr(3)); + auto request = sender_delegate_->PackMonoRequest(key, "t"); + EXPECT_THROW(sender_->SendRequestWithRetry(*request, 0), yacl::LinkError); } } // namespace yacl::link::transport::test diff --git a/yacl/math/galois_field/factory/mcl_factory.cc b/yacl/math/galois_field/factory/mcl_factory.cc index c90eea93..2cca5b9b 100644 --- a/yacl/math/galois_field/factory/mcl_factory.cc +++ b/yacl/math/galois_field/factory/mcl_factory.cc @@ -257,7 +257,7 @@ template T MclField::RandomT() const { const auto per_size = BASE_FP_SIZE; - T ret; + T ret{0}; Buffer buf(per_size * degree); typename T::BaseFp p; for (uint64_t i = 0; i < degree; i++) { @@ -292,7 +292,7 @@ size_t MclField::Serialize(const T& x, uint8_t* buf, template T MclField::DeserializeT(ByteContainerView buffer) const { - T ret; + T ret{0}; ret.deserialize(buffer.data(), buffer.size()); return ret; }