1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
- # flake8: noqa: E402
- import argparse
- import warnings
- from pathlib import Path
- from typing import Optional, Tuple, Union
- import xarray
- warnings.filterwarnings("ignore", category=UserWarning)
- import math
- from dataclasses import dataclass
- import numpy as np
- import rioxarray
- from deadtrees.utils.data_handling import (
- make_blocks_vectorized,
- unmake_blocks_vectorized,
- )
- @dataclass
- class TileInfo:
- size: Tuple[int, int]
- subtiles: Tuple[int, int]
- def divisible_without_remainder(a, b):
- if b == 0:
- return False
- return True if a % b == 0 else False
- def inspect_tile(
- infile: Union[str, Path, xarray.DataArray],
- tile_shape: Tuple[int, int] = (8192, 8192),
- subtile_shape: Tuple[int, int] = (512, 512),
- ) -> TileInfo:
- with rioxarray.open_rasterio(infile).sel(band=1, drop=True) if not isinstance(
- infile, xarray.DataArray
- ) else infile as da:
- shape = tuple(da.shape)
- if not divisible_without_remainder(tile_shape[0], subtile_shape[0]):
- raise ValueError(f"Shapes unaligned (v): {tile_shape[0], subtile_shape[0]}")
- if not divisible_without_remainder(tile_shape[1], subtile_shape[1]):
- raise ValueError(f"Shapes unaligned (h): {tile_shape[1], subtile_shape[1]}")
- subtiles = (
- math.ceil(shape[0] / subtile_shape[0]),
- math.ceil(shape[1] / subtile_shape[1]),
- )
- return TileInfo(size=shape, subtiles=subtiles)
- class Tiler:
- def __init__(
- self,
- infile: Optional[Union[str, Path]] = None,
- tile_shape: Optional[Tuple[int, int]] = (2048, 2048),
- subtile_shape: Optional[Tuple[int, int]] = (256, 256),
- ) -> None:
- self._infile = infile
- self._tile_shape = tile_shape
- self._subtile_shape = subtile_shape
- if subtile_shape[0] != subtile_shape[1]:
- raise ValueError("Subtile required to have matching x/y dims")
- self._source: Optional[xarray.DataArray] = None
- self._target: Optional[xarray.DataArray] = None
- self._indata: Optional[np.ndarray] = None
- self._outdata: Optional[np.ndarray] = None
- self._batch_shape: Optional[np.ndarray] = None
- self._subtiles_to_use: Optional[np.ndarray] = None
- self._tile_info: Optional[TileInfo] = None
- def load_file(
- self,
- infile: Union[str, Path],
- tile_shape: Optional[Tuple[int, int]] = None,
- subtile_shape: Optional[Tuple[int, int]] = None,
- ) -> None:
- self._infile = infile
- self._tile_shape = tile_shape or self._tile_shape
- if subtile_shape:
- if subtile_shape[0] != subtile_shape[1]:
- raise ValueError("Subtile required to have matching x/y dims")
- self._subtile_shape = subtile_shape or self._subtile_shape
- self._tile_info = inspect_tile(
- self._infile, self._tile_shape, self._subtile_shape
- )
- self._source = rioxarray.open_rasterio(
- self._infile, chunks={"band": 4, "x": 256, "y": 256}
- )
- # define padded indata array and place original data inside
- sv = self._source.values
- if self._tile_shape != self._tile_info.size:
- self._indata = np.zeros((4, *self._tile_shape), dtype=self._source.dtype)
- self._indata[:, 0 : sv.shape[1], 0 : sv.shape[2]] = sv
- else:
- self._indata = sv
- # output xarray (single band)
- self._target = (
- self._source.sel(band=1, drop=True).astype("uint8").copy(deep=True)
- )
- # define padded outdata array
- self._outdata = np.zeros(self._tile_shape, dtype="uint8")
- # mark only necessary subtiles
- subtiles_mask = np.zeros(
- (
- self._tile_shape[0] // self._subtile_shape[0],
- self._tile_shape[1] // self._subtile_shape[1],
- ),
- dtype=bool,
- )
- subtiles_mask[
- 0 : self._tile_info.subtiles[0], 0 : self._tile_info.subtiles[1]
- ] = 1
- self._subtiles_to_use = subtiles_mask.ravel()
- def write_file(self, outfile: Union[str, Path]) -> None:
- if self._target is not None:
- # copy data from outdata array into dataarray
- self._target[:] = self._outdata[
- 0 : self._tile_info.size[0], 0 : self._tile_info.size[1]
- ]
- self._target.rio.to_raster(outfile, compress="LZW", tiled=True)
- def get_batches(self) -> np.ndarray:
- subtiles = make_blocks_vectorized(self._indata, self._subtile_shape[0])
- self._batch_shape = self._batch_shape or subtiles.shape
- return subtiles[self._subtiles_to_use]
- def put_batches(self, batches: np.ndarray) -> None:
- batches_expanded = []
- batch_idx = 0
- for flag in self._subtiles_to_use:
- if flag == 1:
- batches_expanded.append(batches[batch_idx])
- batch_idx += 1
- else:
- batches_expanded.append(np.zeros(batches[0].shape))
- batches_expanded = np.array(batches_expanded)
- self._outdata = unmake_blocks_vectorized(
- batches_expanded,
- self._subtile_shape[0],
- self._tile_shape[0],
- self._tile_shape[1],
- )
- # pass data into geo-registered rioxarray object (only subset of expanded tile if not complete tile)
- self._target = self._target.load()
- self._target.loc[:] = self._outdata[
- 0 : self._tile_info.size[0], 0 : self._tile_info.size[1]
- ]
|