diff --git a/src/memfd_secret.c b/src/memfd_secret.c index 3c5320a..16f1468 100644 --- a/src/memfd_secret.c +++ b/src/memfd_secret.c @@ -1,13 +1,16 @@ #include +#include #include +#include #include + #include "ifunc.h" #include -#define READ_ONCE(slot) ((__typeof__(slot))(*(const volatile __typeof__(slot)*)(slot))) -#define WRITE_ONCE(slot, value) (*((volatile __typeof__(slot)*)(slot)) = (value)) +#define READ_ONCE(slot) ((__typeof__(slot))(*(const volatile __typeof__(slot)*)&(slot))) +#define WRITE_ONCE(slot, value) (*((volatile __typeof__(slot)*)&(slot)) = (value)) __attribute__((gnu_inline)) static inline @@ -69,8 +72,15 @@ int IFUNC_IMPL(memfd_secret, $enabled) (unsigned int flags) __attribute__((visibility("hidden"))) int IFUNC_IMPL(memfd_secret, $disabled) (unsigned int flags) { - if( FD_CLOEXEC != MEMFD_CLOEXEC ) { // NOTE: This is a constant expression, and this code will be removed if they are equal. - //TODO: Translate mask `flags`, from `FD_CLOEXEC` (if it is set) -> `MEMFD_CLOEXEC`. + // Translate mask `flags`, from `FD_CLOEXEC` (if it is set) -> `MEMFD_CLOEXEC`. + if( FD_CLOEXEC != MFD_CLOEXEC ) { // NOTE: This is a constant expression, and this code will be removed if they are equal. + // Check if all bit(s) of `FD_CLOEXEC` is in `flags`. + if((flags & FD_CLOEXEC) == FD_CLOEXEC) { + // Mask out the `FD_CLOEXEC` bit(s) + flags &= ~FD_CLOEXEC; + // Mask in the `MFD_CLOEXEC` bit(s) + flags |= MFD_CLOEXEC; + } // NOTE: We do not need to check cases where `flags & FD_CLOEXEC` is non-zero but the above branch is not hit, that would be an invalid call anyway. Plus I highly doubt any system will set `FD_CLOEXEC` to be more than 1 set bit anyway. } return memfd_create("memfd_secret@?", flags); }