add temporary fix to treat data input

This commit is contained in:
Your Name
2023-02-23 15:02:14 +01:00
committed by Nicola Demo
parent 7ce080fd85
commit 54588ddf4c

View File

@@ -146,39 +146,6 @@ class PINN(object):
return self
def _create_dataloader(self):
"""Private method for creating dataloader
:return: dataloader
:rtype: torch.utils.data.DataLoader
"""
if self.batch_size is None:
return [self.input_pts]
def custom_collate(batch):
# extracting pts labels
_, pts = list(batch[0].items())[0]
labels = pts.labels
# calling default torch collate
collate_res = default_collate(batch)
# save collate result in dict
res = {}
for key, val in collate_res.items():
val.labels = labels
res[key] = val
return res
# creating dataset, list of dataset for each location
datasets = [MyDataSet(key, val)
for key, val in self.input_pts.items()]
# creating dataloader
dataloaders = [DataLoader(dataset=dat,
batch_size=self.batch_size,
collate_fn=custom_collate)
for dat in datasets]
return dict(zip(self.input_pts.keys(), dataloaders))
def span_pts(self, *args, **kwargs):
"""
>>> pinn.span_pts(n=10, mode='grid')
@@ -227,6 +194,10 @@ class PINN(object):
self.model.train()
epoch = 0
# Add all condition with `input_points` to dataloader
for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())):
self.input_pts[condition] = self.problem.conditions[condition]
data_loader = self.data_set.dataloader
header = []