-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmlp_interpolant_nd.py
More file actions
236 lines (210 loc) · 10.3 KB
/
mlp_interpolant_nd.py
File metadata and controls
236 lines (210 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import torch
import torch.nn as nn
from typing import List, Tuple, Sequence, Union, Callable, Optional
from src.models.interpolant_nd import SpectralInterpolationND
from src.models.mlp import MLP
class MLPSpectralInterpolationND(SpectralInterpolationND):
"""Spectral interpolant whose nodal values are generated by an internal MLP.
Instead of optimising *one* independent parameter per spectral grid node, we learn a
low‑dimensional neural network *f_θ : ℝᴰ → ℝ* that is **only evaluated at the grid
nodes**. Everywhere else in the domain we rely on the parent class' spectral /
barycentric interpolation machinery exactly as before.
This matches the typical PINN workflow where an MLP implicitly defines a continuous
function, while still preserving the nice derivative & quadrature properties of the
spectral grid.
"""
###########################################################################
# Initialisation
###########################################################################
def __init__(
self,
Ns: List[int],
bases: List[str],
domains: List[Tuple[float, float]],
*,
device: str = "cpu",
fd_k: Optional[List[int]] = None,
# ---- MLP‑specific kwargs ----
hidden_layers: Sequence[int] = (128, 128),
activation: torch.nn.Module = torch.tanh,
use_mlp_for_forward: bool = False,
use_mlp_for_derivatives: bool = False,
) -> None:
"""Create a spectral interpolant whose grid values come from an MLP.
Parameters
----------
Ns, bases, domains, device, fd_k
Forwarded verbatim to :class:`SpectralInterpolationND`.
hidden_layers, activation
Describe the architecture of the internal MLP *f_θ*.
use_mlp_for_derivatives
If True, use the MLP's automatic differentiation for derivatives.
If False (default), use the spectral interpolant's derivative machinery.
"""
# 1) Construct the parent object *first* so that we inherit all node logistics.
super().__init__(
Ns=Ns,
bases=bases,
domains=domains,
device=device,
fd_k=fd_k,
)
# Store flags
self.use_mlp_for_forward = use_mlp_for_forward
self.use_mlp_for_derivatives = use_mlp_for_derivatives
print(
f"[MLPSpectralInterpolationND] use_mlp_for_forward: {self.use_mlp_for_forward}, use_mlp_for_derivatives: {self.use_mlp_for_derivatives}"
)
#####################################################################
# 2) Replace the parentʼs learned *Parameter* at the nodes with a #
# **buffer**. The actual learnables now live inside `self.mlp`. #
#####################################################################
# Keep the storage but detach it so it is not treated as an optimisable param.
with torch.no_grad():
_initial_values = self.values.detach().clone()
# Remove from nn.Module parameter registry.
del self._parameters["values"] # type: ignore[attr-defined]
# Re‑register as a buffer – it will only be populated on‑the‑fly.
self.register_buffer("values", _initial_values, persistent=False)
# 3) Pre‑compute & cache the **flattened grid coordinates** so we donʼt rebuild
# them every forward pass. Shape: (∏ₖ Nₖ, D).
flat_coords = torch.stack([g.reshape(-1) for g in self.mesh], dim=1)
self.register_buffer("_flat_coords", flat_coords, persistent=False)
# 4) Build the multi‑layer perceptron using the existing MLP class
self.mlp = MLP(
n_dim=self.n_dim,
n_layers=len(hidden_layers) + 1, # +1 for output layer
hidden_dim=hidden_layers[0], # Use first hidden layer size
activation=activation,
device=device,
)
###########################################################################
# Internal helpers
###########################################################################
def _compute_node_values(self) -> torch.Tensor:
"""Evaluate the MLP at *all* spectral nodes & reshape to `(N1,…,ND)`."""
node_vals = self.mlp(self._flat_coords) # (P, 1) where P = ∏ Nᵢ
return node_vals.view(*self.Ns)
def _with_node_values(self, func: Callable, *args, **kwargs):
"""Utility to call *func* while temporarily exposing fresh node values via
`self.values` so that we can re‑use the parent classʼ derivative helpers.
"""
new_vals = self._compute_node_values()
# Swap‑in / swap‑out pattern – gradients flow through new_vals just fine.
old_vals = self.values
try:
self.values = new_vals # type: ignore[assignment]
return func(*args, **kwargs)
finally:
self.values = old_vals # type: ignore[assignment]
###########################################################################
# Public API – forward / derivative wrappers
###########################################################################
def interpolate(self, x_eval: Union[List[torch.Tensor], torch.Tensor], values=None):
if values is None:
values = self._compute_node_values()
return super().interpolate(x_eval, values=values)
def forward(self, x_eval: Union[List[torch.Tensor], torch.Tensor]): # type: ignore[override]
"""Evaluate the interpolant at arbitrary points, identical signature to base."""
if self.use_mlp_for_forward:
# Use the MLP's own logic for both list/tuple and tensor
out = self.mlp(x_eval)
return out
else:
node_vals = self._compute_node_values()
if isinstance(x_eval, (list, tuple)):
return super().interpolate(x_eval, values=node_vals)
elif isinstance(x_eval, torch.Tensor):
return super().interpolate_batch(x_eval, values=node_vals)
else:
raise TypeError(
f"Expected list/tuple of tensors or Tensor, got {type(x_eval)}"
)
def derivative(
self,
x_eval: Union[List[torch.Tensor], torch.Tensor],
k: Tuple[int, ...],
*,
use_spectral: bool = False,
) -> torch.Tensor: # type: ignore[override]
"""Compute mixed derivative of interpolant at arbitrary evaluation points.
If use_mlp_for_derivatives is True, this will use the MLP's automatic differentiation.
Otherwise, it will use the spectral interpolant's derivative machinery.
"""
if self.use_mlp_for_derivatives or self.use_mlp_for_forward:
if isinstance(x_eval, (list, tuple)):
# Form meshgrid and flatten for autograd
x_mesh = torch.meshgrid(*x_eval, indexing="ij")
x_stack = torch.stack(x_mesh, dim=-1) # (..., n_dim)
orig_shape = x_stack.shape[:-1]
x_flat = (
x_stack.reshape(-1, self.n_dim)
.clone()
.detach()
.requires_grad_(True)
)
y = self.mlp(x_flat)
# Apply derivatives in sequence
for dim, order in enumerate(k):
if order > 0:
for _ in range(order):
grad = torch.autograd.grad(
y.sum(), x_flat, create_graph=True, retain_graph=True
)[0]
y = grad[:, dim : dim + 1]
y = y.squeeze(-1) if y.ndim > 1 and y.shape[-1] == 1 else y
return y.reshape(*orig_shape)
elif isinstance(x_eval, torch.Tensor):
x_eval = x_eval.clone().detach().requires_grad_(True)
y = self.mlp(x_eval)
for dim, order in enumerate(k):
if order > 0:
for _ in range(order):
grad = torch.autograd.grad(
y.sum(), x_eval, create_graph=True, retain_graph=True
)[0]
y = grad[:, dim : dim + 1]
return y.squeeze(-1) if y.ndim > 1 and y.shape[-1] == 1 else y
else:
raise TypeError(
f"Expected list/tuple of tensors or Tensor, got {type(x_eval)}"
)
else:
# Use the spectral interpolant's derivative machinery
return self._with_node_values(
super().derivative, x_eval, k, use_spectral=use_spectral
)
def fd_derivative(
self,
x_eval: Union[List[torch.Tensor], torch.Tensor],
dim: int,
h: float,
*,
scheme: str = "central",
) -> torch.Tensor: # type: ignore[override]
"""Compute FD derivative at arbitrary points.
If use_mlp_for_derivatives is True, this will use the MLP's automatic differentiation.
Otherwise, it will use the spectral interpolant's FD derivative machinery.
"""
if self.use_mlp_for_derivatives:
# Convert list of tensors to a single tensor if needed
if isinstance(x_eval, (list, tuple)):
x_eval = torch.stack(x_eval, dim=1)
# Compute derivatives using MLP's automatic differentiation
x_eval.requires_grad_(True)
y = self.mlp(x_eval)
# Compute derivative in specified dimension
grad = torch.autograd.grad(y.sum(), x_eval, create_graph=True)[0]
return grad[:, dim]
else:
# Use the spectral interpolant's FD derivative machinery
return self._with_node_values(
super().fd_derivative, x_eval, dim, h, scheme=scheme
)
###########################################################################
# Convenience helpers – e.g. direct loss against nodal predictions
###########################################################################
@torch.no_grad()
def nodal_values(self) -> torch.Tensor:
"""Return the current *detached* nodal field (helpful for e.g. monitoring)."""
return self._compute_node_values().detach().clone()