Skip to content

Commit a230e08

Browse files
authored
NHIRS-8 - Improve memory usage for large raster input (#41)
* NHIRS-8: Improve memory usage for large raster input * No longer reads the entire raster into memory, reads only the cells defined in the exposure data * Removed the ability to pass an in-memory array via 'from_array' * Added a 'ThreadPoolExecutor' for some performance improvement when reading hazard data for large exposure datasets * NHIRS-8: Update docstring Co-authored-by: Callum McKenna <[email protected]>
1 parent 7d0d83a commit a230e08

File tree

6 files changed

+112
-171
lines changed

6 files changed

+112
-171
lines changed

hazimp/jobs/jobs.py

Lines changed: 27 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -612,13 +612,9 @@ def __init__(self):
612612
super(LoadRaster, self).__init__()
613613
self.call_funct = LOADRASTER
614614

615-
# R0913:326: Too many arguments (9/6)
616-
# pylint: disable=R0913
617-
def __call__(self, context, attribute_label,
615+
def __call__(self, context, attribute_label, file_list,
618616
clip_exposure2all_hazards=False,
619-
file_list=None, file_format=None, variable=None,
620-
raster=None, upper_left_x=None, upper_left_y=None,
621-
cell_size=None, no_data_value=None):
617+
file_format=None, variable=None, no_data_value=None):
622618
"""
623619
Load one or more files and get the value for all the
624620
exposure points. All files have to be of the same attribute.
@@ -628,79 +624,45 @@ def __call__(self, context, attribute_label,
628624
:param attribute_label: The string to be associated with this data.
629625
:param clip_exposure2all_hazards: True if the exposure data is
630626
clippped to the hazard data, so no hazard values are ignored.
631-
632627
:param file_list: A list of files or a single file to be loaded.
633-
OR
634-
:param raster: A 2D numeric array of the raster values, North is up.
635-
:param upper_left_x: The longitude at the upper left corner.
636-
:param upper_left_y: The latitude at the upper left corner.
637-
:param cell_size: The cell size.
638628
:param no_data_value: Values in the raster that represent no data.
639629
640-
641630
Context return:
642631
exposure_att: Add the file values into this dictionary.
643632
key: column titles
644633
value: column values, except the title
645634
"""
646635

647-
# We need a file or a full set of raster info.
648-
if file_list is None:
649-
# The raster info is being passed as an array
650-
assert raster is not None
651-
assert upper_left_x is not None
652-
assert upper_left_y is not None
653-
assert cell_size is not None
654-
assert no_data_value is not None
655-
a_raster = raster_module.Raster.from_array(
656-
raster, upper_left_x,
657-
upper_left_y,
658-
cell_size,
659-
no_data_value)
660-
661-
if clip_exposure2all_hazards:
662-
# Reduce the context to the hazard area
663-
# before the raster info has been added to the context
664-
extent = a_raster.extent()
665-
context.clip_exposure(*extent)
666-
667-
file_data = a_raster.raster_data_at_points(
668-
context.exposure_long,
669-
context.exposure_lat)
670-
file_data = np.where(file_data == no_data_value, np.NAN,
671-
file_data)
672-
context.exposure_att[attribute_label] = file_data
673-
else:
674-
if isinstance(file_list, str):
675-
file_list = [file_list]
676-
677-
for f in file_list:
678-
f = misc.download_file_from_s3_if_needed(f)
679-
dt = misc.get_file_mtime(f)
680-
atts = {"dcterms:title": "Source hazard data",
681-
"prov:type": "prov:Dataset",
682-
"prov:atLocation": os.path.basename(f),
683-
"prov:format": os.path.splitext(f)[1].replace('.', ''),
684-
"prov:generatedAtTime": dt, }
685-
if file_format == 'nc' and variable:
686-
atts['prov:variable'] = variable
687-
hazent = context.prov.entity(":Hazard data", atts)
688-
context.prov.used(context.provlabel, hazent)
636+
if isinstance(file_list, str):
637+
file_list = [file_list]
689638

639+
for f in file_list:
640+
f = misc.download_file_from_s3_if_needed(f)
641+
dt = misc.get_file_mtime(f)
642+
atts = {"dcterms:title": "Source hazard data",
643+
"prov:type": "prov:Dataset",
644+
"prov:atLocation": os.path.basename(f),
645+
"prov:format": os.path.splitext(f)[1].replace('.', ''),
646+
"prov:generatedAtTime": dt, }
690647
if file_format == 'nc' and variable:
691-
file_list = misc.mod_file_list(file_list, variable)
648+
atts['prov:variable'] = variable
649+
hazent = context.prov.entity(":Hazard data", atts)
650+
context.prov.used(context.provlabel, hazent)
651+
652+
if file_format == 'nc' and variable:
653+
file_list = misc.mod_file_list(file_list, variable)
692654

693-
file_data, extent = raster_module.files_raster_data_at_points(
694-
context.exposure_long,
695-
context.exposure_lat, file_list)
696-
file_data[file_data == no_data_value] = np.NAN
655+
file_data, extent = raster_module.files_raster_data_at_points(
656+
context.exposure_long,
657+
context.exposure_lat, file_list)
658+
file_data[file_data == no_data_value] = np.NAN
697659

698-
context.exposure_att[attribute_label] = file_data
660+
context.exposure_att[attribute_label] = file_data
699661

700-
if clip_exposure2all_hazards:
701-
# Clipping the exposure points after the data has been added.
702-
# Not optimised for speed, but easy to implement.
703-
context.clip_exposure(*extent)
662+
if clip_exposure2all_hazards:
663+
# Clipping the exposure points after the data has been added.
664+
# Not optimised for speed, but easy to implement.
665+
context.clip_exposure(*extent)
704666

705667

706668
class AggregateLoss(Job):

hazimp/raster.py

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
"""
1919
Manipulate raster data.
20-
21-
Currently loads the entire raster layer into memory, which can blow out memory
22-
usage in some cases.
23-
24-
# TODO: optimise raster loading and reading
2520
"""
21+
import threading
22+
from concurrent.futures.thread import ThreadPoolExecutor
2623

27-
import numpy
2824
import gdal
29-
from gdalconst import GA_ReadOnly, GDT_Float32
25+
import numpy
26+
from gdalconst import GA_ReadOnly
3027

3128

3229
class Raster(object):
@@ -43,13 +40,11 @@ class Raster(object):
4340
# R0913: 34:Raster.__init__: Too many arguments (9/6)
4441
# pylint: disable=R0902, R0913
4542

46-
def __init__(self, raster, upper_left_x, upper_left_y,
43+
def __init__(self, filename, upper_left_x, upper_left_y,
4744
x_pixel, y_pixel, no_data_value, x_size, y_size):
4845
"""
4946
50-
:param raster: A 2D numeric array of the raster values, North is up.
51-
The values are listed in 'English reading order' i.e.
52-
left-right and top-down.
47+
:param filename: The raster file path string.
5348
:param upper_left_x: The longitude at the upper left corner of the
5449
top left pixel.
5550
:param upper_left_y: The latitude at the upper left corner of the
@@ -62,23 +57,22 @@ def __init__(self, raster, upper_left_x, upper_left_y,
6257
:param y_size: Number of rows.
6358
:param no_data_value: Values in the raster that represent no data.
6459
"""
65-
self.raster = raster
60+
self.filename = filename
6661
self.ul_x = upper_left_x
6762
self.ul_y = upper_left_y
6863
self.x_pixel = x_pixel
6964
self.y_pixel = y_pixel
7065
self.no_data_value = no_data_value
7166
self.x_size = x_size
7267
self.y_size = y_size
73-
self.raster[self.raster == self.no_data_value] = numpy.NAN
7468

7569
@classmethod
7670
def from_file(cls, filename):
7771
"""
7872
Load a file in a raster file format known to GDAL.
7973
Note, image must be 'North up'.
8074
81-
:param filename: The csv file path string.
75+
:param filename: The raster file path string.
8276
:returns: A Raster instance.
8377
"""
8478

@@ -99,40 +93,9 @@ def from_file(cls, filename):
9993
y_size = dataset.RasterYSize # This will be a negative value.
10094
band = dataset.GetRasterBand(1)
10195
no_data_value = band.GetNoDataValue()
102-
raster = band.ReadAsArray(0, 0, x_size, y_size, buf_type=GDT_Float32)
103-
instance = cls(raster, upper_left_x, upper_left_y,
96+
instance = cls(filename, upper_left_x, upper_left_y,
10497
x_pixel, y_pixel, no_data_value, x_size, y_size)
105-
return instance
106-
107-
@classmethod
108-
def from_array(cls, raster, upper_left_x, upper_left_y,
109-
cell_size, no_data_value, dtype='float'):
110-
"""
111-
Convert numeric array of raster data and info to a raster instance.
112-
The values are listed in 'English reading order' i.e.
113-
left-right and top-down.
114-
115-
:param raster: A 2D numeric array of the raster values, North is up.
116-
:param upper_left_x: The longitude at the upper left corner.
117-
:param upper_left_y: The latitude at the upper left corner.
118-
:param cell_size: The cell size.
119-
:param no_data_value: Values in the raster that represent no data.
120-
:param dtype: Data type for the raster values (default float).
121-
:returns: A Raster instance
122-
"""
123-
raster = numpy.array(raster, dtype=dtype, copy=False)
124-
if not len(raster.shape) == 2:
125-
msg = ('Bad Raster shape %s' % (str(raster.shape)))
126-
raise TypeError(msg)
127-
128-
x_size = raster.shape[1]
129-
y_size = raster.shape[0]
130-
131-
x_pixel = cell_size
132-
y_pixel = -cell_size
13398

134-
instance = cls(raster, upper_left_x, upper_left_y,
135-
x_pixel, y_pixel, no_data_value, x_size, y_size)
13699
return instance
137100

138101
def raster_data_at_points(self, lon, lat):
@@ -168,11 +131,25 @@ def raster_data_at_points(self, lon, lat):
168131
raw_row_offset = (lat - self.ul_y) / self.y_pixel
169132
row_offset = numpy.trunc(raw_row_offset).astype(int)
170133

171-
values[good_indexes] = self.raster[row_offset[good_indexes],
172-
col_offset[good_indexes]]
173-
# Change NODATA_value to NAN
174-
values = numpy.where(values == self.no_data_value, numpy.NAN,
175-
values)
134+
data = threading.local()
135+
136+
def read_cell(i, x, y):
137+
if 'band' not in data.__dict__:
138+
data.dataset = gdal.Open(self.filename, GA_ReadOnly)
139+
data.band = data.dataset.GetRasterBand(1)
140+
141+
values[i] = data.band.ReadAsArray(x, y, 1, 1)[0]
142+
143+
with ThreadPoolExecutor() as executor:
144+
for index in good_indexes:
145+
executor.submit(read_cell, index,
146+
col_offset[index].item(),
147+
row_offset[index].item())
148+
149+
# Change NODATA_value to NAN
150+
values = numpy.where(values == self.no_data_value, numpy.NAN,
151+
values)
152+
176153
return values
177154

178155
def extent(self):

tests/data/basic_raster.aai

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
ncols 3
2+
nrows 2
3+
xllcorner +0.
4+
yllcorner +8.
5+
cellsize 1
6+
NODATA_value -9999
7+
1 2 -9999
8+
4 5 6

tests/jobs/test_jobs.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from hazimp import context
4747
from hazimp import misc
4848
from hazimp import parallel
49+
from tests import CWD
4950
from tests.jobs.test_vulnerability_model import build_example1
5051

5152
prov = mock.MagicMock(name='prov.model')
@@ -569,20 +570,11 @@ def test_load_raster_clippingIII(self, mock_used):
569570
inst(con_in, **test_kwargs)
570571
os.remove(f.name)
571572

572-
raster = array([[1, 2, -9999], [4, 5, 6]])
573-
upper_left_x = 0
574-
upper_left_y = 10
575-
cell_size = 1
576-
no_data_value = -9999
577573
haz_v = 'haz_v'
578574
inst = JOBS[LOADRASTER]
579575
test_kwargs = {'attribute_label': haz_v,
580576
'clip_exposure2all_hazards': True,
581-
'raster': raster,
582-
'upper_left_x': upper_left_x,
583-
'upper_left_y': upper_left_y,
584-
'cell_size': cell_size,
585-
'no_data_value': no_data_value}
577+
'file_list': [str(CWD / 'data/basic_raster.aai')]}
586578
inst(con_in, **test_kwargs)
587579

588580
# There should be only no exposure points

tests/test_raster.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
Test the raster module.
3232
"""
3333

34-
import unittest
35-
import tempfile
3634
import os
35+
import tempfile
36+
import unittest
3737

3838
import numpy
3939
from scipy import asarray, allclose, nan
@@ -179,42 +179,6 @@ def test2_files_raster_data_at_points(self):
179179
for a_file in files:
180180
os.remove(a_file)
181181

182-
def test3_raster_data_from_array(self):
183-
# A test based on this info;
184-
# http://en.wikipedia.org/wiki/Esri_grid
185-
# Let's hope no one edits the data....
186-
raster = [[-9999, -9999, 5, 2], [-9999, 20, 100, 36],
187-
[3, 8, 35, 10], [32, 42, 50, 6],
188-
[88, 75, 27, 9], [13, 5, 1, -9999]]
189-
upper_left_x = 0.
190-
upper_left_y = 300.
191-
cell_size = 50.0
192-
no_data_value = -9999
193-
194-
# Just outside the midpoint of all sides
195-
lon = asarray([125, 125, 125, 125, 125, 125])
196-
lat = asarray([275, 225, 175, 125, 75, 25])
197-
198-
raster = Raster.from_array(raster, upper_left_x, upper_left_y,
199-
cell_size, no_data_value)
200-
self.assertEqual(raster.ul_x, 0)
201-
self.assertEqual(raster.ul_y, 300)
202-
self.assertEqual(raster.x_pixel, 50)
203-
self.assertEqual(raster.y_pixel, -50)
204-
self.assertEqual(raster.x_size, 4)
205-
self.assertEqual(raster.y_size, 6)
206-
207-
data = raster.raster_data_at_points(lon, lat)
208-
self.assertTrue(allclose(data, asarray([5.0, 100.0, 35.0,
209-
50.0, 27.0, 1.0])))
210-
211-
# testing extent
212-
min_long, min_lat, max_long, max_lat = raster.extent()
213-
self.assertEqual(min_long, 0)
214-
self.assertEqual(min_lat, 0)
215-
self.assertEqual(max_long, 200)
216-
self.assertEqual(max_lat, 300)
217-
218182
def test3_recalc_max(self):
219183
max_extent = (0, 0, 0, 0)
220184
extent = [-10, -20, 20, 40]
@@ -232,7 +196,6 @@ def test3_recalc_max(self):
232196
self.assertEqual(old_max_extent, max_extent)
233197

234198

235-
# -------------------------------------------------------------
236199
if __name__ == "__main__":
237200
Suite = unittest.makeSuite(TestRaster, 'test')
238201
Runner = unittest.TextTestRunner()

0 commit comments

Comments
 (0)