Source code for thelper.data.geo.agrivis

"""
Agricultural Semantic Segentation Challenge Dataset Interface

Original author: David Landry (david.landry@crim.ca)
Updated by Pierre-Luc St-Charles (April 2020)
"""

import copy
import logging
import os
import pprint
import shutil
import typing

import h5py
import numpy as np
import torch.utils
import torch.utils.data
import tqdm

import thelper.data
import thelper.tasks
import thelper.utils
from thelper.data.parsers import Dataset

logger = logging.getLogger(__name__)


class_names = [
    "background",     # optional, depending on task
    "cloud_shadow",
    "double_plant",
    "planter_skip",
    "standing_water",
    "waterway",
    "weed_cluster",
]

approx_weight_map = {
    "background": 0.7754729398810614,
    "cloud_shadow": 0.02987549383646342,
    "double_plant": 0.006768273283806349,
    "planter_skip": 0.0016827190442308664,
    "standing_water": 0.015964306228958156,
    "waterway": 0.012930148362618188,
    "weed_cluster": 0.1573061193628617
}

dontcare = 255


[docs]class Hdf5AgricultureDataset(Dataset):
[docs] def __init__( self, hdf5_path: typing.AnyStr, group_name: typing.AnyStr, transforms: typing.Any = None, use_global_normalization: bool = True, keep_file_open: bool = False, load_meta_keys: bool = False, copy_to_slurm_tmpdir: bool = False, ): super().__init__(transforms, deepcopy=False) if copy_to_slurm_tmpdir: assert os.path.isfile(hdf5_path), f"invalid input hdf5 path: {hdf5_path}" slurm_tmpdir = thelper.utils.get_slurm_tmpdir() assert slurm_tmpdir is not None, "undefined SLURM_TMPDIR env variable" dest_hdf5_path = os.path.join(slurm_tmpdir, "agrivis.hdf5") if not os.path.isfile(dest_hdf5_path): shutil.copyfile(hdf5_path, dest_hdf5_path) hdf5_path = dest_hdf5_path logger.info(f"reading AgriVis challenge {group_name} data from: {hdf5_path}") self.hdf5_path = hdf5_path self.group_name = group_name self.load_meta_keys = load_meta_keys with h5py.File(self.hdf5_path, "r") as archive: assert group_name in archive, \ "unexpected dataset name (should be train/val/test)" dataset = archive[group_name] expected_keys = ["boundaries", "features", "keys"] if group_name != "test": expected_keys += ["labels", "n_labelled_pixels"] assert all([k in dataset.keys() for k in expected_keys]), \ "missing at least one of the expected dataset group keys" assert all([len(dataset[k]) == len(dataset["keys"]) for k in expected_keys]), \ "dataset sample count mismatch across all subgroups" if group_name != "test": assert dataset["labels"].shape[-1] == len(class_names) - 1, \ "unexpected dataset label map count while accounting for background" meta_iter = zip(dataset["keys"], dataset["n_labelled_pixels"]) else: meta_iter = zip(dataset["keys"], [None] * len(dataset["keys"])) self.samples = [{ # list pre-fill "image": None, "label_map": None, "key": key, "mask": None, "pxcounts": pxcounts, } for key, pxcounts in meta_iter] logger.info(f"loaded metadata for {len(self.samples)} patches") self.task = thelper.tasks.Segmentation( class_names=class_names, input_key="image", label_map_key="label_map", meta_keys=["key", "mask", "pxcounts"], dontcare=dontcare, ) self.use_global_normalization = use_global_normalization self.image_mean = np.asarray([ 121.6028380635106, 118.52572985557143, 116.36513065674848, 108.47336023815292, ], dtype=np.float32) self.image_stddev = np.asarray([ 41.47667301013803, 41.782106439616534, 45.04215840534553, 44.53299631408866, ], dtype=np.float32) self.hdf5_handle = h5py.File(self.hdf5_path, "r") if keep_file_open else None
# self.squished = 0 def __len__(self): return len(self.samples)
[docs] def __getitem__(self, idx): if isinstance(idx, slice): return self._getitems(idx) assert idx < len(self.samples), "sample index is out-of-range" if idx < 0: idx = len(self.samples) + idx label_map = None if self.hdf5_handle is not None: image = self.hdf5_handle[self.group_name]["features"][idx] mask = self.hdf5_handle[self.group_name]["boundaries"][idx] if self.group_name != "test": label_map = self.hdf5_handle[self.group_name]["labels"][idx] else: with h5py.File(self.hdf5_path, mode="r") as archive: image = archive[self.group_name]["features"][idx] mask = archive[self.group_name]["boundaries"][idx] if self.group_name != "test": label_map = archive[self.group_name]["labels"][idx] if self.use_global_normalization: image = (image.astype(np.float32) - self.image_mean) / self.image_stddev mask = mask.astype(np.int16) if label_map is not None: # note: we might squish some overlapping labels, but these are very rare... (<0.07%) out_label_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.int16) for label_idx in range(1, len(class_names)): orig_label_map_idx = label_idx - 1 curr_label_map = label_map[..., orig_label_map_idx] # overlap = np.logical_and(out_label_map != 0, curr_label_map) # self.squished += np.count_nonzero(overlap) out_label_map = np.where(curr_label_map, np.int16(label_idx), out_label_map) label_map = out_label_map label_map = np.where(mask, label_map, np.int16(dontcare)) sample = { "image": image, "label_map": label_map, "mask": mask, # drop key if neccessary to fix batching w/ PyTorch default collate "key": self.samples[idx]["key"] if self.load_meta_keys else None, "pxcounts": copy.deepcopy(self.samples[idx]["pxcounts"]), } if self.transforms: sample = self.transforms(sample) return sample
def _compute_statistics(dataset) -> typing.Dict: array_alloc_size = len(dataset) stat_arrays = [ # alloc three vals per item: px-wise sum, px-wise sqsum, px count np.zeros((array_alloc_size, 3), dtype=np.float64) for band in range(4) ] for batch_idx, sample in enumerate(tqdm.tqdm(dataset)): image = sample["image"] assert image.ndim == 3 and image.shape[-1] == 4 for ch_idx in range(4): image_ch = image[..., ch_idx] stat_arrays[ch_idx][batch_idx] = ( np.sum(image_ch, dtype=np.float64), np.sum(np.square(image_ch, dtype=np.float64)), np.float64(image_ch.size), ) stat_map = {} for band_idx, band_array in enumerate(stat_arrays): tot_size = np.sum(band_array[:, 2]) mean = np.sum(band_array[:, 0]) / tot_size stddev = np.sqrt(np.sum(band_array[:, 1]) / tot_size - np.square(mean)) stat_map[band_idx] = {"mean": mean, "stddev": stddev} return stat_map def _compute_class_weights(dataset) -> typing.Dict: class_counts = {key: 0 for key in class_names} tot_samples = 0 image_px_count = 512 * 512 # fixed size for sample in tqdm.tqdm(dataset): if sample["pxcounts"] is not None: tot_labeled_px = 0 for class_idx, px_count in enumerate(sample["pxcounts"]): class_counts[class_names[class_idx + 1]] += px_count tot_labeled_px += px_count class_counts["background"] += image_px_count - tot_labeled_px tot_samples += 1 tot_count = tot_samples * image_px_count class_weights = { key: count / tot_count for key, count in class_counts.items() } return class_weights if __name__ == "__main__": # @@@@ TODO: CONVERT TO PROPER TEST logging.basicConfig() logging.getLogger().setLevel(logging.NOTSET) dataset = torch.utils.data.ConcatDataset([ Hdf5AgricultureDataset( hdf5_path="/shared/data_ufast_ext4/datasets/agrivis/agri_v2.hdf5", group_name=group_name, use_global_normalization=False, keep_file_open=True, ) for group_name in ["train", "val", "test"] ]) # out_map = _compute_statistics(dataset) out_map = _compute_class_weights(dataset) logging.info(f"out_map =\n{pprint.pformat(out_map, indent=4)}") print("all done")