From 088649e042fc2557b17c4f31080359351f6ee83c Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 11 May 2022 16:42:11 +0200 Subject: [PATCH] minor changes --- pina/model/deeponet.py | 78 ++++++++++++++++++++++++++++++++---------- pina/pinn.py | 61 ++++++++++++++++++++++++++------- pina/plotter.py | 3 +- pina/span.py | 2 +- 4 files changed, 112 insertions(+), 32 deletions(-) diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index 17a5df8..753ae17 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -19,7 +19,8 @@ class DeepONet(torch.nn.Module): `_ """ - def __init__(self, branch_net, trunk_net, output_variables, inner_size=10): + def __init__(self, branch_net, trunk_net, output_variables, inner_size=10, + features=None, features_net=None): """ :param torch.nn.Module branch_net: the neural network to use as branch model. It has to take as input a :class:`LabelTensor`. The number @@ -64,17 +65,48 @@ class DeepONet(torch.nn.Module): self.output_variables = output_variables self.output_dimension = len(output_variables) + trunk_out_dim = trunk_net.layers[-1].out_features + branch_out_dim = branch_net.layers[-1].out_features + + if trunk_out_dim != branch_out_dim: + raise ValueError('Branch and trunk networks have not the same ' + 'output dimension.') + self.trunk_net = trunk_net self.branch_net = branch_net - if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int): - if self.branch_net.output_dimension == self.trunk_net.output_dimension: - self.inner_size = self.branch_net.output_dimension - else: - raise ValueError('Branch and trunk networks have not the same output dimension.') - else: - warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.") - self.inner_size = self.output_dimension*inner_size + if features: + # if len(features) != features_net.layers[0].in_features: + # raise ValueError('Incompatible features') + # if trunk_out_dim != features_net.layers[-1].out_features: + # raise ValueError('Incompatible features') + + self.features = features + # self.features_net = nn.Sequential( + # nn.Linear(len(features), 10), nn.Softplus(), + # # nn.Linear(10, 10), nn.Softplus(), + # nn.Linear(10, trunk_out_dim) + # ) + self.features_net = nn.Sequential( + nn.Linear(len(features), trunk_out_dim) + ) + + + + + self.reduction = nn.Linear(trunk_out_dim, self.output_dimension) + + # print(self.branch_net.output_variables) + # print(self.trunk_net.output_variables) + # if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int): + # if self.branch_net.output_dimension == self.trunk_net.output_dimension: + # self.inner_size = self.branch_net.output_dimension + # print('qui') + # else: + # raise ValueError('Branch and trunk networks have not the same output dimension.') + # else: + # warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.") + # self.inner_size = self.output_dimension*inner_size @property def input_variables(self): @@ -89,17 +121,27 @@ class DeepONet(torch.nn.Module): :return: the output computed by the model. :rtype: LabelTensor """ + input_feature = [] + for feature in self.features: + #print(feature) + input_feature.append(feature(x)) + input_feature = torch.cat(input_feature, dim=1) + branch_output = self.branch_net( x.extract(self.branch_net.input_variables)) trunk_output = self.trunk_net( x.extract(self.trunk_net.input_variables)) - local_size = int(self.inner_size/self.output_dimension) - for i, var in enumerate(self.output_variables): - start = i*local_size - stop = (i+1)*local_size - local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var) - if i==0: - output_ = local_output - else: - output_ = output_.append(local_output) + feat_output = self.features_net(input_feature) + output_ = self.reduction(branch_output * trunk_output * feat_output) + output_ = self.reduction(trunk_output * feat_output) + output_ = LabelTensor(output_, self.output_variables) + # local_size = int(trunk_output.shape[1]/self.output_dimension) + # for i, var in enumerate(self.output_variables): + # start = i*local_size + # stop = (i+1)*local_size + # local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var) + # if i==0: + # output_ = local_output + # else: + # output_ = output_.append(local_output) return output_ diff --git a/pina/pinn.py b/pina/pinn.py index 3263792..1f0f3b9 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -120,21 +120,53 @@ class PINN(object): return self - def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'): + def span_pts(self, *args, **kwargs): + """ + >>> pinn.span_pts(n=10, mode='grid') + >>> pinn.span_pts(n=10, mode='grid', variables=['x']) + """ + + def merge_tensors(tensors): # name to be changed + if len(tensors) == 2: + tensor1 = tensors[0] + tensor2 = tensors[1] + n1 = tensor1.shape[0] + n2 = tensor2.shape[0] + + tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) + tensor2 = LabelTensor( + tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) + return tensor1.append(tensor2) + else: + pass + + if isinstance(args[0], int) and isinstance(args[1], str): + pass + variables = self.problem.input_variables + elif all(isinstance(arg, dict) for arg in args): + print(args) + arguments = args + pass + elif all(key in kwargs for key in ['n', 'mode']): + variables = self.problem.input_variables + pass + else: + raise RuntimeError + + locations = kwargs.get('locations', 'all') + if locations == 'all': locations = [condition for condition in self.problem.conditions] for location in locations: condition = self.problem.conditions[location] - try: - pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables) - if n_params != 0: - pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters) - pts = LabelTensor(pts.repeat(n_params, 1), pts.labels) - pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels) - pts = pts.append(pts_params) - except: - pts = condition.input_points + pts = merge_tensors([ + condition.location.sample( + argument['n'], + argument['mode'], + variables=argument['variables']) + for argument in arguments]) + self.input_pts[location] = pts #.double() # TODO self.input_pts[location] = ( self.input_pts[location].to(dtype=self.dtype, @@ -168,9 +200,10 @@ class PINN(object): for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] - pts = self.input_pts[condition_name] - predicted = self.model(pts) + if hasattr(condition, 'function'): + pts = self.input_pts[condition_name] + predicted = self.model(pts) if isinstance(condition.function, list): for function in condition.function: residuals = function(pts, predicted) @@ -181,6 +214,10 @@ class PINN(object): local_loss = condition.data_weight*self._compute_norm(residuals) losses.append(local_loss) elif hasattr(condition, 'output_points'): + pts = condition.input_points + # print(pts) + predicted = self.model(pts) + # print(predicted) residuals = predicted - condition.output_points local_loss = condition.data_weight*self._compute_norm(residuals) losses.append(local_loss) diff --git a/pina/plotter.py b/pina/plotter.py index 6f73af7..b87988d 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -106,6 +106,7 @@ class Plotter: ind_dict[location] = ind_to_exclude import functools from functools import reduce + final_inds = reduce(np.intersect1d, ind_dict.values()) predicted_output = obj.model(pts) predicted_output = predicted_output.extract([component]) @@ -122,7 +123,7 @@ class Plotter: fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach()) + cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach(), levels=32) fig.colorbar(cb, ax=axes) if filename: diff --git a/pina/span.py b/pina/span.py index c72f197..fd1ca12 100644 --- a/pina/span.py +++ b/pina/span.py @@ -71,7 +71,7 @@ class Span(Location): if not len(spatial_fixed_)==0: pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_), dtype=pts.dtype) * fixed - + pts_fixed_ = pts_fixed_.float() pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_) pts_range_ = pts_range_.append(pts_fixed_)