diff --git a/paddle/fluid/operators/frame_op.cc b/paddle/fluid/operators/frame_op.cc new file mode 100644 index 0000000000000..850854144099c --- /dev/null +++ b/paddle/fluid/operators/frame_op.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/frame_op.h" + +namespace paddle { +namespace operators { + +class FrameOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame"); + + const int frame_length = ctx->Attrs().Get("frame_length"); + const int hop_length = ctx->Attrs().Get("hop_length"); + const int axis = ctx->Attrs().Get("axis"); + + const auto x_dims = ctx->GetInputDim("X"); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GT(hop_length, 0, + platform::errors::InvalidArgument( + "Attribute(hop_length) of FrameOp should be greater " + "than 0, but got %s.", + hop_length)); + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), true, + platform::errors::InvalidArgument( + "Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis)); + + std::vector output_shape; + int seq_length; + int n_frames; + + int start_axis; + int end_axis; + + if (axis == 0) { + seq_length = x_dims[0]; + start_axis = 1; + end_axis = x_rank - 1; + } else { + seq_length = x_dims[x_rank - 1]; + start_axis = 0; + end_axis = x_rank - 2; + } + + PADDLE_ENFORCE_LE(frame_length, seq_length, + platform::errors::InvalidArgument( + "Attribute(frame_length) of FrameOp should be less " + "equal than sequence length, but got (%s) > (%s).", + frame_length, seq_length)); + + // It won't go into for loop when x_rank == 1U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + n_frames = 1 + (seq_length - frame_length) / hop_length; + + if (axis == 0) { + // (n_frames, frame_length, ...) + output_shape.insert(output_shape.begin(), frame_length); + output_shape.insert(output_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + output_shape.push_back(frame_length); + output_shape.push_back(n_frames); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +class FrameOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of frame op."); + AddOutput("Out", "(Tensor), The output tensor of frame op."); + AddAttr("frame_length", + "Frame Length" + "Other doc of frame length arg..."); + AddAttr("hop_length", + "Hop Length" + "Other doc of hop length arg..."); + AddAttr("axis", + "Axis" + "Other doc of axis arg...") + .SetDefault(-1); + AddComment(R"DOC( + Frame Operator. + + This operator is used to slice frame of input $X$. + + )DOC"); + } +}; + +class FrameOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "frame_grad"); + const auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +template +class FrameOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("frame_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(frame, ops::FrameOp, ops::FrameOpMaker, + ops::FrameOpGradMaker, + ops::FrameOpGradMaker); + +REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad); + +REGISTER_OP_CPU_KERNEL( + frame, ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel>, + ops::FrameKernel>); + +REGISTER_OP_CPU_KERNEL( + frame_grad, ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel>, + ops::FrameGradKernel>); diff --git a/paddle/fluid/operators/frame_op.cu b/paddle/fluid/operators/frame_op.cu new file mode 100644 index 0000000000000..203cc757ce687 --- /dev/null +++ b/paddle/fluid/operators/frame_op.cu @@ -0,0 +1,40 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/frame_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + frame, ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel>, + ops::FrameKernel>); + +REGISTER_OP_CUDA_KERNEL( + frame_grad, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel>, + ops::FrameGradKernel>); diff --git a/paddle/fluid/operators/frame_op.h b/paddle/fluid/operators/frame_op.h new file mode 100644 index 0000000000000..aa2cee7bee474 --- /dev/null +++ b/paddle/fluid/operators/frame_op.h @@ -0,0 +1,490 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +struct DataMappingFunctor { + DataMappingFunctor(const T* x, T* out, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length) + : x_(x), + out_(out), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Convert sequences to frames. + + 1. Dimension infomation: + + Sequences Frames + (N, seq_length) -> (N, frame_length, n_frames) + + 2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from: + + a. Notion + - `i` stands for the flattened index of a bunch of frames. + - `src_idx` and `trg_idx` are the 1D indices of seqs and frames + respectivly. + + b. Sample idx + ```cpp + sample_idx = i / (n_frames_ * frame_length_); + ``` + + c. Maps `i` to `f` and `n`. + ```cpp + f = i % (n_frames_ * frame_length_) / n_frames_; + n = i % (n_frames_ * frame_length_) % n_frames_; + ``` + + d. Replace `sample_idx`, `f` and `n` in the eqations followed. + ```cpp + src_idx = sample_idx * seq_length_ + n * hop_length_ + f; + trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n; + out_[trg_idx] = x_[src_idx]; + ``` + + e. Result can be deduced shown in the function body below. + */ + HOSTDEVICE void operator()(size_t i) const { + size_t src_idx; + size_t trg_idx; + src_idx = i / (n_frames_ * frame_length_) * seq_length_ + + i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ + + i % (n_frames_ * frame_length_) / n_frames_; + trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ + + i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ + + i % (n_frames_ * frame_length_) % n_frames_; + out_[trg_idx] = x_[src_idx]; + } + + const T* x_; + T* out_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +template +struct DataMappingGradFunctor { + DataMappingGradFunctor(const T* d_out, T* d_x, size_t seq_length, + size_t frame_length, size_t n_frames, + size_t hop_length) + : d_out_(d_out), + d_x_(d_x), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Accumulate output gradient d_out to d_x. + + 1. Dimension infomation: + + d_out d_x + (N, frame_length, n_frames) -> (N, seq_length) + + 2. Using a sliding window to find source indices from `d_out` according to + `i`: + + a. Notion + - `i` stands for the flattened index of `d_x`. + - `seq_i` stands for a relative index of a `d_x` sample. + - `left`: Starting index of a frame window. + - `right`: Ending index of a frame window. + + b. Sample idx + ```cpp + sample_idx = i / seq_length_; + ``` + + c. Slides a window with length of `frame_length` to find `f` and `n`. + - `n`: The idx of num_frames_, increases in each hop. + - `f`: The idx of frame_lengths_, relative idx from left of a sliding + window. + + d. Accumulate all grads from d_out. + ```cpp + d_x_[i] += + d_out_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + ``` + */ + HOSTDEVICE void operator()(size_t i) const { + size_t sample_idx = i / seq_length_; + size_t seq_i = i % seq_length_; + + // Sliding window + d_x_[i] = 0; // Init d_x_[i] to 0, and a while loop followed to sum up all + // grads from d_out_. + + size_t n = get_start_frame_idx(seq_i); + size_t f; + size_t left = n * hop_length_; + size_t right = left + frame_length_ - 1; + + while (left <= seq_i && right < seq_length_) { + f = seq_i - left; + d_x_[i] += + d_out_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + // Next frame. + left += hop_length_; + right += hop_length_; + n += 1; + } + } + + /* + Calculate minimum value of frame index `n` to satisfy the inequality: + + seq_i <= right + ==> seq_i <= left + frame_length - 1 + ==> seq_i <= hop_length_ * n + frame_length_ - 1 + */ + HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const { + int64_t tmp = seq_i + 1 - frame_length_; + if (tmp > 0) { + size_t n = tmp / hop_length_; + if (tmp % hop_length_ == 0) { + return n; + } else { + return n + 1; + } + } else { + return 0; + } + } + + const T* d_out_; + T* d_x_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +template +struct FrameFunctor { + void operator()(const DeviceContext& dev_ctx, const Tensor* input, + Tensor* output, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + auto* input_data = input->data(); + auto* output_data = output->data(); + + platform::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + DataMappingFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } else { + DataMappingGradFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } + } +}; + +template +static inline void TransCompute(const framework::ExecutionContext& ctx, + const Tensor& x, Tensor* out, + const std::vector& perm) { + int rank = x.dims().size(); + PADDLE_ENFORCE_EQ((rank == 2 || rank == 3), true, + platform::errors::InvalidArgument( + "Rank of x should be 2 or 3, but got %d.", rank)); + + if (!out->IsInitialized()) { + auto dims_vec = framework::vectorize(x.dims()); + for (int i = 0; i < rank; ++i) { + dims_vec[i] = x.dims()[perm[i]]; + } + out->Resize(framework::make_ddim(dims_vec)); + out->mutable_data(ctx.GetPlace()); + } + + auto& dev_ctx = ctx.device_context(); + + switch (rank) { + case 2: + math::Transpose trans2; + trans2(dev_ctx, x, out, perm); + break; + case 3: + math::Transpose trans3; + trans3(dev_ctx, x, out, perm); + break; + default: + break; + } +} + +template +class FrameKernel : public framework::OpKernel { + public: + /* + Frame kernel slice frames from input sequences. The main steps as follow: + + - Case 1 - input dims == 1: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose output firstly, and then it falls into + case axis is -1. Finally, it restores the dims of + output tensor. + + - Case 2 - input dims == 2: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose both input and output firstly, and then it falls + into case axis is -1. Finally, it restores the dims of + output tensor. + + - Case 3 - input dims > 2: + Flatten the input and output to 2D and 3D respectively so that it + falls into Case 2. Finally, it restores the dims of output tensor. + */ + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + const size_t x_rank = x->dims().size(); + const size_t out_rank = out->dims().size(); + + const int frame_length = ctx.Attr("frame_length"); + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + Tensor x_(x->type()); + x_ = *x; + + framework::DDim preserved_dims; + if (x_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim x_resized_dims; + framework::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(x_.dims(), 1, x_rank); + x_resized_dims = {seq_length, framework::product(preserved_dims)}; + out_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(x_.dims(), 0, x_rank - 1); + x_resized_dims = {framework::product(preserved_dims), seq_length}; + out_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + Tensor trans_x(x_.type()); + Tensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + trans_x = x_; + std::vector perm_out{1, 0}; + TransCompute(ctx, *out, &trans_out, perm_out); + } else { + std::vector perm_x{1, 0}; + TransCompute(ctx, x_, &trans_x, perm_x); + std::vector perm_out{2, 1, 0}; + TransCompute(ctx, *out, &trans_out, perm_out); + } + } else { + trans_x = x_; + trans_out = *out; + } + + FrameFunctor()(dev_ctx, &trans_x, &trans_out, seq_length, + frame_length, n_frames, hop_length, + /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + std::vector perm_out{1, 0}; + TransCompute(ctx, trans_out, out, perm_out); + } else { + std::vector perm_out{2, 1, 0}; + TransCompute(ctx, trans_out, out, perm_out); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (x_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), frame_length); + restored_out_shape.insert(restored_out_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_out_shape.push_back(frame_length); + restored_out_shape.push_back(n_frames); + } + + out->Resize(framework::make_ddim(restored_out_shape)); + } + } +}; + +template +class FrameGradKernel : public framework::OpKernel { + public: + /* + Frame gradient kernel accumulate gradient `d_x` from `d_out`. The + main steps as follow: + + - Case 1 - d_x dims == 1: + - axis is -1: Call a FrameFunctor to compute directly. Notes that + `is_grad` is set to true to select gradient data functor. + - axis is 0: Transpose `d_out` firstly, and then it falls into + case axis is -1. + + - Case 2 - d_x dims == 2: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose both `d_x` and `d_out` firstly, and then it + falls into case axis is -1. Finally, it restores the + dims of `d_x`. + + - Case 3 - d_x dims > 2: + Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it + falls into Case 2. Finally, it restores the dims of `d_x` tensor. + */ + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + const size_t d_out_rank = d_out->dims().size(); + const size_t d_x_rank = d_x->dims().size(); + + const int frame_length = ctx.Attr("frame_length"); + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; + const int seq_length = + (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + Tensor d_out_(d_out->type()); + d_out_ = *d_out; + + framework::DDim preserved_dims; + if (d_x_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim d_x_resized_dims; + framework::DDim d_out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(d_x->dims(), 1, d_x_rank); + d_x_resized_dims = {seq_length, framework::product(preserved_dims)}; + d_out_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(d_x->dims(), 0, d_x_rank - 1); + d_x_resized_dims = {framework::product(preserved_dims), seq_length}; + d_out_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + } + d_x->Resize(d_x_resized_dims); + d_out_.Resize(d_out_resized_dims); + } + + Tensor trans_d_x(d_x->type()); + Tensor trans_d_out(d_out_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (d_x_rank == 1U) { + trans_d_x = *d_x; + std::vector perm_d_out{1, 0}; + TransCompute(ctx, d_out_, &trans_d_out, perm_d_out); + } else { + std::vector perm_d_x{1, 0}; + TransCompute(ctx, *d_x, &trans_d_x, perm_d_x); + std::vector perm_d_out{2, 1, 0}; + TransCompute(ctx, d_out_, &trans_d_out, perm_d_out); + } + } else { + trans_d_x = *d_x; + trans_d_out = d_out_; + } + + FrameFunctor()(dev_ctx, &trans_d_out, &trans_d_x, + seq_length, frame_length, n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0 && d_x_rank > 1U) { + std::vector perm_d_x{1, 0}; + TransCompute(ctx, trans_d_x, d_x, perm_d_x); + } + + // Restore output dims when the number of dims is larger than 2. + if (d_x_rank > 2) { + std::vector restored_d_x_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_d_x_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_d_x_shape.push_back(seq_length); + } + + d_x->Resize(framework::make_ddim(restored_d_x_shape)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_frame_op.py b/python/paddle/fluid/tests/unittests/test_frame_op.py new file mode 100644 index 0000000000000..ec478c9c834fc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_frame_op.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy.lib.stride_tricks import as_strided +import paddle +import unittest + +from op_test import OpTest + + +def frame_from_librosa(x, frame_length, hop_length, axis=-1): + if axis == -1 and not x.flags["F_CONTIGUOUS"]: + x = np.asfortranarray(x) + elif axis == 0 and not x.flags["C_CONTIGUOUS"]: + x = np.ascontiguousarray(x) + + n_frames = 1 + (x.shape[axis] - frame_length) // hop_length + strides = np.asarray(x.strides) + + new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize + + if axis == -1: + shape = list(x.shape)[:-1] + [frame_length, n_frames] + strides = list(strides) + [hop_length * new_stride] + + elif axis == 0: + shape = [n_frames, frame_length] + list(x.shape)[1:] + strides = [hop_length * new_stride] + list(strides) + + else: + raise ValueError("Frame axis={} must be either 0 or -1".format(axis)) + + return as_strided(x, shape=shape, strides=strides) + + +class TestFrameOp(OpTest): + def setUp(self): + self.op_type = "frame" + self.shape, self.type, self.attrs = self.initTestCase() + self.inputs = { + 'X': np.random.random(size=self.shape).astype(self.type), + } + self.outputs = { + 'Out': frame_from_librosa( + x=self.inputs['X'], **self.attrs) + } + + def initTestCase(self): + input_shape = (150, ) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + def test_check_output(self): + paddle.enable_static() + self.check_output() + paddle.disable_static() + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad(['X'], 'Out') + paddle.disable_static() + + +class TestCase1(TestFrameOp): + def initTestCase(self): + input_shape = (150, ) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + + +class TestCase2(TestFrameOp): + def initTestCase(self): + input_shape = (8, 150) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase3(TestFrameOp): + def initTestCase(self): + input_shape = (150, 8) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + + +# FIXME(chenxiaojie06): There are bugs when input dims >= 3 in librosa. +# class TestCase3(TestFrameOp): +# def initTestCase(self): +# input_shape = (4, 2, 150) +# input_type = 'int32' +# attrs = { +# 'frame_length': 50, +# 'hop_length': 15, +# 'axis': -1, +# } +# return input_shape, input_type, attrs + +# class TestCase4(TestFrameOp): +# def initTestCase(self): +# input_shape = (150, 4, 2) +# input_type = 'int32' +# attrs = { +# 'frame_length': 50, +# 'hop_length': 15, +# 'axis': 0, +# } +# return input_shape, input_type, attrs + +if __name__ == '__main__': + unittest.main()