안녕하세요.

텐서 슬라이싱 Tensor Slicing

오늘 소개해드릴 내용은 pytorch에서 텐서를 슬라이싱하는 방법입니다.

tensor를 조작하는 것은 아주 중요하니 예시와 함께 하나하나 설명해드리겠습니다.


왜 tesnor slicing이 필요할까?

tensor slicing은 모델을 구현하는 과정에서 많이 사용됩니다. 2차원 matrix라면 중학교 때 행렬을 공부했으니 꽤 익숙하고 어느 정도 구조가 상상이 갑니다. 하지만 3차원, 4차원으로 넘어가면.. 헷갈리기 시작하죠.

예를 하나 들어보겠습니다. [batch_size, sequence_length, embedding_dimension]으로 구성된 3차원 matrix A이 있고 모델의 크기를 줄이려 sequence_length를 반토막 내야한다고 가정하겠습니다. (앞쪽 반토막만 사용) 그럼 어떻게 해야할까요? A의 shape을 (3,4,5)로 설정합니다. 최종 목표는 shape이 (3,2,5)이고 1 dim (sequence_length)의 윗쪽 반만 남아있어야합니다.

먼저 for문으로 한번 해보겠습니다.

[Code]

A = torch.arange(3*4*5).view(3,4,5)
print(A)
torch.stack([A[:,i,:] for i in range(A.size(1)) if i < A.size(1)//2], 0).transpose(0,1).view(3,2,5)

[Output]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])
tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49]]])

for문으로 하려니 매우 길고 더욱 복잡하네요.. 사실 이렇게 짜본적이 없어서 이게 for문으로 할 수 있는 최선인지는 잘 모르겠습니다. 하지만 목표는 달성했습니다.

slicing을 이용해봅시다.

[Code]

A[:,:A.size(1)//2,:]

[Output]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49]]])

for문을 이용하는 것에 비해 아주 깔끔하고 명쾌합니다. 가독성도 좋죠. 이 예시는 꼼꼼하게 보실 필요 없습니다. slicing이 깔끔하고 가독성도 좋구나, 텐서 조작할 때 이걸 써야겠다 라는 느낌을 받으셨다면 좋겠습니다.

설명에 필요한 tensor.shape, tensor.size에 대해

[Code]

A = torch.arange(3*4*5).view(3,4,5)
A.shape, A.size(0)

[Output]

(torch.Size([3, 4, 5]), 3)

텐서의 size를 출력해주는 내장함수입니다.


그럼 본격적으로 tensor에서 어떤 걸 뽑아낼지(목표), slicing으로 구현 시작해보겠습니다. 

목표1 : [3,4] shape의 2차원 매트릭스에서 (0~2) 열만 추출하라

[Code]

A = torch.arange(3*4).view(3,4)
print(A)
A[:, 0:3]

[Output]

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])

목표2 : [3,4] shape의 2차원 매트릭스에서 1, 3 열만 추출하라

[Code]

A = torch.arange(3*4).view(3,4)
print(A)
A[:, [1,3]]

[Output]

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11]])

목표3: [3,4] shape의 2차원 매트릭스에서 1,2 열의 순서를 바꾸어라

[Code]

A = torch.arange(3*4).view(3,4)
print(A)
A[:, [0,2,1,3]]

[Output]

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[ 0,  2,  1,  3],
        [ 4,  6,  5,  7],
        [ 8, 10,  9, 11]])

목표4: [3,4] shape의 2차원 매트릭스에서 각 행에서 3,2,1 열의 원소를 뽑아내라 -> [3,6,9] 가 나와야 합니다.

[Code]

A = torch.arange(3*4).view(3,4)
print(A)
A[torch.arange(A.size(0)), [3,2,1]]

[Output]

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([3, 6, 9])

목표5: [3,4,5] shape의 3차원 매트릭스에서 1dim의 마지막 row만 뽑아내라

[Code]

A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,3] # == A[:,3,:]

[Output]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])
tensor([[15, 16, 17, 18, 19],
        [35, 36, 37, 38, 39],
        [55, 56, 57, 58, 59]])

목표6: [3,4,5] shape의 3차원 매트릭스에서 2dim의 마지막 원소만 뽑아내라

[Code]

A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,-1,-1]

[Output]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])
tensor([19, 39, 59])

목표7: [3,4,5] shape의 3차원 매트릭스에서 1dim의 1,3 번째 row의 2번째 원소만 뽑아내라

[Code]

A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,[1,3],2]

[Output]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])
tensor([[ 7, 17],
        [27, 37],
        [47, 57]])

목표8: [3,4,5] shape의 3차원 매트릭스에서 2dim의 0, 4번째 원소들을 전부 뽑아내라

[Code]

A = torch.arange(3*4*5).view(3,4,5)
print(A)
A[:,:,[0,-1]]

[Output]


tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])
tensor([[[ 0,  4],
         [ 5,  9],
         [10, 14],
         [15, 19]],

        [[20, 24],
         [25, 29],
         [30, 34],
         [35, 39]],

        [[40, 44],
         [45, 49],
         [50, 54],
         [55, 59]]])

목표9: [2,2,3,3] shape의 4차원 매트릭스에서 2dim의 1low의 1번째 원소들을 전부 뽑아내라

[Code]

A = torch.arange(2*2*3*3).view(2,2,3,3)
print(A)
A[:,:,1,1]

[Output]

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]],


        [[[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]],

         [[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]]]])
tensor([[ 4, 13],
        [22, 31]])

목표10: [2,2,3,3] shape의 4차원 매트릭스에서 2dim에서 [1,2,0,1]에 해당하는 row를 뽑아내라

[Code]

A = torch.arange(2*2*3*3).view(2,2,3,3)
print(A)
torch.stack([A[0,torch.arange(A.size(1)), [1,2]], A[1,torch.arange(A.size(1)), [0,1]]])

[Output]

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]],


        [[[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]],

         [[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]]]])
tensor([[[ 3,  4,  5],
         [15, 16, 17]],

        [[18, 19, 20],
         [30, 31, 32]]])

이상 pytorch에서 tensor slicing에 대한 설명과 예시였습니다.

3차원, 4차원으로 넘어가면.. 머리 속으로 할 수 있는 분은 몇 없을 거라 생각합니다. 실제 모델에 사용되는 tensor들의 크기는 매우 크니 크기를 좀 줄여서 test해본 후 코드에 적용하는 것이 답이라 생각합니다. 그리고 더블 체크 개념으로 모델이 training과정에 exit()을 넣고 제대로 잘 되고 있나, 코드에서 받아온 tesnor와 실제 데이터와 비교해보는 것은 너무 좋습니다!

감사합니다.