1830E - Bully Sort - CodeForces Solution


data structures math

Please click on ads to support us..

C++ Code:

#include <bits/extc++.h>
#include <bits/stdc++.h>

#ifdef ZAYIN
#include "dbg/dbg.h"
#else
#define dbg(...)
#endif

long long CNT = 0;
long long CNT2 = 0;
long long CNT3 = 0;

struct Info {
  int x;
  int val;
  Info(int x_, int v_ = 0) : x(x_), val(v_) {}
  friend bool operator<(const Info& lhs, const Info& rhs) { return lhs.x < rhs.x; }
  friend bool operator==(const Info& lhs, const Info& rhs) { return lhs.x == rhs.x; }
  friend bool operator<=(const Info& lhs, const Info& rhs) { return lhs.x <= rhs.x; }
  friend std::ostream& operator<<(std::ostream& os, const Info& info) {
    return os << "Info(" << info.x << ")";
  }
};

struct AccumulatedInfo {
  long long sum;
  AccumulatedInfo(Info info) : sum(info.val){};
  AccumulatedInfo(long long s = 0) : sum(s){};
  friend AccumulatedInfo operator+(const AccumulatedInfo& lhs, const AccumulatedInfo& rhs) {
    return AccumulatedInfo(lhs.sum + rhs.sum);
  }
  friend std::ostream& operator<<(std::ostream& os, const AccumulatedInfo& accumulated_info) {
    return os << "AccumulatedInfo(" << accumulated_info.sum << ")";
  }
};

template <typename Info, typename AccumulatedInfo>
class treap {
 public:
  struct node {
    Info info;
    AccumulatedInfo accumulated_info;
    using priority_type = unsigned long long;
    priority_type priority;
    size_t size;
    node *lson, *rson, *parent;
    static priority_type get_priority() {
      static std::mt19937_64 engine(std::chrono::steady_clock::now().time_since_epoch().count());
      static std::uniform_int_distribution<priority_type> rng(
          0, std::numeric_limits<priority_type>::max());
      return rng(engine);
    }

    node(Info info_)
        : info(info_),
          accumulated_info(info_),
          priority(get_priority()),
          size(1),
          lson(nullptr),
          rson(nullptr),
          parent(nullptr) {}

    node* maintain() {
      ++CNT3;
      parent = nullptr;
      accumulated_info = AccumulatedInfo(info);
      size = 1;
      if (lson) {
        accumulated_info = lson->accumulated_info + accumulated_info;
        size += lson->size;
        lson->parent = this;
      }
      if (rson) {
        accumulated_info = accumulated_info + rson->accumulated_info;
        size += rson->size;
        rson->parent = this;
      }
      return this;
    }

    static size_t get_size(node* x) {
      if (!x) return 0;
      return x->size;
    }

    static node* merge(node* x, node* y) {
      if (!x) return y;
      if (!y) return x;
      if (x->priority > y->priority) {
        x->rson = merge(x->rson, y);
        return x->maintain();
      } else {
        y->lson = merge(x, y->lson);
        return y->maintain();
      }
    }

    // split such that predicate(all node in x)=true
    // precondition: if predicate(rt[i])=false, predicate(rt[j])=false for i<=j
    static void split_by_info(node* rt,
                              node*& x,
                              node*& y,
                              const std::function<bool(Info)>& predicate) {
      if (!rt) {
        x = nullptr;
        y = nullptr;
        return;
      }
      if (predicate(rt->info)) {
        x = rt;
        split_by_info(rt->rson, x->rson, y, predicate);
      } else {
        y = rt;
        split_by_info(rt->lson, x, y->lson, predicate);
      }
      rt->maintain();
    }

