Skip to content

Commit

Permalink
mtie should output the correct rank and size for the output operator.
Browse files Browse the repository at this point in the history
Fixes #725
  • Loading branch information
luitjens committed Aug 20, 2024
1 parent 15c7e82 commit 50fb9db
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
4 changes: 2 additions & 2 deletions include/matx/core/tie.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ struct mtie : public BaseOp<mtie<Ts...>>{

static __MATX_INLINE__ constexpr int32_t Rank()
{
return matxNoRank;
return decltype(cuda::std::get<0>(ts_))::Rank();
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size([[maybe_unused]] int dim) const noexcept
{
return 0;
return cuda::std::get<0>(ts_).Size(dim);
}

template <typename Executor>
Expand Down
54 changes: 53 additions & 1 deletion test/00_transform/Norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,22 @@ TYPED_TEST(NormTestFloatTypes, VectorL1)
this->pb->RunTVGenerator("vector_l1");
this->pb->NumpyToTensorView(this->in_v, "in_v");
this->pb->NumpyToTensorView(this->out_v, "out_v");

auto redOp = vector_norm(this->in_v, NormOrder::L1);

EXPECT_TRUE(redOp.Rank() == this->out_v.Rank());
for(int i = 0; i < redOp.Rank(); i++) {
EXPECT_TRUE(redOp.Size(i) == this->out_v.Size(i));
}
(this->out_v = redOp).run(this->exec);
MATX_TEST_ASSERT_COMPARE(this->pb, this->out_v, "out_v", this->thresh);

(this->out_v = TestType(0)).run(this->exec);

// example-begin vector-norm-test-1
(this->out_v = vector_norm(this->in_v, NormOrder::L1)).run(this->exec);
// example-end vector-norm-test-1

MATX_TEST_ASSERT_COMPARE(this->pb, this->out_v, "out_v", this->thresh);

MATX_EXIT_HANDLER();
Expand All @@ -102,6 +114,19 @@ TYPED_TEST(NormTestFloatTypes, VectorL2)
this->pb->RunTVGenerator("vector_l2");
this->pb->NumpyToTensorView(this->in_v, "in_v");
this->pb->NumpyToTensorView(this->out_v, "out_v");

auto redOp = vector_norm(this->in_v, NormOrder::L2);

EXPECT_TRUE(redOp.Rank() == this->out_v.Rank());
for(int i = 0; i < redOp.Rank(); i++) {
EXPECT_TRUE(redOp.Size(i) == this->out_v.Size(i));
}
(this->out_v = redOp).run(this->exec);

MATX_TEST_ASSERT_COMPARE(this->pb, this->out_v, "out_v", this->thresh);

(this->out_v = TestType(0)).run(this->exec);

// example-begin vector-norm-test-2
(this->out_v = vector_norm(this->in_v, NormOrder::L2)).run(this->exec);
// example-end vector-norm-test-2
Expand All @@ -119,6 +144,20 @@ TYPED_TEST(NormTestFloatTypes, MatrixL1)
this->pb->RunTVGenerator("matrix_l1");
this->pb->NumpyToTensorView(this->in_m, "in_m");
this->pb->NumpyToTensorView(this->out_m, "out_m");

auto redOp = matrix_norm(this->in_m, NormOrder::L1);

EXPECT_TRUE(redOp.Rank() == this->out_v.Rank());
for(int i = 0; i < redOp.Rank(); i++) {
EXPECT_TRUE(redOp.Size(i) == this->out_v.Size(i));
}

(this->out_m = redOp).run(this->exec);

MATX_TEST_ASSERT_COMPARE(this->pb, this->out_m, "out_m", this->thresh);

(this->out_v = TestType(0)).run(this->exec);

// example-begin matrix-norm-test-1
(this->out_m = matrix_norm(this->in_m, NormOrder::L1)).run(this->exec);
// example-end matrix-norm-test-1
Expand All @@ -136,6 +175,19 @@ TYPED_TEST(NormTestFloatTypes, MatrixL2)
this->pb->RunTVGenerator("matrix_frob");
this->pb->NumpyToTensorView(this->in_m, "in_m");
this->pb->NumpyToTensorView(this->out_m, "out_m");

auto redOp = matrix_norm(this->in_m, NormOrder::FROB);

EXPECT_TRUE(redOp.Rank() == this->out_v.Rank());
for(int i = 0; i < redOp.Rank(); i++) {
EXPECT_TRUE(redOp.Size(i) == this->out_v.Size(i));
}
(this->out_m = redOp).run(this->exec);

MATX_TEST_ASSERT_COMPARE(this->pb, this->out_m, "out_m", this->thresh);

(this->out_v = TestType(0)).run(this->exec);

// example-begin matrix-norm-test-2
(this->out_m = matrix_norm(this->in_m, NormOrder::FROB)).run(this->exec);
// example-end matrix-norm-test-2
Expand Down

0 comments on commit 50fb9db

Please sign in to comment.