PyTorch의 .contiguous()란 무엇인가..?
딥러닝 모델을 PyTorch로 구현하다 보면 x.contiguous()라는 함수를 만나는 경우가 많습니다.
처음에는 “그냥 넣어야 에러가 안 나네?”하고 지나치기 쉬운데, 실제로는 텐서의 메모리 구조와 관련된 중요한 개념입니다.
이 글에서는 .contiguous()가 왜 필요한지, 그리고 그 배경이 되는 stride(보폭) 개념까지 함께 알아보겠습니다.
.contiguous란?
PyTorch에서 .contiguous() 는 텐서를 메모리상에서 연속적인(Contiguous) 형태로 재배열해 주는 함수입니다.
PyTorch의 많은 연산 - 특히 .view() - 은 연속된 메모리 구조를 가진 텐서만 처리할 수 있습니다.
그런데 transpose(), permute() 등을 쓰면 내부 stride가 바뀌어서 비연속적인 텐서가 되기 때문에, .contiguous() 를 써서 메모리 정렬을 다시 맞춰주는 것입니다.
그런데 Stride는 뭐지?
Stride는 각 차원에서 인접한 원소로 이동할 때 메모리에서 건너뛰는 칸 수입니다.
PyTorch 텐서는 1차원 메모리에 저장됩니다. 다차원 텐서를 표현하기 위해 stride 값을 이용해 위치를 계산합니다.
예시:
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.stride()) # 출력: (3, 1)
이 의미는:
- dim 0 (행): 한 행 아래로 가려면 3칸을 건너뛰어야 하니까 stride[0] = 3
- dim 1 (열): 한 열 오른쪽으로 가려면 1칸만 이동하면 되니까 stride[1] = 1
즉, 내부 메모리 배치는 [1, 2, 3, 4, 5, 6]입니다.
🔄 transpose() 후에는?
x_t = x.transpose(0, 1)
print(x_t.stride()) # 출력: (1, 3)
이제 메모리 stride가 바뀌었습니다.
하지만 실제 데이터는 복사되지 않았고, stride만 바뀌었기 때문에 비연속 텐서가 됩니다.
❌ .view() 오류 발생
이렇게 transpose 된 텐서에 .view()를 사용하면 오류가 발생할 수 있습니다:
x_t = x.transpose(0, 1)
# 오류 발생 가능
x_t.view(-1)
# RuntimeError: view size is not compatible with input tensor’s size and stride...
.contiguous()로 해결
x_t = x.transpose(0, 1)
x_reshaped = x_t.contiguous().view(-1)
.contiguous()는 메모리를 실제로 복사해서 연속적인 형태로 정렬해 줍니다.
이제 .view()를 안전하게 사용할 수 있죠!
.view() vs .reshape() 차이
.view() | 메모리 순서를 그대로 재사용 | ✅ 필요함 |
.reshape() | 내부적으로 contiguous()를 호출해 필요시 복사 | ❌ 자동 해결됨 |
따라서 안전한 코드는 .reshape()를 사용하는 것이고,
성능을 더 세밀히 다루고 싶다면 .contiguous().view() 조합을 쓰면 됩니다.
실제로 Seq2Seq 코드 구현 과정에서 .contiguous()가 사용됩니다.
def merge_encoder_hiddens(self, encoder_hiddens):
# (n_layers * 2, batch_size, hidden_size // 2) -> (n_layers , batch_size, hidden_size)
h_0_tgt, c_0_tgt = encoder_hiddens
batch_size = h_0_tgt.size(1)
h_0_tgt = h_0_tgt.transpose(0, 1).contiguous().view(batch_size, -1, self.hidden_sze).transpose(0, 1).contiguous()
c_0_tgt = c_0_tgt.transpose(0, 1).contiguous().view(batch_size, -1, self.hidden_sze).transpose(0, 1).contiguous()
return h_0_tgt, c_0_tgt
위 코드는 Seq2Seq의 Encoder hidden state를 decoder에 넣기 위해 크기를 맞추는 코드입니다.
Encoder는 bidirectional LSTM을 통해 구현했기 때문에 이러한 과정이 발생합니다.
마무리
.contiguous()는 단순히 "넣으면 되네" 하고 넘어가기엔 꽤 중요한 개념입니다.
텐서의 메모리 구조와 stride에 대한 이해는, 딥러닝 모델 구현의 안정성과 성능에 직접적인 영향을 줍니다.
앞으로 PyTorch 코드에서 contiguous()를 만나면,
"아, 이건 메모리 정렬이 필요해서 쓰는 거구나!" 하고 이해할 수 있게 되셨길 바랍니다 🙌
읽어주셔서 감사합니다!!