Skip to content
2 changes: 1 addition & 1 deletion packages/wallet/core/src/signers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Signer {

export interface SapientSigner {
readonly address: MaybePromise<Address.Address>
readonly imageHash: MaybePromise<Hex.Hex | undefined>
readonly imageHash: MaybePromise<Hex.Hex>

signSapient: (
wallet: Address.Address,
Expand Down
84 changes: 57 additions & 27 deletions packages/wallet/core/src/signers/session-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ export class SessionManager implements SapientSigner {
this._provider = options.provider
}

get imageHash(): Promise<Hex.Hex | undefined> {
get imageHash(): Promise<Hex.Hex> {
return this.getImageHash()
}

async getImageHash(): Promise<Hex.Hex | undefined> {
async getImageHash(): Promise<Hex.Hex> {
const { configuration } = await this.wallet.getStatus()
const sessionConfigLeaf = Config.findSignerLeaf(configuration, this.address)
if (!sessionConfigLeaf || !Config.isSapientSignerLeaf(sessionConfigLeaf)) {
return undefined
throw new Error(`Session configuration not found for wallet ${this.wallet.address}`)
}
return sessionConfigLeaf.imageHash
}
Expand All @@ -72,6 +72,10 @@ export class SessionManager implements SapientSigner {
if (!imageHash) {
throw new Error(`Session configuration not found for image hash ${imageHash}`)
}
return this._getTopologyForImageHash(imageHash)
}

private async _getTopologyForImageHash(imageHash: Hex.Hex): Promise<SessionConfig.SessionsTopology> {
const tree = await this.stateProvider.getTree(imageHash)
if (!tree) {
throw new Error(`Session configuration not found for image hash ${imageHash}`)
Expand Down Expand Up @@ -131,8 +135,21 @@ export class SessionManager implements SapientSigner {
}

async findSignersForCalls(wallet: Address.Address, chainId: number, calls: Payload.Call[]): Promise<SessionSigner[]> {
if (!Address.isEqual(this.wallet.address, wallet)) {
throw new Error('Wallet address mismatch')
}
// Only use signers that match the topology
const topology = await this.topology
return this._findSignersForCalls(wallet, chainId, calls, topology)
}

private async _findSignersForCalls(
wallet: Address.Address,
chainId: number,
calls: Payload.Call[],
topology: SessionConfig.SessionsTopology,
): Promise<SessionSigner[]> {
// Only use signers that match the topology
const identitySigners = SessionConfig.getIdentitySigners(topology)
if (identitySigners.length === 0) {
throw new Error('Identity signers not found')
Expand Down Expand Up @@ -173,11 +190,24 @@ export class SessionManager implements SapientSigner {
wallet: Address.Address,
chainId: number,
calls: Payload.Call[],
): Promise<Payload.Call | null> {
if (!Address.isEqual(wallet, this.wallet.address)) {
throw new Error('Wallet address mismatch')
}
const topology = await this.topology
return this._prepareIncrement(wallet, chainId, calls, topology)
}

private async _prepareIncrement(
wallet: Address.Address,
chainId: number,
calls: Payload.Call[],
topology: SessionConfig.SessionsTopology,
): Promise<Payload.Call | null> {
if (calls.length === 0) {
throw new Error('No calls provided')
}
const signers = await this.findSignersForCalls(wallet, chainId, calls)
const signers = await this._findSignersForCalls(wallet, chainId, calls, topology)

// Create a map of signers to their associated calls
const signerToCalls = new Map<SessionSigner, Payload.Call[]>()
Expand Down Expand Up @@ -233,18 +263,18 @@ export class SessionManager implements SapientSigner {
if (!Address.isEqual(wallet, this.wallet.address)) {
throw new Error('Wallet address mismatch')
}
if (this._provider) {
const providerChainId = await this._provider.request({
method: 'eth_chainId',
})
if (providerChainId !== Hex.fromNumber(chainId)) {
throw new Error(`Provider chain id mismatch, expected ${Hex.fromNumber(chainId)} but got ${providerChainId}`)
}
}
if ((await this.imageHash) !== imageHash) {
throw new Error('Unexpected image hash')
}
//FIXME Test chain id
// if (this._provider) {
// const providerChainId = await this._provider.request({
// method: 'eth_chainId',
// })
// if (providerChainId !== Hex.fromNumber(chainId)) {
// throw new Error(`Provider chain id mismatch, expected ${Hex.fromNumber(chainId)} but got ${providerChainId}`)
// }
// }
const topology = await this._getTopologyForImageHash(imageHash)
if (!Payload.isCalls(payload) || payload.calls.length === 0) {
throw new Error('Only calls are supported')
}
Expand All @@ -254,7 +284,7 @@ export class SessionManager implements SapientSigner {
throw new Error(`Space ${payload.space} is too large`)
}

const signers = await this.findSignersForCalls(wallet, chainId, payload.calls)
const signers = await this._findSignersForCalls(wallet, chainId, payload.calls, topology)
if (signers.length !== payload.calls.length) {
throw new Error('No signer supported for call')
}
Expand All @@ -270,7 +300,7 @@ export class SessionManager implements SapientSigner {
)

// Check if the last call is an increment usage call
const expectedIncrement = await this.prepareIncrement(wallet, chainId, payload.calls)
const expectedIncrement = await this._prepareIncrement(wallet, chainId, payload.calls, topology)
if (expectedIncrement) {
let actualIncrement: Payload.Call
if (
Expand Down Expand Up @@ -327,7 +357,7 @@ export class SessionManager implements SapientSigner {
// Perform encoding
const encodedSignature = SessionSignature.encodeSessionCallSignatures(
signatures,
await this.topology,
topology,
identitySigner,
explicitSigners,
implicitSigners,
Expand All @@ -346,23 +376,23 @@ export class SessionManager implements SapientSigner {
payload: Payload.Parented,
signature: SignatureTypes.SignatureOfSapientSignerLeaf,
): Promise<boolean> {
if (!Payload.isCalls(payload)) {
if (!Address.isEqual(wallet, this.wallet.address)) {
throw new Error('Wallet address mismatch')
}
if (!Payload.isCalls(payload) || payload.calls.length === 0) {
// Only calls are supported
return false
}

if (!this._provider) {
throw new Error('Provider not set')
}
//FIXME Test chain id
// const providerChainId = await this._provider.request({
// method: 'eth_chainId',
// })
// if (providerChainId !== Hex.fromNumber(chainId)) {
// throw new Error(
// `Provider chain id mismatch, expected ${Hex.fromNumber(chainId)} but got ${providerChainId}`,
// )
// }
// Test chain id
const providerChainId = await this._provider.request({
method: 'eth_chainId',
})
if (providerChainId !== Hex.fromNumber(chainId)) {
throw new Error(`Provider chain id mismatch, expected ${Hex.fromNumber(chainId)} but got ${providerChainId}`)
}

const encodedPayload = Payload.encodeSapient(chainId, payload)
const encodedCallData = AbiFunction.encodeData(Constants.RECOVER_SAPIENT_SIGNATURE, [
Expand Down
21 changes: 11 additions & 10 deletions packages/wallet/core/src/wallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ export class Wallet {
// If the latest configuration does not match the onchain configuration
// then we bundle the update into the transaction envelope
if (!options?.noConfigUpdate) {
const status = await this.getStatus(provider)
if (status.imageHash !== status.onChainImageHash) {
calls.push({
to: this.address,
Expand Down Expand Up @@ -402,7 +401,7 @@ export class Wallet {
factory,
factoryData,
},
...(await this.prepareBlankEnvelope(Number(chainId))),
...(await this.prepareBlankEnvelope(Number(chainId), status.configuration)),
}
}

Expand Down Expand Up @@ -461,15 +460,15 @@ export class Wallet {
}
}

const [chainId, nonce] = await Promise.all([
const [chainId, nonce, status] = await Promise.all([
provider.request({ method: 'eth_chainId' }),
this.getNonce(provider, space),
this.getStatus(provider),
])

// If the latest configuration does not match the onchain configuration
// then we bundle the update into the transaction envelope
if (!options?.noConfigUpdate) {
const status = await this.getStatus(provider)
if (status.imageHash !== status.onChainImageHash) {
calls.push({
to: this.address,
Expand All @@ -490,7 +489,7 @@ export class Wallet {
nonce,
calls,
},
...(await this.prepareBlankEnvelope(Number(chainId))),
...(await this.prepareBlankEnvelope(Number(chainId), status.configuration)),
}
}

Expand Down Expand Up @@ -597,13 +596,15 @@ export class Wallet {
return encoded
}

private async prepareBlankEnvelope(chainId: number) {
const status = await this.getStatus()

private async prepareBlankEnvelope(chainId: number, configuration?: Config.Config) {
if (!configuration) {
const status = await this.getStatus()
configuration = status.configuration
}
return {
wallet: this.address,
chainId: chainId,
configuration: status.configuration,
chainId,
configuration,
}
}
}
4 changes: 2 additions & 2 deletions packages/wallet/core/test/session-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ for (const extension of ALL_EXTENSIONS) {
'should fail to sign with an expired explicit session',
async () => {
const provider = Provider.from(RpcTransport.fromHttp(LOCAL_RPC_URL))
const chainId = 0
const chainId = Number(await provider.request({ method: 'eth_chainId' }))

// Create unique identity and state provider for this test
const identityPrivateKey = Secp256k1.randomPrivateKey()
Expand Down Expand Up @@ -561,7 +561,7 @@ for (const extension of ALL_EXTENSIONS) {
}

// Sign the transaction
expect(sessionManager.signSapient(wallet.address, chainId, payload, imageHash)).rejects.toThrow(
await expect(sessionManager.signSapient(wallet.address, chainId, payload, imageHash)).rejects.toThrow(
'No signers match the topology',
)
},
Expand Down
47 changes: 29 additions & 18 deletions packages/wallet/dapp-client/src/ChainSessionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -673,27 +673,22 @@ export class ChainSessionManager {
): Promise<void> {
if (!this.provider || !this.wallet)
throw new InitializationError('Manager core components not ready for explicit session.')
if (!this.sessionManager) throw new InitializationError('Main session manager is not initialized.')

const signerAddress = Address.fromPublicKey(Secp256k1.getPublicKey({ privateKey: pk }))

const maxRetries = allowRetries ? 3 : 1
let lastError: Error | null = null

for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
const tempManager = new Signers.SessionManager(this.wallet, {
sessionManagerAddress: Extensions.Rc3.sessions,
provider: this.provider,
})
const topology = await tempManager.getTopology()
const topology = await this.sessionManager.getTopology()

const signerAddress = Address.fromPublicKey(Secp256k1.getPublicKey({ privateKey: pk }))
const permissions = SessionConfig.getSessionPermissions(topology, signerAddress)

if (!permissions) {
throw new InitializationError(`Permissions not found for session key.`)
}

if (!this.sessionManager) throw new InitializationError('Main session manager is not initialized.')

const explicitSigner = new Signers.Session.Explicit(pk, permissions)
this.sessionManager = this.sessionManager.withExplicitSigner(explicitSigner)

Expand All @@ -711,9 +706,10 @@ export class ChainSessionManager {
return
} catch (err) {
lastError = err instanceof Error ? err : new Error(String(err))
if (attempt < maxRetries) {
await new Promise((resolve) => setTimeout(resolve, 1000 * attempt))
}
}
if (attempt < maxRetries) {
console.error('Explicit session initialization failed, retrying...')
await new Promise((resolve) => setTimeout(resolve, 1000 * attempt))
}
}
if (lastError)
Expand Down Expand Up @@ -755,13 +751,31 @@ export class ChainSessionManager {
}
}

/**
* Checks if the current session has a valid signer.
* @returns A promise that resolves to true if the session has a valid signer, false otherwise.
*/
async hasValidSigner(): Promise<boolean> {
if (!this.wallet || !this.sessionManager || !this.provider || !this.isInitialized) {
return false
}

const signerValidity = await this.sessionManager.listSignerValidity(this.chainId)
if (signerValidity.some((s) => s.isValid)) {
return true
}
// SessionSignerInvalidReason available here
return false
}

/**
* Fetches fee options for a set of transactions.
* @param wallet The wallet address to use for the fee options.
* @param calls The transactions to estimate fees for.
* @returns A promise that resolves with an array of fee options.
* @throws {FeeOptionError} If fetching fee options fails.
*/
async getFeeOptions(calls: Transaction[]): Promise<Relayer.FeeOption[]> {
async getFeeOptions(wallet: Address.Address, calls: Transaction[]): Promise<Relayer.FeeOption[]> {
const callsToSend = calls.map((tx) => ({
to: tx.to,
value: tx.value,
Expand All @@ -772,8 +786,7 @@ export class ChainSessionManager {
behaviorOnError: tx.behaviorOnError ?? ('revert' as const),
}))
try {
const signedCall = await this._buildAndSignCalls(callsToSend)
const feeOptions = await this.relayer.feeOptions(signedCall.to, this.chainId, callsToSend)
const feeOptions = await this.relayer.feeOptions(wallet, this.chainId, callsToSend)
return feeOptions.options
} catch (err) {
throw new FeeOptionError(`Failed to get fee options: ${err instanceof Error ? err.message : String(err)}`)
Expand Down Expand Up @@ -948,9 +961,7 @@ export class ChainSessionManager {
...envelope.payload,
parentWallets: [this.wallet.address],
}
const imageHash = await this.sessionManager.imageHash
if (imageHash === undefined) throw new SessionConfigError('Session manager image hash is undefined')

const imageHash = await this.sessionManager.getImageHash()
const signature = await this.sessionManager.signSapient(
this.wallet.address,
this.chainId,
Expand Down
Loading