minor changes
This commit is contained in:
@@ -19,7 +19,8 @@ class DeepONet(torch.nn.Module):
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
|
||||
"""
|
||||
def __init__(self, branch_net, trunk_net, output_variables, inner_size=10):
|
||||
def __init__(self, branch_net, trunk_net, output_variables, inner_size=10,
|
||||
features=None, features_net=None):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: the neural network to use as branch
|
||||
model. It has to take as input a :class:`LabelTensor`. The number
|
||||
@@ -64,17 +65,48 @@ class DeepONet(torch.nn.Module):
|
||||
self.output_variables = output_variables
|
||||
self.output_dimension = len(output_variables)
|
||||
|
||||
trunk_out_dim = trunk_net.layers[-1].out_features
|
||||
branch_out_dim = branch_net.layers[-1].out_features
|
||||
|
||||
if trunk_out_dim != branch_out_dim:
|
||||
raise ValueError('Branch and trunk networks have not the same '
|
||||
'output dimension.')
|
||||
|
||||
self.trunk_net = trunk_net
|
||||
self.branch_net = branch_net
|
||||
|
||||
if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
|
||||
if self.branch_net.output_dimension == self.trunk_net.output_dimension:
|
||||
self.inner_size = self.branch_net.output_dimension
|
||||
else:
|
||||
raise ValueError('Branch and trunk networks have not the same output dimension.')
|
||||
else:
|
||||
warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
|
||||
self.inner_size = self.output_dimension*inner_size
|
||||
if features:
|
||||
# if len(features) != features_net.layers[0].in_features:
|
||||
# raise ValueError('Incompatible features')
|
||||
# if trunk_out_dim != features_net.layers[-1].out_features:
|
||||
# raise ValueError('Incompatible features')
|
||||
|
||||
self.features = features
|
||||
# self.features_net = nn.Sequential(
|
||||
# nn.Linear(len(features), 10), nn.Softplus(),
|
||||
# # nn.Linear(10, 10), nn.Softplus(),
|
||||
# nn.Linear(10, trunk_out_dim)
|
||||
# )
|
||||
self.features_net = nn.Sequential(
|
||||
nn.Linear(len(features), trunk_out_dim)
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension)
|
||||
|
||||
# print(self.branch_net.output_variables)
|
||||
# print(self.trunk_net.output_variables)
|
||||
# if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
|
||||
# if self.branch_net.output_dimension == self.trunk_net.output_dimension:
|
||||
# self.inner_size = self.branch_net.output_dimension
|
||||
# print('qui')
|
||||
# else:
|
||||
# raise ValueError('Branch and trunk networks have not the same output dimension.')
|
||||
# else:
|
||||
# warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
|
||||
# self.inner_size = self.output_dimension*inner_size
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
@@ -89,17 +121,27 @@ class DeepONet(torch.nn.Module):
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
input_feature = []
|
||||
for feature in self.features:
|
||||
#print(feature)
|
||||
input_feature.append(feature(x))
|
||||
input_feature = torch.cat(input_feature, dim=1)
|
||||
|
||||
branch_output = self.branch_net(
|
||||
x.extract(self.branch_net.input_variables))
|
||||
trunk_output = self.trunk_net(
|
||||
x.extract(self.trunk_net.input_variables))
|
||||
local_size = int(self.inner_size/self.output_dimension)
|
||||
for i, var in enumerate(self.output_variables):
|
||||
start = i*local_size
|
||||
stop = (i+1)*local_size
|
||||
local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
|
||||
if i==0:
|
||||
output_ = local_output
|
||||
else:
|
||||
output_ = output_.append(local_output)
|
||||
feat_output = self.features_net(input_feature)
|
||||
output_ = self.reduction(branch_output * trunk_output * feat_output)
|
||||
output_ = self.reduction(trunk_output * feat_output)
|
||||
output_ = LabelTensor(output_, self.output_variables)
|
||||
# local_size = int(trunk_output.shape[1]/self.output_dimension)
|
||||
# for i, var in enumerate(self.output_variables):
|
||||
# start = i*local_size
|
||||
# stop = (i+1)*local_size
|
||||
# local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
|
||||
# if i==0:
|
||||
# output_ = local_output
|
||||
# else:
|
||||
# output_ = output_.append(local_output)
|
||||
return output_
|
||||
|
||||
61
pina/pinn.py
61
pina/pinn.py
@@ -120,21 +120,53 @@ class PINN(object):
|
||||
return self
|
||||
|
||||
|
||||
def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
|
||||
def span_pts(self, *args, **kwargs):
|
||||
"""
|
||||
>>> pinn.span_pts(n=10, mode='grid')
|
||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||
"""
|
||||
|
||||
def merge_tensors(tensors): # name to be changed
|
||||
if len(tensors) == 2:
|
||||
tensor1 = tensors[0]
|
||||
tensor2 = tensors[1]
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels)
|
||||
return tensor1.append(tensor2)
|
||||
else:
|
||||
pass
|
||||
|
||||
if isinstance(args[0], int) and isinstance(args[1], str):
|
||||
pass
|
||||
variables = self.problem.input_variables
|
||||
elif all(isinstance(arg, dict) for arg in args):
|
||||
print(args)
|
||||
arguments = args
|
||||
pass
|
||||
elif all(key in kwargs for key in ['n', 'mode']):
|
||||
variables = self.problem.input_variables
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
locations = kwargs.get('locations', 'all')
|
||||
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
for location in locations:
|
||||
condition = self.problem.conditions[location]
|
||||
|
||||
try:
|
||||
pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
|
||||
if n_params != 0:
|
||||
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
|
||||
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
|
||||
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
|
||||
pts = pts.append(pts_params)
|
||||
except:
|
||||
pts = condition.input_points
|
||||
pts = merge_tensors([
|
||||
condition.location.sample(
|
||||
argument['n'],
|
||||
argument['mode'],
|
||||
variables=argument['variables'])
|
||||
for argument in arguments])
|
||||
|
||||
self.input_pts[location] = pts #.double() # TODO
|
||||
self.input_pts[location] = (
|
||||
self.input_pts[location].to(dtype=self.dtype,
|
||||
@@ -168,9 +200,10 @@ class PINN(object):
|
||||
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = self.input_pts[condition_name]
|
||||
predicted = self.model(pts)
|
||||
|
||||
if hasattr(condition, 'function'):
|
||||
pts = self.input_pts[condition_name]
|
||||
predicted = self.model(pts)
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
@@ -181,6 +214,10 @@ class PINN(object):
|
||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||
losses.append(local_loss)
|
||||
elif hasattr(condition, 'output_points'):
|
||||
pts = condition.input_points
|
||||
# print(pts)
|
||||
predicted = self.model(pts)
|
||||
# print(predicted)
|
||||
residuals = predicted - condition.output_points
|
||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||
losses.append(local_loss)
|
||||
|
||||
@@ -106,6 +106,7 @@ class Plotter:
|
||||
ind_dict[location] = ind_to_exclude
|
||||
import functools
|
||||
from functools import reduce
|
||||
|
||||
final_inds = reduce(np.intersect1d, ind_dict.values())
|
||||
predicted_output = obj.model(pts)
|
||||
predicted_output = predicted_output.extract([component])
|
||||
@@ -122,7 +123,7 @@ class Plotter:
|
||||
fig.colorbar(cb, ax=axes[2])
|
||||
else:
|
||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach(), levels=32)
|
||||
fig.colorbar(cb, ax=axes)
|
||||
|
||||
if filename:
|
||||
|
||||
@@ -71,7 +71,7 @@ class Span(Location):
|
||||
if not len(spatial_fixed_)==0:
|
||||
pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
|
||||
dtype=pts.dtype) * fixed
|
||||
|
||||
pts_fixed_ = pts_fixed_.float()
|
||||
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
|
||||
pts_range_ = pts_range_.append(pts_fixed_)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user