1436C - Binary Search - CodeForces Solution

binary search combinatorics *1500

C++ Code:

#include <numeric> // std::iota

using namespace std;
using namespace __gnu_cxx;//for matrix power built_in function.
//#include "MATH_A.h"
//https://atcoder.jp/contests/dp for DP algorthim
//https://codeforces.com/blog/entry/95106 for all ALGOs in CP
#if 1
#define rep(q, w) for(q=0;q<w;q++)
#define pii std::pair<int,int>
#define pis std::pair<int,string>
#define psi std::pair<string,int>
#define vi std::vector<int>
#define vs std::vector<string>
#define vc std::vector<char>
#define lc std::list<char>
#define ls std::list<string>
#define li std::list<int>
#define pb push_back
#define popb pop_back
#define fastio std::ios_base::sync_with_stdio(false), std::cout.tie(0), std::cin.tie(0);
#define all(v) (v).begin(),(v).end()
#define sz(v) (v.size())
#define srtR(v) std::sort(v.rbegin(), v.rend())
#define srt(v) std::sort(v.begin(), v.end())
//#define F first
//#define S second
#define arrRange(ty, a, mn, mx) ty a##_[mx-mn+1], *a=(a##_)-mn
#define clr(a, i) memset(a,i,sizeof(a))

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef double db;
|                               |
|      WE ARE <  > team         |
|                               |
|  learn more,more and more     |
|                               |
const long double pi = 3.14159265359;
const long long MOD = 1e9 + 7;
const int OO = 0x3f3f3f3f;
const ll OOO = 0x3f3f3f3f3f3f3f3f;
const int dr[8] = {-1, 0, 1, 0, -1, -1, 1, 1};
const int dc[8] = {0, -1, 0, 1, -1, 1, 1, -1};

