Skip to content

Extract a create_wrapper_function for use in allocator shim writing #113722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 90 additions & 92 deletions compiler/rustc_codegen_llvm/src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_middle::ty::TyCtxt;
use rustc_session::config::{DebugInfo, OomStrategy};

use crate::debuginfo;
use crate::llvm::{self, False, True};
use crate::llvm::{self, Context, False, Module, True, Type};
use crate::ModuleLlvm;

pub(crate) unsafe fn codegen(
Expand All @@ -29,7 +29,6 @@ pub(crate) unsafe fn codegen(
};
let i8 = llvm::LLVMInt8TypeInContext(llcx);
let i8p = llvm::LLVMPointerTypeInContext(llcx, 0);
let void = llvm::LLVMVoidTypeInContext(llcx);

if kind == AllocatorKind::Default {
for method in ALLOCATOR_METHODS {
Expand All @@ -54,102 +53,25 @@ pub(crate) unsafe fn codegen(
panic!("invalid allocator output")
}
};
let ty = llvm::LLVMFunctionType(
output.unwrap_or(void),
args.as_ptr(),
args.len() as c_uint,
False,
);
let name = global_fn_name(method.name);
let llfn =
llvm::LLVMRustGetOrInsertFunction(llmod, name.as_ptr().cast(), name.len(), ty);

if tcx.sess.target.default_hidden_visibility {
llvm::LLVMRustSetVisibility(llfn, llvm::Visibility::Hidden);
}
if tcx.sess.must_emit_unwind_tables() {
let uwtable = attributes::uwtable_attr(llcx);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[uwtable]);
}

let callee = default_fn_name(method.name);
let callee =
llvm::LLVMRustGetOrInsertFunction(llmod, callee.as_ptr().cast(), callee.len(), ty);
llvm::LLVMRustSetVisibility(callee, llvm::Visibility::Hidden);

let llbb = llvm::LLVMAppendBasicBlockInContext(llcx, llfn, "entry\0".as_ptr().cast());

let llbuilder = llvm::LLVMCreateBuilderInContext(llcx);
llvm::LLVMPositionBuilderAtEnd(llbuilder, llbb);
let args = args
.iter()
.enumerate()
.map(|(i, _)| llvm::LLVMGetParam(llfn, i as c_uint))
.collect::<Vec<_>>();
let ret = llvm::LLVMRustBuildCall(
llbuilder,
ty,
callee,
args.as_ptr(),
args.len() as c_uint,
[].as_ptr(),
0 as c_uint,
);
llvm::LLVMSetTailCall(ret, True);
if output.is_some() {
llvm::LLVMBuildRet(llbuilder, ret);
} else {
llvm::LLVMBuildRetVoid(llbuilder);
}
llvm::LLVMDisposeBuilder(llbuilder);
let from_name = global_fn_name(method.name);
let to_name = default_fn_name(method.name);

create_wrapper_function(tcx, llcx, llmod, &from_name, &to_name, &args, output, false);
}
}

// rust alloc error handler
let args = [usize, usize]; // size, align

let ty = llvm::LLVMFunctionType(void, args.as_ptr(), args.len() as c_uint, False);
let name = "__rust_alloc_error_handler";
let llfn = llvm::LLVMRustGetOrInsertFunction(llmod, name.as_ptr().cast(), name.len(), ty);
// -> ! DIFlagNoReturn
let no_return = llvm::AttributeKind::NoReturn.create_attr(llcx);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[no_return]);

if tcx.sess.target.default_hidden_visibility {
llvm::LLVMRustSetVisibility(llfn, llvm::Visibility::Hidden);
}
if tcx.sess.must_emit_unwind_tables() {
let uwtable = attributes::uwtable_attr(llcx);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[uwtable]);
}

