add temporary fix to treat data input
This commit is contained in:
37
pina/pinn.py
37
pina/pinn.py
@@ -146,39 +146,6 @@ class PINN(object):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _create_dataloader(self):
|
|
||||||
"""Private method for creating dataloader
|
|
||||||
|
|
||||||
:return: dataloader
|
|
||||||
:rtype: torch.utils.data.DataLoader
|
|
||||||
"""
|
|
||||||
if self.batch_size is None:
|
|
||||||
return [self.input_pts]
|
|
||||||
|
|
||||||
def custom_collate(batch):
|
|
||||||
# extracting pts labels
|
|
||||||
_, pts = list(batch[0].items())[0]
|
|
||||||
labels = pts.labels
|
|
||||||
# calling default torch collate
|
|
||||||
collate_res = default_collate(batch)
|
|
||||||
# save collate result in dict
|
|
||||||
res = {}
|
|
||||||
for key, val in collate_res.items():
|
|
||||||
val.labels = labels
|
|
||||||
res[key] = val
|
|
||||||
return res
|
|
||||||
|
|
||||||
# creating dataset, list of dataset for each location
|
|
||||||
datasets = [MyDataSet(key, val)
|
|
||||||
for key, val in self.input_pts.items()]
|
|
||||||
# creating dataloader
|
|
||||||
dataloaders = [DataLoader(dataset=dat,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
collate_fn=custom_collate)
|
|
||||||
for dat in datasets]
|
|
||||||
|
|
||||||
return dict(zip(self.input_pts.keys(), dataloaders))
|
|
||||||
|
|
||||||
def span_pts(self, *args, **kwargs):
|
def span_pts(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
>>> pinn.span_pts(n=10, mode='grid')
|
>>> pinn.span_pts(n=10, mode='grid')
|
||||||
@@ -227,6 +194,10 @@ class PINN(object):
|
|||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
epoch = 0
|
epoch = 0
|
||||||
|
# Add all condition with `input_points` to dataloader
|
||||||
|
for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())):
|
||||||
|
self.input_pts[condition] = self.problem.conditions[condition]
|
||||||
|
|
||||||
data_loader = self.data_set.dataloader
|
data_loader = self.data_set.dataloader
|
||||||
|
|
||||||
header = []
|
header = []
|
||||||
|
|||||||
Reference in New Issue
Block a user