Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add Randomized SVD in PCA] #300

Merged
merged 32 commits into from
Aug 15, 2023
Merged

[Add Randomized SVD in PCA] #300

merged 32 commits into from
Aug 15, 2023

Conversation

tarantula-leo
Copy link
Contributor

Pull Request

What problem does this PR solve?

使用 SPU 优化 PCA 算法
Issue Number: Fixed #259

Possible side effects?

  • Performance:
  1. 收敛速度更快(体现在能支持更大的特征维度)
  2. 不需要显示的计算原数据集的协方差矩阵
  • Backward compatibility:

@tarantula-leo
Copy link
Contributor Author

@deadlywing
emul文件可能是我这边环境原因,没有办法执行,看下哪里需要调整,simulation可以通过。

@deadlywing
Copy link
Contributor

@deadlywing
emul文件可能是我这边环境原因,没有办法执行,看下哪里需要调整,simulation可以通过。

我简单了看了一下,暂时有一些问题:

  1. pca target的deps缺失
py_library(
    name = "pca",
    srcs = ["pca.py"],
    deps = ["//sml/utils:extmath"],  # 需要指定deps,否则无法运行
)
  1. emul文件没有把输入数据seal,所以实际上emul是明文执行的
  2. 您可以暂时在decomposition目录下新建一个config,将环大小设置为128,fxp也可以自由修改,然后emulaiton的时候,指定这个文件,否则example目录下的config是64bit环,很容易overflow
  3. 我本地emul和test都是没有通过测试的(即使把数据维度调小),辛苦您检查一下实现逻辑

PLUS:emul执行需要本地从源码使用bazel运行,如果您是从pip安装的spu包则无法运行,您可以重新搭建一个本地的运行环境方便后续测试~

Thanks

@tarantula-leo
Copy link
Contributor Author

按提交的代码数据集大小跑出来的结果是多少?我本地测试应该没有问题,Simulation的。

@deadlywing
Copy link
Contributor

# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import unittest

import jax.numpy as jnp
import numpy as np
from jax import random
from sklearn.decomposition import PCA as SklearnPCA
import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim

# Add the sml directory to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))

from sml.decomposition.pca import PCA


class UnitTests(unittest.TestCase):
    def test_power(self):
        print("\ntest power ...")
        sim = spsim.Simulator.simple(
            3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
        )

        # Test fit_transform
        def proc_transform(X):
            model = PCA(
                method='power_iteration',
                n_components=2,
            )

            model.fit(X)
            X_transformed = model.transform(X)
            X_variances = model._variances

            return X_transformed, X_variances

        # Create a simple dataset
        X = random.normal(random.PRNGKey(0), (15, 100))

        # Run the simulation
        result = spsim.sim_jax(sim, proc_transform)(X)

        X_np = np.array(X)

        # Run fit_transform using sklearn
        sklearn_pca = SklearnPCA(n_components=2)
        X_transformed_sklearn = sklearn_pca.fit_transform(X_np)

        # Compare the transform results
        print("X_transformed_sklearn: ", X_transformed_sklearn[:2, :])
        print("X_transformed_jax", result[0][:2, :])

        # Compare the variance results
        print(
            "X_transformed_sklearn.explained_variance_: ",
            sklearn_pca.explained_variance_,
        )
        print("X_transformed_jax.explained_variance_: ", result[1])

    def test_rsvd(self):
        print("\nrsvd test...")

        config = spu_pb2.RuntimeConfig(
            protocol=spu_pb2.ProtocolKind.ABY3,
            field=spu_pb2.FieldType.FM128,
            fxp_fraction_bits=30,
        )
        sim = spsim.Simulator(3, config)

        # Test fit_transform
        def proc_transform(X, random_matrix):
            model = PCA(
                method='rsvd',
                n_components=2,
                random_matrix=random_matrix,
            )

            model.fit(X)
            X_transformed = model.transform(X)
            X_variances = model._variances

            return X_transformed, X_variances

        # Create a simple dataset
        X = random.normal(random.PRNGKey(0), (15, 100))

        # Create random_matrix
        random_state = np.random.RandomState(0)
        random_matrix = random_state.normal(size=(X.shape[1], 2))

        # Run the simulation
        result = spsim.sim_jax(sim, proc_transform)(X, random_matrix)

        X_np = np.array(X)

        # Run fit_transform using sklearn
        sklearn_pca = SklearnPCA(n_components=2)
        sklearn_pca.fit(X_np)
        X_transformed_sklearn = sklearn_pca.transform(X_np)

        # Compare the transform results
        print("X_transformed_sklearn: ", X_transformed_sklearn[:2, :])
        print("X_transformed_jax", result[0][:2, :])

        # Compare the variance results
        print(
            "X_transformed_sklearn.explained_variance_: ",
            sklearn_pca.explained_variance_,
        )
        print("X_transformed_jax.explained_variance_: ", result[1])