//namespace mathA {
/*double sinA(double dig) {
    return sin(dig * pi / 180);

double cosA(double dig) {
    return cos(dig * pi / 180);

double tanA(double dig) {
    int we = dig;
    if (we == dig) {
        if (we % 90 == 0) {
            if ((we / 90) % 2 == 1)std::cout << "ERROR\n";
            else return 0;
    return tan(dig * pi / 180);

double secA(double dig) {
    int qw = cos(dig * pi / 180);
    if (qw == 0)std::cout << "ERROR\n";
    else return (1 / qw);

double cosecA(double dig) {
    int qw = sin(dig * pi / 180);
    if (qw == 0)std::cout << "ERROR\n";
    else return (1 / qw);

double cotA(double dig) {
    int qw = tan(dig * pi / 180);
    if (qw == 0)std::cout << "ERROR\n";
    else return (1 / qw);

template<class T>
inline std::istream &operator>>(std::istream &input, std::vector<T> &V)
    for (auto &in: V)std::cin >> in;
    return input;

template<class T>
inline std::ostream &operator<<(std::ostream &out, std::vector<T> &V)
    for (auto &h: V)out << h << " ";
    return out;

bool isPrime(long long num)
    if (num < 0)num *= -1;
    if (num == 1 || num == 0)return false;
    else if (num == 2)return true;
        if (num % 2 == 0)
            return false;
            for (int x = 3; x * x <= num; x += 2)
                if (num % x == 0)
                    return false;
        return true;

double distAtoB(double A, double B)
    return hypot(A, B);

ll gcd(ll A, ll B)
    return !B ? A : gcd(B, A % B);

long long lcm(long long f, long long l)
    return (f / gcd(f, l) * l);

/*ll fact(ll number) {
    if (number < 0)number *= -1;
    if (number == 0)return 1;
    if (number == 1)return number;
    else return number * fact(number - 1);

ll fib(ll n)
    auto sqrt_5 = std::sqrt(5);
    if (n == 0)return 0;
    if (n == 1)return 1;
    //binet's fibonacci number formula
    return static_cast<ll>((std::pow(1 + sqrt_5, n) - std::pow(1 - sqrt_5, n)) / (std::pow(2, n) * sqrt_5));

ll cdiv(ll a, ll b)
    return a / b + ((a ^ b) > 0 && a % b);    // divide a by b rounded up
ll fdiv(ll a, ll b)
    return a / b - ((a ^ b) < 0 && a % b);    // divide a by b rounded down
ll rounding(ll a, ll b)
    return (a + b / 2) / b;

int getBit(int number, int ind)
    return ((1 << ind) & number);

int setBit(int number, int ind, bool val)
    return val ? ((1 << ind) | number) : (number & ~(1 << ind));

long long power(long long x, long long n)
    if (n == 0)return 1;
    //if (n == 1)return x;
    long long temp = power(x, n / 2);
    temp *= temp;
    if (n % 2)temp *= x;
    return temp;

ll powerMOD(ll x, ll n, ll m)
    if (n == 0)return (1LL);
    //if (n == 1)return x;
    ll temp = powerMOD(x, n / 2, m);
    temp = ((temp % m) * (temp % m)) % m;
    if (n % 2)temp = ((temp % m) * (x % m)) % m;
    return temp;

long long ETF(long long n)
    long long ans = 1;
    for (long long x = 2; x * x <= n; x++)
        if (n % x == 0)
            long long count = 0;
            while (n % x == 0)
                n /= x;
            ans *= (power(x, count - 1) * (x - 1));
    if (n != 1)ans *= (n - 1);
    return ans;

ll SUM(ll number)
    ll ans = 0;
    while (number)
        ans += number % 10;
        number /= 10;
    return ans;

std::vector<int> SieveLiner(int C)
    std::vector<int> primes(C + 1), lsPrime;
    for (int i = 2; i <= C; i++)
        if (primes[i] == 0)
            primes[i] = i;
        for (int j = 0; j < lsPrime.size() && i * lsPrime[j] <= C && lsPrime[j] <= primes[i]; j++)
            primes[i * lsPrime[j]] = lsPrime[j];
            //pr[i * lsPrime[j]]++;
    return primes;

//std::map<ll, ll> f;
//ll fibb(ll n)
//    if (f.count(n)) return f[n];
//    if (n == 0) return 0;
//    if (n == 1 || n == 2) return 1;
//    if (n % 2 == 0)
//    {
//        ll k = n / 2;
//        ll ret1 = fibb(k - 1), ret2 = fibb(k);
//        return f[n] = ((((2ll * ret1) % MOD + ret2) % MOD) * ret2) % MOD;
//    }
//    else
//    {
//        ll k = (n + 1) / 2;
//        ll ret1 = fibb(k - 1), ret2 = fibb(k);
//        return f[n] = ((ret1 * ret1) % MOD + (ret2 * ret2) % MOD) % MOD;
//    }

void generateNthrow(int N)
    // nC0 = 1
    int prev = 1;
    std::cout << prev;

    for (int i = 1; i <= N; i++)
        // nCr = (nCr-1 * (n - r + 1))/r
        int curr = (prev * (N - i + 1)) / i;
        std::cout << ", " << curr;
        prev = curr;

int logNM(int num1, int num2)

    int counter = 0;
    while (num2)
        num2 /= num1;
    return counter;

std::string bits(ll num, bool f)  //f means is that 32-type or 64-type.
    std::string str = "";
    int a = 30;
    if (f)
        a *= 2;
        a += 2;
    for (int x = a; x >= 0; x--)
        str += (num & (1LL << x)) ? '1' : '0';
    return str;

ll cntBit(ll num)
    ll a = 0;
    while (num)
        num = num & (num - 1);
    return a;

long long extend_gcd(long long a, long long b, long long &x, long long &y)
    if (b == 0)
        x = 1;
        y = 0;
        return a;
    long long d = extend_gcd(b, a % b, y, x);
    y -= a / b * x;
    return d;

long long mod_reverse(long long a, long long m)
    long long x, y;
    long long d = extend_gcd(a, m, x, y);
    if (d == 1)return (x % m + m) % m;
    else return -1;// if gcd(a,m) != 1 that means no reverse for 1/a and m

std::vector<int> repBase(int number, int base)
    std::vector<int> arr;
    while (number)
        arr.push_back(number % base);
        number /= base;
    reverse(arr.begin(), arr.end());
    return arr;


//const int N = 1e3 + 7, M = 1.5e7 + 5;

class DSU
    std::vector<int> p, sz, nodeUsed;
    int ncmp, mx;
    int Nodes;

    DSU() {}

    DSU(int nNodes)
        Nodes = nNodes;
        p = std::vector<int>(nNodes);
        sz = std::vector<int>(nNodes);

        iota(begin(p), end(p), 0);
        fill(begin(sz), end(sz), 1);
        ncmp = nNodes;
        mx = 1;

    int find(int u)
        return p[u] == u ? u : p[u] = find(p[u]);

    bool operator()(int u, int v)
        u = find(u), v = find(v);
        if (u == v)
            return 0;

        if (sz[u] > sz[v])
            swap(u, v);
        p[u] = v;
        sz[v] += sz[u];

        mx = std::max(mx, sz[v]);
        return 1;
    bool same(int u, int v)
        return find(u) == find(v);

    void clearNodes()
        for (auto x: nodeUsed)
            p[x] = x;

    int capacity(int node)
        return sz[find(node)];

/*=============================================== ~main code~ =======================================*/

