Link

https://www.acmicpc.net/problem/9007

Detail

네 수의 합 중에서 K와 가장 가까운 수를 찾는 간단한 문제다. 각 배열의 길이가 최대 1000이므로 1000 * 1000의 최대 길이를 가지는 배열을 만들어 각 두 배열의 합을 저장하고,

두 배열의 합으로 K와 가장 가까운 수를 Binary Search 를 사용해서 해결하면 문제를 풀 수 있다.

이때, 두 배열 (Sum1, Sum2)중에 Lower_Bound를 사용할 Sum2만 정렬해도 되지만, Sum1을 같이 정렬해 줄 경우에 캐시 적중률을 상승시켜 더 좋은 결과를 유도할 수 있다.

Code

#include<bits/stdc++.h>
using namespace std;
 
#define DEBUG false
 
inline int getCloser(int a, int b, int target) {
    int val_a = abs(a - target);
    int val_b = abs(b - target);
 
    if(val_a == val_b) {
        return min(a, b);
    }
 
    return val_a < val_b ? a : b;
}
 
int solve(int n, int k, const vector<vector<int>>& weights) {
    vector<int> sum1(n*n, 0);
    vector<int> sum2(n*n, 0);
 
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < n; j++) {
            sum1[i*n + j] = weights[0][i] + weights[1][j];
            sum2[i*n + j] = weights[2][i] + weights[3][j];
        }
    }
 
    // 캐시 적중률 향상을 위해 정렬
    sort(sum1.begin(), sum1.end());
    sort(sum2.begin(), sum2.end());
 
 
    // 이분탐색으로 sum[i] + sum[j] 중 target과 가장 가까운 값을 찾는다.
    // 거리가 같은 경우 더 작은 값을 선택한다.
    int answer = -4e8;
    for(int first : sum1) {
        int target = k - first;
        auto it = lower_bound(sum2.begin(), sum2.end(), target);
 
        if(it == sum2.begin()) {
            answer = getCloser(answer, first + *it, k);
            continue;
        }
 
        // 찾았을 경우
        if(it != sum2.end()) {
            int over = *it;
            int under = *(--it);
 
            answer = getCloser(answer, first + over, k);
            answer = getCloser(answer, first + under, k);
        } else {
            answer = getCloser(answer, first + *(it - 1), k);
        }
    }
 
    return answer;
}
 
 
int main() {
    int t_case;
    cin >> t_case;
 
    vector<vector<int>> weights(4);
    int target;
 
    while(t_case--) {
 
        if(DEBUG) cout << endl << "Case " << t_case << endl;
 
        int n, k;
        cin >> k >> n;
 
        if(DEBUG) cout << "n: " << n << " k: " << k << endl;
 
        for(int i = 0; i < 4; i++) {
            weights[i].clear();
            for (int j = 0; j < n; j++) {
                int w;
                cin >> w;
                weights[i].push_back(w);
            }
        }
 
        if(DEBUG) {
            for(int i = 0; i < 4; i++) {
                for (int j = 0; j < n; j++) {
                    cout << weights[i][j] << " ";
                }
                cout << endl;
            }
        }
 
        cout << solve(n, k, weights) << endl;
    }
    
}