if __name__ == "__main__":
    unittest.main()
test power ...
X_transformed_sklearn:  [[-3.1667597  5.433083 ]
 [-5.998851  -2.5643744]]
X_transformed_jax [[ 3.1572075 -5.4363747]
 [ 6.003105   2.5544815]]
X_transformed_sklearn.explained_variance_:  [13.892871 12.988095]
X_transformed_jax.explained_variance_:  [13.891659 12.987503]

rsvd test...
X_transformed_sklearn:  [[-3.1667576  5.4330835]
 [-5.9988513 -2.5643716]]
X_transformed_jax [[ 6.9492254 -1.3078341]
 [-2.1523845 -4.7718153]]
X_transformed_sklearn.explained_variance_:  [13.892871 12.988095]
X_transformed_jax.explained_variance_:  [12.294682 11.470435]

@tarantula-leo
Copy link
Contributor Author

你需要用类似1000-10这样的数据集测试,15-100的可能会默认使用full svd,而且数据集大小变化scale参数也需要调整到合适的区间,scale过大会出现下溢,过小会出现上溢。

@tarantula-leo
Copy link
Contributor Author

这里的scale大小是针对1000-10的数据集,可以直接用我提交代码的test文件跑一下。

if scale is None:
        scale = [10000000, 10000]

@deadlywing
Copy link
Contributor

我改回1000,10也是一样,误差比较明显。您需要看一下transform以后的向量是否近似一致以及特征值是否一致。

单元测试里的assert测试不包括这些,需要肉眼去看结果...

@deadlywing
Copy link
Contributor

我不确定是不是rsvd对特征向量的误差就是比较大,因为我们之前只关注了奇异值,而没有看过U和V矩阵...

@tarantula-leo
Copy link
Contributor Author

transform的结果是多少?还有特征值

@tarantula-leo
Copy link
Contributor Author

之前看varience误差很小,如果你测试的时候发现varience误差也比较大可能是哪里参数或者设置和我这里不一样。

@tarantula-leo
Copy link
Contributor Author

n_sample = 1000,n_feature = 10,n_components = 5:

X_transformed_sklearn:  [[ 1.8261861  -0.65116286  0.2735902  -0.8654603  -0.3307426 ]
 [ 1.5359008   1.4993718   0.2324587   0.1412468  -0.5224671 ]
 [-1.8814251  -0.5651471  -0.7853469   0.7124165  -0.34269035]
 ...
 [ 0.2749451   1.6580818   1.3672045  -1.6250662   0.8621    ]
 [-1.1529483  -0.05411027  1.0677718  -1.8566325  -0.1780819 ]
 [ 0.04752983  0.17162137 -0.20342752  0.13385767 -0.09883766]]
X_transformed_jax [[ 1.8237323   0.64224774  0.2782427   0.8701657  -0.33234033]
 [ 1.5347306  -1.5002834   0.24384737 -0.14954297 -0.5210677 ]
 [-1.8802099   0.5633879  -0.79882276 -0.70185053 -0.33993524]
 ...
 [ 0.27663878 -1.6467705   1.4006331   1.5978043   0.8619212 ]
 [-1.153742    0.05693035  1.0916754   1.8462943  -0.18820946]
 [ 0.04826732 -0.1734069  -0.20169292 -0.13092276 -0.10489745]]
X_transformed_sklearn.explained_variance_:  [1.1966611 1.0949419 1.043913  0.9964739 0.9429506]
X_transformed_jax.explained_variance_:  [1.1965173  1.0949839  1.0438068  0.99661297 0.94274837]

@deadlywing
Copy link
Contributor

