前回に引き続き、LOUDSの実装。
LOUDSは木構造を、ビットベクトルで表現したが、このビットベクトルに関して簡単な計算を行うことで、木に対して探索をする。

その計算とは以下のようなものである。

  • rank1(x) ・・・ 位置 x までに現れる 1 の個数を返す
  • rank0(x) ・・・ 位置 x までに現れる 0 の個数を返す
  • select1(t) ・・・ 1が t 回現れる位置を返す
  • select0(t) ・・・ 0が t 回現れる位置を返す

ベクトルの要素を先頭から走査することで簡単に計算することができるが、時間計算量は要素数\(N\)に対して、\(O(N)\)である。

これを空間計算量\(o(N)\)の補助データを追加することで、定数時間で高速に計算できるデータ構造を 簡潔ビットベクトル(または完備辞書)という。

rankの高速化

まず、32ビットの整数を用いて、ビットベクトルを要素32個ずつに分解しておく。
ビット演算を工夫することで、この整数に現れる1の数は定数時間で計算することができる。
それらの整数列の累積和を保持しておくことで、rankは定数時間で計算できる。

selectの高速化

selectを定数時間で行う方法もあるらしいが、実装は大変らしい。
残念ながら二分探索で\(O(\log N)\)で妥協することとする。

実装

package bitvector

import (
    "fmt"
    "math/bits"
)

const blocksize8 = 8

type BitVector8 struct {
    blocks []uint8
    offset uint
    ranks  []int
}

func NewBitVector8() *BitVector8 {
    return &BitVector8{ranks: []int{0}}
}

func (l *BitVector8) Add(b byte) {
    if b > 1 {
        panic(fmt.Sprintf("required: 0 or 1, but got: %d", b))
    }

    if l.offset == 0 {
        l.blocks = append(l.blocks, uint8(b))
        l.offset = 1
    } else {
        if b == 1 {
            l.blocks[len(l.blocks)-1] += 1 << l.offset
        }
        l.offset += 1

        if l.offset == blocksize8 {
            l.ranks = append(l.ranks, bits.OnesCount8(l.blocks[len(l.blocks)-1])+l.ranks[len(l.ranks)-1])
            l.offset = 0
        }
    }
}

func (l *BitVector8) String() (res string) {
    for i := 0; i < len(l.blocks); i++ {
        tmp := fmt.Sprintf("%08b", bits.Reverse8(uint8(l.blocks[i])))
        if i == len(l.blocks)-1 {
            if l.offset != 0 {
                tmp = tmp[:l.offset]
            }
        } else {
            tmp += "|"
        }
        res += tmp
    }
    return
}

func (l *BitVector8) Rank1(i int) (res int) {
    res = l.ranks[i/blocksize8]
    res += bits.OnesCount8(l.blocks[i/blocksize8] & (1<<(uint(i%8+1)) - 1))
    return res
}

func (l *BitVector8) Rank0(i int) (res int) {
    return i + 1 - l.Rank1(i)
}

func (l *BitVector8) Select1(count int) (res int) {
    left := 0
    right := (len(l.blocks)-1)*blocksize8 + int(l.offset) - 1

    if l.Rank1(left) > count {
        return -1
    }

    if l.Rank1(right) < count {
        return -1
    }

    for left+1 < right {
        mid := (left + right) / 2
        if l.Rank1(mid) < count {
            left = mid
        } else {
            right = mid
        }
    }

    return right
}

func (l *BitVector8) Select0(count int) (res int) {
    left := 0
    right := (len(l.blocks)-1)*blocksize8 + int(l.offset) - 1

    if l.Rank0(left) > count {
        return -1
    }

    if l.Rank0(right) < count {
        return -1
    }

    for left+1 < right {
        mid := (left + right) / 2
        if l.Rank0(mid) < count {
            left = mid
        } else {
            right = mid
        }
    }

    return right
}

func (l *BitVector8) extend() {
    next := make([]uint8, cap(l.blocks)*2)
    copy(next, l.blocks)
    l.blocks = next
}

テストケース

今回は複雑になったので、テストケースも充実させた。

package bitvector

import (
    "fmt"
    "testing"
)

func buildBitVector8(t *testing.T, b string) *BitVector8 {
    t.Helper()

    v := NewBitVector8()
    for i := range b {
        switch string(b[i]) {
        case "1":
            v.Add(1)
        case "0":
            v.Add(0)
        }
    }

    return v
}

func TestBitVector8_String(t *testing.T) {
    expected := "01011100|10"

    v := buildBitVector8(t, expected)
    got := v.String()

    if got != expected {
        t.Fatalf("got: %v, expected: %v", got, expected)
    }
}

func TestBitVector8_Rank1(t *testing.T) {
    cases := []struct {
        str      string
        index    int
        expected int
    }{
        {"01011", 3, 2},
        {"01011", 4, 3},
        {"01011101", 7, 5},
        {"01011101|1", 7, 5},
        {"01011101|1", 8, 6},
        {"01011101|10", 9, 6},
        {"01011101|01011101|01011101|0101110", 6, 4},
        {"01011101|01011101|01011101|0101110", 22, 14},
    }

    for _, tt := range cases {
        t.Run(fmt.Sprintf("rank(%s,%d)=%d", tt.str, tt.index, tt.expected), func(t *testing.T) {
            v := buildBitVector8(t, tt.str)
            got := v.Rank1(tt.index)

            if got != tt.expected {
                t.Fatalf("got: %v, expected: %v", got, tt.expected)
            }
        })
    }
}

func TestBitVector8_Rank0(t *testing.T) {
    cases := []struct {
        str      string
        index    int
        expected int
    }{
        {"01011", 3, 2},
        {"01011", 4, 2},
        {"01011101", 7, 3},
        {"01011101|1", 7, 3},
        {"01011101|1", 8, 3},
        {"01011101|10", 9, 4},
    }

    for _, tt := range cases {
        t.Run(fmt.Sprintf("rank(%s,%d)=%d", tt.str, tt.index, tt.expected), func(t *testing.T) {
            v := buildBitVector8(t, tt.str)
            got := v.Rank0(tt.index)

            if got != tt.expected {
                t.Fatalf("got: %v, expected: %v", got, tt.expected)
            }
        })
    }
}

func TestBitVector8_Select1(t *testing.T) {
    cases := []struct {
        str      string
        count    int
        expected int
    }{
        {"01011", 2, 3},
        {"01011111|111", 2, 3},
        {"00000", 1, -1},
        {"01011", 4, -1},
    }

    for _, tt := range cases {
        t.Run(fmt.Sprintf("select1(%s,%d)=%d", tt.str, tt.count, tt.expected), func(t *testing.T) {
            v := buildBitVector8(t, tt.str)
            got := v.Select1(tt.count)

            if got != tt.expected {
                t.Fatalf("got: %v, expected: %v", got, tt.expected)
            }
        })
    }
}

func TestBitVector8_Select0(t *testing.T) {
    cases := []struct {
        str      string
        count    int
        expected int
    }{
        {"01011", 2, 2},
        {"01000000|001", 3, 3},
        {"11111", 1, -1},
        {"01011", 4, -1},
    }

    for _, tt := range cases {
        t.Run(fmt.Sprintf("(%s,%d)=%d", tt.str, tt.count, tt.expected), func(t *testing.T) {
            v := buildBitVector8(t, tt.str)
            got := v.Select0(tt.count)

            if got != tt.expected {
                t.Fatalf("got: %v, expected: %v", got, tt.expected)
            }
        })
    }
}

次回は、今回の簡潔ビットベクトルのパッケージを使って、Trie木の検索を行っていく。