Skip to content

Replace GetSymbolTable and derivatives with GetOrInit method #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions internal/ast/symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ type Symbol struct {

type SymbolTable map[string]*Symbol

// GetOrInit returns the symbol table, or initializes it if it's nil.
// This will modify whatever holds the SymbolTable, so is not safe for concurrent use.
func (s *SymbolTable) GetOrInit() SymbolTable {
if *s == nil {
*s = make(SymbolTable)
}
return *s
}

const InternalSymbolNamePrefix = "\xFE" // Invalid UTF8 sequence, will never occur as IdentifierName

const (
Expand Down
19 changes: 0 additions & 19 deletions internal/ast/utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,6 @@ func GetSymbolId(symbol *Symbol) SymbolId {
return SymbolId(id)
}

func GetSymbolTable(data *SymbolTable) SymbolTable {
if *data == nil {
*data = make(SymbolTable)
}
return *data
}

func GetMembers(symbol *Symbol) SymbolTable {
return GetSymbolTable(&symbol.Members)
}

func GetExports(symbol *Symbol) SymbolTable {
return GetSymbolTable(&symbol.Exports)
}

func GetLocals(container *Node) SymbolTable {
return GetSymbolTable(&container.LocalsContainerData().Locals)
}

// Determines if a node is missing (either `nil` or empty)
func NodeIsMissing(node *Node) bool {
return node == nil || node.Loc.Pos() == node.Loc.End() && node.Loc.Pos() >= 0 && node.Kind != KindEndOfFile
Expand Down
58 changes: 35 additions & 23 deletions internal/binder/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,25 @@ func GetSymbolNameForPrivateIdentifier(containingClassSymbol *ast.Symbol, descri
return ast.InternalSymbolNamePrefix + "#" + strconv.Itoa(int(ast.GetSymbolId(containingClassSymbol))) + "@" + description
}

func getMembers(symbol *ast.Symbol) ast.SymbolTable {
return symbol.Members.GetOrInit()
}

func getExports(symbol *ast.Symbol) ast.SymbolTable {
return symbol.Exports.GetOrInit()
}

func getLocals(container *ast.Node) ast.SymbolTable {
return container.LocalsContainerData().Locals.GetOrInit()
}

func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
hasExportModifier := ast.GetCombinedModifierFlags(node)&ast.ModifierFlagsExport != 0
if symbolFlags&ast.SymbolFlagsAlias != 0 {
if node.Kind == ast.KindExportSpecifier || (node.Kind == ast.KindImportEqualsDeclaration && hasExportModifier) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
// Exported module members are given 2 symbols: A local symbol that is classified with an ExportValue flag,
// and an associated export symbol with all the correct flags set on it. There are 2 main reasons:
Expand All @@ -402,33 +414,33 @@ func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags
// and should never be merged directly with other augmentation, and the latter case would be possible if automatic merge is allowed.
if !ast.IsAmbientModule(node) && (hasExportModifier || b.container.Flags&ast.NodeFlagsExportContext != 0) {
if !ast.IsLocalsContainer(b.container) || (ast.HasSyntacticModifier(node, ast.ModifierFlagsDefault) && b.getDeclarationName(node) == ast.InternalSymbolNameMissing) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
// No local symbol for an unnamed default!
}
exportKind := ast.SymbolFlagsNone
if symbolFlags&ast.SymbolFlagsValue != 0 {
exportKind = ast.SymbolFlagsExportValue
}
local := b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, exportKind, symbolExcludes)
local.ExportSymbol = b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
local := b.declareSymbol(getLocals(b.container), nil /*parent*/, node, exportKind, symbolExcludes)
local.ExportSymbol = b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
node.ExportableData().LocalSymbol = local
return local
}
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareClassMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
if ast.IsStatic(node) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareSourceFileMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
if ast.IsExternalModule(b.file) {
return b.declareModuleMember(node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
Expand All @@ -440,14 +452,14 @@ func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags as
case ast.KindClassExpression, ast.KindClassDeclaration:
return b.declareClassMember(node, symbolFlags, symbolExcludes)
case ast.KindEnumDeclaration:
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
case ast.KindTypeLiteral, ast.KindJSDocTypeLiteral, ast.KindObjectLiteralExpression, ast.KindInterfaceDeclaration, ast.KindJsxAttributes:
return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
case ast.KindFunctionType, ast.KindConstructorType, ast.KindCallSignature, ast.KindConstructSignature, ast.KindJSDocSignature,
ast.KindIndexSignature, ast.KindMethodDeclaration, ast.KindMethodSignature, ast.KindConstructor, ast.KindGetAccessor,
ast.KindSetAccessor, ast.KindFunctionDeclaration, ast.KindFunctionExpression, ast.KindArrowFunction,
ast.KindClassStaticBlockDeclaration, ast.KindTypeAliasDeclaration, ast.KindJSTypeAliasDeclaration, ast.KindMappedType:
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
panic("Unhandled case in declareSymbolAndAddToSymbolTable")
}
Expand Down Expand Up @@ -771,7 +783,7 @@ func (b *Binder) bindSourceFileIfExternalModule() {
b.bindSourceFileAsExternalModule()
// Create symbol equivalent for the module.exports = {}
originalSymbol := b.file.Symbol
b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.Exports), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll)
b.declareSymbol(b.file.Symbol.Exports.GetOrInit(), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll)
b.file.Symbol = originalSymbol
}
}
Expand Down Expand Up @@ -833,7 +845,7 @@ func (b *Binder) bindNamespaceExportDeclaration(node *ast.Node) {
case !node.Parent.AsSourceFile().IsDeclarationFile:
b.errorOnNode(node, diagnostics.Global_module_exports_may_only_appear_in_declaration_files)
default:
b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.GlobalExports), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
b.declareSymbol(b.file.Symbol.GlobalExports.GetOrInit(), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
}
}

