IT Study/코딩테스트 by Python

[백준] 2751번 수 정렬하기2(정렬)_python (+시도과정, 예시답안)

짹짹체유 2023. 8. 8. 17:23

수 정렬하기2

분류: 정렬

문제

N개의 수가 주어졌을 때, 이를 오름차순으로 정렬하는 프로그램을 작성하시오.

 

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)이 주어진다. 둘째 줄부터 N개의 줄에는 수가 주어진다. 이 수는 절댓값이 1,000,000보다 작거나 같은 정수이다. 수는 중복되지 않는다.

 

출력

첫째 줄부터 N개의 줄에 오름차순으로 정렬한 결과를 한 줄에 하나씩 출력한다.

 

예제 입력 1

5
5
4
3
2
1

예제 출력 1

1
2
3
4
5

 


💡문제 해결 IDEA

1) 수는 중복되지 않기 때문에 계수 정렬은 x

2) 퀵 정렬로 시도 -> 메모리 초과

3) 삽입 정렬로 시도 -> 시간 초과

4) 계수 정렬로 시도 -> 걍 틀림....

5) 퀵 정렬을 다시 정리 -> 약 25%까지 가다가 메모리 초과

 

#1차 시도 - 퀵 정렬
import sys
N = int(input())
n_list = []
for _ in range(N):
    n_list.append(int(sys.stdin.readline()))

def quick(list):
    if len(list) <= 1:
        return list
    pivot = list[0]
    r_list = list[1:]
    left_list = [x for x in r_list if x <= pivot]
    right_list = [x for x in r_list if x > pivot]
    return quick(left_list) + [pivot] + quick(right_list)

for num in quick(n_list):
    print(num)

메모리 초과

-> 퀵 정렬은 최악의 경우 O(N^2)

 

#2차 시도 - 삽입정렬
import sys
N = int(input())
n_list = [int(sys.stdin.readline().strip()) for _ in range(N)]

for i in range(1, len(n_list)):
    for j in range(i, 0, -1):
        if n_list[j] < n_list[j-1]:
            n_list[j], n_list[j-1] = n_list[j-1], n_list[j]
        else:
            break

for num in n_list:
    print(num)

시간 초과

-> 삽입 정렬은 O(N^2)

 

 

# 3차 시도 - 계수 정렬
import sys
N = int(input())
n_list = [int(sys.stdin.readline().strip()) for _ in range(N)]

cnt_list = [0] * (N+1)

for i in n_list:
    cnt_list[i] = 1

for n in range(len(cnt_list)):
    if cnt_list[n] == 1:
        print(n)

걍 틀림

 

필독! FAQ

1. O(N^2)짜리 정렬 알고리즘을 사용하면 시간 초과 ex. 버블 정렬, 선택 정렬, 삽입 정렬

 -> O(NlogN) 이하의 복잡도를 갖는 정렬을 사용해야 함 ex. 병합 정렬, 힙 정렬

2. 퀵 정렬은 최악의 경우 O(N^2)임

 -> 피벗으로 중앙값의 중앙값 고르기, 재귀 깊어지면 다른 정렬 사용하기, 랜덤으로 섞은 뒤 수행하기 등으로 회피할 수 있음.

 -> 연습하기 위한 목적으로만 사용하고 어떤 알고리즘 문제에도 사용하지 않는 것이 좋음.

3. 힙 정렬은 복잡한 편. 힙 정렬이 요구하는 것이 무엇인지 정확히 알고 사용해야 함.

4. Python은 매우 느림. Pypy2나 PyPy3로 제출해볼 것

 

https://www.acmicpc.net/board/view/31887

 

글 읽기 - ★☆★☆★ [필독] 수 정렬하기 2 FAQ ★☆★☆★

댓글을 작성하려면 로그인해야 합니다.

www.acmicpc.net

 

# 4차 시도 - 병합 정렬
import sys
N = int(input())
n_list = [int(sys.stdin.readline().strip()) for _ in range(N)]

def merge_sort(arr):
    if len(arr) < 2:
        return arr
    
    mid = len(arr) // 2
    low_arr = merge_sort(arr[:mid])
    high_arr = merge_sort(arr[mid:])

    merge_arr = []
    i = j = 0
    while (i < len(low_arr)) and (j < len(high_arr)):
        if low_arr[i] <= high_arr[j]:
            merge_arr.append(low_arr[i])
            i += 1
        else:
            merge_arr.append(high_arr[j])
            j += 1
    merge_arr += low_arr[i:]
    merge_arr += high_arr[j:]

    return merge_arr

merge_list = merge_sort(n_list)
for num in merge_list:
    print(num)

정답

메모리: 100124KB, 시간: 4656ms

 

# 최적화
import sys
N = int(input())
n_list = [int(sys.stdin.readline().strip()) for _ in range(N)]

def merge_sort(arr):
    def sort(low, high):
        if high - low < 2:
            return arr
        mid = (low + high) // 2
        sort(low, mid)
        sort(mid, high)
        result = merge(low, mid, high)
        return result
        
    def merge(low, mid, high):
        temp = []
        i, j = low, mid

        while i < mid and j < high:
            if arr[i] < arr[j]:
                temp.append(arr[i])
                i += 1
            else:
                temp.append(arr[j])
                j += 1

        while i < mid:
            temp.append(arr[i])
            i += 1
        while j < high:
            temp.append(arr[j])
            j += 1
        
        for k in range(low, high):
            arr[k] = temp[k-low]
        return temp
    
    return sort(0, len(arr))

merge_list = merge_sort(n_list)
for num in merge_list:
    print(num)

정답

-> 인덱스를 활용해서 접근해서 메모리 활용이 적음 + 시간은 더 걸림

메모리: 83764KB, 시간: 5084ms

 

해당 부분을 왜 추가한 것인지 이해 x

=> 제거 후 merge_list를 출력하기 전 오름차순으로 정렬 후 하나씩 출력

 

# 최종답안
import sys
N = int(input())
n_list = [int(sys.stdin.readline().strip()) for _ in range(N)]

def merge_sort(arr):
    def sort(low, high):
        if high - low < 2:
            return arr
        mid = (low + high) // 2
        sort(low, mid)
        sort(mid, high)
        result = merge(low, mid, high)
        return result
        
    def merge(low, mid, high):
        temp = []
        i, j = low, mid

        while i < mid and j < high:
            if arr[i] < arr[j]:
                temp.append(arr[i])
                i += 1
            else:
                temp.append(arr[j])
                j += 1

        while i < mid:
            temp.append(arr[i])
            i += 1
        while j < high:
            temp.append(arr[j])
            j += 1
        
        return temp
    
    return sort(0, len(arr))

merge_list = merge_sort(n_list)
merge_list.sort()
for num in merge_list:
    print(num)

정답

메모리: 83764KB, 시간: 3188ms

-> 메모리사용은 같으나 시간 단축

 

+ PyPy3로 제출

메모리: 231344KB, 시간: 1380ms

-> 메모리사용은 증가했으나 시간은 많이 단축

 

참고자료

https://www.daleseo.com/sort-merge/

반응형