From 54588ddf4cf06e747a6240f58dc81d72318a2046 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 23 Feb 2023 15:02:14 +0100 Subject: [PATCH] add temporary fix to treat data input --- pina/pinn.py | 37 ++++--------------------------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/pina/pinn.py b/pina/pinn.py index 215b53d..0be477b 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -146,39 +146,6 @@ class PINN(object): return self - def _create_dataloader(self): - """Private method for creating dataloader - - :return: dataloader - :rtype: torch.utils.data.DataLoader - """ - if self.batch_size is None: - return [self.input_pts] - - def custom_collate(batch): - # extracting pts labels - _, pts = list(batch[0].items())[0] - labels = pts.labels - # calling default torch collate - collate_res = default_collate(batch) - # save collate result in dict - res = {} - for key, val in collate_res.items(): - val.labels = labels - res[key] = val - return res - - # creating dataset, list of dataset for each location - datasets = [MyDataSet(key, val) - for key, val in self.input_pts.items()] - # creating dataloader - dataloaders = [DataLoader(dataset=dat, - batch_size=self.batch_size, - collate_fn=custom_collate) - for dat in datasets] - - return dict(zip(self.input_pts.keys(), dataloaders)) - def span_pts(self, *args, **kwargs): """ >>> pinn.span_pts(n=10, mode='grid') @@ -227,6 +194,10 @@ class PINN(object): self.model.train() epoch = 0 + # Add all condition with `input_points` to dataloader + for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())): + self.input_pts[condition] = self.problem.conditions[condition] + data_loader = self.data_set.dataloader header = []