Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -114,10 +114,10 @@ def test_dummy_dataloader(input_, output_):
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input_points"], Batch)
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input_points"], torch.Tensor)
assert isinstance(data[0][1]["output_points"], torch.Tensor)
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DummyDataloader)
@@ -126,10 +126,10 @@ def test_dummy_dataloader(input_, output_):
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input_points"], Batch)
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input_points"], torch.Tensor)
assert isinstance(data[0][1]["output_points"], torch.Tensor)
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
@pytest.mark.parametrize(
@@ -157,10 +157,10 @@ def test_dataloader(input_, output_, automatic_batching):
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
assert isinstance(data["data"]["input_points"], Batch)
assert isinstance(data["data"]["input"], Batch)
else:
assert isinstance(data["data"]["input_points"], torch.Tensor)
assert isinstance(data["data"]["output_points"], torch.Tensor)
assert isinstance(data["data"]["input"], torch.Tensor)
assert isinstance(data["data"]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader)
@@ -168,10 +168,10 @@ def test_dataloader(input_, output_, automatic_batching):
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
assert isinstance(data["data"]["input_points"], Batch)
assert isinstance(data["data"]["input"], Batch)
else:
assert isinstance(data["data"]["input_points"], torch.Tensor)
assert isinstance(data["data"]["output_points"], torch.Tensor)
assert isinstance(data["data"]["input"], torch.Tensor)
assert isinstance(data["data"]["target"], torch.Tensor)
from pina import LabelTensor
@@ -212,15 +212,15 @@ def test_dataloader_labels(input_, output_, automatic_batching):
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
assert isinstance(data["data"]["input_points"], Batch)
assert isinstance(data["data"]["input_points"].x, LabelTensor)
assert data["data"]["input_points"].x.labels == ["u", "v", "w"]
assert data["data"]["input_points"].pos.labels == ["x", "y"]
assert isinstance(data["data"]["input"], Batch)
assert isinstance(data["data"]["input"].x, LabelTensor)
assert data["data"]["input"].x.labels == ["u", "v", "w"]
assert data["data"]["input"].pos.labels == ["x", "y"]
else:
assert isinstance(data["data"]["input_points"], LabelTensor)
assert data["data"]["input_points"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["output_points"], LabelTensor)
assert data["data"]["output_points"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["input"], LabelTensor)
assert data["data"]["input"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["target"], LabelTensor)
assert data["data"]["target"].labels == ["u", "v", "w"]
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader)
@@ -228,13 +228,13 @@ def test_dataloader_labels(input_, output_, automatic_batching):
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
assert isinstance(data["data"]["input_points"], Batch)
assert isinstance(data["data"]["input_points"].x, LabelTensor)
assert data["data"]["input_points"].x.labels == ["u", "v", "w"]
assert data["data"]["input_points"].pos.labels == ["x", "y"]
assert isinstance(data["data"]["input"], Batch)
assert isinstance(data["data"]["input"].x, LabelTensor)
assert data["data"]["input"].x.labels == ["u", "v", "w"]
assert data["data"]["input"].pos.labels == ["x", "y"]
else:
assert isinstance(data["data"]["input_points"], torch.Tensor)
assert isinstance(data["data"]["input_points"], LabelTensor)
assert data["data"]["input_points"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["output_points"], torch.Tensor)
assert data["data"]["output_points"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["input"], torch.Tensor)
assert isinstance(data["data"]["input"], LabelTensor)
assert data["data"]["input"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["target"], torch.Tensor)
assert data["data"]["target"].labels == ["u", "v", "w"]