Skip to content

Instantly share code, notes, and snippets.

@folkertdev
Created June 7, 2023 19:42
Show Gist options
  • Save folkertdev/80cf998e84219e11dc3506f5975045bb to your computer and use it in GitHub Desktop.
Save folkertdev/80cf998e84219e11dc3506f5975045bb to your computer and use it in GitHub Desktop.
mio responding to a message arriving on the error queue (EPOLLPRI)
use std::net::SocketAddr;
use std::os::unix::io::AsRawFd;
use std::thread;
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
const SOCKET_TOKEN: Token = Token(0);
fn main() {
// Create a UDP socket and bind it to a local address
let socket_addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
let mut socket = UdpSocket::bind(socket_addr).expect("Failed to bind socket");
let socket_fd = socket.as_raw_fd();
// Enable IP_RECVERR option to receive error messages
let recv_err: libc::c_int = 1;
unsafe {
let res = libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_IP,
libc::IP_RECVERR,
&recv_err as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
if res == -1 {
panic!("{:?}", std::io::Error::last_os_error());
}
}
// Create a Mio poll instance
let mut poll = Poll::new().expect("Failed to create poll");
// Register the socket with the poll instance
poll.registry()
.register(&mut socket, SOCKET_TOKEN, Interest::READABLE)
.expect("Failed to register socket");
// Create an events instance to store the events
let mut events = Events::with_capacity(128);
// Spawn a separate thread for sending messages
thread::spawn(move || {
// Set the destination address
let mut dest_addr =
unsafe { std::mem::MaybeUninit::<libc::sockaddr_in>::zeroed().assume_init() };
dest_addr.sin_family = libc::AF_INET as _;
dest_addr.sin_port = 1234u16.to_be(); // Destination port
dest_addr.sin_addr.s_addr = libc::INADDR_LOOPBACK.to_be(); // Destination IP address
// Prepare the message data
let message = "Hello, Socket!";
// Prepare the ancillary data (control message)
let mut cmsg_buf =
[0; unsafe { libc::CMSG_SPACE(std::mem::size_of::<libc::c_int>() as _) } as usize];
unsafe {
let cmsg: *mut libc::cmsghdr = cmsg_buf.as_mut_ptr() as *mut libc::cmsghdr;
(*cmsg).cmsg_len =
libc::CMSG_LEN(std::mem::size_of::<libc::c_int>() as libc::socklen_t) as _;
(*cmsg).cmsg_level = libc::IPPROTO_IP;
(*cmsg).cmsg_type = libc::IP_TTL as libc::c_int;
let cmsg_data: *mut libc::c_int = libc::CMSG_DATA(cmsg) as *mut libc::c_int;
*cmsg_data = 64; // Set the TTL value
}
// Prepare the destination address for the sendmsg call
let dest_sockaddr: *const libc::sockaddr = &dest_addr as *const _ as *const libc::sockaddr;
let dest_addrlen: libc::socklen_t = std::mem::size_of_val(&dest_addr) as libc::socklen_t;
// Prepare the message structure for sendmsg
let mut iov = libc::iovec {
iov_base: message.as_ptr() as *mut libc::c_void,
iov_len: message.len(),
};
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_name = dest_sockaddr as *mut libc::c_void;
msg.msg_namelen = dest_addrlen;
msg.msg_iov = &mut iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = unsafe { (*msg.msg_control.cast::<libc::cmsghdr>()).cmsg_len };
loop {
thread::sleep(std::time::Duration::from_secs(1));
// Send the message with ancillary data using sendmsg
println!("sending");
let res = unsafe { libc::sendmsg(socket_fd, &msg, 0) };
if res == -1 {
panic!("{:?}", std::io::Error::last_os_error());
}
}
});
// Read messages from the error queue
const BUFFER_SIZE: usize = 2048;
let mut buffer: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let mut iov_recv = libc::iovec {
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
iov_len: BUFFER_SIZE,
};
let mut msg_recv: libc::msghdr = unsafe { std::mem::zeroed() };
msg_recv.msg_iov = &mut iov_recv;
msg_recv.msg_iovlen = 1;
loop {
// Wait for events
poll.poll(&mut events, None).expect("Failed to poll events");
for event in events.iter() {
match event.token() {
SOCKET_TOKEN => {
dbg!(&event);
if event.is_error() {
println!("reading");
let recv_len = unsafe {
libc::recvmsg(socket.as_raw_fd(), &mut msg_recv, libc::MSG_ERRQUEUE)
};
if recv_len == -1 {
let err = std::io::Error::last_os_error();
panic!("{err:?}");
}
println!("Received message from error queue:");
println!("Message Length: {} bytes", recv_len);
println!(
"Message Content: {}",
String::from_utf8_lossy(&buffer[..recv_len as usize])
);
}
}
_ => unreachable!(),
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment