Skip to content

Commit

Permalink
#91 add tests for VRT
Browse files Browse the repository at this point in the history
  • Loading branch information
akorosov committed Mar 9, 2018
1 parent 323521c commit 0f2fc9d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
54 changes: 53 additions & 1 deletion nansat/tests/test_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,16 @@ def test_init_from_old__gdal_dataset(self):
self.assertIn('filename', list(vrt.dataset.GetMetadata().keys()))
self.assertIn('AREA_OR_POINT', vrt.dataset.GetMetadata())

def test_init_from_old__gdal_dataset2(self):
ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif'))
with warnings.catch_warnings(record=True) as w:
vrt = VRT(ds)
self.assertEqual(w[0].category, NansatFutureWarning)
self.assertIsInstance(vrt.dataset, gdal.Dataset)
self.assertTrue(vrt.filename.startswith('/vsimem/'))
self.assertIn('filename', list(vrt.dataset.GetMetadata().keys()))
self.assertIn('AREA_OR_POINT', vrt.dataset.GetMetadata())

def test_init_from_old__vrt_dataset(self):
ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif'))
with warnings.catch_warnings(record=True) as w:
Expand All @@ -432,11 +442,13 @@ def test_init_from_old__vrt_dataset(self):
def test_init_from_old__dataset_params(self):
ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif'))
with warnings.catch_warnings(record=True) as w:
vrt = VRT(srcGeoTransform=(0, 1, 0, 0, 0, -1), srcRasterXSize=10, srcRasterYSize=20)
vrt = VRT(srcGeoTransform=(0, 1, 0, 0, 0, -1), srcRasterXSize=10, srcRasterYSize=20,
srcMetadata={'meta_key1': 'meta_value1'})
self.assertEqual(w[0].category, NansatFutureWarning)
self.assertIsInstance(vrt.dataset, gdal.Dataset)
self.assertEqual(vrt.dataset.RasterXSize, 10)
self.assertTrue(vrt.filename.startswith('/vsimem/'))
self.assertIn('meta_key1', vrt.dataset.GetMetadata())

def test_init_from_old__array(self):
a = np.random.randn(100,100)
Expand Down Expand Up @@ -508,5 +520,45 @@ def test_get_super_vrt(self):
self.assertEqual(vrt2.dataset.GetMetadataItem(str('AREA_OR_POINT')), 'Area')


def test_filename_warning(self):
vrt = VRT()
with warnings.catch_warnings(record=True) as w:
vrt_filename = vrt.fileName
self.assertEqual(vrt_filename, vrt.filename)

def test_get_sub_vrt0(self):
vrt1 = VRT()
vrt2 = vrt1.get_sub_vrt()
self.assertEqual(vrt1, vrt2)

def test_get_sub_vrt3(self):
vrt1 = VRT().get_super_vrt().get_super_vrt().get_super_vrt()
vrt2 = vrt1.get_sub_vrt(3)
self.assertEqual(vrt2.vrt, None)

def test_get_sub_vrt_steps_0(self):
vrt1 = VRT().get_super_vrt()
vrt2 = vrt1.get_sub_vrt(steps=0)
self.assertEqual(vrt1, vrt2)

def test_transform_points(self):
ds = gdal.Open(os.path.join(ntd.test_data_path, 'gcps.tif'))
vrt1 = VRT.from_gdal_dataset(ds, metadata=ds.GetMetadata())
vrt1.tps = True
lon, lat = vrt1.transform_points([1, 2, 3], [4, 5, 6])
self.assertTrue(np.allclose(lon, np.array([28.23549571, 28.24337106, 28.25126129])))
self.assertTrue(np.allclose(lat, np.array([71.52509848, 71.51913744, 71.51317568])))

def test_make_filename(self):
filename1 = VRT._make_filename()
filename2 = VRT._make_filename(extention='smth')
filename3 = VRT._make_filename(nomem=True)
self.assertTrue(filename1.startswith('/vsimem/'))
self.assertTrue(filename2.startswith('/vsimem/'))
self.assertTrue(filename2.endswith('.smth'))
self.assertTrue(os.path.exists(filename3))



if __name__ == "__main__":
unittest.main()
9 changes: 1 addition & 8 deletions nansat/vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,8 +1050,6 @@ def create_band(self, src, dst=None):
srcs = [src]
elif type(src) in [list, tuple]:
srcs = src
else:
raise ValueError('Wrong src type (%s)! Should be dict or list/tuple of dict'%type(src))

# Check if dst is given, or create empty dict
if dst is None:
Expand Down Expand Up @@ -1482,12 +1480,7 @@ def transform_points(self, col_vector, row_vector, dst2src=0,
lonlat = transformer.TransformPoints(dst2src, xy)[0]

# convert return to lon,lat vectors
lonlat = np.array(lonlat)
if lonlat.shape[0] > 0:
lon_vector = lonlat[:, 0]
lat_vector = lonlat[:, 1]
else:
lon_vector, lat_vector = [], []
lon_vector, lat_vector, _ = np.array(lonlat).T

return lon_vector, lat_vector

Expand Down

0 comments on commit 0f2fc9d

Please sign in to comment.