diff --git a/nansat/node.py b/nansat/node.py index b400504d2..8d5a131e9 100644 --- a/nansat/node.py +++ b/nansat/node.py @@ -169,31 +169,6 @@ def delNode(self, tag, options=None): for i in sorted(ideleted, reverse=True): self.children.pop(i) - def find_dom_child(self, dom, tagName, n=0): - '''Recoursively find child of the dom''' - children = dom.childNodes - theChild = None - - chn = 0 - for child in children: - print(child, child.nodeType, chn) - if child.nodeType == 1: - print(child.tagName) - if str(child.tagName) == tagName: - print(child.tagName, tagName, 'OK') - if chn == n: - theChild = child - chn += 1 - - if theChild is not None: - break - - if child.hasChildNodes(): - print('has childs') - theChild = self.find_dom_child(child, tagName, n) - - return theChild - def nodeList(self, tag): ''' Produce a list of subnodes with the same tag. diff --git a/nansat/pointbrowser.py b/nansat/pointbrowser.py index 29c784205..c747d8949 100644 --- a/nansat/pointbrowser.py +++ b/nansat/pointbrowser.py @@ -19,9 +19,9 @@ import numpy as np try: - if 'DISPLAY' not in os.environ: - import matplotlib; matplotlib.use('Agg') import matplotlib + if 'DISPLAY' not in os.environ: + matplotlib.use('Agg') import matplotlib.pyplot as plt except ImportError: MATPLOTLIB_IS_INSTALLED = False @@ -63,11 +63,11 @@ class PointBrowser(): lines = None coordinates = None - def __init__(self, data, fmt='x-k', **kwargs): + def __init__(self, data, fmt='x-k', force_interactive=True, **kwargs): """Open figure with imshow and colorbar""" if not MATPLOTLIB_IS_INSTALLED: raise ImportError(' Matplotlib is not installed ') - if not matplotlib.is_interactive(): + if force_interactive and not matplotlib.is_interactive(): raise SystemError(''' Python is started with -pylab option, transect will not work. Please restart python without -pylab.''') diff --git a/nansat/tests/test_nansat.py b/nansat/tests/test_nansat.py index 39440e37f..4fec61758 100644 --- a/nansat/tests/test_nansat.py +++ b/nansat/tests/test_nansat.py @@ -748,22 +748,15 @@ def test_get_transect_data(self): self.assertEqual(type(t['lat']), np.ndarray) self.assertEqual(type(t['lon']), np.ndarray) - @unittest.skipUnless(MATPLOTLIB_IS_INSTALLED and 'DISPLAY' in os.environ, 'Matplotlib is required') - def test_digitize_points(self): - ''' shall return empty array in non interactive mode ''' - for backend in matplotlib.rcsetup.interactive_bk: - # Find a supported interactive backend - try: - plt.switch_backend(backend) - break; - except: - pass - plt.ion() - n1 = Nansat(self.test_file_gcps, log_level=40) - points = n1.digitize_points(1) - - self.assertEqual(len(points), 0) - plt.ioff() + @patch('nansat.nansat.PointBrowser') + def test_digitize_points(self, mock_PointBrowser): + """ shall create PointBrowser and call PointBrowser.get_points() """ + value = 'points' + mock_PointBrowser().get_points.return_value = value + n = Nansat(self.test_file_gcps, log_level=40) + points = n.digitize_points(1) + self.assertTrue(mock_PointBrowser.called_once()) + self.assertEqual(points, value) def test_crop(self): n1 = Nansat(self.test_file_gcps, log_level=40) diff --git a/nansat/tests/test_node.py b/nansat/tests/test_node.py index b9bd408bc..4a2ee3a17 100644 --- a/nansat/tests/test_node.py +++ b/nansat/tests/test_node.py @@ -23,10 +23,14 @@ def test_creation(self): tag = 'Root' value = ' Value ' anAttr = 'elValue' + new_value = 'New Value' node = Node(tag, value=value, anAttr=anAttr) self.assertEqual(node.tag, tag) self.assertDictEqual(node.attributes, {'anAttr': anAttr}) self.assertEqual(node.value, value.strip()) + self.assertEqual(node[tag], value.strip()) + node[tag] = new_value + self.assertEqual(node.value, new_value) def test_getAttributeList(self): tag = 'Root' @@ -145,3 +149,12 @@ def test_search_node(self): root += firstLevel3 self.assertEqual(root.node(firstLevelTag,0), firstLevel) self.assertEqual(root.node(firstLevelTag,1), firstLevel2) + + def test_str(self): + tag = 'Root' + value = 'Value' + node = Node(tag, value=value) + self.assertEqual(str(node), '%s\n value: [%s]' % (tag, value)) + +if __name__ == "__main__": + unittest.main() diff --git a/nansat/tests/test_pointbrowser.py b/nansat/tests/test_pointbrowser.py index ad9a85e6a..9f60adfa9 100644 --- a/nansat/tests/test_pointbrowser.py +++ b/nansat/tests/test_pointbrowser.py @@ -16,11 +16,10 @@ from nansat.pointbrowser import PointBrowser try: - if 'DISPLAY' not in os.environ: - import matplotlib; matplotlib.use('Agg') import matplotlib + if 'DISPLAY' not in os.environ: + matplotlib.use('Agg') import matplotlib.pyplot as plt - plt.switch_backend('qt5agg') except ImportError: MATPLOTLIB_IS_INSTALLED = False else: @@ -28,17 +27,11 @@ class PointBrowserTest(unittest.TestCase): - @unittest.skipUnless(MATPLOTLIB_IS_INSTALLED and 'DISPLAY' in os.environ, 'Matplotlib is required') - def setUp(self): - plt.switch_backend('qt5agg') - plt.ion() - data = np.ndarray(shape=(4, 4), dtype=float, order='F') - self.point = PointBrowser(data) - def test_onclick(self): + point_browser = PointBrowser(np.zeros((4, 4)), force_interactive=False) event = Event(xdata=0, ydata=0, key=None) - self.point.onclick(event) - t = self.point._convert_coordinates()[0] + point_browser.onclick(event) + t = point_browser._convert_coordinates()[0] self.assertIsInstance(t, np.ndarray) xPoints = t[0] self.assertIsInstance(xPoints, np.ndarray) @@ -48,20 +41,26 @@ def test_onclick(self): self.assertEqual(yPoints[0], event.ydata, "y coordinates is set wrong") def test_onclick_multilines(self): + point_browser = PointBrowser(np.zeros((4, 4)), force_interactive=False) events = [] events.append(Event(xdata=0, ydata=0, key=None)) events.append(Event(xdata=1, ydata=0, key=None)) events.append(Event(xdata=2, ydata=2, key='AnyKeyButZorAltZ')) events.append(Event(xdata=2, ydata=3, key=None)) for event in events: - self.point.onclick(event) - points = self.point._convert_coordinates() + point_browser.onclick(event) + points = point_browser._convert_coordinates() self.assertEqual(len(points), 2, 'There should be two transects') self.assertTrue(np.alltrue(points[0] == np.array([[0, 1], [0, 0]])), 't1 is not correct') self.assertTrue(np.alltrue(points[1] == np.array([[2, 2], [2, 3]])), 't2 is not correct') + @unittest.skipIf('DISPLAY' in os.environ, 'Non-interactive mode is required') + def test_fail_non_interactive(self): + with self.assertRaises(SystemError): + point_browser = PointBrowser(np.zeros((4, 4))) + class Event: def __init__(self, **kwds): self.__dict__.update(kwds) diff --git a/nansat/tests/test_vrt.py b/nansat/tests/test_vrt.py index 96f83a368..54d7f88fe 100644 --- a/nansat/tests/test_vrt.py +++ b/nansat/tests/test_vrt.py @@ -452,6 +452,16 @@ def test_init_from_old__gdal_dataset(self): self.assertIn('filename', list(vrt.dataset.GetMetadata().keys())) self.assertIn('AREA_OR_POINT', vrt.dataset.GetMetadata()) + def test_init_from_old__gdal_dataset2(self): + ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif')) + with warnings.catch_warnings(record=True) as w: + vrt = VRT(ds) + self.assertEqual(w[0].category, NansatFutureWarning) + self.assertIsInstance(vrt.dataset, gdal.Dataset) + self.assertTrue(vrt.filename.startswith('/vsimem/')) + self.assertIn('filename', list(vrt.dataset.GetMetadata().keys())) + self.assertIn('AREA_OR_POINT', vrt.dataset.GetMetadata()) + def test_init_from_old__vrt_dataset(self): ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif')) with warnings.catch_warnings(record=True) as w: @@ -463,11 +473,13 @@ def test_init_from_old__vrt_dataset(self): def test_init_from_old__dataset_params(self): ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif')) with warnings.catch_warnings(record=True) as w: - vrt = VRT(srcGeoTransform=(0, 1, 0, 0, 0, -1), srcRasterXSize=10, srcRasterYSize=20) + vrt = VRT(srcGeoTransform=(0, 1, 0, 0, 0, -1), srcRasterXSize=10, srcRasterYSize=20, + srcMetadata={'meta_key1': 'meta_value1'}) self.assertEqual(w[0].category, NansatFutureWarning) self.assertIsInstance(vrt.dataset, gdal.Dataset) self.assertEqual(vrt.dataset.RasterXSize, 10) self.assertTrue(vrt.filename.startswith('/vsimem/')) + self.assertIn('meta_key1', vrt.dataset.GetMetadata()) def test_init_from_old__array(self): a = np.random.randn(100,100) @@ -548,5 +560,44 @@ def test_get_super_vrt_and_copy(self): self.assertFalse(data is None) self.assertTrue(np.all(data == array)) + def test_filename_warning(self): + vrt = VRT() + with warnings.catch_warnings(record=True) as w: + vrt_filename = vrt.fileName + self.assertEqual(vrt_filename, vrt.filename) + + def test_get_sub_vrt0(self): + vrt1 = VRT() + vrt2 = vrt1.get_sub_vrt() + self.assertEqual(vrt1, vrt2) + + def test_get_sub_vrt3(self): + vrt1 = VRT().get_super_vrt().get_super_vrt().get_super_vrt() + vrt2 = vrt1.get_sub_vrt(3) + self.assertEqual(vrt2.vrt, None) + + def test_get_sub_vrt_steps_0(self): + vrt1 = VRT().get_super_vrt() + vrt2 = vrt1.get_sub_vrt(steps=0) + self.assertEqual(vrt1, vrt2) + + def test_transform_points(self): + ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif')) + vrt1 = VRT.from_gdal_dataset(ds, metadata=ds.GetMetadata()) + vrt1.tps = True + lon, lat = vrt1.transform_points([1, 2, 3], [4, 5, 6]) + self.assertTrue(np.allclose(lon, np.array([28.23549571, 28.24337106, 28.25126129]))) + self.assertTrue(np.allclose(lat, np.array([71.52509848, 71.51913744, 71.51317568]))) + + def test_make_filename(self): + filename1 = VRT._make_filename() + filename2 = VRT._make_filename(extention='smth') + filename3 = VRT._make_filename(nomem=True) + self.assertTrue(filename1.startswith('/vsimem/')) + self.assertTrue(filename2.startswith('/vsimem/')) + self.assertTrue(filename2.endswith('.smth')) + self.assertTrue(os.path.exists(filename3)) + + if __name__ == "__main__": unittest.main() diff --git a/nansat/vrt.py b/nansat/vrt.py index 507dc14f9..9535d419a 100644 --- a/nansat/vrt.py +++ b/nansat/vrt.py @@ -1064,8 +1064,6 @@ def create_band(self, src, dst=None): srcs = [src] elif type(src) in [list, tuple]: srcs = src - else: - raise ValueError('Wrong src type (%s)! Should be dict or list/tuple of dict'%type(src)) # Check if dst is given, or create empty dict if dst is None: @@ -1496,12 +1494,7 @@ def transform_points(self, col_vector, row_vector, dst2src=0, lonlat = transformer.TransformPoints(dst2src, xy)[0] # convert return to lon,lat vectors - lonlat = np.array(lonlat) - if lonlat.shape[0] > 0: - lon_vector = lonlat[:, 0] - lat_vector = lonlat[:, 1] - else: - lon_vector, lat_vector = [], [] + lon_vector, lat_vector, _ = np.array(lonlat).T return lon_vector, lat_vector