-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Add Randomized SVD in PCA] #300
Conversation
@deadlywing |
我简单了看了一下,暂时有一些问题:
py_library(
name = "pca",
srcs = ["pca.py"],
deps = ["//sml/utils:extmath"], # 需要指定deps,否则无法运行
)
PLUS:emul执行需要本地从源码使用bazel运行,如果您是从pip安装的spu包则无法运行,您可以重新搭建一个本地的运行环境方便后续测试~ Thanks |
按提交的代码数据集大小跑出来的结果是多少?我本地测试应该没有问题,Simulation的。 |
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import jax.numpy as jnp
import numpy as np
from jax import random
from sklearn.decomposition import PCA as SklearnPCA
import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim
# Add the sml directory to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from sml.decomposition.pca import PCA
class UnitTests(unittest.TestCase):
def test_power(self):
print("\ntest power ...")
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
# Test fit_transform
def proc_transform(X):
model = PCA(
method='power_iteration',
n_components=2,
)
model.fit(X)
X_transformed = model.transform(X)
X_variances = model._variances
return X_transformed, X_variances
# Create a simple dataset
X = random.normal(random.PRNGKey(0), (15, 100))
# Run the simulation
result = spsim.sim_jax(sim, proc_transform)(X)
X_np = np.array(X)
# Run fit_transform using sklearn
sklearn_pca = SklearnPCA(n_components=2)
X_transformed_sklearn = sklearn_pca.fit_transform(X_np)
# Compare the transform results
print("X_transformed_sklearn: ", X_transformed_sklearn[:2, :])
print("X_transformed_jax", result[0][:2, :])
# Compare the variance results
print(
"X_transformed_sklearn.explained_variance_: ",
sklearn_pca.explained_variance_,
)
print("X_transformed_jax.explained_variance_: ", result[1])
def test_rsvd(self):
print("\nrsvd test...")
config = spu_pb2.RuntimeConfig(
protocol=spu_pb2.ProtocolKind.ABY3,
field=spu_pb2.FieldType.FM128,
fxp_fraction_bits=30,
)
sim = spsim.Simulator(3, config)
# Test fit_transform
def proc_transform(X, random_matrix):
model = PCA(
method='rsvd',
n_components=2,
random_matrix=random_matrix,
)
model.fit(X)
X_transformed = model.transform(X)
X_variances = model._variances
return X_transformed, X_variances
# Create a simple dataset
X = random.normal(random.PRNGKey(0), (15, 100))
# Create random_matrix
random_state = np.random.RandomState(0)
random_matrix = random_state.normal(size=(X.shape[1], 2))
# Run the simulation
result = spsim.sim_jax(sim, proc_transform)(X, random_matrix)
X_np = np.array(X)
# Run fit_transform using sklearn
sklearn_pca = SklearnPCA(n_components=2)
sklearn_pca.fit(X_np)
X_transformed_sklearn = sklearn_pca.transform(X_np)
# Compare the transform results
print("X_transformed_sklearn: ", X_transformed_sklearn[:2, :])
print("X_transformed_jax", result[0][:2, :])
# Compare the variance results
print(
"X_transformed_sklearn.explained_variance_: ",
sklearn_pca.explained_variance_,
)
print("X_transformed_jax.explained_variance_: ", result[1])
if __name__ == "__main__":
unittest.main() test power ...
X_transformed_sklearn: [[-3.1667597 5.433083 ]
[-5.998851 -2.5643744]]
X_transformed_jax [[ 3.1572075 -5.4363747]
[ 6.003105 2.5544815]]
X_transformed_sklearn.explained_variance_: [13.892871 12.988095]
X_transformed_jax.explained_variance_: [13.891659 12.987503]
rsvd test...
X_transformed_sklearn: [[-3.1667576 5.4330835]
[-5.9988513 -2.5643716]]
X_transformed_jax [[ 6.9492254 -1.3078341]
[-2.1523845 -4.7718153]]
X_transformed_sklearn.explained_variance_: [13.892871 12.988095]
X_transformed_jax.explained_variance_: [12.294682 11.470435] |
你需要用类似1000-10这样的数据集测试,15-100的可能会默认使用full svd,而且数据集大小变化scale参数也需要调整到合适的区间,scale过大会出现下溢,过小会出现上溢。 |
这里的scale大小是针对1000-10的数据集,可以直接用我提交代码的test文件跑一下。
|
我改回1000,10也是一样,误差比较明显。您需要看一下transform以后的向量是否近似一致以及特征值是否一致。 单元测试里的assert测试不包括这些,需要肉眼去看结果... |
我不确定是不是rsvd对特征向量的误差就是比较大,因为我们之前只关注了奇异值,而没有看过U和V矩阵... |
transform的结果是多少?还有特征值 |
之前看varience误差很小,如果你测试的时候发现varience误差也比较大可能是哪里参数或者设置和我这里不一样。 |
n_sample = 1000,n_feature = 10,n_components = 5:
|
X_transformed_sklearn: [[-2.1166942 1.0700845 -0.14631309 -1.3690518 1.8845886 ]
[-1.2855983 -0.7558893 0.71847 -0.00927168 0.47726887]
[ 2.148614 0.5903857 0.9635215 -0.5206121 -0.31438693]
[-0.76127815 1.5483521 0.5180957 -0.6482162 -1.1910266 ]
[ 0.13378008 0.00686819 0.34645978 -0.32356977 0.08244695]]
X_transformed_jax [[ 1.8285344 0.6458801 0.2799378 0.85873485 -0.33970326]
[ 1.5334405 -1.4996622 0.24053855 -0.14607927 -0.5198155 ]
[-1.8808262 0.56276166 -0.79416734 -0.7032572 -0.34177095]
[ 1.2244229 1.2555254 0.6243129 0.9388125 -0.7295229 ]
[ 0.00891032 -0.1405381 0.32475436 0.97715914 0.31833112]]
X_transformed_sklearn.explained_variance_: [1.2137077 1.110193 1.0762373 1.0629445 1.0329939]
X_transformed_jax.explained_variance_: [1.1966664 1.0949516 1.0440481 0.996428 0.9429427] 我就用的你提供的test文件,参数都没动,transform后的结果我截取了前5行; PS: 我本地的scikit-learn版本是1.3.0(理论上这个不应该会有啥区别啊...) |
想起来了,sklearn调用randomized_svd这里的代码我调整过,因为没有加n_oversamples,还有对齐一些参数:
|
sklearn.decomposition._pca._fit_truncated |
我看了一下,sklearn不推荐n_oversamples=0的设置,要不您考虑增加一下这个feature?不然似乎结果还是会和真实值有点偏差。(和full svd的结果比较) |
你是在PCA里调参数的还是在randomized_svd里调的?具体参数是多少 |
我的意思是,我们的目的是尽可能准确的计算svd,但是在n_oversamples=0的时候可能计算不是特别准确(因为和full-svd的结果有比较大的偏差。实际上,即使是sklearn的pca,如果我令n_oversamples特别小,如1,2,那结果也是不准确的) |
所以我觉得支持n_oversamples是必要的 |
另外,我注意到您的commit没有通过python linter测试,麻烦您确保自己commit的文件被正确的格式化过:
|
n_oversamples接口之前没考虑加的原因是,使用了的话在目前情况下会导致很大的性能损失,增加n_oversamples参数后: |
extmath.py里也需要更新下
|
sml/decomposition/tests/pca_test.py
Outdated
# Run fit_transform using sklearn | ||
# Copy to sklearn.decomposition._pca._fit_truncated | ||
# U, S, Vt = randomized_svd(X, n_components=n_components, power_iteration_normalizer = 'QR', svd_lapack_driver = 'gesvd', random_state = 0, n_oversamples = 10) | ||
sklearn_pca = SklearnPCA(n_components=n_components) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sklearn的pca需要手动调整参数
model = SklearnPCA(
n_components=n_components,
svd_solver="randomized",
power_iteration_normalizer="QR",
random_state=0,
)
不能要求后面的用户都去改自己本地的sklearn...
同理请修改emul里的参数
@@ -93,5 +93,91 @@ def proc_reconstruct(X): | |||
emulator.down() | |||
|
|||
|
|||
def emul_rsvdPCA(mode: emulation.Mode.MULTIPROCESS): | |||
def proc(X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要增加能接受random matrix
assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3) | ||
|
||
# Compare with sklearn | ||
model = SklearnPCA(n_components=n_components) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类似tests文件,需要手动设置好sklearn的参数
sml/decomposition/pca.py
Outdated
Additional number of random vectors to sample the range of A so as | ||
to ensure proper conditioning. The total number of random vectors | ||
used to find the range of A is n_components + n_oversamples | ||
max_power_iter : int, default=100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值为300
另外,需要您在emulations目录下重新copy一份config, 然后修改里面的field和fxp_fraction_bits; |
sml/utils/extmath.py
Outdated
Omega = random_matrix / scale[0] | ||
Q = rsvd_iteration(A, Omega, scale[1], n_iter) | ||
B = jnp.dot(Q.T, A) | ||
u_tilde, s, v = svd(B, eigh_iter) | ||
u = jnp.dot(Q, u_tilde) | ||
return u, s, v | ||
u = jnp.dot(Q, u_tilde)[:, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
u = jnp.dot(Q, u_tilde)
就可以了?
sml/decomposition/3pc.json
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个配置文件请放到emulations目录下
def emul_rsvdPCA(mode: emulation.Mode.MULTIPROCESS):
print("emul rsvdPCA.")
def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale):
model = PCA(
method='rsvd',
n_components=n_components,
n_oversamples=n_oversamples,
random_matrix=random_matrix,
max_power_iter=max_power_iter,
scale=scale,
)
model.fit(X)
X_transformed = model.transform(X)
X_variances = model._variances
X_reconstructed = model.inverse_transform(X_transformed)
return X_transformed, X_variances, X_reconstructed
try:
# bandwidth and latency only work for docker mode
conf_path = "sml/decomposition/emulations/3pc.json"
emulator = emulation.Emulator(conf_path, mode, bandwidth=300, latency=20)
emulator.up()
# Create a simple dataset
X = random.normal(random.PRNGKey(0), (1000, 20))
X_spu = emulator.seal(X)
n_components = 5
n_oversamples = 10
max_power_iter = 300
scale = (10000000, 10000)
# Create random_matrix
random_state = np.random.RandomState(0)
random_matrix = random_state.normal(
size=(X.shape[1], n_components + n_oversamples)
)
random_matrix_spu = emulator.seal(random_matrix)
result = emulator.run(proc, static_argnums=(2, 3, 4, 5))(
X_spu, random_matrix_spu, n_components, n_oversamples, max_power_iter, scale
)
print("X_transformed_spu: ", result[0][:5, :])
print("X_variance_spu: ", result[1])
print("X_reconstructed_spu:", result[2][:5, :])
# The transformed data should have 2 dimensions
assert result[0].shape[1] == n_components
# The mean of the transformed data should be approximately 0
assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)
# Compare with sklearn
model = SklearnPCA(
n_components=n_components,
svd_solver="randomized",
power_iteration_normalizer="QR",
random_state=0,
)
model.fit(X)
X_transformed = model.transform(X)
X_variances = model.explained_variance_
X_reconstructed = model.inverse_transform(X_transformed)
print("X_transformed_sklearn: ", X_transformed[:5, :])
print("X_variances_sklearn: ", X_variances)
print("X_reconstructed_sklearn: ", X_reconstructed[:5, :])
assert np.allclose(X_reconstructed, result[2], atol=1e-1)
finally:
emulator.down() |
我把你的emul部分重新整理了一下,,少fit一次, |
辛苦您重新在emulations的文件夹下建一个emul文件哈, 对应的BUILD.bazel也要增加2条哈,麻烦补充一下: py_binary(
name = "pca_emul",
srcs = ["pca_emul.py"], # 这里填你新建的py文件
data = [":conf"],
deps = [
"//sml/decomposition:pca",
"//sml/utils:emulation",
],
)
filegroup(
name = "conf",
srcs = [
"3pc.json",
],
) |
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_binary")
package(default_visibility = ["//visibility:public"])
py_binary(
name = "pca_emul",
srcs = ["pca_emul.py"],
deps = [
"//sml/decomposition:pca",
"//sml/utils:emulation",
],
)
py_binary(
name = "rsvd_pca_emul",
srcs = ["rsvd_pca_emul.py"],
data = [":conf"],
deps = [
"//sml/decomposition:pca",
"//sml/utils:emulation",
],
)
filegroup(
name = "conf",
srcs = [
"3pc.json",
],
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Pull Request
What problem does this PR solve?
使用 SPU 优化 PCA 算法
Issue Number: Fixed #259
Possible side effects?