minor changes
This commit is contained in:
@@ -19,7 +19,8 @@ class DeepONet(torch.nn.Module):
|
|||||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, branch_net, trunk_net, output_variables, inner_size=10):
|
def __init__(self, branch_net, trunk_net, output_variables, inner_size=10,
|
||||||
|
features=None, features_net=None):
|
||||||
"""
|
"""
|
||||||
:param torch.nn.Module branch_net: the neural network to use as branch
|
:param torch.nn.Module branch_net: the neural network to use as branch
|
||||||
model. It has to take as input a :class:`LabelTensor`. The number
|
model. It has to take as input a :class:`LabelTensor`. The number
|
||||||
@@ -64,17 +65,48 @@ class DeepONet(torch.nn.Module):
|
|||||||
self.output_variables = output_variables
|
self.output_variables = output_variables
|
||||||
self.output_dimension = len(output_variables)
|
self.output_dimension = len(output_variables)
|
||||||
|
|
||||||
|
trunk_out_dim = trunk_net.layers[-1].out_features
|
||||||
|
branch_out_dim = branch_net.layers[-1].out_features
|
||||||
|
|
||||||
|
if trunk_out_dim != branch_out_dim:
|
||||||
|
raise ValueError('Branch and trunk networks have not the same '
|
||||||
|
'output dimension.')
|
||||||
|
|
||||||
self.trunk_net = trunk_net
|
self.trunk_net = trunk_net
|
||||||
self.branch_net = branch_net
|
self.branch_net = branch_net
|
||||||
|
|
||||||
if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
|
if features:
|
||||||
if self.branch_net.output_dimension == self.trunk_net.output_dimension:
|
# if len(features) != features_net.layers[0].in_features:
|
||||||
self.inner_size = self.branch_net.output_dimension
|
# raise ValueError('Incompatible features')
|
||||||
else:
|
# if trunk_out_dim != features_net.layers[-1].out_features:
|
||||||
raise ValueError('Branch and trunk networks have not the same output dimension.')
|
# raise ValueError('Incompatible features')
|
||||||
else:
|
|
||||||
warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
|
self.features = features
|
||||||
self.inner_size = self.output_dimension*inner_size
|
# self.features_net = nn.Sequential(
|
||||||
|
# nn.Linear(len(features), 10), nn.Softplus(),
|
||||||
|
# # nn.Linear(10, 10), nn.Softplus(),
|
||||||
|
# nn.Linear(10, trunk_out_dim)
|
||||||
|
# )
|
||||||
|
self.features_net = nn.Sequential(
|
||||||
|
nn.Linear(len(features), trunk_out_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension)
|
||||||
|
|
||||||
|
# print(self.branch_net.output_variables)
|
||||||
|
# print(self.trunk_net.output_variables)
|
||||||
|
# if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
|
||||||
|
# if self.branch_net.output_dimension == self.trunk_net.output_dimension:
|
||||||
|
# self.inner_size = self.branch_net.output_dimension
|
||||||
|
# print('qui')
|
||||||
|
# else:
|
||||||
|
# raise ValueError('Branch and trunk networks have not the same output dimension.')
|
||||||
|
# else:
|
||||||
|
# warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
|
||||||
|
# self.inner_size = self.output_dimension*inner_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self):
|
def input_variables(self):
|
||||||
@@ -89,17 +121,27 @@ class DeepONet(torch.nn.Module):
|
|||||||
:return: the output computed by the model.
|
:return: the output computed by the model.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
input_feature = []
|
||||||
|
for feature in self.features:
|
||||||
|
#print(feature)
|
||||||
|
input_feature.append(feature(x))
|
||||||
|
input_feature = torch.cat(input_feature, dim=1)
|
||||||
|
|
||||||
branch_output = self.branch_net(
|
branch_output = self.branch_net(
|
||||||
x.extract(self.branch_net.input_variables))
|
x.extract(self.branch_net.input_variables))
|
||||||
trunk_output = self.trunk_net(
|
trunk_output = self.trunk_net(
|
||||||
x.extract(self.trunk_net.input_variables))
|
x.extract(self.trunk_net.input_variables))
|
||||||
local_size = int(self.inner_size/self.output_dimension)
|
feat_output = self.features_net(input_feature)
|
||||||
for i, var in enumerate(self.output_variables):
|
output_ = self.reduction(branch_output * trunk_output * feat_output)
|
||||||
start = i*local_size
|
output_ = self.reduction(trunk_output * feat_output)
|
||||||
stop = (i+1)*local_size
|
output_ = LabelTensor(output_, self.output_variables)
|
||||||
local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
|
# local_size = int(trunk_output.shape[1]/self.output_dimension)
|
||||||
if i==0:
|
# for i, var in enumerate(self.output_variables):
|
||||||
output_ = local_output
|
# start = i*local_size
|
||||||
else:
|
# stop = (i+1)*local_size
|
||||||
output_ = output_.append(local_output)
|
# local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
|
||||||
|
# if i==0:
|
||||||
|
# output_ = local_output
|
||||||
|
# else:
|
||||||
|
# output_ = output_.append(local_output)
|
||||||
return output_
|
return output_
|
||||||
|
|||||||
61
pina/pinn.py
61
pina/pinn.py
@@ -120,21 +120,53 @@ class PINN(object):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
|
def span_pts(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
>>> pinn.span_pts(n=10, mode='grid')
|
||||||
|
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def merge_tensors(tensors): # name to be changed
|
||||||
|
if len(tensors) == 2:
|
||||||
|
tensor1 = tensors[0]
|
||||||
|
tensor2 = tensors[1]
|
||||||
|
n1 = tensor1.shape[0]
|
||||||
|
n2 = tensor2.shape[0]
|
||||||
|
|
||||||
|
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||||
|
tensor2 = LabelTensor(
|
||||||
|
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels)
|
||||||
|
return tensor1.append(tensor2)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(args[0], int) and isinstance(args[1], str):
|
||||||
|
pass
|
||||||
|
variables = self.problem.input_variables
|
||||||
|
elif all(isinstance(arg, dict) for arg in args):
|
||||||
|
print(args)
|
||||||
|
arguments = args
|
||||||
|
pass
|
||||||
|
elif all(key in kwargs for key in ['n', 'mode']):
|
||||||
|
variables = self.problem.input_variables
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
locations = kwargs.get('locations', 'all')
|
||||||
|
|
||||||
if locations == 'all':
|
if locations == 'all':
|
||||||
locations = [condition for condition in self.problem.conditions]
|
locations = [condition for condition in self.problem.conditions]
|
||||||
for location in locations:
|
for location in locations:
|
||||||
condition = self.problem.conditions[location]
|
condition = self.problem.conditions[location]
|
||||||
|
|
||||||
try:
|
pts = merge_tensors([
|
||||||
pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
|
condition.location.sample(
|
||||||
if n_params != 0:
|
argument['n'],
|
||||||
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
|
argument['mode'],
|
||||||
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
|
variables=argument['variables'])
|
||||||
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
|
for argument in arguments])
|
||||||
pts = pts.append(pts_params)
|
|
||||||
except:
|
|
||||||
pts = condition.input_points
|
|
||||||
self.input_pts[location] = pts #.double() # TODO
|
self.input_pts[location] = pts #.double() # TODO
|
||||||
self.input_pts[location] = (
|
self.input_pts[location] = (
|
||||||
self.input_pts[location].to(dtype=self.dtype,
|
self.input_pts[location].to(dtype=self.dtype,
|
||||||
@@ -168,9 +200,10 @@ class PINN(object):
|
|||||||
|
|
||||||
for condition_name in self.problem.conditions:
|
for condition_name in self.problem.conditions:
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = self.input_pts[condition_name]
|
|
||||||
predicted = self.model(pts)
|
|
||||||
if hasattr(condition, 'function'):
|
if hasattr(condition, 'function'):
|
||||||
|
pts = self.input_pts[condition_name]
|
||||||
|
predicted = self.model(pts)
|
||||||
if isinstance(condition.function, list):
|
if isinstance(condition.function, list):
|
||||||
for function in condition.function:
|
for function in condition.function:
|
||||||
residuals = function(pts, predicted)
|
residuals = function(pts, predicted)
|
||||||
@@ -181,6 +214,10 @@ class PINN(object):
|
|||||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||||
losses.append(local_loss)
|
losses.append(local_loss)
|
||||||
elif hasattr(condition, 'output_points'):
|
elif hasattr(condition, 'output_points'):
|
||||||
|
pts = condition.input_points
|
||||||
|
# print(pts)
|
||||||
|
predicted = self.model(pts)
|
||||||
|
# print(predicted)
|
||||||
residuals = predicted - condition.output_points
|
residuals = predicted - condition.output_points
|
||||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||||
losses.append(local_loss)
|
losses.append(local_loss)
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ class Plotter:
|
|||||||
ind_dict[location] = ind_to_exclude
|
ind_dict[location] = ind_to_exclude
|
||||||
import functools
|
import functools
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
final_inds = reduce(np.intersect1d, ind_dict.values())
|
final_inds = reduce(np.intersect1d, ind_dict.values())
|
||||||
predicted_output = obj.model(pts)
|
predicted_output = obj.model(pts)
|
||||||
predicted_output = predicted_output.extract([component])
|
predicted_output = predicted_output.extract([component])
|
||||||
@@ -122,7 +123,7 @@ class Plotter:
|
|||||||
fig.colorbar(cb, ax=axes[2])
|
fig.colorbar(cb, ax=axes[2])
|
||||||
else:
|
else:
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach(), levels=32)
|
||||||
fig.colorbar(cb, ax=axes)
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class Span(Location):
|
|||||||
if not len(spatial_fixed_)==0:
|
if not len(spatial_fixed_)==0:
|
||||||
pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
|
pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
|
||||||
dtype=pts.dtype) * fixed
|
dtype=pts.dtype) * fixed
|
||||||
|
pts_fixed_ = pts_fixed_.float()
|
||||||
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
|
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
|
||||||
pts_range_ = pts_range_.append(pts_fixed_)
|
pts_range_ = pts_range_.append(pts_fixed_)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user