diff --git a/di.go b/di.go index cd3772b..009a6d4 100644 --- a/di.go +++ b/di.go @@ -58,6 +58,7 @@ const ( var initializeShutdownLock sync.Mutex var createInstanceLock sync.Mutex +var configureInstanceLock sync.RWMutex var containerInitialized int32 var beans = make(map[string]reflect.Type) var beanFactories = make(map[string]func(context.Context) (interface{}, error)) @@ -65,6 +66,8 @@ var scopes = make(map[string]Scope) var singletonInstances = make(map[string]interface{}) var userCreatedInstances = make(map[string]bool) var beanPostprocessors = make(map[reflect.Type][]func(bean interface{}) error) +var configurations = make(map[reflect.Type]interface{}) +var configurationTypeCache = make(map[reflect.Type]reflect.Type) // InitializingBean is an interface marking beans that need to be additionally initialized after the container is ready. type InitializingBean interface { @@ -95,6 +98,20 @@ func RegisterBeanPostprocessor(beanType reflect.Type, postprocessor func(bean in return nil } +func RegisterBeanConfiguration [T interface{}](configuration T) error { + initializeShutdownLock.Lock() + defer initializeShutdownLock.Unlock() + if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { + return errors.New("container is already initialized: can't register bean configuration") + } + confType := reflect.TypeOf(configuration) + if _, contains := configurations[confType]; contains { + return errors.New("configuration of this type is already registered") + } + configurations[confType] = configuration + return nil +} + // InitializeContainer function initializes the IoC container. func InitializeContainer() error { initializeShutdownLock.Lock() @@ -259,9 +276,17 @@ func injectSingletonDependencies() error { func injectDependencies(beanID string, instance interface{}, chain map[string]bool) error { logrus.WithField("beanID", beanID).Trace("injecting dependencies") instanceType := beans[beanID] - instanceElement := instanceType.Elem() + return injectDependenciesWithType(instanceType.Elem(), beanID, instance, chain) +} + +func injectDependenciesWithType(instanceElement reflect.Type, beanID string, instance interface{}, chain map[string]bool) error { for i := 0; i < instanceElement.NumField(); i++ { field := instanceElement.Field(i) + if field.Type.Kind() == reflect.Struct && field.Anonymous { + fieldToInject := reflect.ValueOf(instance).Elem().Field(i) + injectDependenciesWithType(field.Type, beanID, fieldToInject.Addr().Interface(), chain) + continue + } beanToInject, ok := field.Tag.Lookup(string(inject)) if !ok { continue @@ -457,14 +482,56 @@ func initializeSingletonInstances() error { return nil } +func applyConfiguration(beanType reflect.Type, instance interface{}) error { + configureInstanceLock.RLock() + configType, ok := configurationTypeCache[beanType] + configureInstanceLock.RUnlock() + + if ok { + if configType != nil { + method := reflect.ValueOf(instance).MethodByName("Configure") + config := configurations[configType] + val:=method.Call([]reflect.Value{reflect.ValueOf(config)}) + if val[0].Interface() != nil { + return val[0].Interface().(error) + } + } + return nil + } else { + configureInstanceLock.Lock() + defer configureInstanceLock.Unlock() + configurationTypeCache[beanType] = nil + for configType, config := range configurations { + method := reflect.ValueOf(instance).MethodByName("Configure") + if method.IsValid() && method.Type().NumIn() == 1 { + if configType.AssignableTo(method.Type().In(0)) { + configurationTypeCache[beanType] = configType + val:=method.Call([]reflect.Value{reflect.ValueOf(config)}) + if val[0].Interface() != nil { + return val[0].Interface().(error) + } + break + } + } + } + return nil + } +} + func initializeInstance(beanID string, instance interface{}) error { + bean := reflect.TypeOf(instance) + + // Configure first, then PostConstruct + if err:= applyConfiguration(bean, instance); err != nil { + return err + } + if impl, ok := instance.(InitializingBean); ok { logrus.WithField("beanID", beanID).Trace("initializing bean") if err := impl.PostConstruct(); err != nil { return err } } - bean := reflect.TypeOf(instance) if postprocessors, ok := beanPostprocessors[bean]; ok { logrus.WithField("beanID", beanID).Trace("postprocessing bean") for _, postprocessor := range postprocessors { @@ -473,6 +540,7 @@ func initializeInstance(beanID string, instance interface{}) error { } } } + return nil } diff --git a/di_test.go b/di_test.go index 723266f..7d8b851 100644 --- a/di_test.go +++ b/di_test.go @@ -1166,3 +1166,102 @@ func (suite *TestSuite) TestShutdownContinueOnError() { assert.Equal(suite.T(), 5, len(closedSingletons)) assert.Equal(suite.T(), 5, len(singletonBeansWithErrorOnClose)) } + +func (suite *TestSuite) TestInjectInParent() { + type SingletonBeanParent struct { + otherBean1 someInterface `di.inject:""` + } + type SingletonBeanChild struct { + SingletonBeanParent + otherBean2 someInterface `di.inject:""` + } + + overwritten, err := RegisterBean("singletonBean", reflect.TypeOf((*SingletonBeanChild)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + overwritten, err = RegisterBean("otherBean", reflect.TypeOf((*otherBean)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + err = InitializeContainer() + assert.NoError(suite.T(), err) + instance, err := GetInstanceSafe("singletonBean") + assert.NoError(suite.T(), err) + assert.NotNil(suite.T(), instance.(*SingletonBeanChild).otherBean1) + assert.NotNil(suite.T(), instance.(*SingletonBeanChild).otherBean2) + _, ok := instance.(*SingletonBeanChild).otherBean1.(*otherBean) + assert.True(suite.T(), ok) + assert.EqualValues(suite.T(), instance.(*SingletonBeanChild).otherBean1, instance.(*SingletonBeanChild).otherBean2) +} + +type SingletonConfiguringInnerBean struct { + value int +} +type SingletonInnerBean struct { +} + +type SingletonConfiguredBean struct { + innerConfigBean *SingletonConfiguringInnerBean `di.inject:""` + innerBean *SingletonInnerBean `di.inject:""` + value int + postConstructCalled bool +} + +type Config1 struct { + Conf int +} + +type Config2 struct { + Conf int +} + +func (bean *SingletonConfiguredBean) Configure(conf *Config1) error { + bean.value = conf.Conf + return nil +} + +func (bean *SingletonConfiguredBean) PostConstruct() error { + bean.postConstructCalled = true + if bean.value == 0 { + return errors.New("not initialized") + } + return nil +} + +func (bean *SingletonConfiguringInnerBean) Configure(conf *Config2) error { + bean.value = conf.Conf + return nil +} + +func (suite *TestSuite) TestInjectConfiguration() { + overwritten, err := RegisterBean("singletonBean1", reflect.TypeOf((*SingletonConfiguredBean)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + overwritten, err = RegisterBean("singletonBean2", reflect.TypeOf((*SingletonConfiguredBean)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + overwritten, err = RegisterBean("innerConfigBean", reflect.TypeOf((*SingletonConfiguringInnerBean)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + overwritten, err = RegisterBean("innerBean", reflect.TypeOf((*SingletonInnerBean)(nil))) + assert.False(suite.T(), overwritten) + assert.NoError(suite.T(), err) + RegisterBeanConfiguration(&Config1{Conf: 11}) + RegisterBeanConfiguration(&Config2{Conf: 22}) + err = InitializeContainer() + assert.NoError(suite.T(), err) + instance, err := GetInstanceSafe("singletonBean1") + assert.NoError(suite.T(), err) + assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerConfigBean) + assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerBean) + assert.Equal(suite.T(), 11, instance.(*SingletonConfiguredBean).value) + assert.Equal(suite.T(), 22, instance.(*SingletonConfiguredBean).innerConfigBean.value) + assert.Equal(suite.T(), true, instance.(*SingletonConfiguredBean).postConstructCalled) + + instance, err = GetInstanceSafe("singletonBean2") + assert.NoError(suite.T(), err) + assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerConfigBean) + assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerBean) + assert.Equal(suite.T(), 11, instance.(*SingletonConfiguredBean).value) + assert.Equal(suite.T(), 22, instance.(*SingletonConfiguredBean).innerConfigBean.value) + assert.Equal(suite.T(), true, instance.(*SingletonConfiguredBean).postConstructCalled) +} \ No newline at end of file