Skip to content

Commit

Permalink
Add very initial comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
ntjohnson1 committed Dec 23, 2023
1 parent eb6a74d commit f67c7d4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 5 deletions.
2 changes: 2 additions & 0 deletions py-tensor/benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Initial Benchmarking
This is a crude starting point to comparing runtimes to pyttb.
45 changes: 45 additions & 0 deletions py-tensor/benchmark/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pyttb as ttb
from tensor_ext import Dense, Kruskal, cp_als
from typing import Tuple
import time
import numpy as np

def run_pyttb(source: ttb.tensor, init: ttb.ktensor, rank: int) -> ttb.ktensor:
M, _, _ = ttb.cp_als(source, rank, init=init, printitn=-1)
return M

def run_rusty(source:Dense, init: Kruskal, rank:int) -> Kruskal:
M = cp_als(source, init, rank)
return M

def benchmark(shape: Tuple[int,...], num_iters: int):
pyttb_time = 0.0
rusty_time = 0.0
for _ in range(num_iters):
data = np.random.random(shape)
rank = 2
weights = np.ones((rank,))
factors = tuple(np.random.random((first, rank)) for first in shape)
py_tensor = ttb.tensor(data)
py_ktensor = ttb.ktensor(list(factors), weights)
rust_dense = Dense(data)
rust_kruskal = Kruskal(weights, factors)

# TODO should provide init here for fair comparison
start = time.time()
pyttb_result = run_pyttb(py_tensor, py_ktensor, rank)
pyttb_time += time.time() - start

start = time.time()
rusty_result = run_rusty(rust_dense, rust_kruskal, rank)
rusty_time += time.time() - start

np.testing.assert_allclose(pyttb_result.full().data, rusty_result.full().data)
print(
f"Pyttb time: {pyttb_time}\n"
f"Rust time: {rusty_time}\n"
f"Shape: {shape}"
)

if __name__=="__main__":
benchmark((100, 100, 100), 10)
1 change: 1 addition & 0 deletions py-tensor/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pytest",
"pyttb@git+https://github.com/sandialabs/pyttb",
]
8 changes: 3 additions & 5 deletions src/cp/cp_als.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn cp_als(
Array::<f64, Ix2>::zeros((input_tensor.shape[dimorder[dimorder.len() - 1]], rank));

if printitn > 0 {
print!("CP ALS:");
print!("CP ALS:\n");
}

// Main Loop: Iterate until convergence
Expand Down Expand Up @@ -65,11 +65,10 @@ pub fn cp_als(
y = y * utu.slice(s![.., .., i]);
}
}

if y.abs_diff_eq(&Array::<f64, Ix2>::zeros((rank, rank)), 1e-8) {
if y.abs_diff_eq(&Array::<f64, Ix2>::zeros((rank, rank)), 1e-16) {
factors_new = Array::<f64, Ix2>::zeros(factors_new.raw_dim());
} else {
for i in 0..ndim {
for i in 0..input_tensor.shape[*n] {
// TODO using same y every time so update to more efficient pre-factor
let mut update = factors_new.slice_mut(s![i, ..]);
update.assign(&y.t().solve(&update.t()).unwrap().t());
Expand All @@ -92,7 +91,6 @@ pub fn cp_als(
}

factors[*n] = factors_new;
//FIXME: Left off on defining this update
utu.slice_mut(s![.., .., *n])
.assign(&factors[*n].t().dot(&factors[*n]));
}
Expand Down

0 comments on commit f67c7d4

Please sign in to comment.