Skip to content

Commit

Permalink
feat(api): add a way for the server to disable certain platforms (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 25, 2023
1 parent 43787f0 commit 67d51a9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def load_platforms():

providers = get_available_providers()
available_platforms = [p for p in platform_providers if (
platform_providers[p] in providers)]
platform_providers[p] in providers and p not in context.block_platforms)]

print('available acceleration platforms: %s' % (available_platforms))

Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ def __init__(
params_path: str = '.',
cors_origin: str = '*',
num_workers: int = 1,
block_platforms: List[str] = [],
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers
self.block_platforms = block_platforms

@classmethod
def from_environ(cls):
Expand All @@ -81,6 +83,7 @@ def from_environ(cls):
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get('ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
)


Expand Down

0 comments on commit 67d51a9

Please sign in to comment.