X_transformed_sklearn:  [[-2.1166942   1.0700845  -0.14631309 -1.3690518   1.8845886 ]
 [-1.2855983  -0.7558893   0.71847    -0.00927168  0.47726887]
 [ 2.148614    0.5903857   0.9635215  -0.5206121  -0.31438693]
 [-0.76127815  1.5483521   0.5180957  -0.6482162  -1.1910266 ]
 [ 0.13378008  0.00686819  0.34645978 -0.32356977  0.08244695]]
X_transformed_jax [[ 1.8285344   0.6458801   0.2799378   0.85873485 -0.33970326]
 [ 1.5334405  -1.4996622   0.24053855 -0.14607927 -0.5198155 ]
 [-1.8808262   0.56276166 -0.79416734 -0.7032572  -0.34177095]
 [ 1.2244229   1.2555254   0.6243129   0.9388125  -0.7295229 ]
 [ 0.00891032 -0.1405381   0.32475436  0.97715914  0.31833112]]
X_transformed_sklearn.explained_variance_:  [1.2137077 1.110193  1.0762373 1.0629445 1.0329939]
X_transformed_jax.explained_variance_:  [1.1966664 1.0949516 1.0440481 0.996428  0.9429427]

我就用的你提供的test文件,参数都没动,transform后的结果我截取了前5行;
好像我们的运行结果区别在于sklearn,但是我手动设置过svd_solver,无论是full还是randomize结果都是一致的,,

PS: 我本地的scikit-learn版本是1.3.0(理论上这个不应该会有啥区别啊...)

@tarantula-leo
Copy link
Contributor Author

想起来了,sklearn调用randomized_svd这里的代码我调整过,因为没有加n_oversamples,还有对齐一些参数:

        if svd_solver == "arpack":
            v0 = _init_arpack_v0(min(X.shape), random_state)
            U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0)
            # svds doesn't abide by scipy.linalg.svd/randomized_svd
            # conventions, so reverse its outputs.
            S = S[::-1]
            # flip eigenvectors' sign to enforce deterministic output
            U, Vt = svd_flip(U[:, ::-1], Vt[::-1])

        elif svd_solver == "randomized":
            # sign flipping is done inside
            # U, S, Vt = randomized_svd(
            #     X,
            #     n_components=n_components,
            #     n_oversamples=self.n_oversamples,
            #     n_iter=self.iterated_power,
            #     power_iteration_normalizer=self.power_iteration_normalizer,
            #     flip_sign=True,
            #     random_state=random_state,
            # )
            U, S, Vt = randomized_svd(X,n_components=n_components,power_iteration_normalizer = 'QR',svd_lapack_driver = 'gesvd',n_oversamples=0, random_state=0)

@tarantula-leo
Copy link
Contributor Author

sklearn.decomposition._pca._fit_truncated

@deadlywing
Copy link
Contributor

想起来了,sklearn调用randomized_svd这里的代码我调整过,因为没有加n_oversamples,还有对齐一些参数:

        if svd_solver == "arpack":
            v0 = _init_arpack_v0(min(X.shape), random_state)
            U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0)
            # svds doesn't abide by scipy.linalg.svd/randomized_svd
            # conventions, so reverse its outputs.
            S = S[::-1]
            # flip eigenvectors' sign to enforce deterministic output
            U, Vt = svd_flip(U[:, ::-1], Vt[::-1])

        elif svd_solver == "randomized":
            # sign flipping is done inside
            # U, S, Vt = randomized_svd(
            #     X,
            #     n_components=n_components,
            #     n_oversamples=self.n_oversamples,
            #     n_iter=self.iterated_power,
            #     power_iteration_normalizer=self.power_iteration_normalizer,
            #     flip_sign=True,
            #     random_state=random_state,
            # )
            U, S, Vt = randomized_svd(X,n_components=n_components,power_iteration_normalizer = 'QR',svd_lapack_driver = 'gesvd',n_oversamples=0, random_state=0)

我看了一下,sklearn不推荐n_oversamples=0的设置,要不您考虑增加一下这个feature?不然似乎结果还是会和真实值有点偏差。(和full svd的结果比较)

@tarantula-leo
Copy link
Contributor Author

你是在PCA里调参数的还是在randomized_svd里调的?具体参数是多少

@deadlywing
Copy link
Contributor

