icpc_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ebi-fly13/icpc_library

:heavy_check_mark: BinaryTrie
(data_structure/BinaryTrie.hpp)

実装の際に省略できるもの

count , erase , order , get , xor_all の実装は他の機能と独立であり、省略できます。

説明

この BinaryTrie は非負整数の多重集合を管理します。

コンストラクタ

BinrayTrie<T, MAX_LOG> ()

T でbit長が MAX_LOG であるBinaryTrieを構築します。

Tint または ll が想定されています。

多重集合で扱われる値 x0 <= x < (1 << MAX_LOG) が満たされている必要があります。 ただし、 MAX_LOG 桁以上の桁が $0$ であるという過程で動作するため、正確には非負整数 x は $\bmod$ 1 << MAX_LOG で扱われます。

insert

void insert(T x)

多重集合に $x$ を $1$ つ挿入します。既にその要素が存在する場合でも、重複をカウントします。

erase

void erase(T x)

多重集合から $x$ を $1$ つ削除します。 $x$ が存在しない場合は何も起こりません。特にエラーを起こしません。

count

int count(T x = -1)

count() と呼んだ、あるいは x = -1 のとき、多重集合の要素数を返します。

そうでないとき、多重集合に含まれる x の個数を返します。

order

int order(T x)

多重集合に含まれる $x$ より小さい要素の個数を返します。

get

T get(int k)

0_indexed で $k$ 番目に小さい値を返します。 $k$ が多重集合の要素数以上である場合は最大値を返します。多重集合が空の場合は $0$ を返します。

xor_all

void xor_all(T x)

多重集合の要素をすべて $x$ との XOR をとった値に置き換えます。

使い方

int main(){
    BinaryTrie<int,31> bt;
    bt.insert(1);
    bt.insert(2); // multiset mst = {1,2}
    cout << bt.get(0) << endl; // 1
    bt.xor_all(2); // mst = {0,3}
    cout << bt.get(0) << endl; // 0
    bt.erase(2); // NO change
    cout << bt.count() << endl; // 2
    cout << bt.get(1) << endl; // 3
    bt.erase(0); // mst = {3}
}

Depends on

Verified with

Code

#pragma once

#include "../template/template.hpp"

namespace lib {

using namespace std;

template <typename T, int MAX_LOG>  // T = int/ll, 0 <= x < 2 ^ MAX_LOG
struct BinaryTrie {                 // set(multiset) of integer
    struct node {
        node *p;
        array<node *, 2> ch;
        int exist;  // number of item
        int sz;     // number of integers exist in the subtree of this node
        node() : p(nullptr), ch({nullptr, nullptr}), exist(0), sz(0) {}
    };
    BinaryTrie() : lazy(T(0)) {}
    int size(node *v) {
        if (v == nullptr) {
            return 0;
        }
        return v->sz;
    }
    int count(T x = -1) {
        node *v = root;
        if (x < 0) return v->sz;
        x ^= lazy;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                return 0;
            }
            v = v->ch[j];
        }
        return v->sz;
    }
    void insert(T x) {
        x ^= lazy;
        node *v = root;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                v->ch[j] = new node();
                v->ch[j]->p = v;
            }
            v = v->ch[j];
        }
        v->exist++;
        update(v);
        rep(i, 0, MAX_LOG) {
            v = v->p;
            update(v);
        }
    }
    void erase(T x) {
        x ^= lazy;
        node *v = root;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                return;
            }
            v = v->ch[j];
        }
        if (v->exist == 0) return;
        v->exist--;
        update(v);
        rrep(i, 0, MAX_LOG) {
            node *p = v->p;
            if (size(v) == 0) {
                if (v == p->ch[0])
                    p->ch[0] = nullptr;
                else
                    p->ch[1] = nullptr;
                delete v;
            }
            v = p;
            update(v);
        }
    }
    int order(T x) {  // number of element which is less than x
        node *v = root;
        int res = 0;
        rrep(i, 0, MAX_LOG) {
            int j = lazy >> i & 1;
            if ((x >> i & 1) == 0) {
                v = v->ch[j];
            } else {
                res += size(v->ch[j]);
                v = v->ch[j ^ 1];
            }
            if (v == nullptr) {
                break;
            }
        }
        return res;
    }
    T get(int k) {  // value of kth(0_indexed) element, order(get(k)) = k
        node *v = root;
        T ans = T(0);
        rrep(i, 0, MAX_LOG) {
            int j = lazy >> i & 1;
            if (k < size(v->ch[j])) {
                v = v->ch[j];
            } else {
                k -= size(v->ch[j]);
                v = v->ch[j ^ 1];
                ans |= T(1) << i;
            }
        }
        return ans;
    }
    void xor_all(T x) {
        lazy ^= x;
    }

  private:
    T lazy;
    node *root = new node();
    void update(node *v) {
        v->sz = v->exist + size(v->ch[0]) + size(v->ch[1]);
    }
};

}  // namespace lib
#line 2 "data_structure/BinaryTrie.hpp"

