Skip to content

Pull upstream changes #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmd/aws-lambda-rie/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
package main

import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
Expand All @@ -24,7 +24,7 @@ import (

type Sandbox interface {
Init(i *interop.Init, invokeTimeoutMs int64)
Invoke(responseWriter io.Writer, invoke *interop.Invoke) error
Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error
}

var initDone bool
Expand Down Expand Up @@ -98,7 +98,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) {
InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:012345678912:function:%s", GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function")),
TraceID: r.Header.Get("X-Amzn-Trace-Id"),
LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"),
Payload: bodyBytes,
Payload: bytes.NewReader(bodyBytes),
CorrelationID: "invokeCorrelationID",
}
fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion)
Expand Down
17 changes: 7 additions & 10 deletions cmd/aws-lambda-rie/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
)


const (
optBootstrap = "/opt/bootstrap"
runtimeBootstrap = "/var/runtime/bootstrap"
Expand Down Expand Up @@ -58,24 +57,22 @@ func getCLIArgs() (options, []string) {
}

func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) {
var bootstrapLookupCmdList [][]string
var bootstrapLookupCmd []string
var handler string
currentWorkingDir := "/var/task" // default value

if len(args) <= 1 {
bootstrapLookupCmdList = [][]string{
[]string{fmt.Sprintf("%s/bootstrap", currentWorkingDir)},
[]string{optBootstrap},
[]string{runtimeBootstrap},
bootstrapLookupCmd = []string{
fmt.Sprintf("%s/bootstrap", currentWorkingDir),
optBootstrap,
runtimeBootstrap,
}

// handler is used later to set an env var for Lambda Image support
handler = ""
} else if len(args) > 1 {

bootstrapLookupCmdList = [][]string{
args[1:],
}
bootstrapLookupCmd = args[1:]

if cwd, err := os.Getwd(); err == nil {
currentWorkingDir = cwd
Expand All @@ -92,5 +89,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) {
log.Panic("insufficient arguments: bootstrap not provided")
}

return rapidcore.NewBootstrap(bootstrapLookupCmdList, currentWorkingDir), handler
return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler
}
41 changes: 41 additions & 0 deletions lambda/core/directinvoke/customerheaders.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package directinvoke

import (
"bytes"
"encoding/base64"
"encoding/json"
)

type CustomerHeaders struct {
CognitoIdentityID string `json:"Cognito-Identity-Id"`
CognitoIdentityPoolID string `json:"Cognito-Identity-Pool-Id"`
ClientContext string `json:"Client-Context"`
}

func (s CustomerHeaders) Dump() string {
if (s == CustomerHeaders{}) {
return ""
}

custHeadersJSON, err := json.Marshal(&s)
if err != nil {
panic(err)
}

return base64.StdEncoding.EncodeToString(custHeadersJSON)
}

func (s *CustomerHeaders) Load(in string) error {
*s = CustomerHeaders{}

if in == "" {
return nil
}

base64Decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(in)))

return json.NewDecoder(base64Decoder).Decode(s)
}
25 changes: 25 additions & 0 deletions lambda/core/directinvoke/customerheaders_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package directinvoke

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestCustomerHeadersEmpty(t *testing.T) {
in := CustomerHeaders{}
out := CustomerHeaders{}

require.NoError(t, out.Load(in.Dump()))
require.Equal(t, in, out)
}

func TestCustomerHeaders(t *testing.T) {
in := CustomerHeaders{CognitoIdentityID: "asd"}
out := CustomerHeaders{}

require.NoError(t, out.Load(in.Dump()))
require.Equal(t, in, out)
}
106 changes: 106 additions & 0 deletions lambda/core/directinvoke/directinvoke.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package directinvoke

import (
"io"
"net/http"

"github.com/go-chi/chi"
"go.amzn.com/lambda/interop"
)

const (
InvokeIDHeader = "Invoke-Id"
InvokedFunctionArnHeader = "Invoked-Function-Arn"
VersionIDHeader = "Invoked-Function-Version"
ReservationTokenHeader = "Reservation-Token"
CustomerHeadersHeader = "Customer-Headers"
ContentTypeHeader = "Content-Type"

ErrorTypeHeader = "Error-Type"

EndOfResponseTrailer = "End-Of-Response"

SandboxErrorType = "Error.Sandbox"
)

const (
EndOfResponseComplete = "Complete"
EndOfResponseTruncated = "Truncated"
EndOfResponseOversized = "Oversized"
)

var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI

func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) {
w.Header().Set(ErrorTypeHeader, errorType)
w.WriteHeader(http.StatusBadRequest)
w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete)
}

// ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token
// Renders BadRequest in case of error
func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) {
w.Header().Set("Trailer", EndOfResponseTrailer)

custHeaders := CustomerHeaders{}
if err := custHeaders.Load(r.Header.Get(CustomerHeadersHeader)); err != nil {
renderBadRequest(w, r, interop.ErrMalformedCustomerHeaders.Error())
return nil, interop.ErrMalformedCustomerHeaders
}

inv := &interop.Invoke{
ID: r.Header.Get(InvokeIDHeader),
ReservationToken: chi.URLParam(r, "reservationtoken"),
InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader),
VersionID: r.Header.Get(VersionIDHeader),
ContentType: r.Header.Get(ContentTypeHeader),
CognitoIdentityID: custHeaders.CognitoIdentityID,
CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID,
TraceID: token.TraceID,
LambdaSegmentID: token.LambdaSegmentID,
ClientContext: custHeaders.ClientContext,
Payload: r.Body,
CorrelationID: "invokeCorrelationID",
DeadlineNs: token.DeadlineNs,
}

