Skip to content

Commit

Permalink
Merge pull request #24 from thom311/th/mypy
Browse files Browse the repository at this point in the history
[th/mypy] fix typing annotations so that `mypy --strict` passes and enable github action
  • Loading branch information
bn222 authored Oct 2, 2024
2 parents ffe67d9 + ca83dcb commit ae97860
Show file tree
Hide file tree
Showing 15 changed files with 126 additions and 84 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/lint-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ jobs:
python -m pip install --upgrade pip
python -m pip install black
python -m pip install flake8
python -m pip install mypy
python -m pip install -r requirements.txt
python -m pip install types-paramiko
python -m pip install types-requests
- name: Check code formatting with Black
run: |
black --version
Expand All @@ -29,3 +32,7 @@ jobs:
run: |
flake8 --version
flake8
- name: Type check with Mypy
run: |
mypy --version
mypy --strict --config-file mypy.ini
2 changes: 1 addition & 1 deletion bfb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import time
import common_bf


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Downloads BFB images and sends it to the BF."
)
Expand Down
41 changes: 26 additions & 15 deletions common_bf.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import dataclasses
import os
import shlex
import subprocess
import sys

from collections import namedtuple
from typing import Optional


def run(cmd: str, env: dict = os.environ.copy()):
Result = namedtuple("Result", "out err returncode")
@dataclasses.dataclass(frozen=True)
class Result:
out: str
err: str
returncode: int


def run(cmd: str, env: dict[str, str] = os.environ.copy()) -> Result:
args = shlex.split(cmd)
pipe = subprocess.PIPE
with subprocess.Popen(args, stdout=pipe, stderr=pipe, env=env) as proc:
out = proc.stdout.read().decode("utf-8")
err = proc.stderr.read().decode("utf-8")
proc.communicate()
ret = proc.returncode
return Result(out, err, ret)
res = subprocess.run(
args,
capture_output=True,
env=env,
)

return Result(
out=res.stdout.decode("utf-8"),
err=res.stderr.decode("utf-8"),
returncode=res.returncode,
)


def all_interfaces():
def all_interfaces() -> dict[str, str]:
out = run("lshw -c network -businfo").out
ret = {}
for e in out.split("\n")[2:]:
Expand All @@ -32,13 +43,13 @@ def all_interfaces():
return ret


def find_bf_pci_addresses():
def find_bf_pci_addresses() -> list[str]:
ai = all_interfaces()
bfs = [e for e in ai.items() if "BlueField" in e[1]]
return [k.split("@")[1] for k, v in bfs]


def find_bf_pci_addresses_or_quit(bf_id):
def find_bf_pci_addresses_or_quit(bf_id: int) -> str:
bf_pci = find_bf_pci_addresses()
if not bf_pci:
print("No BF found")
Expand All @@ -49,7 +60,7 @@ def find_bf_pci_addresses_or_quit(bf_id):
return bf_pci[bf_id]


def mst_flint(pci):
def mst_flint(pci: str) -> dict[str, str]:
out = run(f"mstflint -d {pci} q").out
ret = {}
for e in out.split("\n"):
Expand All @@ -67,7 +78,7 @@ def mst_flint(pci):
return ret


def bf_version(pci):
def bf_version(pci: str) -> Optional[int]:
out = run("lshw -c network -businfo").out
for e in out.split("\n"):
if not e.startswith(f"pci@{pci}"):
Expand Down
2 changes: 1 addition & 1 deletion console
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import os
import common_bf


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Select BF to connect to with a console."
)
Expand Down
2 changes: 1 addition & 1 deletion cx_fwup
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import os
import sys


def main():
def main() -> None:
os.system("chmod +x mlxup")
r = os.system("/mlxup -y")
sys.exit(r)
Expand Down
37 changes: 23 additions & 14 deletions dpu-tools/dpu-tools
Original file line number Diff line number Diff line change
@@ -1,33 +1,42 @@
#!/usr/bin/env python3

import argparse
import dataclasses
import os
import re
import shlex
import shutil
import subprocess
import tempfile

from collections import namedtuple

@dataclasses.dataclass(frozen=True)
class Result:
out: str
err: str
returncode: int

def run(cmd: str, env: dict = os.environ.copy()):
Result = namedtuple("Result", "out err returncode")

