Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/functions/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Function struct {
ReturnType id.Type
ParameterNames []string
ParameterTypes []id.Type
ParameterDefaults []string
Variadic bool
IsNonDeterministic bool
Strict bool
Expand Down
11 changes: 8 additions & 3 deletions core/functions/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (function Function) Serialize(ctx context.Context) ([]byte, error) {

// Write all of the functions to the writer
writer := utils.NewWriter(256)
writer.VariableUint(2) // Version
writer.VariableUint(3) // Version
// Write the function data
writer.Id(function.ID.AsId())
writer.Id(function.ReturnType.AsId())
Expand All @@ -58,6 +58,8 @@ func (function Function) Serialize(ctx context.Context) ([]byte, error) {
// Write version 2 data
writer.String(function.SQLDefinition)
writer.Bool(function.SetOf)
// Write version 3 data
writer.StringSlice(function.ParameterDefaults)
// Returns the data
return writer.Data(), nil
}
Expand All @@ -70,7 +72,7 @@ func DeserializeFunction(ctx context.Context, data []byte) (Function, error) {
}
reader := utils.NewReader(data)
version := reader.VariableUint()
if version > 2 {
if version > 3 {
return Function{}, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
}

Expand Down Expand Up @@ -101,10 +103,13 @@ func DeserializeFunction(ctx context.Context, data []byte) (Function, error) {
f.ExtensionName = reader.String()
f.ExtensionSymbol = reader.String()
}
if version == 2 {
if version >= 2 {
f.SQLDefinition = reader.String()
f.SetOf = reader.Bool()
}
if version >= 3 {
f.ParameterDefaults = reader.StringSlice()
}
if !reader.IsEmpty() {
return Function{}, errors.Errorf("extra data found while deserializing a function")
}
Expand Down
19 changes: 10 additions & 9 deletions core/procedures/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@ type Collection struct {

// Procedure represents a created procedure.
type Procedure struct {
ID id.Procedure
ParameterNames []string
ParameterTypes []id.Type
ParameterModes []ParameterMode
Definition string
ExtensionName string // Only used when this is an extension procedure
ExtensionSymbol string // Only used when this is an extension procedure
Operations []plpgsql.InterpreterOperation // Only used when this is a plpgsql language
SQLDefinition string // Only used when this is a sql language
ID id.Procedure
ParameterNames []string
ParameterTypes []id.Type
ParameterModes []ParameterMode
ParameterDefaults []string
Definition string
ExtensionName string // Only used when this is an extension procedure
ExtensionSymbol string // Only used when this is an extension procedure
Operations []plpgsql.InterpreterOperation // Only used when this is a plpgsql language
SQLDefinition string // Only used when this is a sql language
}

var _ objinterface.Collection = (*Collection)(nil)
Expand Down
11 changes: 8 additions & 3 deletions core/procedures/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (procedure Procedure) Serialize(ctx context.Context) ([]byte, error) {

// Write all of the procedures to the writer
writer := utils.NewWriter(256)
writer.VariableUint(0) // Version
writer.VariableUint(1) // Version
// Write the procedure data
writer.Id(procedure.ID.AsId())
writer.StringSlice(procedure.ParameterNames)
Expand All @@ -56,6 +56,8 @@ func (procedure Procedure) Serialize(ctx context.Context) ([]byte, error) {
writer.Int32(int32(op.Index))
writer.StringMap(op.Options)
}
// Write version 1 data
writer.StringSlice(procedure.ParameterDefaults)
// Returns the data
return writer.Data(), nil
}
Expand All @@ -68,8 +70,8 @@ func DeserializeProcedure(ctx context.Context, data []byte) (Procedure, error) {
}
reader := utils.NewReader(data)
version := reader.VariableUint()
if version > 0 {
return Procedure{}, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
if version > 1 {
return Procedure{}, errors.Errorf("version %d of procedures is not supported, please upgrade the server", version)
}

// Read from the reader
Expand Down Expand Up @@ -100,6 +102,9 @@ func DeserializeProcedure(ctx context.Context, data []byte) (Procedure, error) {
op.Options = reader.StringMap()
p.Operations[opIdx] = op
}
if version >= 1 {
p.ParameterDefaults = reader.StringSlice()
}
if !reader.IsEmpty() {
return Procedure{}, errors.New("extra data found while deserializing a procedure")
}
Expand Down
8 changes: 5 additions & 3 deletions server/analyzer/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ func ValidateCreateFunction(ctx *sql.Context, a *analyzer.Analyzer, n sql.Node,
}

builder := planbuilder.New(ctx, a.Catalog, nil)
_, _, err = builder.BindOnly(ct.SqlDefParsed, ct.SqlDef, nil)
if err != nil {
return nil, transform.SameTree, err
for _, parsed := range ct.SqlDefParsedStmts {
_, _, err = builder.BindOnly(parsed, ct.SqlDef, nil)
if err != nil {
return nil, transform.SameTree, err
}
}

return n, transform.SameTree, nil
Expand Down
6 changes: 4 additions & 2 deletions server/analyzer/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ const (
ruleId_ValidateCreateTable // validateCreateTable
ruleId_ValidateCreateSchema // validateCreateSchema
ruleId_ResolveAlterColumn // resolveAlterColumn
ruleId_ValidateCreateFunction
ruleId_ResolveValuesTypes // resolveValuesTypes
ruleId_ValidateCreateFunction // validateCreateFunction
ruleId_ResolveValuesTypes // resolveValuesTypes
ruleId_ResolveProcedureDefaults // resolveProcedureDefaults
)

// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
Expand All @@ -66,6 +67,7 @@ func Init() {
analyzer.Rule{Id: ruleId_AssignTriggers, Apply: AssignTriggers},
analyzer.Rule{Id: ruleId_ValidateCreateFunction, Apply: ValidateCreateFunction},
analyzer.Rule{Id: ruleId_ValidateCreateSchema, Apply: ValidateCreateSchema},
analyzer.Rule{Id: ruleId_ResolveProcedureDefaults, Apply: ResolveProcedureDefaults},
)

analyzer.OnceBeforeDefault = append([]analyzer.Rule{
Expand Down
32 changes: 32 additions & 0 deletions server/analyzer/optimize_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
package analyzer

import (
"fmt"

"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/planbuilder"
"github.com/dolthub/go-mysql-server/sql/transform"

"github.com/dolthub/doltgresql/server/functions/framework"
Expand Down Expand Up @@ -58,6 +62,13 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
if quickFunction := compiledFunction.GetQuickFunction(); quickFunction != nil {
return quickFunction, transform.NewTree, nil
}

// fill in default exprs if applicable
if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
return getDefaultExpr(ctx, a.Catalog, defExpr)
}); err != nil {
return nil, transform.SameTree, err
}
}
if v, ok := in.(*plan.Values); ok {
hasMultipleExpressionTuples = len(v.ExpressionTuples) > 1
Expand Down Expand Up @@ -92,6 +103,13 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
return nil, transform.SameTree, err
}
}

// fill in default exprs if applicablea
if err = compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
return getDefaultExpr(ctx, a.Catalog, defExpr)
}); err != nil {
return nil, transform.SameTree, err
}
}
return expr, transform.SameTree, nil
})
Expand All @@ -113,3 +131,17 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
return projectNode, sameNode && sameExprs, err
})
}

