gdal/autotest/gcore/numpy_rw_multidim.py

295 строки
8.7 KiB
Python
Исполняемый файл

#!/usr/bin/env pytest
###############################################################################
# $Id$
#
# Project: GDAL/OGR Test Suite
# Purpose: Test multidimensional support with numpy
# Author: Even Rouault <even.rouault@spatialys.com>
#
###############################################################################
# Copyright (c) 2019, Even Rouault <even.rouault@spatialys.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
###############################################################################
import struct
import gdaltest
import pytest
from osgeo import gdal
###############################################################################
# verify that we can load Numeric python, and find the Numpy driver.
def test_numpy_rw_multidim_init():
gdaltest.numpy_drv = None
# importing gdal_array will allow numpy driver registration
pytest.importorskip("osgeo.gdal_array")
gdal.AllRegister()
gdaltest.numpy_drv = gdal.GetDriverByName("NUMPY")
assert gdaltest.numpy_drv is not None, "NUMPY driver not found!"
###############################################################################
def test_numpy_rw_multidim_readasarray_writearray():
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
drv = gdal.GetDriverByName("MEM")
ds = drv.CreateMultiDimensional("myds")
rg = ds.GetRootGroup()
dim0 = rg.CreateDimension("dim0", None, None, 2)
dim1 = rg.CreateDimension("dim1", None, None, 3)
myarray = rg.CreateMDArray(
"myarray", [dim0, dim1], gdal.ExtendedDataType.Create(gdal.GDT_Byte)
)
assert myarray
ar = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8)
assert myarray.WriteArray(ar) == gdal.CE_None
got_ar = myarray.ReadAsArray()
assert got_ar.shape == (2, 3)
assert np.array_equal(got_ar, ar)
# Check algo with non-numpy method so as to detect issues with buffer striding
assert struct.unpack("B" * 6, myarray.Read()) == (1, 2, 3, 4, 5, 6)
assert struct.unpack("B" * 6, myarray.Read(buffer_stride=[3, 1])) == (
1,
2,
3,
4,
5,
6,
)
assert struct.unpack("B" * 6, myarray.Read(buffer_stride=[1, 2])) == (
1,
4,
2,
5,
3,
6,
)
###############################################################################
def test_numpy_rw_multidim_numpy_array_as_dataset():
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
from osgeo import gdal_array
for typ in (
np.int8,
np.uint8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.float32,
np.float64,
np.cfloat,
np.cdouble,
):
ar = np.array([[1, 2, 3], [4, 5, 6]], dtype=typ)
ds = gdal_array.OpenMultiDimensionalNumPyArray(ar)
assert ds
rg = ds.GetRootGroup()
assert rg
myarray = rg.OpenMDArray("array")
assert myarray
assert np.array_equal(myarray.ReadAsArray(), ar)
###############################################################################
def test_numpy_rw_multidim_readasarray_writearray_negative_strides():
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
drv = gdal.GetDriverByName("MEM")
ds = drv.CreateMultiDimensional("myds")
rg = ds.GetRootGroup()
dim0 = rg.CreateDimension("dim0", None, None, 2)
dim1 = rg.CreateDimension("dim1", None, None, 3)
myarray = rg.CreateMDArray(
"myarray", [dim0, dim1], gdal.ExtendedDataType.Create(gdal.GDT_Byte)
)
assert myarray
ar = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8)
ar = ar[::-1, ::-1] # Test negative strides
assert myarray.WriteArray(ar) == gdal.CE_None
got_ar = myarray.ReadAsArray()
assert got_ar.shape == (2, 3)
assert np.array_equal(got_ar, ar)
# Check algo with non-numpy method so as to detect issues with buffer striding
assert struct.unpack("B" * 6, myarray.Read()) == (6, 5, 4, 3, 2, 1)
assert struct.unpack("B" * 6, myarray.Read(buffer_stride=[3, 1])) == (
6,
5,
4,
3,
2,
1,
)
assert struct.unpack("B" * 6, myarray.Read(buffer_stride=[1, 2])) == (
6,
3,
5,
2,
4,
1,
)
###############################################################################
def test_numpy_rw_multidim_numpy_array_as_dataset_negative_strides():
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
from osgeo import gdal_array
for typ in (
np.int8,
np.uint8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.float32,
np.float64,
np.cfloat,
np.cdouble,
):
ar = np.array([[1, 2, 3], [4, 5, 6]], dtype=typ)
ar = ar[::-1, ::-1] # Test negative strides
ds = gdal_array.OpenMultiDimensionalNumPyArray(ar)
assert ds
rg = ds.GetRootGroup()
assert rg
myarray = rg.OpenMDArray("array")
assert myarray
assert np.array_equal(myarray.ReadAsArray(), ar)
###############################################################################
def test_numpy_rw_multidim_compound_datatype():
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
from osgeo import gdal_array
drv = gdal.GetDriverByName("MEM")
ds = drv.CreateMultiDimensional("myds")
rg = ds.GetRootGroup()
dim = rg.CreateDimension("dim0", None, None, 2)
comp0 = gdal.EDTComponent.Create(
"x", 0, gdal.ExtendedDataType.Create(gdal.GDT_Int16)
)
comp1 = gdal.EDTComponent.Create(
"y", 4, gdal.ExtendedDataType.Create(gdal.GDT_Int32)
)
dt = gdal.ExtendedDataType.CreateCompound("mytype", 8, [comp0, comp1])
myarray = rg.CreateMDArray("myarray", [dim], dt)
assert myarray
numpydt = gdal_array.ExtendedDataTypeToNumPyDataType(dt)
assert numpydt.itemsize == 8
assert numpydt.names == ("x", "y")
assert numpydt.fields["x"] == (np.int16, 0)
assert numpydt.fields["y"] == (np.int32, 4)
assert (
myarray.Write(struct.pack("hi" * 2, 32767, 1000000, -32768, -1000000))
== gdal.CE_None
)
res = myarray.ReadAsArray()
assert res.dtype == numpydt
assert np.array_equal(
res, np.array([(32767, 1000000), (-32768, -1000000)], dtype=res.dtype)
)
ar = np.array([(-32768, -1000000), (-32767, 1000000)], dtype=numpydt)
assert myarray.WriteArray(ar) == gdal.CE_None
res = myarray.ReadAsArray()
assert np.array_equal(res, ar)
###############################################################################
@pytest.mark.parametrize(
"datatype",
[
gdal.GDT_Byte,
gdal.GDT_Int16,
gdal.GDT_UInt16,
gdal.GDT_Int32,
gdal.GDT_UInt32,
gdal.GDT_Float32,
gdal.GDT_Float64,
gdal.GDT_CInt16,
gdal.GDT_CInt32,
gdal.GDT_CFloat32,
gdal.GDT_CFloat64,
],
ids=gdal.GetDataTypeName,
)
def test_numpy_rw_multidim_datatype(datatype):
if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np
drv = gdal.GetDriverByName("MEM")
ds = drv.CreateMultiDimensional("myds")
rg = ds.GetRootGroup()
dim = rg.CreateDimension("dim0", None, None, 2)
myarray = rg.CreateMDArray("myarray", [dim], gdal.ExtendedDataType.Create(datatype))
assert myarray
numpy_ar = np.reshape(np.arange(0, 2, dtype=np.uint16), (2,))
assert myarray.WriteArray(numpy_ar) == gdal.CE_None
got = myarray.ReadAsArray()
assert np.array_equal(got, numpy_ar)
assert np.array_equal(
myarray.ReadAsArray(buf_obj=np.zeros(got.shape, got.dtype)), numpy_ar
)