|
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
-
- import numpy as np
- import torch
-
- from mmengine.dataset import default_collate, pseudo_collate
- from mmengine.structures import BaseDataElement
- from mmengine.utils import is_list_of
-
-
- class TestDataUtils(TestCase):
-
- def test_pseudo_collate(self):
- # Test with list of dict tensor inputs.
- input1 = torch.randn(1, 3, 5)
- input2 = torch.randn(1, 3, 5)
- label1 = torch.randn(1)
- label2 = torch.randn(1)
-
- data_batch = [
- dict(inputs=input1, data_sample=label1),
- dict(inputs=input2, data_sample=label2)
- ]
- data_batch = pseudo_collate(data_batch)
- self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
- self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
- self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
- self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
-
- # Test with list of dict, and each element contains `data_sample`
- # inputs
- data_sample1 = BaseDataElement(label=torch.tensor(1))
- data_sample2 = BaseDataElement(label=torch.tensor(1))
- data = [
- dict(inputs=input1, data_sample=data_sample1),
- dict(inputs=input2, data_sample=data_sample2),
- ]
- data_batch = pseudo_collate(data)
- batch_inputs, batch_data_sample = (data_batch['inputs'],
- data_batch['data_sample'])
- # check batch_inputs
- self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
- self.assertIs(input1, batch_inputs[0])
- self.assertIs(input2, batch_inputs[1])
-
- # check data_sample
- self.assertIs(batch_data_sample[0], data_sample1)
- self.assertIs(batch_data_sample[1], data_sample2)
-
- # Test with list of tuple, each tuple is a nested dict instance
- data_batch = [(dict(
- inputs=input1,
- data_sample=data_sample1,
- value=1,
- name='1',
- nested=dict(data_sample=data_sample1)),
- dict(
- inputs=input2,
- data_sample=data_sample2,
- value=2,
- name='2',
- nested=dict(data_sample=data_sample2))),
- (dict(
- inputs=input1,
- data_sample=data_sample1,
- value=1,
- name='1',
- nested=dict(data_sample=data_sample1)),
- dict(
- inputs=input2,
- data_sample=data_sample2,
- value=2,
- name='2',
- nested=dict(data_sample=data_sample2)))]
- data_batch = pseudo_collate(data_batch)
- batch_inputs_0 = data_batch[0]['inputs']
- batch_inputs_1 = data_batch[1]['inputs']
- batch_data_sample_0 = data_batch[0]['data_sample']
- batch_data_sample_1 = data_batch[1]['data_sample']
- batch_value_0 = data_batch[0]['value']
- batch_value_1 = data_batch[1]['value']
- batch_name_0 = data_batch[0]['name']
- batch_name_1 = data_batch[1]['name']
- batch_nested_0 = data_batch[0]['nested']
- batch_nested_1 = data_batch[1]['nested']
-
- self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
- self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
- self.assertIs(batch_inputs_0[0], input1)
- self.assertIs(batch_inputs_0[1], input1)
- self.assertIs(batch_inputs_1[0], input2)
- self.assertIs(batch_inputs_1[1], input2)
-
- self.assertIs(batch_data_sample_0[0], data_sample1)
- self.assertIs(batch_data_sample_0[1], data_sample1)
- self.assertIs(batch_data_sample_1[0], data_sample2)
- self.assertIs(batch_data_sample_1[1], data_sample2)
-
- self.assertEqual(batch_value_0, [1, 1])
- self.assertEqual(batch_value_1, [2, 2])
-
- self.assertEqual(batch_name_0, ['1', '1'])
- self.assertEqual(batch_name_1, ['2', '2'])
-
- self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
- self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
- self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
- self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
-
- def test_default_collate(self):
- # `default_collate` has comment logic with `pseudo_collate`, therefore
- # only test it cam stack batch tensor, convert int or float to tensor.
- input1 = torch.randn(1, 3, 5)
- input2 = torch.randn(1, 3, 5)
- data_batch = [(
- dict(inputs=input1, value=1, array=np.array(1)),
- dict(inputs=input2, value=2, array=np.array(2)),
- ),
- (
- dict(inputs=input1, value=1, array=np.array(1)),
- dict(inputs=input2, value=2, array=np.array(2)),
- )]
- data_batch = default_collate(data_batch)
- batch_inputs_0 = data_batch[0]['inputs']
- batch_inputs_1 = data_batch[1]['inputs']
- batch_value_0 = data_batch[0]['value']
- batch_value_1 = data_batch[1]['value']
- batch_array_0 = data_batch[0]['array']
- batch_array_1 = data_batch[1]['array']
-
- self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
- self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
-
- self.assertTrue(
- torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
- self.assertTrue(
- torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
-
- target1 = torch.stack([torch.tensor(1), torch.tensor(1)])
- target2 = torch.stack([torch.tensor(2), torch.tensor(2)])
-
- self.assertTrue(
- torch.allclose(batch_value_0.to(target1.dtype), target1))
- self.assertTrue(
- torch.allclose(batch_value_1.to(target2.dtype), target2))
-
- self.assertTrue(
- torch.allclose(batch_array_0.to(target1.dtype), target1))
- self.assertTrue(
- torch.allclose(batch_array_1.to(target2.dtype), target2))
|