// getDefaultExpr takes the default value definition, parses, builds and returns sql.ColumnDefaultValue.
func getDefaultExpr(ctx *sql.Context, c sql.Catalog, defExpr string) (sql.Expression, error) {
builder := planbuilder.New(ctx, c, nil)
proj, _, _, _, err := builder.Parse(fmt.Sprintf("select %s", defExpr), nil, false)
if err != nil {
return nil, err
}
parsedExpr := proj.(*plan.Project).Projections[0]
if a, ok := parsedExpr.(*expression.Alias); ok {
parsedExpr = a.Child
}
return parsedExpr, nil
}
130 changes: 130 additions & 0 deletions server/analyzer/resolve_routine_defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright 2026 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package analyzer

import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/extensions"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/functions"
"github.com/dolthub/doltgresql/server/functions/framework"
pgnodes "github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// ResolveProcedureDefaults resolves default expressions of routines that are in string format by parsing it into sql.Expression.
// This function retrieves the procedure overloads and sets CompiledFunction in the Call node.
func ResolveProcedureDefaults(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
switch n := node.(type) {
case *pgnodes.Call:
procCollection, err := core.GetProceduresCollectionFromContext(ctx)
if err != nil {
return nil, transform.SameTree, err
}
typesCollection, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, transform.SameTree, err
}
schemaName, err := core.GetSchemaName(ctx, nil, n.SchemaName)
if err != nil {
return nil, transform.SameTree, err
}
procName := id.NewProcedure(schemaName, n.ProcedureName)
overloads, err := procCollection.GetProcedureOverloads(ctx, procName)
if err != nil {
return nil, transform.SameTree, err
}
if len(overloads) == 0 {
if strings.HasPrefix(n.ProcedureName, "dolt_") {
return nil, transform.SameTree, functions.ErrDoltProcedureSelectOnly
}
return nil, transform.SameTree, sql.ErrStoredProcedureDoesNotExist.New(n.ProcedureName)
}

same := transform.SameTree
overloadTree := framework.NewOverloads()
for _, overload := range overloads {
paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes))
for i, paramType := range overload.ParameterTypes {
paramTypes[i], err = typesCollection.GetType(ctx, paramType)
if err != nil || paramTypes[i] == nil {
return nil, transform.SameTree, err
}
}
// TODO: we should probably have procedure equivalents instead of converting these to functions
// probably fine for now since we don't implement/support the differing functionality between the two just yet
if len(overload.ExtensionName) > 0 {
if err = overloadTree.Add(framework.CFunction{
ID: id.Function(overload.ID),
ReturnType: pgtypes.Void,
ParameterTypes: paramTypes,
Variadic: false,
IsNonDeterministic: true,
Strict: false,
ExtensionName: extensions.LibraryIdentifier(overload.ExtensionName),
ExtensionSymbol: overload.ExtensionSymbol,
}); err != nil {
return nil, transform.SameTree, err
}
} else if len(overload.SQLDefinition) > 0 {
if err = overloadTree.Add(framework.SQLFunction{
ID: id.Function(overload.ID),
ReturnType: pgtypes.Void,
ParameterNames: overload.ParameterNames,
ParameterTypes: paramTypes,
ParameterDefaults: overload.ParameterDefaults,
Variadic: false,
IsNonDeterministic: true,
Strict: false,
SqlStatement: overload.SQLDefinition,
SetOf: false,
}); err != nil {
return nil, transform.SameTree, err
}
} else {
if err = overloadTree.Add(framework.InterpretedFunction{
ID: id.Function(overload.ID),
ReturnType: pgtypes.Void,
ParameterNames: overload.ParameterNames,
ParameterTypes: paramTypes,
Variadic: false,
IsNonDeterministic: true,
Strict: false,
Statements: overload.Operations,
}); err != nil {
return nil, transform.SameTree, err
}
}
}
compiledFunction := framework.NewCompiledFunction(n.ProcedureName, n.Exprs, overloadTree, false)
// fill in default exprs if applicable
if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
return getDefaultExpr(ctx, a.Catalog, defExpr)
}); err != nil {
return nil, transform.SameTree, err
}
n.CompiledFunc = compiledFunction
return node, same, nil
default:
return node, transform.SameTree, nil
}
}
12 changes: 4 additions & 8 deletions server/analyzer/resolve_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,22 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node,
if err != nil {
return nil, transform.NewTree, err
}
paramTypes := make([]*pgtypes.DoltgresType, len(n.ParameterTypes))
for i := range n.ParameterTypes {
paramTypes[i], err = resolveType(ctx, db, n.ParameterTypes[i])
for i := range n.Parameters {
n.Parameters[i].Type, err = resolveType(ctx, db, n.Parameters[i].Type)
if err != nil {
return nil, transform.NewTree, err
}
}
n.ReturnType = retType
n.ParameterTypes = paramTypes
return node, transform.NewTree, nil
case *pgnodes.CreateProcedure:
paramTypes := make([]*pgtypes.DoltgresType, len(n.ParameterTypes))
for i := range n.ParameterTypes {
for i := range n.Parameters {
var err error
paramTypes[i], err = resolveType(ctx, db, n.ParameterTypes[i])
n.Parameters[i].Type, err = resolveType(ctx, db, n.Parameters[i].Type)
if err != nil {
return nil, transform.NewTree, err
}
}
n.ParameterTypes = paramTypes
return node, transform.NewTree, nil
case *plan.CreateTable:
for _, col := range n.TargetSchema() {
Expand Down
2 changes: 1 addition & 1 deletion server/analyzer/validate_column_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
pgnode "github.com/dolthub/doltgresql/server/node"
)

// validateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
// ValidateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
// column, various other business logic checks to match MySQL's logic.
func ValidateColumnDefaults(ctx *sql.Context, _ *analyzer.Analyzer, n sql.Node, _ *plan.Scope, _ analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
span, ctx := ctx.Span("validateColumnDefaults")
Expand Down
Loading
Loading