advent_of_code_2021/16/main.cpp
2021-12-16 19:29:01 +01:00

312 lines
7.5 KiB
C++

#include <algorithm>
#include <deque>
#include <fstream>
#include <iostream>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
class Packet {
public:
Packet(int version, int type) : _version(version), _type(type) {}
void setValue(uint64_t value) {
_has_value = true;
_value = value;
}
void addSubPacket(Packet p) {
_subs.push_back(std::move(p));
}
int getVersion() const {
return _version;
}
int getType() const {
return _type;
}
uint64_t getValue() const {
return _value;
}
bool hasSubs() const {
return _subs.size() != 0;
}
const std::vector<Packet> &getSubs() const {
return _subs;
}
std::vector<Packet> &getSubs() {
return _subs;
}
bool hasValue() const {
return _has_value;
}
private:
int _version;
int _type;
uint64_t _value{};
bool _has_value = false;
std::vector<Packet> _subs{};
};
uint8_t getBitsFromChar(const char &c) {
switch (c) {
case 'A':
return 0b1010;
case 'B':
return 0b1011;
case 'C':
return 0b1100;
case 'D':
return 0b1101;
case 'E':
return 0b1110;
case 'F':
return 0b1111;
default:
return c - '0';
}
}
std::vector<uint8_t> getBits(const std::string &file_name) {
std::vector<uint8_t> bits;
std::ifstream file(file_name);
std::string str;
std::getline(file, str);
if (str.length() % 2 == 1) {
bits.resize(str.length() / 2 + 1);
} else {
bits.resize(str.length() / 2);
}
for (size_t i = 0; i < str.length() / 2; i++) {
bits[i] |= getBitsFromChar(str[i * 2]);
bits[i] <<= 4;
bits[i] |= getBitsFromChar(str[i * 2 + 1]);
}
if (str.length() % 2 == 1) {
bits.back() |= getBitsFromChar(str.back());
bits.back() <<= 4;
}
return bits;
}
uint8_t getBit(const std::vector<uint8_t> &bits, uint64_t bitpos) {
auto ind = bitpos / 8;
auto bit = bitpos % 8;
uint8_t comparator = 0;
switch (bit) {
case 0:
comparator = 0b10000000;
break;
case 1:
comparator = 0b01000000;
break;
case 2:
comparator = 0b00100000;
break;
case 3:
comparator = 0b00010000;
break;
case 4:
comparator = 0b00001000;
break;
case 5:
comparator = 0b00000100;
break;
case 6:
comparator = 0b00000010;
break;
case 7:
comparator = 0b00000001;
break;
}
return bits[ind] & comparator ? 1 : 0;
}
std::pair<uint64_t, uint64_t> getPacketValue(const std::vector<uint8_t> &bits,
uint64_t size, uint64_t bitpos) {
uint64_t result = 0;
auto length = bitpos + size;
while (bitpos < length) {
uint8_t group = 0;
for (int i = 0; i < 5; i++) {
group <<= 1;
group |= getBit(bits, bitpos);
bitpos++;
}
result <<= 4;
result |= group & 0b00001111;
if (!(group & 0b00010000)) {
break;
}
}
return { bitpos, result };
}
std::pair<uint64_t, Packet> bitsToPacket(const std::vector<uint8_t> &bits,
uint64_t start, uint64_t size) {
int version = 0;
version |= getBit(bits, start);
version <<= 1;
version |= getBit(bits, start + 1);
version <<= 1;
version |= getBit(bits, start + 2);
int type = 0;
type |= getBit(bits, start + 3);
type <<= 1;
type |= getBit(bits, start + 4);
type <<= 1;
type |= getBit(bits, start + 5);
auto length = start + size;
uint64_t end = 0;
Packet result(version, type);
switch (type) {
case 4: {
auto res = getPacketValue(bits, size - 6, start + 6);
result.setValue(res.second);
end = res.first;
} break;
default:
if (getBit(bits, start + 6)) {
uint16_t count = 0;
for (int i = 0; i < 11; i++) {
count <<= 1;
count |= getBit(bits, start + 7 + i);
}
auto new_start = start + 7 + 11;
for (int i = 0; i < count; i++) {
auto res = bitsToPacket(bits, new_start, length - new_start);
new_start = res.first;
result.addSubPacket(std::move(res.second));
}
end = new_start;
} else {
uint16_t size = 0;
for (int i = 0; i < 15; i++) {
size <<= 1;
size |= getBit(bits, start + 7 + i);
}
auto new_start = start + 7 + 15;
auto sub_end = start + 7 + 15 + size;
while (new_start < sub_end) {
auto res = bitsToPacket(bits, new_start, size);
size -= res.first - new_start;
new_start = res.first;
result.addSubPacket(std::move(res.second));
}
end = new_start;
}
break;
}
return { end, result };
}
uint64_t versionSum(const Packet &packet) {
uint64_t sum = packet.getVersion();
if (packet.hasSubs()) {
for (auto &sub : packet.getSubs()) {
sum += versionSum(sub);
}
}
return sum;
}
uint64_t part1(const Packet &packet) {
return versionSum(packet);
}
uint64_t calculatePacket(Packet &packet) {
if (packet.hasValue()) {
return packet.getValue();
}
uint64_t result = 0;
switch (packet.getType()) {
case 0:
result = 0;
for (auto &sub : packet.getSubs()) {
result += calculatePacket(sub);
}
break;
case 1:
result = 1;
for (auto &sub : packet.getSubs()) {
result *= calculatePacket(sub);
}
break;
case 2:
result = -1;
for (auto &sub : packet.getSubs()) {
auto calc = calculatePacket(sub);
if (calc < result) {
result = calc;
}
}
break;
case 3:
result = 0;
for (auto &sub : packet.getSubs()) {
auto calc = calculatePacket(sub);
if (calc > result) {
result = calc;
}
}
break;
case 4:
result = packet.getValue();
break;
case 5:
result = calculatePacket(packet.getSubs()[0]) >
calculatePacket(packet.getSubs()[1])
? 1
: 0;
break;
case 6:
result = calculatePacket(packet.getSubs()[0]) <
calculatePacket(packet.getSubs()[1])
? 1
: 0;
break;
case 7:
result = calculatePacket(packet.getSubs()[0]) ==
calculatePacket(packet.getSubs()[1])
? 1
: 0;
break;
default:
break;
}
packet.setValue(result);
return result;
}
uint64_t part2(Packet &packet) {
return calculatePacket(packet);
}
int main(int argc, char **argv) {
if (argc < 2) {
std::cerr << "You must provide input file!" << std::endl;
return 1;
}
auto bits = getBits(argv[1]);
auto packet = bitsToPacket(bits, 0, bits.size() * 8).second;
std::cout << "Sum of all packet's versions is \033[91;1m" << part1(packet)
<< "\033[0m." << std::endl;
std::cout << "The resulting value of the outermost packet is \033[91;1m"
<< part2(packet) << "\033[0m." << std::endl;
return 0;
}