Skip to content

Commit

Permalink
Repo Sync (secretflow#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Aug 23, 2023
1 parent 294b897 commit a62c779
Show file tree
Hide file tree
Showing 115 changed files with 5,805 additions and 3,283 deletions.
22 changes: 2 additions & 20 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,6 @@ build:macos --host_macos_minimum_os=11.0
build:linux --copt=-fopenmp
build:linux --linkopt=-fopenmp

build:asan --strip=never
build:asan --copt -fno-sanitize-recover=all
build:asan --copt -fsanitize=address
build:asan --copt -Og
build:asan --copt -g
build:asan --copt -fno-omit-frame-pointer
build:asan --linkopt -fsanitize=address
build:asan --linkopt -static-libasan

build:ubsan --strip=never
build:ubsan --copt -fno-sanitize-recover=all
build:ubsan --copt -fsanitize=undefined
build:ubsan --copt -Og
build:ubsan --copt -g
build:ubsan --copt -fno-omit-frame-pointer
build:ubsan --linkopt -fsanitize=undefined
build:ubsan --linkopt -static-libubsan

build:macos-asan --features=asan
build:macos-ubsan --features=ubsan
build:asan --features=asan
build:ubsan --features=ubsan

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- [API] Add intrinsic support
- [Feature] Support half type
- [Feature] Add Psi Progress
- [Feature] Add SineOp/CosineOp support

## 20230705

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ docker exec -it spu-dev-$(whoami) bash
#### Linux

```sh
Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python==3.8, bazel==6.2.1, golang
Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python>=3.8, bazel==6.2.1, golang

python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-dev.txt
Expand Down
12 changes: 12 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ rules_foreign_cc_dependencies(
register_preinstalled_tools = True,
)

load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")

rules_cuda_dependencies()

register_detected_cuda_toolchains()

load("@xla//:workspace4.bzl", "xla_workspace4")

xla_workspace4()
Expand All @@ -64,3 +70,9 @@ python_configure(
name = "local_config_python",
python_version = "3",
)

load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_repos", "rules_proto_grpc_toolchains")

rules_proto_grpc_toolchains()

rules_proto_grpc_repos()
103 changes: 102 additions & 1 deletion bazel/patches/seal.patch
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ index 31e07441..c34d0a45 100644
// Write the poly_modulus_degree. Note that it will always be positive.
*param_data_ptr++ = static_cast<uint64_t>(poly_modulus_degree_);

+ *param_data_ptr++ = static_cast<uint64_t>(use_special_prime_);
+ *param_data_ptr++ = static_cast<uint64_t>(use_special_prime_);
for (const auto &mod : coeff_modulus_)
{
*param_data_ptr++ = mod.value();
Expand Down Expand Up @@ -118,3 +118,104 @@ index 9e1fbe48..eb71c4ac 100644
parms_id_type parms_id_ = parms_id_zero;
};
} // namespace seal

diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp
index dabd3bab..afaa71dc 100644
--- a/native/src/seal/evaluator.cpp
+++ b/native/src/seal/evaluator.cpp
@@ -2382,6 +2382,7 @@ namespace seal
size_t encrypted_size = encrypted.size();
// Use key_context_data where permutation tables exist since previous runs.
auto galois_tool = context_.key_context_data()->galois_tool();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Size check
if (!product_fits_in(coeff_count, coeff_modulus_size))
@@ -2412,7 +2413,7 @@ namespace seal
// DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION
// BEGIN: Apply Galois for each ciphertext
// Execution order is sensitive, since apply_galois is not inplace!
- if (parms.scheme() == scheme_type::bfv)
+ if (not is_ntt_form)
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2426,7 +2427,7 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp);
}
- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv)
+ else
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2440,10 +2441,6 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp);
}
- else
- {
- throw logic_error("scheme not implemented");
- }

// Wipe encrypted.data(1)
set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1));
@@ -2530,6 +2527,7 @@ namespace seal
auto &key_context_data = *context_.key_context_data();
auto &key_parms = key_context_data.parms();
auto scheme = parms.scheme();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
@@ -2559,14 +2557,6 @@ namespace seal
{
throw invalid_argument("pool is uninitialized");
}
- if (scheme == scheme_type::bfv && encrypted.is_ntt_form())
- {
- throw invalid_argument("BFV encrypted cannot be in NTT form");
- }
- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form())
- {
- throw invalid_argument("CKKS encrypted must be in NTT form");
- }
if (scheme == scheme_type::bgv && !encrypted.is_ntt_form())
{
throw invalid_argument("BGV encrypted must be in NTT form");
@@ -2605,7 +2595,7 @@ namespace seal
set_uint(target_iter, decomp_modulus_size * coeff_count, t_target);

// In CKKS or BGV, t_target is in NTT form; switch back to normal form
- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv)
+ if (is_ntt_form)
{
inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables);
}
@@ -2632,7 +2622,7 @@ namespace seal
ConstCoeffIter t_operand;

// RNS-NTT form exists in input
- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J))
+ if (is_ntt_form && (I == J))
{
t_operand = target_iter[J];
}
@@ -2789,7 +2779,7 @@ namespace seal
SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; });

uint64_t qi_lazy = qi << 1; // some multiples of qi
- if (scheme == scheme_type::ckks)
+ if (is_ntt_form)
{
// This ntt_negacyclic_harvey_lazy results in [0, 4*qi).
ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J));
@@ -2802,7 +2792,7 @@ namespace seal
qi_lazy = qi << 2;
#endif
}
- else if (scheme == scheme_type::bfv)
+ else
{
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
}
22 changes: 22 additions & 0 deletions bazel/patches/xla.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
diff --git a/third_party/tsl/workspace1.bzl b/third_party/tsl/workspace1.bzl
index 4cfb6da82..0e3774834 100644
--- a/third_party/tsl/workspace1.bzl
+++ b/third_party/tsl/workspace1.bzl
@@ -3,7 +3,7 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
+# load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")

# buildifier: disable=unnamed-macro
@@ -14,7 +14,7 @@ def workspace(with_rules_cc = True):
with_rules_cc: whether to load and patch rules_cc repository.
"""
native.register_toolchains("@local_config_python//:py_toolchain")
- rules_cuda_dependencies(with_rules_cc)
+ # rules_cuda_dependencies(with_rules_cc)
rules_pkg_dependencies()

closure_repositories()
28 changes: 26 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ SECRETFLOW_GIT = "https://github.com/secretflow"
YACL_COMMIT_ID = "a9c1d7d119c80eb75d5ec63ee6cd77145dff18c2"

def spu_deps():
_rules_cuda()
_rules_proto_grpc()
_bazel_platform()
_upb()
_com_github_xtensor_xtensor()
Expand Down Expand Up @@ -61,6 +63,24 @@ def spu_deps():
path = "/opt/homebrew/opt/libomp/",
)

def _rules_proto_grpc():
http_archive(
name = "rules_proto_grpc",
sha256 = "928e4205f701b7798ce32f3d2171c1918b363e9a600390a25c876f075f1efc0a",
strip_prefix = "rules_proto_grpc-4.4.0",
urls = [
"https://github.com/rules-proto-grpc/rules_proto_grpc/releases/download/4.4.0/rules_proto_grpc-4.4.0.tar.gz",
],
)

def _rules_cuda():
http_archive(
name = "rules_cuda",
sha256 = "fa1462c4c3104de44489800a1da055f55afa57795789539c835e069818786f71",
strip_prefix = "rules_cuda-cab1fa2dd0e1f8489f566c91a5025856cf5ae572",
urls = ["https://github.com/bazel-contrib/rules_cuda/archive/cab1fa2dd0e1f8489f566c91a5025856cf5ae572.tar.gz"],
)

def _bazel_platform():
http_archive(
name = "platforms",
Expand Down Expand Up @@ -138,8 +158,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "f23282956725050dd07996d3b80d6788ad8aaeba"
OPENXLA_SHA256 = "6e87093f550573de12af247edfa590d76adc6e931da6b551e78dc2c0c2bbd04d"
OPENXLA_COMMIT = "0c99beffabc5d43fa29f121674eb59e14a22c779"
OPENXLA_SHA256 = "d4c7511a496aeb917976c0d8a65de374c395546f0c3d4077d9dfd4df780d7ea8"

SKYLIB_VERSION = "1.3.0"

Expand All @@ -160,6 +180,10 @@ def _com_github_openxla_xla():
sha256 = OPENXLA_SHA256,
strip_prefix = "xla-" + OPENXLA_COMMIT,
type = ".tar.gz",
patch_args = ["-p1"],
patches = [
"@spulib//bazel:patches/xla.patch",
],
urls = [
"https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT),
],
Expand Down
3 changes: 3 additions & 0 deletions bazel/spu.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def spu_cc_binary(
cc_binary(
linkopts = linkopts,
copts = copts + _spu_copts(),
linkstatic = True,
**kargs
)

Expand All @@ -66,6 +67,7 @@ def spu_cc_library(
local_defines = local_defines + [
"SPU_BUILD",
],
linkstatic = True,
**kargs
)

Expand Down Expand Up @@ -111,5 +113,6 @@ def spu_cc_test(
local_defines = local_defines + [
"SPU_BUILD",
],
linkstatic = True,
**kwargs
)
6 changes: 3 additions & 3 deletions docs/reference/xla_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Count: Total = 4, fully supported = 2
| `ceil` | fully |
| `convert` | fully |
| `count_leading_zeros`| no |
| `cosine` | no |
| `cosine` | fully |
| `exponential` | fully |
| `exponential_minus_one`| fully |
| `floor` | fully |
Expand All @@ -46,11 +46,11 @@ Count: Total = 4, fully supported = 2
| `round_nearest_afz`| not |
| `rsqrt` | fully |
| `sign` | partial |
| `sine` | not |
| `sine` | fully |
| `sqrt` | fully |
| `tanh` | fully |

Count: Total = 24, fully supported = 12, partial = 0
Count: Total = 24, fully supported = 16, partial = 1

### XLA binary element-wise ops

Expand Down
18 changes: 8 additions & 10 deletions examples/python/ml/jax_lr/jax_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def save_and_load_model():
print(W_r, b_r)

x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)
print(
"AUC(save_and_load_model)={}".format(
metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
)
)

score = metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
print("AUC(save_and_load_model)={}".format(score))

return score


def compute_score(W_r, b_r, type):
Expand All @@ -163,10 +163,7 @@ def train(x1, x2, y):
x2, _ = ppd.device("P2")(dsutil.breast_cancer)(slice(15, None), True)
W, b = train(x1, x2, y)

W_r, b_r = ppd.get(W), ppd.get(b)
print(W_r, b_r)

return W_r, b_r
return W, b


if __name__ == "__main__":
Expand All @@ -176,5 +173,6 @@ def train(x1, x2, y):
compute_score(w[1], b[1], 'cpu, manual_grad')
print('Run on SPU\n------\n')
w, b = run_on_spu()
compute_score(w, b, 'spu')
w_r, b_r = ppd.get(w), ppd.get(b)
compute_score(w_r, b_r, 'spu')
save_and_load_model()
10 changes: 9 additions & 1 deletion examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_jax_lr(self):
from examples.python.ml.jax_lr import jax_lr

w, b = profile_test_point(jax_lr.run_on_spu)
score = jax_lr.compute_score(w, b, 'spu')
score = jax_lr.compute_score(ppd.get(w), ppd.get(b), 'spu')

self.assertGreater(score, 0.95)

Expand Down Expand Up @@ -222,6 +222,13 @@ def test_torch_experiment(self):
score = torch_experiment.run_inference_on_spu(model)
self.assertGreater(score, 0.9)

def test_save_and_load_model(self):
from examples.python.ml.jax_lr import jax_lr

score = jax_lr.save_and_load_model()
self.assertGreater(score, 0.9)
pass


def suite():
suite = unittest.TestSuite()
Expand All @@ -239,6 +246,7 @@ def suite():
suite.addTest(UnitTests('test_stax_nn'))
suite.addTest(UnitTests('test_tf_experiment'))
suite.addTest(UnitTests('test_torch_experiment'))
suite.addTest(UnitTests('test_save_and_load_model'))
return suite


Expand Down
Loading

0 comments on commit a62c779

Please sign in to comment.