Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(unstable_derive): trap atom methods #2741

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
100 changes: 51 additions & 49 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,21 @@ const flushPending = (pending: Pending) => {
}
}

type GetAtomState = <Value>(
atom: Atom<Value>,
originAtomState?: AtomState,
) => AtomState<Value>

// internal & unstable type
type StoreArgs = readonly [
getAtomState: GetAtomState,
// possible other arguments in the future
getAtomState: <Value>(atom: Atom<Value>) => AtomState<Value>,
atomRead: <Value>(
atom: Atom<Value>,
...params: Parameters<Atom<Value>['read']>
) => Value,
atomWrite: <Value, Args extends unknown[], Result>(
atom: WritableAtom<Value, Args, Result>,
...params: Parameters<WritableAtom<Value, Args, Result>['write']>
) => Result,
atomOnMount: <Value, Args extends unknown[], Result>(
atom: WritableAtom<Value, Args, Result>,
setAtom: (...args: Args) => Result,
) => OnUnmount | void,
]

// for debugging purpose only
Expand All @@ -245,7 +251,12 @@ type Store = PrdStore | (PrdStore & DevStoreRev4)
export type INTERNAL_DevStoreRev4 = DevStoreRev4
export type INTERNAL_PrdStore = PrdStore

const buildStore = (getAtomState: StoreArgs[0]): Store => {
const buildStore = (
getAtomState: StoreArgs[0],
atomRead: StoreArgs[1],
atomWrite: StoreArgs[2],
atomOnMount: StoreArgs[3],
): Store => {
dai-shi marked this conversation as resolved.
Show resolved Hide resolved
// for debugging purpose only
let debugMountedAtoms: Set<AnyAtom>

Expand All @@ -264,11 +275,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
if (isPromiseLike(valueOrPromise)) {
patchPromiseForCancelability(valueOrPromise)
for (const a of atomState.d.keys()) {
addPendingPromiseToDependency(
atom,
valueOrPromise,
getAtomState(a, atomState),
)
addPendingPromiseToDependency(atom, valueOrPromise, getAtomState(a))
}
atomState.v = valueOrPromise
delete atomState.e
Expand All @@ -287,9 +294,9 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
const readAtomState = <Value>(
pending: Pending | undefined,
atom: Atom<Value>,
atomState: AtomState<Value>,
force?: (a: AnyAtom) => boolean,
): AtomState<Value> => {
const atomState = getAtomState(atom)
// See if we can skip recomputing this atom.
if (!force?.(atom) && isAtomStateInitialized(atomState)) {
// If the atom is mounted, we can use the cache.
Expand All @@ -304,8 +311,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
([a, n]) =>
// Recursively, read the atom state of the dependency, and
// check if the atom epoch number is unchanged
readAtomState(pending, a, getAtomState(a, atomState), force).n ===
n,
readAtomState(pending, a, force).n === n,
)
) {
return atomState
Expand All @@ -316,7 +322,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
let isSync = true
const getter: Getter = <V>(a: Atom<V>) => {
if (isSelfAtom(atom, a)) {
const aState = getAtomState(a, atomState)
const aState = getAtomState(a)
if (!isAtomStateInitialized(aState)) {
if (hasInitialValue(a)) {
setAtomStateValueOrPromise(a, aState, a.init)
Expand All @@ -328,12 +334,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
return returnAtomValue(aState)
}
// a !== atom
const aState = readAtomState(
pending,
a,
getAtomState(a, atomState),
force,
)
const aState = readAtomState(pending, a, force)
if (isSync) {
addDependency(pending, atom, atomState, a, aState)
} else {
Expand Down Expand Up @@ -374,7 +375,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
},
}
try {
const valueOrPromise = atom.read(getter, options as never)
const valueOrPromise = atomRead(atom, getter, options as never)
setAtomStateValueOrPromise(atom, atomState, valueOrPromise)
if (isPromiseLike(valueOrPromise)) {
valueOrPromise.onCancel?.(() => controller?.abort())
Expand All @@ -399,7 +400,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
}

const readAtom = <Value>(atom: Atom<Value>): Value =>
returnAtomValue(readAtomState(undefined, atom, getAtomState(atom)))
returnAtomValue(readAtomState(undefined, atom))

const getDependents = <Value>(
pending: Pending,
Expand All @@ -408,16 +409,16 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
): Map<AnyAtom, AtomState> => {
const dependents = new Map<AnyAtom, AtomState>()
for (const a of atomState.m?.t || []) {
dependents.set(a, getAtomState(a, atomState))
dependents.set(a, getAtomState(a))
}
for (const atomWithPendingPromise of atomState.p) {
dependents.set(
atomWithPendingPromise,
getAtomState(atomWithPendingPromise, atomState),
getAtomState(atomWithPendingPromise),
)
}
getPendingDependents(pending, atom)?.forEach((dependent) => {
dependents.set(dependent, getAtomState(dependent, atomState))
dependents.set(dependent, getAtomState(dependent))
})
return dependents
}
Expand Down Expand Up @@ -471,7 +472,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
}
}
if (hasChangedDeps) {
readAtomState(pending, a, aState, isMarked)
readAtomState(pending, a, isMarked)
mountDependencies(pending, a, aState)
if (prevEpochNumber !== aState.n) {
addPendingAtom(pending, a, aState)
Expand All @@ -485,16 +486,15 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
const writeAtomState = <Value, Args extends unknown[], Result>(
pending: Pending,
atom: WritableAtom<Value, Args, Result>,
atomState: AtomState<Value>,
...args: Args
): Result => {
const getter: Getter = <V>(a: Atom<V>) =>
returnAtomValue(readAtomState(pending, a, getAtomState(a, atomState)))
returnAtomValue(readAtomState(pending, a))
const setter: Setter = <V, As extends unknown[], R>(
a: WritableAtom<V, As, R>,
...args: As
) => {
const aState = getAtomState(a, atomState)
const aState = getAtomState(a)
let r: R | undefined
if (isSelfAtom(atom, a)) {
if (!hasInitialValue(a)) {
Expand All @@ -511,12 +511,12 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
recomputeDependents(pending, a, aState)
}
} else {
r = writeAtomState(pending, a, aState, ...args) as R
r = writeAtomState(pending, a, ...args) as R
}
flushPending(pending)
return r as R
}
const result = atom.write(getter, setter, ...args)
const result = atomWrite(atom, getter, setter, ...args)
return result
}

Expand All @@ -525,7 +525,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
...args: Args
): Result => {
const pending = createPending()
const result = writeAtomState(pending, atom, getAtomState(atom), ...args)
const result = writeAtomState(pending, atom, ...args)
flushPending(pending)
return result
}
Expand All @@ -538,15 +538,15 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
if (atomState.m && !isPendingPromise(atomState.v)) {
for (const a of atomState.d.keys()) {
if (!atomState.m.d.has(a)) {
const aMounted = mountAtom(pending, a, getAtomState(a, atomState))
const aMounted = mountAtom(pending, a, getAtomState(a))
aMounted.t.add(atom)
atomState.m.d.add(a)
}
}
for (const a of atomState.m.d || []) {
if (!atomState.d.has(a)) {
atomState.m.d.delete(a)
const aMounted = unmountAtom(pending, a, getAtomState(a, atomState))
const aMounted = unmountAtom(pending, a, getAtomState(a))
aMounted?.t.delete(atom)
}
}
Expand All @@ -560,10 +560,10 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
): Mounted => {
if (!atomState.m) {
// recompute atom state
readAtomState(pending, atom, atomState)
readAtomState(pending, atom)
// mount dependencies first
for (const a of atomState.d.keys()) {
const aMounted = mountAtom(pending, a, getAtomState(a, atomState))
const aMounted = mountAtom(pending, a, getAtomState(a))
aMounted.t.add(atom)
}
// mount self
Expand All @@ -575,12 +575,11 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
if (import.meta.env?.MODE !== 'production') {
debugMountedAtoms.add(atom)
}
if (isActuallyWritableAtom(atom) && atom.onMount) {
if (isActuallyWritableAtom(atom)) {
const mounted = atomState.m
const { onMount } = atom
addPendingFunction(pending, () => {
const onUnmount = onMount((...args) =>
writeAtomState(pending, atom, atomState, ...args),
const onUnmount = atomOnMount(atom, (...args) =>
writeAtomState(pending, atom, ...args),
)
if (onUnmount) {
mounted.u = onUnmount
Expand All @@ -599,9 +598,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
if (
atomState.m &&
!atomState.m.l.size &&
!Array.from(atomState.m.t).some((a) =>
getAtomState(a, atomState).m?.d.has(atom),
)
!Array.from(atomState.m.t).some((a) => getAtomState(a).m?.d.has(atom))
) {
// unmount self
const onUnmount = atomState.m.u
Expand All @@ -614,7 +611,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
}
// unmount dependencies
for (const a of atomState.d.keys()) {
const aMounted = unmountAtom(pending, a, getAtomState(a, atomState))
const aMounted = unmountAtom(pending, a, getAtomState(a))
aMounted?.t.delete(atom)
}
return undefined
Expand All @@ -638,7 +635,7 @@ const buildStore = (getAtomState: StoreArgs[0]): Store => {
}

const unstable_derive = (fn: (...args: StoreArgs) => StoreArgs) =>
buildStore(...fn(getAtomState))
buildStore(...fn(getAtomState, atomRead, atomWrite, atomOnMount))

const store: Store = {
get: readAtom,
Expand Down Expand Up @@ -693,7 +690,12 @@ export const createStore = (): Store => {
}
return atomState
}
return buildStore(getAtomState)
return buildStore(
getAtomState,
(atom, ...params) => atom.read(...params),
(atom, ...params) => atom.write(...params),
(atom, ...params) => atom.onMount?.(...params),
)
}

let defaultStore: Store | undefined
Expand Down
Loading