supervised working
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
__all__ = [
|
||||
]
|
||||
|
||||
from .pina_dataloader import SamplePointLoader
|
||||
from .data_dataset import DataPointDataset
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .pina_batch import Batch
|
||||
41
pina/data/data_dataset.py
Normal file
41
pina/data/data_dataset.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, "output_points"):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.stack(input_list)
|
||||
self.output_pts = LabelTensor.stack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
36
pina/data/pina_batch.py
Normal file
36
pina/data/pina_batch.py
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, type_, idx, *args, **kwargs) -> None:
|
||||
"""
|
||||
"""
|
||||
if type_ == "sample":
|
||||
|
||||
if len(args) != 2:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
conditions = args[1]
|
||||
|
||||
self.input = input[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
elif type_ == "data":
|
||||
|
||||
if len(args) != 3:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
output = args[1]
|
||||
conditions = args[2]
|
||||
|
||||
self.input = input[idx]
|
||||
self.output = output[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid number of arguments.")
|
||||
@@ -1,84 +1,8 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, "output_points"):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.pts = LabelTensor.vstack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, "output_points"):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.vstack(input_list)
|
||||
self.output_pts = LabelTensor.vstack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .data_dataset import DataPointDataset
|
||||
from .pina_batch import Batch
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
@@ -133,6 +57,8 @@ class SamplePointLoader:
|
||||
else:
|
||||
self.random_idx = torch.arange(len(self.batch_list))
|
||||
|
||||
self._prepare_batches()
|
||||
|
||||
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for data points.
|
||||
@@ -169,7 +95,7 @@ class SamplePointLoader:
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num
|
||||
)
|
||||
|
||||
print(input_labels)
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
@@ -216,6 +142,29 @@ class SamplePointLoader:
|
||||
self.tensor_conditions, batch_num
|
||||
)
|
||||
|
||||
def _prepare_batches(self):
|
||||
"""
|
||||
Prepare the batches.
|
||||
"""
|
||||
self.batches = []
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
|
||||
if type_ == "sample":
|
||||
batch = Batch(
|
||||
"sample", idx_,
|
||||
self.batch_sample_pts,
|
||||
self.batch_sample_conditions)
|
||||
else:
|
||||
batch = Batch(
|
||||
"data", idx_,
|
||||
self.batch_input_pts,
|
||||
self.batch_output_pts,
|
||||
self.batch_data_conditions)
|
||||
print(batch.input.labels)
|
||||
|
||||
self.batches.append(batch)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator over the points. Any element of the iterator is a
|
||||
@@ -233,21 +182,24 @@ class SamplePointLoader:
|
||||
:rtype: iter
|
||||
"""
|
||||
# for i in self.random_idx:
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
for i in self.random_idx:
|
||||
yield self.batches[i]
|
||||
|
||||
if type_ == "sample":
|
||||
d = {
|
||||
"pts": self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
"condition": self.batch_sample_conditions[idx_],
|
||||
}
|
||||
else:
|
||||
d = {
|
||||
"pts": self.batch_input_pts[idx_].requires_grad_(True),
|
||||
"output": self.batch_output_pts[idx_],
|
||||
"condition": self.batch_data_conditions[idx_],
|
||||
}
|
||||
yield d
|
||||
# for i in range(len(self.batch_list)):
|
||||
# type_, idx_ = self.batch_list[i]
|
||||
|
||||
# if type_ == "sample":
|
||||
# d = {
|
||||
# "pts": self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
# "condition": self.batch_sample_conditions[idx_],
|
||||
# }
|
||||
# else:
|
||||
# d = {
|
||||
# "pts": self.batch_input_pts[idx_].requires_grad_(True),
|
||||
# "output": self.batch_output_pts[idx_],
|
||||
# "condition": self.batch_data_conditions[idx_],
|
||||
# }
|
||||
# yield d
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
43
pina/data/sample_dataset.py
Normal file
43
pina/data/sample_dataset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, "output_points"):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.pts = LabelTensor.stack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
Reference in New Issue
Block a user