Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No. 10 #170

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7220d31
fix
enkilee Jul 12, 2023
421082c
fix
enkilee Jul 13, 2023
94052b2
fix
enkilee Jul 13, 2023
240c942
Merge branch 'master' into master
enkilee Jul 13, 2023
932539b
fix
enkilee Jul 13, 2023
4f9fe54
remove stft
enkilee Jul 13, 2023
f0514ae
add stft
enkilee Jul 14, 2023
f19c1e9
fix
enkilee Jul 14, 2023
8aece8b
fix
enkilee Jul 15, 2023
d290778
fix
enkilee Jul 15, 2023
af29221
fix
enkilee Jul 15, 2023
f211c16
fix
enkilee Jul 15, 2023
d4afe68
fix
enkilee Jul 15, 2023
305310d
fix
enkilee Jul 15, 2023
6b141f5
fix
enkilee Jul 18, 2023
5391062
fix
enkilee Jul 18, 2023
cc447a0
fix
enkilee Jul 18, 2023
16405d7
fix
enkilee Jul 18, 2023
99948c7
fix
enkilee Jul 18, 2023
480b26f
fix
enkilee Jul 18, 2023
e468b93
fix
enkilee Jul 18, 2023
2b42e83
fix
enkilee Jul 18, 2023
ecc1b9c
Merge branch 'optest-stft' of https://github.com/enkilee/PaConvert in…
enkilee Jul 18, 2023
0d0ffc2
fix
enkilee Jul 18, 2023
dd27616
fix
enkilee Jul 19, 2023
beaa6fa
fix
enkilee Jul 19, 2023
bfb3dfc
fix
enkilee Jul 19, 2023
48941d3
fix
enkilee Jul 19, 2023
4f328f5
fix
enkilee Jul 19, 2023
0904eff
fix
enkilee Jul 19, 2023
a7bd400
fix
enkilee Jul 19, 2023
e014679
fix
enkilee Jul 19, 2023
567f49b
fix
enkilee Jul 19, 2023
ad58800
fix
enkilee Jul 19, 2023
68611af
fix
enkilee Jul 19, 2023
205a64d
fix
enkilee Jul 19, 2023
c7be7c7
fix
enkilee Jul 19, 2023
81708fb
fix
enkilee Jul 19, 2023
22ec03f
fix
enkilee Jul 20, 2023
242b317
Merge branch 'master' into optest-stft
enkilee Aug 1, 2023
1ecc4c5
fix
enkilee Aug 1, 2023
26dab40
fix
enkilee Aug 1, 2023
5b0e18e
add unitest
enkilee Aug 7, 2023
e678380
add unittest
enkilee Aug 8, 2023
4de4947
Merge branch 'PaddlePaddle:master' into optest-stft
enkilee Aug 30, 2023
e6483e4
CI
enkilee Sep 5, 2023
6da74b8
Merge branch 'master' into optest-stft
enkilee Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -9090,6 +9090,25 @@
"out"
]
},
"torch.stft": {
"Matcher": "StftMatcher",
"paddle_api": "paddle.signal.stft",
"args_list": [
"input",
"n_fft",
"hop_length",
"win_length",
"window",
"center",
"pad_mode",
"normalized",
"onesided",
"return_complex"
],
"kwargs_change": {
"input": "x"
}
},
"torch.sub": {
"Matcher": "SubMatcher",
"args_list": [
Expand Down
61 changes: 61 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,67 @@ def generate_code(self, kwargs):
return code


class StftMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "hop_length" not in kwargs:
enkilee marked this conversation as resolved.
Show resolved Hide resolved
kwargs["hop_length"] = None
enkilee marked this conversation as resolved.
Show resolved Hide resolved
if "win_length" not in kwargs:
kwargs["win_length"] = None
if "window" not in kwargs:
kwargs["window"] = None
if "center" not in kwargs:
kwargs["center"] = "True"
if "pad_mode" not in kwargs:
kwargs["pad_mode"] = "'reflect'"
if "normalized" not in kwargs:
kwargs["normalized"] = "False"

return_complex_temp = (
kwargs.pop("return_complex") if "return_complex" in kwargs else None
)

if return_complex_temp:
enkilee marked this conversation as resolved.
Show resolved Hide resolved
kwargs["onesided"] = "True" if return_complex_temp != "(False)" else "False"

if "out" in kwargs and kwargs["out"] is not None:
API_TEMPLATE = textwrap.dedent(
"""
paddle.assign((paddle.signal.stft(x={}, n_fft={}, hop_length={}, win_length={}, window={}, center={}, pad_mode={}, normalized={}, onesided={})), output={})
"""
)
code = API_TEMPLATE.format(
kwargs["input"],
kwargs["n_fft"],
kwargs["hop_length"],
kwargs["win_length"],
kwargs["window"],
kwargs["center"],
kwargs["pad_mode"],
kwargs["normalized"],
kwargs["onesided"],
kwargs["out"],
)
else:
API_TEMPLATE = textwrap.dedent(
"""
paddle.signal.stft(x={}, n_fft={}, hop_length={}, win_length={}, window={}, center={}, pad_mode={}, normalized={}, onesided={})
"""
)
n_fft = get_unique_name("n_fft")
enkilee marked this conversation as resolved.
Show resolved Hide resolved
code = API_TEMPLATE.format(
kwargs["input"],
kwargs["n_fft"],
kwargs["hop_length"],
kwargs["win_length"],
kwargs["window"],
kwargs["center"],
kwargs["pad_mode"],
kwargs["normalized"],
kwargs["onesided"],
)
return code


class CovMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
Expand Down
91 changes: 91 additions & 0 deletions tests/test_stft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import textwrap

from apibase import APIBase

obj = APIBase("torch.stft")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
n_fft = 4
result = torch.stft(x, n_fft=n_fft, return_complex=True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
n_fft = 4
hop_length = 4
result = torch.stft(x, n_fft=n_fft, hop_length=hop_length, return_complex=True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
n_fft = 4
win_length = 4
result = torch.stft(x, n_fft=n_fft, win_length=win_length, return_complex=True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测case需要加一些,这个API参数很多,可以写七八个case

infoflow 2023-08-03 12-31-08

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加。CI已过,请有空审核

pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
n_fft = 4
result = torch.stft(x, n_fft=n_fft, center=False, return_complex=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这四种情况需要都测到:

  1. 全部指定关键字
  2. 全部不指定关键字
  3. 改变关键字顺序
  4. 默认参数全部省略

"""
)
obj.run(pytorch_code, ["result"])