Skip to content

Commit

Permalink
Fix inference using OV backend
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants committed Nov 1, 2024
1 parent 22d5c17 commit fa6b461
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.openvino import core
from keras.src.backend.openvino import image
from keras.src.backend.openvino import linalg
from keras.src.backend.openvino import math
from keras.src.backend.openvino import nn
from keras.src.backend.openvino import numpy
from keras.src.backend.openvino import random
from keras.src.backend.openvino.core import IS_THREAD_SAFE
from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.openvino.core import Variable
from keras.src.backend.openvino.core import cast
Expand Down
3 changes: 3 additions & 0 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src.backend.common.stateless_scope import StatelessScope

SUPPORTS_SPARSE_TENSORS = False
IS_THREAD_SAFE = True

OPENVINO_DTYPES = {
"float16": ov.Type.f16,
Expand Down Expand Up @@ -116,6 +117,8 @@ def convert_to_numpy(x):


def is_tensor(x):
if isinstance(x, (ov.runtime.Output)):
return False
if isinstance(x, ov.Tensor):
return True
return False
Expand Down
6 changes: 4 additions & 2 deletions keras/src/backend/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _get_compiled_model(self):
# prepare compiled model from scratch
del self.ov_compiled_model
ov_inputs = []
parameters = []
for _input in self._inputs:
ov_type = OPENVINO_DTYPES[_input.dtype]
ov_shape = _input.shape
Expand All @@ -79,13 +80,14 @@ def _get_compiled_model(self):
if ov_shape[i] is None:
ov_shape[i] = -1
param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)
ov_inputs.append(param)
parameters.append(param)
ov_inputs.append(param.output(0))
# build OpenVINO graph ov.Model
ov_outputs = self._run_through_graph(
ov_inputs, operation_fn=lambda op: op
)
ov_outputs = tree.flatten(ov_outputs)
ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs)
ov_model = ov.Model(results=ov_outputs, parameters=parameters)
self.ov_compiled_model = ov.compile_model(ov_model, get_device())
self.ov_device = get_device()
return self.ov_compiled_model
Expand Down

0 comments on commit fa6b461

Please sign in to comment.