#include #include #include #include #include #include #include #include class Packet { public: Packet(int version, int type) : _version(version), _type(type) {} void setValue(uint64_t value) { _has_value = true; _value = value; } void addSubPacket(Packet p) { _subs.push_back(std::move(p)); } int getVersion() const { return _version; } int getType() const { return _type; } uint64_t getValue() const { return _value; } bool hasSubs() const { return _subs.size() != 0; } const std::vector &getSubs() const { return _subs; } std::vector &getSubs() { return _subs; } bool hasValue() const { return _has_value; } private: int _version; int _type; uint64_t _value{}; bool _has_value = false; std::vector _subs{}; }; uint8_t getBitsFromChar(const char &c) { switch (c) { case 'A': return 0b1010; case 'B': return 0b1011; case 'C': return 0b1100; case 'D': return 0b1101; case 'E': return 0b1110; case 'F': return 0b1111; default: return c - '0'; } } std::vector getBits(const std::string &file_name) { std::vector bits; std::ifstream file(file_name); std::string str; std::getline(file, str); if (str.length() % 2 == 1) { bits.resize(str.length() / 2 + 1); } else { bits.resize(str.length() / 2); } for (size_t i = 0; i < str.length() / 2; i++) { bits[i] |= getBitsFromChar(str[i * 2]); bits[i] <<= 4; bits[i] |= getBitsFromChar(str[i * 2 + 1]); } if (str.length() % 2 == 1) { bits.back() |= getBitsFromChar(str.back()); bits.back() <<= 4; } return bits; } uint8_t getBit(const std::vector &bits, uint64_t bitpos) { auto ind = bitpos / 8; auto bit = bitpos % 8; uint8_t comparator = 0; switch (bit) { case 0: comparator = 0b10000000; break; case 1: comparator = 0b01000000; break; case 2: comparator = 0b00100000; break; case 3: comparator = 0b00010000; break; case 4: comparator = 0b00001000; break; case 5: comparator = 0b00000100; break; case 6: comparator = 0b00000010; break; case 7: comparator = 0b00000001; break; } return bits[ind] & comparator ? 1 : 0; } std::pair getPacketValue(const std::vector &bits, uint64_t size, uint64_t bitpos) { uint64_t result = 0; auto length = bitpos + size; while (bitpos < length) { uint8_t group = 0; for (int i = 0; i < 5; i++) { group <<= 1; group |= getBit(bits, bitpos); bitpos++; } result <<= 4; result |= group & 0b00001111; if (!(group & 0b00010000)) { break; } } return { bitpos, result }; } std::pair bitsToPacket(const std::vector &bits, uint64_t start, uint64_t size) { int version = 0; version |= getBit(bits, start); version <<= 1; version |= getBit(bits, start + 1); version <<= 1; version |= getBit(bits, start + 2); int type = 0; type |= getBit(bits, start + 3); type <<= 1; type |= getBit(bits, start + 4); type <<= 1; type |= getBit(bits, start + 5); auto length = start + size; uint64_t end = 0; Packet result(version, type); switch (type) { case 4: { auto res = getPacketValue(bits, size - 6, start + 6); result.setValue(res.second); end = res.first; } break; default: if (getBit(bits, start + 6)) { uint16_t count = 0; for (int i = 0; i < 11; i++) { count <<= 1; count |= getBit(bits, start + 7 + i); } auto new_start = start + 7 + 11; for (int i = 0; i < count; i++) { auto res = bitsToPacket(bits, new_start, length - new_start); new_start = res.first; result.addSubPacket(std::move(res.second)); } end = new_start; } else { uint16_t size = 0; for (int i = 0; i < 15; i++) { size <<= 1; size |= getBit(bits, start + 7 + i); } auto new_start = start + 7 + 15; auto sub_end = start + 7 + 15 + size; while (new_start < sub_end) { auto res = bitsToPacket(bits, new_start, size); size -= res.first - new_start; new_start = res.first; result.addSubPacket(std::move(res.second)); } end = new_start; } break; } return { end, result }; } uint64_t versionSum(const Packet &packet) { uint64_t sum = packet.getVersion(); if (packet.hasSubs()) { for (auto &sub : packet.getSubs()) { sum += versionSum(sub); } } return sum; } uint64_t part1(const Packet &packet) { return versionSum(packet); } uint64_t calculatePacket(Packet &packet) { if (packet.hasValue()) { return packet.getValue(); } uint64_t result = 0; switch (packet.getType()) { case 0: result = 0; for (auto &sub : packet.getSubs()) { result += calculatePacket(sub); } break; case 1: result = 1; for (auto &sub : packet.getSubs()) { result *= calculatePacket(sub); } break; case 2: result = -1; for (auto &sub : packet.getSubs()) { auto calc = calculatePacket(sub); if (calc < result) { result = calc; } } break; case 3: result = 0; for (auto &sub : packet.getSubs()) { auto calc = calculatePacket(sub); if (calc > result) { result = calc; } } break; case 4: result = packet.getValue(); break; case 5: result = calculatePacket(packet.getSubs()[0]) > calculatePacket(packet.getSubs()[1]) ? 1 : 0; break; case 6: result = calculatePacket(packet.getSubs()[0]) < calculatePacket(packet.getSubs()[1]) ? 1 : 0; break; case 7: result = calculatePacket(packet.getSubs()[0]) == calculatePacket(packet.getSubs()[1]) ? 1 : 0; break; default: break; } packet.setValue(result); return result; } uint64_t part2(Packet &packet) { return calculatePacket(packet); } int main(int argc, char **argv) { if (argc < 2) { std::cerr << "You must provide input file!" << std::endl; return 1; } auto bits = getBits(argv[1]); auto packet = bitsToPacket(bits, 0, bits.size() * 8).second; std::cout << "Sum of all packet's versions is \033[91;1m" << part1(packet) << "\033[0m." << std::endl; std::cout << "The resulting value of the outermost packet is \033[91;1m" << part2(packet) << "\033[0m." << std::endl; return 0; }