Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo Sync #326

Merged
merged 3 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,6 @@ build:macos --host_macos_minimum_os=11.0
build:linux --copt=-fopenmp
build:linux --linkopt=-fopenmp

build:asan --strip=never
build:asan --copt -fno-sanitize-recover=all
build:asan --copt -fsanitize=address
build:asan --copt -Og
build:asan --copt -g
build:asan --copt -fno-omit-frame-pointer
build:asan --linkopt -fsanitize=address
build:asan --linkopt -static-libasan

build:ubsan --strip=never
build:ubsan --copt -fno-sanitize-recover=all
build:ubsan --copt -fsanitize=undefined
build:ubsan --copt -Og
build:ubsan --copt -g
build:ubsan --copt -fno-omit-frame-pointer
build:ubsan --linkopt -fsanitize=undefined
build:ubsan --linkopt -static-libubsan

build:macos-asan --features=asan
build:macos-ubsan --features=ubsan
build:asan --features=asan
build:ubsan --features=ubsan

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- [API] Add intrinsic support
- [Feature] Support half type
- [Feature] Add Psi Progress
- [Feature] Add SineOp/CosineOp support

## 20230705

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ docker exec -it spu-dev-$(whoami) bash
#### Linux

```sh
Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python==3.8, bazel==6.2.1, golang
Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python>=3.8, bazel==6.2.1, golang

python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-dev.txt
Expand Down
12 changes: 12 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ rules_foreign_cc_dependencies(
register_preinstalled_tools = True,
)

load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")

rules_cuda_dependencies()

register_detected_cuda_toolchains()

load("@xla//:workspace4.bzl", "xla_workspace4")

xla_workspace4()
Expand All @@ -64,3 +70,9 @@ python_configure(
name = "local_config_python",
python_version = "3",
)

load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_repos", "rules_proto_grpc_toolchains")

rules_proto_grpc_toolchains()

rules_proto_grpc_repos()
103 changes: 102 additions & 1 deletion bazel/patches/seal.patch
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ index 31e07441..c34d0a45 100644
// Write the poly_modulus_degree. Note that it will always be positive.
*param_data_ptr++ = static_cast<uint64_t>(poly_modulus_degree_);

+ *param_data_ptr++ = static_cast<uint64_t>(use_special_prime_);
+ *param_data_ptr++ = static_cast<uint64_t>(use_special_prime_);
for (const auto &mod : coeff_modulus_)
{
*param_data_ptr++ = mod.value();
Expand Down Expand Up @@ -118,3 +118,104 @@ index 9e1fbe48..eb71c4ac 100644
parms_id_type parms_id_ = parms_id_zero;
};
} // namespace seal

diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp
index dabd3bab..afaa71dc 100644
--- a/native/src/seal/evaluator.cpp
+++ b/native/src/seal/evaluator.cpp
@@ -2382,6 +2382,7 @@ namespace seal
size_t encrypted_size = encrypted.size();
// Use key_context_data where permutation tables exist since previous runs.
auto galois_tool = context_.key_context_data()->galois_tool();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Size check
if (!product_fits_in(coeff_count, coeff_modulus_size))
@@ -2412,7 +2413,7 @@ namespace seal
// DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION
// BEGIN: Apply Galois for each ciphertext
// Execution order is sensitive, since apply_galois is not inplace!
- if (parms.scheme() == scheme_type::bfv)
+ if (not is_ntt_form)
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2426,7 +2427,7 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp);
}
- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv)
+ else
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2440,10 +2441,6 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp);
}
- else
- {
- throw logic_error("scheme not implemented");
- }

// Wipe encrypted.data(1)
set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1));
@@ -2530,6 +2527,7 @@ namespace seal
auto &key_context_data = *context_.key_context_data();
auto &key_parms = key_context_data.parms();
auto scheme = parms.scheme();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
@@ -2559,14 +2557,6 @@ namespace seal
{
throw invalid_argument("pool is uninitialized");
}
- if (scheme == scheme_type::bfv && encrypted.is_ntt_form())
- {
- throw invalid_argument("BFV encrypted cannot be in NTT form");
- }
- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form())
- {
- throw invalid_argument("CKKS encrypted must be in NTT form");
- }
if (scheme == scheme_type::bgv && !encrypted.is_ntt_form())
{
throw invalid_argument("BGV encrypted must be in NTT form");
@@ -2605,7 +2595,7 @@ namespace seal
set_uint(target_iter, decomp_modulus_size * coeff_count, t_target);

// In CKKS or BGV, t_target is in NTT form; switch back to normal form
- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv)
+ if (is_ntt_form)
{
inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables);
}
@@ -2632,7 +2622,7 @@ namespace seal
ConstCoeffIter t_operand;

// RNS-NTT form exists in input
- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J))
+ if (is_ntt_form && (I == J))
{
t_operand = target_iter[J];
}
@@ -2789,7 +2779,7 @@ namespace seal
SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; });

uint64_t qi_lazy = qi << 1; // some multiples of qi
- if (scheme == scheme_type::ckks)
+ if (is_ntt_form)
{
// This ntt_negacyclic_harvey_lazy results in [0, 4*qi).
ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J));
@@ -2802,7 +2792,7 @@ namespace seal
qi_lazy = qi << 2;
#endif
}
- else if (scheme == scheme_type::bfv)
+ else
{
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
}
22 changes: 22 additions & 0 deletions bazel/patches/xla.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
diff --git a/third_party/tsl/workspace1.bzl b/third_party/tsl/workspace1.bzl
index 4cfb6da82..0e3774834 100644
--- a/third_party/tsl/workspace1.bzl
+++ b/third_party/tsl/workspace1.bzl
@@ -3,7 +3,7 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
+# load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")

# buildifier: disable=unnamed-macro
@@ -14,7 +14,7 @@ def workspace(with_rules_cc = True):
with_rules_cc: whether to load and patch rules_cc repository.
"""
native.register_toolchains("@local_config_python//:py_toolchain")
- rules_cuda_dependencies(with_rules_cc)
+ # rules_cuda_dependencies(with_rules_cc)
rules_pkg_dependencies()

closure_repositories()
28 changes: 26 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ SECRETFLOW_GIT = "https://github.com/secretflow"
YACL_COMMIT_ID = "a9c1d7d119c80eb75d5ec63ee6cd77145dff18c2"

def spu_deps():
_rules_cuda()
_rules_proto_grpc()
_bazel_platform()
_upb()
_com_github_xtensor_xtensor()
Expand Down Expand Up @@ -61,6 +63,24 @@ def spu_deps():
path = "/opt/homebrew/opt/libomp/",
)

def _rules_proto_grpc():
http_archive(
name = "rules_proto_grpc",
sha256 = "928e4205f701b7798ce32f3d2171c1918b363e9a600390a25c876f075f1efc0a",
strip_prefix = "rules_proto_grpc-4.4.0",
urls = [
"https://github.com/rules-proto-grpc/rules_proto_grpc/releases/download/4.4.0/rules_proto_grpc-4.4.0.tar.gz",
],
)

def _rules_cuda():
http_archive(
name = "rules_cuda",
sha256 = "fa1462c4c3104de44489800a1da055f55afa57795789539c835e069818786f71",
strip_prefix = "rules_cuda-cab1fa2dd0e1f8489f566c91a5025856cf5ae572",
urls = ["https://github.com/bazel-contrib/rules_cuda/archive/cab1fa2dd0e1f8489f566c91a5025856cf5ae572.tar.gz"],
)

def _bazel_platform():
http_archive(
name = "platforms",
Expand Down Expand Up @@ -138,8 +158,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "f23282956725050dd07996d3b80d6788ad8aaeba"
OPENXLA_SHA256 = "6e87093f550573de12af247edfa590d76adc6e931da6b551e78dc2c0c2bbd04d"
OPENXLA_COMMIT = "0c99beffabc5d43fa29f121674eb59e14a22c779"
OPENXLA_SHA256 = "d4c7511a496aeb917976c0d8a65de374c395546f0c3d4077d9dfd4df780d7ea8"

SKYLIB_VERSION = "1.3.0"

Expand All @@ -160,6 +180,10 @@ def _com_github_openxla_xla():
sha256 = OPENXLA_SHA256,
strip_prefix = "xla-" + OPENXLA_COMMIT,
type = ".tar.gz",
patch_args = ["-p1"],
patches = [
"@spulib//bazel:patches/xla.patch",
],
urls = [
"https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT),
],
Expand Down
3 changes: 3 additions & 0 deletions bazel/spu.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def spu_cc_binary(
cc_binary(
linkopts = linkopts,
copts = copts + _spu_copts(),
linkstatic = True,
**kargs
)

Expand All @@ -66,6 +67,7 @@ def spu_cc_library(
local_defines = local_defines + [
"SPU_BUILD",
],
linkstatic = True,
**kargs
)

Expand Down Expand Up @@ -111,5 +113,6 @@ def spu_cc_test(
local_defines = local_defines + [
"SPU_BUILD",
],
linkstatic = True,
**kwargs
)
6 changes: 3 additions & 3 deletions docs/reference/xla_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Count: Total = 4, fully supported = 2
| `ceil` | fully |
| `convert` | fully |
| `count_leading_zeros`| no |
| `cosine` | no |
| `cosine` | fully |
| `exponential` | fully |
| `exponential_minus_one`| fully |
| `floor` | fully |
Expand All @@ -46,11 +46,11 @@ Count: Total = 4, fully supported = 2
| `round_nearest_afz`| not |
| `rsqrt` | fully |
| `sign` | partial |
| `sine` | not |
| `sine` | fully |
| `sqrt` | fully |
| `tanh` | fully |

Count: Total = 24, fully supported = 12, partial = 0
Count: Total = 24, fully supported = 16, partial = 1

### XLA binary element-wise ops

Expand Down
18 changes: 8 additions & 10 deletions examples/python/ml/jax_lr/jax_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def save_and_load_model():
print(W_r, b_r)

x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)
print(
"AUC(save_and_load_model)={}".format(
metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
)
)

score = metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
print("AUC(save_and_load_model)={}".format(score))

return score


def compute_score(W_r, b_r, type):
Expand All @@ -163,10 +163,7 @@ def train(x1, x2, y):
x2, _ = ppd.device("P2")(dsutil.breast_cancer)(slice(15, None), True)
W, b = train(x1, x2, y)

W_r, b_r = ppd.get(W), ppd.get(b)
print(W_r, b_r)

return W_r, b_r
return W, b


if __name__ == "__main__":
Expand All @@ -176,5 +173,6 @@ def train(x1, x2, y):
compute_score(w[1], b[1], 'cpu, manual_grad')
print('Run on SPU\n------\n')
w, b = run_on_spu()
compute_score(w, b, 'spu')
w_r, b_r = ppd.get(w), ppd.get(b)
compute_score(w_r, b_r, 'spu')
save_and_load_model()
10 changes: 9 additions & 1 deletion examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_jax_lr(self):
from examples.python.ml.jax_lr import jax_lr

w, b = profile_test_point(jax_lr.run_on_spu)
score = jax_lr.compute_score(w, b, 'spu')
score = jax_lr.compute_score(ppd.get(w), ppd.get(b), 'spu')

self.assertGreater(score, 0.95)

Expand Down Expand Up @@ -222,6 +222,13 @@ def test_torch_experiment(self):
score = torch_experiment.run_inference_on_spu(model)
self.assertGreater(score, 0.9)

def test_save_and_load_model(self):
from examples.python.ml.jax_lr import jax_lr

score = jax_lr.save_and_load_model()
self.assertGreater(score, 0.9)
pass


def suite():
suite = unittest.TestSuite()
Expand All @@ -239,6 +246,7 @@ def suite():
suite.addTest(UnitTests('test_stax_nn'))
suite.addTest(UnitTests('test_tf_experiment'))
suite.addTest(UnitTests('test_torch_experiment'))
suite.addTest(UnitTests('test_save_and_load_model'))
return suite


Expand Down
Loading
Loading