Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated Code Change #66578

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 0 additions & 64 deletions third_party/xla/xla/service/cpu/cpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,70 +269,6 @@ Status CpuExecutable::ExecuteComputeFunction(
return OkStatus();
}

absl::StatusOr<std::unique_ptr<Executable>> CpuExecutable::LoadFromObjFile(
std::unique_ptr<HloModule> hlo_module, absl::string_view obj_file,
absl::string_view mlir_module,
std::unique_ptr<BufferAssignment> buffer_assignment,
XlaFrameworkMapping xla_framework_mapping,
runtime::JitExecutable::Options opts) {
VLOG(1) << "Load serialized Cpu executable from object file: module="
<< hlo_module->name();

runtime::DialectRegistry dialects;
opts.compiler.register_dialects(dialects);
auto threading = mlir::MLIRContext::Threading::DISABLED;
auto ctx = std::make_unique<mlir::MLIRContext>(*dialects, threading);
ctx->loadAllAvailableDialects();

// Load MLIR module behind the compiled object file.
auto module = mlir::parseSourceString<mlir::ModuleOp>(mlir_module, ctx.get());
if (!module) return Internal("Failed to parse AOT compiled module");

llvm::StringRef data(obj_file.data(), obj_file.size());
auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());

// Recover function signatures using calling convention and type converter.
auto func = mlir::cast<mlir::func::FuncOp>(module->lookupSymbol("main"));
mlir::FunctionType func_type = func.getFunctionType();
absl::StatusOr<runtime::FunctionType> sig =
opts.compiler.type_converter.Convert(func_type);
if (!sig.ok())
return Internal("Type converter failed to convert function type");

mlir::FunctionType runtime_type = opts.compiler.calling_convention(func_type);
if (!runtime_type)
return Internal("Calling convention failed to convert function type");

absl::StatusOr<runtime::FunctionType> runtime_sig =
opts.compiler.type_converter.Convert(runtime_type);
if (!runtime_sig.ok())
return Internal(
"Type converter failed to convert runtime function type");

// Cpu executable has a single exported function.
std::vector<runtime::Executable::LoadFunction> functions;
functions.push_back({"main", std::move(*sig), std::move(*runtime_sig)});

// Load XLA Runtime executable from an object file.
auto executable = runtime::Executable::LoadFromObjFile(
hlo_module->name(), std::move(buffer), std::move(functions),
opts.compiler.symbols_binding);

if (!executable.ok())
return Internal("Failed to load XLA Runtime executable: %s",
executable.status().message());

// Move runtime::Executable ownership to the XlaRuntimeCpuExecutable.
auto executable_ptr =
std::make_unique<runtime::Executable>(std::move(executable.value()));
auto xla_runtime_executable = std::make_unique<XlaRuntimeCpuExecutable>(
std::move(executable_ptr), xla_framework_mapping);

return CpuExecutable::Create(std::move(hlo_module), nullptr, nullptr,
std::move(buffer_assignment),
std::move(xla_runtime_executable));
}

absl::StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
absl::Span<MaybeOwningDeviceMemory> buffers,
Expand Down
37 changes: 0 additions & 37 deletions third_party/xla/xla/service/cpu/cpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,6 @@ class XlaRuntimeCpuExecutable {
return std::string_view(obj_file->getBuffer());
}

absl::StatusOr<std::string_view> GetMlirModule() const {
if (!std::holds_alternative<std::unique_ptr<runtime::JitExecutable>>(
executable_)) {
return Internal("No JitExecutable");
}

runtime::JitExecutable* jit_executable =
std::get<std::unique_ptr<runtime::JitExecutable>>(executable_).get();
return jit_executable->mlir_module();
}

XlaFrameworkMapping xla_framework_mapping() { return xla_framework_mapping_; }

private:
// In JIT compilation mode `JitExecutable` is used. In AOT compilation mode
// `Executable` is used.
Expand Down Expand Up @@ -161,15 +148,6 @@ class CpuExecutable : public Executable {
absl::Span<MaybeOwningDeviceMemory const> buffers,
HloExecutionProfile* hlo_execution_profile);

// Returns an Executable that is loaded from an object file (XLA program
// compiled to a native function using the XLA Runtime stack).
static absl::StatusOr<std::unique_ptr<Executable>> LoadFromObjFile(
std::unique_ptr<HloModule> hlo_module, absl::string_view obj_file,
absl::string_view mlir_module,
std::unique_ptr<BufferAssignment> buffer_assignment,
XlaFrameworkMapping xla_framework_mapping,
runtime::JitExecutable::Options opts);

absl::Span<const std::string> obj_files() const { return obj_files_; }

void set_obj_files(std::vector<std::string> obj_files) {
Expand Down Expand Up @@ -201,21 +179,6 @@ class CpuExecutable : public Executable {

int64_t SizeOfGeneratedCodeInBytes() const override;

absl::StatusOr<std::string_view> GetObjFile() const {
if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable");
return xla_runtime_executable_->GetObjFile();
}

absl::StatusOr<std::string_view> GetMlirModule() const {
if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable");
return xla_runtime_executable_->GetMlirModule();
}

absl::StatusOr<XlaFrameworkMapping> GetXlaFrameworkMapping() const {
if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable");
return xla_runtime_executable_->xla_framework_mapping();
}

private:
// Creates an array suitable for passing as the "buffer_table" argument to the
// JIT compiled function pointer.
Expand Down