Expand All @@ -850,12 +862,12 @@ func (b *Binder) bindExportDeclaration(node *ast.Node) {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsExportStar, b.getDeclarationName(node))
} else if decl.ExportClause == nil {
// All export * declarations are collected in an __export symbol
b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone)
b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone)
} else if ast.IsNamespaceExport(decl.ExportClause) {
// declareSymbol walks up parents to find name text, parent _must_ be set
// but won't be set by the normal binder walk until `bindChildren` later on.
setParent(decl.ExportClause, node)
b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
}
}

Expand All @@ -870,7 +882,7 @@ func (b *Binder) bindExportAssignment(node *ast.Node) {
}
// If there is an `export default x;` alias declaration, can't `export default` anything else.
// (In contrast, you can still have `export default function f() {}` and `export default interface I {}`.)
symbol := b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, flags, ast.SymbolFlagsAll)
symbol := b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, flags, ast.SymbolFlagsAll)
if node.AsExportAssignment().IsExportEquals {
// Will be an error later, since the module already has other exports. Just make sure this has a valueDeclaration set.
SetValueDeclaration(symbol, node)
Expand Down Expand Up @@ -946,12 +958,12 @@ func (b *Binder) bindClassLikeDeclaration(node *ast.Node) {
// module might have an exported variable called 'prototype'. We can't allow that as
// that would clash with the built-in 'prototype' for the class.
prototypeSymbol := b.newSymbol(ast.SymbolFlagsProperty|ast.SymbolFlagsPrototype, "prototype")
symbolExport := ast.GetExports(symbol)[prototypeSymbol.Name]
symbolExport := getExports(symbol)[prototypeSymbol.Name]
if symbolExport != nil {
setParent(name, node)
b.errorOnNode(symbolExport.Declarations[0], diagnostics.Duplicate_identifier_0, ast.SymbolName(prototypeSymbol))
}
ast.GetExports(symbol)[prototypeSymbol.Name] = prototypeSymbol
getExports(symbol)[prototypeSymbol.Name] = prototypeSymbol
prototypeSymbol.Parent = symbol
}

Expand Down Expand Up @@ -1014,7 +1026,7 @@ func (b *Binder) bindFunctionPropertyAssignment(node *ast.Node) {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.InternalSymbolNameComputed)
addLateBoundAssignmentDeclarationToSymbol(node, funcSymbol)
} else {
b.declareSymbol(ast.GetExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes)
b.declareSymbol(getExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes)
}
}
}
Expand Down Expand Up @@ -1075,7 +1087,7 @@ func (b *Binder) bindParameter(node *ast.Node) {
if ast.IsParameterPropertyDeclaration(node, node.Parent) {
classDeclaration := node.Parent.Parent
flags := ast.SymbolFlagsProperty | core.IfElse(decl.QuestionToken != nil, ast.SymbolFlagsOptional, ast.SymbolFlagsNone)
b.declareSymbol(ast.GetMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes)
b.declareSymbol(getMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes)
}
}

