Source code for thelper.data.geo.infer

import json
import logging
import os
from typing import TYPE_CHECKING

import gdal
import numpy as np
import torch

import thelper.concepts
import thelper.data.geo
from thelper.infer.base import Tester

if TYPE_CHECKING:
    from typing import AnyStr, Optional, Tuple  # noqa: F401

logger = logging.getLogger(__name__)


[docs]@thelper.concepts.classification class SlidingWindowTester(Tester): """Tester that satisfies the requirements of the :class:`Tester` in order to run classification inference """
[docs] def __init__(self, session_name, # type: AnyStr session_dir, # type: AnyStr model, # type: thelper.typedefs.ModelType task, # type: thelper.tasks.Task loaders, # type: thelper.typedefs.MultiLoaderType config, # type: thelper.typedefs.ConfigDict ckptdata=None # type: Optional[thelper.typedefs.CheckpointContentType] ): super(SlidingWindowTester, self).__init__(session_name, session_dir, model, task, loaders, config, ckptdata=ckptdata) # because 'tester' gets called explicitly during inference, check for tester/runner before trainer key # this way we can favor using a detailed config which specified both trainer/tester simultaneously and # use the correct one with all corresponding CLI modes runner_config = thelper.utils.get_key(["runner", "tester", "trainer"], config) self.normalize_loss = thelper.utils.get_key_def("normalize_loss", runner_config, True)
[docs] def eval_epoch(self, model, epoch, dev, loader, metrics, output_path): """Computes the pixelwise prediction on an image. It does the prediction per batch size of N pixels. It returns the class predicted and its probability. The results are saved into two images created with the same size and projection info as the input rasters. The ``class`` image gives the class id, a number between 1 and the number of classes for corresponding pixels. Class id 0 is reserved for ``nodata``. The ``probs`` image contains N-class channels with the probability values of the pixels for each class. The probabilities by default are normalised. Also, a ``config-classes.json`` file is created listing the ``name-to-class-id`` mapping that was used to generate the values in the ``class`` image (i.e.: class names defined by the pre-trained ``model``). Args: model: the model with which to run inference that is already uploaded to the target device(s). epoch: the epoch index we are training for (0-based, and should normally only be 0 for single test pass). dev: the target device that tensors should be uploaded to (corresponding to model's device(s)). loader: the data loader used to get transformed test samples. metrics: the dictionary of metrics/consumers to report inference results (mostly loggers and basic report generator in this case since there shouldn't be ground truth labels to validate against). output_path: directory where output files should be written, if necessary. """ ds_type = thelper.data.geo.parsers.SlidingWindowDataset if not isinstance(loader.dataset, ds_type): raise AssertionError(f"Only dataset '{ds_type.__module__}.{ds_type.__name__}' is supported by " f"test session runner {SlidingWindowTester.__module__}.{SlidingWindowTester.__name__}") output_path = os.path.abspath(output_path) class_count = len(model.task.class_names) class_ds, probs_ds = self._prepare_output_rasters(loader.dataset, output_path, class_count) class_indices = model.task.class_indices for key in class_indices.keys(): class_indices[key] += 1 class_indices['no_data'] = 0 class_indices_file = "config-classes.json" class_indices_file_path = os.path.join(output_path, class_indices_file) logger.debug("Writing class indices: [%s]", class_indices_file_path) with open(class_indices_file_path, 'w') as f: json.dump(class_indices, f, indent=4) normalize = torch.nn.Softmax(dim=1) if self.normalize_loss else lambda _: _ # Normalizing/pass-through model.eval() with torch.no_grad(): n_batches = len(loader) n_patches = loader.batch_size logger.debug("Starting inference of %s batches each composed of %s patch samples", n_batches, n_patches) for k, sample in enumerate(loader): logger.info(f"Batch {k+1} of {n_batches}: {(k+1)/n_batches:4.1%}") center_x0 = self._move_tensor(sample[loader.dataset.center_key][0], dev="cpu", detach=True).data.numpy() center_y0 = self._move_tensor(sample[loader.dataset.center_key][1], dev="cpu", detach=True).data.numpy() n_data = center_x0.shape[0] # batch-size x_data = sample[loader.dataset.image_key] x_data = self._move_tensor(x_data, dev=dev) y_prob = model(x_data) y_prob = normalize(y_prob) y_class_indices = torch.argmax(y_prob, dim=1) y_class_indices = self._move_tensor(y_class_indices, dev="cpu", detach=True).data.numpy() y_prob = self._move_tensor(y_prob, dev="cpu", detach=True).data.numpy() # loop each patch from the batch for j in range(n_data): class_id = np.array([[y_class_indices[j] + 1]]) x0 = int(center_x0[j]) y0 = int(center_y0[j]) class_ds.GetRasterBand(1).WriteArray(class_id, x0, y0) for p in range(y_prob.shape[1]): probs_ds.GetRasterBand(p + 1).WriteArray(np.array([[y_prob[j, p]]], dtype='float32'), int(center_x0[j]), int(center_y0[j])) # save writen changes to disk class_ds.FlushCache() probs_ds.FlushCache() logger.debug("Closing output rasters") class_ds = None # noqa # close file probs_ds = None # noqa # close file
@staticmethod def _prepare_output_rasters(raster_loader, output_path, class_count): # type: (thelper.data.geo.parsers.SlidingWindowDataset, AnyStr, int) -> Tuple[gdal.Dataset, gdal.Dataset] """ Generates the ``class`` and ``probs`` datasets to be filed by inference results. """ logger.info("Preparing output rasters") logger.debug("using output name: [%s]", raster_loader.raster["name"]) xsize = raster_loader.raster["xsize"] ysize = raster_loader.raster["ysize"] georef = raster_loader.raster["georef"] affine = raster_loader.raster["affine"] raster_name = raster_loader.raster["name"] raster_class_name = f"{raster_name}_class.tif" raster_class_path = os.path.join(output_path, raster_class_name) # Create the class raster output class_ds = gdal.GetDriverByName('GTiff').Create(raster_class_path, xsize, ysize, 1, gdal.GDT_Byte) if class_ds is None: raise IOError(f"Unable to create: [{raster_class_path}]") else: logger.debug(f"Creating: [{raster_class_path}]") class_ds.SetGeoTransform(affine) class_ds.SetProjection(georef) class_band = class_ds.GetRasterBand(1) class_band.SetNoDataValue(0) class_ds.FlushCache() # save to disk class_ds = None # noqa # need to close before open-update class_band = None # noqa # also close band (remove ptr) class_ds = gdal.Open(raster_class_path, gdal.GA_Update) # Create the probabilities raster output raster_prob_name = f"{raster_name}_probs.tif" raster_prob_path = os.path.join(output_path, raster_prob_name) probs_ds = gdal.GetDriverByName('GTiff').Create(raster_prob_path, xsize, ysize, class_count, gdal.GDT_Float32) if probs_ds is None: raise IOError(f"Unable to create: [{raster_prob_path}]") else: logger.debug(f"Creating: [{raster_prob_path}]") probs_ds.SetGeoTransform(affine) probs_ds.SetProjection(georef) probs_ds.FlushCache() # save to disk probs_ds = None # noqa # need to close before open-update probs_ds = gdal.Open(raster_prob_path, gdal.GA_Update) return class_ds, probs_ds