let callee = alloc_error_handler_name(alloc_error_handler_kind);
let callee = llvm::LLVMRustGetOrInsertFunction(llmod, callee.as_ptr().cast(), callee.len(), ty);
// -> ! DIFlagNoReturn
attributes::apply_to_llfn(callee, llvm::AttributePlace::Function, &[no_return]);
llvm::LLVMRustSetVisibility(callee, llvm::Visibility::Hidden);

let llbb = llvm::LLVMAppendBasicBlockInContext(llcx, llfn, "entry\0".as_ptr().cast());

let llbuilder = llvm::LLVMCreateBuilderInContext(llcx);
llvm::LLVMPositionBuilderAtEnd(llbuilder, llbb);
let args = args
.iter()
.enumerate()
.map(|(i, _)| llvm::LLVMGetParam(llfn, i as c_uint))
.collect::<Vec<_>>();
let ret = llvm::LLVMRustBuildCall(
llbuilder,
ty,
callee,
args.as_ptr(),
args.len() as c_uint,
[].as_ptr(),
0 as c_uint,
create_wrapper_function(
tcx,
llcx,
llmod,
"__rust_alloc_error_handler",
&alloc_error_handler_name(alloc_error_handler_kind),
&[usize, usize], // size, align
None,
true,
);
llvm::LLVMSetTailCall(ret, True);
llvm::LLVMBuildRetVoid(llbuilder);
llvm::LLVMDisposeBuilder(llbuilder);

// __rust_alloc_error_handler_should_panic
let name = OomStrategy::SYMBOL;
Expand All @@ -175,3 +97,79 @@ pub(crate) unsafe fn codegen(
dbg_cx.finalize(tcx.sess);
}
}

fn create_wrapper_function(
tcx: TyCtxt<'_>,
llcx: &Context,
llmod: &Module,
from_name: &str,
to_name: &str,
args: &[&Type],
output: Option<&Type>,
no_return: bool,
) {
unsafe {
let ty = llvm::LLVMFunctionType(
output.unwrap_or_else(|| llvm::LLVMVoidTypeInContext(llcx)),
args.as_ptr(),
args.len() as c_uint,
False,
);
let llfn = llvm::LLVMRustGetOrInsertFunction(
llmod,
from_name.as_ptr().cast(),
from_name.len(),
ty,
);
let no_return = if no_return {
// -> ! DIFlagNoReturn
let no_return = llvm::AttributeKind::NoReturn.create_attr(llcx);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[no_return]);
Some(no_return)
} else {
None
};

if tcx.sess.target.default_hidden_visibility {
llvm::LLVMRustSetVisibility(llfn, llvm::Visibility::Hidden);
}
if tcx.sess.must_emit_unwind_tables() {
let uwtable = attributes::uwtable_attr(llcx);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[uwtable]);
}

let callee =
llvm::LLVMRustGetOrInsertFunction(llmod, to_name.as_ptr().cast(), to_name.len(), ty);
if let Some(no_return) = no_return {
// -> ! DIFlagNoReturn
attributes::apply_to_llfn(callee, llvm::AttributePlace::Function, &[no_return]);
}
llvm::LLVMRustSetVisibility(callee, llvm::Visibility::Hidden);

let llbb = llvm::LLVMAppendBasicBlockInContext(llcx, llfn, "entry\0".as_ptr().cast());

let llbuilder = llvm::LLVMCreateBuilderInContext(llcx);
llvm::LLVMPositionBuilderAtEnd(llbuilder, llbb);
let args = args
.iter()
.enumerate()
.map(|(i, _)| llvm::LLVMGetParam(llfn, i as c_uint))
.collect::<Vec<_>>();
let ret = llvm::LLVMRustBuildCall(
llbuilder,
ty,
callee,
args.as_ptr(),
args.len() as c_uint,
[].as_ptr(),
0 as c_uint,
);
llvm::LLVMSetTailCall(ret, True);
if output.is_some() {
llvm::LLVMBuildRet(llbuilder, ret);
} else {
llvm::LLVMBuildRetVoid(llbuilder);
}
llvm::LLVMDisposeBuilder(llbuilder);
}
}