Skip to content

Commit

Permalink
Add some type hints (zhelyabuzhsky#40)
Browse files Browse the repository at this point in the history
* Add ': Stockfish' type hints in test_models.py.

* Fix mypy errors.

* Add various type hints in models.py.

* Minor test_models.py updates.
  • Loading branch information
johndoknjas authored May 31, 2023
1 parent 417e57d commit 5d4f89b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 116 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
!LICENSE

/*.exe
/Notes/

.DS_Store
*.pyc
Expand Down
71 changes: 36 additions & 35 deletions stockfish/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations
import subprocess
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
import copy
from os import path
from dataclasses import dataclass
Expand Down Expand Up @@ -53,9 +53,9 @@ def __init__(
"UCI_LimitStrength": False,
"UCI_Elo": 1350,
}
self._debug_view = debug_view
self._debug_view: bool = debug_view

self._path = path
self._path: str = path
self._stockfish = subprocess.Popen(
self._path,
universal_newlines=True,
Expand All @@ -64,7 +64,7 @@ def __init__(
stderr=subprocess.STDOUT,
)

self._has_quit_command_been_sent = False
self._has_quit_command_been_sent: bool = False

self._stockfish_major_version: int = int(
self._read_line().split(" ")[1].split(".")[0].replace("-", "")
Expand Down Expand Up @@ -346,10 +346,10 @@ def get_board_visual(self, perspective_white: bool = True) -> str:
```
"""
self._put("d")
board_rep_lines = []
count_lines = 0
board_rep_lines: List[str] = []
count_lines: int = 0
while count_lines < 17:
board_str = self._read_line()
board_str: str = self._read_line()
if "+" in board_str or "|" in board_str:
count_lines += 1
if perspective_white:
Expand Down Expand Up @@ -481,7 +481,7 @@ def set_num_nodes(self, num_nodes: int = 1000000) -> None:
or num_nodes < 1
):
raise TypeError("num_nodes must be an integer higher than 0")
self._num_nodes = num_nodes
self._num_nodes: int = num_nodes

def get_num_nodes(self) -> int:
"""Returns configured number of nodes to search
Expand Down Expand Up @@ -582,8 +582,8 @@ def _is_fen_syntax_valid(fen: str) -> bool:
if len(regexList[0].split("/")) != 8:
return False # 8 rows not present.
for fenPart in regexList[0].split("/"):
field_sum = 0
previous_was_digit = False
field_sum: int = 0
previous_was_digit: bool = False
for c in fenPart:
if c in ["1", "2", "3", "4", "5", "6", "7", "8"]:
if previous_was_digit:
Expand All @@ -610,10 +610,10 @@ def is_fen_valid(self, fen: str) -> bool:
"""
if not Stockfish._is_fen_syntax_valid(fen):
return False
temp_sf = Stockfish(path=self._path, parameters={"Hash": 1})
temp_sf: Stockfish = Stockfish(path=self._path, parameters={"Hash": 1})
# Using a new temporary SF instance, in case the fen is an illegal position that causes
# the SF process to crash.
best_move = None
best_move: Optional[str] = None
temp_sf.set_fen_position(fen, False)
try:
temp_sf._put("go depth 10")
Expand Down Expand Up @@ -649,7 +649,7 @@ def is_move_correct(self, move_value: str) -> bool:
self.info = old_self_info
return is_move_correct

def get_wdl_stats(self) -> Optional[List]:
def get_wdl_stats(self) -> Optional[List[int]]:
"""Returns Stockfish's win/draw/loss stats for the side to move.
Returns:
Expand All @@ -667,7 +667,7 @@ def get_wdl_stats(self) -> Optional[List]:
)

self._go()
lines = []
lines: List[List[str]] = []
while True:
text = self._read_line()
splitted_text = text.split(" ")
Expand All @@ -681,7 +681,7 @@ def get_wdl_stats(self) -> Optional[List]:
index_of_multipv = current_line.index("multipv")
if current_line[index_of_multipv + 1] == "1" and "wdl" in current_line:
index_of_wdl = current_line.index("wdl")
wdl_stats = []
wdl_stats: List[int] = []
for i in range(1, 4):
wdl_stats.append(int(current_line[index_of_wdl + i]))
return wdl_stats
Expand Down Expand Up @@ -720,14 +720,14 @@ def get_evaluation(self) -> dict:
+ """ get_evaluation will still return full strength Stockfish's evaluation of the position."""
)

compare = (
compare: int = (
1 if self.get_turn_perspective() or ("w" in self.get_fen_position()) else -1
)
# If the user wants the evaluation specified relative to who is to move, this will be done.
# Otherwise, the evaluation will be in terms of white's side (positive meaning advantage white,
# negative meaning advantage black).
self._go()
evaluation = dict()
evaluation: dict = dict()
while True:
text = self._read_line()
splitted_text = text.split(" ")
Expand All @@ -751,7 +751,7 @@ def get_static_eval(self) -> Optional[float]:
"""

# Stockfish gives the static eval from white's perspective:
compare = (
compare: int = (
1
if not self.get_turn_perspective() or ("w" in self.get_fen_position())
else -1
Expand Down Expand Up @@ -813,8 +813,8 @@ def get_top_moves(
)

# remember global values
old_multipv = self._parameters["MultiPV"]
old_num_nodes = self._num_nodes
old_multipv: int = self._parameters["MultiPV"]
old_num_nodes: int = self._num_nodes

# to get number of top moves, we use Stockfish's MultiPV option (i.e., multiple principal variations).
# set MultiPV to num_top_moves requested
Expand All @@ -828,7 +828,7 @@ def get_top_moves(
self._num_nodes = num_nodes
self._go_nodes()

lines = []
lines: List[List[str]] = []

# parse output into a list of lists
# this loop will run until Stockfish has finished evaluating the position
Expand All @@ -846,7 +846,7 @@ def get_top_moves(

# set perspective of evaluations. if get_turn_perspective() is True, or white to move,
# use Stockfish's values, otherwise invert values.
perspective = (
perspective: int = (
1 if self.get_turn_perspective() or ("w" in self.get_fen_position()) else -1
)

Expand All @@ -872,7 +872,7 @@ def get_top_moves(
if (num_nodes > 0) and (int(self._pick(line, "nodes")) < self._num_nodes):
break

move_evaluation = {
move_evaluation: dict[str, Union[str, int, None]] = {
# get move
"Move": self._pick(line, "pv"),
# get cp if available
Expand Down Expand Up @@ -916,7 +916,7 @@ def get_top_moves(

return top_moves

def _pick(self, line: list, value: str = "", index: int = 1) -> str:
def _pick(self, line: list[str], value: str = "", index: int = 1) -> str:
return line[line.index(value) + index]

def get_what_is_on_square(self, square: str) -> Optional[Piece]:
Expand All @@ -934,8 +934,8 @@ def get_what_is_on_square(self, square: str) -> Optional[Piece]:
>>> piece = stockfish.get_what_is_on_square("e2")
"""

file_letter = square[0].lower()
rank_num = int(square[1])
file_letter: str = square[0].lower()
rank_num: int = int(square[1])
if (
len(square) != 2
or file_letter < "a"
Expand All @@ -946,12 +946,9 @@ def get_what_is_on_square(self, square: str) -> Optional[Piece]:
raise ValueError(
"square argument to the get_what_is_on_square function isn't valid."
)
rank_visual = self.get_board_visual().splitlines()[17 - 2 * rank_num]
piece_as_char = rank_visual[2 + (ord(file_letter) - ord("a")) * 4]
if piece_as_char == " ":
return None
else:
return Stockfish.Piece(piece_as_char)
rank_visual: str = self.get_board_visual().splitlines()[17 - 2 * rank_num]
piece_as_char: str = rank_visual[2 + (ord(file_letter) - ord("a")) * 4]
return None if piece_as_char == " " else Stockfish.Piece(piece_as_char)

def will_move_be_a_capture(self, move_value: str) -> Capture:
"""Returns whether the proposed move will be a direct capture,
Expand All @@ -973,9 +970,13 @@ def will_move_be_a_capture(self, move_value: str) -> Capture:
"""
if not self.is_move_correct(move_value):
raise ValueError("The proposed move is not valid in the current position.")
starting_square_piece = self.get_what_is_on_square(move_value[:2])
ending_square_piece = self.get_what_is_on_square(move_value[2:4])
if ending_square_piece != None:
starting_square_piece: Optional[Stockfish.Piece] = self.get_what_is_on_square(
move_value[:2]
)
ending_square_piece: Optional[Stockfish.Piece] = self.get_what_is_on_square(
move_value[2:4]
)
if ending_square_piece is not None:
if not self._parameters["UCI_Chess960"]:
return Stockfish.Capture.DIRECT_CAPTURE
else:
Expand Down
Loading

0 comments on commit 5d4f89b

Please sign in to comment.