我的意思是,我们的目的是尽可能准确的计算svd,但是在n_oversamples=0的时候可能计算不是特别准确(因为和full-svd的结果有比较大的偏差。实际上,即使是sklearn的pca,如果我令n_oversamples特别小,如1,2,那结果也是不准确的)

@deadlywing
Copy link
Contributor

所以我觉得支持n_oversamples是必要的

@deadlywing
Copy link
Contributor

另外,我注意到您的commit没有通过python linter测试,麻烦您确保自己commit的文件被正确的格式化过:

  • python文件需要使用black格式化,并且使用isort优化import顺序

@tarantula-leo
Copy link
Contributor Author

n_oversamples接口之前没考虑加的原因是,使用了的话在目前情况下会导致很大的性能损失,增加n_oversamples参数后:
一方面naive svd计算量会增加n_oversamples/n_components倍;
另一方面,power_iteration的迭代次数随着矩阵增大也大约需要增加n_oversamples/n_components倍。
rsvd可以把原power_iteration中的n_feature降低到n_components(+n_oversamples),在n_feature大的时候会体现的明显,但目前受限于使用power_iteration计算特征值/特征向量,没法使用特别大的n_feature。
整体看下来后面是需要去针对eig/eigh函数看下能如何在spu中实现高效的算法。

@tarantula-leo
Copy link
Contributor Author

extmath.py里也需要更新下

def randomized_svd(
    A,
    n_components,
    n_oversamples,
    random_matrix,
    n_iter=4,
    scale=None,
    eigh_iter=300,
):
    if scale is None:
        scale = [10000000, 10000]
    assert random_matrix.shape == (
        A.shape[1],
        n_components + n_oversamples,
    ), f"Expected random_matrix to be ({A.shape[1]}, {n_components + n_oversamples}) array, got {random_matrix.shape}"
    Omega = random_matrix / scale[0]
    Q = rsvd_iteration(A, Omega, scale[1], n_iter)
    B = jnp.dot(Q.T, A)
    u_tilde, s, v = svd(B, eigh_iter)
    u = jnp.dot(Q, u_tilde)[:,:]
    return u[:, :n_components], s[:n_components], v[:n_components, :]

@deadlywing
Copy link
Contributor

我试了一下,增加n_oversamples后精度确实明显提高了不少,感觉可以整理一下这部分的内容,先merge进来?

PS:现在还需要修改的点有
image
(您可以按我的提示先修改emul部分,我可以帮您测试运行)

  1. max_iter这个参数可以修改为max_power_iter,明确参数含义
  2. extmath.py需要您这边commit一下更新
  3. 需要使用black和isort格式化一下所有相关的python文件,否则ci无法通过

Thanks

sml/utils/BUILD.bazel Show resolved Hide resolved
# Run fit_transform using sklearn
# Copy to sklearn.decomposition._pca._fit_truncated
# U, S, Vt = randomized_svd(X, n_components=n_components, power_iteration_normalizer = 'QR', svd_lapack_driver = 'gesvd', random_state = 0, n_oversamples = 10)
sklearn_pca = SklearnPCA(n_components=n_components)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sklearn的pca需要手动调整参数

model = SklearnPCA(
            n_components=n_components,
            svd_solver="randomized",
            power_iteration_normalizer="QR",
            random_state=0,
        )

不能要求后面的用户都去改自己本地的sklearn...
同理请修改emul里的参数

@@ -93,5 +93,91 @@ def proc_reconstruct(X):
emulator.down()


def emul_rsvdPCA(mode: emulation.Mode.MULTIPROCESS):
def proc(X):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要增加能接受random matrix

sml/decomposition/emulations/pca_emul.py Outdated Show resolved Hide resolved
sml/decomposition/emulations/pca_emul.py Outdated Show resolved Hide resolved
assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)

# Compare with sklearn
model = SklearnPCA(n_components=n_components)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似tests文件,需要手动设置好sklearn的参数

Additional number of random vectors to sample the range of A so as
to ensure proper conditioning. The total number of random vectors
used to find the range of A is n_components + n_oversamples
max_power_iter : int, default=100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值为300

@deadlywing
Copy link
Contributor

另外,需要您在emulations目录下重新copy一份config,
https://github.com/secretflow/spu/blob/main/examples/python/conf/3pc.json

然后修改里面的field和fxp_fraction_bits;
然后在创建emulator的时候把这个config文件路径传进去...

