"""Geospatial data parser & utilities module."""
import functools
import json
import logging
import math
import os
import pickle
import cv2 as cv
import numpy as np
import shapely
import torch
import tqdm
import thelper.tasks
import thelper.utils
from thelper.data import Dataset, ImageFolderDataset
from thelper.data.geo.utils import parse_raster_metadata
try:
from osgeo import gdal, ogr, osr
except (ImportError, ModuleNotFoundError):
import gdal
import ogr
import osr
logger = logging.getLogger(__name__)
[docs]class VectorCropDataset(Dataset):
"""Abstract dataset used to combine geojson vector data and rasters."""
[docs] def __init__(self, raster_path, vector_path, px_size=None, skew=None,
allow_outlying_vectors=True, clip_outlying_vectors=True,
vector_area_min=0.0, vector_area_max=float("inf"),
vector_target_prop=None, feature_buffer=None, master_roi=None,
srs_target="3857", raster_key="raster", mask_key="mask",
cleaner=None, cropper=None, force_parse=False,
reproj_rasters=False, reproj_all_cpus=True,
keep_rasters_open=True, transforms=None):
import thelper.data.geo as geo
# before anything else, create a hash to cache parsed data
cache_hash = thelper.utils.get_params_hash(
{k: v for k, v in vars().items() if not k.startswith("_") and k != "self"}) if not force_parse else None
assert isinstance(raster_path, str), "raster file/folder path should be given as string"
assert isinstance(vector_path, str), "vector file/folder path should be given as string"
self.raster_path = raster_path
self.vector_path = vector_path
assert px_size is None or \
(isinstance(px_size, (list, tuple)) and all([isinstance(i, (float, int)) for i in px_size]) and len(px_size) == 2) or \
isinstance(px_size, (float, int)), "pixel size (resolution) must be float/int or list/tuple"
self.px_size = (1.0, 1.0) if px_size is None else (float(px_size[0]), float(px_size[1])) \
if isinstance(px_size, (list, tuple)) else (float(px_size), float(px_size))
assert skew is None or \
(isinstance(skew, (list, tuple)) and all([isinstance(i, (float, int)) for i in skew]) and len(skew) == 2) or \
isinstance(skew, (float, int)), "pixel skew must be float/int or list/tuple"
self.skew = (0.0, 0.0) if skew is None else (float(skew[0]), float(skew[1])) \
if isinstance(skew, (list, tuple)) else (float(skew), float(skew))
assert isinstance(allow_outlying_vectors, bool), "unexpected flag type"
assert isinstance(clip_outlying_vectors, bool), "unexpected flag type"
assert isinstance(force_parse, bool), "unexpected flag type"
assert isinstance(reproj_rasters, bool), "unexpected flag type"
assert isinstance(reproj_all_cpus, bool), "unexpected flag type"
assert isinstance(keep_rasters_open, bool), "unexpected flag type"
self.allow_outlying = allow_outlying_vectors
self.clip_outlying = clip_outlying_vectors
self.force_parse = force_parse
self.reproj_rasters = reproj_rasters
self.reproj_all_cpus = reproj_all_cpus
self.keep_rasters_open = keep_rasters_open
assert isinstance(vector_area_min, (float, int)) and vector_area_min >= 0, \
"min surface filter value must be > 0"
assert isinstance(vector_area_max, (float, int)) and vector_area_max >= vector_area_min, \
"max surface filter value must be greater than minimum surface value"
self.area_min = float(vector_area_min)
self.area_max = float(vector_area_max)
assert vector_target_prop is None or isinstance(vector_target_prop, dict), \
"feature target props should be specified as dictionary of property name-value pairs for search"
self.target_prop = {} if vector_target_prop is None else vector_target_prop
assert feature_buffer is None or (isinstance(feature_buffer, (int, float)) and feature_buffer > 0), \
"feature roi 'buffer' value should be strictly positive int/float"
self.feature_buffer = feature_buffer
assert isinstance(master_roi, (str, shapely.geometry.polygon.Polygon,
shapely.geometry.multipolygon.MultiPolygon)) or master_roi is None, \
"invalid master roi (should be path to geojson/shapefile or polygon object)"
assert isinstance(srs_target, (str, int, osr.SpatialReference)), \
"target EPSG SRS must be given as int/str"
self.srs_target = srs_target
if isinstance(self.srs_target, (str, int)):
if isinstance(self.srs_target, str):
self.srs_target = int(self.srs_target.replace("EPSG:", ""))
srs_target_obj = osr.SpatialReference()
srs_target_obj.ImportFromEPSG(self.srs_target)
self.srs_target = srs_target_obj
self.master_roi = geo.utils.parse_roi(master_roi, srs_target=self.srs_target) \
if isinstance(master_roi, str) else master_roi
assert isinstance(raster_key, str), "raster key must be given as string"
self.raster_key = raster_key
assert isinstance(mask_key, str), "mask key must be given as string"
self.mask_key = mask_key
super().__init__(transforms=transforms)
self.rasters_data, self.coverage = self._parse_rasters(self.raster_path, self.srs_target, reproj_rasters)
if self.master_roi is not None:
self.coverage = self.coverage.intersection(self.master_roi)
if cleaner is None:
cleaner = functools.partial(self._default_feature_cleaner, area_min=self.area_min,
area_max=self.area_max, target_prop=self.target_prop)
self.features = self._parse_features(self.vector_path, self.srs_target, self.coverage, cache_hash,
self.allow_outlying, self.clip_outlying, cleaner)
if cropper is None:
cropper = functools.partial(self._default_feature_cropper, px_size=self.px_size,
skew=self.skew, feature_buffer=self.feature_buffer)
self.samples = self._parse_crops(cropper, self.vector_path, cache_hash)
# all keys already in sample dicts should be 'meta'; mask & raster will be added later
meta_keys = list(set([k for s in self.samples for k in s]))
if self.mask_key not in meta_keys:
meta_keys.append(self.mask_key)
# create default task without gt specification (this is a pretty basic parser)
self.task = thelper.tasks.Task(input_key=self.raster_key, meta_keys=meta_keys)
self.display_debug = False # for internal debugging purposes only
@staticmethod
def _default_feature_cleaner(features, area_min, area_max, target_prop=None):
"""Flags geometric features as 'clean' based on some criteria (may be modified in derived classes)."""
# note: we use a flag here instead of removing bad features so that end-users can still use them if needed
for feature in tqdm.tqdm(features, desc="cleaning up features"):
assert isinstance(feature, dict) and "clean" not in feature
feature["clean"] = True
if target_prop is not None and "properties" in feature and isinstance(feature["properties"], dict):
if not all([k in feature["properties"] and feature["properties"][k] == v for k, v in target_prop.items()]):
feature["clean"] = False
if not (area_min <= feature["geometry"].area <= area_max):
feature["clean"] = False
return features
@staticmethod
def _default_feature_cropper(features, rasters_data, coverage, srs_target, px_size, skew, feature_buffer):
"""Returns the samples for a set of features (may be modified in derived classes)."""
# note: default behavior = just center on the feature, and pad if required by user
import thelper.data.geo as geo
samples = []
clean_feats = [f for f in features if f["clean"]]
srs_target_wkt = srs_target.ExportToWkt()
for feature in tqdm.tqdm(clean_feats, desc="validating crop candidates"):
assert feature["clean"] # should not get here with bad features
roi, roi_tl, roi_br, crop_width, crop_height = \
geo.utils.get_feature_roi(feature["geometry"], px_size, skew, feature_buffer)
# test all raster regions that touch the selected feature
raster_hits = []
for raster_idx, raster_data in enumerate(rasters_data):
if raster_data["target_roi"].intersects(roi):
raster_hits.append(raster_idx)
# make list of all other features that may be included in the roi
roi_radius = np.linalg.norm(np.asarray(roi_tl) - np.asarray(roi_br)) / 2
roi_features = [f for f in features if feature["centroid"].distance(f["centroid"]) <= roi_radius and
f["geometry"].intersects(roi)]
# prepare actual 'sample' for crop generation at runtime
samples.append({
"features": roi_features,
"focal": feature,
"roi": roi,
"roi_tl": roi_tl,
"roi_br": roi_br,
"raster_hits": raster_hits,
"crop_width": crop_width,
"crop_height": crop_height,
"geotransform": np.asarray((roi_tl[0], px_size[0], skew[0],
roi_tl[1], skew[1], px_size[1])),
"srs": srs_target_wkt,
})
return samples
@staticmethod
def _parse_rasters(path, srs, reproj_rasters):
"""Parses rasters (geotiffs) and returns metadata/coverage information."""
import thelper.data.geo as geo
logger.info(f"parsing rasters from path '{path}'...")
raster_paths = thelper.utils.get_file_paths(path, ".", allow_glob=True)
rasters_data, coverage = geo.utils.parse_rasters(raster_paths, srs, reproj_rasters)
assert rasters_data, f"could not find any usable rasters at '{raster_paths}'"
logger.debug(f"rasters total coverage area = {coverage.area:.2f}")
for idx, data in enumerate(rasters_data):
logger.debug(f"raster #{idx + 1} area = {data['target_roi'].area:.2f}")
# here, we enforce that raster datatypes/bandcounts match
assert data["band_count"] == rasters_data[0]["band_count"], \
"parser expects that all raster band counts match" + \
f"(found {str(data['band_count'])} and {str(rasters_data[0]['band_count'])})"
assert data["data_type"] == rasters_data[0]["data_type"], \
"parser expects that all raster data types match" + \
f"(found {str(data['data_type'])} and {str(rasters_data[0]['data_type'])})"
data["to_target_transform"] = osr.CoordinateTransformation(data["srs"], srs)
data["from_target_transform"] = osr.CoordinateTransformation(srs, data["srs"])
return rasters_data, coverage
@staticmethod
def _parse_features(path, srs, roi, cache_hash, allow_outlying, clip_outlying, cleaner):
"""Parses vector files (geojsons) and returns geometry information."""
import thelper.data.geo as geo
logger.info(f"parsing vectors from path '{path}'...")
assert os.path.isfile(path) and path.endswith("geojson"), \
"vector file must be provided as geojson (shapefile support still incomplete)"
cache_file_path = os.path.join(os.path.dirname(path), cache_hash + ".feats.pkl") \
if cache_hash else None
if cache_file_path is not None and os.path.exists(cache_file_path):
logger.debug(f"parsing cached feature data from '{cache_file_path}'...")
with open(cache_file_path, "rb") as fd:
features = pickle.load(fd)
else:
with open(path) as vector_fd:
vector_data = json.load(vector_fd)
features = geo.utils.parse_geojson(vector_data, srs_target=srs, roi=roi,
allow_outlying=allow_outlying, clip_outlying=clip_outlying)
features = cleaner(features)
if cache_file_path is not None:
logger.debug(f"caching clean data to '{cache_file_path}'...")
with open(cache_file_path, "wb") as fd:
pickle.dump(features, fd)
logger.debug(f"cleanup resulted in {len([f for f in features if f['clean']])} features of interest")
return features
def _parse_crops(self, cropper, cache_file_path, cache_hash):
"""Parses crops based on prior feature/raster data.
Each 'crop' corresponds to a sample that can be loaded at runtime.
"""
logger.info("preparing crops...")
cache_file_path = os.path.join(os.path.dirname(cache_file_path), cache_hash + ".crops.pkl") \
if cache_hash else None
if cache_file_path is not None and os.path.exists(cache_file_path):
logger.debug(f"parsing cached crop data from '{cache_file_path}'...")
with open(cache_file_path, "rb") as fd:
samples = pickle.load(fd)
else:
samples = cropper(self.features, self.rasters_data, self.coverage, self.srs_target)
if cache_file_path is not None:
logger.debug(f"caching crop data to '{cache_file_path}'...")
with open(cache_file_path, "wb") as fd:
pickle.dump(samples, fd)
return samples
def _process_crop(self, sample):
"""Returns a crop for a specific (internal) set of sampled features."""
import thelper.data.geo as geo
# remember: we assume that all rasters have the same intrinsic settings
crop_datatype = geo.utils.GDAL2NUMPY_TYPE_CONV[self.rasters_data[0]["data_type"]]
crop_size = (sample["crop_height"], sample["crop_width"], self.rasters_data[0]["band_count"])
crop = np.ma.array(np.zeros(crop_size, dtype=crop_datatype), mask=np.ones(crop_size, dtype=np.uint8))
mask = np.zeros(crop_size[0:2], dtype=np.uint8)
crop_raster_gdal = gdal.GetDriverByName("MEM").Create("", crop_size[1], crop_size[0],
crop_size[2], self.rasters_data[0]["data_type"])
crop_raster_gdal.SetGeoTransform(sample["geotransform"])
crop_raster_gdal.SetProjection(self.srs_target.ExportToWkt())
crop_mask_gdal = gdal.GetDriverByName("MEM").Create("", crop_size[1], crop_size[0], 1, gdal.GDT_Byte)
crop_mask_gdal.SetGeoTransform(sample["geotransform"])
crop_mask_gdal.SetProjection(self.srs_target.ExportToWkt())
crop_mask_gdal.GetRasterBand(1).WriteArray(np.zeros(crop_size[0:2], dtype=np.uint8))
ogr_dataset = ogr.GetDriverByName("Memory").CreateDataSource("mask")
ogr_layer = ogr_dataset.CreateLayer("feature_mask", srs=self.srs_target)
for feature in sample["features"]:
ogr_feature = ogr.Feature(ogr_layer.GetLayerDefn())
ogr_geometry = ogr.CreateGeometryFromWkt(feature["geometry"].wkt)
ogr_feature.SetGeometry(ogr_geometry)
ogr_layer.CreateFeature(ogr_feature)
gdal.RasterizeLayer(crop_mask_gdal, [1], ogr_layer, burn_values=[1], options=["ALL_TOUCHED=TRUE"])
np.copyto(dst=mask, src=crop_mask_gdal.GetRasterBand(1).ReadAsArray())
for raster_idx in sample["raster_hits"]:
rasterfile = geo.utils.open_rasterfile(self.rasters_data[raster_idx],
keep_rasters_open=self.keep_rasters_open)
assert rasterfile.RasterCount == crop_size[2], "unexpected raster count"
# using all cpus should be ok since we probably cant parallelize this loader anyway (swig serialization issues)
options = ["NUM_THREADS=ALL_CPUS"] if self.reproj_all_cpus else []
geo.utils.reproject_crop(rasterfile, crop_raster_gdal, crop_size, crop_datatype, reproj_opt=options, fill_nodata=True)
for raster_band_idx in range(crop_raster_gdal.RasterCount):
curr_band = crop_raster_gdal.GetRasterBand(raster_band_idx + 1)
curr_band_array = curr_band.ReadAsArray()
flag_mask = curr_band_array != curr_band.GetNoDataValue()
np.copyto(dst=crop.data[:, :, raster_band_idx], src=curr_band_array, where=flag_mask)
np.bitwise_and(crop.mask[:, :, raster_band_idx], np.invert(flag_mask), out=crop.mask[:, :, raster_band_idx])
# ogr_dataset = None # noqa # close local fd
# noinspection PyUnusedLocal
crop_raster_gdal = None # noqa # close local fd
# noinspection PyUnusedLocal
crop_mask_gdal = None # noqa # close local fd
# noinspection PyUnusedLocal
rasterfile = None # noqa # close input fd
return crop, mask
[docs] def __getitem__(self, idx):
"""Returns the data sample (a dictionary) for a specific (0-based) index."""
if isinstance(idx, slice):
return self._getitems(idx)
assert idx < len(self.samples), "sample index is out-of-range"
if idx < 0:
idx = len(self.samples) + idx
sample = self.samples[idx]
crop, mask = self._process_crop(sample)
if self.display_debug:
crop = cv.cvtColor(crop, cv.COLOR_GRAY2BGR)
mask = cv.cvtColor(mask, cv.COLOR_GRAY2BGR)
mask[:, :, 1:3] = 0
crop = cv.normalize(crop, dst=crop, alpha=0, beta=255, norm_type=cv.NORM_MINMAX, dtype=cv.CV_8U)
mask = cv.normalize(mask, dst=mask, alpha=0, beta=255, norm_type=cv.NORM_MINMAX, dtype=cv.CV_8U)
crop = np.bitwise_or(crop, mask)
sample = {
self.raster_key: np.array(crop.data, copy=True),
self.mask_key: mask,
**sample
}
if self.transforms:
sample = self.transforms(sample)
return sample
[docs]class TileDataset(VectorCropDataset):
"""Abstract dataset used to systematically tile vector data and rasters."""
[docs] def __init__(self, raster_path, vector_path, tile_size, tile_overlap=0,
skip_empty_tiles=False, skip_nodata_tiles=True, px_size=None,
allow_outlying_vectors=True, clip_outlying_vectors=True,
vector_area_min=0.0, vector_area_max=float("inf"),
vector_target_prop=None, master_roi=None, srs_target="3857",
raster_key="raster", mask_key="mask", cleaner=None,
force_parse=False, reproj_rasters=False,
reproj_all_cpus=True, keep_rasters_open=True, transforms=None):
# note1: input 'tile_size' must be given in pixels
# note2: input 'tile_overlap' must be given in pixels
# note3: input 'px_size' must be given in meters/degrees
if isinstance(tile_size, (float, int)):
tile_size = (tile_size, tile_size)
assert isinstance(tile_size, (tuple, list)) and len(tile_size) == 2, \
"invalid tile size (should be scalar or two-elem tuple)"
assert all([t > 0 for t in tile_size]), "unexpected tile size value (should be positive)"
tile_size = [float(t) for t in tile_size] # convert all vals if necessary
assert isinstance(tile_overlap, (float, int)) and tile_overlap >= 0, \
"unexpected tile overlap (should be non-negative scalar)"
tile_overlap = float(tile_overlap)
assert isinstance(skip_empty_tiles, bool), "unexpected flag type (should be bool)"
assert isinstance(skip_nodata_tiles, bool), "unexpected flag type (should be bool)"
cropper = functools.partial(self._tile_cropper, tile_size=tile_size, tile_overlap=tile_overlap,
skip_empty_tiles=skip_empty_tiles, skip_nodata_tiles=skip_nodata_tiles,
keep_rasters_open=keep_rasters_open, px_size=px_size)
super().__init__(raster_path=raster_path, vector_path=vector_path, px_size=px_size, skew=None,
allow_outlying_vectors=allow_outlying_vectors, clip_outlying_vectors=clip_outlying_vectors,
vector_area_min=vector_area_min, vector_area_max=vector_area_max, vector_target_prop=vector_target_prop,
master_roi=master_roi, srs_target=srs_target, raster_key=raster_key, mask_key=mask_key,
cleaner=cleaner, cropper=cropper, force_parse=force_parse, reproj_rasters=reproj_rasters,
reproj_all_cpus=reproj_all_cpus, keep_rasters_open=keep_rasters_open, transforms=transforms)
@staticmethod
def _tile_cropper(features, rasters_data, coverage, srs_target, tile_size, tile_overlap,
skip_empty_tiles, skip_nodata_tiles, keep_rasters_open, px_size):
"""Returns the ROI information for a given feature (may be modified in derived classes)."""
import thelper.data.geo as geo
# instead of iterating over features to generate samples, we tile the raster(s)
# note: the 'coverage' geometry should already be in the target srs
roi_tl, roi_br = geo.utils.get_feature_bbox(coverage)
roi_geotransform = (roi_tl[0], px_size[0], 0.0,
roi_tl[1], 0.0, px_size[1])
srs_target_wkt = srs_target.ExportToWkt()
# remember: we assume that all rasters have the same intrinsic settings
crop_datatype = geo.utils.GDAL2NUMPY_TYPE_CONV[rasters_data[0]["data_type"]]
crop_raster_gdal = gdal.GetDriverByName("MEM").Create("",
int(round(tile_size[1])),
int(round(tile_size[0])),
rasters_data[0]["band_count"],
rasters_data[0]["data_type"])
crop_raster_gdal.SetProjection(srs_target_wkt)
samples = []
crop_id = 0
roi_px_br = geo.utils.get_pxcoord(roi_geotransform, *roi_br)
nb_iter_y = int(math.ceil((roi_px_br[1] + tile_overlap) / (tile_size[1] - tile_overlap)))
nb_iter_x = int(math.ceil((roi_px_br[0] + tile_overlap) / (tile_size[0] - tile_overlap)))
pbar = tqdm.tqdm(total=nb_iter_y * nb_iter_x, desc="validating crop candidates")
roi_offset_px_y = -tile_overlap
while roi_offset_px_y < roi_px_br[1]:
roi_offset_px_x = -tile_overlap
while roi_offset_px_x < roi_px_br[0]:
pbar.update(1)
crop_px_tl = (roi_offset_px_x, roi_offset_px_y)
crop_px_br = (crop_px_tl[0] + tile_size[0], crop_px_tl[1] + tile_size[1])
crop_tl = geo.utils.get_geocoord(roi_geotransform, *crop_px_tl)
crop_br = geo.utils.get_geocoord(roi_geotransform, *crop_px_br)
crop_geom = shapely.geometry.Polygon([crop_tl, (crop_br[0], crop_tl[1]),
crop_br, (crop_tl[0], crop_br[1])])
crop_geotransform = (crop_tl[0], px_size[0], 0.0,
crop_tl[1], 0.0, px_size[1])
crop_raster_gdal.SetGeoTransform(crop_geotransform)
raster_hits = []
found_valid_intersection = False or not skip_nodata_tiles
for raster_idx, raster_data in enumerate(rasters_data):
if raster_data["target_roi"].intersects(crop_geom):
if not found_valid_intersection:
rasterfile = geo.utils.open_rasterfile(raster_data, keep_rasters_open=keep_rasters_open)
# yeah, we reproject the crop, preprocessing is slow, deal with it
geo.utils.reproject_crop(rasterfile, crop_raster_gdal, tile_size, crop_datatype, fill_nodata=True)
for raster_band_idx in range(crop_raster_gdal.RasterCount):
curr_band = crop_raster_gdal.GetRasterBand(raster_band_idx + 1)
found_valid_intersection = found_valid_intersection or \
np.count_nonzero(curr_band.ReadAsArray() != curr_band.GetNoDataValue()) > 0
raster_hits.append(raster_idx)
if raster_hits and found_valid_intersection:
crop_centroid = crop_geom.centroid
crop_radius = np.linalg.norm(np.asarray(crop_tl) - np.asarray(crop_br)) / 2
crop_features = []
for f in features:
if f["geometry"].distance(crop_centroid) > crop_radius:
continue
inters = f["geometry"].intersection(crop_geom)
if inters.is_empty:
continue
crop_features.append(f)
if crop_features or not skip_empty_tiles:
# prepare actual 'sample' for crop generation at runtime
samples.append({
"features": crop_features,
"id": crop_id,
"roi": crop_geom,
"roi_tl": crop_tl,
"roi_br": crop_br,
"raster_hits": raster_hits,
"crop_width": int(round(tile_size[0])),
"crop_height": int(round(tile_size[1])),
"geotransform": np.asarray(crop_geotransform),
})
crop_id += 1
roi_offset_px_x += tile_size[0] - tile_overlap
roi_offset_px_y += tile_size[1] - tile_overlap
return samples
[docs]class ImageFolderGDataset(ImageFolderDataset):
"""Image folder dataset specialization interface for classification tasks on geospatial images.
This specialization is used to parse simple image subfolders, and it essentially replaces the very
basic ``torchvision.datasets.ImageFolder`` interface with similar functionalities. It it used to provide
a proper task interface as well as path metadata in each loaded packet for metrics/logging output.
The difference with the parent class ImageFolderDataset is the used of gdal to manage multi channels images found
in remote sensing domain. The user can specify the channels to load. By default the first three channels are
loaded [1,2,3].
.. seealso::
| :class:`thelper.data.parsers.ImageDataset`
| :class:`thelper.data.parsers.ClassificationDataset`
| :class:`thelper.data.parsers.ImageFolderDataset`
"""
[docs] def __init__(self, root, transforms=None, image_key="image", label_key="label",
path_key="path", idx_key="idx", channels=None):
"""Image folder dataset parser constructor."""
super(ImageFolderGDataset, self).__init__(root=root, transforms=transforms, image_key=image_key,
path_key=path_key, label_key=label_key, idx_key=idx_key)
self.channels = channels if channels else [1, 2, 3]
[docs] def __getitem__(self, idx):
"""Returns the data sample (a dictionary) for a specific (0-based) index."""
if isinstance(idx, slice):
return self._getitems(idx)
if idx >= len(self.samples):
raise AssertionError("sample index is out-of-range")
if idx < 0:
idx = len(self.samples) + idx
sample = self.samples[idx]
raster_path = sample[self.path_key]
raster_ds = gdal.Open(raster_path, gdal.GA_ReadOnly)
if raster_ds is None:
raise Exception(f"File not found: {raster_path}")
image = []
for channel in self.channels:
image_arr = raster_ds.GetRasterBand(channel).ReadAsArray()
if image_arr is None:
logger.fatal(f"Band not found: {channel}")
image.append(image_arr)
image = np.dstack(image)
raster_ds = None # noqa # flush
sample = {
self.image_key: image,
self.idx_key: idx,
**sample
}
if self.transforms:
sample = self.transforms(sample)
return sample
[docs]class SlidingWindowDataset(Dataset):
"""Sliding window dataset specialization interface for classification tasks over a geospatial image.
The dataset runs a sliding window over the whole geospatial image in order to return tile patches.
The operation can be accomplished over multiple raster bands if they can be found in the provided raster container.
"""
[docs] def __init__(self, raster_path, raster_bands, patch_size, transforms=None, image_key="image"):
super().__init__(transforms=transforms)
self.logger.debug("Creating %s with [%s]", type(self).__name__, raster_path)
self.image_key = image_key
self.center_key = "center"
self.raster_dss = []
# update raster metadata that can be used by other objects
self.raster = {"path": raster_path, "bands": raster_bands}
raster_ds = gdal.OpenShared(raster_path, gdal.GA_ReadOnly)
parse_raster_metadata(self.raster, raster_ds)
xsize = raster_ds.RasterXSize
ysize = raster_ds.RasterYSize
self.patch_size = patch_size
self.raster["xsize"] = xsize
self.raster["ysize"] = ysize
self.raster["georef"] = raster_ds.GetProjectionRef()
self.raster["affine"] = raster_ds.GetGeoTransform()
raster_ds = None # noqa # flush dataset
# generate patch sample locations
lines = ysize - self.patch_size
cols = xsize - self.patch_size
self.samples = []
for y in range(lines):
for x in range(cols):
self.samples.append((x, y, self.patch_size, self.patch_size))
self.n_samples = len(self.samples)
self.logger.info(f"Number of samples: {self.n_samples}")
self.done = False
def __len__(self):
return self.n_samples
[docs] def __getitem__(self, idx):
# Get the number n of workers and the current worker's id
info = torch.utils.data.get_worker_info()
# Open the data with gdal n times in multithread shared mode
# The operation is done once
if not self.done:
raster_path = self.raster.get('reader', self.raster['path'])
self.logger.info(f"Single time load of raster: [{raster_path}]")
for _ in range(info.num_workers):
raster_ds = gdal.OpenShared(raster_path, gdal.GA_ReadOnly)
self.raster_dss.append(raster_ds)
self.done = True
# Do your processing with the gdal dataset associated with the worker's id
image = []
patch = self.samples[idx]
for raster_band in self.raster["bands"]:
image.append(self.raster_dss[info.id].GetRasterBand(raster_band).ReadAsArray(*patch))
image = np.dstack(image)
offsets = patch[:2]
half_size = self.patch_size // 2
sample = {
self.image_key: np.array(image.data, copy=True, dtype='float32'),
self.center_key: (offsets[0] + half_size, offsets[1] + half_size),
}
if self.transforms:
sample = self.transforms(sample)
return sample