From 36ad2d2b6f3830fa00a61ce1f0a9bf2567f5aa28 Mon Sep 17 00:00:00 2001 From: renqingquan <2944007663@qq.com> Date: Sat, 3 Dec 2022 22:56:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20''?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- learn_conv1d.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 learn_conv1d.py diff --git a/learn_conv1d.py b/learn_conv1d.py new file mode 100644 index 0000000..30be6bf --- /dev/null +++ b/learn_conv1d.py @@ -0,0 +1,31 @@ +# Created on 2018/12/10 +# Author: Kaituo XU + +import torch +torch.manual_seed(123) + +Cin, Cout, F, S = 4, 3, 2, 1 +B, Lin = 2, 5 +D = 2 +print('Cin, Cout, F, S, B, Lin, D=', Cin, Cout, F, S, B, Lin, D) +conv1d = torch.nn.Conv1d(Cin, Cin, F, S, bias=False, padding=(F-1)*D, dilation=D, groups=Cin) +#conv1d = torch.nn.Conv1d(Cin, Cin, F, S, bias=False, padding=1, dilation=1, groups=Cin) +inputs = torch.randint(3, (B, Cin, Lin)) +conv1d.weight.data = torch.randint(5, conv1d.weight.size()) +outputs = conv1d(inputs) +# Lout = (Lin - F) / S + 1 + +print('weight', conv1d.weight.size()) +print('inputs', inputs.size()) +print('outputs', outputs.size()) + +print('inputs\n', inputs) +print('weight\n', conv1d.weight) +print('outputs\n', outputs) +print('chomp outputs\n', outputs[:,:,:-(F-1)*D]) + +# m = torch.nn.Conv1d(16, 33, 3, stride=2) +# print(m.weight.size()) +# input = torch.randn(20, 16, 50) +# output = m(input) +# print(output.size()) \ No newline at end of file -- 2.34.1