[Pytorch] 텐서 슬라이싱 Tensor Slicing
by Jeonghyeok Park
안녕하세요.
텐서 슬라이싱 Tensor Slicing
오늘 소개해드릴 내용은 pytorch에서 텐서를 슬라이싱하는 방법입니다.
tensor를 조작하는 것은 아주 중요하니 예시와 함께 하나하나 설명해드리겠습니다.
왜 tesnor slicing이 필요할까?
tensor slicing은 모델을 구현하는 과정에서 많이 사용됩니다. 2차원 matrix라면 중학교 때 행렬을 공부했으니 꽤 익숙하고 어느 정도 구조가 상상이 갑니다. 하지만 3차원, 4차원으로 넘어가면.. 헷갈리기 시작하죠.
예를 하나 들어보겠습니다. [batch_size, sequence_length, embedding_dimension]으로 구성된 3차원 matrix A이 있고 모델의 크기를 줄이려 sequence_length를 반토막 내야한다고 가정하겠습니다. (앞쪽 반토막만 사용) 그럼 어떻게 해야할까요? A의 shape을 (3,4,5)로 설정합니다. 최종 목표는 shape이 (3,2,5)이고 1 dim (sequence_length)의 윗쪽 반만 남아있어야합니다.
먼저 for문으로 한번 해보겠습니다.
[Code]
A = torch.arange(3*4*5).view(3,4,5)
print(A)
torch.stack([A[:,i,:] for i in range(A.size(1)) if i < A.size(1)//2], 0).transpose(0,1).view(3,2,5)
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49]]])
for문으로 하려니 매우 길고 더욱 복잡하네요.. 사실 이렇게 짜본적이 없어서 이게 for문으로 할 수 있는 최선인지는 잘 모르겠습니다. 하지만 목표는 달성했습니다.
slicing을 이용해봅시다.
[Code]
A[:,:A.size(1)//2,:]
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49]]])
for문을 이용하는 것에 비해 아주 깔끔하고 명쾌합니다. 가독성도 좋죠. 이 예시는 꼼꼼하게 보실 필요 없습니다. slicing이 깔끔하고 가독성도 좋구나, 텐서 조작할 때 이걸 써야겠다 라는 느낌을 받으셨다면 좋겠습니다.
설명에 필요한 tensor.shape, tensor.size에 대해
[Code]
A = torch.arange(3*4*5).view(3,4,5)
A.shape, A.size(0)
[Output]
(torch.Size([3, 4, 5]), 3)
텐서의 size를 출력해주는 내장함수입니다.
그럼 본격적으로 tensor에서 어떤 걸 뽑아낼지(목표), slicing으로 구현 시작해보겠습니다.
목표1 : [3,4] shape의 2차원 매트릭스에서 (0~2) 열만 추출하라
[Code]
A = torch.arange(3*4).view(3,4)
print(A)
A[:, 0:3]
[Output]
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
목표2 : [3,4] shape의 2차원 매트릭스에서 1, 3 열만 추출하라
[Code]
A = torch.arange(3*4).view(3,4)
print(A)
A[:, [1,3]]
[Output]
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 1, 3],
[ 5, 7],
[ 9, 11]])
목표3: [3,4] shape의 2차원 매트릭스에서 1,2 열의 순서를 바꾸어라
[Code]
A = torch.arange(3*4).view(3,4)
print(A)
A[:, [0,2,1,3]]
[Output]
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 0, 2, 1, 3],
[ 4, 6, 5, 7],
[ 8, 10, 9, 11]])
목표4: [3,4] shape의 2차원 매트릭스에서 각 행에서 3,2,1 열의 원소를 뽑아내라 -> [3,6,9] 가 나와야 합니다.
[Code]
A = torch.arange(3*4).view(3,4)
print(A)
A[torch.arange(A.size(0)), [3,2,1]]
[Output]
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([3, 6, 9])
목표5: [3,4,5] shape의 3차원 매트릭스에서 1dim의 마지막 row만 뽑아내라
[Code]
A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,3] # == A[:,3,:]
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([[15, 16, 17, 18, 19],
[35, 36, 37, 38, 39],
[55, 56, 57, 58, 59]])
목표6: [3,4,5] shape의 3차원 매트릭스에서 2dim의 마지막 원소만 뽑아내라
[Code]
A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,-1,-1]
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([19, 39, 59])
목표7: [3,4,5] shape의 3차원 매트릭스에서 1dim의 1,3 번째 row의 2번째 원소만 뽑아내라
[Code]
A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,[1,3],2]
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([[ 7, 17],
[27, 37],
[47, 57]])
목표8: [3,4,5] shape의 3차원 매트릭스에서 2dim의 0, 4번째 원소들을 전부 뽑아내라
[Code]
A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,:,[0,-1]]
[Output]
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([[[ 0, 4],
[ 5, 9],
[10, 14],
[15, 19]],
[[20, 24],
[25, 29],
[30, 34],
[35, 39]],
[[40, 44],
[45, 49],
[50, 54],
[55, 59]]])
목표9: [2,2,3,3] shape의 4차원 매트릭스에서 2dim의 1low의 1번째 원소들을 전부 뽑아내라
[Code]
A = torch.arange(2*2*3*3).view(2,2,3,3)
print(A)
A[:,:,1,1]
[Output]
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]],
[[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]]]])
tensor([[ 4, 13],
[22, 31]])
목표10: [2,2,3,3] shape의 4차원 매트릭스에서 2dim에서 [1,2,0,1]에 해당하는 row를 뽑아내라
[Code]
A = torch.arange(2*2*3*3).view(2,2,3,3)
print(A)
torch.stack([A[0,torch.arange(A.size(1)), [1,2]], A[1,torch.arange(A.size(1)), [0,1]]])
[Output]
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]],
[[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]]]])
tensor([[[ 3, 4, 5],
[15, 16, 17]],
[[18, 19, 20],
[30, 31, 32]]])
이상 pytorch에서 tensor slicing에 대한 설명과 예시였습니다.
3차원, 4차원으로 넘어가면.. 머리 속으로 할 수 있는 분은 몇 없을 거라 생각합니다. 실제 모델에 사용되는 tensor들의 크기는 매우 크니 크기를 좀 줄여서 test해본 후 코드에 적용하는 것이 답이라 생각합니다. 그리고 더블 체크 개념으로 모델이 training과정에 exit()을 넣고 제대로 잘 되고 있나, 코드에서 받아온 tesnor와 실제 데이터와 비교해보는 것은 너무 좋습니다!
감사합니다.
Subscribe via RSS