Skip to content

Commit

Permalink
Remove union shorthand
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Anklesaria committed Jun 24, 2024
1 parent f0cbaa9 commit 1d16f4e
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions test/contrib/hsgp/test_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from functools import reduce
from operator import mul
from typing import Literal
from typing import Literal, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -117,7 +117,11 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]:
],
)
def test_kernel_approx_squared_exponential(
x1: ArrayImpl, x2: ArrayImpl, length: float | ArrayImpl, ell: float, xfail: bool
x1: ArrayImpl,
x2: ArrayImpl,
length: Union[float, ArrayImpl],
ell: float,
xfail: bool,
):
"""ensure that the approximation of the squared exponential kernel is accurate,
matching the exact kernel implementation from sklearn.
Expand All @@ -140,7 +144,7 @@ def test_kernel_approx_squared_exponential(
def _exact_rbf(length):
return RBF(length)(x1, x2).squeeze(axis=-1)

if isinstance(length, float | int):
if isinstance(length, Union[float, int]):
exact = _exact_rbf(length)
elif length.ndim == 1:
exact = _exact_rbf(length)
Expand Down Expand Up @@ -218,7 +222,7 @@ def test_kernel_approx_squared_matern(
def _exact_matern(length):
return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1)

if isinstance(length, float | int):
if isinstance(length, Union[float, int]):
exact = _exact_matern(length)
elif length.ndim == 1:
exact = _exact_matern(length)
Expand Down Expand Up @@ -280,8 +284,8 @@ def test_approximation_squared_exponential(
x: ArrayImpl,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -332,8 +336,8 @@ def test_approximation_matern(
nu: float,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, nu, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -375,8 +379,8 @@ def model(x, nu, alpha, length, ell, m, non_centered):
def test_squared_exponential_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
ell: float | int | list[float | int],
m: int | list[int],
ell: Union[float, int, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down Expand Up @@ -433,8 +437,8 @@ def test_matern_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
nu: float,
ell: int | float | list[float | int],
m: int | list[int],
ell: Union[int, float, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down

0 comments on commit 1d16f4e

Please sign in to comment.