-
Notifications
You must be signed in to change notification settings - Fork 0
/
comb.hpp
101 lines (89 loc) · 2.4 KB
/
comb.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//------------------------------------------------------------------------------
//
// this file implements all combinations from 0 .. n-1 of t
// usage (combinations of 3 from 0..4):
//
// all_comb ac(5, 3);
// do {
// for (auto it = ac.begin(); it != ac.end(); ++it)
// std::cout << *it << " ";
// std::cout << std::endl;
// } while (ac.next_comb());
//
// output:
// 0 1 2
// 0 2 3
// ... etc ...
//
// Details may be found in Knuth, algorithm R from 7.2.1.3 (vol 4A)
//
//------------------------------------------------------------------------------
//
// This file is licensed after LGPL v3
// Look at: https://www.gnu.org/licenses/lgpl-3.0.en.html for details
//
//------------------------------------------------------------------------------
#pragma once
#include <cassert>
#include <numeric>
#include <vector>
class all_comb {
std::vector<int> combination_;
int total_;
public:
auto begin() { return combination_.begin(); }
auto end() { return std::prev(combination_.end()); }
all_comb(int n, int t) : combination_(t + 1), total_(n) {
assert(n > 0 && "Makes no sense to combine things out of nothing");
assert(t > 0 && "Makes no sense to combine zero number of things");
assert(n > t && "Makes no sense to have n <= t");
reinit();
}
void reinit() {
std::iota(combination_.begin(), combination_.end(), 0);
combination_.back() = total_;
}
bool next_comb() {
bool skipr4 = false;
int comblen_ = combination_.size() - 1;
if ((comblen_ % 2) == 1) {
if (combination_[0] + 1 < combination_[1]) {
combination_[0] += 1;
return true;
}
} else {
if (combination_[0] > 0) {
combination_[0] -= 1;
return true;
}
skipr4 = true;
}
int j = 2;
for (;;) {
// step R4
if (!skipr4) {
assert(combination_[j - 1] == combination_[j - 2] + 1);
if (combination_[j - 1] >= j) {
combination_[j - 1] = combination_[j - 2];
combination_[j - 2] = j - 2;
break;
}
j += 1;
}
// step R5
assert(combination_[j - 2] == j - 2);
if (combination_[j - 1] + 1 < combination_[j]) {
combination_[j - 2] = combination_[j - 1];
combination_[j - 1] += 1;
break;
}
j += 1;
if (j > comblen_) {
reinit();
return false;
}
skipr4 = false;
}
return true;
}
};