Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -114,10 +114,10 @@ def test_dummy_dataloader(input_, output_):
|
||||
assert isinstance(data, list)
|
||||
assert isinstance(data[0], tuple)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data[0][1]["input_points"], Batch)
|
||||
assert isinstance(data[0][1]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data[0][1]["input_points"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["output_points"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, DummyDataloader)
|
||||
@@ -126,10 +126,10 @@ def test_dummy_dataloader(input_, output_):
|
||||
assert isinstance(data, list)
|
||||
assert isinstance(data[0], tuple)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data[0][1]["input_points"], Batch)
|
||||
assert isinstance(data[0][1]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data[0][1]["input_points"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["output_points"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -157,10 +157,10 @@ def test_dataloader(input_, output_, automatic_batching):
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data["data"]["input_points"], Batch)
|
||||
assert isinstance(data["data"]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data["data"]["input_points"], torch.Tensor)
|
||||
assert isinstance(data["data"]["output_points"], torch.Tensor)
|
||||
assert isinstance(data["data"]["input"], torch.Tensor)
|
||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
@@ -168,10 +168,10 @@ def test_dataloader(input_, output_, automatic_batching):
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data["data"]["input_points"], Batch)
|
||||
assert isinstance(data["data"]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data["data"]["input_points"], torch.Tensor)
|
||||
assert isinstance(data["data"]["output_points"], torch.Tensor)
|
||||
assert isinstance(data["data"]["input"], torch.Tensor)
|
||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||
|
||||
|
||||
from pina import LabelTensor
|
||||
@@ -212,15 +212,15 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data["data"]["input_points"], Batch)
|
||||
assert isinstance(data["data"]["input_points"].x, LabelTensor)
|
||||
assert data["data"]["input_points"].x.labels == ["u", "v", "w"]
|
||||
assert data["data"]["input_points"].pos.labels == ["x", "y"]
|
||||
assert isinstance(data["data"]["input"], Batch)
|
||||
assert isinstance(data["data"]["input"].x, LabelTensor)
|
||||
assert data["data"]["input"].x.labels == ["u", "v", "w"]
|
||||
assert data["data"]["input"].pos.labels == ["x", "y"]
|
||||
else:
|
||||
assert isinstance(data["data"]["input_points"], LabelTensor)
|
||||
assert data["data"]["input_points"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["output_points"], LabelTensor)
|
||||
assert data["data"]["output_points"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["input"], LabelTensor)
|
||||
assert data["data"]["input"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["target"], LabelTensor)
|
||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
@@ -228,13 +228,13 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data["data"]["input_points"], Batch)
|
||||
assert isinstance(data["data"]["input_points"].x, LabelTensor)
|
||||
assert data["data"]["input_points"].x.labels == ["u", "v", "w"]
|
||||
assert data["data"]["input_points"].pos.labels == ["x", "y"]
|
||||
assert isinstance(data["data"]["input"], Batch)
|
||||
assert isinstance(data["data"]["input"].x, LabelTensor)
|
||||
assert data["data"]["input"].x.labels == ["u", "v", "w"]
|
||||
assert data["data"]["input"].pos.labels == ["x", "y"]
|
||||
else:
|
||||
assert isinstance(data["data"]["input_points"], torch.Tensor)
|
||||
assert isinstance(data["data"]["input_points"], LabelTensor)
|
||||
assert data["data"]["input_points"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["output_points"], torch.Tensor)
|
||||
assert data["data"]["output_points"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["input"], torch.Tensor)
|
||||
assert isinstance(data["data"]["input"], LabelTensor)
|
||||
assert data["data"]["input"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||
|
||||
Reference in New Issue
Block a user