def run(cmd: str, env: dict[str, str] = os.environ.copy()) -> Result:
args = shlex.split(cmd)
pipe = subprocess.PIPE
with subprocess.Popen(args, stdout=pipe, stderr=pipe, env=env) as proc:
out = proc.stdout.read().decode("utf-8")
err = proc.stderr.read().decode("utf-8")
proc.communicate()
ret = proc.returncode
return Result(out, err, ret)
res = subprocess.run(
args,
capture_output=True,
env=env,
)

return Result(
out=res.stdout.decode("utf-8"),
err=res.stderr.decode("utf-8"),
returncode=res.returncode,
)


def reset(args):
def reset(args: argparse.Namespace) -> None:
run("ssh [email protected] sudo reboot")


def console(args):
def console(args: argparse.Namespace) -> None:
if args.target == "imc":
minicom_cmd = "minicom -b 460800 -D /dev/ttyUSB2"
else:
Expand Down Expand Up @@ -64,7 +73,7 @@ def find_bus_pci_address(address: str) -> str:
return "Invalid PCI address format"


def list_dpus(args):
def list_dpus(args: argparse.Namespace) -> None:
del args
devs = {}
for e in run("lspci").out.split("\n"):
Expand All @@ -81,7 +90,7 @@ def list_dpus(args):
print(f"{i: 5d} {k.ljust(8)} {d.ljust(12)} {kind}")


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Tools to interact with an IPU")
subparsers = parser.add_subparsers(
title="subcommands", description="Valid subcommands", dest="subcommand"
Expand Down
2 changes: 1 addition & 1 deletion fwdefaults
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import os
import common_bf


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Resets the firmware settings on the BF to defaults."
)
Expand Down
21 changes: 12 additions & 9 deletions fwup
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,33 @@ import argparse
import requests
import sys

from typing import Any

import common_bf


class RemoteAPI:
def __init__(self, bf_version):
def __init__(self, bf_version: int):
self._remote_url = f"https://downloaders.azurewebsites.net/downloaders/bluefield{bf_version}_fw_downloader/helper.php"

def get_latest_version(self):
def get_latest_version(self) -> str:
data = {
"action": "get_versions",
}
response = requests.post(self._remote_url, data=data)
return response.json()["latest"]
s = response.json()["latest"]
assert isinstance(s, str)
return s

def get_distros(self, v):
def get_distros(self, v: str) -> Any:
data = {
"action": "get_distros",
"version": v,
}
r = requests.post(self._remote_url, data=data)

return r.json()

def get_os(self, version, distro):
def get_os(self, version: str, distro: str) -> Any:
data = {
"action": "get_oses",
"version": version,
Expand All @@ -36,7 +39,7 @@ class RemoteAPI:
r = requests.post(self._remote_url, data=data)
return r.json()[0]

def get_download_info(self, version, distro, os_param):
def get_download_info(self, version: str, distro: str, os_param: str) -> Any:
data = {
"action": "get_download_info",
"version": version,
Expand All @@ -48,7 +51,7 @@ class RemoteAPI:
return r.json()


def update_bf_firmware(args):
def update_bf_firmware(args: argparse.Namespace) -> int:
bf = common_bf.find_bf_pci_addresses_or_quit(args.id)
target_psid = common_bf.mst_flint(bf)["PSID"]
bf_version = common_bf.bf_version(bf)
Expand Down Expand Up @@ -91,7 +94,7 @@ def update_bf_firmware(args):
return 0


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Specify the id of the BF. Updates the firmware on the BF to the latest avaible one."
)
Expand Down
2 changes: 1 addition & 1 deletion fwversion
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import argparse
import common_bf


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Shows firmware version.")
parser.add_argument(
"-i",
Expand Down
2 changes: 1 addition & 1 deletion get_mode
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import argparse
import common_bf


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Reads the current mode of the BF.")
parser.add_argument(
"-i",
Expand Down
2 changes: 1 addition & 1 deletion listbf
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import common_bf


def main():
def main() -> None:
bf2s = common_bf.find_bf_pci_addresses()
print("ID PCI-Address")
print("----- ------------")
Expand Down
7 changes: 7 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[mypy]
strict = true
scripts_are_modules = true
files = *.py, bfb, console, cx_fwup, fwdefaults, fwup, fwversion, get_mode, listbf, pxeboot, reset, set_mode, dpu-tools/dpu-tools

[mypy-pexpect]
ignore_missing_imports = true
Loading

0 comments on commit ae97860

Please sign in to comment.