306 lines
13 KiB
Python
306 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Module IntegrityRoutine Contains IntegrityRoutine class helps with FIPS 140-2 build time integrity routine.
|
|
This module is needed to calculate HMAC and embed other needed stuff.
|
|
"""
|
|
|
|
import hmac
|
|
import hashlib
|
|
import bisect
|
|
import itertools
|
|
import binascii
|
|
from ELF import *
|
|
|
|
__author__ = "Vadym Stupakov"
|
|
__copyright__ = "Copyright (c) 2017 Samsung Electronics"
|
|
__credits__ = ["Vadym Stupakov"]
|
|
__version__ = "1.0"
|
|
__maintainer__ = "Vadym Stupakov"
|
|
__email__ = "v.stupakov@samsung.com"
|
|
__status__ = "Production"
|
|
|
|
|
|
class IntegrityRoutine(ELF):
|
|
"""
|
|
Utils for fips-integrity process
|
|
"""
|
|
def __init__(self, elf_file, readelf_path="readelf"):
|
|
ELF.__init__(self, elf_file, readelf_path)
|
|
|
|
@staticmethod
|
|
def __remove_all_dublicates(lst):
|
|
"""
|
|
Removes all occurrences of tha same value. For instance: transforms [1, 2, 3, 1] -> [2, 3]
|
|
:param lst: input list
|
|
:return: lst w/o duplicates
|
|
"""
|
|
to_remove = list()
|
|
for i in range(len(lst)):
|
|
it = itertools.islice(lst, i + 1, len(lst) - 1, None)
|
|
for j, val in enumerate(it, start=i+1):
|
|
if val == lst[i]:
|
|
to_remove.extend([lst[i], lst[j]])
|
|
|
|
for el in to_remove:
|
|
lst.remove(el)
|
|
|
|
def get_reloc_gaps(self, start_addr, end_addr):
|
|
"""
|
|
:param start_addr: start address :int
|
|
:param end_addr: end address: int
|
|
:returns list of relocation gaps like [[gap_start, gap_end], [gap_start, gap_end], ...]
|
|
"""
|
|
all_relocs = self.get_relocs(start_addr, end_addr)
|
|
relocs_gaps = list()
|
|
for addr in all_relocs:
|
|
relocs_gaps.append(addr)
|
|
relocs_gaps.append(addr + 8)
|
|
self.__remove_all_dublicates(relocs_gaps)
|
|
relocs_gaps.sort()
|
|
relocs_gaps = [[addr1, addr2] for addr1, addr2 in self.utils.pairwise(relocs_gaps)]
|
|
return relocs_gaps
|
|
|
|
def get_addrs_for_hmac(self, sec_sym_sequence, relocs_gaps=None):
|
|
"""
|
|
Generate addresses for calculating HMAC
|
|
:param sec_sym_sequence: [addr_start1, addr_end1, ..., addr_startN, addr_endN],
|
|
:param relocs_gaps: [[start_gap_addr, end_gap_addr], [start_gap_addr, end_gap_addr]]
|
|
:return: addresses for calculating HMAC: [[addr_start, addr_end], [addr_start, addr_end], ...]
|
|
"""
|
|
addrs_for_hmac = list()
|
|
for section_name, sym_names in sec_sym_sequence.items():
|
|
if relocs_gaps is not None and section_name == ".rodata":
|
|
for symbol in self.get_symbol_by_name(sym_names):
|
|
addrs_for_hmac.append(symbol.addr)
|
|
else:
|
|
for symbol in self.get_symbol_by_name(sym_names):
|
|
addrs_for_hmac.append(symbol.addr)
|
|
addrs_for_hmac.extend(self.utils.flatten(relocs_gaps))
|
|
addrs_for_hmac.sort()
|
|
return [[item1, item2] for item1, item2 in self.utils.pairwise(addrs_for_hmac)]
|
|
|
|
def embed_bytes(self, vaddr, in_bytes):
|
|
"""
|
|
Write bytes to ELF file
|
|
:param vaddr: virtual address in ELF
|
|
:param in_bytes: byte array to write
|
|
"""
|
|
offset = self.vaddr_to_file_offset(vaddr)
|
|
with open(self.get_elf_file(), "rb+") as elf_file:
|
|
elf_file.seek(offset)
|
|
elf_file.write(in_bytes)
|
|
|
|
def __update_hmac(self, hmac_obj, file_obj, file_offset_start, file_offset_end):
|
|
"""
|
|
Update hmac from addrstart tp addr_end
|
|
FIXMI: it needs to implement this function via fixed block size
|
|
:param file_offset_start: could be string or int
|
|
:param file_offset_end: could be string or int
|
|
"""
|
|
file_offset_start = self.utils.to_int(file_offset_start)
|
|
file_offset_end = self.utils.to_int(file_offset_end)
|
|
file_obj.seek(self.vaddr_to_file_offset(file_offset_start))
|
|
block_size = file_offset_end - file_offset_start
|
|
msg = file_obj.read(block_size)
|
|
hmac_obj.update(msg)
|
|
|
|
def get_hmac(self, offset_sequence, key, output_type="byte"):
|
|
"""
|
|
Calculate HMAC
|
|
:param offset_sequence: start and end addresses sequence [addr_start, addr_end], [addr_start, addr_end], ...]
|
|
:param key HMAC key: string value
|
|
:param output_type string value. Could be "hex" or "byte"
|
|
:return: bytearray or hex string
|
|
"""
|
|
digest = hmac.new(bytearray(key.encode("utf-8")), digestmod=hashlib.sha256)
|
|
with open(self.get_elf_file(), "rb") as file:
|
|
for addr_start, addr_end in offset_sequence:
|
|
self.__update_hmac(digest, file, addr_start, addr_end)
|
|
if output_type == "byte":
|
|
return digest.digest()
|
|
if output_type == "hex":
|
|
return digest.hexdigest()
|
|
|
|
def __find_nearest_symbol_by_vaddr(self, vaddr, method):
|
|
"""
|
|
Find nearest symbol near vaddr
|
|
:param vaddr:
|
|
:return: idx of symbol from self.get_symbols()
|
|
"""
|
|
symbol = self.get_symbol_by_vaddr(vaddr)
|
|
if symbol is None:
|
|
raise ValueError("Can't find symbol by vaddr")
|
|
idx = method(list(self.get_symbols()), vaddr)
|
|
return idx
|
|
|
|
def find_rnearest_symbol_by_vaddr(self, vaddr):
|
|
"""
|
|
Find right nearest symbol near vaddr
|
|
:param vaddr:
|
|
:return: idx of symbol from self.get_symbols()
|
|
"""
|
|
return self.__find_nearest_symbol_by_vaddr(vaddr, bisect.bisect_right)
|
|
|
|
def find_lnearest_symbol_by_vaddr(self, vaddr):
|
|
"""
|
|
Find left nearest symbol near vaddr
|
|
:param vaddr:
|
|
:return: idx of symbol from self.get_symbols()
|
|
"""
|
|
return self.__find_nearest_symbol_by_vaddr(vaddr, bisect.bisect_left)
|
|
|
|
def find_symbols_between_vaddrs(self, vaddr_start, vaddr_end):
|
|
"""
|
|
Returns list of symbols between two virtual addresses
|
|
:param vaddr_start:
|
|
:param vaddr_end:
|
|
:return: [(Symbol(), Section)]
|
|
"""
|
|
symbol_start = self.get_symbol_by_vaddr(vaddr_start)
|
|
symbol_end = self.get_symbol_by_vaddr(vaddr_end)
|
|
if symbol_start is None or symbol_end is None:
|
|
raise ValueError("Error: Cannot find symbol by vaddr. vaddr should coincide with symbol address!")
|
|
|
|
idx_start = self.find_lnearest_symbol_by_vaddr(vaddr_start)
|
|
idx_end = self.find_lnearest_symbol_by_vaddr(vaddr_end)
|
|
|
|
sym_sec = list()
|
|
for idx in range(idx_start, idx_end):
|
|
symbol_addr = list(self.get_symbols())[idx]
|
|
symbol = self.get_symbol_by_vaddr(symbol_addr)
|
|
section = self.get_section_by_vaddr(symbol_addr)
|
|
sym_sec.append((symbol, section))
|
|
|
|
sym_sec.sort(key=lambda x: x[0])
|
|
return sym_sec
|
|
|
|
@staticmethod
|
|
def __get_skipped_bytes(symbol, relocs):
|
|
"""
|
|
:param symbol: Symbol()
|
|
:param relocs: [[start1, end1], [start2, end2]]
|
|
:return: Returns skipped bytes and [[start, end]] addresses that show which bytes were skipped
|
|
"""
|
|
symbol_start_addr = symbol.addr
|
|
symbol_end_addr = symbol.addr + symbol.size
|
|
skipped_bytes = 0
|
|
reloc_addrs = list()
|
|
for reloc_start, reloc_end in relocs:
|
|
if reloc_start >= symbol_start_addr and reloc_end <= symbol_end_addr:
|
|
skipped_bytes += reloc_end - reloc_start
|
|
reloc_addrs.append([reloc_start, reloc_end])
|
|
if reloc_start > symbol_end_addr:
|
|
break
|
|
|
|
return skipped_bytes, reloc_addrs
|
|
|
|
def print_covered_info(self, sec_sym, relocs, print_reloc_addrs=False, sort_by="address", reverse=False):
|
|
"""
|
|
Prints information about covered symbols in detailed table:
|
|
|N| symbol name | symbol address | symbol section | bytes skipped | skipped bytes address range |
|
|
|1| symbol | 0xXXXXXXXXXXXXXXXX | .rodata | 8 | [[addr1, addr2], [addr1, addr2]] |
|
|
:param sec_sym: {section_name : [sym_name1, sym_name2]}
|
|
:param relocs: [[start1, end1], [start2, end2]]
|
|
:param print_reloc_addrs: print or not skipped bytes address range
|
|
:param sort_by: method for sorting table. Could be: "address", "name", "section"
|
|
:param reverse: sort order
|
|
"""
|
|
if sort_by.lower() == "address":
|
|
def sort_method(x): return x[0].addr
|
|
elif sort_by.lower() == "name":
|
|
def sort_method(x): return x[0].name
|
|
elif sort_by.lower() == "section":
|
|
def sort_method(x): return x[1].name
|
|
else:
|
|
raise ValueError("Invalid sort type!")
|
|
table_format = "|{:4}| {:50} | {:18} | {:20} | {:15} |"
|
|
if print_reloc_addrs is True:
|
|
table_format += "{:32} |"
|
|
|
|
print(table_format.format("N", "symbol name", "symbol address", "symbol section", "bytes skipped",
|
|
"skipped bytes address range"))
|
|
data_to_print = list()
|
|
for sec_name, sym_names in sec_sym.items():
|
|
for symbol_start, symbol_end in self.utils.pairwise(self.get_symbol_by_name(sym_names)):
|
|
symbol_sec_in_range = self.find_symbols_between_vaddrs(symbol_start.addr, symbol_end.addr)
|
|
for symbol, section in symbol_sec_in_range:
|
|
skipped_bytes, reloc_addrs = self.__get_skipped_bytes(symbol, relocs)
|
|
reloc_addrs_str = "["
|
|
for start_addr, end_addr in reloc_addrs:
|
|
reloc_addrs_str += "[{}, {}], ".format(hex(start_addr), hex(end_addr))
|
|
reloc_addrs_str += "]"
|
|
if symbol.size > 0:
|
|
data_to_print.append((symbol, section, skipped_bytes, reloc_addrs_str))
|
|
|
|
skipped_bytes_size = 0
|
|
symbol_covered_size = 0
|
|
cnt = 0
|
|
data_to_print.sort(key=sort_method, reverse=reverse)
|
|
for symbol, section, skipped_bytes, reloc_addrs_str in data_to_print:
|
|
cnt += 1
|
|
symbol_covered_size += symbol.size
|
|
skipped_bytes_size += skipped_bytes
|
|
if print_reloc_addrs is True:
|
|
print(table_format.format(cnt, symbol.name, hex(symbol.addr), section.name,
|
|
self.utils.human_size(skipped_bytes), reloc_addrs_str))
|
|
else:
|
|
print(table_format.format(cnt, symbol.name, hex(symbol.addr), section.name,
|
|
self.utils.human_size(skipped_bytes)))
|
|
addrs_for_hmac = self.get_addrs_for_hmac(sec_sym, relocs)
|
|
all_covered_size = 0
|
|
for addr_start, addr_end in addrs_for_hmac:
|
|
all_covered_size += addr_end - addr_start
|
|
print("Symbol covered bytes len: {} ".format(self.utils.human_size(symbol_covered_size - skipped_bytes_size)))
|
|
print("All covered bytes len : {} ".format(self.utils.human_size(all_covered_size)))
|
|
print("Skipped bytes len : {} ".format(self.utils.human_size(skipped_bytes_size)))
|
|
|
|
def dump_covered_bytes(self, vaddr_seq, out_file):
|
|
"""
|
|
Dumps covered bytes
|
|
:param vaddr_seq: [[start1, end1], [start2, end2]] start - end sequence of covered bytes
|
|
:param out_file: file where will be stored dumped bytes
|
|
"""
|
|
with open(self.get_elf_file(), "rb") as elf_fp:
|
|
with open(out_file, "wb") as out_fp:
|
|
for vaddr_start, vaddr_end, in vaddr_seq:
|
|
elf_fp.seek(self.vaddr_to_file_offset(vaddr_start))
|
|
out_fp.write(elf_fp.read(vaddr_end - vaddr_start))
|
|
|
|
def make_integrity(self, sec_sym, module_name, debug=False, print_reloc_addrs=False, sort_by="address",
|
|
reverse=False):
|
|
"""
|
|
Calculate HMAC and embed needed info
|
|
:param sec_sym: {sec_name: [addr1, addr2, ..., addrN]}
|
|
:param module_name: module name that you want to make integrity. See Makefile targets
|
|
:param debug: If True prints debug information
|
|
:param print_reloc_addrs: If True, print relocation addresses that are skipped
|
|
:param sort_by: sort method
|
|
:param reverse: sort order
|
|
"""
|
|
rel_addr_start = self.get_symbol_by_name("first_" + module_name + "_rodata")
|
|
rel_addr_end = self.get_symbol_by_name("last_" + module_name + "_rodata")
|
|
|
|
reloc_gaps = self.get_reloc_gaps(rel_addr_start.addr, rel_addr_end.addr)
|
|
addrs_for_hmac = self.get_addrs_for_hmac(sec_sym, reloc_gaps)
|
|
|
|
digest = self.get_hmac(addrs_for_hmac, "The quick brown fox jumps over the lazy dog")
|
|
|
|
self.embed_bytes(self.get_symbol_by_name("builtime_" + module_name + "_hmac").addr,
|
|
self.utils.to_bytearray(digest))
|
|
|
|
self.embed_bytes(self.get_symbol_by_name("integrity_" + module_name + "_addrs").addr,
|
|
self.utils.to_bytearray(addrs_for_hmac))
|
|
|
|
self.embed_bytes(self.get_symbol_by_name(module_name + "_buildtime_address").addr,
|
|
self.utils.to_bytearray(self.get_symbol_by_name(module_name + "_buildtime_address").addr))
|
|
|
|
print("HMAC for \"{}\" module is: {}".format(module_name, binascii.hexlify(digest)))
|
|
if debug:
|
|
self.print_covered_info(sec_sym, reloc_gaps, print_reloc_addrs=print_reloc_addrs, sort_by=sort_by,
|
|
reverse=reverse)
|
|
self.dump_covered_bytes(addrs_for_hmac, "covered_dump_for_" + module_name + ".bin")
|
|
|
|
print("FIPS integrity procedure has been finished for {}".format(module_name))
|