백준 11659번: 구간 합 구하기 4 [실버3] - Python
https://www.acmicpc.net/problem/11659
11659번: 구간 합 구하기 4
첫째 줄에 수의 개수 N과 합을 구해야 하는 횟수 M이 주어진다. 둘째 줄에는 N개의 수가 주어진다. 수는 1,000보다 작거나 같은 자연수이다. 셋째 줄부터 M개의 줄에는 합을 구해야 하는 구간 i와 j
www.acmicpc.net
풀이
처음 문제를 보자마자 든 생각은 '너무 쉬운데 이게 왜 실버3이지?' 라는 생각이었네요. 그냥 리스트 슬라이스 쓰면 되는 거 아닌가 하구요. 하지만 수의 개수 N도 최대 100,000이고, 합을 구해야 하는 개수 M도 최대 100,000이기 때문에 그 방법으로는 100,000 x 100,000 크기의 연산을 해야 하기 때문에 시간 초과가 발생하게 됩니다.. (항상 조건 안 보고 시간초과 나고 나서야 조건 확인하는 나란 녀석..)
시간 초과 코드
n, m = map(int, input().split())
nums = list(map(int, input().split()))
for _ in range(m):
i, j = map(int, input().split())
print(sum(nums[i-1:j]))
그럼 이 문제는 어떻게 풀어야 할까요? 누적 합이라는 알고리즘을 쓰면 됩니다.
누적 합(Prefix Sum) 알고리즘
이것은 무엇이냐. 어떤 수열에서 각각의 인덱스까지 구간의 합을 구하는 것을 의미합니다. 어떠한 리스트 arr
가 있다고 했을 때 이 리스트의 누적 합을 prefixSum
이라고 한다면, prefixSum[i]
는 arr[0] + arr[1] + ... + arr[i]
를 의미합니다.
그럼 이 누적 합 알고리즘을 통해서 이 문제를 어떻게 풀면 될까요? 이 문제는 어떠한 수열의 구간 합을 구하는 문제이니, 누적 합에서 누적 합을 빼면 됩니다.
무슨 말이냐면 어떤 리스트 arr
가 있다고 했을 때 이 리스트의 i
인덱스 부터 j
인덱스 까지의 합이 j
까지의 누적합과 i-1
까지의 누적합의 차라는 뜻입니다. 예를 들어 인덱스 2 ~ 4의 구간의 합은 0 ~ 4의 합에서 0 ~ 1의 합을 뺀 것과 같겠죠.
코드로 표현하면 다음과 같습니다,
sum(arr[i : j]) == prefixSum[j] - prefixSum[i-1]
정답 코드
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
nums = list(map(int, input().split()))
# 누적 합 계산
prefixSum = [0]*(n+1)
prefixSum[1] = nums[0]
for k in range(2, n+1):
prefixSum[k] = prefixSum[k-1] + nums[k-1]
# 누적 합으로 구간 합 구하기
for _ in range(m):
i, j = map(int, input().split())
print(prefixSum[j] - prefixSum[i-1])
누적 합 알고리즘을 사용하지 않고 하나하나 모두 계산한다면 시간복잡도는 O(NM)이 됩니다. M개의 구간합 계산에 대하여 O(N)의 시간복잡도를 가지는 슬라이스를 각각 해야 하기 때문이죠. N과 M이 모두 최대 100,000이기 때문에 이 방법으로는 시간복잡도가 너무 높아져버립니다..!
하지만 누적 합 알고리즘을 사용한다면, 누적 합 계산 부분에서 O(N)의 시간이 소요되고, 구간 합 계산 부분에서 O(M)의 시간이 소요됩니다. 결과적으로는 O(N+M)의 시간복잡도가 되는 것이죠.
참고로 O(N+M)은 시간복잡도가 N과 M의 크기에 동시에 영향을 받는다는 것이 아니라 N과 M 중 더 큰 값에 따라 시간복잡도가 결정된다는 뜻입니다. 즉 이 코드의 시간복잡도는 N>M일 경우 O(N), M>N일 경우 O(M)이 됩니다.
여담
참고로 이 문제도 입력을 많이 받아야 하기 때문에 sys.stdin.readline()
의 사용이 필수적입니다. 그냥 input()
을 통해 입력을 받으면 누적 합 알고리즘을 사용했더라도 시간 초과가 뜨게 됩니다.
그런데 다른 분들의 코드를 보던 중 PyPy 코드는 sys를 쓰지 않고도 정답 처리가 된 것을 발견했습니다?!
처음엔 '대체 어떤 코드길래 4836ms나 걸리지?'라고 생각했는데요, sys를 안 썼더군요.. python3에서는 sys를 쓰지 않으면 얄짤없이 시간 초과인 반면 PyPy3은 sys를 쓰지 않아도 오래 걸리긴 해도 정답 처리를 받는 모습이었습니다. 정말 PyPy가 빠르긴 한가 보군요..
PyPy가 이렇게 빠르구나 라는걸 알게 해 준 시간이었습니다. (중요한 건 아니지만 신기하군요)