Omega = random_matrix / scale[0]
Q = rsvd_iteration(A, Omega, scale[1], n_iter)
B = jnp.dot(Q.T, A)
u_tilde, s, v = svd(B, eigh_iter)
u = jnp.dot(Q, u_tilde)
return u, s, v
u = jnp.dot(Q, u_tilde)[:, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u = jnp.dot(Q, u_tilde)就可以了?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个配置文件请放到emulations目录下

@deadlywing
Copy link
Contributor

def emul_rsvdPCA(mode: emulation.Mode.MULTIPROCESS):
    print("emul rsvdPCA.")

    def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale):
        model = PCA(
            method='rsvd',
            n_components=n_components,
            n_oversamples=n_oversamples,
            random_matrix=random_matrix,
            max_power_iter=max_power_iter,
            scale=scale,
        )

        model.fit(X)
        X_transformed = model.transform(X)
        X_variances = model._variances
        X_reconstructed = model.inverse_transform(X_transformed)

        return X_transformed, X_variances, X_reconstructed

    try:
        # bandwidth and latency only work for docker mode
        conf_path = "sml/decomposition/emulations/3pc.json"
        emulator = emulation.Emulator(conf_path, mode, bandwidth=300, latency=20)
        emulator.up()

        # Create a simple dataset
        X = random.normal(random.PRNGKey(0), (1000, 20))
        X_spu = emulator.seal(X)
        n_components = 5
        n_oversamples = 10
        max_power_iter = 300
        scale = (10000000, 10000)

        # Create random_matrix
        random_state = np.random.RandomState(0)
        random_matrix = random_state.normal(
            size=(X.shape[1], n_components + n_oversamples)
        )
        random_matrix_spu = emulator.seal(random_matrix)

        result = emulator.run(proc, static_argnums=(2, 3, 4, 5))(
            X_spu, random_matrix_spu, n_components, n_oversamples, max_power_iter, scale
        )
        print("X_transformed_spu: ", result[0][:5, :])
        print("X_variance_spu: ", result[1])
        print("X_reconstructed_spu:", result[2][:5, :])

        # The transformed data should have 2 dimensions
        assert result[0].shape[1] == n_components

        # The mean of the transformed data should be approximately 0
        assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)

        # Compare with sklearn
        model = SklearnPCA(
            n_components=n_components,
            svd_solver="randomized",
            power_iteration_normalizer="QR",
            random_state=0,
        )
        model.fit(X)
        X_transformed = model.transform(X)
        X_variances = model.explained_variance_
        X_reconstructed = model.inverse_transform(X_transformed)

        print("X_transformed_sklearn: ", X_transformed[:5, :])
        print("X_variances_sklearn: ", X_variances)
        print("X_reconstructed_sklearn: ", X_reconstructed[:5, :])

        assert np.allclose(X_reconstructed, result[2], atol=1e-1)

    finally:
        emulator.down()

@deadlywing
Copy link
Contributor

我把你的emul部分重新整理了一下,,少fit一次,

@deadlywing
Copy link
Contributor

辛苦您重新在emulations的文件夹下建一个emul文件哈,

对应的BUILD.bazel也要增加2条哈,麻烦补充一下:

py_binary(
    name = "pca_emul",
    srcs = ["pca_emul.py"],  # 这里填你新建的py文件
    data = [":conf"],  
    deps = [
        "//sml/decomposition:pca",
        "//sml/utils:emulation",
    ],
)

filegroup(
    name = "conf",
    srcs = [
        "3pc.json",
    ],
)

@deadlywing
Copy link
Contributor

# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
    name = "pca_emul",
    srcs = ["pca_emul.py"],
    deps = [
        "//sml/decomposition:pca",
        "//sml/utils:emulation",
    ],
)

py_binary(
    name = "rsvd_pca_emul",
    srcs = ["rsvd_pca_emul.py"],
    data = [":conf"],
    deps = [
        "//sml/decomposition:pca",
        "//sml/utils:emulation",
    ],
)

filegroup(
    name = "conf",
    srcs = [
        "3pc.json",
    ],
)

Copy link
Contributor

@deadlywing deadlywing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@deadlywing deadlywing merged commit f726019 into secretflow:main Aug 15, 2023
6 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Aug 15, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

使用 SPU 优化 PCA 算法
2 participants