Skip to content

Commit 270bf04

Browse files
author
Noelle Cheng
committed
sampler
1 parent 1d1f43c commit 270bf04

File tree

6 files changed

+100
-91
lines changed

6 files changed

+100
-91
lines changed

map2loop/mapdata.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,63 +1433,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y):
14331433
val = data.ReadAsArray(px, py, 1, 1)[0][0]
14341434
return val
14351435

1436-
@beartype.beartype
1437-
def __value_from_raster(self, inv_geotransform, data, x: float, y: float):
1438-
"""
1439-
Get the value from a raster dataset at the specified point
1440-
1441-
Args:
1442-
inv_geotransform (gdal.GeoTransform):
1443-
The inverse of the data's geotransform
1444-
data (numpy.array):
1445-
The raster data
1446-
x (float):
1447-
The easting coordinate of the value
1448-
y (float):
1449-
The northing coordinate of the value
1450-
1451-
Returns:
1452-
float or int: The value at the point specified
1453-
"""
1454-
px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y)
1455-
py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y)
1456-
# Clamp values to the edges of raster if past boundary, similiar to GL_CLIP
1457-
px = max(px, 0)
1458-
px = min(px, data.shape[0] - 1)
1459-
py = max(py, 0)
1460-
py = min(py, data.shape[1] - 1)
1461-
return data[px][py]
1462-
1463-
@beartype.beartype
1464-
def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame):
1465-
"""
1466-
Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates
1467-
1468-
Args:
1469-
datatype (Datatype):
1470-
The datatype of the raster map to retrieve from
1471-
df (pandas.DataFrame):
1472-
The original dataframe with 'X' and 'Y' columns
1473-
1474-
Returns:
1475-
pandas.DataFrame: The modified dataframe
1476-
"""
1477-
if len(df) <= 0:
1478-
df["Z"] = []
1479-
return df
1480-
data = self.get_map_data(datatype)
1481-
if data is None:
1482-
logger.warning("Cannot get value from data as data is not loaded")
1483-
return None
1484-
1485-
inv_geotransform = gdal.InvGeoTransform(data.GetGeoTransform())
1486-
data_array = numpy.array(data.GetRasterBand(1).ReadAsArray().T)
1487-
1488-
df["Z"] = df.apply(
1489-
lambda row: self.__value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]),
1490-
axis=1,
1491-
)
1492-
return df
14931436

14941437
@beartype.beartype
14951438
def extract_all_contacts(self, save_contacts=True):

map2loop/project.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# internal imports
22
from map2loop.fault_orientation import FaultOrientationNearest
3-
from .utils import hex_to_rgb
3+
from .utils import hex_to_rgb, set_z_values_from_raster_df
44
from .m2l_enums import VerboseLevel, ErrorState, Datatype
55
from .mapdata import MapData
66
from .sampler import Sampler, SamplerDecimator, SamplerSpacing
@@ -506,23 +506,19 @@ def sample_map_data(self):
506506
logger.info(
507507
f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}"
508508
)
509-
self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(
510-
self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data
511-
)
512-
logger.info(
513-
f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}"
514-
)
515-
self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(
516-
self.map_data.get_map_data(Datatype.STRUCTURE), self.map_data
517-
)
509+
geology_data = self.map_data.get_map_data(Datatype.GEOLOGY)
510+
dtm_data = self.map_data.get_map_data(Datatype.DTM)
511+
512+
self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data)
513+
logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}")
514+
515+
self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data)
518516
logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}")
519-
self.fault_samples = self.samplers[Datatype.FAULT].sample(
520-
self.map_data.get_map_data(Datatype.FAULT), self.map_data
521-
)
517+
518+
self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT))
522519
logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}")
523-
self.fold_samples = self.samplers[Datatype.FOLD].sample(
524-
self.map_data.get_map_data(Datatype.FOLD), self.map_data
525-
)
520+
521+
self.fold_samples = self.samplers[Datatype.FOLD].sample(self.map_data.get_map_data(Datatype.FOLD))
526522

