Source code for common.vision.datasets.keypoint_detection.freihand
"""
@author: Junguang Jiang
@contact: [email protected]
"""
import json
import time
import torch
import os
import os.path as osp
from torchvision.datasets.utils import download_and_extract_archive
from ...transforms.keypoint_detection import *
from .keypoint_dataset import Hand21KeypointDataset
from .util import *
""" General util functions. """
def _assert_exist(p):
msg = 'File does not exists: %s' % p
assert os.path.exists(p), msg
def json_load(p):
_assert_exist(p)
with open(p, 'r') as fi:
d = json.load(fi)
return d
def load_db_annotation(base_path, set_name=None):
if set_name is None:
# only training set annotations are released so this is a valid default choice
set_name = 'training'
print('Loading FreiHAND dataset index ...')
t = time.time()
# assumed paths to data containers
k_path = os.path.join(base_path, '%s_K.json' % set_name)
mano_path = os.path.join(base_path, '%s_mano.json' % set_name)
xyz_path = os.path.join(base_path, '%s_xyz.json' % set_name)
# load if exist
K_list = json_load(k_path)
mano_list = json_load(mano_path)
xyz_list = json_load(xyz_path)
# should have all the same length
assert len(K_list) == len(mano_list), 'Size mismatch.'
assert len(K_list) == len(xyz_list), 'Size mismatch.'
print('Loading of %d samples done in %.2f seconds' % (len(K_list), time.time()-t))
return list(zip(K_list, mano_list, xyz_list))
def projectPoints(xyz, K):
""" Project 3D coordinates into image space. """
xyz = np.array(xyz)
K = np.array(K)
uv = np.matmul(K, xyz.T).T
return uv[:, :2] / uv[:, -1:]
""" Dataset related functions. """
def db_size(set_name):
""" Hardcoded size of the datasets. """
if set_name == 'training':
return 32560 # number of unique samples (they exists in multiple 'versions')
elif set_name == 'evaluation':
return 3960
else:
assert 0, 'Invalid choice.'
class sample_version:
gs = 'gs' # green screen
hom = 'hom' # homogenized
sample = 'sample' # auto colorization with sample points
auto = 'auto' # auto colorization without sample points: automatic color hallucination
db_size = db_size('training')
@classmethod
def valid_options(cls):
return [cls.gs, cls.hom, cls.sample, cls.auto]
@classmethod
def check_valid(cls, version):
msg = 'Invalid choice: "%s" (must be in %s)' % (version, cls.valid_options())
assert version in cls.valid_options(), msg
@classmethod
def map_id(cls, id, version):
cls.check_valid(version)
return id + cls.db_size*cls.valid_options().index(version)
[docs]class FreiHand(Hand21KeypointDataset):
"""`FreiHand Dataset <https://lmb.informatik.uni-freiburg.de/projects/freihand/>`_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
task (str, optional): The post-processing option to create dataset. Choices include ``'gs'``: green screen \
recording, ``'auto'``: auto colorization without sample points: automatic color hallucination, \
``'sample'``: auto colorization with sample points, ``'hom'``: homogenized, \
and ``'all'``: all hands. Default: 'all'.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~common.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note:: In `root`, there will exist following files after downloading.
::
*.json
training/
evaluation/
"""
def __init__(self, root, split='train', task='all', download=True, **kwargs):
if download:
if not osp.exists(osp.join(root, "training")) or not osp.exists(osp.join(root, "evaluation")):
download_and_extract_archive("https://lmb.informatik.uni-freiburg.de/data/freihand/FreiHAND_pub_v2.zip",
download_root=root, filename="FreiHAND_pub_v2.zip", remove_finished=False,
extract_root=root)
assert split in ['train', 'test', 'all']
self.split = split
assert task in ['all', 'gs', 'auto', 'sample', 'hom']
self.task = task
if task == 'all':
samples = self.get_samples(root, 'gs') + self.get_samples(root, 'auto') + self.get_samples(root, 'sample') + self.get_samples(root, 'hom')
else:
samples = self.get_samples(root, task)
random.seed(42)
random.shuffle(samples)
samples_len = len(samples)
samples_split = min(int(samples_len * 0.2), 3200)
if self.split == 'train':
samples = samples[samples_split:]
elif self.split == 'test':
samples = samples[:samples_split]
super(FreiHand, self).__init__(root, samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = os.path.join(self.root, image_name)
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
# Crop the images such that the hand is at the center of the image
# The images will be 1.5 times larger than the hand
# The crop process will change Xc and Yc, leaving Zc with no changes
bounding_box = get_bounding_box(keypoint2d)
w, h = image.size
left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)
image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d)
# Change all hands to right hands
if sample['left'] is False:
image, keypoint2d = hflip(image, keypoint2d)
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.ones((self.num_keypoints, ), dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
z = keypoint3d_n[:, 2]
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
'z': z,
}
return image, target, target_weight, meta
def get_samples(self, root, version='gs'):
set = 'training'
# load annotations of this set
db_data_anno = load_db_annotation(root, set)
version_map = {
'gs': sample_version.gs,
'hom': sample_version.hom,
'sample': sample_version.sample,
'auto': sample_version.auto
}
samples = []
for idx in range(db_size(set)):
image_name = os.path.join(set, 'rgb',
'%08d.jpg' % sample_version.map_id(idx, version_map[version]))
mask_name = os.path.join(set, 'mask', '%08d.jpg' % idx)
intrinsic_matrix, mano, keypoint3d = db_data_anno[idx]
keypoint2d = projectPoints(keypoint3d, intrinsic_matrix)
sample = {
'name': image_name,
'mask_name': mask_name,
'keypoint2d': keypoint2d,
'keypoint3d': keypoint3d,
'intrinsic_matrix': intrinsic_matrix,
'left': False
}
samples.append(sample)
return samples