[java] 하나로-MST/Kruskal/Prim (SWEA, 1251)

SWEA의 "하나로"문제(1251)를 Kruskal, Prim 알고리즘을 이용해 해결한다.

Posted by Seoyoung Lee on March 24, 2021 · 12 mins read

문제(Link) : 1251. 하나로


최소신장트리 문제로 Kruskal, Prim 알고리즘을 써서 해결해야하는 문제였다. 주어진 값들의 범위가 커 인접행렬 대신 인접 리스트를 구성했다.

참고자료 없이도 Kruskal, Prim 알고리즘을 구현할 수 있도록 암기가 필요할 듯하다.(ㅠㅠ)

1. Kruskal 알고리즘 이용

먼저 크루스칼은 Edge클래스를 만들어서 가능한 모든 간선을 PriorityQueue에 저장한다. (방향 없는 그래프이므로 내부 for문을 i+1부터 시작해서 중복을 없앤다.)

그리고 비용이 가장 작은 간선부터 Union-Find기법으로 사이클이 만들어지지 않도록 간선을 선택한다. 간선을 선택하는 개념이므로 cnt == N-1이면 종료한다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Solution_1251_Kruskal {
	
	static class Edge implements Comparable<Edge>{
		int from, to;
		long cost;
		
		Edge(int from, int to, long cost){
			this.from = from;
			this.to = to;
			this.cost = cost;
		}
		
		@Override
		public int compareTo(Edge o) {
			return Long.compare(this.cost, o.cost);
		}
	}

    static int[] parents;
	
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st;
        
		int T = Integer.parseInt(br.readLine());
		for (int tc = 1; tc <= T; tc++) {
			int N = Integer.parseInt(br.readLine());
			long[] X = new long[N];
			long[] Y = new long[N];
			
			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++)
				X[i] = Long.parseLong(st.nextToken());
			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++)
				Y[i] = Long.parseLong(st.nextToken());
			double E = Double.parseDouble(br.readLine());

			parents=new int[N];
			for (int i = 0; i < N; i++)	parents[i] = i;
			
			PriorityQueue<Edge> pq = new PriorityQueue<>();
			for (int i = 0; i < N; i++) {
				for (int j = i+1; j < N; j++) {
					long L = (X[i]-X[j])*(X[i]-X[j]) + (Y[i]-Y[j])*(Y[i]-Y[j]);
					pq.add(new Edge(i, j, L));
				}
			}
			
			long ans = 0;
			int cnt = 0;
			while(!pq.isEmpty()) {
				Edge e = pq.poll();
				if(union(e.from, e.to))	continue;
				ans += e.cost;
				if(++cnt == N-1)	break;
			}
			System.out.println("#"+tc+" "+Math.round(ans*E));
		}
	}
	
	static int find(int x){
		if(parents[x] == x) return x;
		return parents[x] = find(parents[x]);
	}
	
	static boolean union(int a,int b){
		int pa = find(a);
		int pb = find(b);
		if(pa != pb){ 
			parents[pb] = pa;
			return false;
		}
		return true;
	}
}

2. Prim 알고리즘 이용

kruskal보다 짧지만 더 복잡했던거 같다. kruskal에서는 Edge 클래스를 만들었지만 Prim은 Node클래스를 사용한다.

그리고 LinkedList[] list를 사용해 모든 간선을 넣어준다. 여기서는 연결되어진 두 노드에서 서로에게 갈 수 있도록 같은 간선이어도 중복해서 넣어준다.

위와 마찬가지로 PriorityQueue를 이용하는데 0번 노드부터 시작하도록 Node(0,0)을 초기값으로 준다.

그리고 방문체크를 하면서 방문한 노드들과 연결된 노드들의 비용이 작은 것 부터 선택한다. 노드를 선택하는 개념이므로 cnt == N이면 종료한다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Solution_1251_Prim {
	static class Node implements Comparable<Node>{
		int to;
		long cost;
		Edge(int to, long cost){
			this.to = to;
			this.cost = cost;
		}
		
		@Override
		public int compareTo(Node o) {
			return Long.compare(this.cost, o.cost);
		}
	}
	
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st;
        
		int T = Integer.parseInt(br.readLine());
		for (int tc = 1; tc <= T; tc++) {
			int N = Integer.parseInt(br.readLine());
			long[] X = new long[N];
			long[] Y = new long[N];
			boolean[] visit = new boolean[N];

			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++)
				X[i] = Long.parseLong(st.nextToken());
			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++)
				Y[i] = Long.parseLong(st.nextToken());
			double E = Double.parseDouble(br.readLine());

			LinkedList<Node>[] list = new LinkedList[N];    // 가능한 모든 간선의 비용을 저장

			for (int i = 0; i < N; i++) {
				list[i] = new LinkedList<>();
				for (int j = 0; j < N; j++) {
					if(i == j)	continue;
					long L = (X[i]-X[j])*(X[i]-X[j]) + (Y[i]-Y[j])*(Y[i]-Y[j]);
					list[i].add(new Node(j, L));
				}
			}
			
			PriorityQueue<Node> pq = new PriorityQueue<>();
			pq.add(new Node(0, 0));
			long ans = 0;
			int cnt = 0;
			while(!pq.isEmpty()) {
				Node n = pq.poll();
				if(visit[n.to])	continue;
				visit[n.to] = true;
				ans += n.cost;
				if(++cnt == N)	break;
				
				for (Node node : list[n.to]) {
					if(!visit[node.to])	pq.add(node);
				}
			}
			System.out.println("#"+tc+" "+Math.round(ans*E));
		}
	}
}