527523
def extract_geology_contacts(self):
528524
"""
@@ -532,11 +528,9 @@ def extract_geology_contacts(self):
532528
self.map_data.extract_basal_contacts(self.stratigraphic_column.column)
533529

534530
# sample the contacts
535-
self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(
536-
self.map_data.basal_contacts
537-
)
538-
539-
self.map_data.get_value_from_raster_df(Datatype.DTM, self.map_data.sampled_contacts)
531+
self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts)
532+
dtm_data = self.map_data.get_map_data(Datatype.DTM)
533+
set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts)
540534

541535
def calculate_stratigraphic_order(self, take_best=False):
542536
"""
@@ -714,7 +708,8 @@ def calculate_fault_orientations(self):
714708
self.map_data.get_map_data(Datatype.FAULT_ORIENTATION),
715709
self.map_data,
716710
)
717-
self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_orientations)
711+
dtm_data = self.map_data.get_map_data(Datatype.DTM)
712+
set_z_values_from_raster_df(dtm_data, self.fault_orientations)
718713
else:
719714
logger.warning(
720715
"No fault orientation data found, skipping fault orientation calculation"
@@ -739,7 +734,8 @@ def summarise_fault_data(self):
739734
"""
740735
Use the fault shapefile to make a summary of each fault by name
741736
"""
742-
self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_samples)
737+
dtm_data = self.map_data.get_map_data(Datatype.DTM)
738+
set_z_values_from_raster_df(dtm_data, self.fault_samples)
743739

