Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Fix all flake8 violations in smoke_test.py #1553

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions test/smoke_test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
},
]


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
Expand All @@ -53,6 +54,7 @@ def forward(self, x):
output = self.fc1(x)
return output


def check_version(package: str) -> None:
# only makes sense to check nightly package where dates are known
if channel == "nightly":
Expand All @@ -65,32 +67,33 @@ def check_version(package: str) -> None:
else:
print(f"Skip version check for channel {channel} as stable version is None")


def check_nightly_binaries_date(package: str) -> None:
from datetime import datetime, timedelta
format_dt = '%Y%m%d'

torch_str = torch.__version__
date_t_str = re.findall("dev\d+", torch.__version__)
date_t_str = re.findall("dev\\d+", torch.__version__)
date_t_delta = datetime.now() - datetime.strptime(date_t_str[0][3:], format_dt)
if date_t_delta.days >= NIGHTLY_ALLOWED_DELTA:
raise RuntimeError(
f"the binaries are from {date_t_str} and are more than {NIGHTLY_ALLOWED_DELTA} days old!"
)

if(package == "all"):
if package == "all":
for module in MODULES:
imported_module = importlib.import_module(module["name"])
module_version = imported_module.__version__
date_m_str = re.findall("dev\d+", module_version)
date_m_str = re.findall("dev\\d+", module_version)
date_m_delta = datetime.now() - datetime.strptime(date_m_str[0][3:], format_dt)
print(f"Nightly date check for {module['name']} version {module_version}")
if date_m_delta.days > NIGHTLY_ALLOWED_DELTA:
raise RuntimeError(
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
)


def test_cuda_runtime_errors_captured() -> None:
cuda_exception_missed=True
cuda_exception_missed = True
try:
print("Testing test_cuda_runtime_errors_captured")
torch._assert_async(torch.tensor(0, device="cuda"))
Expand All @@ -101,14 +104,15 @@ def test_cuda_runtime_errors_captured() -> None:
cuda_exception_missed = False
else:
raise e
if(cuda_exception_missed):
raise RuntimeError( f"Expected CUDA RuntimeError but have not received!")
if cuda_exception_missed:
raise RuntimeError("Expected CUDA RuntimeError but have not received!")


def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
if not torch.cuda.is_available() and is_cuda_system:
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")

if(package == 'all' and is_cuda_system):
if package == 'all' and is_cuda_system:
for module in MODULES:
imported_module = importlib.import_module(module["name"])
# TBD for vision move extension module to private so it will
Expand All @@ -131,12 +135,10 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")

# torch.compile is available only on Linux and python 3.8-3.10
if (sys.platform == "linux" or sys.platform == "linux2") and sys.version_info < (3, 11, 0) and channel == "release":
smoke_test_compile()
elif (sys.platform == "linux" or sys.platform == "linux2") and channel != "release":
if sys.platform in ["linux", "linux2"] and (sys.version_info < (3, 11, 0) or channel != "release"):
smoke_test_compile()

if(runtime_error_check == "enabled"):
if runtime_error_check == "enabled":
test_cuda_runtime_errors_captured()


Expand All @@ -148,6 +150,7 @@ def smoke_test_conv2d() -> None:
m = nn.Conv2d(16, 33, 3, stride=2)
# non-square kernels and unequal stride and with padding
m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
assert m is not None
# non-square kernels and unequal stride and with padding and dilation
basic_conv = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
input = torch.randn(20, 16, 50, 100)
Expand All @@ -156,16 +159,19 @@ def smoke_test_conv2d() -> None:
if is_cuda_system:
print("Testing smoke_test_conv2d with cuda")
conv = nn.Conv2d(3, 3, 3).cuda()
x = torch.randn(1, 3, 24, 24).cuda()
x = torch.randn(1, 3, 24, 24, device="cuda")
with torch.cuda.amp.autocast():
out = conv(x)
assert out is not None

supported_dtypes = [torch.float16, torch.float32, torch.float64]
for dtype in supported_dtypes:
print(f"Testing smoke_test_conv2d with cuda for {dtype}")
conv = basic_conv.to(dtype).cuda()
input = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
output = conv(input)
assert output is not None


def smoke_test_linalg() -> None:
print("Testing smoke_test_linalg")
Expand All @@ -189,10 +195,13 @@ def smoke_test_linalg() -> None:
A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
torch.linalg.svd(A)


def smoke_test_compile() -> None:
supported_dtypes = [torch.float16, torch.float32, torch.float64]

def foo(x: torch.Tensor) -> torch.Tensor:
return torch.sin(x) + torch.cos(x)

for dtype in supported_dtypes:
print(f"Testing smoke_test_compile for {dtype}")
x = torch.rand(3, 3, device="cuda").type(dtype)
Expand All @@ -209,6 +218,7 @@ def foo(x: torch.Tensor) -> torch.Tensor:
model = Net().to(device="cuda")
x_pt2 = torch.compile(model, mode="max-autotune")(x)


def smoke_test_modules():
cwd = os.getcwd()
for module in MODULES:
Expand All @@ -224,9 +234,7 @@ def smoke_test_modules():
smoke_test_command, stderr=subprocess.STDOUT, shell=True,
universal_newlines=True)
except subprocess.CalledProcessError as exc:
raise RuntimeError(
f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}"
)
raise RuntimeError(f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}")
else:
print("Output: \n{}\n".format(output))

Expand Down
Loading