diff --git a/.gitignore b/.gitignore index f5ca491..61f96e3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ docs/_build tmp/ .coverage htmlcov +*.iml diff --git a/bin/wifi b/bin/wifi new file mode 100755 index 0000000..140d437 --- /dev/null +++ b/bin/wifi @@ -0,0 +1,205 @@ +#!/usr/bin/python +from __future__ import print_function +import argparse +import sys +import os + +from wifi import Cell, Scheme +from wifi.utils import print_table, match as fuzzy_match +from wifi.exceptions import ConnectionError, InterfaceError + +try: # Python 2.x + input = raw_input +except NameError: + pass + + +def fuzzy_find_cell(interface, query): + match_partial = lambda cell: fuzzy_match(query, cell.ssid) + + matches = Cell.where(interface, match_partial) + + num_unique_matches = len(set(cell.ssid for cell in matches)) + assert num_unique_matches > 0, "Couldn't find a network that matches '{}'".format(query) + assert num_unique_matches < 2, "Found more than one network that matches '{}'".format(query) + + # Several cells of the same SSID + if len(matches) > 1: + matches.sort(key=lambda cell: cell.signal) + + return matches[0] + + +def find_cell(interface, query): + cell = Cell.where(interface, lambda cell: cell.ssid.lower() == query.lower()) + + try: + cell = cell[0] + except IndexError: + cell = fuzzy_find_cell(interface, query) + return cell + + +def get_scheme_params(interface, scheme, ssid=None): + cell = find_cell(interface, ssid or scheme) + passkey = None if not cell.encrypted else input('passkey> ') + + return interface, scheme, cell, passkey + + +def scan_command(args): + print_table([[cell.signal, cell.ssid, 'protected' if cell.encrypted else 'unprotected'] for cell in Cell.all(args.interface)]) + + +def list_command(args): + for scheme in Scheme.for_file(args.file).all(): + print(scheme.name) + + +def show_command(args): + scheme = Scheme.for_file(args.file).for_cell(*get_scheme_params(args.interface, args.scheme, args.ssid)) + print(scheme) + + +def add_command(args): + scheme_class = Scheme.for_file(args.file) + assert not scheme_class.find(args.interface, args.scheme), "That scheme has already been used" + + scheme = scheme_class.for_cell(*get_scheme_params(args.interface, args.scheme, args.ssid)) + scheme.save() + + +def connect_command(args): + scheme_class = Scheme.for_file(args.file) + if args.adhoc: + # ensure that we have the adhoc utility scheme + try: + adhoc_scheme = scheme_class(args.interface, 'adhoc') + adhoc_scheme.save() + except AssertionError: + pass + except IOError: + assert False, "Can't write on {0!r}, do you have required privileges?".format(args.file) + + scheme = scheme_class.for_cell(*get_scheme_params(args.interface, 'adhoc', args.scheme)) + else: + scheme = scheme_class.find(args.interface, args.scheme) + assert scheme, "Couldn't find a scheme named {0!r}, did you mean to use -a?".format(args.scheme) + + try: + scheme.activate() + except ConnectionError: + assert False, "Failed to connect to %s." % scheme.name + + +def autoconnect_command(args): + ssids = [cell.ssid for cell in Cell.all(args.interface)] + + for scheme in Scheme.all(): + # TODO: make it easier to get the SSID off of a scheme. + ssid = scheme.options.get('wpa-ssid', scheme.options.get('wireless-essid')) + if ssid in ssids: + sys.stderr.write('Connecting to "%s".\n' % ssid) + try: + scheme.activate() + except ConnectionError: + assert False, "Failed to connect to %s." % scheme.name + break + else: + assert False, "Couldn't find any schemes that are currently available." + + +def arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', + '--interface', + default='wlan0', + help="Specifies which interface to use (wlan0, eth0, etc.)") + parser.add_argument('-f', + '--file', + default='/etc/network/interfaces', + help="Specifies which file for scheme storage.") + + subparsers = parser.add_subparsers(title='commands') + + parser_scan = subparsers.add_parser('scan', help="Shows a list of available networks.") + parser_scan.set_defaults(func=scan_command) + + parser_list = subparsers.add_parser('list', help="Shows a list of networks already configured.") + parser_list.set_defaults(func=list_command) + + scheme_help = ("A memorable nickname for a wireless network." + " If SSID is not provided, the network will be guessed using SCHEME.") + ssid_help = ("The SSID for the network to which you wish to connect." + " This is fuzzy matched, so you don't have to be precise.") + + parser_show = subparsers.add_parser('config', + help="Prints the configuration to connect to a new network.") + parser_show.add_argument('scheme', help=scheme_help, metavar='SCHEME') + parser_show.add_argument('ssid', nargs='?', help=ssid_help, metavar='SSID') + parser_show.set_defaults(func=show_command) + + parser_add = subparsers.add_parser('add', + help="Adds the configuration to connect to a new network.") + parser_add.add_argument('scheme', help=scheme_help, metavar='SCHEME') + parser_add.add_argument('ssid', nargs='?', help=ssid_help, metavar='SSID') + parser_add.set_defaults(func=add_command) + + parser_connect = subparsers.add_parser('connect', + help="Connects to the network corresponding to SCHEME") + parser_connect.add_argument('scheme', + help="The nickname of the network to which you wish to connect.", + metavar='SCHEME') + parser_connect.add_argument('-a', + '--ad-hoc', + dest='adhoc', + action="store_true", + help="Connect to a network without storing it in the config file") + parser_connect.set_defaults(func=connect_command) + + + # TODO: how to specify the correct interfaces file to work off of. + parser_connect.get_options = lambda: [scheme.name for scheme in Scheme.all()] + + parser_autoconnect = subparsers.add_parser( + 'autoconnect', + help="Searches for saved schemes that are currently" + " available and connects to the first one it finds." + ) + parser_autoconnect.set_defaults(func=autoconnect_command) + + return parser, subparsers + + +def autocomplete(position, wordlist, subparsers): + if position == 1: + ret = subparsers.choices.keys() + else: + try: + prev = wordlist[position - 1] + ret = subparsers.choices[prev].get_options() + except (IndexError, KeyError, AttributeError): + ret = [] + + print(' '.join(ret)) + + +if __name__ == "__main__": + parser, subparsers = arg_parser() + + if len(sys.argv) == 1: + argv = ['scan'] + else: + argv = sys.argv[1:] + + args = parser.parse_args(argv) + + try: + if 'WIFI_AUTOCOMPLETE' in os.environ: + autocomplete(int(os.environ['COMP_CWORD']), + os.environ['COMP_WORDS'].split(), subparsers) + else: + args.func(args) + except (AssertionError, InterfaceError) as e: + sys.stderr.write("Error: ") + sys.exit(e) diff --git a/docs/index.rst b/docs/index.rst index 71f34a2..c5a5eff 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,8 @@ Using this library, you can discover networks, connect to them, save your config The original impetus for creating this library was my frustration with with connecting to the Internet using NetworkManager and wicd. It is very much for computer programmers, not so much for normal computer users. -Wifi is built on top the old technologies of the `/etc/network/interfaces` file and `ifup` and `ifdown`. +Wifi is built on top the old technologies of the `/etc/network/interfaces` file and `ifup` and `ifdown` as well as +`hostapd` and `dnsmasq` for creating access points. It is inspired by `ifscheme`. The library also comes with an executable that you can use to manage your WiFi connections. diff --git a/setup.py b/setup.py index 51fece7..733ffdd 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from setuptools import setup +from setuptools import setup, Command import os import sys @@ -16,38 +16,120 @@ def read(fname): install_requires = [ 'setuptools', 'pbkdf2', + 'netaddr' ] try: import argparse except: install_requires.append('argparse') -version = '1.0.0' - -should_install_cli = os.environ.get('WIFI_INSTALL_CLI') not in ['False', '0'] -command_name = os.environ.get('WIFI_CLI_NAME', 'wifi') - -if command_name == 'wifi.py': - print( - "Having a command name of wifi.py will result in a weird ImportError" - " that doesn't seem possible to work around. Pretty much any other" - " name seems to work though." - ) - sys.exit(1) - -entry_points = {} -data_files = [] - -if should_install_cli: - entry_points['console_scripts'] = [ - '{command} = wifi.cli:main'.format(command=command_name), - ] - # make sure we actually have write access to the target folder and if not don't - # include it in data_files - if os.access('/etc/bash_completion.d/', os.W_OK): - data_files.append(('/etc/bash_completion.d/', ['extras/wifi-completion.bash'])) +version = '1.0.1' + +EXTRAS = [ + ('/etc/bash_completion.d/', [('extras/wifi-completion.bash', 'wifi-completion', 0644)]) +] + + +def get_extra_tuple(entry): + if isinstance(entry, (tuple, list)): + if len(entry) == 2: + path, mode = entry + filename = os.path.basename(path) + elif len(entry) == 3: + path, filename, mode = entry + elif len(entry) == 1: + path = entry[0] + filename = os.path.basename(path) + mode = None + else: + return None + else: - print("Not installing bash completion because of lack of permissions.") + path = entry + filename = os.path.basename(path) + mode = None + + return path, filename, mode + + +class InstallExtrasCommand(Command): + description = "install extras like init scripts and config files" + user_options = [("force", "F", "force overwriting files if they already exist")] + + def initialize_options(self): + self.force = None + + def finalize_options(self): + if self.force is None: + self.force = False + + def run(self): + global EXTRAS + import shutil + import os + + for target, files in EXTRAS: + for entry in files: + extra_tuple = get_extra_tuple(entry) + if extra_tuple is None: + print("Can't parse entry for target %s, skipping it: %r" % (target, entry)) + continue + + path, filename, mode = extra_tuple + target_path = os.path.join(target, filename) + + path_exists = os.path.exists(target_path) + if path_exists and not self.force: + print("Skipping copying %s to %s as it already exists, use --force to overwrite" % (path, target_path)) + continue + + try: + shutil.copy(path, target_path) + if mode: + os.chmod(target_path, mode) + print("Copied %s to %s and changed mode to %o" % (path, target_path, mode)) + else: + print("Copied %s to %s" % (path, target_path)) + except Exception as e: + if not path_exists and os.path.exists(target_path): + # we'll try to clean up again + try: + os.remove(target_path) + except: + pass + + import sys + print("Error while copying %s to %s (%s), aborting" % (path, target_path, e.message)) + sys.exit(-1) + + +class UninstallExtrasCommand(Command): + description = "uninstall extras like init scripts and config files" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + global EXTRAS + import os + + for target, files in EXTRAS: + for entry in files: + extra_tuple = get_extra_tuple(entry) + if extra_tuple is None: + print("Can't parse entry for target %s, skipping it: %r" % (target, entry)) + + path, filename, mode = extra_tuple + target_path = os.path.join(target, filename) + try: + os.remove(target_path) + print("Removed %s" % target_path) + except Exception as e: + print("Error while deleting %s from %s (%s), please remove manually" % (filename, target, e.message)) setup( name='wifi', @@ -57,7 +139,7 @@ def read(fname): description=__doc__, long_description='\n\n'.join([read('README.rst'), read('CHANGES.rst')]), packages=['wifi'], - entry_points=entry_points, + scripts=['bin/wifi'], test_suite='tests', platforms=["Debian"], license='BSD', @@ -72,5 +154,8 @@ def read(fname): "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3.3", ], - data_files=data_files + cmdclass={ + 'install_extras': InstallExtrasCommand, + 'uninstall_extras': UninstallExtrasCommand + } ) diff --git a/tests/test_ap.py b/tests/test_ap.py new file mode 100644 index 0000000..fbc626b --- /dev/null +++ b/tests/test_ap.py @@ -0,0 +1,242 @@ +from unittest import TestCase +import tempfile +import os +import shutil + +from wifi import Hostapd, Dnsmasq +from wifi.exceptions import BindError + +HOSTAPD_FILE_WITHOUT_ENCRYPTION = """interface=wlan0 +driver=nl80211 +ssid=SsidWithoutEncryption +channel=3 +""" + +HOSTAPD_FILE_WITH_ENCRYPTION = """interface=wlan0 +driver=madwifi +ssid=SsidWithEncryption +channel=5 +wpa=3 +wpa_passphrase=MySecretPresharedKey +wpa_key_mgmt=WPA-PSK +wpa_pairwise=TKIP CCMP +rsn_pairwise=CCMP +hw_mode=g +auth_algs=3 +""" + +class TestHostapd(TestCase): + + def setUp(self): + self.confd = tempfile.mkdtemp() + + with open(os.path.join(self.confd, "ssid_without_encryption.conf"), "w") as f: + f.write(HOSTAPD_FILE_WITHOUT_ENCRYPTION) + + with open(os.path.join(self.confd, "ssid_with_encryption.conf"), "w") as f: + f.write(HOSTAPD_FILE_WITH_ENCRYPTION) + + self.Hostapd = Hostapd.for_hostapd_and_confd(None, self.confd) + + def tearDown(self): + shutil.rmtree(self.confd) + + def test_str(self): + hostapd = self.Hostapd('wlan0', 'some_name', 'SomeSsid', 3) + self.assertEquals(str(hostapd), "interface=wlan0\ndriver=nl80211\nssid=SomeSsid\nchannel=3") + + hostapd = self.Hostapd('wlan0', 'some_name', 'SomeSsid', 3, driver='madwifi') + self.assertEquals(str(hostapd), "interface=wlan0\ndriver=madwifi\nssid=SomeSsid\nchannel=3") + + hostapd = self.Hostapd('wlan0', 'some_name', 'SomeSsid', 3, psk="SuperSecret") + self.assertEqual(str(hostapd), "interface=wlan0\ndriver=nl80211\nssid=SomeSsid\nchannel=3\nwpa=3\nwpa_passphrase=SuperSecret\nwpa_key_mgmt=WPA-PSK\nwpa_pairwise=TKIP CCMP\nrsn_pairwise=CCMP") + + hostapd = self.Hostapd('wlan0', 'some_name', 'SomeSsid', 3, options=dict(some_option="some_value")) + self.assertEquals(str(hostapd), "interface=wlan0\ndriver=nl80211\nssid=SomeSsid\nchannel=3\nsome_option=some_value") + + def test_find(self): + with_encryption = self.Hostapd.find('wlan0', 'ssid_with_encryption') + self.assertIsNotNone(with_encryption) + self.assertEquals(with_encryption.ssid, "SsidWithEncryption") + + wrong_interface = self.Hostapd.find('wlan1', 'ssid_with_encryption') + self.assertIsNone(wrong_interface) + + unknown = self.Hostapd.find('wlan0', 'unknown_ssid') + self.assertIsNone(unknown) + + def test_delete(self): + with_encryption = self.Hostapd.find('wlan0', 'ssid_with_encryption') + with_encryption.delete() + self.assertIsNone(self.Hostapd.find('wlan0', 'ssid_with_encryption')) + self.assertIsNotNone(self.Hostapd.find('wlan0', 'ssid_without_encryption')) + + def test_parse(self): + hostapd = self.Hostapd.from_hostapd_conf(os.path.join(self.confd, "ssid_without_encryption.conf")) + self.assertEquals("wlan0", hostapd.interface) + self.assertEquals("ssid_without_encryption", hostapd.name) + self.assertEquals("SsidWithoutEncryption", hostapd.ssid) + self.assertEquals(3, hostapd.channel) + self.assertEquals("nl80211", hostapd.driver) + self.assertIsNone(hostapd.psk) + self.assertDictEqual(dict(), hostapd.options) + + hostapd = self.Hostapd.from_hostapd_conf(os.path.join(self.confd, "ssid_with_encryption.conf")) + self.assertEquals("wlan0", hostapd.interface) + self.assertEquals("ssid_with_encryption", hostapd.name) + self.assertEquals("SsidWithEncryption", hostapd.ssid) + self.assertEquals(5, hostapd.channel) + self.assertEquals("madwifi", hostapd.driver) + self.assertEquals("MySecretPresharedKey", hostapd.psk) + self.assertEquals(2, len(hostapd.options)) + self.assertTrue("hw_mode" in hostapd.options) + self.assertEquals("g", hostapd.options["hw_mode"]) + self.assertTrue("auth_algs" in hostapd.options) + self.assertEquals("3", hostapd.options["auth_algs"]) + + + def test_save(self): + hostapd = self.Hostapd('wlan0', 'test', 'Test', 3) + hostapd.save() + self.assertIsNotNone(self.Hostapd.find('wlan0', 'test')) + + def test_save_overwrite(self): + hostapd = self.Hostapd('wlan0', 'ssid_without_encryption', 'SsidWithoutEncryption', 3, driver='madwifi') + + try: + hostapd.save() + self.fail("Expected an exception") + except: + pass + + existing_hostapd = self.Hostapd.find('wlan0', 'ssid_without_encryption') + self.assertIsNotNone(existing_hostapd) + self.assertEquals(existing_hostapd.driver, 'nl80211') + + hostapd.save(allow_overwrite=True) + existing_hostapd = self.Hostapd.find('wlan0', 'ssid_without_encryption') + self.assertIsNotNone(existing_hostapd) + self.assertEquals(existing_hostapd.driver, 'madwifi') + + +DNSMASQ_FILE_1 = """interface=wlan0 +bind-interfaces +dhcp-range=192.168.0.100,192.168.0.200,600 +""" + +DNSMASQ_FILE_2 = """interface=wlan0 +bind-interfaces +dhcp-range=10.10.0.1,10.10.254.254,7200 +local=/mydomain/ +domain=mydomain +expand-hosts +dhcp-option=option:router,10.0.0.1 +dhcp-option=option:ntp-server,10.0.0.2 +read-ethers +""" + +DNSMASQ_FILE_3 = """interface=wlan0 +bind-interfaces +dhcp-range=192.168.0.100,192.168.0.200,5m +""" + +DNSMASQ_FILE_4 = """interface=wlan0 +bind-interfaces +dhcp-range=192.168.0.100,192.168.0.200,12h +""" + + +class TestDnsmasq(TestCase): + + def setUp(self): + self.confd = tempfile.mkdtemp() + + with open(os.path.join(self.confd, "dnsmasq_1.conf"), "w") as f: + f.write(DNSMASQ_FILE_1) + + with open(os.path.join(self.confd, "dnsmasq_2.conf"), "w") as f: + f.write(DNSMASQ_FILE_2) + + with open(os.path.join(self.confd, "dnsmasq_3.conf"), "w") as f: + f.write(DNSMASQ_FILE_3) + + with open(os.path.join(self.confd, "dnsmasq_4.conf"), "w") as f: + f.write(DNSMASQ_FILE_4) + + self.Dnsmasq = Dnsmasq.for_dnsmasq_and_confd(None, self.confd) + + def tearDown(self): + shutil.rmtree(self.confd) + + def test_str(self): + dnsmasq = self.Dnsmasq("wlan0", "test", "192.168.1.100", "192.168.1.200") + self.assertEquals(str(dnsmasq), "interface=wlan0\nbind-interfaces\ndhcp-range=192.168.1.100,192.168.1.200,600") + + dnsmasq = self.Dnsmasq("wlan0", "test", "10.10.0.1", "10.10.254.254", lease_time=7200, gateway="10.0.0.1", domain="mydomain") + self.assertEquals(str(dnsmasq), "interface=wlan0\nbind-interfaces\ndhcp-range=10.10.0.1,10.10.254.254,7200\nlocal=/mydomain/\ndomain=mydomain\nexpand-hosts\ndhcp-option=option:router,10.0.0.1") + + dnsmasq = self.Dnsmasq("wlan0", "test", "10.10.0.1", "10.10.254.254", gateway="10.0.0.1", options={"dhcp-option": ["option:ntp-server,10.0.0.2"]}) + self.assertEquals(str(dnsmasq), "interface=wlan0\nbind-interfaces\ndhcp-range=10.10.0.1,10.10.254.254,600\ndhcp-option=option:router,10.0.0.1\ndhcp-option=option:ntp-server,10.0.0.2") + + def test_parse(self): + dnsmasq = self.Dnsmasq.from_dnsmasq_conf(os.path.join(self.confd, "dnsmasq_1.conf")) + self.assertEquals("wlan0", dnsmasq.interface) + self.assertEquals("dnsmasq_1", dnsmasq.name) + self.assertEquals("192.168.0.100", dnsmasq.start) + self.assertEquals("192.168.0.200", dnsmasq.end) + self.assertEquals(600, dnsmasq.lease_time) + self.assertIsNone(dnsmasq.gateway) + self.assertIsNone(dnsmasq.domain) + self.assertDictEqual(dict(), dnsmasq.options) + + dnsmasq = self.Dnsmasq.from_dnsmasq_conf(os.path.join(self.confd, "dnsmasq_2.conf")) + self.assertEquals("wlan0", dnsmasq.interface) + self.assertEquals("dnsmasq_2", dnsmasq.name) + self.assertEquals("10.10.0.1", dnsmasq.start) + self.assertEquals("10.10.254.254", dnsmasq.end) + self.assertEquals(7200, dnsmasq.lease_time) + self.assertEquals("10.0.0.1", dnsmasq.gateway) + self.assertEquals("mydomain", dnsmasq.domain) + self.assertEquals(2, len(dnsmasq.options)) + self.assertTrue("dhcp-option" in dnsmasq.options) + self.assertEquals(1, len(dnsmasq.options["dhcp-option"])) + self.assertEquals("option:ntp-server,10.0.0.2", dnsmasq.options["dhcp-option"][0]) + self.assertTrue("read-ethers" in dnsmasq.options) + self.assertIsNone(dnsmasq.options["read-ethers"]) + + dnsmasq = self.Dnsmasq.from_dnsmasq_conf(os.path.join(self.confd, "dnsmasq_3.conf")) + self.assertEquals(300, dnsmasq.lease_time) + + dnsmasq = self.Dnsmasq.from_dnsmasq_conf(os.path.join(self.confd, "dnsmasq_4.conf")) + self.assertEquals(43200, dnsmasq.lease_time) + + def test_find(self): + self.assertIsNotNone(self.Dnsmasq.find("wlan0", "dnsmasq_1")) + self.assertIsNone(self.Dnsmasq.find("eth0", "dnsmasq_1")) + self.assertIsNone(self.Dnsmasq.find("wlan0", "unknown")) + + def test_save(self): + dnsmasq = self.Dnsmasq('wlan0', 'test', '192.168.10.10', '192.168.10.20') + dnsmasq.save() + self.assertIsNotNone(self.Dnsmasq.find('wlan0', 'test')) + pass + + def test_save_overwrite(self): + dnsmasq = self.Dnsmasq('wlan0', 'dnsmasq_1', '192.168.10.100', '192.168.10.200') + + try: + dnsmasq.save() + self.fail("Expected an exception") + except: + pass + + existing_dnsmasq = self.Dnsmasq.find('wlan0', 'dnsmasq_1') + self.assertIsNotNone(existing_dnsmasq) + self.assertEquals(existing_dnsmasq.start, '192.168.0.100') + self.assertEquals(existing_dnsmasq.end, '192.168.0.200') + + dnsmasq.save(allow_overwrite=True) + existing_dnsmasq = self.Dnsmasq.find('wlan0', 'dnsmasq_1') + self.assertIsNotNone(existing_dnsmasq) + self.assertEquals(existing_dnsmasq.start, '192.168.10.100') + self.assertEquals(existing_dnsmasq.end, '192.168.10.200') diff --git a/tests/test_schemes.py b/tests/test_schemes.py index be5c843..1b9f501 100644 --- a/tests/test_schemes.py +++ b/tests/test_schemes.py @@ -63,10 +63,10 @@ def test_scheme_extraction(self): work, coffee, home, coffee2 = list(extract_schemes(NETWORK_INTERFACES_FILE))[:4] assert work.name == 'work' - assert work.options['wpa-ssid'] == 'workwifi' + assert work.options['wpa-ssid'] == ['workwifi'] assert coffee.name == 'coffee' - assert coffee.options['wireless-essid'] == 'Coffee WiFi' + assert coffee.options['wireless-essid'] == ['Coffee WiFi'] def test_with_hyphen(self): with_hyphen = self.Scheme.find('wlan0', 'with-hyphen') @@ -79,7 +79,7 @@ def test_str(self): scheme = self.Scheme('wlan0', 'test') assert str(scheme) == 'iface wlan0-test inet dhcp\n' - scheme = self.Scheme('wlan0', 'test', { + scheme = self.Scheme('wlan0', 'test', options={ 'wpa-ssid': 'workwifi', }) @@ -88,7 +88,7 @@ def test_str(self): def test_find(self): work = self.Scheme.find('wlan0', 'work') - assert work.options['wpa-ssid'] == 'workwifi' + assert work.options['wpa-ssid'] == ['workwifi'] def test_delete(self): work = self.Scheme.find('wlan0', 'work') @@ -124,8 +124,8 @@ def test_unencrypted(self): scheme = Scheme.for_cell('wlan0', 'test', cell) self.assertEqual(scheme.options, { - 'wireless-essid': 'SSID', - 'wireless-channel': 'auto', + 'wireless-essid': ['SSID'], + 'wireless-channel': ['auto'], }) def test_wep_hex(self): @@ -140,8 +140,8 @@ def test_wep_hex(self): scheme = Scheme.for_cell('wlan0', 'test', cell, key) self.assertEqual(scheme.options, { - 'wireless-essid': 'SSID', - 'wireless-key': key + 'wireless-essid': ['SSID'], + 'wireless-key': [key] }) def test_wep_ascii(self): @@ -156,8 +156,8 @@ def test_wep_ascii(self): scheme = Scheme.for_cell('wlan0', 'test', cell, key) self.assertEqual(scheme.options, { - 'wireless-essid': 'SSID', - 'wireless-key': 's:' + key + 'wireless-essid': ['SSID'], + 'wireless-key': ['s:' + key] }) def test_wpa2(self): @@ -169,9 +169,9 @@ def test_wpa2(self): scheme = Scheme.for_cell('wlan0', 'test', cell, b'passkey') self.assertEqual(scheme.options, { - 'wpa-ssid': 'SSID', - 'wpa-psk': 'ea1548d4e8850c8d94c5ef9ed6fe483981b64c1436952cb1bf80c08a68cdc763', - 'wireless-channel': 'auto', + 'wpa-ssid': ['SSID'], + 'wpa-psk': ['ea1548d4e8850c8d94c5ef9ed6fe483981b64c1436952cb1bf80c08a68cdc763'], + 'wireless-channel': ['auto'], }) def test_wpa(self): @@ -183,9 +183,9 @@ def test_wpa(self): scheme = Scheme.for_cell('wlan0', 'test', cell, 'passkey') self.assertEqual(scheme.options, { - 'wpa-ssid': 'SSID', - 'wpa-psk': 'ea1548d4e8850c8d94c5ef9ed6fe483981b64c1436952cb1bf80c08a68cdc763', - 'wireless-channel': 'auto', + 'wpa-ssid': ['SSID'], + 'wpa-psk': ['ea1548d4e8850c8d94c5ef9ed6fe483981b64c1436952cb1bf80c08a68cdc763'], + 'wireless-channel': ['auto'], }) diff --git a/wifi/__init__.py b/wifi/__init__.py index 777f763..ba2bb7e 100644 --- a/wifi/__init__.py +++ b/wifi/__init__.py @@ -1,2 +1,3 @@ from wifi.scan import Cell from wifi.scheme import Scheme +from wifi.ap import AccessPoint, Hostapd, Dnsmasq diff --git a/wifi/ap.py b/wifi/ap.py new file mode 100644 index 0000000..e83b119 --- /dev/null +++ b/wifi/ap.py @@ -0,0 +1,717 @@ +from __future__ import print_function, absolute_import + +import netaddr +import os +import logging +import re + +from wifi import Scheme +from wifi.exceptions import BindError +from wifi.utils import mac_addr_pattern +import wifi.subprocess_compat as subprocess + + +bound_ap_re = re.compile(r"^Using interface (?P\w+) with hwaddr %s and ssid '(?P[^']+)'" % mac_addr_pattern, flags=re.MULTILINE) + + +class Hostapd(object): + """ + A wrapper for managing hostapd configuration files stored under /etc/hostapd/conf.d and + managing hostapd service instances based on them. + + Note: The directory /etc/hostapd/conf.d does not usually exist and has to be created before using + this class. Alternatively provide a different location for file storage by creating a custom type + wrapper using `Hostapd.for_hostapd_and_confd(hostapd=, confd=)`. + """ + + # location of hostapd binary + hostapd = "/usr/sbin/hostapd" + + # location of hostapd config folder + confd = "/etc/hostapd/conf.d/" + + # our logger instance + logger = logging.getLogger(__name__) + + @classmethod + def for_hostapd_and_confd(cls, hostapd, confd): + return type(cls)(cls.__name__, (cls,), { + 'hostapd': hostapd if hostapd is not None else cls.hostapd, + 'confd': confd if confd is not None else cls.confd, + }) + + def __init__(self, interface, name, ssid, channel, driver=None, psk=None, options=None): + self.interface = interface + self.driver = driver if driver is not None else "nl80211" + self.name = name + self.ssid = ssid + self.channel = channel + self.psk = psk + + self.options = options if options else dict() + + def __str__(self): + # default parameters for a simply ap + conf = [ + "interface={interface}", + "driver={driver}", + "ssid={ssid}", + "channel={channel}" + ] + + if self.psk is not None: + # parameters for encryption via WPA + conf += [ + "wpa=3", + "wpa_passphrase={psk}", + "wpa_key_mgmt=WPA-PSK", + "wpa_pairwise=TKIP CCMP", + "rsn_pairwise=CCMP" + ] + + if self.options: + # any additional options given + conf += ["{k}={v}".format(k=k, v=v) for k, v in self.options.items()] + + return "\n".join(conf).format(**vars(self)) + + def __repr__(self): + return "Hostapd(interface={interface!r}, driver={driver!r}, name={name!r}, ssid={ssid!r})"\ + .format(**vars(self)) + + def save(self, allow_overwrite=False): + existing_hostapd = self.__class__.find(self.interface, self.name) + if existing_hostapd: + if not allow_overwrite: + raise RuntimeError("Config for interface %s named %s does already exists and overwrite is not allowed" % (self.interface, self.name)) + existing_hostapd.delete() + + with open(self.configfile, "w") as f: + f.write(str(self)) + + def delete(self): + if self.is_running(): + self.deactivate() + + try: + os.remove(self.configfile) + except OSError as e: + self._logger.warn("Could not delete %s: %s" % (self.configfile, e)) + + def activate(self): + try: + output = subprocess.check_output([self.__class__.hostapd, "-dd", "-B", self.configfile], stderr=subprocess.STDOUT) + self._logger.info("Started hostapd: {output}".format(output=output)) + return True + except subprocess.CalledProcessError as e: + self._logger.warn("Error while starting hostapd: {output}".format(output=e.output)) + raise e + + def deactivate(self): + pid = self.get_pid() + if pid is None: + return + subprocess.check_call(["kill", pid]) + + def get_pid(self): + pids = [pid for pid in os.listdir("/proc") if pid.isdigit()] + for pid in pids: + try: + with open(os.path.join("/proc", pid, "cmdline"), "r") as f: + line = f.readline() + if self.__class__.hostapd in line and self.configfile in line: + return pid + except: + # the pid might just have vanished because the process exited normally, no need to worry + pass + return None + + def is_running(self): + return self.get_pid() is not None + + @property + def configfile(self): + return os.path.join(self.__class__.confd, "{name}.conf".format(name=self.name)) + + @property + def _logger(self): + return self.__class__.logger + + @classmethod + def all(cls): + result = [] + for conf in os.listdir(cls.confd): + if conf.endswith(".conf"): + filename = os.path.join(cls.confd, conf) + try: + ap = cls.from_hostapd_conf(filename) + result.append(ap) + except: + cls.logger.exception("Could not retrieve hostapd from file %s:" % filename) + return result + + @classmethod + def find(cls, interface, name): + try: + return cls.where(lambda x: x.name == name and x.interface == interface)[0] + except IndexError: + return None + + @classmethod + def where(cls, fn): + return list(filter(fn, cls.all())) + + @classmethod + def from_hostapd_conf(cls, configfile): + if not os.path.exists(configfile): + raise IOError("Configfile not found: %s" % configfile) + + name = os.path.basename(configfile)[:-len(".conf")] + + conf_options = dict() + with open(configfile, "r") as f: + for line in f: + k, v = line.strip().split("=", 1) + conf_options[k] = v + + for key in ("interface", "ssid", "channel"): + if not key in conf_options or conf_options[key] is None: + raise RuntimeError("Invalid config, %s is missing or none" % key) + + options = dict((k, conf_options[k]) for k in conf_options if not k in ["interface", "driver", "ssid", "channel", "wpa", "wpa_passphrase", "wpa_key_mgmt", "wpa_pairwise", "rsn_pairwise"]) + psk = conf_options["wpa_passphrase"] if "wpa_passphrase" in conf_options else None + try: + channel = int(conf_options["channel"]) + except ValueError as e: + raise RuntimeError("Invalid config, %r is an invalid channel" % conf_options["channel"], e) + + return cls(conf_options["interface"], name, conf_options["ssid"], channel, driver=conf_options["driver"], psk=psk, options=options) + + + def parse_hostapd_output(self, output): + print(output) + matches = bound_ap_re.search(output) + if matches: + return True + else: + raise BindError("Could not bind hostapd %r to interface %s:\n%s" % (self, self.interface, output)) + + +class Dnsmasq(object): + """ + A wrapper for managing dnsmasq configurations (stored as .conf files under /etc/dnsmasq.conf.d) + and managing dnsmasq service instances based on them. + + Note: The directory /etc/dnsmasq.conf.d does not usually exist and has to be created before using + this class. Alternatively provide a different location for file storage by creating a custom type + wrapper using `Dnsmasq.for_dnsmasq_and_confd(dnsmasq=, confd=)`. + """ + + # dnsmasq binary + dnsmasq = "/usr/sbin/dnsmasq" + + # dnsmasq configuration storage + confd = "/etc/dnsmasq.conf.d" + + # our logger + logger = logging.getLogger(__name__) + + @classmethod + def for_dnsmasq_and_confd(cls, dnsmasq, confd): + return type(cls)(cls.__name__, (cls,), { + 'dnsmasq': dnsmasq if dnsmasq is not None else cls.dnsmasq, + 'confd': confd if confd is not None else cls.confd, + }) + + @classmethod + def all(cls): + result = [] + for conf in os.listdir(cls.confd): + if conf.endswith(".conf"): + filename = os.path.join(cls.confd, conf) + try: + dnsmasq = cls.from_dnsmasq_conf(os.path.join(cls.confd, conf)) + result.append(dnsmasq) + except: + cls.logger.exception("Could not retrieve dnsmasq config from file %s" % filename) + return result + + @classmethod + def where(cls, fn): + return list(filter(fn, cls.all())) + + @classmethod + def find(cls, interface, name): + try: + return cls.where(lambda x: x.name == name and x.interface == interface)[0] + except IndexError: + return None + + @classmethod + def from_dnsmasq_conf(cls, configfile): + """ + Creates a :class:`Dnsmasq` config from a given dnsmasq configuration file. + + :param configfile: path of config file to create instance from + :return: created instance + """ + + if not os.path.exists(configfile): + raise IOError("Configfile not found: %s" % configfile) + + name = os.path.basename(configfile)[:-len(".conf")] + + conf = dict() + additional_options = dict() + with open(configfile, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + + # split or "key=value" pairs + split_line = map(str.strip, line.split("=", 1)) + if len(split_line) > 1: + # this is an actual "key=value" pair + k, v = split_line + else: + # this is only a single "key" option without value + k = split_line[0] + v = None + + if k == "interface": + conf["interface"] = v + continue + + elif k == "dhcp-range": + # format is either "dhcp-range=,," or + # "dhcp-range=,,," + opts = v.split(",") + if len(opts) > 3: + # strip off tags, we don't care about them + opts = opts[-3:] + conf["start"], conf["end"], lease_time = opts + + # lease time can be given as "h", "m" or "" + if lease_time.endswith("h"): + # hours = 60 minutes * 60 seconds + factor = 60 * 60 + lease_time = lease_time[:-1] + elif lease_time.endswith("m"): + # minutes = 60 seconds + factor = 60 + lease_time = lease_time[:-1] + else: + # seconds + factor = 1 + + try: + conf["lease_time"] = int(lease_time) * factor + continue + except ValueError: + cls.logger.exception("Could not convert lease time value %s" % lease_time) + + elif k == "domain": + conf["domain"] = v + continue + + elif k == "dhcp-option": + # parse dhcp options, right now we only support gateway definition, which + # is provided in the format "dhcp-option=option:router," or + # "dhcp-option=3," + opts = v.split(",") + if len(opts) == 2 and (opts[0] == "option:router" or opts[0] == "3"): + conf["gateway"] = opts[1] + continue + + elif k in ("bind-interfaces", "local", "expand-hosts"): + # ignore known parameters that are used for general setup or domain setup + continue + + # if we came this far then we have an additional option at hand + if v is not None: + # if the value is not None, we create a key => list entry and add the value to it + if not k in additional_options: + additional_options[k] = list() + additional_options[k].append(v) + else: + # if the value is None it's a key only entry, so we add a key => None entry + additional_options[k] = None + + # make sure the mandatory parameters are all there + for key in "interface", "start", "end": + if not key in conf or conf[key] is None: + raise RuntimeError("Invalid config, %s is missing or None" % key) + + # make sure the optional arguments that are not supplied are all set to None + for key in "lease_time", "domain", "gateway": + if not key in conf: + conf[key] = None + + return Dnsmasq(conf["interface"], name, conf["start"], conf["end"], lease_time=conf["lease_time"], + gateway=conf["gateway"], domain=conf["domain"], options=additional_options) + + + def __init__(self, interface, name, start, end, lease_time=None, gateway=None, domain=None, options=None): + """ + :param interface: the interface on which to listen + :param name: the name of the configuration + :param start: the start ip of the managed dhcp range + :param end: the end ip of the managed dhcp range + :param lease_time: the lease time for given dhcp leases, defaults to 600s + :param gateway: the gateway to hand out via dhcp, defaults to no gateway + :param domain: the local domain to define, default to no domain + :param options: additional options, dict of either key => list (for possibly + multiple key-value-pairs) or key => None for key-only-statements + in the config + """ + + self.interface = interface + self.name = name + self.start = start + self.end = end + self.lease_time = lease_time if lease_time else 600 + self.gateway = gateway + self.domain = domain + self.options = options if options else dict() + + def __str__(self): + # basic dhcp setup for dnsmasq + conf = [ + "interface={interface}", + "bind-interfaces", + "dhcp-range={start},{end},{lease_time}" + ] + + # if a local domain is configured, add the corresponding configuration lines + if self.domain: + conf += [ + "local=/{domain}/", + "domain={domain}", + "expand-hosts" + ] + + # if a gateway is configured, add the corresponding configuration line + if self.gateway: + conf += [ + "dhcp-option=option:router,{gateway}" + ] + + # add any additional dnsmasq options that were provided + if self.options: + for k, l in self.options.items(): + if l is not None: + for v in l: + conf.append("{key}={value}".format(key=k, value=v)) + else: + conf.append(k) + + return "\n".join(conf).format(**vars(self)) + + def __repr__(self): + return "Dnsmasq(interface={interface}, name={name}, start={start}, end={end})".format(**vars(self)) + + def save(self, allow_overwrite=False): + """ + Saves the config to the defined `confd` directory. + + :param allow_overwrite: whether to overwrite an existing config of the same name, raises an + exception if such a config is found and set to `False` (the default) + """ + + existing_dnsmasq = self.__class__.find(self.interface, self.name) + if existing_dnsmasq: + if not allow_overwrite: + raise RuntimeError("Config for interface %s named %s does already exists and overwrite is not allowed" % (self.interface, self.name)) + existing_dnsmasq.delete() + + with open(self.configfile, "w") as f: + f.write(str(self)) + + def delete(self): + """ Deletes the config, deactivates it before if it's currently active. """ + + if self.is_running(): + self.deactivate() + try: + os.remove(self.configfile) + except OSError as e: + self._logger.warn("Could not delete %s: %s" % (self.configfile, e)) + + def activate(self): + """ Activates this config. """ + + try: + output = subprocess.check_output([self.__class__.dnsmasq, "--conf-file={file}".format(file=self.configfile)], stderr=subprocess.STDOUT) + self._logger.info("Started dnsmasq: {output}".format(output=output)) + except subprocess.CalledProcessError as e: + self._logger.warn("Error while starting dnsmasq: {output}".format(output=e.output)) + raise e + + def deactivate(self): + """ Deactivates this config. """ + + pid = self.get_pid() + if pid is None: + return + subprocess.check_call(["kill", pid]) + + def get_pid(self): + """ Get's the pid of the dnsmasq process running this config, or None if not currently running. """ + + pids = [pid for pid in os.listdir("/proc") if pid.isdigit()] + for pid in pids: + try: + with open(os.path.join("/proc", pid, "cmdline"), "r") as f: + line = f.readline() + if self.__class__.dnsmasq in line and self.configfile in line: + return pid + except: + pass + return None + + def is_running(self): + """ Returns a boolean indicating whether this config is currently active or not. """ + + return self.get_pid() is not None + + @property + def configfile(self): + return os.path.join(self.__class__.confd, "{name}.conf".format(name=self.name)) + + @property + def _logger(self): + return self.__class__.logger + + +class AccessPoint(object): + """ + Manages access point configurations by wrapping the hostapd, dnsmasq and scheme configurations + they are based on and allows starting and stopping the access point altogether. + """ + + # class providing the hostapd wrapper + hostapd_cls = Hostapd + + # class providing the dnsmasq wrapper + dnsmasq_cls = Dnsmasq + + # class providing the scheme wrapper + scheme_cls = Scheme + + @classmethod + def for_classes(cls, hostapd_cls=None, dnsmasq_cls=None, scheme_cls=None): + return type(cls)(cls.__name__, (cls,), { + 'hostapd_cls': hostapd_cls if hostapd_cls is not None else cls.hostapd_cls, + 'dnsmasq_cls': dnsmasq_cls if dnsmasq_cls is not None else cls.dnsmasq_cls, + 'scheme_cls': scheme_cls if scheme_cls is not None else cls.scheme_cls + }) + + @classmethod + def all(cls): + hostapds = {(hostapd.interface, hostapd.name): hostapd for hostapd in cls.hostapd_cls.all()} + dnsmasqs = {(dnsmasq.interface, dnsmasq.name): dnsmasq for dnsmasq in cls.dnsmasq_cls.all()} + schemes = {(scheme.interface, scheme.name): scheme for scheme in cls.scheme_cls.all()} + + result = [] + for key in hostapds: + if key in dnsmasqs and key in schemes: + result.append(AccessPoint(hostapds[key], dnsmasqs[key], schemes[key])) + return result + + @classmethod + def where(cls, fn): + return list(filter(fn, cls.all())) + + @classmethod + def find(cls, interface, name): + try: + return cls.where(lambda x: x.name == name and x.interface == interface)[0] + except IndexError: + return None + + @classmethod + def for_arguments(cls, interface, name, ssid, channel, ip, network, start, end, forwarding_to=None, + hostap_options=None, dnsmasq_options=None, scheme_options=None): + """ + Creates a new access point configuration for the given arguments. + + :param string interface: the interface on which to create the access point + :param string name: the configuration name + :param string ssid: the SSID to create + :param int channel: the channel on which to create the access point + :param string ip: the ip to assign to the interface serving as access point + :param string network: the network of the access point + :param string start: start address of IP address range handled by dhcp server + :param string end: end address of IP address range handled by dhcp server + :param string forwarding_to: interface to forward to, defaults to None for no + forwarding enabled + :param dict hostap_options: hostap options, defaults to None. Parameters + `driver` and `psk` will be used as their counterparts during + `Hostapd` construction, all other options will be given as + `options` to the Hostapd constructor. + :param dict dnsmasq_options: dnsmasq options, defaults to None. Parameters + `lease_time`, `domain` and `gateway` will be used as their + counterparts during `Dnsmasq` construction, all other options + will be given as `options` to the Dnsmasq constructor. + :param dict scheme_options: scheme options, defaults to None. Note that + `address`, `netmask` and `broadcast` will be overwritten with + the values derived from `ip` and `network`. If `forwarding_to` + is set `post-up` and `pre-down` will be extended to include + the necessary firewalling rules and forward-sysctl-calls + :return: the resulting `AccessPoint` instance + """ + + network_address = netaddr.IPNetwork(network) + + # prepare hostapd options + if hostap_options is None: + hostap_options = dict() + + if "driver" in hostap_options: + driver = hostap_options["driver"] + del hostap_options["driver"] + else: + driver = None + + if "psk" in hostap_options: + psk = hostap_options["psk"] + del hostap_options["psk"] + else: + psk = None + + # create hostapd config + hostapd = cls.hostapd_cls(interface, name, ssid, channel, driver, psk=psk, options=hostap_options) + + # prepare dnsmasq options + if dnsmasq_options is None: + dnsmasq_options = dict() + + if "lease_time" in dnsmasq_options: + lease_time = dnsmasq_options["lease_time"] + del dnsmasq_options["lease_time"] + else: + lease_time = None + + if "domain" in dnsmasq_options: + domain = dnsmasq_options["domain"] + del dnsmasq_options["domain"] + else: + domain = None + + if "gateway" in dnsmasq_options: + gateway = dnsmasq_options["gateway"] + del dnsmasq_options["gateway"] + else: + gateway = None + + # create dnsmasq + dnsmasq = cls.dnsmasq_cls(interface, name, start, end, lease_time=lease_time, gateway=gateway, domain=domain, options=dnsmasq_options) + + # prepare scheme options + if scheme_options == None: + scheme_options = dict() + + # create a scheme with static configuration, given ip and netmask -- those parameters will be ruthlessly + # overridden if they were already present in the supplied scheme_options + scheme_options.update(dict( + address=[ip], + netmask=[str(network_address.netmask)], + broadcast=[str(network_address.broadcast)] + )) + + if forwarding_to is not None: + # if forwarding is enabled, also add some rules and stuff + if not "post-up" in scheme_options: + scheme_options["post-up"] = [] + scheme_options["post-up"] += [ + # flush current tables + "/sbin/iptables -F", + "/sbin/iptables -X", + "/sbin/iptables -t nat -F", + # setup forwarding rules + "/sbin/iptables -A FORWARD -o {forward} -i {interface} -s {network} -m conntrack --ctstate NEW -j ACCEPT".format(forward=forwarding_to, network=str(network_address), interface=interface), + "/sbin/iptables -A FORWARD -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", + "/sbin/iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE", + # enable forwarding + "/sbin/sysctl -w net.ipv4.ip_forward=1" + ] + + if not "pre-down" in scheme_options: + scheme_options["pre-down"] = [] + scheme_options["pre-down"] += [ + # disable forwarding + "/sbin/sysctl -w net.ipv4.ip_forward=0", + # flush current tables + "/sbin/iptables -F", + "/sbin/iptables -X", + "/sbin/iptables -t nat -F", + ] + scheme = cls.scheme_cls(interface, name, type="static", options=scheme_options) + + return cls(hostapd, dnsmasq, scheme) + + def __init__(self, hostapd, dnsmasq, scheme): + """ + Constructor for the :class:`AccessPoint` instance, takes :class:`Hostapd`, :class:`Dnsmasq` and :class:`Scheme` + instance to utilize. + + Should normally not be used directly from calling code, instead use the provided factory `for_arguments`. + + :param hostapd: :class:`Hostapd` instance + :param dnsmasq: :class:`Dnsmasq` instance + :param scheme: :class:`Scheme` instance + """ + + self._logger = logging.getLogger(__name__) + + self.hostapd = hostapd + self.dnsmasq = dnsmasq + self.scheme = scheme + + def save(self, allow_overwrite=False): + """ + Saves all wrapped configurations. + :param allow_overwrite: whether to allow overwriting of existing configs, defaults to False + """ + + self.hostapd.save(allow_overwrite=allow_overwrite) + self.dnsmasq.save(allow_overwrite=allow_overwrite) + self.scheme.save(allow_overwrite=allow_overwrite) + + def delete(self): + """ Deletes all wrapped configurations. """ + + self.hostapd.delete() + self.dnsmasq.delete() + self.scheme.delete() + + def activate(self): + """ Activates the access point by activating all wrapped configurations. """ + + self.hostapd.activate() + self.scheme.activate() + self.dnsmasq.activate() + + def deactivate(self): + """ Deactivates the access point by deactivating all wrapped configurations. """ + + self.dnsmasq.deactivate() + self.scheme.deactivate() + self.hostapd.deactivate() + + @property + def name(self): + return self.hostapd.name + + @property + def interface(self): + return self.hostapd.interface + + def is_running(self): + """ Returns whether the access point is currently running (either hostap or dnsmasq) or not. """ + return self.hostapd.is_running() or self.dnsmasq.is_running() + + def __repr__(self): + return "AccessPoint(hostapd={hostapd!r}, dnsmasq={dnsmasq!r}, scheme={scheme!r})".format(**vars(self)) + diff --git a/wifi/exceptions.py b/wifi/exceptions.py index 3a75b37..b8102ea 100644 --- a/wifi/exceptions.py +++ b/wifi/exceptions.py @@ -1,6 +1,13 @@ -class ConnectionError(Exception): +class WifiError(Exception): + pass + +class ConnectionError(WifiError): pass -class InterfaceError(Exception): +class InterfaceError(WifiError): pass + + +class BindError(WifiError): + pass \ No newline at end of file diff --git a/wifi/scheme.py b/wifi/scheme.py index c1a5026..bf1d702 100644 --- a/wifi/scheme.py +++ b/wifi/scheme.py @@ -1,10 +1,23 @@ import re import itertools +import logging +import socket +import fcntl +import struct import wifi.subprocess_compat as subprocess from pbkdf2 import PBKDF2 from wifi.utils import ensure_file_exists -from wifi.exceptions import ConnectionError +from wifi.exceptions import * + + +def get_ip_address(ifname): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + return socket.inet_ntoa(fcntl.ioctl( + s.fileno(), + 0x8915, # SIOCGIFADDR + struct.pack('256s', ifname[:15]) + )[20:24]) def configuration(cell, passkey=None): @@ -77,18 +90,26 @@ def for_file(cls, interfaces): 'interfaces': interfaces, }) - def __init__(self, interface, name, options=None): + def __init__(self, interface, name, type="dhcp", options=None): self.interface = interface self.name = name + self.type = type + + if options: + for k, v in options.items(): + if not isinstance(v, (list, tuple)): + options[k] = [v] self.options = options or {} + self.logger = logging.getLogger(__name__) + def __str__(self): """ Returns the representation of a scheme that you would need in the /etc/network/interfaces file. """ - iface = "iface {interface}-{name} inet dhcp".format(**vars(self)) - options = ''.join("\n {k} {v}".format(k=k, v=v) for k, v in self.options.items()) + iface = "iface {interface}-{name} inet {type}".format(**vars(self)) + options = ''.join("\n {k} {v}".format(k=k, v=v) for k in self.options.keys() for v in self.options[k]) return iface + options + '\n' def __repr__(self): @@ -124,13 +145,17 @@ def for_cell(cls, interface, name, cell, passkey=None): Intuits the configuration needed for a specific :class:`Cell` and creates a :class:`Scheme` for it. """ - return cls(interface, name, configuration(cell, passkey)) + return cls(interface, name, options=configuration(cell, passkey)) - def save(self): + def save(self, allow_overwrite=False): """ Writes the configuration to the :attr:`interfaces` file. """ - assert not self.find(self.interface, self.name), "This scheme already exists" + existing_scheme = self.find(self.interface, self.name) + if existing_scheme: + if not allow_overwrite: + raise RuntimeError("Scheme for interface %s named %s already exists and overwrite is forbidden" % (self.interface, self.name)) + existing_scheme.delete() with open(self.interfaces, 'a') as f: f.write('\n') @@ -140,14 +165,14 @@ def delete(self): """ Deletes the configuration from the :attr:`interfaces` file. """ - iface = "iface %s-%s inet dhcp" % (self.interface, self.name) + iface = "iface %s-%s inet %s" % (self.interface, self.name, self.type) content = '' with open(self.interfaces, 'r') as f: skip = False for line in f: if not line.strip(): skip = False - elif line.strip() == iface: + elif line.strip().startswith(iface): skip = True if not skip: content += line @@ -160,7 +185,7 @@ def iface(self): def as_args(self): args = list(itertools.chain.from_iterable( - ('-o', '{k}={v}'.format(k=k, v=v)) for k, v in self.options.items())) + ('-o', '{k}={v}'.format(k=k, v=v)) for k in self.options.keys() for v in self.options[k])) return [self.interface + '=' + self.iface] + args @@ -169,18 +194,36 @@ def activate(self): Connects to the network as configured in this scheme. """ - subprocess.check_output(['/sbin/ifdown', self.interface], stderr=subprocess.STDOUT) - ifup_output = subprocess.check_output(['/sbin/ifup'] + self.as_args(), stderr=subprocess.STDOUT) + self.deactivate() + try: + ifup_output = subprocess.check_output(['/sbin/ifup'] + self.as_args(), stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + self.logger.exception("Error while trying to connect to %s" % self.iface) + self.logger.error("Output: %s" % e.output) + raise InterfaceError("Failed to connect to %r: %s" % (self, e.message)) + ifup_output = ifup_output.decode('utf-8') return self.parse_ifup_output(ifup_output) + def deactivate(self): + """ + Disconnects from the network as configured in this scheme. + """ + + subprocess.check_output(['/sbin/ifdown', self.iface], stderr=subprocess.STDOUT) + def parse_ifup_output(self, output): - matches = bound_ip_re.search(output) - if matches: - return Connection(scheme=self, ip_address=matches.group('ip_address')) + if self.type == "dhcp": + matches = bound_ip_re.search(output) + if matches: + return Connection(scheme=self, ip_address=matches.group('ip_address')) + elif "already configured" in output: + return Connection(scheme=self, ip_address=get_ip_address(self.interface)) + else: + raise ConnectionError("Failed to connect to %r" % self) else: - raise ConnectionError("Failed to connect to %r" % self) + return Connection(scheme=self, ip_address=self.options["address"][0]) class Connection(object): @@ -192,8 +235,8 @@ def __init__(self, scheme, ip_address): self.ip_address = ip_address -scheme_re = re.compile(r'iface\s+(?P[^-]+)(?:-(?P\S+))?') - +# TODO: support other interfaces +scheme_re = re.compile(r'iface\s+(?Pwlan\d?)(?:-(?P\w+))?\s+inet\s+(?P\w+)') def extract_schemes(interfaces, scheme_class=Scheme): lines = interfaces.splitlines() @@ -206,15 +249,18 @@ def extract_schemes(interfaces, scheme_class=Scheme): match = scheme_re.match(line) if match: options = {} - interface, scheme = match.groups() + interface, scheme, type = match.groups() if not scheme or not interface: continue while lines and lines[0].startswith(' '): key, value = re.sub(r'\s{2,}', ' ', lines.pop(0).strip()).split(' ', 1) - options[key] = value + if not key in options: + options[key] = [] + options[key].append(value) - scheme = scheme_class(interface, scheme, options) + scheme = scheme_class(interface, scheme, type=type, options=options) yield scheme + diff --git a/wifi/utils.py b/wifi/utils.py index 7fa6084..45af1af 100644 --- a/wifi/utils.py +++ b/wifi/utils.py @@ -55,3 +55,7 @@ def ensure_file_exists(filename): """ if not os.path.exists(filename): open(filename, 'a').close() + +cidr_v4_pattern = r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])(\/(\d|[1-2]\d|3[0-2]))" +mac_addr_pattern = r"[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}" +