diff --git a/src/ata_command.rs b/src/ata_command.rs new file mode 100644 index 0000000..a4913a6 --- /dev/null +++ b/src/ata_command.rs @@ -0,0 +1,117 @@ +use std::io::Cursor; + +use binread::BinRead; + +use crate::{ + ahci_structs::{DeviceRegsWriteBuilder, RegH2DFis, RegH2DFisBuilder}, + identify::{IdentifyData, IdentifyPacketData}, +}; + +#[allow(clippy::module_name_repetitions)] +pub trait AtaCommandDataIn { + type ProcessedData<'a>; + fn get_fis(&self) -> RegH2DFis; + fn data_len(&self) -> usize; + fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a>; +} + +pub struct IdentifyCommand; + +impl AtaCommandDataIn for IdentifyCommand { + type ProcessedData<'a> = IdentifyData; + + fn get_fis(&self) -> RegH2DFis { + RegH2DFisBuilder::default() + .cmd(true) + .regs( + DeviceRegsWriteBuilder::default() + .commmad(0xEC) + .build() + .unwrap(), + ) + .build() + .unwrap() + } + + fn data_len(&self) -> usize { + 512 + } + + fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> { + IdentifyData::read(&mut Cursor::new(data)).unwrap() + } +} + +pub struct IdentifyPacketCommand; + +impl AtaCommandDataIn for IdentifyPacketCommand { + type ProcessedData<'a> = IdentifyPacketData; + + fn get_fis(&self) -> RegH2DFis { + RegH2DFisBuilder::default() + .cmd(true) + .regs( + DeviceRegsWriteBuilder::default() + .commmad(0xA1) + .build() + .unwrap(), + ) + .build() + .unwrap() + } + + fn data_len(&self) -> usize { + 512 + } + + fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> { + IdentifyPacketData::read(&mut Cursor::new(data)).unwrap() + } +} + +pub struct ReadDmaExtCommand { + lba: u64, + count: u16, +} + +impl ReadDmaExtCommand { + pub fn new(lba: u64, count: u16) -> Self { + Self { lba, count } + } +} + +impl AtaCommandDataIn for ReadDmaExtCommand { + type ProcessedData<'a> = &'a [u8]; + + fn get_fis(&self) -> RegH2DFis { + let lba_bytes = self.lba.to_le_bytes(); + + RegH2DFisBuilder::default() + .cmd(true) + .regs( + DeviceRegsWriteBuilder::default() + .commmad(0x25) + .lba(true) + .countl(self.count as u8) + .counth((self.count >> 8) as u8) + .lba0(lba_bytes[0]) + .lba1(lba_bytes[1]) + .lba2(lba_bytes[2]) + .lba3(lba_bytes[3]) + .lba4(lba_bytes[4]) + .lba5(lba_bytes[5]) + .build() + .unwrap(), + ) + .build() + .unwrap() + } + + fn data_len(&self) -> usize { + (self.count * 512) as usize + } + + fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a> { + data + } +} diff --git a/src/ata_dev.rs b/src/ata_dev.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/ata_dev.rs @@ -0,0 +1 @@ + diff --git a/src/hba.rs b/src/hba.rs new file mode 100644 index 0000000..9725864 --- /dev/null +++ b/src/hba.rs @@ -0,0 +1,127 @@ +use std::os::mikros::address_space::ACTIVE_SPACE; + +use crate::{ + ahci_structs::{CommandHeader, FisBuf, GenHC, PortRegs, CAP, GHC}, + port::AhciPort, +}; + +pub struct Hba { + regs: &'static GenHC, + ports: Vec, + syslog_client: syslog_rpc::Client, +} + +#[derive(Clone, Copy, Debug)] +pub enum HBAInitError { + No64BitDma, +} + +impl Hba { + pub unsafe fn new( + reg_base: *mut u8, + syslog_client: syslog_rpc::Client, + ) -> Result { + let regs = unsafe { &*(reg_base as *const GenHC) }; + + let mut hba = Self { + regs, + ports: Vec::new(), + syslog_client, + }; + let supports_64bit_dma = regs.CAP.read(CAP::S64A) > 0; + if supports_64bit_dma { + syslog_client + .send_text_message("ahci", "HBA supports 64bit DMA".to_string()) + .unwrap(); + } else { + syslog_client + .send_text_message("ahci", "HBA does not support 64bit DMA".to_string()) + .unwrap(); + syslog_client + .send_text_message( + "ahci", + "Aborting, there is no way to ensure buffers from the OS are in the low 4G." + .to_string(), + ) + .unwrap(); + return Err(HBAInitError::No64BitDma); + } + hba.reset(); + let num_raw_ports = (regs.CAP.read(CAP::NP) + 1) as usize; + let port_reg_base = unsafe { reg_base.add(0x100).cast::() }; + let num_ports = (regs.PI.get().count_ones()) as usize; + if num_ports == num_raw_ports { + syslog_client + .send_text_message("ahci", format!("HBA has {num_ports} ports")) + .unwrap(); + } else { + syslog_client + .send_text_message( + "ahci", + format!("HBA has {num_raw_ports} ports, but only {num_ports} are usable",), + ) + .unwrap(); + } + + let max_cmd_slots = (regs.CAP.read(CAP::NCS) + 1) as usize; + + let fis_buf_mem = num_ports * 256; + let comm_list_mem_per_port = max_cmd_slots * 32; + let comm_list_mem = comm_list_mem_per_port * num_ports; + let total_mem = fis_buf_mem + comm_list_mem; + + let (buf, buf_phys) = ACTIVE_SPACE + .lock() + .unwrap() + .map_free_cont_phys(total_mem.div_ceil(4096)) + .unwrap(); + unsafe { + buf.write_bytes(0, total_mem); + } + + let ports_raw = regs.PI.get(); + + for port_idx in 0..32 { + if (ports_raw & (1 << port_idx)) == 0 { + continue; + } + + let regs = unsafe { &*(port_reg_base.add(port_idx)) }; + let cmd_list_phys = buf_phys + (comm_list_mem_per_port * hba.ports.len()) as u64; + let cmd_list = unsafe { + buf.add(comm_list_mem_per_port * hba.ports.len()) + .cast::() + }; + let fis_buf_phys = buf_phys + (comm_list_mem + 256 * hba.ports.len()) as u64; + let fis_buf = unsafe { FisBuf::new(buf.add(comm_list_mem + 256 * hba.ports.len())) }; + + hba.ports.push(AhciPort::new( + regs, + port_idx, + fis_buf, + fis_buf_phys, + cmd_list, + cmd_list_phys, + max_cmd_slots, + )); + } + Ok(hba) + } + + fn reset(&self) { + self.syslog_client + .send_text_message("ahci", "Resetting HBA".to_string()) + .unwrap(); + self.regs.GHC.modify(GHC::AE::SET); + self.regs.GHC.modify(GHC::HR::SET); + while self.regs.GHC.read(GHC::HR) > 0 {} + self.syslog_client + .send_text_message("ahci", "Reset HBA".to_string()) + .unwrap(); + self.regs.GHC.modify(GHC::AE::SET); + } + + pub fn ports(&self) -> &[AhciPort] { + &self.ports + } +} diff --git a/src/identify.rs b/src/identify.rs index 5713a4d..4403e20 100644 --- a/src/identify.rs +++ b/src/identify.rs @@ -162,7 +162,7 @@ pub struct IdentifyData { #[derive(Clone, Debug, BinRead)] #[allow(unused)] -pub struct IdentifyDataATAPI { +pub struct IdentifyPacketData { #[br(seek_before = SeekFrom::Start(0))] pub gen_cfg: u16, #[br(seek_before = SeekFrom::Start(2*2))] diff --git a/src/main.rs b/src/main.rs index c996e00..7e8a8a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,318 +20,22 @@ #![allow(clippy::tuple_array_conversions)] mod ahci_structs; +mod ata_command; +mod ata_dev; +mod hba; mod identify; +mod port; -use std::{ - io::Cursor, - os::mikros::{address_space::ACTIVE_SPACE, syscalls}, - ptr::NonNull, -}; +use std::os::mikros::{address_space::ACTIVE_SPACE, syscalls}; -use ahci_structs::{ - CommandHeader, CommandTableHeader, DeviceRegsWriteBuilder, FisBuf, GenHC, PortRegs, Prd, PxCMD, - PxIS, PxSERR, PxSSTS, PxTFD, RegH2DFis, RegH2DFisBuilder, CAP, GHC, -}; -use binread::BinRead; -use identify::{IdentifyData, IdentifyDataATAPI}; +use ata_command::{IdentifyCommand, IdentifyPacketCommand, ReadDmaExtCommand}; +use hba::Hba; use itertools::Itertools; use uuid::Uuid; -use volatile::VolatilePtr; use x86_64::structures::paging::PageTableFlags; use std::fmt::Debug; -struct AhciPort { - regs: &'static PortRegs, - phys_no: usize, - fis_buf: FisBuf, - cmd_list: *mut CommandHeader, - #[allow(dead_code)] - cmd_list_len: usize, - has_device: bool, -} - -#[derive(Copy, Clone, Debug)] -enum CommandIssueError { - PrdtTooBig, - DataTooBig, - CommandFailed, - DataTooSmall, -} - -impl AhciPort { - - unsafe fn issue_command_prdt( - &self, - fis: &RegH2DFis, - prdt: &[Prd], - ) -> Result<(), CommandIssueError> { - if prdt.len() > 65535 { - return Err(CommandIssueError::PrdtTooBig); - } - - let mut tbl_hdr = CommandTableHeader { - cfis: [0; 64], - acmd: [0; 16], - rsvd: [0; 48], - }; - - tbl_hdr.cfis[0..RegH2DFis::BYTE_SIZE].copy_from_slice(&fis.to_bytes()); - - let tbl_size = std::mem::size_of::() + std::mem::size_of_val(prdt); - - let (buf, buf_phys) = ACTIVE_SPACE - .lock() - .unwrap() - .map_free_cont_phys(tbl_size.div_ceil(4096)) - .unwrap(); - - let tbl_hdr_ptr = buf.cast::(); - let prdt_ptr = unsafe { - buf.add(std::mem::size_of::()) - .cast::() - }; - let cmd_hdr_ptr = self.cmd_list; - - let cmd_hdr = CommandHeader { - flags: (RegH2DFis::BYTE_SIZE / 4) as u16, - prdtl: prdt.len() as u16, - prdbc: 0, - ctba: buf_phys as u32, - ctbau: (buf_phys >> 32) as u32, - rsvd: [0; 4], - }; - - unsafe { - tbl_hdr_ptr.write_volatile(tbl_hdr); - prdt_ptr.copy_from_nonoverlapping(&prdt[0], prdt.len()); - cmd_hdr_ptr.write_volatile(cmd_hdr); - } - - self.regs.PxCI.set(0x1); - - while (self.regs.PxCI.get() & 0x1) > 0 && !self.has_fatal_error() {} - - if self.has_fatal_error() { - self.regs.PxCMD.modify(PxCMD::ST::CLEAR); - while self.regs.PxCMD.read(PxCMD::CR) > 0 {} - Self::clear_serr(self.regs); - self.regs.PxIS.modify(PxIS::HBFS::SET); - self.regs.PxIS.modify(PxIS::HBDS::SET); - self.regs.PxIS.modify(PxIS::IFS::SET); - self.regs.PxIS.modify(PxIS::TFES::SET); - if self.regs.PxTFD.read(PxTFD::STS_BSY) > 0 || self.regs.PxTFD.read(PxTFD::STS_DRQ) > 0 - { - unimplemented!("Port reset on error not implemented") - } - self.regs.PxCMD.modify(PxCMD::ST::SET); - return Err(CommandIssueError::CommandFailed); - } - - ACTIVE_SPACE - .lock() - .unwrap() - .unmap(buf, tbl_size.div_ceil(4096)) - .unwrap(); - - Ok(()) - } - - fn issue_data_in_command<'a, T: AtaCommandDataIn>( - &self, - command: &T, - buf: &'a mut [u8], - ) -> Result, CommandIssueError> { - if buf.len() < command.data_len() { - return Err(CommandIssueError::DataTooSmall); - } - let len = buf.len(); - - if len > 0x40_0000 * 0xFFFF { - return Err(CommandIssueError::DataTooBig); - } - - let (data_buf, data_buf_phys) = ACTIVE_SPACE - .lock() - .unwrap() - .map_free_cont_phys(len.div_ceil(4096)) - .unwrap(); - - let num_prds = len.div_ceil(0x40_0000); - - let prdt = (0..num_prds) - .map(|i| { - let size = if i == num_prds - 1 { - len - (i * 0x40_0000) - } else { - 0x40_0000 - }; - println!( - "Prd::new({:#x}, {:#x}, false)", - data_buf_phys + (i * 0x40_0000) as u64, - size as u32 - ); - Prd::new(data_buf_phys + (i * 0x40_0000) as u64, size as u32, false).unwrap() - }) - .collect_vec(); - - unsafe { - self.issue_command_prdt(&command.get_fis(), &prdt)?; - } - - let data_vol = unsafe { - VolatilePtr::new(NonNull::slice_from_raw_parts( - NonNull::new(data_buf).unwrap(), - len, - )) - }; - - data_vol.copy_into_slice(buf); - - ACTIVE_SPACE - .lock() - .unwrap() - .unmap(data_buf, len.div_ceil(4096)) - .unwrap(); - - Ok(command.process_data(buf)) - } - - fn has_fatal_error(&self) -> bool { - self.regs.PxIS.read(PxIS::HBFS) > 0 - || self.regs.PxIS.read(PxIS::HBDS) > 0 - || self.regs.PxIS.read(PxIS::IFS) > 0 - || self.regs.PxIS.read(PxIS::TFES) > 0 - } - - fn clear_serr(regs: &PortRegs) { - regs.PxSERR.modify(PxSERR::DIAG_X::SET); - regs.PxSERR.modify(PxSERR::DIAG_F::SET); - regs.PxSERR.modify(PxSERR::DIAG_T::SET); - regs.PxSERR.modify(PxSERR::DIAG_S::SET); - regs.PxSERR.modify(PxSERR::DIAG_H::SET); - regs.PxSERR.modify(PxSERR::DIAG_C::SET); - regs.PxSERR.modify(PxSERR::DIAG_B::SET); - regs.PxSERR.modify(PxSERR::DIAG_W::SET); - regs.PxSERR.modify(PxSERR::DIAG_I::SET); - regs.PxSERR.modify(PxSERR::DIAG_N::SET); - regs.PxSERR.modify(PxSERR::ERR_E::SET); - regs.PxSERR.modify(PxSERR::ERR_P::SET); - regs.PxSERR.modify(PxSERR::ERR_C::SET); - regs.PxSERR.modify(PxSERR::ERR_T::SET); - regs.PxSERR.modify(PxSERR::ERR_M::SET); - regs.PxSERR.modify(PxSERR::ERR_I::SET); - } -} - -trait AtaCommandDataIn { - type ProcessedData<'a>; - fn get_fis(&self) -> RegH2DFis; - fn data_len(&self) -> usize; - fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a>; -} - -struct IdentifyCommand; - -impl AtaCommandDataIn for IdentifyCommand { - type ProcessedData<'a> = IdentifyData; - - fn get_fis(&self) -> RegH2DFis { - RegH2DFisBuilder::default() - .cmd(true) - .regs( - DeviceRegsWriteBuilder::default() - .commmad(0xEC) - .build() - .unwrap(), - ) - .build() - .unwrap() - } - - fn data_len(&self) -> usize { - 512 - } - - fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> { - IdentifyData::read(&mut Cursor::new(data)).unwrap() - } -} - -struct IdentifyPacketCommand; - -impl AtaCommandDataIn for IdentifyPacketCommand { - type ProcessedData<'a> = IdentifyDataATAPI; - - fn get_fis(&self) -> RegH2DFis { - RegH2DFisBuilder::default() - .cmd(true) - .regs( - DeviceRegsWriteBuilder::default() - .commmad(0xA1) - .build() - .unwrap(), - ) - .build() - .unwrap() - } - - fn data_len(&self) -> usize { - 512 - } - - fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> { - IdentifyDataATAPI::read(&mut Cursor::new(data)).unwrap() - } -} - -struct ReadDmaExtCommand { - lba: u64, - count: u16, -} - -impl ReadDmaExtCommand { - fn new(lba: u64, count: u16) -> Self { - Self { lba, count } - } -} - -impl AtaCommandDataIn for ReadDmaExtCommand { - type ProcessedData<'a> = &'a [u8]; - - fn get_fis(&self) -> RegH2DFis { - let lba_bytes = self.lba.to_le_bytes(); - - RegH2DFisBuilder::default() - .cmd(true) - .regs( - DeviceRegsWriteBuilder::default() - .commmad(0x25) - .lba(true) - .countl(self.count as u8) - .counth((self.count >> 8) as u8) - .lba0(lba_bytes[0]) - .lba1(lba_bytes[1]) - .lba2(lba_bytes[2]) - .lba3(lba_bytes[3]) - .lba4(lba_bytes[4]) - .lba5(lba_bytes[5]) - .build() - .unwrap(), - ) - .build() - .unwrap() - } - - fn data_len(&self) -> usize { - (self.count * 512) as usize - } - - fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a> { - data - } -} - fn main() { let syslog_pid = loop { if let Some(pid) = syscalls::try_get_registered(2) { @@ -375,124 +79,13 @@ fn main() { ) .unwrap() }; - let ghc_regs = unsafe { &*(reg_base as *const GenHC) }; - let supports_64bit_dma = ghc_regs.CAP.read(CAP::S64A) > 0; - if supports_64bit_dma { - syslog_client - .send_text_message("ahci", "HBA supports 64bit DMA".to_string()) - .unwrap(); - } else { - syslog_client - .send_text_message("ahci", "HBA does not support 64bit DMA".to_string()) - .unwrap(); - syslog_client - .send_text_message( - "ahci", - "Aborting, there is no way to ensure buffers from the OS are in the low 4G." - .to_string(), - ) - .unwrap(); - return; - } + let hba = unsafe { Hba::new(reg_base, syslog_client) }.unwrap(); - syslog_client - .send_text_message("ahci", "Resetting HBA".to_string()) - .unwrap(); - ghc_regs.GHC.modify(GHC::AE::SET); - ghc_regs.GHC.modify(GHC::HR::SET); - while ghc_regs.GHC.read(GHC::HR) > 0 {} - syslog_client - .send_text_message("ahci", "Reset HBA".to_string()) - .unwrap(); - ghc_regs.GHC.modify(GHC::AE::SET); - - let num_raw_ports = (ghc_regs.CAP.read(CAP::NP) + 1) as usize; - let port_reg_base = unsafe { reg_base.add(0x100).cast::() }; - let num_ports = (ghc_regs.PI.get().count_ones()) as usize; - if num_ports == num_raw_ports { - syslog_client - .send_text_message("ahci", format!("HBA has {num_ports} ports")) - .unwrap(); - } else { - syslog_client - .send_text_message( - "ahci", - format!("HBA has {num_raw_ports} ports, but only {num_ports} are usable",), - ) - .unwrap(); - } - - let max_cmd_slots = (ghc_regs.CAP.read(CAP::NCS) + 1) as usize; - - let fis_buf_mem = num_ports * 256; - let comm_list_mem_per_port = max_cmd_slots * 32; - let comm_list_mem = comm_list_mem_per_port * num_ports; - let total_mem = fis_buf_mem + comm_list_mem; - - let (buf, buf_phys) = ACTIVE_SPACE - .lock() - .unwrap() - .map_free_cont_phys(total_mem.div_ceil(4096)) - .unwrap(); - unsafe { - buf.write_bytes(0, total_mem); - } - - let mut avail_ports = Vec::new(); - let avail_ports_raw = ghc_regs.PI.get(); - - for port_idx in 0..32 { - if (avail_ports_raw & (1 << port_idx)) == 0 { - continue; - } - - let port = unsafe { &*(port_reg_base.add(port_idx)) }; - let cmd_reg = &port.PxCMD; - if !(cmd_reg.read(PxCMD::ST) == 0 - && cmd_reg.read(PxCMD::CR) == 0 - && cmd_reg.read(PxCMD::FRE) == 0 - && cmd_reg.read(PxCMD::FR) == 0) - { - cmd_reg.modify(PxCMD::ST::CLEAR); - while cmd_reg.read(PxCMD::CR) > 0 {} - if cmd_reg.read(PxCMD::FRE) == 1 { - cmd_reg.modify(PxCMD::FRE::CLEAR); - while cmd_reg.read(PxCMD::FR) > 0 {} - } - } - let cl_phys_ptr = buf_phys + (comm_list_mem_per_port * avail_ports.len()) as u64; - let fis_phys_ptr = buf_phys + (comm_list_mem + 256 * avail_ports.len()) as u64; - port.PxCLB.set(cl_phys_ptr as u32); - port.PxCLBU.set((cl_phys_ptr >> 32) as u32); - port.PxFB.set(fis_phys_ptr as u32); - port.PxFBU.set((fis_phys_ptr >> 32) as u32); - port.PxCMD.modify(PxCMD::FRE::SET); - AhciPort::clear_serr(port); - - let has_device = port.PxTFD.read(PxTFD::STS_BSY) == 0 - && port.PxTFD.read(PxTFD::STS_DRQ) == 0 - && port.PxSSTS.read(PxSSTS::DET) == 3; - if has_device { - port.PxCMD.modify(PxCMD::ST::SET); - } - avail_ports.push(AhciPort { - regs: port, - phys_no: port_idx, - fis_buf: unsafe { FisBuf::new(buf.add(comm_list_mem + 256 * avail_ports.len())) }, - cmd_list: unsafe { - buf.add(comm_list_mem_per_port * avail_ports.len()) - .cast::() - }, - cmd_list_len: max_cmd_slots, - has_device, - }); - } - - for port in &avail_ports { - if !port.has_device { + for port in hba.ports() { + if !port.has_device() { syslog_client - .send_text_message("ahci", format!("Port {}: Empty", port.phys_no)) + .send_text_message("ahci", format!("Port {}: Empty", port.phys_no())) .unwrap(); continue; } @@ -522,7 +115,7 @@ fn main() { "ahci", format!( "Port {}: {} {}, firmware {} ({})", - port.phys_no, + port.phys_no(), identify_data.model, identify_data.serial, identify_data.firmware, @@ -531,19 +124,26 @@ fn main() { ) .unwrap(); } else { - let reg_fis = port.fis_buf.get_d2h_reg_fis().unwrap(); + let regs = port.regs().unwrap(); - if reg_fis.regs.lba1 != 0x14 || reg_fis.regs.lba2 != 0xEB { + if regs.lba1 != 0x14 || regs.lba2 != 0xEB { syslog_client - .send_text_message("ahci", format!("Port {} failed to identify", port.phys_no)) + .send_text_message( + "ahci", + format!("Port {} failed to identify", port.phys_no()), + ) .unwrap(); continue; } - let Ok(identify_data) = port.issue_data_in_command(&IdentifyPacketCommand, &mut ident_buf) + let Ok(identify_data) = + port.issue_data_in_command(&IdentifyPacketCommand, &mut ident_buf) else { syslog_client - .send_text_message("ahci", format!("Port {} failed to identify", port.phys_no)) + .send_text_message( + "ahci", + format!("Port {} failed to identify", port.phys_no()), + ) .unwrap(); continue; }; @@ -553,7 +153,7 @@ fn main() { "ahci", format!( "Port {}: {} {}, firmware {}", - port.phys_no, + port.phys_no(), identify_data.model, identify_data.serial, identify_data.firmware @@ -565,7 +165,7 @@ fn main() { let mut mbr = [0; 512]; - avail_ports[0] + hba.ports()[0] .issue_data_in_command(&ReadDmaExtCommand::new(0, 1), &mut mbr) .unwrap(); @@ -606,7 +206,7 @@ fn main() { let mut gpt_header = [0; 512]; - avail_ports[0] + hba.ports()[0] .issue_data_in_command(&ReadDmaExtCommand::new(1, 1), &mut gpt_header) .unwrap(); @@ -641,7 +241,7 @@ fn main() { let mut gpt_part_table = vec![0; part_table_num_lbas * 512]; - avail_ports[0] + hba.ports()[0] .issue_data_in_command( &ReadDmaExtCommand::new(part_table_start_lba as u64, part_table_num_lbas as u16), gpt_part_table.as_mut_slice(), diff --git a/src/port.rs b/src/port.rs new file mode 100644 index 0000000..495c717 --- /dev/null +++ b/src/port.rs @@ -0,0 +1,253 @@ +use std::{os::mikros::address_space::ACTIVE_SPACE, ptr::NonNull}; + +use itertools::Itertools; +use volatile::VolatilePtr; + +use crate::{ + ahci_structs::{ + CommandHeader, CommandTableHeader, DeviceRegsRead, FisBuf, PortRegs, Prd, PxCMD, PxIS, + PxSERR, PxSSTS, PxTFD, RegH2DFis, + }, + ata_command::AtaCommandDataIn, +}; + +#[allow(clippy::module_name_repetitions)] +pub struct AhciPort { + regs: &'static PortRegs, + phys_no: usize, + fis_buf: FisBuf, + cmd_list: *mut CommandHeader, + #[allow(dead_code)] + cmd_list_len: usize, + has_device: bool, +} + +#[derive(Copy, Clone, Debug)] +pub enum CommandIssueError { + PrdtTooBig, + DataTooBig, + CommandFailed, + DataTooSmall, +} + +impl AhciPort { + pub fn new( + regs: &'static PortRegs, + phys_no: usize, + fis_buf: FisBuf, + fis_buf_phys: u64, + cmd_list: *mut CommandHeader, + cmd_list_phys: u64, + cmd_list_len: usize, + ) -> Self { + let cmd_reg = ®s.PxCMD; + if !(cmd_reg.read(PxCMD::ST) == 0 + && cmd_reg.read(PxCMD::CR) == 0 + && cmd_reg.read(PxCMD::FRE) == 0 + && cmd_reg.read(PxCMD::FR) == 0) + { + cmd_reg.modify(PxCMD::ST::CLEAR); + while cmd_reg.read(PxCMD::CR) > 0 {} + if cmd_reg.read(PxCMD::FRE) == 1 { + cmd_reg.modify(PxCMD::FRE::CLEAR); + while cmd_reg.read(PxCMD::FR) > 0 {} + } + } + regs.PxCLB.set(cmd_list_phys as u32); + regs.PxCLBU.set((cmd_list_phys >> 32) as u32); + regs.PxFB.set(fis_buf_phys as u32); + regs.PxFBU.set((fis_buf_phys >> 32) as u32); + regs.PxCMD.modify(PxCMD::FRE::SET); + Self::clear_serr(regs); + + let has_device = regs.PxTFD.read(PxTFD::STS_BSY) == 0 + && regs.PxTFD.read(PxTFD::STS_DRQ) == 0 + && regs.PxSSTS.read(PxSSTS::DET) == 3; + if has_device { + regs.PxCMD.modify(PxCMD::ST::SET); + } + + Self { + regs, + phys_no, + fis_buf, + cmd_list, + cmd_list_len, + has_device, + } + } + + unsafe fn issue_command_prdt( + &self, + fis: &RegH2DFis, + prdt: &[Prd], + ) -> Result<(), CommandIssueError> { + if prdt.len() > 65535 { + return Err(CommandIssueError::PrdtTooBig); + } + + let mut tbl_hdr = CommandTableHeader { + cfis: [0; 64], + acmd: [0; 16], + rsvd: [0; 48], + }; + + tbl_hdr.cfis[0..RegH2DFis::BYTE_SIZE].copy_from_slice(&fis.to_bytes()); + + let tbl_size = std::mem::size_of::() + std::mem::size_of_val(prdt); + + let (buf, buf_phys) = ACTIVE_SPACE + .lock() + .unwrap() + .map_free_cont_phys(tbl_size.div_ceil(4096)) + .unwrap(); + + let tbl_hdr_ptr = buf.cast::(); + let prdt_ptr = unsafe { + buf.add(std::mem::size_of::()) + .cast::() + }; + let cmd_hdr_ptr = self.cmd_list; + + let cmd_hdr = CommandHeader { + flags: (RegH2DFis::BYTE_SIZE / 4) as u16, + prdtl: prdt.len() as u16, + prdbc: 0, + ctba: buf_phys as u32, + ctbau: (buf_phys >> 32) as u32, + rsvd: [0; 4], + }; + + unsafe { + tbl_hdr_ptr.write_volatile(tbl_hdr); + prdt_ptr.copy_from_nonoverlapping(&prdt[0], prdt.len()); + cmd_hdr_ptr.write_volatile(cmd_hdr); + } + + self.regs.PxCI.set(0x1); + + while (self.regs.PxCI.get() & 0x1) > 0 && !self.has_fatal_error() {} + + if self.has_fatal_error() { + self.regs.PxCMD.modify(PxCMD::ST::CLEAR); + while self.regs.PxCMD.read(PxCMD::CR) > 0 {} + Self::clear_serr(self.regs); + self.regs.PxIS.modify(PxIS::HBFS::SET); + self.regs.PxIS.modify(PxIS::HBDS::SET); + self.regs.PxIS.modify(PxIS::IFS::SET); + self.regs.PxIS.modify(PxIS::TFES::SET); + if self.regs.PxTFD.read(PxTFD::STS_BSY) > 0 || self.regs.PxTFD.read(PxTFD::STS_DRQ) > 0 + { + unimplemented!("Port reset on error not implemented") + } + self.regs.PxCMD.modify(PxCMD::ST::SET); + return Err(CommandIssueError::CommandFailed); + } + + ACTIVE_SPACE + .lock() + .unwrap() + .unmap(buf, tbl_size.div_ceil(4096)) + .unwrap(); + + Ok(()) + } + + pub fn issue_data_in_command<'a, T: AtaCommandDataIn>( + &self, + command: &T, + buf: &'a mut [u8], + ) -> Result, CommandIssueError> { + if buf.len() < command.data_len() { + return Err(CommandIssueError::DataTooSmall); + } + let len = buf.len(); + + if len > 0x40_0000 * 0xFFFF { + return Err(CommandIssueError::DataTooBig); + } + + let (data_buf, data_buf_phys) = ACTIVE_SPACE + .lock() + .unwrap() + .map_free_cont_phys(len.div_ceil(4096)) + .unwrap(); + + let num_prds = len.div_ceil(0x40_0000); + + let prdt = (0..num_prds) + .map(|i| { + let size = if i == num_prds - 1 { + len - (i * 0x40_0000) + } else { + 0x40_0000 + }; + println!( + "Prd::new({:#x}, {:#x}, false)", + data_buf_phys + (i * 0x40_0000) as u64, + size as u32 + ); + Prd::new(data_buf_phys + (i * 0x40_0000) as u64, size as u32, false).unwrap() + }) + .collect_vec(); + + unsafe { + self.issue_command_prdt(&command.get_fis(), &prdt)?; + } + + let data_vol = unsafe { + VolatilePtr::new(NonNull::slice_from_raw_parts( + NonNull::new(data_buf).unwrap(), + len, + )) + }; + + data_vol.copy_into_slice(buf); + + ACTIVE_SPACE + .lock() + .unwrap() + .unmap(data_buf, len.div_ceil(4096)) + .unwrap(); + + Ok(command.process_data(buf)) + } + + fn has_fatal_error(&self) -> bool { + self.regs.PxIS.read(PxIS::HBFS) > 0 + || self.regs.PxIS.read(PxIS::HBDS) > 0 + || self.regs.PxIS.read(PxIS::IFS) > 0 + || self.regs.PxIS.read(PxIS::TFES) > 0 + } + + fn clear_serr(regs: &PortRegs) { + regs.PxSERR.modify(PxSERR::DIAG_X::SET); + regs.PxSERR.modify(PxSERR::DIAG_F::SET); + regs.PxSERR.modify(PxSERR::DIAG_T::SET); + regs.PxSERR.modify(PxSERR::DIAG_S::SET); + regs.PxSERR.modify(PxSERR::DIAG_H::SET); + regs.PxSERR.modify(PxSERR::DIAG_C::SET); + regs.PxSERR.modify(PxSERR::DIAG_B::SET); + regs.PxSERR.modify(PxSERR::DIAG_W::SET); + regs.PxSERR.modify(PxSERR::DIAG_I::SET); + regs.PxSERR.modify(PxSERR::DIAG_N::SET); + regs.PxSERR.modify(PxSERR::ERR_E::SET); + regs.PxSERR.modify(PxSERR::ERR_P::SET); + regs.PxSERR.modify(PxSERR::ERR_C::SET); + regs.PxSERR.modify(PxSERR::ERR_T::SET); + regs.PxSERR.modify(PxSERR::ERR_M::SET); + regs.PxSERR.modify(PxSERR::ERR_I::SET); + } + + pub fn phys_no(&self) -> usize { + self.phys_no + } + + pub fn has_device(&self) -> bool { + self.has_device + } + + pub fn regs(&self) -> Option { + self.fis_buf.get_d2h_reg_fis().map(|x| x.regs) + } +}