Split into multiple files

This commit is contained in:
pjht 2024-08-20 19:12:34 -05:00
parent 09fa549dea
commit 4f6eb0080b
Signed by: pjht
GPG Key ID: 7B5F6AFBEC7EE78E
6 changed files with 527 additions and 429 deletions

117
src/ata_command.rs Normal file
View File

@ -0,0 +1,117 @@
use std::io::Cursor;
use binread::BinRead;
use crate::{
ahci_structs::{DeviceRegsWriteBuilder, RegH2DFis, RegH2DFisBuilder},
identify::{IdentifyData, IdentifyPacketData},
};
#[allow(clippy::module_name_repetitions)]
pub trait AtaCommandDataIn {
type ProcessedData<'a>;
fn get_fis(&self) -> RegH2DFis;
fn data_len(&self) -> usize;
fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a>;
}
pub struct IdentifyCommand;
impl AtaCommandDataIn for IdentifyCommand {
type ProcessedData<'a> = IdentifyData;
fn get_fis(&self) -> RegH2DFis {
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0xEC)
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
512
}
fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> {
IdentifyData::read(&mut Cursor::new(data)).unwrap()
}
}
pub struct IdentifyPacketCommand;
impl AtaCommandDataIn for IdentifyPacketCommand {
type ProcessedData<'a> = IdentifyPacketData;
fn get_fis(&self) -> RegH2DFis {
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0xA1)
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
512
}
fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> {
IdentifyPacketData::read(&mut Cursor::new(data)).unwrap()
}
}
pub struct ReadDmaExtCommand {
lba: u64,
count: u16,
}
impl ReadDmaExtCommand {
pub fn new(lba: u64, count: u16) -> Self {
Self { lba, count }
}
}
impl AtaCommandDataIn for ReadDmaExtCommand {
type ProcessedData<'a> = &'a [u8];
fn get_fis(&self) -> RegH2DFis {
let lba_bytes = self.lba.to_le_bytes();
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0x25)
.lba(true)
.countl(self.count as u8)
.counth((self.count >> 8) as u8)
.lba0(lba_bytes[0])
.lba1(lba_bytes[1])
.lba2(lba_bytes[2])
.lba3(lba_bytes[3])
.lba4(lba_bytes[4])
.lba5(lba_bytes[5])
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
(self.count * 512) as usize
}
fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a> {
data
}
}

1
src/ata_dev.rs Normal file
View File

@ -0,0 +1 @@

127
src/hba.rs Normal file
View File

