From f9efeaadaaf0abc53547e042a1c7079aa1643e06 Mon Sep 17 00:00:00 2001 From: zhen Date: Fri, 23 Aug 2024 17:48:59 +0800 Subject: [PATCH] [experiment] Verify command injection when starting experiments asynchronously (#3685) # Description Added command parameter detection before executing commands to avoid risky characters. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](https://github.com/microsoft/promptflow/blob/main/CONTRIBUTING.md).** - [ ] **I confirm that all new dependencies are compatible with the MIT license.** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../_orchestrator/experiment_orchestrator.py | 12 ++++++++++ .../sdk_cli_test/e2etests/test_experiment.py | 23 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py b/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py index 978479d462b..7c427a78a47 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py +++ b/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py @@ -6,6 +6,7 @@ import json import os import platform +import re import signal import subprocess import sys @@ -414,6 +415,14 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt :return: Experiment info. :rtype: ~promptflow.entities.Experiment """ + def _params_inject_validation(params, param_name): + # Verify that the command is injected in the parameters. + # parameters can only consist of numeric, alphabetic parameters, strikethrough and dash. + pattern = r'^[a-zA-Z0-9 _\-]*$' + for item in params: + if not bool(re.match(pattern, item)): + raise ExperimentValueError(f"Invalid character found in the parameter {params} of {param_name}.") + # Setup file handler file_handler, index = _set_up_experiment_log_handler(experiment_path=self.experiment._output_dir, index=attempt) logger.addHandler(file_handler._stream_handler) @@ -423,10 +432,13 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt executable_path = executable_path or sys.executable args = [executable_path, __file__, "start", "--experiment", self.experiment.name] if nodes: + _params_inject_validation(nodes, "nodes") args = args + ["--nodes"] + nodes if from_nodes: + _params_inject_validation(from_nodes, "from-nodes") args = args + ["--from-nodes"] + from_nodes if kwargs.get("session"): + _params_inject_validation(kwargs.get("session"), "session") args = args + ["--session", kwargs.get("session")] args = args + ["--attempt", str(index)] # Start an orchestrator process using detach mode diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py index 14a74f59b0b..40c3d4758b1 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py @@ -217,6 +217,29 @@ def test_experiment_start_from_nodes(self): assert len(exp.node_runs["main"]) == 3 assert len(exp.node_runs["echo"]) == 2 + @pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection") + def test_experiment_start_with_command_injection(self): + template_path = EXP_ROOT / "basic-script-template" / "basic-script.exp.yaml" + # Load template and create experiment + template = load_common(ExperimentTemplate, source=template_path) + experiment = Experiment.from_template(template) + client = PFClient() + exp = client._experiments.create_or_update(experiment) + + # Test start experiment with injection command + injection_command = ";bad command;" + with pytest.raises(ExperimentValueError) as error: + client._experiments.start(exp, nodes=[injection_command]) + assert "Invalid character found" in str(error.value) + + with pytest.raises(ExperimentValueError): + client._experiments.start(exp, from_nodes=[injection_command]) + assert "Invalid character found" in str(error.value) + + with pytest.raises(ExperimentValueError): + client._experiments.start(exp, session=injection_command) + assert "Invalid character found" in str(error.value) + @pytest.mark.skipif(condition=not pytest.is_live, reason="Injection cannot passed to detach process.") def test_cancel_experiment(self): template_path = EXP_ROOT / "command-node-exp-template" / "basic-command.exp.yaml"