#6 fine-tune

Open
6000 wants to merge 37 commits from fine-tune into main
  1. +21
    -0
      LICENSE
  2. +25
    -11
      README.md
  3. +1
    -1
      cfgs/dataset_configs/ScanObjectNN_hardest.yaml
  4. +2
    -2
      cfgs/dataset_configs/ShapeNet-55.yaml
  5. +3
    -3
      cfgs/finetune_scan_hardest.yaml
  6. +2
    -0
      main.py
  7. +39
    -11
      models/Point_MAE.py
  8. +6
    -4
      utils/parser.py

+ 21
- 0
LICENSE View File

@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 PANG-Yatian, YUAN-Li

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

+ 25
- 11
README.md View File

@@ -1,6 +1,9 @@
# Point-MAE

## Implementation for paper: Masked Autoencoders for Point Cloud Self-supervised Learning
## Masked Autoencoders for Point Cloud Self-supervised Learning, [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136620591.pdf), [ArXiv](https://arxiv.org/abs/2203.06604)

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/masked-autoencoders-for-point-cloud-self/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=masked-autoencoders-for-point-cloud-self)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/masked-autoencoders-for-point-cloud-self/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=masked-autoencoders-for-point-cloud-self)

In this work, we present a novel scheme of masked autoencoders for point cloud self-supervised learning, termed as Point-MAE. Our Point-MAE is neat and efficient, with minimal modifications based on the properties of the point cloud. In classification tasks, Point-MAE outperforms all the other self-supervised learning methods on ScanObjectNN and ModelNet40. Point-MAE also advances state-of-the-art accuracies by 1.5%-2.3% in the few-shot learning on ModelNet40.

@@ -9,7 +12,7 @@ In this work, we present a novel scheme of masked autoencoders for point cloud s
</div>