744740
self.deformation_history.summarise_data(self.fault_samples)
745741
self.deformation_history.faults = self.throw_calculator.compute(

map2loop/sampler.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# internal imports
22
from .m2l_enums import Datatype
33
from .mapdata import MapData
4+
from .utils import set_z_values_from_raster_df
45

56
# external imports
67
from abc import ABC, abstractmethod
@@ -10,6 +11,7 @@
1011
import shapely
1112
import numpy
1213
from typing import Optional
14+
from osgeo import gdal
1315

1416

1517
class Sampler(ABC):
@@ -38,7 +40,7 @@ def type(self):
3840
@beartype.beartype
3941
@abstractmethod
4042
def sample(
41-
self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None
43+
self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None
4244
) -> pandas.DataFrame:
4345
"""
4446
Execute sampling method (abstract method)
@@ -73,7 +75,7 @@ def __init__(self, decimation: int = 1):
7375

7476
@beartype.beartype
7577
def sample(
76-
self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None
78+
self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None
7779
) -> pandas.DataFrame:
7880
"""
7981
Execute sample method takes full point data, samples the data and returns the decimated points
@@ -87,10 +89,16 @@ def sample(
8789
data = spatial_data.copy()
8890
data["X"] = data.geometry.x
8991
data["Y"] = data.geometry.y
90-
data["Z"] = map_data.get_value_from_raster_df(Datatype.DTM, data)["Z"]
91-
data["layerID"] = geopandas.sjoin(
92-
data, map_data.get_map_data(Datatype.GEOLOGY), how='left'
93-
)['index_right']
92+
if dtm_data is not None:
93+
data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"]
94+
else:
95+
data["Z"] = None
96+
if geology_data is not None:
97+
data["layerID"] = geopandas.sjoin(
98+
data, geology_data, how='left'
99+
)['index_right']
100+
else:
101+
data["layerID"] = None
94102
data.reset_index(drop=True, inplace=True)
95103

96104
return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry"))
@@ -118,7 +126,7 @@ def __init__(self, spacing: float = 50.0):
118126

119127
@beartype.beartype
120128
def sample(
121-
self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None
129+
self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None
122130
) -> pandas.DataFrame:
123131
"""
124132
Execute sample method takes full point data, samples the data and returns the sampled points

map2loop/thickness_calculator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
calculate_endpoints,
66
multiline_to_line,
77
find_segment_strike_from_pt,
8+
set_z_values_from_raster_df
89
)
910
from .m2l_enums import Datatype
1011
from .interpolators import DipDipDirectionInterpolator
@@ -271,7 +272,8 @@ def compute(
271272
# set the crs of the contacts to the crs of the units
272273
contacts = contacts.set_crs(crs=basal_contacts.crs)
273274
# get the elevation Z of the contacts
274-
contacts = map_data.get_value_from_raster_df(Datatype.DTM, contacts)
275+
dtm_data = map_data.get_map_data(Datatype.DTM)
276+
contacts = set_z_values_from_raster_df(dtm_data, contacts)
275277
# update the geometry of the contact points to include the Z value
276278
contacts["geometry"] = contacts.apply(
277279
lambda row: shapely.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1
@@ -299,7 +301,8 @@ def compute(
299301
# set the crs of the interpolated orientations to the crs of the units
300302
interpolated_orientations = interpolated_orientations.set_crs(crs=basal_contacts.crs)
301303
# get the elevation Z of the interpolated points
302-
interpolated = map_data.get_value_from_raster_df(Datatype.DTM, interpolated_orientations)
304+
dtm_data = map_data.get_map_data(Datatype.DTM)
305+
interpolated = set_z_values_from_raster_df(dtm_data, interpolated_orientations)
303306
# update the geometry of the interpolated points to include the Z value
304307
interpolated["geometry"] = interpolated.apply(
305308
lambda row: shapely.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1

map2loop/utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas
88
import re
99
import json
10+
from osgeo import gdal
1011

1112
from .logging import getLogger
1213
logger = getLogger(__name__)
@@ -527,4 +528,62 @@ def update_from_legacy_file(
527528
with open(json_save_path, "w") as f:
528529
json.dump(parsed_data, f, indent=4)
529530

530-
return file_map
531+
return file_map
532+
533+
@beartype.beartype
534+
def value_from_raster(inv_geotransform, data, x: float, y: float):
535+
"""
536+
Get the value from a raster dataset at the specified point
537+
538+
Args:
539+
inv_geotransform (gdal.GeoTransform):
540+
The inverse of the data's geotransform
541+
data (numpy.array):
542+
The raster data
543+
x (float):
544+
The easting coordinate of the value
545+
y (float):
546+
The northing coordinate of the value
547+
548+
Returns:
549+
float or int: The value at the point specified
550+
"""
551+
px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y)
552+
py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y)
553+
# Clamp values to the edges of raster if past boundary, similiar to GL_CLIP
554+
px = max(px, 0)
555+
px = min(px, data.shape[0] - 1)
556+
py = max(py, 0)
557+
py = min(py, data.shape[1] - 1)
558+
return data[px][py]
559+
560+
@beartype.beartype
561+
def set_z_values_from_raster_df(dtm_data: gdal.Dataset, df: pandas.DataFrame):
562+
"""
563+
Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates
564+
565+
Args:
566+
dtm_data (gdal.Dataset):
567+
Dtm data from raster map
568+
df (pandas.DataFrame):
569+
The original dataframe with 'X' and 'Y' columns
570+
571+
Returns:
572+
pandas.DataFrame: The modified dataframe
573+
"""
574+
if len(df) <= 0:
575+
df["Z"] = []
576+
return df
577+
578+
if dtm_data is None:
579+
logger.warning("Cannot get value from data as data is not loaded")
580+
return None
581+
582+
inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform())
583+
data_array = numpy.array(dtm_data.GetRasterBand(1).ReadAsArray().T)
584+
585+
df["Z"] = df.apply(
586+
lambda row: value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]),
587+
axis=1,
588+
)
589+
return df

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name = 'map2loop'
66
description = 'Generate 3D model data from 2D maps.'
77
authors = [{name = 'Loop team'}]
88
readme = 'README.md'
9-
requires-python = '>=3.8'
9+
requires-python = '>=3.8,<3.13'
1010
keywords = [ "earth sciences",
1111
"geology",
1212
"3-D modelling",

0 commit comments

Comments
 (0)