const int N = 1e5,M = 2e4;

int binPow(int a, int n) {
    int res = 1;
    while (n) {
        if (n & 1)
            res = (1LL * res * a) % MOD;
        a = (1LL * a * a) % MOD;

        n >>= 1;
    return res;

void binarySearch(int n, int x_position, int &cnt_big, int &cnt_less) {
    int left = 0, right = n;

    while(left < right) {
        int middle = (left + right) / 2;
        if (x_position >= middle) {
            if (x_position != middle) cnt_less++;
            left = middle + 1;
        else if (x_position < middle){
            right = middle;

int P(int n, int k, const vector <long long> &fact, const vector <long long> &inv) {
    if (k > n) return 0;
    int multiply = (1LL * fact[n] ) % MOD;
    multiply = (1LL * multiply * inv[n - k]) % MOD;
    return multiply;

void run(){
    int n, x, x_position;
    long long ans = 0;

    cin >> n >> x >> x_position;
    vector <long long> fact(n + 1, 1LL);
    vector <long long> inv(n + 1, 1LL);
    for (int i = 1; i <= n; ++i) {
        fact[i] = (fact[i - 1] * i) % MOD;
        inv[i] = binPow(fact[i], MOD - 2);

    int cnt_big = 0, cnt_less = 0;
    binarySearch(n, x_position, cnt_big, cnt_less);

    int other = (n - cnt_big - cnt_less - 1);
    int can_big = n - x, can_less = x - 1;
    int countLess = P(can_less, cnt_less, fact, inv);
    int countBig = P(can_big, cnt_big, fact, inv);

    //countBig = (1LL * countBig * fact[cnt_big]) % MOD;
    //countLess = (1LL * countLess * fact[cnt_less]) % MOD;

    int multiply = (1LL * countBig * countLess) % MOD;
    multiply = (1LL * multiply * fact[other]) % MOD;

    ans = (ans + multiply) % MOD;

    cout << ans << endl;


int main()
#if 0
    freopen("./input.in", "r", stdin);
    //freopen("C:\\Users\\Alostaz\\CLionProjects\\comp\\output.txt", "w", stdout);

    clock_t before = clock();
    int t = 1;
    //std::cin >> t;

    for (int CASE = 1; CASE <= t; CASE++)
        //        std::cout << "Case #" << CASE << ": ";


        if (CASE != t)std::cout << "\n";
    /*int m,n;
        if(!n or !m)break;
        std::cout<<"Problem "<<t++<<":\n";

    clock_t after = clock();
    cerr << "\nRun Time: " << (after - before - 0.0) / CLOCKS_PER_SEC << "\n";