    // split such that predicate(x->accumulated)=true
    // precondition: if predicate(rt[:i])=false, predicate(rt[:j])=false for
    // i<=j
    static void split_by_accumulated_info(node* rt,
                                          node*& x,
                                          node*& y,
                                          const std::function<bool(AccumulatedInfo)>& predicate) {
      if (!rt) {
        x = nullptr;
        y = nullptr;
        return;
      }

      auto accumulated_info_with_lson = AccumulatedInfo(rt->info);
      if (rt->lson) {
        accumulated_info_with_lson = rt->lson->accumulated_info + accumulated_info_with_lson;
      }
      if (predicate(accumulated_info_with_lson)) {
        x = rt;
        split_by_info(rt->rson, x->rson, y, predicate);
      } else {
        y = rt;
        split_by_info(rt->lson, x, y->lson, predicate);
      }
      rt->maintain();
    }

    // split such that x->size=size
    static void split_by_size(node* rt, node*& x, node*& y, size_t size) {
      if (!rt) {
        x = nullptr;
        y = nullptr;
        return;
      }

      size_t size_with_lson = node::get_size(rt->lson) + 1;
      if (size_with_lson <= size) {
        x = rt;
        split_by_size(rt->rson, x->rson, y, size - size_with_lson);
      } else {
        y = rt;
        split_by_size(rt->lson, x, y->lson, size);
      }
      rt->maintain();
    }

    // index from 0
    static node* get_kth(node* rt, int k) {
      assert(0 <= k && k < rt->size);
      for (++k;;) {
        size_t size_with_lson = node::get_size(rt->lson) + 1;
        if (size_with_lson == k) return rt;
        if (size_with_lson <= k) {
          k -= size_with_lson;
          rt = rt->rson;
        } else {
          rt = rt->lson;
        }
      }
    }
  };

  // had better to guarantee there's only one iterator in every moment
  struct iterator {
    node* rt;
    int rank;
    iterator(node* rt_, int rank_) : rt(rt_), rank(rank_) {
      assert(0 <= rank && rank <= node::get_size(rt));
    }
    iterator& operator++() { return *this = iterator(rt, rank + 1); }
    iterator operator++(int) {
      iterator res = *this;
      ++*this;
      return res;
    }
    iterator& operator--() { return *this = iterator(rt, rank - 1); }
    iterator operator--(int) {
      iterator res = *this;
      --*this;
      return res;
    }
    friend bool operator==(const iterator& lhs, const iterator& rhs) {
      return lhs.rt == rhs.rt && lhs.rank == rhs.rank;
    }
    friend bool operator!=(const iterator& lhs, const iterator& rhs) {
      return lhs.rt != rhs.rt || lhs.rank != rhs.rank;
    }

    int get_rank() { return rank; }
    Info info() { return node::get_kth(rt, rank)->info; }
    AccumulatedInfo prefix_accumulated_info() {
      node *l, *r;
      node::split_by_size(rt, l, r, rank + 1);
      assert(l->size > 0);
      AccumulatedInfo res = l->accumulated_info;
      rt = node::merge(l, r);
      return res;
    }
    AccumulatedInfo suffix_accumulated_info() {
      node *l, *r;
      node::split_by_size(rt, l, r, rank);
      assert(r->size > 0);
      AccumulatedInfo res = r->accumulated_info;
      rt = node::merge(l, r);
      return res;
    }
  };

  treap(node* rt_ = nullptr) : rt(rt_) {}
  ~treap() {
    std::function<void(node*)> delete_node = [&](node* rt) {
      if (!rt) return;
      delete_node(rt->lson);
      delete_node(rt->rson);
      delete rt;
    };
    delete_node(rt);
  }

  void insert(Info v) {
    node *l, *r;
    node::split_by_info(rt, l, r, [&](Info x) { return x < v; });
    node* mid = new node(v);
    rt = node::merge(node::merge(l, mid), r);
  }
  void erase(Info v, bool erase_all = false) {
    node *l, *mid, *r;
    node::split_by_info(rt, l, r, [&](Info x) { return x < v; });
    node::split_by_info(r, mid, r, [&](Info x) { return x <= v; });
    if (mid) {
      if (erase_all) {
        mid = nullptr;
      } else {
        mid = node::merge(mid->lson, mid->rson);
      }
    }
    rt = node::merge(node::merge(l, mid), r);
  }

