summaryrefslogtreecommitdiff
path: root/drivers/pe_bliss/pe_checksum.cpp
blob: f6d23f0e1088223d31d657131ddd2ff3a7e9e7e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "pe_checksum.h"
#include "pe_structures.h"
#include "pe_base.h"

namespace pe_bliss
{
using namespace pe_win;

//Calculate checksum of image
uint32_t calculate_checksum(std::istream& file)
{
	//Save istream state
	std::ios_base::iostate state = file.exceptions();
	std::streamoff old_offset = file.tellg();

	//Checksum value
	unsigned long long checksum = 0;

	try
	{
		image_dos_header header;

		file.exceptions(std::ios::goodbit);

		//Read DOS header
		pe_base::read_dos_header(file, header);

		//Calculate PE checksum
		file.seekg(0);
		unsigned long long top = 0xFFFFFFFF;
		top++;

		//"CheckSum" field position in optional PE headers - it's always 64 for PE and PE+
		static const unsigned long checksum_pos_in_optional_headers = 64;
		//Calculate real PE headers "CheckSum" field position
		//Sum is safe here
		unsigned long pe_checksum_pos = header.e_lfanew + sizeof(image_file_header) + sizeof(uint32_t) + checksum_pos_in_optional_headers;

		//Calculate checksum for each byte of file
		std::streamoff filesize = pe_utils::get_file_size(file);
		for(long long i = 0; i < filesize; i += 4)
		{
			unsigned long dw = 0;

			//Read DWORD from file
			file.read(reinterpret_cast<char*>(&dw), sizeof(unsigned long));
			//Skip "CheckSum" DWORD
			if(i == pe_checksum_pos)
				continue;

			//Calculate checksum
			checksum = (checksum & 0xffffffff) + dw + (checksum >> 32);
			if(checksum > top)
				checksum = (checksum & 0xffffffff) + (checksum >> 32);
		}

		//Finish checksum
		checksum = (checksum & 0xffff) + (checksum >> 16);
		checksum = (checksum) + (checksum >> 16);
		checksum = checksum & 0xffff;

		checksum += static_cast<unsigned long>(filesize);
	}
	catch(const std::exception&)
	{
		//If something went wrong, restore istream state
		file.exceptions(state);
		file.seekg(old_offset);
		file.clear();
		//Rethrow
		throw;
	}

	//Restore istream state
	file.exceptions(state);
	file.seekg(old_offset);
	file.clear();

	//Return checksum
	return static_cast<uint32_t>(checksum);	
}
}