diff --git a/src/symbol_table.rs b/src/symbol_table.rs index a35fb66..fc84160 100644 --- a/src/symbol_table.rs +++ b/src/symbol_table.rs @@ -3,16 +3,51 @@ use anyhow::anyhow; use elf::gabi::{STT_FILE, STT_SECTION}; use elf::CachedReadBytes; use indexmap::IndexSet; +use itertools::Itertools; use std::collections::HashMap; +use std::fmt::Display; use std::fs::File; +use thiserror::Error; + +pub struct SymbolDisplayer<'a>(&'a SymbolTable); + +impl Display for SymbolDisplayer<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "{}", + self.0 + .symbols + .iter() + .format_with("\n", |(name, symbol), g| { + g(&format_args!("{name}: {symbol}")) + }) + )) + } +} + +#[derive(Debug, Copy, Clone, Error)] +#[error("Invalid symbol table")] +struct InvalidSymbolTable; + +pub struct BreakpointDisplayer<'a>(&'a SymbolTable); + +impl Display for BreakpointDisplayer<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.0.breakpoints.iter().format("\n"))) + } +} #[derive(Debug)] pub struct SymbolTable { - pub symbols: HashMap, - pub breakpoints: IndexSet, + symbols: HashMap, + breakpoints: IndexSet, pub active: bool, } +#[derive(Debug, Copy, Clone, Error)] +#[error("Invalid symbol name")] +pub struct InvalidSymbolName; + impl SymbolTable { pub fn new(symbols: HashMap) -> Self { Self { @@ -41,14 +76,14 @@ impl SymbolTable { Ok(Self::new(symbols)) } - pub fn update_symbols(&mut self, symbols: HashMap) { + pub fn update_symbols_from(&mut self, table: Self) { self.breakpoints = self .breakpoints .iter() .cloned() - .filter(|sym| symbols.contains_key(sym)) + .filter(|sym| table.symbols.contains_key(sym)) .collect::>(); - self.symbols = symbols; + self.symbols = table.symbols; } pub fn breakpoint_set_at(&self, addr: u32) -> bool { @@ -57,6 +92,14 @@ impl SymbolTable { .any(|sym| self.symbols[sym].value() == addr) } + pub fn set_breakpoint(&mut self, symbol: String) { + self.breakpoints.insert(symbol); + } + + pub fn delete_breakpoint(&mut self, symbol: &str) -> bool { + self.breakpoints.shift_remove(symbol) + } + pub fn address_to_symbol(&self, addr: u32) -> Option<(&String, u32)> { self.symbols .iter() @@ -64,4 +107,20 @@ impl SymbolTable { .map(|(sym_name, sym)| (sym_name, addr - sym.value())) .min_by_key(|(_, offset)| *offset) } + + pub fn get_symbol(&self, symbol: &str) -> anyhow::Result<&Symbol> { + Ok(self.symbols.get(symbol).ok_or(InvalidSymbolName)?) + } + + pub fn contains_symbol(&self, symbol: &str) -> bool { + self.symbols.contains_key(symbol) + } + + pub fn symbol_displayer(&self) -> SymbolDisplayer<'_> { + SymbolDisplayer(self) + } + + pub fn breakpoint_displayer(&self) -> BreakpointDisplayer { + BreakpointDisplayer(self) + } } diff --git a/src/symbol_tables.rs b/src/symbol_tables.rs index 39f93b8..1d6c51f 100644 --- a/src/symbol_tables.rs +++ b/src/symbol_tables.rs @@ -1,13 +1,15 @@ use std::{fmt::Display, path::Path}; -use crate::{location::Location, symbol::Symbol, symbol_table::SymbolTable}; +use crate::{ + location::Location, + symbol::Symbol, + symbol_table::{InvalidSymbolName, SymbolTable}, +}; use indexmap::IndexMap; use itertools::Itertools; use parse_int::parse; use thiserror::Error; -pub struct SymbolDisplayer<'a>(&'a SymbolTables); - fn displayer_common<'a, F, T: Display>( symbol_tables: &'a SymbolTables, f: &mut std::fmt::Formatter<'_>, @@ -31,13 +33,11 @@ where )) } +pub struct SymbolDisplayer<'a>(&'a SymbolTables); + impl Display for SymbolDisplayer<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - displayer_common(self.0, f, |table| { - table.symbols.iter().format_with("\n", |(name, symbol), g| { - g(&format_args!("{name}: {symbol}")) - }) - }) + displayer_common(self.0, f, |table| table.symbol_displayer()) } } @@ -45,15 +45,11 @@ impl Display for SymbolDisplayer<'_> { #[error("Invalid symbol table")] struct InvalidSymbolTable; -#[derive(Debug, Copy, Clone, Error)] -#[error("Invalid symbol name")] -struct InvalidSymbolName; - pub struct BreakpointDisplayer<'a>(&'a SymbolTables); impl Display for BreakpointDisplayer<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - displayer_common(self.0, f, |table| table.breakpoints.iter().format("\n")) + displayer_common(self.0, f, |table| table.breakpoint_displayer()) } } @@ -97,7 +93,7 @@ impl SymbolTables { let new_table = SymbolTable::read_from_file(path)?; let table_name = Path::new(&path).file_name().unwrap().to_str().unwrap(); if let Some(table) = self.tables.get_mut(table_name) { - table.update_symbols(new_table.symbols); + table.update_symbols_from(new_table); } else { self.tables.insert(table_name.to_string(), new_table); }; @@ -110,12 +106,12 @@ impl SymbolTables { } pub fn set_breakpoint(&mut self, table: &str, symbol: String) -> anyhow::Result<()> { - self.get_table_mut(table)?.breakpoints.insert(symbol); + self.get_table_mut(table)?.set_breakpoint(symbol); Ok(()) } pub fn delete_breakpoint(&mut self, table: &str, symbol: &str) -> anyhow::Result { - Ok(self.get_table_mut(table)?.breakpoints.shift_remove(symbol)) + Ok(self.get_table_mut(table)?.delete_breakpoint(symbol)) } pub fn symbol_displayer(&self) -> SymbolDisplayer<'_> { @@ -131,11 +127,7 @@ impl SymbolTables { } pub fn get(&self, table: &str, symbol: &str) -> anyhow::Result<&Symbol> { - Ok(self - .get_table(table)? - .symbols - .get(symbol) - .ok_or(InvalidSymbolName)?) + self.get_table(table)?.get_symbol(symbol) } pub fn parse_location(&self, location: &str) -> anyhow::Result { @@ -145,16 +137,10 @@ impl SymbolTables { table_name = self .tables .iter() - .find(|(_, table)| table.symbols.contains_key(symbol_name)) + .find(|(_, table)| table.contains_symbol(symbol_name)) .ok_or(InvalidSymbolName)? .0; - } else if !self - .tables - .get(table_name) - .ok_or(InvalidSymbolTable)? - .symbols - .contains_key(symbol_name) - { + } else if !self.get_table(table_name)?.contains_symbol(symbol_name) { return Err(InvalidSymbolName.into()); } Ok(Location::Symbol((