From fba8877f4a6cc1e245d2adc416408eeb980ee645 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Sun, 12 Mar 2023 13:29:20 +0100 Subject: [PATCH] cleanup netlink impl --- src/interface/android.rs | 340 ++++++++++++++++++++++++--------------- 1 file changed, 206 insertions(+), 134 deletions(-) diff --git a/src/interface/android.rs b/src/interface/android.rs index cc85a42..b8addde 100644 --- a/src/interface/android.rs +++ b/src/interface/android.rs @@ -56,166 +56,238 @@ mod netlink { LinkMessage, RtnlMessage, }; use netlink_sys::{protocols::NETLINK_ROUTE, Socket}; + use std::io; use std::net::{Ipv4Addr, Ipv6Addr}; use crate::interface::{Interface, InterfaceType, Ipv4Net, Ipv6Net, MacAddr}; pub fn unix_interfaces() -> Vec { - let socket = Socket::new(NETLINK_ROUTE).unwrap(); - let mut ifaces = Vec::new(); - enumerate_netlink( - &socket, - RtnlMessage::GetLink(LinkMessage::default()), - &mut ifaces, - ); - enumerate_netlink( - &socket, - RtnlMessage::GetAddress(AddressMessage::default()), - &mut ifaces, - ); - + if let Ok(socket) = Socket::new(NETLINK_ROUTE) { + if let Err(err) = enumerate_netlink( + &socket, + RtnlMessage::GetLink(LinkMessage::default()), + &mut ifaces, + handle_new_link, + ) { + eprintln!("unable to list interfaces: {:?}", err); + }; + if let Err(err) = enumerate_netlink( + &socket, + RtnlMessage::GetAddress(AddressMessage::default()), + &mut ifaces, + handle_new_addr, + ) { + eprintln!("unable to list addresses: {:?}", err); + } + } ifaces } - fn enumerate_netlink(socket: &Socket, msg: RtnlMessage, ifaces: &mut Vec) { - let mut packet = NetlinkMessage::new(NetlinkHeader::default(), NetlinkPayload::from(msg)); - packet.header.flags = NLM_F_DUMP | NLM_F_REQUEST; - packet.header.sequence_number = 1; - packet.finalize(); + fn handle_new_link(ifaces: &mut Vec, msg: RtnlMessage) -> io::Result<()> { + match msg { + RtnlMessage::NewLink(link_msg) => { + let mut interface: Interface = Interface { + index: link_msg.header.index, + name: String::new(), + friendly_name: None, + description: None, + if_type: InterfaceType::try_from(link_msg.header.link_layer_type as u32) + .unwrap_or(InterfaceType::Unknown), + mac_addr: None, + ipv4: Vec::new(), + ipv6: Vec::new(), + flags: link_msg.header.flags, + transmit_speed: None, + receive_speed: None, + gateway: None, + }; + + for nla in link_msg.nlas { + match nla { + LinkNla::IfName(name) => { + interface.name = name; + } + LinkNla::Address(addr) => { + match addr.len() { + 6 => { + interface.mac_addr = + Some(MacAddr::new(addr.try_into().unwrap())); + } + 4 => { + let ip = Ipv4Addr::from(<[u8; 4]>::try_from(addr).unwrap()); + interface + .ipv4 + .push(Ipv4Net::new_with_netmask(ip, Ipv4Addr::UNSPECIFIED)); + } + _ => { + // unclear what these would be + } + } + } + _ => {} + } + } + ifaces.push(interface); + } + _ => {} + } + + Ok(()) + } + + fn handle_new_addr(ifaces: &mut Vec, msg: RtnlMessage) -> io::Result<()> { + match msg { + RtnlMessage::NewAddress(addr_msg) => { + if let Some(interface) = + ifaces.iter_mut().find(|i| i.index == addr_msg.header.index) + { + for nla in addr_msg.nlas { + match nla { + AddressNla::Address(addr) => match addr.len() { + 4 => { + let ip = Ipv4Addr::from(<[u8; 4]>::try_from(addr).unwrap()); + interface + .ipv4 + .push(Ipv4Net::new(ip, addr_msg.header.prefix_len)); + } + 16 => { + let ip = Ipv6Addr::from(<[u8; 16]>::try_from(addr).unwrap()); + interface + .ipv6 + .push(Ipv6Net::new(ip, addr_msg.header.prefix_len)); + } + _ => { + // what else? + } + }, + _ => {} + } + } + } else { + eprintln!( + "found unknown interface with index: {}", + addr_msg.header.index + ); + } + } + _ => {} + } + + Ok(()) + } - let mut buf = vec![0; packet.header.length as usize]; + struct NetlinkIter<'a> { + socket: &'a Socket, + /// Buffer for received data. + buf: Vec, + /// Size of the data available in `buf`. + size: usize, + /// Offset into the data currently in `buf`. + offset: usize, + /// Are we don iterating? + done: bool, + } - // TODO: gracefully handle error - assert!(buf.len() == packet.buffer_len()); - packet.serialize(&mut buf[..]); + impl<'a> NetlinkIter<'a> { + fn new(socket: &'a Socket, msg: RtnlMessage) -> io::Result { + let mut packet = + NetlinkMessage::new(NetlinkHeader::default(), NetlinkPayload::from(msg)); + packet.header.flags = NLM_F_DUMP | NLM_F_REQUEST; + packet.header.sequence_number = 1; + packet.finalize(); - socket.send(&buf[..], 0).unwrap(); + let mut buf = vec![0; packet.header.length as usize]; + assert_eq!(buf.len(), packet.buffer_len()); + packet.serialize(&mut buf[..]); + socket.send(&buf[..], 0)?; - let mut receive_buffer = vec![0; 4096]; - let mut offset = 0; + Ok(NetlinkIter { + socket, + offset: 0, + size: 0, + buf: vec![0u8; 4096], + done: false, + }) + } + } - loop { - let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + impl<'a> Iterator for NetlinkIter<'a> { + type Item = io::Result; - loop { - let bytes = &receive_buffer[offset..]; - let rx_packet: NetlinkMessage = - NetlinkMessage::deserialize(bytes).unwrap(); + fn next(&mut self) -> Option { + if self.done { + return None; + } - match rx_packet.payload { - NetlinkPayload::Done => { - return; - } - NetlinkPayload::Error(err) => { - eprintln!("Error: {:?}", err); - return; + while !self.done { + // Outer loop + if self.size == 0 { + match self.socket.recv(&mut &mut self.buf[..], 0) { + Ok(size) => { + self.size = size; + self.offset = 0; + } + Err(err) => { + self.done = true; + return Some(Err(err)); + } } - NetlinkPayload::InnerMessage(msg) => { - match msg { - RtnlMessage::NewLink(link_msg) => { - let mut interface: Interface = Interface { - index: link_msg.header.index, - name: String::new(), - friendly_name: None, - description: None, - if_type: InterfaceType::try_from( - link_msg.header.link_layer_type as u32, - ) - .unwrap_or(InterfaceType::Unknown), - mac_addr: None, - ipv4: Vec::new(), - ipv6: Vec::new(), - flags: link_msg.header.flags, - transmit_speed: None, - receive_speed: None, - gateway: None, - }; - - for nla in link_msg.nlas { - match nla { - LinkNla::IfName(name) => { - interface.name = name; - } - LinkNla::Address(addr) => { - match addr.len() { - 6 => { - interface.mac_addr = Some(MacAddr::new( - addr.try_into().unwrap(), - )); - } - 4 => { - let ip = Ipv4Addr::from( - <[u8; 4]>::try_from(addr).unwrap(), - ); - interface.ipv4.push(Ipv4Net::new_with_netmask( - ip, - Ipv4Addr::UNSPECIFIED, - )); - } - _ => { - // unclear what these would be - } - } - } - _ => {} - } - } - ifaces.push(interface); + } + + let bytes = &self.buf[self.offset..]; + match NetlinkMessage::::deserialize(bytes) { + Ok(packet) => { + self.offset += packet.header.length as usize; + if packet.header.length == 0 || self.offset == self.size { + // mark this message as fully read + self.size = 0; + } + match packet.payload { + NetlinkPayload::Done => { + self.done = true; + return None; } - RtnlMessage::NewAddress(addr_msg) => { - if let Some(interface) = - ifaces.iter_mut().find(|i| i.index == addr_msg.header.index) - { - for nla in addr_msg.nlas { - match nla { - AddressNla::Address(addr) => match addr.len() { - 4 => { - let ip = Ipv4Addr::from( - <[u8; 4]>::try_from(addr).unwrap(), - ); - interface.ipv4.push(Ipv4Net::new( - ip, - addr_msg.header.prefix_len, - )); - } - 16 => { - let ip = Ipv6Addr::from( - <[u8; 16]>::try_from(addr).unwrap(), - ); - interface.ipv6.push(Ipv6Net::new( - ip, - addr_msg.header.prefix_len, - )); - } - _ => { - // what else? - } - }, - _ => {} - } - } - } else { - eprintln!( - "found unknown interface with index: {}", - addr_msg.header.index - ); - } + NetlinkPayload::Error(err) => { + self.done = true; + return Some(Err(io::Error::new( + io::ErrorKind::Other, + err.to_string(), + ))); } + NetlinkPayload::InnerMessage(msg) => return Some(Ok(msg)), _ => { - // not expecting other messages + continue; } } } - _ => {} - } - offset += rx_packet.header.length as usize; - if offset == size || rx_packet.header.length == 0 { - offset = 0; - break; + Err(err) => { + self.done = true; + return Some(Err(io::Error::new(io::ErrorKind::Other, err.to_string()))); + } } } + + None + } + } + + fn enumerate_netlink( + socket: &Socket, + msg: RtnlMessage, + ifaces: &mut Vec, + cb: F, + ) -> io::Result<()> + where + F: Fn(&mut Vec, RtnlMessage) -> io::Result<()>, + { + let iter = NetlinkIter::new(socket, msg)?; + for msg in iter { + let msg = msg?; + cb(ifaces, msg)?; } + + Ok(()) } #[cfg(test)]