Skip to content

Commit

Permalink
further reduce torch symbolic execution overhead (#393)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin committed Jun 22, 2023
1 parent 7d282a2 commit 3d660d8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
2 changes: 1 addition & 1 deletion keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def vectorized_map(function, elements):
def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = torch.zeros(shape, dtype=values.dtype).to(get_device())
zeros = torch.zeros(shape, dtype=values.dtype, device=get_device())

index_length = indices.shape[-1]
value_shape = shape[index_length:]
Expand Down
12 changes: 5 additions & 7 deletions keras_core/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
def segment_sum(data, segment_ids, num_segments=None, **kwargs):
data = convert_to_tensor(data)
segment_ids = convert_to_tensor(segment_ids)
num_repeats = (
torch.prod(torch.tensor(data.shape[1:])).long().to(get_device())
)
num_repeats = torch.prod(
torch.tensor(data.shape[1:], device=get_device())
).long()
# To use `scatter_add` in torch, we need to replicate `segment_ids` into the
# shape of `data`.
segment_ids = (
Expand All @@ -32,10 +32,8 @@ def segment_sum(data, segment_ids, num_segments=None, **kwargs):
# Add one more dimension to the result shape with the "+1".
shape = (num_segments + 1,) + tuple(data.shape[1:])

result = (
torch.zeros(*shape)
.to(get_device())
.scatter_add(0, segment_ids, data.float())
result = torch.zeros(*shape, device=get_device()).scatter_add(
0, segment_ids, data.float()
)

# Removing the extra dimension.
Expand Down
28 changes: 16 additions & 12 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def ones(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
if isinstance(shape, int):
shape = (shape,)
return torch.ones(size=shape, dtype=dtype).to(get_device())
return torch.ones(size=shape, dtype=dtype, device=get_device())


def zeros(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
if isinstance(shape, int):
shape = (shape,)
return torch.zeros(size=shape, dtype=dtype).to(get_device())
return torch.zeros(size=shape, dtype=dtype, device=get_device())


def zeros_like(x, dtype=None):
Expand Down Expand Up @@ -144,8 +144,10 @@ def append(
def arange(start, stop=None, step=1, dtype=None):
dtype = to_torch_dtype(dtype)
if stop is None:
return torch.arange(end=start, dtype=dtype)
return torch.arange(start, stop, step=step, dtype=dtype)
return torch.arange(end=start, dtype=dtype, device=get_device())
return torch.arange(
start, stop, step=step, dtype=dtype, device=get_device()
)


def arccos(x):
Expand Down Expand Up @@ -310,7 +312,7 @@ def dot(x, y):

def empty(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
return torch.empty(size=shape, dtype=dtype)
return torch.empty(size=shape, dtype=dtype, device=get_device())


def equal(x1, x2):
Expand Down Expand Up @@ -355,7 +357,9 @@ def full(shape, fill_value, dtype=None):
expand_size = len(shape) - len(fill_value.shape)
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
return torch.tile(fill_value, tile_shape)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
return torch.full(
size=shape, fill_value=fill_value, dtype=dtype, device=get_device()
)


def full_like(x, fill_value, dtype=None):
Expand Down Expand Up @@ -436,7 +440,7 @@ def linspace(
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype).to(get_device()) / (num - 1)
steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1)

# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
Expand Down Expand Up @@ -510,7 +514,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype).to(get_device()) / (num - 1)
steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1)

# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
Expand Down Expand Up @@ -787,7 +791,7 @@ def trace(x, offset=None, axis1=None, axis2=None):
def tri(N, M=None, k=0, dtype="float32"):
dtype = to_torch_dtype(dtype)
M = M or N
x = torch.ones((N, M), dtype=dtype)
x = torch.ones((N, M), dtype=dtype, device=get_device())
return torch.tril(x, diagonal=k)


Expand Down Expand Up @@ -883,7 +887,7 @@ def eye(N, M=None, k=None, dtype="float32"):
M = N if M is None else M
k = 0 if k is None else k
if k == 0:
return torch.eye(N, M, dtype=dtype).to(get_device())
return torch.eye(N, M, dtype=dtype, device=get_device())
diag_length = np.maximum(N, M)
diag = torch.ones(diag_length, dtype=dtype)
return torch.diag(diag, diagonal=k)[:N, :M].to(get_device())
diag = torch.ones(diag_length, dtype=dtype, device=get_device())
return torch.diag(diag, diagonal=k)[:N, :M]
21 changes: 12 additions & 9 deletions keras_core/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
x = normal(shape + (4,), mean=0, stddev=1, dtype=dtype, seed=seed)
valid = (x > -2) & (x < 2)
indexes = valid.max(-1, keepdim=True)[1]
trunc_x = torch.empty(shape)
trunc_x = torch.empty(shape, device=get_device())
trunc_x.data.copy_(x.gather(-1, indexes).squeeze(-1))
trunc_x.data.mul_(stddev).add_(mean)
return trunc_x.to(get_device())
return trunc_x


def _get_concrete_noise_shape(inputs, noise_shape):
Expand All @@ -122,16 +122,19 @@ def _get_concrete_noise_shape(inputs, noise_shape):


def dropout(inputs, rate, noise_shape=None, seed=None):
seed, _ = draw_seed(seed)
generator = torch.Generator()
generator.manual_seed(int(seed))

keep_prob = 1.0 - rate
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
keep_prob_matrix = torch.full(noise_shape, keep_prob)
mask = torch.bernoulli(keep_prob_matrix, generator=generator).bool()
keep_prob_matrix = torch.full(noise_shape, keep_prob, device=get_device())
generator = torch_seed_generator(seed)

# Do not use generator during symbolic execution.
if get_device() == "meta":
mask = torch.bernoulli(keep_prob_matrix)
else:
mask = torch.bernoulli(keep_prob_matrix, generator=generator)

mask = mask.bool()
mask = torch.broadcast_to(mask, inputs.shape)
mask = mask.to(get_device())
return torch.where(
mask, inputs / keep_prob, torch.zeros_like(inputs, dtype=inputs.dtype)
)

0 comments on commit 3d660d8

Please sign in to comment.