if inv.ID != token.InvokeID {
renderBadRequest(w, r, interop.ErrInvalidInvokeID.Error())
return nil, interop.ErrInvalidInvokeID
}

if inv.ReservationToken != token.ReservationToken {
renderBadRequest(w, r, interop.ErrInvalidReservationToken.Error())
return nil, interop.ErrInvalidReservationToken
}

if inv.VersionID != token.VersionID {
renderBadRequest(w, r, interop.ErrInvalidFunctionVersion.Error())
return nil, interop.ErrInvalidFunctionVersion
}

w.Header().Set(VersionIDHeader, token.VersionID)
w.Header().Set(ReservationTokenHeader, token.ReservationToken)
w.Header().Set(InvokeIDHeader, token.InvokeID)

return inv, nil
}

func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, w http.ResponseWriter) error {
for k, v := range additionalHeaders {
w.Header().Add(k, v)
}

n, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte
if err != nil {
w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated)
} else if n == MaxDirectResponseSize+1 {
w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized)
} else {
w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete)
}
return err
}
9 changes: 8 additions & 1 deletion lambda/core/registrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ func (s *registrationServiceImpl) getInternalStateDescription(appCtx appctx.Appl
}

func (s *registrationServiceImpl) CountAgents() int {
s.mutex.Lock()
defer s.mutex.Unlock()

return s.countAgentsUnsafe()
}

func (s *registrationServiceImpl) countAgentsUnsafe() int {
res := 0
s.externalAgents.Visit(func(a *ExternalAgent) {
res++
Expand Down Expand Up @@ -237,7 +244,7 @@ func (s *registrationServiceImpl) CreateInternalAgent(agentName string) (*Intern
return nil, ErrRegistrationServiceOff
}

if s.CountAgents() >= MaxAgentsAllowed {
if s.countAgentsUnsafe() >= MaxAgentsAllowed {
return nil, ErrTooManyExtensions
}

Expand Down
18 changes: 16 additions & 2 deletions lambda/core/statejson/description.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (

// StateDescription ...
type StateDescription struct {
Name string `json:"name"`
LastModified int64 `json:"lastModified"`
Name string `json:"name"`
LastModified int64 `json:"lastModified"`
ResponseTimeNs int64 `json:"responseTimeNs"`
}

// RuntimeDescription ...
Expand All @@ -34,10 +35,23 @@ type InternalStateDescription struct {
FirstFatalError string `json:"firstFatalError"`
}

// ResetDescription describes fields of the response to an INVOKE API request
type ResetDescription struct {
ExtensionsResetMs int64 `json:"extensionsResetMs"`
}

func (s *InternalStateDescription) AsJSON() []byte {
bytes, err := json.Marshal(s)
if err != nil {
log.Panicf("Failed to marshall internal states: %s", err)
}
return bytes
}

func (s *ResetDescription) AsJSON() []byte {
bytes, err := json.Marshal(s)
if err != nil {
log.Panicf("Failed to marshall reset description: %s", err)
}
return bytes
}
13 changes: 11 additions & 2 deletions lambda/core/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type Runtime struct {
currentState RuntimeState
stateLastModified time.Time
Pid int
responseTime time.Time

RuntimeStartedState RuntimeState
RuntimeInitErrorState RuntimeState
Expand Down Expand Up @@ -150,19 +151,27 @@ func (s *Runtime) InitError() error {
func (s *Runtime) ResponseSent() error {
s.ManagedThread.Lock()
defer s.ManagedThread.Unlock()
return s.currentState.ResponseSent()
err := s.currentState.ResponseSent()
if err == nil {
s.responseTime = time.Now()
}
return err
}

// GetRuntimeDescription returns runtime description object for debugging purposes
func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription {
s.ManagedThread.Lock()
defer s.ManagedThread.Unlock()
return statejson.RuntimeDescription{
res := statejson.RuntimeDescription{
State: statejson.StateDescription{
Name: s.currentState.Name(),
LastModified: s.stateLastModified.UnixNano() / int64(time.Millisecond),
},
}
if !s.responseTime.IsZero() {
res.State.ResponseTimeNs = s.responseTime.UnixNano()
}
return res
}

// NewRuntime returns new Runtime instance.
Expand Down
1 change: 1 addition & 0 deletions lambda/core/states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) {
runtime.SetState(runtime.RuntimeInvocationResponseState)
assert.NoError(t, runtime.ResponseSent())
assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState())
assert.NotEqual(t, 0, runtime.GetRuntimeDescription().State.ResponseTimeNs)
// InvocationResponse-> InvocationResponse
runtime.SetState(runtime.RuntimeInvocationResponseState)
assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse())
Expand Down
1 change: 1 addition & 0 deletions lambda/fatalerror/fatalerror.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched
RuntimeExit ErrorType = "Runtime.ExitError"
InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint"
InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir"
InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig"
Unknown ErrorType = "Unknown"
)
Loading