  iterator get_kth(int k) { return iterator(rt, k); }

  // the first one >=v
  iterator lowerbound(Info v) {
    node *l, *r;
    node::split_by_info(rt, l, r, [&](Info x) { return x < v; });
    int rank = node::get_size(l);
    rt = node::merge(l, r);
    return iterator(rt, rank);
  }

  // the first one >v
  iterator upperbound(Info v) {
    node *l, *r;
    node::split_by_info(rt, l, r, [&](Info x) { return x <= v; });
    int rank = node::get_size(l);
    rt = node::merge(l, r);
    return iterator(rt, rank);
  }

  AccumulatedInfo lesum(Info v) {
    node *l, *r;
    node::split_by_info(rt, l, r, [&](Info x) { return x <= v; });
    AccumulatedInfo result;
    if (l) result = l->accumulated_info;
    int rank = node::get_size(l);
    rt = node::merge(l, r);
    return result;
  }

  iterator begin() { return iterator(rt, 0); }
  iterator end() { return iterator(rt, rt->size); }

 private:
  node* rt;
};

template <typename T>
using ordered_set = __gnu_pbds::tree<T,
                                     __gnu_pbds::null_type,
                                     std::less<T>,
                                     __gnu_pbds::rb_tree_tag,
                                     __gnu_pbds::tree_order_statistics_node_update>;

// 0-index based
template <typename T>
class fenwick_tree_2d {
 public:
  fenwick_tree_2d(int n_, int m_) : n(n_), tree(n_) {}

  void add(int x, int y, T v) {
    for (int i = x + 1; i <= n; i += i & -i) {
      ++CNT;
      if (v == 1) {
        tree[i - 1].insert(y);
      } else {
        tree[i - 1].erase(y);
      }
    }
  }

  T lowerbound_sum(int x, int y) {
    auto ans = T();
    for (int i = x + 1; i > 0; i -= i & -i) {
      ++CNT;
      ans += tree[i - 1].order_of_key(y + 1);
    }
    return ans;
  }

  T columnar_sum(int x, int ly, int ry) {
    return lowerbound_sum(x, ry) - lowerbound_sum(x, ly - 1);
  }

  T range_sum(int lx, int rx, int ly, int ry) {
    if (lx > rx) return T();
    return columnar_sum(rx, ly, ry) - columnar_sum(lx - 1, ly, ry);
  }

 private:
  int n;
  std::vector<ordered_set<int>> tree;
};

struct DS2d {
  DS2d(int n_) : n(n_), point(n_, n_) {}

  long long get_sum() { return sum; }

  void change_point(int x, int y, int d) {
    int t = point.range_sum(0, x - 1, y + 1, n - 1) + point.range_sum(x + 1, n - 1, 0, y - 1);
    sum += t * d;
    point.add(x, y, d);
  };

  int n;
  long long sum = 0;
  fenwick_tree_2d<long long> point;
};

namespace fastio {

class ostream;

class ostream_control_char {
 public:
  ostream_control_char(const auto& handler) : handler_(handler) {}

 private:
  const std::function<void(ostream&)> handler_;
  friend class ostream;
};

class ostream {
 public:
  ostream(FILE* file) : file(file) {}
  ~ostream() { reflesh(); }

  template <typename T>
  ostream& operator<<(T t) {
    write_single(t);
    return *this;
  }

  ostream& operator<<(const ostream_control_char& oct) {
    oct.handler_(*this);
    return *this;
  }

  inline void flush() {
    reflesh();
    fflush(file);
  }
  inline void reflesh() {
    fwrite(buffer, sizeof(char), pointer - buffer, file);
    pointer = buffer;
  }

 private:
  inline void write_single(const char& c) {
    if (pointer == buffer + buffer_size) reflesh();
    *(pointer++) = c;
  }
  inline void write_single(const char* s) { write_single(std::string(s)); }
  inline void write_single(const std::string& s) {
    for (char c : s) write_single(c);
  }
  template <typename AT, typename std::enable_if_t<std::is_arithmetic<AT>::value, int> = 0>
  inline void write_single(const AT a) {
    return write_single(std::to_string(a));
  }

