diff --git a/config/main.py b/config/main.py index d0c7c4258b..969e58d594 100644 --- a/config/main.py +++ b/config/main.py @@ -5012,7 +5012,7 @@ def polling_int(ctx, interval): config_db.mod_entry('SFLOW', 'global', sflow_tbl['global']) def is_valid_sample_rate(rate): - return rate in range(256, 8388608 + 1) + return rate.isdigit() and int(rate) in range(256, 8388608 + 1) # @@ -5070,24 +5070,31 @@ def disable(ctx, ifname): # @interface.command('sample-rate') @click.argument('ifname', metavar='', required=True, type=str) -@click.argument('rate', metavar='', required=True, type=int) +@click.argument('rate', metavar='', required=True, type=str) @click.pass_context def sample_rate(ctx, ifname, rate): config_db = ctx.obj['db'] if not interface_name_is_valid(config_db, ifname) and ifname != 'all': click.echo('Invalid interface name') return - if not is_valid_sample_rate(rate): - click.echo('Error: Sample rate must be between 256 and 8388608') + if not is_valid_sample_rate(rate) and rate != 'default': + click.echo('Error: Sample rate must be between 256 and 8388608 or default') return sess_dict = config_db.get_table('SFLOW_SESSION') - if sess_dict and ifname in sess_dict: + if sess_dict and ifname in sess_dict.keys(): + if rate == 'default': + if 'sample_rate' not in sess_dict[ifname]: + return + del sess_dict[ifname]['sample_rate'] + config_db.set_entry('SFLOW_SESSION', ifname, sess_dict[ifname]) + return sess_dict[ifname]['sample_rate'] = rate config_db.mod_entry('SFLOW_SESSION', ifname, sess_dict[ifname]) else: - config_db.mod_entry('SFLOW_SESSION', ifname, {'sample_rate': rate}) + if rate != 'default': + config_db.mod_entry('SFLOW_SESSION', ifname, {'sample_rate': rate}) # diff --git a/tests/sflow_test.py b/tests/sflow_test.py index 0e15f1e027..ecb2782534 100644 --- a/tests/sflow_test.py +++ b/tests/sflow_test.py @@ -290,6 +290,45 @@ def test_config_enable_all_intf(self): sflowSession = db.cfgdb.get_table('SFLOW_SESSION') assert sflowSession["all"]["admin_state"] == "up" + def test_config_sflow_intf_sample_rate_default(self): + db = Db() + runner = CliRunner() + obj = {'db':db.cfgdb} + + # mock interface_name_is_valid + config.interface_name_is_valid = mock.MagicMock(return_value = True) + + result_out1 = runner.invoke(show.cli.commands["sflow"].commands["interface"], [], obj=Db()) + print(result_out1.exit_code, result_out1.output) + assert result_out1.exit_code == 0 + + # set sample-rate to 2500 + result = runner.invoke(config.config.commands["sflow"]. + commands["interface"].commands["sample-rate"], + ["Ethernet2", "2500"], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + + # we can not use 'show sflow interface', becasue 'show sflow interface' + # gets data from appDB, we need to fetch data from configDB for verification + sflowSession = db.cfgdb.get_table('SFLOW_SESSION') + assert sflowSession["Ethernet2"]["sample_rate"] == "2500" + + # set sample-rate to default + result = runner.invoke(config.config.commands["sflow"]. + commands["interface"].commands["sample-rate"], + ["Ethernet2", "default"], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + + result_out2 = runner.invoke(show.cli.commands["sflow"].commands["interface"], [], obj=Db()) + print(result_out2.exit_code, result_out2.output) + assert result_out2.exit_code == 0 + assert result_out1.output == result_out2.output + + return + + @classmethod def teardown_class(cls): print("TEARDOWN")