@ -0,0 +1,127 @@
use std::os::mikros::address_space::ACTIVE_SPACE;
use crate::{
ahci_structs::{CommandHeader, FisBuf, GenHC, PortRegs, CAP, GHC},
port::AhciPort,
};
pub struct Hba {
regs: &'static GenHC,
ports: Vec<AhciPort>,
syslog_client: syslog_rpc::Client,
}
#[derive(Clone, Copy, Debug)]
pub enum HBAInitError {
No64BitDma,
}
impl Hba {
pub unsafe fn new(
reg_base: *mut u8,
syslog_client: syslog_rpc::Client,
) -> Result<Self, HBAInitError> {
let regs = unsafe { &*(reg_base as *const GenHC) };
let mut hba = Self {
regs,
ports: Vec::new(),
syslog_client,
};
let supports_64bit_dma = regs.CAP.read(CAP::S64A) > 0;
if supports_64bit_dma {
syslog_client
.send_text_message("ahci", "HBA supports 64bit DMA".to_string())
.unwrap();
} else {
syslog_client
.send_text_message("ahci", "HBA does not support 64bit DMA".to_string())
.unwrap();
syslog_client
.send_text_message(
"ahci",
"Aborting, there is no way to ensure buffers from the OS are in the low 4G."
.to_string(),
)
.unwrap();
return Err(HBAInitError::No64BitDma);
}
hba.reset();
let num_raw_ports = (regs.CAP.read(CAP::NP) + 1) as usize;
let port_reg_base = unsafe { reg_base.add(0x100).cast::<PortRegs>() };
let num_ports = (regs.PI.get().count_ones()) as usize;
if num_ports == num_raw_ports {
syslog_client
.send_text_message("ahci", format!("HBA has {num_ports} ports"))
.unwrap();
} else {
syslog_client
.send_text_message(
"ahci",
format!("HBA has {num_raw_ports} ports, but only {num_ports} are usable",),
)
.unwrap();
}
let max_cmd_slots = (regs.CAP.read(CAP::NCS) + 1) as usize;
let fis_buf_mem = num_ports * 256;
let comm_list_mem_per_port = max_cmd_slots * 32;
let comm_list_mem = comm_list_mem_per_port * num_ports;
let total_mem = fis_buf_mem + comm_list_mem;
let (buf, buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(total_mem.div_ceil(4096))
.unwrap();
unsafe {
buf.write_bytes(0, total_mem);
}
let ports_raw = regs.PI.get();
for port_idx in 0..32 {
if (ports_raw & (1 << port_idx)) == 0 {
continue;
}
let regs = unsafe { &*(port_reg_base.add(port_idx)) };
let cmd_list_phys = buf_phys + (comm_list_mem_per_port * hba.ports.len()) as u64;
let cmd_list = unsafe {
buf.add(comm_list_mem_per_port * hba.ports.len())
.cast::<CommandHeader>()
};
let fis_buf_phys = buf_phys + (comm_list_mem + 256 * hba.ports.len()) as u64;
let fis_buf = unsafe { FisBuf::new(buf.add(comm_list_mem + 256 * hba.ports.len())) };
hba.ports.push(AhciPort::new(
regs,
port_idx,
fis_buf,
fis_buf_phys,
cmd_list,
cmd_list_phys,
max_cmd_slots,
));
}
Ok(hba)
}
fn reset(&self) {
self.syslog_client
.send_text_message("ahci", "Resetting HBA".to_string())
.unwrap();
self.regs.GHC.modify(GHC::AE::SET);
self.regs.GHC.modify(GHC::HR::SET);
while self.regs.GHC.read(GHC::HR) > 0 {}
self.syslog_client
.send_text_message("ahci", "Reset HBA".to_string())
.unwrap();
self.regs.GHC.modify(GHC::AE::SET);
}
pub fn ports(&self) -> &[AhciPort] {
&self.ports
}
}

View File

@ -162,7 +162,7 @@ pub struct IdentifyData {
#[derive(Clone, Debug, BinRead)]
#[allow(unused)]
pub struct IdentifyDataATAPI {
pub struct IdentifyPacketData {
#[br(seek_before = SeekFrom::Start(0))]
pub gen_cfg: u16,
#[br(seek_before = SeekFrom::Start(2*2))]

View File

@ -20,318 +20,22 @@
#![allow(clippy::tuple_array_conversions)]
mod ahci_structs;
mod ata_command;
mod ata_dev;
mod hba;
mod identify;
mod port;
use std::{
io::Cursor,
os::mikros::{address_space::ACTIVE_SPACE, syscalls},
ptr::NonNull,
};
use std::os::mikros::{address_space::ACTIVE_SPACE, syscalls};
use ahci_structs::{
CommandHeader, CommandTableHeader, DeviceRegsWriteBuilder, FisBuf, GenHC, PortRegs, Prd, PxCMD,
PxIS, PxSERR, PxSSTS, PxTFD, RegH2DFis, RegH2DFisBuilder, CAP, GHC,
};
use binread::BinRead;
use identify::{IdentifyData, IdentifyDataATAPI};
use ata_command::{IdentifyCommand, IdentifyPacketCommand, ReadDmaExtCommand};
use hba::Hba;
use itertools::Itertools;
use uuid::Uuid;
use volatile::VolatilePtr;
use x86_64::structures::paging::PageTableFlags;
use std::fmt::Debug;
struct AhciPort {
regs: &'static PortRegs,
phys_no: usize,
fis_buf: FisBuf,
cmd_list: *mut CommandHeader,
#[allow(dead_code)]
cmd_list_len: usize,
has_device: bool,
}
#[derive(Copy, Clone, Debug)]
enum CommandIssueError {
PrdtTooBig,
DataTooBig,
CommandFailed,
DataTooSmall,
}
impl AhciPort {
unsafe fn issue_command_prdt(
&self,
fis: &RegH2DFis,
prdt: &[Prd],
) -> Result<(), CommandIssueError> {
if prdt.len() > 65535 {
return Err(CommandIssueError::PrdtTooBig);
}
let mut tbl_hdr = CommandTableHeader {
cfis: [0; 64],
acmd: [0; 16],
rsvd: [0; 48],
};
tbl_hdr.cfis[0..RegH2DFis::BYTE_SIZE].copy_from_slice(&fis.to_bytes());
let tbl_size = std::mem::size_of::<CommandTableHeader>() + std::mem::size_of_val(prdt);
let (buf, buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(tbl_size.div_ceil(4096))
.unwrap();
let tbl_hdr_ptr = buf.cast::<CommandTableHeader>();
let prdt_ptr = unsafe {
buf.add(std::mem::size_of::<CommandTableHeader>())
.cast::<Prd>()
};
let cmd_hdr_ptr = self.cmd_list;
let cmd_hdr = CommandHeader {
flags: (RegH2DFis::BYTE_SIZE / 4) as u16,
prdtl: prdt.len() as u16,
prdbc: 0,
ctba: buf_phys as u32,
ctbau: (buf_phys >> 32) as u32,
rsvd: [0; 4],
};
unsafe {
tbl_hdr_ptr.write_volatile(tbl_hdr);
prdt_ptr.copy_from_nonoverlapping(&prdt[0], prdt.len());
cmd_hdr_ptr.write_volatile(cmd_hdr);
}
self.regs.PxCI.set(0x1);
while (self.regs.PxCI.get() & 0x1) > 0 && !self.has_fatal_error() {}
if self.has_fatal_error() {
self.regs.PxCMD.modify(PxCMD::ST::CLEAR);
while self.regs.PxCMD.read(PxCMD::CR) > 0 {}
Self::clear_serr(self.regs);
self.regs.PxIS.modify(PxIS::HBFS::SET);
self.regs.PxIS.modify(PxIS::HBDS::SET);
self.regs.PxIS.modify(PxIS::IFS::SET);
self.regs.PxIS.modify(PxIS::TFES::SET);
if self.regs.PxTFD.read(PxTFD::STS_BSY) > 0 || self.regs.PxTFD.read(PxTFD::STS_DRQ) > 0
{
unimplemented!("Port reset on error not implemented")
}
self.regs.PxCMD.modify(PxCMD::ST::SET);
return Err(CommandIssueError::CommandFailed);
}
ACTIVE_SPACE
.lock()
.unwrap()
.unmap(buf, tbl_size.div_ceil(4096))
.unwrap();
Ok(())
}
fn issue_data_in_command<'a, T: AtaCommandDataIn>(
&self,
command: &T,
buf: &'a mut [u8],
) -> Result<T::ProcessedData<'a>, CommandIssueError> {
if buf.len() < command.data_len() {
return Err(CommandIssueError::DataTooSmall);
}
let len = buf.len();
if len > 0x40_0000 * 0xFFFF {
return Err(CommandIssueError::DataTooBig);
}
let (data_buf, data_buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(len.div_ceil(4096))
.unwrap();
let num_prds = len.div_ceil(0x40_0000);
let prdt = (0..num_prds)
.map(|i| {
let size = if i == num_prds - 1 {
len - (i * 0x40_0000)
} else {
0x40_0000
};
println!(
"Prd::new({:#x}, {:#x}, false)",
data_buf_phys + (i * 0x40_0000) as u64,
size as u32
);
Prd::new(data_buf_phys + (i * 0x40_0000) as u64, size as u32, false).unwrap()
})
.collect_vec();
unsafe {
self.issue_command_prdt(&command.get_fis(), &prdt)?;
}
let data_vol = unsafe {
VolatilePtr::new(NonNull::slice_from_raw_parts(
NonNull::new(data_buf).unwrap(),
len,
))
};
data_vol.copy_into_slice(buf);
ACTIVE_SPACE
.lock()
.unwrap()
.unmap(data_buf, len.div_ceil(4096))
.unwrap();
Ok(command.process_data(buf))
}
fn has_fatal_error(&self) -> bool {
self.regs.PxIS.read(PxIS::HBFS) > 0
|| self.regs.PxIS.read(PxIS::HBDS) > 0
|| self.regs.PxIS.read(PxIS::IFS) > 0
|| self.regs.PxIS.read(PxIS::TFES) > 0
}
fn clear_serr(regs: &PortRegs) {
regs.PxSERR.modify(PxSERR::DIAG_X::SET);
regs.PxSERR.modify(PxSERR::DIAG_F::SET);
regs.PxSERR.modify(PxSERR::DIAG_T::SET);
regs.PxSERR.modify(PxSERR::DIAG_S::SET);
regs.PxSERR.modify(PxSERR::DIAG_H::SET);
regs.PxSERR.modify(PxSERR::DIAG_C::SET);
regs.PxSERR.modify(PxSERR::DIAG_B::SET);
regs.PxSERR.modify(PxSERR::DIAG_W::SET);
regs.PxSERR.modify(PxSERR::DIAG_I::SET);
regs.PxSERR.modify(PxSERR::DIAG_N::SET);
regs.PxSERR.modify(PxSERR::ERR_E::SET);
regs.PxSERR.modify(PxSERR::ERR_P::SET);
regs.PxSERR.modify(PxSERR::ERR_C::SET);
regs.PxSERR.modify(PxSERR::ERR_T::SET);
regs.PxSERR.modify(PxSERR::ERR_M::SET);
regs.PxSERR.modify(PxSERR::ERR_I::SET);
}
}
trait AtaCommandDataIn {
type ProcessedData<'a>;
fn get_fis(&self) -> RegH2DFis;
fn data_len(&self) -> usize;
fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a>;
}
struct IdentifyCommand;
impl AtaCommandDataIn for IdentifyCommand {
type ProcessedData<'a> = IdentifyData;
fn get_fis(&self) -> RegH2DFis {
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0xEC)
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
512
}
fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> {
IdentifyData::read(&mut Cursor::new(data)).unwrap()
}
}
struct IdentifyPacketCommand;
impl AtaCommandDataIn for IdentifyPacketCommand {
type ProcessedData<'a> = IdentifyDataATAPI;
fn get_fis(&self) -> RegH2DFis {
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0xA1)
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
512
}
fn process_data(&self, data: &[u8]) -> Self::ProcessedData<'_> {
IdentifyDataATAPI::read(&mut Cursor::new(data)).unwrap()
}
}
struct ReadDmaExtCommand {
lba: u64,
count: u16,
}
impl ReadDmaExtCommand {
fn new(lba: u64, count: u16) -> Self {
Self { lba, count }
}
}
impl AtaCommandDataIn for ReadDmaExtCommand {
type ProcessedData<'a> = &'a [u8];
fn get_fis(&self) -> RegH2DFis {
let lba_bytes = self.lba.to_le_bytes();
RegH2DFisBuilder::default()
.cmd(true)
.regs(
DeviceRegsWriteBuilder::default()
.commmad(0x25)
.lba(true)
.countl(self.count as u8)
.counth((self.count >> 8) as u8)
.lba0(lba_bytes[0])
.lba1(lba_bytes[1])
.lba2(lba_bytes[2])
.lba3(lba_bytes[3])
.lba4(lba_bytes[4])
.lba5(lba_bytes[5])
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn data_len(&self) -> usize {
(self.count * 512) as usize
}
fn process_data<'a>(&self, data: &'a [u8]) -> Self::ProcessedData<'a> {
data
}
}
fn main() {
let syslog_pid = loop {
if let Some(pid) = syscalls::try_get_registered(2) {
@ -375,124 +79,13 @@ fn main() {
)
.unwrap()
};
let ghc_regs = unsafe { &*(reg_base as *const GenHC) };
let supports_64bit_dma = ghc_regs.CAP.read(CAP::S64A) > 0;
if supports_64bit_dma {
syslog_client
.send_text_message("ahci", "HBA supports 64bit DMA".to_string())
.unwrap();
} else {
syslog_client
.send_text_message("ahci", "HBA does not support 64bit DMA".to_string())
.unwrap();
syslog_client
.send_text_message(
"ahci",
"Aborting, there is no way to ensure buffers from the OS are in the low 4G."
.to_string(),
)
.unwrap();
return;
}
let hba = unsafe { Hba::new(reg_base, syslog_client) }.unwrap();
syslog_client
.send_text_message("ahci", "Resetting HBA".to_string())
.unwrap();
ghc_regs.GHC.modify(GHC::AE::SET);
ghc_regs.GHC.modify(GHC::HR::SET);
while ghc_regs.GHC.read(GHC::HR) > 0 {}
syslog_client
.send_text_message("ahci", "Reset HBA".to_string())
.unwrap();
ghc_regs.GHC.modify(GHC::AE::SET);
let num_raw_ports = (ghc_regs.CAP.read(CAP::NP) + 1) as usize;
let port_reg_base = unsafe { reg_base.add(0x100).cast::<PortRegs>() };
let num_ports = (ghc_regs.PI.get().count_ones()) as usize;
if num_ports == num_raw_ports {
syslog_client
.send_text_message("ahci", format!("HBA has {num_ports} ports"))
.unwrap();
} else {
syslog_client
.send_text_message(
"ahci",
format!("HBA has {num_raw_ports} ports, but only {num_ports} are usable",),
)
.unwrap();
}
let max_cmd_slots = (ghc_regs.CAP.read(CAP::NCS) + 1) as usize;
let fis_buf_mem = num_ports * 256;
let comm_list_mem_per_port = max_cmd_slots * 32;
let comm_list_mem = comm_list_mem_per_port * num_ports;
let total_mem = fis_buf_mem + comm_list_mem;
let (buf, buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(total_mem.div_ceil(4096))
.unwrap();
unsafe {
buf.write_bytes(0, total_mem);
}
let mut avail_ports = Vec::new();
let avail_ports_raw = ghc_regs.PI.get();
for port_idx in 0..32 {
if (avail_ports_raw & (1 << port_idx)) == 0 {
continue;
}
let port = unsafe { &*(port_reg_base.add(port_idx)) };
let cmd_reg = &port.PxCMD;
if !(cmd_reg.read(PxCMD::ST) == 0
&& cmd_reg.read(PxCMD::CR) == 0
&& cmd_reg.read(PxCMD::FRE) == 0
&& cmd_reg.read(PxCMD::FR) == 0)
{
cmd_reg.modify(PxCMD::ST::CLEAR);
while cmd_reg.read(PxCMD::CR) > 0 {}
if cmd_reg.read(PxCMD::FRE) == 1 {
cmd_reg.modify(PxCMD::FRE::CLEAR);
while cmd_reg.read(PxCMD::FR) > 0 {}
}
}
let cl_phys_ptr = buf_phys + (comm_list_mem_per_port * avail_ports.len()) as u64;
let fis_phys_ptr = buf_phys + (comm_list_mem + 256 * avail_ports.len()) as u64;
port.PxCLB.set(cl_phys_ptr as u32);
port.PxCLBU.set((cl_phys_ptr >> 32) as u32);
port.PxFB.set(fis_phys_ptr as u32);
port.PxFBU.set((fis_phys_ptr >> 32) as u32);
port.PxCMD.modify(PxCMD::FRE::SET);
AhciPort::clear_serr(port);
let has_device = port.PxTFD.read(PxTFD::STS_BSY) == 0
&& port.PxTFD.read(PxTFD::STS_DRQ) == 0
&& port.PxSSTS.read(PxSSTS::DET) == 3;
if has_device {
port.PxCMD.modify(PxCMD::ST::SET);
}
avail_ports.push(AhciPort {
regs: port,
phys_no: port_idx,
fis_buf: unsafe { FisBuf::new(buf.add(comm_list_mem + 256 * avail_ports.len())) },
cmd_list: unsafe {
buf.add(comm_list_mem_per_port * avail_ports.len())
.cast::<CommandHeader>()
},
cmd_list_len: max_cmd_slots,
has_device,
});
}
for port in &avail_ports {
if !port.has_device {
for port in hba.ports() {
if !port.has_device() {
syslog_client
.send_text_message("ahci", format!("Port {}: Empty", port.phys_no))
.send_text_message("ahci", format!("Port {}: Empty", port.phys_no()))
.unwrap();
continue;
}
@ -522,7 +115,7 @@ fn main() {
"ahci",
format!(
"Port {}: {} {}, firmware {} ({})",
port.phys_no,
port.phys_no(),
identify_data.model,
identify_data.serial,
identify_data.firmware,
@ -531,19 +124,26 @@ fn main() {
)
.unwrap();
} else {
let reg_fis = port.fis_buf.get_d2h_reg_fis().unwrap();
let regs = port.regs().unwrap();
if reg_fis.regs.lba1 != 0x14 || reg_fis.regs.lba2 != 0xEB {
if regs.lba1 != 0x14 || regs.lba2 != 0xEB {
syslog_client
.send_text_message("ahci", format!("Port {} failed to identify", port.phys_no))
.send_text_message(
"ahci",
format!("Port {} failed to identify", port.phys_no()),
)
.unwrap();
continue;
}
let Ok(identify_data) = port.issue_data_in_command(&IdentifyPacketCommand, &mut ident_buf)
let Ok(identify_data) =
port.issue_data_in_command(&IdentifyPacketCommand, &mut ident_buf)
else {
syslog_client
.send_text_message("ahci", format!("Port {} failed to identify", port.phys_no))
.send_text_message(
"ahci",
format!("Port {} failed to identify", port.phys_no()),
)
.unwrap();
continue;
};
@ -553,7 +153,7 @@ fn main() {
"ahci",
format!(
"Port {}: {} {}, firmware {}",
port.phys_no,
port.phys_no(),
identify_data.model,
identify_data.serial,
identify_data.firmware
@ -565,7 +165,7 @@ fn main() {
let mut mbr = [0; 512];
avail_ports[0]
hba.ports()[0]
.issue_data_in_command(&ReadDmaExtCommand::new(0, 1), &mut mbr)
.unwrap();
@ -606,7 +206,7 @@ fn main() {
let mut gpt_header = [0; 512];
avail_ports[0]
hba.ports()[0]
.issue_data_in_command(&ReadDmaExtCommand::new(1, 1), &mut gpt_header)
.unwrap();
@ -641,7 +241,7 @@ fn main() {
let mut gpt_part_table = vec![0; part_table_num_lbas * 512];
avail_ports[0]
hba.ports()[0]
.issue_data_in_command(
&ReadDmaExtCommand::new(part_table_start_lba as u64, part_table_num_lbas as u16),
gpt_part_table.as_mut_slice(),

253
src/port.rs Normal file
View File

@ -0,0 +1,253 @@
use std::{os::mikros::address_space::ACTIVE_SPACE, ptr::NonNull};
use itertools::Itertools;
use volatile::VolatilePtr;
use crate::{
ahci_structs::{
CommandHeader, CommandTableHeader, DeviceRegsRead, FisBuf, PortRegs, Prd, PxCMD, PxIS,
PxSERR, PxSSTS, PxTFD, RegH2DFis,
},
ata_command::AtaCommandDataIn,
};
#[allow(clippy::module_name_repetitions)]
pub struct AhciPort {
regs: &'static PortRegs,
phys_no: usize,
fis_buf: FisBuf,
cmd_list: *mut CommandHeader,
#[allow(dead_code)]
cmd_list_len: usize,
has_device: bool,
}
#[derive(Copy, Clone, Debug)]
pub enum CommandIssueError {
PrdtTooBig,
DataTooBig,
CommandFailed,
DataTooSmall,
}
impl AhciPort {
pub fn new(
regs: &'static PortRegs,
phys_no: usize,
fis_buf: FisBuf,
fis_buf_phys: u64,
cmd_list: *mut CommandHeader,
cmd_list_phys: u64,
cmd_list_len: usize,
) -> Self {
let cmd_reg = &regs.PxCMD;
if !(cmd_reg.read(PxCMD::ST) == 0
&& cmd_reg.read(PxCMD::CR) == 0
&& cmd_reg.read(PxCMD::FRE) == 0
&& cmd_reg.read(PxCMD::FR) == 0)
{
cmd_reg.modify(PxCMD::ST::CLEAR);
while cmd_reg.read(PxCMD::CR) > 0 {}
if cmd_reg.read(PxCMD::FRE) == 1 {
cmd_reg.modify(PxCMD::FRE::CLEAR);
while cmd_reg.read(PxCMD::FR) > 0 {}
}
}
regs.PxCLB.set(cmd_list_phys as u32);
regs.PxCLBU.set((cmd_list_phys >> 32) as u32);
regs.PxFB.set(fis_buf_phys as u32);
regs.PxFBU.set((fis_buf_phys >> 32) as u32);
regs.PxCMD.modify(PxCMD::FRE::SET);
Self::clear_serr(regs);
let has_device = regs.PxTFD.read(PxTFD::STS_BSY) == 0
&& regs.PxTFD.read(PxTFD::STS_DRQ) == 0
&& regs.PxSSTS.read(PxSSTS::DET) == 3;
if has_device {
regs.PxCMD.modify(PxCMD::ST::SET);
}
Self {
regs,
phys_no,
fis_buf,
cmd_list,
cmd_list_len,
has_device,
}
}
unsafe fn issue_command_prdt(
&self,
fis: &RegH2DFis,
prdt: &[Prd],
) -> Result<(), CommandIssueError> {
if prdt.len() > 65535 {
return Err(CommandIssueError::PrdtTooBig);
}
let mut tbl_hdr = CommandTableHeader {
cfis: [0; 64],
acmd: [0; 16],
rsvd: [0; 48],
};
tbl_hdr.cfis[0..RegH2DFis::BYTE_SIZE].copy_from_slice(&fis.to_bytes());
let tbl_size = std::mem::size_of::<CommandTableHeader>() + std::mem::size_of_val(prdt);
let (buf, buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(tbl_size.div_ceil(4096))
.unwrap();
let tbl_hdr_ptr = buf.cast::<CommandTableHeader>();
let prdt_ptr = unsafe {
buf.add(std::mem::size_of::<CommandTableHeader>())
.cast::<Prd>()
};
let cmd_hdr_ptr = self.cmd_list;
let cmd_hdr = CommandHeader {
flags: (RegH2DFis::BYTE_SIZE / 4) as u16,
prdtl: prdt.len() as u16,
prdbc: 0,
ctba: buf_phys as u32,
ctbau: (buf_phys >> 32) as u32,
rsvd: [0; 4],
};
unsafe {
tbl_hdr_ptr.write_volatile(tbl_hdr);
prdt_ptr.copy_from_nonoverlapping(&prdt[0], prdt.len());
cmd_hdr_ptr.write_volatile(cmd_hdr);
}
self.regs.PxCI.set(0x1);
while (self.regs.PxCI.get() & 0x1) > 0 && !self.has_fatal_error() {}
if self.has_fatal_error() {
self.regs.PxCMD.modify(PxCMD::ST::CLEAR);
while self.regs.PxCMD.read(PxCMD::CR) > 0 {}
Self::clear_serr(self.regs);
self.regs.PxIS.modify(PxIS::HBFS::SET);
self.regs.PxIS.modify(PxIS::HBDS::SET);
self.regs.PxIS.modify(PxIS::IFS::SET);
self.regs.PxIS.modify(PxIS::TFES::SET);
if self.regs.PxTFD.read(PxTFD::STS_BSY) > 0 || self.regs.PxTFD.read(PxTFD::STS_DRQ) > 0
{
unimplemented!("Port reset on error not implemented")
}
self.regs.PxCMD.modify(PxCMD::ST::SET);
return Err(CommandIssueError::CommandFailed);
}
ACTIVE_SPACE
.lock()
.unwrap()
.unmap(buf, tbl_size.div_ceil(4096))
.unwrap();
Ok(())
}
pub fn issue_data_in_command<'a, T: AtaCommandDataIn>(
&self,
command: &T,
buf: &'a mut [u8],
) -> Result<T::ProcessedData<'a>, CommandIssueError> {
if buf.len() < command.data_len() {
return Err(CommandIssueError::DataTooSmall);
}
let len = buf.len();
if len > 0x40_0000 * 0xFFFF {
return Err(CommandIssueError::DataTooBig);
}
let (data_buf, data_buf_phys) = ACTIVE_SPACE
.lock()
.unwrap()
.map_free_cont_phys(len.div_ceil(4096))
.unwrap();
let num_prds = len.div_ceil(0x40_0000);
let prdt = (0..num_prds)
.map(|i| {
let size = if i == num_prds - 1 {
len - (i * 0x40_0000)
} else {
0x40_0000
};
println!(
"Prd::new({:#x}, {:#x}, false)",
data_buf_phys + (i * 0x40_0000) as u64,
size as u32
);
Prd::new(data_buf_phys + (i * 0x40_0000) as u64, size as u32, false).unwrap()
})
.collect_vec();
unsafe {
self.issue_command_prdt(&command.get_fis(), &prdt)?;
}
let data_vol = unsafe {
VolatilePtr::new(NonNull::slice_from_raw_parts(
NonNull::new(data_buf).unwrap(),
len,
))
};
data_vol.copy_into_slice(buf);
ACTIVE_SPACE
.lock()
.unwrap()
.unmap(data_buf, len.div_ceil(4096))
.unwrap();
Ok(command.process_data(buf))
}
fn has_fatal_error(&self) -> bool {
self.regs.PxIS.read(PxIS::HBFS) > 0
|| self.regs.PxIS.read(PxIS::HBDS) > 0
|| self.regs.PxIS.read(PxIS::IFS) > 0
|| self.regs.PxIS.read(PxIS::TFES) > 0
}
fn clear_serr(regs: &PortRegs) {
regs.PxSERR.modify(PxSERR::DIAG_X::SET);
regs.PxSERR.modify(PxSERR::DIAG_F::SET);
regs.PxSERR.modify(PxSERR::DIAG_T::SET);
regs.PxSERR.modify(PxSERR::DIAG_S::SET);
regs.PxSERR.modify(PxSERR::DIAG_H::SET);
regs.PxSERR.modify(PxSERR::DIAG_C::SET);
regs.PxSERR.modify(PxSERR::DIAG_B::SET);
regs.PxSERR.modify(PxSERR::DIAG_W::SET);
regs.PxSERR.modify(PxSERR::DIAG_I::SET);
regs.PxSERR.modify(PxSERR::DIAG_N::SET);
regs.PxSERR.modify(PxSERR::ERR_E::SET);
regs.PxSERR.modify(PxSERR::ERR_P::SET);
regs.PxSERR.modify(PxSERR::ERR_C::SET);
regs.PxSERR.modify(PxSERR::ERR_T::SET);
regs.PxSERR.modify(PxSERR::ERR_M::SET);
regs.PxSERR.modify(PxSERR::ERR_I::SET);
}
pub fn phys_no(&self) -> usize {
self.phys_no
}
pub fn has_device(&self) -> bool {
self.has_device
}
pub fn regs(&self) -> Option<DeviceRegsRead> {
self.fis_buf.get_d2h_reg_fis().map(|x| x.regs)
}
}