  FILE* file;
  static const int buffer_size = 1 << 23;
  char buffer[buffer_size], *pointer = buffer;
  friend class ostream_control_char;
};
ostream cout(stdout);

const ostream_control_char endl([](ostream& os) {
  os << '\n';
  os.flush();
});

class istream : public std::istream {
 public:
  istream(FILE* file) : file(file) {}
  template <typename AT, typename std::enable_if_t<std::is_arithmetic<AT>::value, int> = 0>
  istream& operator>>(AT& a) {
    AT abs_value = 0;
    int sign = 1;
    char c;
    do
      c = read_single();
    while (!std::isdigit(c) && c != '-');
    if (c == '-') {
      sign = -1;
      c = read_single();
    }
    for (; std::isdigit(c); c = read_single()) {
      abs_value = (abs_value << 3) + (abs_value << 1) + (c ^ '0');
    }
    a = sign * abs_value;
    return *this;
  }

  operator bool() { return !std::feof(file); }

 private:
  inline bool reflesh() {
    if (pointer == buffer + buffer_size) {
      if (!fread(buffer, sizeof(char), buffer_size, file)) return false;
      pointer = buffer;
    }
    return true;
  }

  inline char read_single() {
    if (!reflesh()) return EOF;
    return *(pointer++);
  }

  FILE* file;
  static const int buffer_size = 1 << 23;
  char buffer[buffer_size], *pointer = buffer;
};

istream cin(stdin);
}  // namespace fastio

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  std::cout.tie(nullptr);
  int n, m;
  // std::cin >> n >> m;
  fastio::cin >> n >> m;
  std::vector<int> a(n);
  for (int i = 0; i < n; ++i) {
    // std::cin >> a[i];
    fastio::cin >> a[i];
    --a[i];
  }

  // sum point <= query
  DS2d ds2d(n);
  long long step_sum = 0;
  auto modify = [&](int i, int d) {
    step_sum += std::abs(i - a[i]) * d;
    ds2d.change_point(i, a[i], d);
  };
  dbg(a);
  for (int i = 0; i < n; ++i) {
    modify(i, 1);
    dbg(i, ds2d.get_sum());
    // std::cerr << i << " " << CNT << " " << CNT2 << " " << CNT3 << std::endl;
  }
  while (m--) {
    int x, y;
    // std::cin >> x >> y;
    fastio::cin >> x >> y;
    --x, --y;
    modify(x, -1);
    modify(y, -1);
    std::swap(a[x], a[y]);
    modify(x, 1);
    modify(y, 1);
    long long ans = step_sum - ds2d.get_sum();
    dbg(step_sum, ds2d.get_sum());
    fastio::cout << ans << '\n';
  }
  return 0;
}
/*
[README BEFORE SUBMISSION]
1. should use long long?
2. is sum of n/m/q guaranteed if multiple testcase enable?
*/


Comments

Submit
0 Comments
More Questions

1613B - Absent Remainder
1536B - Prinzessin der Verurteilung
1699B - Almost Ternary Matrix
1545A - AquaMoon and Strange Sort
538B - Quasi Binary
424A - Squats
1703A - YES or YES
494A - Treasure
48B - Land Lot
835A - Key races
1622C - Set or Decrease
1682A - Palindromic Indices
903C - Boxes Packing
887A - Div 64
755B - PolandBall and Game
808B - Average Sleep Time
1515E - Phoenix and Computers
1552B - Running for Gold
994A - Fingerprints
1221C - Perfect Team
1709C - Recover an RBS
378A - Playing with Dice
248B - Chilly Willy
1709B - Also Try Minecraft
1418A - Buying Torches
131C - The World is a Theatre
1696A - NIT orz
1178D - Prime Graph
1711D - Rain
534A - Exam