diff --git a/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py b/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py index 978479d462b..7c427a78a47 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py +++ b/src/promptflow-devkit/promptflow/_sdk/_orchestrator/experiment_orchestrator.py @@ -6,6 +6,7 @@ import json import os import platform +import re import signal import subprocess import sys @@ -414,6 +415,14 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt :return: Experiment info. :rtype: ~promptflow.entities.Experiment """ + def _params_inject_validation(params, param_name): + # Verify that the command is injected in the parameters. + # parameters can only consist of numeric, alphabetic parameters, strikethrough and dash. + pattern = r'^[a-zA-Z0-9 _\-]*$' + for item in params: + if not bool(re.match(pattern, item)): + raise ExperimentValueError(f"Invalid character found in the parameter {params} of {param_name}.") + # Setup file handler file_handler, index = _set_up_experiment_log_handler(experiment_path=self.experiment._output_dir, index=attempt) logger.addHandler(file_handler._stream_handler) @@ -423,10 +432,13 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt executable_path = executable_path or sys.executable args = [executable_path, __file__, "start", "--experiment", self.experiment.name] if nodes: + _params_inject_validation(nodes, "nodes") args = args + ["--nodes"] + nodes if from_nodes: + _params_inject_validation(from_nodes, "from-nodes") args = args + ["--from-nodes"] + from_nodes if kwargs.get("session"): + _params_inject_validation(kwargs.get("session"), "session") args = args + ["--session", kwargs.get("session")] args = args + ["--attempt", str(index)] # Start an orchestrator process using detach mode diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py index 14a74f59b0b..40c3d4758b1 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_experiment.py @@ -217,6 +217,29 @@ def test_experiment_start_from_nodes(self): assert len(exp.node_runs["main"]) == 3 assert len(exp.node_runs["echo"]) == 2 + @pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection") + def test_experiment_start_with_command_injection(self): + template_path = EXP_ROOT / "basic-script-template" / "basic-script.exp.yaml" + # Load template and create experiment + template = load_common(ExperimentTemplate, source=template_path) + experiment = Experiment.from_template(template) + client = PFClient() + exp = client._experiments.create_or_update(experiment) + + # Test start experiment with injection command + injection_command = ";bad command;" + with pytest.raises(ExperimentValueError) as error: + client._experiments.start(exp, nodes=[injection_command]) + assert "Invalid character found" in str(error.value) + + with pytest.raises(ExperimentValueError): + client._experiments.start(exp, from_nodes=[injection_command]) + assert "Invalid character found" in str(error.value) + + with pytest.raises(ExperimentValueError): + client._experiments.start(exp, session=injection_command) + assert "Invalid character found" in str(error.value) + @pytest.mark.skipif(condition=not pytest.is_live, reason="Injection cannot passed to detach process.") def test_cancel_experiment(self): template_path = EXP_ROOT / "command-node-exp-template" / "basic-command.exp.yaml"