## 1. Requirements
PyTorch >= 1.7.0;
PyTorch >= 1.7.0 < 1.11.0;
python >= 3.7;
CUDA >= 9.0;
GCC >= 4.9;
@@ -26,7 +29,7 @@ python setup.py install --user
cd ./extensions/emd
python setup.py install --user
# PointNet++
pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
# GPU kNN
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
```
@@ -38,13 +41,13 @@ We use ShapeNet, ScanObjectNN, ModelNet40 and ShapeNetPart in this work. See [DA
## 3. Point-MAE Models
| Task | Dataset | Config | Acc.| Download|
| ----- | ----- |-----| -----| -----|
| Pre-training | ShapeNet |[pretrain.yaml](./cfgs/pretrain.yaml)| N.A. | |
| Classification | ScanObjectNN |[finetune_scan_hardest.yaml](./cfgs/finetune_scan_hardest.yaml)| 84.52%| |
| Classification | ScanObjectNN |[finetune_scan_objbg.yaml](./cfgs/finetune_scan_objbg.yaml)| 88.29%| |
| Classification | ScanObjectNN |[finetune_scan_objonly.yaml](./cfgs/finetune_scan_objonly.yaml)| 90.01%| |
| Classification | ModelNet40(1k) |[finetune_modelnet.yaml](./cfgs/finetune_modelnet.yaml)| 93.80%| |
| Classification | ModelNet40(8k) |[finetune_modelnet_8k.yaml](./cfgs/finetune_modelnet_8k.yaml)| 94.04%| |
| Part segmentation| ShapeNetPart| [segmentation](./segmentation)| 86.1% mIoU| |
| Pre-training | ShapeNet |[pretrain.yaml](./cfgs/pretrain.yaml)| N.A. | [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/pretrain.pth) |
| Classification | ScanObjectNN |[finetune_scan_hardest.yaml](./cfgs/finetune_scan_hardest.yaml)| 85.18%| [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/scan_hardest.pth) |
| Classification | ScanObjectNN |[finetune_scan_objbg.yaml](./cfgs/finetune_scan_objbg.yaml)|90.02% | [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/scan_objbg.pth) |
| Classification | ScanObjectNN |[finetune_scan_objonly.yaml](./cfgs/finetune_scan_objonly.yaml)| 88.29%| [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/scan_objonly.pth) |
| Classification | ModelNet40(1k) |[finetune_modelnet.yaml](./cfgs/finetune_modelnet.yaml)| 93.80%| [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/modelnet_1k.pth) |
| Classification | ModelNet40(8k) |[finetune_modelnet_8k.yaml](./cfgs/finetune_modelnet_8k.yaml)| 94.04%| [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/modelnet_8k.pth) |
| Part segmentation| ShapeNetPart| [segmentation](./segmentation)| 86.1% mIoU| [here](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/part_seg.pth) |

| Task | Dataset | Config | 5w10s Acc. (%)| 5w20s Acc. (%)| 10w10s Acc. (%)| 10w20s Acc. (%)|
| ----- | ----- |-----| -----| -----|-----|-----|
@@ -98,6 +101,17 @@ python main_vis.py --test --ckpts <path/to/pre-trained/model> --config cfgs/pret

## Acknowledgements

Our codes are build upon [Point-BERT](https://github.com/lulutang0608/Point-BERT), [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) and [Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)
Our codes are built upon [Point-BERT](https://github.com/lulutang0608/Point-BERT), [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) and [Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)

## Reference

```
@misc{pang2022masked,
title={Masked Autoencoders for Point Cloud Self-supervised Learning},
author={Yatian Pang and Wenxiao Wang and Francis E. H. Tay and Wei Liu and Yonghong Tian and Li Yuan},
year={2022},
eprint={2203.06604},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

+ 1
- 1
cfgs/dataset_configs/ScanObjectNN_hardest.yaml View File

@@ -1,2 +1,2 @@
NAME: ScanObjectNN_hardest
ROOT: data/ScanObjectNN/main_split
ROOT: /tmp/dataset/ScanObjectNN/ScanObjectNN/main_split

+ 2
- 2
cfgs/dataset_configs/ShapeNet-55.yaml View File

@@ -1,4 +1,4 @@
NAME: ShapeNet
DATA_PATH: data/ShapeNet55-34/ShapeNet-55
DATA_PATH: /tmp/dataset/ShapeNet55-34/ShapeNet-55
N_POINTS: 8192
PC_PATH: data/ShapeNet55-34/shapenet_pc
PC_PATH: /tmp/dataset/ShapeNet55-34/shapenet_pc

+ 3
- 3
cfgs/finetune_scan_hardest.yaml View File

@@ -8,7 +8,7 @@ optimizer : {
scheduler: {
type: CosLR,
kwargs: {
epochs: 200,
epochs: 300,
initial_epochs : 10
}}

@@ -35,5 +35,5 @@ model : {
npoints: 2048
total_bs : 32
step_per_update : 1
max_epoch : 200
grad_norm_clip : 10
max_epoch : 300
grad_norm_clip : 10

+ 2
- 0
main.py View File

@@ -8,6 +8,8 @@ import time
import os
import torch
from tensorboardX import SummaryWriter
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


def main():
# args


+ 39
- 11
models/Point_MAE.py View File

@@ -10,7 +10,40 @@ from utils.checkpoint import get_missing_parameters_message, get_unexpected_para
from utils.logger import *
import random
from knn_cuda import KNN
from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
# from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
# from pytorch3d.loss import chamfer_distance

def chamfer(pc1, pc2):
"""
CD distance of single point cloud
arguments:
pc1: the array, size: (bs, num_point, num_feature).
pc2: the samples, size: (bs, num_point, num_feature).
returns:
distances: each entry is the distance from a sample to pc1
"""
bs, num_point1, num_features1 = pc1.shape
_, num_point2, num_features2 = pc2.shape
expanded_pc1 = pc1.repeat(1,num_point2, 1) # (512*512,1)
# (512,3) -> (512,1,3) -> (512,512,3) -> (512*512,3)
expanded_pc2 = torch.reshape(torch.unsqueeze(pc2, 2).repeat(1, 1, num_point1, 1), (bs, -1, num_features2))
# expanded_pc1: (512,1), expanded_pc2: (512*512,3)

distances = (expanded_pc1 - expanded_pc2) * (expanded_pc1 - expanded_pc2)
# distances = torch.sqrt(distances)
distances = torch.sum(distances, dim=2) # s1中的点和s2中的点两两之间的平方距离
distances = torch.reshape(distances, (bs, num_point2, num_point1)) #
distances = torch.min(distances, dim=1)[0]
distances = torch.mean(distances)
return distances

def chamfer_distance(pc1, pc2):
dist1 = chamfer(pc1, pc2)
dist2 = chamfer(pc2, pc1)
return torch.mean(dist1) + torch.mean(dist2)


class Encoder(nn.Module): ## Embedding module
@@ -115,6 +148,9 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

# print(q.shape)
# print(k.shape)
# print((q @ k.transpose(-2, -1)).shape)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
@@ -366,16 +402,7 @@ class Point_MAE(nn.Module):
trunc_normal_(self.mask_token, std=.02)
self.loss = config.loss
# loss
self.build_loss_func(self.loss)

def build_loss_func(self, loss_type):
if loss_type == "cdl1":
self.loss_func = ChamferDistanceL1().cuda()
elif loss_type =='cdl2':
self.loss_func = ChamferDistanceL2().cuda()
else:
raise NotImplementedError
# self.loss_func = emd().cuda()
self.loss_func = chamfer_distance


def forward(self, pts, vis = False, **kwargs):
@@ -400,6 +427,7 @@ class Point_MAE(nn.Module):

gt_points = neighborhood[mask].reshape(B*M,-1,3)
loss1 = self.loss_func(rebuild_points, gt_points)
print(loss1)

if vis: #visualization
vis_points = neighborhood[~mask].reshape(B * (self.num_group - M), -1, 3)


+ 6
- 4
utils/parser.py View File

@@ -7,6 +7,7 @@ def get_args():
parser.add_argument(
'--config',
type = str,
default='cfgs/finetune_scan_hardest.yaml',
help = 'yaml config file')
parser.add_argument(
'--launcher',
@@ -31,7 +32,8 @@ def get_args():
parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name')
parser.add_argument('--loss', type=str, default='cd1', help='loss name')
parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')
parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path')
# parser.add_argument('--ckpts', type = str, default="/tmp/dataset/ckpt-last/ckpt-last.pth", help = 'test used ckpt path')
parser.add_argument('--ckpts', type = str, default="/tmp/dataset/pretrain/pretrain.pth", help = 'test used ckpt path')
parser.add_argument('--val_freq', type = int, default=1, help = 'test freq')
parser.add_argument(
'--vote',
@@ -51,7 +53,7 @@ def get_args():
parser.add_argument(
'--finetune_model',
action='store_true',
default=False,
default=True,
help = 'finetune modelnet with pretrained weight')
parser.add_argument(
'--scratch_model',
@@ -95,8 +97,8 @@ def get_args():
args.exp_name = 'test_' + args.exp_name
if args.mode is not None:
args.exp_name = args.exp_name + '_' +args.mode
args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)
args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)
args.experiment_path = os.path.join('/tmp/output/experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)
args.tfboard_path = os.path.join('/tmp/output/experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)
args.log_name = Path(args.config).stem
create_experiment_dir(args)
return args


Loading…
Cancel
Save