#line 2 "template/template.hpp"

#include <bits/stdc++.h>

#define rep(i, s, n) for (int i = (int)(s); i < (int)(n); i++)
#define rrep(i, s, n) for (int i = (int)(n)-1; i >= (int)(s); i--)
#define all(v) v.begin(), v.end()

using ll = long long;
using ld = long double;
using ull = unsigned long long;

template <typename T> bool chmin(T &a, const T &b) {
    if (a <= b) return false;
    a = b;
    return true;
}
template <typename T> bool chmax(T &a, const T &b) {
    if (a >= b) return false;
    a = b;
    return true;
}

namespace lib {

using namespace std;

}  // namespace lib

// using namespace lib;
#line 4 "data_structure/BinaryTrie.hpp"

namespace lib {

using namespace std;

template <typename T, int MAX_LOG>  // T = int/ll, 0 <= x < 2 ^ MAX_LOG
struct BinaryTrie {                 // set(multiset) of integer
    struct node {
        node *p;
        array<node *, 2> ch;
        int exist;  // number of item
        int sz;     // number of integers exist in the subtree of this node
        node() : p(nullptr), ch({nullptr, nullptr}), exist(0), sz(0) {}
    };
    BinaryTrie() : lazy(T(0)) {}
    int size(node *v) {
        if (v == nullptr) {
            return 0;
        }
        return v->sz;
    }
    int count(T x = -1) {
        node *v = root;
        if (x < 0) return v->sz;
        x ^= lazy;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                return 0;
            }
            v = v->ch[j];
        }
        return v->sz;
    }
    void insert(T x) {
        x ^= lazy;
        node *v = root;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                v->ch[j] = new node();
                v->ch[j]->p = v;
            }
            v = v->ch[j];
        }
        v->exist++;
        update(v);
        rep(i, 0, MAX_LOG) {
            v = v->p;
            update(v);
        }
    }
    void erase(T x) {
        x ^= lazy;
        node *v = root;
        rrep(i, 0, MAX_LOG) {
            int j = x >> i & 1;
            if (v->ch[j] == nullptr) {
                return;
            }
            v = v->ch[j];
        }
        if (v->exist == 0) return;
        v->exist--;
        update(v);
        rrep(i, 0, MAX_LOG) {
            node *p = v->p;
            if (size(v) == 0) {
                if (v == p->ch[0])
                    p->ch[0] = nullptr;
                else
                    p->ch[1] = nullptr;
                delete v;
            }
            v = p;
            update(v);
        }
    }
    int order(T x) {  // number of element which is less than x
        node *v = root;
        int res = 0;
        rrep(i, 0, MAX_LOG) {
            int j = lazy >> i & 1;
            if ((x >> i & 1) == 0) {
                v = v->ch[j];
            } else {
                res += size(v->ch[j]);
                v = v->ch[j ^ 1];
            }
            if (v == nullptr) {
                break;
            }
        }
        return res;
    }
    T get(int k) {  // value of kth(0_indexed) element, order(get(k)) = k
        node *v = root;
        T ans = T(0);
        rrep(i, 0, MAX_LOG) {
            int j = lazy >> i & 1;
            if (k < size(v->ch[j])) {
                v = v->ch[j];
            } else {
                k -= size(v->ch[j]);
                v = v->ch[j ^ 1];
                ans |= T(1) << i;
            }
        }
        return ans;
    }
    void xor_all(T x) {
        lazy ^= x;
    }

  private:
    T lazy;
    node *root = new node();
    void update(node *v) {
        v->sz = v->exist + size(v->ch[0]) + size(v->ch[1]);
    }
};

}  // namespace lib
Back to top page