PS

백준 11659번: 구간 합 구하기 4 [실버3] - Python

Alsong 2023. 11. 22. 20:11

https://www.acmicpc.net/problem/11659

 

11659번: 구간 합 구하기 4

첫째 줄에 수의 개수 N과 합을 구해야 하는 횟수 M이 주어진다. 둘째 줄에는 N개의 수가 주어진다. 수는 1,000보다 작거나 같은 자연수이다. 셋째 줄부터 M개의 줄에는 합을 구해야 하는 구간 i와 j

www.acmicpc.net

 

풀이

처음 문제를 보자마자 든 생각은 '너무 쉬운데 이게 왜 실버3이지?' 라는 생각이었네요. 그냥 리스트 슬라이스 쓰면 되는 거 아닌가 하구요. 하지만 수의 개수 N도 최대 100,000이고, 합을 구해야 하는 개수 M도 최대 100,000이기 때문에 그 방법으로는 100,000 x 100,000 크기의 연산을 해야 하기 때문에 시간 초과가 발생하게 됩니다.. (항상 조건 안 보고 시간초과 나고 나서야 조건 확인하는 나란 녀석..)

 

시간 초과 코드

n, m = map(int, input().split())
nums = list(map(int, input().split()))

for _ in range(m):
    i, j = map(int, input().split())
    print(sum(nums[i-1:j]))

 

그럼 이 문제는 어떻게 풀어야 할까요? 누적 합이라는 알고리즘을 쓰면 됩니다. 

 

 

누적 합(Prefix Sum) 알고리즘

이것은 무엇이냐. 어떤 수열에서 각각의 인덱스까지 구간의 합을 구하는 것을 의미합니다. 어떠한 리스트 arr가 있다고 했을 때 이 리스트의 누적 합을 prefixSum이라고 한다면, prefixSum[i]arr[0] + arr[1] + ... + arr[i]를 의미합니다.

 

그럼 이 누적 합 알고리즘을 통해서 이 문제를 어떻게 풀면 될까요? 이 문제는 어떠한 수열의 구간 합을 구하는 문제이니, 누적 합에서 누적 합을 빼면 됩니다.

무슨 말이냐면 어떤 리스트 arr가 있다고 했을 때 이 리스트의 i인덱스 부터 j인덱스 까지의 합이 j까지의 누적합과 i-1까지의 누적합의 차라는 뜻입니다. 예를 들어 인덱스 2 ~ 4의 구간의 합은 0 ~ 4의 합에서 0 ~ 1의 합을 뺀 것과 같겠죠.

 

코드로 표현하면 다음과 같습니다,

sum(arr[i : j]) == prefixSum[j] - prefixSum[i-1]

 

 

정답 코드

import sys
input = sys.stdin.readline
n, m = map(int, input().split())
nums = list(map(int, input().split()))

# 누적 합 계산
prefixSum = [0]*(n+1)
prefixSum[1] = nums[0]
for k in range(2, n+1):
    prefixSum[k] = prefixSum[k-1] + nums[k-1]

# 누적 합으로 구간 합 구하기
for _ in range(m):
    i, j = map(int, input().split())
    print(prefixSum[j] - prefixSum[i-1])

 

누적 합 알고리즘을 사용하지 않고 하나하나 모두 계산한다면 시간복잡도는 O(NM)이 됩니다. M개의 구간합 계산에 대하여 O(N)의 시간복잡도를 가지는 슬라이스를 각각 해야 하기 때문이죠. N과 M이 모두 최대 100,000이기 때문에 이 방법으로는 시간복잡도가 너무 높아져버립니다..!

 

하지만 누적 합 알고리즘을 사용한다면, 누적 합 계산 부분에서 O(N)의 시간이 소요되고, 구간 합 계산 부분에서 O(M)의 시간이 소요됩니다. 결과적으로는 O(N+M)의 시간복잡도가 되는 것이죠.

참고로 O(N+M)은 시간복잡도가 N과 M의 크기에 동시에 영향을 받는다는 것이 아니라 N과 M 중 더 큰 값에 따라 시간복잡도가 결정된다는 뜻입니다. 즉 이 코드의 시간복잡도는 N>M일 경우 O(N), M>N일 경우 O(M)이 됩니다. 

 

 

여담

참고로 이 문제도 입력을 많이 받아야 하기 때문에 sys.stdin.readline()의 사용이 필수적입니다. 그냥 input()을 통해 입력을 받으면 누적 합 알고리즘을 사용했더라도 시간 초과가 뜨게 됩니다. 

 

그런데 다른 분들의 코드를 보던 중 PyPy 코드는 sys를 쓰지 않고도 정답 처리가 된 것을 발견했습니다?!

 

이것이 PyPy의 위엄인가..

 

처음엔 '대체 어떤 코드길래 4836ms나 걸리지?'라고 생각했는데요, sys를 안 썼더군요.. python3에서는 sys를 쓰지 않으면 얄짤없이 시간 초과인 반면 PyPy3은 sys를 쓰지 않아도 오래 걸리긴 해도 정답 처리를 받는 모습이었습니다. 정말 PyPy가 빠르긴 한가 보군요..

 

Python3 / sys 사용안함

 

PyPy3 / sys 사용안함

 

Python3 / sys 사용

 

PyPy3 / sys 사용

 

PyPy가 이렇게 빠르구나 라는걸 알게 해 준 시간이었습니다. (중요한 건 아니지만 신기하군요)