diff --git a/pina/pinn.py b/pina/pinn.py index 860c464..1ca2636 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -3,6 +3,7 @@ import torch from .problem import AbstractProblem from .label_tensor import LabelTensor +from .utils import merge_tensors torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -126,27 +127,6 @@ class PINN(object): >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) >>> pinn.span_pts(n=10, mode='grid', variables=['x']) """ - - def merge_tensors(tensors): # name to be changed - if len(tensors) == 2: - tensor1 = tensors[0] - tensor2 = tensors[1] - n1 = tensor1.shape[0] - n2 = tensor2.shape[0] - - tensor1 = LabelTensor( - tensor1.repeat(n2, 1), - labels=tensor1.labels) - tensor2 = LabelTensor( - tensor2.repeat_interleave(n1, dim=0), - labels=tensor2.labels) - return tensor1.append(tensor2) - elif len(tensors) == 1: - return tensors[0] - else: - recursive_result = merge_tensors(tensors[1:]) - return merge_tensors([tensors[0], recursive_result]) - if isinstance(args[0], int) and isinstance(args[1], str): argument = {} argument['n'] = int(args[0]) @@ -171,19 +151,20 @@ class PINN(object): for location in locations: condition = self.problem.conditions[location] - pts = merge_tensors([ - condition.location.sample( - argument['n'], - argument['mode'], - variables=argument['variables']) - for argument in arguments]) + samples = tuple(condition.location.sample( + argument['n'], + argument['mode'], + variables=argument['variables']) + for argument in arguments) + pts = merge_tensors(samples) - self.input_pts[location] = pts #.double() # TODO - self.input_pts[location] = ( - self.input_pts[location].to(dtype=self.dtype, - device=self.device)) - self.input_pts[location].requires_grad_(True) - self.input_pts[location].retain_grad() + # TODO + # pts = pts.double() + pts = pts.to(dtype=self.dtype, device=self.device) + pts.requires_grad_(True) + pts.retain_grad() + + self.input_pts[location] = pts def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): @@ -193,17 +174,14 @@ class PINN(object): for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] - if hasattr(condition, 'function'): - if isinstance(condition.function, list): - for function in condition.function: - header.append(f'{condition_name}{function.__name__}') - - continue - - header.append(f'{condition_name}') + if (hasattr(condition, 'function') and + isinstance(condition.function, list)): + for function in condition.function: + header.append(f'{condition_name}{function.__name__}') + else: + header.append(f'{condition_name}') while True: - losses = [] for condition_name in self.problem.conditions: diff --git a/pina/utils.py b/pina/utils.py index 769ee97..ed18af0 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -1,9 +1,11 @@ """Utils module""" +from functools import reduce + def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check """ Return the number of parameters of a given `model`. - + :param torch.nn.Module model: the torch module to inspect. :param bool aggregate: if True the return values is an integer corresponding to the total amount of parameters of whole model. If False, it returns a @@ -25,3 +27,19 @@ def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check tmp = sum(tmp.values()) return tmp + + +def merge_tensors(tensors): # name to be changed + if tensors: + return reduce(merge_two_tensors, tensors[1:], tensors[0]) + raise ValueError("Expected at least one tensor") + + +def merge_two_tensors(tensor1, tensor2): + n1 = tensor1.shape[0] + n2 = tensor2.shape[0] + + tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) + tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), + labels=tensor2.labels) + return tensor1.append(tensor2)