Skip to content

Commit

Permalink
[Add Randomized SVD in PCA] (#300)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?
使用 SPU 优化 PCA 算法
Issue Number: Fixed #259

## Possible side effects?

- Performance:
1. 收敛速度更快(体现在能支持更大的特征维度)
2. 不需要显示的计算原数据集的协方差矩阵
- Backward compatibility:
  • Loading branch information
tarantula-leo authored Aug 15, 2023
1 parent 9857c82 commit f726019
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 57 deletions.
1 change: 1 addition & 0 deletions sml/decomposition/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "pca",
srcs = ["pca.py"],
deps = ["//sml/utils:extmath"],
)
52 changes: 52 additions & 0 deletions sml/decomposition/emulations/3pc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"id": "outsourcing.3pc",
"nodes": {
"node:0": "127.0.0.1:9920",
"node:1": "127.0.0.1:9921",
"node:2": "127.0.0.1:9922",
"node:3": "127.0.0.1:9923",
"node:4": "127.0.0.1:9924"
},
"devices": {
"SPU": {
"kind": "SPU",
"config": {
"node_ids": [
"node:0",
"node:1",
"node:2"
],
"spu_internal_addrs": [
"127.0.0.1:9930",
"127.0.0.1:9931",
"127.0.0.1:9932"
],
"experimental_data_folder": [
"/tmp/spu_data_0/",
"/tmp/spu_data_1/",
"/tmp/spu_data_2/"
],
"runtime_config": {
"protocol": "ABY3",
"field": "FM128",
"fxp_fraction_bits": 30,
"enable_pphlo_profile": true,
"enable_hal_profile": true,
"enable_pphlo_trace": false
}
}
},
"P1": {
"kind": "PYU",
"config": {
"node_id": "node:3"
}
},
"P2": {
"kind": "PYU",
"config": {
"node_id": "node:4"
}
}
}
}
17 changes: 17 additions & 0 deletions sml/decomposition/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,20 @@ py_binary(
"//sml/utils:emulation",
],
)

py_binary(
name = "rsvd_pca_emul",
srcs = ["rsvd_pca_emul.py"],
data = [":conf"],
deps = [
"//sml/decomposition:pca",
"//sml/utils:emulation",
],
)

filegroup(
name = "conf",
srcs = [
"3pc.json",
],
)
6 changes: 3 additions & 3 deletions sml/decomposition/emulations/pca_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from sklearn.decomposition import PCA as SklearnPCA

# Add the library directory to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))

import sml.utils.emulation as emulation
from sml.decomposition.pca import PCA


def emul_PCA(mode: emulation.Mode.MULTIPROCESS):
def emul_powerPCA(mode: emulation.Mode.MULTIPROCESS):
def proc(X):
model = PCA(
method='power_iteration',
Expand Down Expand Up @@ -94,4 +94,4 @@ def proc_reconstruct(X):


if __name__ == "__main__":
emul_PCA(emulation.Mode.MULTIPROCESS)
emul_powerPCA(emulation.Mode.MULTIPROCESS)
107 changes: 107 additions & 0 deletions sml/decomposition/emulations/rsvd_pca_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import jax.numpy as jnp
import jax.random as random
import numpy as np
from sklearn.decomposition import PCA as SklearnPCA

# Add the library directory to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))

import sml.utils.emulation as emulation
from sml.decomposition.pca import PCA


def emul_rsvdPCA(mode: emulation.Mode.MULTIPROCESS):
print("emul rsvdPCA.")

def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale):
model = PCA(
method='rsvd',
n_components=n_components,
n_oversamples=n_oversamples,
random_matrix=random_matrix,
max_power_iter=max_power_iter,
scale=scale,
)

model.fit(X)
X_transformed = model.transform(X)
X_variances = model._variances
X_reconstructed = model.inverse_transform(X_transformed)

return X_transformed, X_variances, X_reconstructed

try:
# bandwidth and latency only work for docker mode
conf_path = "sml/decomposition/emulations/3pc.json"
emulator = emulation.Emulator(conf_path, mode, bandwidth=300, latency=20)
emulator.up()

# Create a simple dataset
X = random.normal(random.PRNGKey(0), (1000, 20))
X_spu = emulator.seal(X)
n_components = 5
n_oversamples = 10
max_power_iter = 300
scale = (10000000, 10000)

# Create random_matrix
random_state = np.random.RandomState(0)
random_matrix = random_state.normal(
size=(X.shape[1], n_components + n_oversamples)
)
random_matrix_spu = emulator.seal(random_matrix)

result = emulator.run(proc, static_argnums=(2, 3, 4, 5))(
X_spu, random_matrix_spu, n_components, n_oversamples, max_power_iter, scale
)
print("X_transformed_spu: ", result[0][:5, :])
print("X_variance_spu: ", result[1])
print("X_reconstructed_spu:", result[2][:5, :])

# The transformed data should have 2 dimensions
assert result[0].shape[1] == n_components

# The mean of the transformed data should be approximately 0
assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)

# Compare with sklearn
model = SklearnPCA(
n_components=n_components,
svd_solver="randomized",
power_iteration_normalizer="QR",
random_state=0,
)
model.fit(X)
X_transformed = model.transform(X)
X_variances = model.explained_variance_
X_reconstructed = model.inverse_transform(X_transformed)

print("X_transformed_sklearn: ", X_transformed[:5, :])
print("X_variances_sklearn: ", X_variances)
print("X_reconstructed_sklearn: ", X_reconstructed[:5, :])

assert np.allclose(X_reconstructed, result[2], atol=1e-1)

finally:
emulator.down()


if __name__ == "__main__":
emul_rsvdPCA(emulation.Mode.MULTIPROCESS)
Loading

0 comments on commit f726019

Please sign in to comment.