#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module defines a method to run a RasterProcessing on sliding windows.
"""
from itertools import repeat
import logging.config
import os
from typing import List
import multiprocessing
import numpy as np
import numpy.ma as ma
import rasterio
from rasterio.windows import Window
from tqdm.contrib.concurrent import process_map
from eolab.georastertools import utils
from eolab.georastertools.processing import RasterProcessing
_logger = logging.getLogger(__name__)
[docs]def compute_sliding(input_image: str, output_image: str, rasterprocessing: RasterProcessing,
window_size: tuple = (1024, 1024), window_overlap: int = 0,
pad_mode: str = "edge", bands: List[int] = None):
"""
Apply a sliding window raster processing operation on an input image and save the result.
This function processes a raster image in small sliding windows, allowing efficient
memory management for large datasets by processing chunks. The specified `rasterprocessing`
operation is applied to each window, with options for padding and overlapping windows.
Args:
input_image (str): Path to the input raster image file to be processed.
output_image (str): Path to save the output raster image after processing.
rasterprocessing (RasterProcessing): A processing object defining the algorithm and
parameters to apply on each window of the input image.
window_size (tuple(int, int), optional): Size of each window for processing,
default is (1024, 1024).
window_overlap (int, optional): Number of pixels to overlap between consecutive windows,
default is 0.
pad_mode (str, optional, default="edge"): Padding mode for the edges of the windows, default is "edge".
Refer to `numpy.pad <https://numpy.org/doc/stable/reference/generated/numpy.pad.html`_ documentation for valid modes.
bands (List[int], optional): List of specific bands to process. If None, all bands in
the input image will be processed.
Raises:
ValueError: If specified band indices are out of range for the input image.
Note:
This function supports concurrent processing and makes use of a thread pool for efficient
handling of multiple windows. Window padding at the image boundaries is applied as
specified, and sliding window indices are computed internally.
"""
with rasterio.Env(GDAL_VRT_ENABLE_PYTHON=True):
with rasterio.open(input_image) as src:
profile = src.profile
# set block size
blockxsize, blockysize = window_size
if src.width < blockxsize:
blockxsize = utils.highest_power_of_2(src.width)
if src.height < blockysize:
blockysize = utils.highest_power_of_2(src.height)
# dtype and creation options of output data
dtype = rasterprocessing.dtype or rasterio.float32
in_dtype = rasterprocessing.in_dtype or dtype
nbits = rasterprocessing.nbits
compress = rasterprocessing.compress or src.compression or 'lzw'
nodata = rasterprocessing.nodata or src.nodata
# check band index and handle all bands options (when bands is an empty list)
if bands is None or len(bands) == 0:
bands = src.indexes
elif min(bands) < 1 or max(bands) > src.count:
raise ValueError(f"Invalid bands, all values are not in range [1, {src.count}]")
# setup profile for output image
profile.update(driver='GTiff', blockxsize=blockxsize, blockysize=blockysize,
tiled=True, dtype=dtype, nbits=nbits, compress=compress,
nodata=nodata, count=len(bands))
with rasterio.open(output_image, "w", **profile):
# file is created
pass
# create the generator of sliding windows
sliding_gen = _sliding_windows((src.width, src.height),
window_size, window_overlap)
if rasterprocessing.per_band_algo:
sliding_windows_bands = [(w, [b]) for w in sliding_gen for b in bands]
else:
sliding_windows_bands = [(w, bands) for w in sliding_gen]
m = multiprocessing.Manager()
write_lock = m.Lock()
# compute using concurrent.futures.ThreadPoolExecutor and tqdm
kwargs = {
"total": len(sliding_windows_bands),
"disable": os.getenv("RASTERTOOLS_NOTQDM", 'False').lower() in ['true', '1']
}
max_workers = os.getenv("RASTERTOOLS_MAXWORKERS")
if max_workers is not None:
kwargs["max_workers"] = int(max_workers)
process_map(_process_sliding, repeat(rasterprocessing),
repeat(input_image), repeat(output_image),
sliding_windows_bands, repeat(window_overlap),
repeat(pad_mode), repeat(in_dtype),
repeat(write_lock),
**kwargs)
def _process_sliding(rasterprocessing: RasterProcessing,
input_image, output_image,
sliding_windowbands, window_overlap,
pad_mode, dtype, write_lock):
"""Internal method that computes the raster data for a specific window.
This method can be called safely by several processes thanks to the locks
that prevent from writing files simultaneously.
"""
sliding_window, bands = sliding_windowbands
r_window, pad, w_window = sliding_window
with rasterio.Env(GDAL_VRT_ENABLE_PYTHON=True):
with rasterio.open(input_image) as src:
dataset = _read_dataset(src, bands, r_window, pad, pad_mode)
dataset = dataset.astype(dtype)
# The computation can be performed concurrently
output = rasterprocessing.compute(dataset)
# Use of the lock to avoid writing in //
with write_lock:
with rasterio.open(output_image, mode="r+") as dst:
if rasterprocessing.per_band_algo:
# here bands only contain a single item which is the band number
dst.write_band(
bands[0],
output[0, window_overlap:-window_overlap,
window_overlap:-window_overlap],
window=w_window)
else:
dst.write(
output[:,
window_overlap:-window_overlap,
window_overlap:-window_overlap],
window=w_window)
def _read_dataset(src, bands: List[int], window: Window, pad: tuple, pad_mode: str):
"""Read a src dataset
Args:
src:
Source dataset as given by rasterio.open(...)
bands ([int]):
Bands to read or None if all bands shall be read
window (:obj:`rasterio.windows.Window`):
Window of data to read
pad (tuple of tuple of int):
Pad to apply to the read data: (padx, pady). padx
and pady are also a tuple. The first (resp. second) value is the
pad to apply at the beginning (resp. end) of the window.
pad_mode:
Method to pad data
(See https://numpy.org/doc/stable/reference/generated/numpy.pad.html)
Returns:
The numpy masked array
"""
# get all bands values as a np.ndarray of 2 or 3 dimensions depending
# on band argument (if None, all bands are read simultaneously and the
# dataset contains all bands and is thus a 3-dims array)
if bands is not None:
dataset = src.read(bands, window=window, masked=True)
else:
dataset = src.read(window=window, masked=True)
# pad the dataset if necessary
pad_width = [(0, 0)]
padx, pady = pad
if padx != (0, 0) or pady != (0, 0):
pad_width.extend([pady, padx])
pad_dataset = np.pad(dataset, pad_width=pad_width, mode=pad_mode)
# reapply a padded mask if a mask exists
if ma.is_masked(dataset):
mask = np.pad(dataset.mask, pad_width=pad_width, mode=pad_mode)
dataset = ma.masked_array(pad_dataset, mask=mask)
else:
dataset = ma.masked_array(pad_dataset)
return dataset
def _sliding_windows(image_size, window_size, overlap):
"""Create a generator of windows of a given size with an overlap.
(*) = Image boundaries
(-) = Windows boundaries
1st window starts at position (-overlap, -overlap) in the
coordinates reference of the image.
Last window ends at position (width + overlap, height + overlap)
in the coordinates reference of the image
```
(-o,-o)
|-----------| ... ---------|
| (0,0) | |
| ********|********************** |
| * | * |
|-----------| ---------|
* *
* *
* *
| * | * |
| * | * |
| ********|********************** |
| | (w,h)|
|-----------| ... ---------|
(w+o,h+o)
```
Args:
image_size (tuple or int):
Total size of the image that is windowed
window_size (tuple or int):
Window size
overlap (tuple or int):
Number of pixels that overlap previous (or next) window
Returns:
A generator of tuples. Each tuple contains: the window to read,
the pad to apply when the window is on the edge of the image, the
corresponding window of "useful data" (the window without the
overlapping pixels). Pad is given as a tuple (padx, pady). padx
and pady are also a tuple. The first (resp. second) value is the
pad to apply at the beginning (resp. end) of the window.
"""
w_width, w_height = utils.to_tuple(window_size)
width, height = utils.to_tuple(image_size)
c_overlap, r_overlap = utils.to_tuple(overlap)
# compute 2d slices from input parameters
slices = utils.slices_2d(window_size,
# shift is reduced by 2 * overlap
(w_width - 2 * c_overlap, w_height - 2 * r_overlap),
# stop is increased by overlap
(width + c_overlap, height + r_overlap),
# start is decreased by overlap
(-c_overlap, -r_overlap))
# iterate the slices and compute windows and padding
for row_start, row_stop, col_start, col_stop in slices:
# compute padding when the window is partially outside the image
padx = (max(0, -col_start), max(0, col_stop - width))
pady = (max(0, -row_start), max(0, row_stop - height))
# read window should not be outside the actual dataset window
r_window = Window.from_slices((max(0, row_start), min(height, row_stop)),
(max(0, col_start), min(width, col_stop)))
# write window corresponds to the window without overlap
w_window = Window.from_slices((row_start + r_overlap, row_stop - r_overlap),
(col_start + c_overlap, col_stop - c_overlap))
yield(r_window, (padx, pady), w_window)