Expand Down Expand Up @@ -1119,7 +1131,7 @@ func (b *Binder) bindBlockScopedDeclaration(node *ast.Node, symbolFlags ast.Symb
}
fallthrough
default:
b.declareSymbol(ast.GetLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes)
b.declareSymbol(getLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
}

Expand All @@ -1138,7 +1150,7 @@ func (b *Binder) bindTypeParameter(node *ast.Node) {
if node.Parent.Kind == ast.KindInferType {
container := b.getInferTypeContainer(node.Parent)
if container != nil {
b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes)
b.declareSymbol(getLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes)
} else {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsTypeParameter, b.getDeclarationName(node))
}
Expand Down
14 changes: 7 additions & 7 deletions internal/checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ func (c *Checker) mergeModuleAugmentation(moduleName *ast.Node) {
}) {
merged := c.mergeSymbol(moduleAugmentation.Symbol, mainModule, true /*unidirectional*/)
// moduleName will be a StringLiteral since this is not `declare global`.
ast.GetSymbolTable(&c.patternAmbientModuleAugmentations)[moduleName.Text()] = merged
c.patternAmbientModuleAugmentations.GetOrInit()[moduleName.Text()] = merged
} else {
if mainModule.Exports[ast.InternalSymbolNameExportStar] != nil && len(moduleAugmentation.Symbol.Exports) != 0 {
// We may need to merge the module augmentation's exports into the target symbols of the resolved exports
Expand Down Expand Up @@ -13257,10 +13257,10 @@ func (c *Checker) mergeSymbol(target *ast.Symbol, source *ast.Symbol, unidirecti
}
target.Declarations = append(target.Declarations, source.Declarations...)
if source.Members != nil {
c.mergeSymbolTable(ast.GetSymbolTable(&target.Members), source.Members, unidirectional, nil)
c.mergeSymbolTable(target.Members.GetOrInit(), source.Members, unidirectional, nil)
}
if source.Exports != nil {
c.mergeSymbolTable(ast.GetSymbolTable(&target.Exports), source.Exports, unidirectional, target)
c.mergeSymbolTable(target.Exports.GetOrInit(), source.Exports, unidirectional, target)
}
if !unidirectional {
c.recordMergedSymbol(target, source)
Expand Down Expand Up @@ -14284,7 +14284,7 @@ func (c *Checker) getCommonJSExportEquals(exported *ast.Symbol, moduleSymbol *as
merged = c.cloneSymbol(exported)
}
merged.Flags |= ast.SymbolFlagsValueModule
mergedExports := ast.GetExports(merged)
mergedExports := merged.Exports.GetOrInit()
for name, s := range moduleSymbol.Exports {
if name != ast.InternalSymbolNameExportEquals {
if existing, ok := mergedExports[name]; ok {
Expand Down Expand Up @@ -19752,9 +19752,9 @@ func (c *Checker) getPropertyOfUnionOrIntersectionType(t *Type, name string, ski
func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjectFunctionPropertyAugment bool) *ast.Symbol {
var cache ast.SymbolTable
if skipObjectFunctionPropertyAugment {
cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment)
cache = t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment.GetOrInit()
} else {
cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache)
cache = t.AsUnionOrIntersectionType().propertyCache.GetOrInit()
}
if prop := cache[name]; prop != nil {
return prop
Expand All @@ -19764,7 +19764,7 @@ func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjec
cache[name] = prop
// Propagate an entry from the non-augmented cache to the augmented cache unless the property is partial.
if skipObjectFunctionPropertyAugment && prop.CheckFlags&ast.CheckFlagsPartial == 0 {
augmentedCache := ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache)
augmentedCache := t.AsUnionOrIntersectionType().propertyCache.GetOrInit()
if augmentedCache[name] == nil {
augmentedCache[name] = prop
}
Expand Down